diff options
author | Justin Le <justin@jle.im> | 2015-12-22 22:15:53 -0800 |
---|---|---|
committer | Justin Le <justin@jle.im> | 2015-12-22 22:15:53 -0800 |
commit | e97a22d5bbdde5b96d9401a7abb25534a2d45bd1 (patch) | |
tree | 92398355b123ff7acf1789b6b06efd33892280b9 /packages/base/src | |
parent | 35a7f3355611cd20994f36b43acbd7413e09f558 (diff) |
NFData instances for various Static types, and mapping and outer product methods to Domain
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Internal/Static.hs | 17 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 27 |
2 files changed, 44 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs index 0068313..419ff07 100644 --- a/packages/base/src/Internal/Static.hs +++ b/packages/base/src/Internal/Static.hs | |||
@@ -11,6 +11,7 @@ | |||
11 | {-# LANGUAGE FlexibleInstances #-} | 11 | {-# LANGUAGE FlexibleInstances #-} |
12 | {-# LANGUAGE TypeOperators #-} | 12 | {-# LANGUAGE TypeOperators #-} |
13 | {-# LANGUAGE ViewPatterns #-} | 13 | {-# LANGUAGE ViewPatterns #-} |
14 | {-# LANGUAGE BangPatterns #-} | ||
14 | 15 | ||
15 | {- | | 16 | {- | |
16 | Module : Internal.Static | 17 | Module : Internal.Static |
@@ -28,6 +29,7 @@ import qualified Numeric.LinearAlgebra as LA | |||
28 | import Numeric.LinearAlgebra hiding (konst,size,R,C) | 29 | import Numeric.LinearAlgebra hiding (konst,size,R,C) |
29 | import Internal.Vector as D hiding (R,C) | 30 | import Internal.Vector as D hiding (R,C) |
30 | import Internal.ST | 31 | import Internal.ST |
32 | import Control.DeepSeq | ||
31 | import Data.Proxy(Proxy) | 33 | import Data.Proxy(Proxy) |
32 | import Foreign.Storable(Storable) | 34 | import Foreign.Storable(Storable) |
33 | import Text.Printf | 35 | import Text.Printf |
@@ -50,6 +52,9 @@ lift2F | |||
50 | -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) | 52 | -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) |
51 | lift2F f (Dim u) (Dim v) = Dim (f u v) | 53 | lift2F f (Dim u) (Dim v) = Dim (f u v) |
52 | 54 | ||
55 | instance NFData t => NFData (Dim n t) where | ||
56 | rnf (Dim (force -> !_)) = () | ||
57 | |||
53 | -------------------------------------------------------------------------------- | 58 | -------------------------------------------------------------------------------- |
54 | 59 | ||
55 | newtype R n = R (Dim n (Vector ℝ)) | 60 | newtype R n = R (Dim n (Vector ℝ)) |
@@ -75,6 +80,18 @@ mkL x = L (Dim (Dim x)) | |||
75 | mkM :: Matrix ℂ -> M m n | 80 | mkM :: Matrix ℂ -> M m n |
76 | mkM x = M (Dim (Dim x)) | 81 | mkM x = M (Dim (Dim x)) |
77 | 82 | ||
83 | instance NFData (R n) where | ||
84 | rnf (R (force -> !_)) = () | ||
85 | |||
86 | instance NFData (C n) where | ||
87 | rnf (C (force -> !_)) = () | ||
88 | |||
89 | instance NFData (L n m) where | ||
90 | rnf (L (force -> !_)) = () | ||
91 | |||
92 | instance NFData (M n m) where | ||
93 | rnf (M (force -> !_)) = () | ||
94 | |||
78 | -------------------------------------------------------------------------------- | 95 | -------------------------------------------------------------------------------- |
79 | 96 | ||
80 | type V n t = Dim n (Vector t) | 97 | type V n t = Dim n (Vector t) |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 843c727..4de4d7a 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -416,6 +416,9 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve | |||
416 | dot :: forall n . (KnownNat n) => vec n -> vec n -> field | 416 | dot :: forall n . (KnownNat n) => vec n -> vec n -> field |
417 | cross :: vec 3 -> vec 3 -> vec 3 | 417 | cross :: vec 3 -> vec 3 -> vec 3 |
418 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n | 418 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n |
419 | dvmap :: forall n. (field -> field) -> vec n -> vec n | ||
420 | dmmap :: forall n m. (field -> field) -> mat n m -> mat n m | ||
421 | outer :: forall n m. vec n -> vec m -> mat n m | ||
419 | 422 | ||
420 | 423 | ||
421 | instance Domain ℝ R L | 424 | instance Domain ℝ R L |
@@ -425,6 +428,9 @@ instance Domain ℝ R L | |||
425 | dot = dotR | 428 | dot = dotR |
426 | cross = crossR | 429 | cross = crossR |
427 | diagR = diagRectR | 430 | diagR = diagRectR |
431 | dvmap = mapR | ||
432 | dmmap = mapL | ||
433 | outer = outerR | ||
428 | 434 | ||
429 | instance Domain ℂ C M | 435 | instance Domain ℂ C M |
430 | where | 436 | where |
@@ -433,6 +439,9 @@ instance Domain ℂ C M | |||
433 | dot = dotC | 439 | dot = dotC |
434 | cross = crossC | 440 | cross = crossC |
435 | diagR = diagRectC | 441 | diagR = diagRectC |
442 | dvmap = mapC | ||
443 | dmmap = mapM' | ||
444 | outer = outerC | ||
436 | 445 | ||
437 | -------------------------------------------------------------------------------- | 446 | -------------------------------------------------------------------------------- |
438 | 447 | ||
@@ -472,6 +481,15 @@ crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3 | |||
472 | z2 = x!2*y!0-x!0*y!2 | 481 | z2 = x!2*y!0-x!0*y!2 |
473 | z3 = x!0*y!1-x!1*y!0 | 482 | z3 = x!0*y!1-x!1*y!0 |
474 | 483 | ||
484 | outerR :: R n -> R m -> L n m | ||
485 | outerR (R (Dim x)) (R (Dim y)) = mkL (LA.outer x y) | ||
486 | |||
487 | mapR :: (ℝ -> ℝ) -> R n -> R n | ||
488 | mapR f (R (Dim v)) = R (Dim (LA.cmap f v)) | ||
489 | |||
490 | mapM' :: (ℂ -> ℂ) -> M n m -> M n m | ||
491 | mapM' f (M (Dim (Dim m))) = M (Dim (Dim (LA.cmap f m))) | ||
492 | |||
475 | -------------------------------------------------------------------------------- | 493 | -------------------------------------------------------------------------------- |
476 | 494 | ||
477 | mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n | 495 | mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n |
@@ -510,6 +528,15 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) | |||
510 | z2 = x!2*y!0-x!0*y!2 | 528 | z2 = x!2*y!0-x!0*y!2 |
511 | z3 = x!0*y!1-x!1*y!0 | 529 | z3 = x!0*y!1-x!1*y!0 |
512 | 530 | ||
531 | outerC :: C n -> C m -> M n m | ||
532 | outerC (C (Dim x)) (C (Dim y)) = mkM (LA.outer x y) | ||
533 | |||
534 | mapC :: (ℂ -> ℂ) -> C n -> C n | ||
535 | mapC f (C (Dim v)) = C (Dim (LA.cmap f v)) | ||
536 | |||
537 | mapL :: (ℝ -> ℝ) -> L n m -> L n m | ||
538 | mapL f (L (Dim (Dim m))) = L (Dim (Dim (LA.cmap f m))) | ||
539 | |||
513 | -------------------------------------------------------------------------------- | 540 | -------------------------------------------------------------------------------- |
514 | 541 | ||
515 | diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n | 542 | diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n |