summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/base/src/Internal/Static.hs17
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs27
2 files changed, 44 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs
index 0068313..419ff07 100644
--- a/packages/base/src/Internal/Static.hs
+++ b/packages/base/src/Internal/Static.hs
@@ -11,6 +11,7 @@
11{-# LANGUAGE FlexibleInstances #-} 11{-# LANGUAGE FlexibleInstances #-}
12{-# LANGUAGE TypeOperators #-} 12{-# LANGUAGE TypeOperators #-}
13{-# LANGUAGE ViewPatterns #-} 13{-# LANGUAGE ViewPatterns #-}
14{-# LANGUAGE BangPatterns #-}
14 15
15{- | 16{- |
16Module : Internal.Static 17Module : Internal.Static
@@ -28,6 +29,7 @@ import qualified Numeric.LinearAlgebra as LA
28import Numeric.LinearAlgebra hiding (konst,size,R,C) 29import Numeric.LinearAlgebra hiding (konst,size,R,C)
29import Internal.Vector as D hiding (R,C) 30import Internal.Vector as D hiding (R,C)
30import Internal.ST 31import Internal.ST
32import Control.DeepSeq
31import Data.Proxy(Proxy) 33import Data.Proxy(Proxy)
32import Foreign.Storable(Storable) 34import Foreign.Storable(Storable)
33import Text.Printf 35import Text.Printf
@@ -50,6 +52,9 @@ lift2F
50 -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) 52 -> Dim n (c t) -> Dim n (c t) -> Dim n (c t)
51lift2F f (Dim u) (Dim v) = Dim (f u v) 53lift2F f (Dim u) (Dim v) = Dim (f u v)
52 54
55instance NFData t => NFData (Dim n t) where
56 rnf (Dim (force -> !_)) = ()
57
53-------------------------------------------------------------------------------- 58--------------------------------------------------------------------------------
54 59
55newtype R n = R (Dim n (Vector ℝ)) 60newtype R n = R (Dim n (Vector ℝ))
@@ -75,6 +80,18 @@ mkL x = L (Dim (Dim x))
75mkM :: Matrix ℂ -> M m n 80mkM :: Matrix ℂ -> M m n
76mkM x = M (Dim (Dim x)) 81mkM x = M (Dim (Dim x))
77 82
83instance NFData (R n) where
84 rnf (R (force -> !_)) = ()
85
86instance NFData (C n) where
87 rnf (C (force -> !_)) = ()
88
89instance NFData (L n m) where
90 rnf (L (force -> !_)) = ()
91
92instance NFData (M n m) where
93 rnf (M (force -> !_)) = ()
94
78-------------------------------------------------------------------------------- 95--------------------------------------------------------------------------------
79 96
80type V n t = Dim n (Vector t) 97type V n t = Dim n (Vector t)
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index 843c727..4de4d7a 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -416,6 +416,9 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve
416 dot :: forall n . (KnownNat n) => vec n -> vec n -> field 416 dot :: forall n . (KnownNat n) => vec n -> vec n -> field
417 cross :: vec 3 -> vec 3 -> vec 3 417 cross :: vec 3 -> vec 3 -> vec 3
418 diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n 418 diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n
419 dvmap :: forall n. (field -> field) -> vec n -> vec n
420 dmmap :: forall n m. (field -> field) -> mat n m -> mat n m
421 outer :: forall n m. vec n -> vec m -> mat n m
419 422
420 423
421instance Domain ℝ R L 424instance Domain ℝ R L
@@ -425,6 +428,9 @@ instance Domain ℝ R L
425 dot = dotR 428 dot = dotR
426 cross = crossR 429 cross = crossR
427 diagR = diagRectR 430 diagR = diagRectR
431 dvmap = mapR
432 dmmap = mapL
433 outer = outerR
428 434
429instance Domain ℂ C M 435instance Domain ℂ C M
430 where 436 where
@@ -433,6 +439,9 @@ instance Domain ℂ C M
433 dot = dotC 439 dot = dotC
434 cross = crossC 440 cross = crossC
435 diagR = diagRectC 441 diagR = diagRectC
442 dvmap = mapC
443 dmmap = mapM'
444 outer = outerC
436 445
437-------------------------------------------------------------------------------- 446--------------------------------------------------------------------------------
438 447
@@ -472,6 +481,15 @@ crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3
472 z2 = x!2*y!0-x!0*y!2 481 z2 = x!2*y!0-x!0*y!2
473 z3 = x!0*y!1-x!1*y!0 482 z3 = x!0*y!1-x!1*y!0
474 483
484outerR :: R n -> R m -> L n m
485outerR (R (Dim x)) (R (Dim y)) = mkL (LA.outer x y)
486
487mapR :: (ℝ -> ℝ) -> R n -> R n
488mapR f (R (Dim v)) = R (Dim (LA.cmap f v))
489
490mapM' :: (ℂ -> ℂ) -> M n m -> M n m
491mapM' f (M (Dim (Dim m))) = M (Dim (Dim (LA.cmap f m)))
492
475-------------------------------------------------------------------------------- 493--------------------------------------------------------------------------------
476 494
477mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n 495mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n
@@ -510,6 +528,15 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3])
510 z2 = x!2*y!0-x!0*y!2 528 z2 = x!2*y!0-x!0*y!2
511 z3 = x!0*y!1-x!1*y!0 529 z3 = x!0*y!1-x!1*y!0
512 530
531outerC :: C n -> C m -> M n m
532outerC (C (Dim x)) (C (Dim y)) = mkM (LA.outer x y)
533
534mapC :: (ℂ -> ℂ) -> C n -> C n
535mapC f (C (Dim v)) = C (Dim (LA.cmap f v))
536
537mapL :: (ℝ -> ℝ) -> L n m -> L n m
538mapL f (L (Dim (Dim m))) = L (Dim (Dim (LA.cmap f m)))
539
513-------------------------------------------------------------------------------- 540--------------------------------------------------------------------------------
514 541
515diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n 542diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n