diff options
author | Justin Le <justin@jle.im> | 2016-01-07 00:44:23 -0800 |
---|---|---|
committer | Justin Le <justin@jle.im> | 2016-01-07 00:44:23 -0800 |
commit | 69e0f0c19251cd2abe08cc5f0711f6ac7e42a2ce (patch) | |
tree | 88bb0ec6e918c466a79bb840329453cca5a59616 /packages/base/src/Numeric/LinearAlgebra | |
parent | cc18707dc26ae27b338bc7ff033921b7e4946294 (diff) |
added zipWith to Domain typeclass, and fixed scalar expansion bugs for mapX functions from neglect of extract
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 31 |
1 files changed, 20 insertions, 11 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index bd0e593..b551cd9 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -69,6 +69,7 @@ import Numeric.LinearAlgebra hiding ( | |||
69 | eigenvalues,eigenvaluesSH,build, | 69 | eigenvalues,eigenvaluesSH,build, |
70 | qr,size,dot,chol,range,R,C,sym,mTm,unSym) | 70 | qr,size,dot,chol,range,R,C,sym,mTm,unSym) |
71 | import qualified Numeric.LinearAlgebra as LA | 71 | import qualified Numeric.LinearAlgebra as LA |
72 | import qualified Numeric.LinearAlgebra.Devel as LA | ||
72 | import Data.Proxy(Proxy) | 73 | import Data.Proxy(Proxy) |
73 | import Internal.Static | 74 | import Internal.Static |
74 | import Control.Arrow((***)) | 75 | import Control.Arrow((***)) |
@@ -394,7 +395,6 @@ withVector v f = | |||
394 | Nothing -> error "static/dynamic mismatch" | 395 | Nothing -> error "static/dynamic mismatch" |
395 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) | 396 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) |
396 | 397 | ||
397 | |||
398 | withMatrix | 398 | withMatrix |
399 | :: forall z | 399 | :: forall z |
400 | . Matrix ℝ | 400 | . Matrix ℝ |
@@ -418,9 +418,10 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve | |||
418 | dot :: forall n . (KnownNat n) => vec n -> vec n -> field | 418 | dot :: forall n . (KnownNat n) => vec n -> vec n -> field |
419 | cross :: vec 3 -> vec 3 -> vec 3 | 419 | cross :: vec 3 -> vec 3 -> vec 3 |
420 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n | 420 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n |
421 | dvmap :: forall n. (field -> field) -> vec n -> vec n | 421 | dvmap :: forall n. KnownNat n => (field -> field) -> vec n -> vec n |
422 | dmmap :: forall n m. (field -> field) -> mat n m -> mat n m | 422 | dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m |
423 | outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m | 423 | outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m |
424 | zipWith :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n | ||
424 | 425 | ||
425 | 426 | ||
426 | instance Domain ℝ R L | 427 | instance Domain ℝ R L |
@@ -433,6 +434,7 @@ instance Domain ℝ R L | |||
433 | dvmap = mapR | 434 | dvmap = mapR |
434 | dmmap = mapL | 435 | dmmap = mapL |
435 | outer = outerR | 436 | outer = outerR |
437 | zipWith = zipWithR | ||
436 | 438 | ||
437 | instance Domain ℂ C M | 439 | instance Domain ℂ C M |
438 | where | 440 | where |
@@ -444,6 +446,7 @@ instance Domain ℂ C M | |||
444 | dvmap = mapC | 446 | dvmap = mapC |
445 | dmmap = mapM' | 447 | dmmap = mapM' |
446 | outer = outerC | 448 | outer = outerC |
449 | zipWith = zipWithC | ||
447 | 450 | ||
448 | -------------------------------------------------------------------------------- | 451 | -------------------------------------------------------------------------------- |
449 | 452 | ||
@@ -486,11 +489,14 @@ crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3 | |||
486 | outerR :: (KnownNat m, KnownNat n) => R n -> R m -> L n m | 489 | outerR :: (KnownNat m, KnownNat n) => R n -> R m -> L n m |
487 | outerR (extract -> x) (extract -> y) = mkL (LA.outer x y) | 490 | outerR (extract -> x) (extract -> y) = mkL (LA.outer x y) |
488 | 491 | ||
489 | mapR :: (ℝ -> ℝ) -> R n -> R n | 492 | mapR :: KnownNat n => (ℝ -> ℝ) -> R n -> R n |
490 | mapR f (R (Dim v)) = R (Dim (LA.cmap f v)) | 493 | mapR f (extract -> v) = mkR (LA.cmap f v) |
494 | |||
495 | zipWithR :: KnownNat n => (ℝ -> ℝ -> ℝ) -> R n -> R n -> R n | ||
496 | zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y) | ||
491 | 497 | ||
492 | mapM' :: (ℂ -> ℂ) -> M n m -> M n m | 498 | mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m |
493 | mapM' f (M (Dim (Dim m))) = M (Dim (Dim (LA.cmap f m))) | 499 | mapM' f (extract -> m) = mkM (LA.cmap f m) |
494 | 500 | ||
495 | -------------------------------------------------------------------------------- | 501 | -------------------------------------------------------------------------------- |
496 | 502 | ||
@@ -533,11 +539,14 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) | |||
533 | outerC :: (KnownNat m, KnownNat n) => C n -> C m -> M n m | 539 | outerC :: (KnownNat m, KnownNat n) => C n -> C m -> M n m |
534 | outerC (extract -> x) (extract -> y) = mkM (LA.outer x y) | 540 | outerC (extract -> x) (extract -> y) = mkM (LA.outer x y) |
535 | 541 | ||
536 | mapC :: (ℂ -> ℂ) -> C n -> C n | 542 | mapC :: KnownNat n => (ℂ -> ℂ) -> C n -> C n |
537 | mapC f (C (Dim v)) = C (Dim (LA.cmap f v)) | 543 | mapC f (extract -> v) = mkC (LA.cmap f v) |
544 | |||
545 | zipWithC :: KnownNat n => (ℂ -> ℂ -> ℂ) -> C n -> C n -> C n | ||
546 | zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y) | ||
538 | 547 | ||
539 | mapL :: (ℝ -> ℝ) -> L n m -> L n m | 548 | mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m |
540 | mapL f (L (Dim (Dim m))) = L (Dim (Dim (LA.cmap f m))) | 549 | mapL f (extract -> m) = mkL (LA.cmap f m) |
541 | 550 | ||
542 | -------------------------------------------------------------------------------- | 551 | -------------------------------------------------------------------------------- |
543 | 552 | ||