summaryrefslogtreecommitdiff
path: root/packages
diff options
context:
space:
mode:
authorJustin Le <justin@jle.im>2016-01-07 00:44:23 -0800
committerJustin Le <justin@jle.im>2016-01-07 00:44:23 -0800
commit69e0f0c19251cd2abe08cc5f0711f6ac7e42a2ce (patch)
tree88bb0ec6e918c466a79bb840329453cca5a59616 /packages
parentcc18707dc26ae27b338bc7ff033921b7e4946294 (diff)
added zipWith to Domain typeclass, and fixed scalar expansion bugs for mapX functions from neglect of extract
Diffstat (limited to 'packages')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs31
1 files changed, 20 insertions, 11 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index bd0e593..b551cd9 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -69,6 +69,7 @@ import Numeric.LinearAlgebra hiding (
69 eigenvalues,eigenvaluesSH,build, 69 eigenvalues,eigenvaluesSH,build,
70 qr,size,dot,chol,range,R,C,sym,mTm,unSym) 70 qr,size,dot,chol,range,R,C,sym,mTm,unSym)
71import qualified Numeric.LinearAlgebra as LA 71import qualified Numeric.LinearAlgebra as LA
72import qualified Numeric.LinearAlgebra.Devel as LA
72import Data.Proxy(Proxy) 73import Data.Proxy(Proxy)
73import Internal.Static 74import Internal.Static
74import Control.Arrow((***)) 75import Control.Arrow((***))
@@ -394,7 +395,6 @@ withVector v f =
394 Nothing -> error "static/dynamic mismatch" 395 Nothing -> error "static/dynamic mismatch"
395 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) 396 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m)
396 397
397
398withMatrix 398withMatrix
399 :: forall z 399 :: forall z
400 . Matrix ℝ 400 . Matrix ℝ
@@ -418,9 +418,10 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve
418 dot :: forall n . (KnownNat n) => vec n -> vec n -> field 418 dot :: forall n . (KnownNat n) => vec n -> vec n -> field
419 cross :: vec 3 -> vec 3 -> vec 3 419 cross :: vec 3 -> vec 3 -> vec 3
420 diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n 420 diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n
421 dvmap :: forall n. (field -> field) -> vec n -> vec n 421 dvmap :: forall n. KnownNat n => (field -> field) -> vec n -> vec n
422 dmmap :: forall n m. (field -> field) -> mat n m -> mat n m 422 dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m
423 outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m 423 outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m
424 zipWith :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n
424 425
425 426
426instance Domain ℝ R L 427instance Domain ℝ R L
@@ -433,6 +434,7 @@ instance Domain ℝ R L
433 dvmap = mapR 434 dvmap = mapR
434 dmmap = mapL 435 dmmap = mapL
435 outer = outerR 436 outer = outerR
437 zipWith = zipWithR
436 438
437instance Domain ℂ C M 439instance Domain ℂ C M
438 where 440 where
@@ -444,6 +446,7 @@ instance Domain ℂ C M
444 dvmap = mapC 446 dvmap = mapC
445 dmmap = mapM' 447 dmmap = mapM'
446 outer = outerC 448 outer = outerC
449 zipWith = zipWithC
447 450
448-------------------------------------------------------------------------------- 451--------------------------------------------------------------------------------
449 452
@@ -486,11 +489,14 @@ crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3
486outerR :: (KnownNat m, KnownNat n) => R n -> R m -> L n m 489outerR :: (KnownNat m, KnownNat n) => R n -> R m -> L n m
487outerR (extract -> x) (extract -> y) = mkL (LA.outer x y) 490outerR (extract -> x) (extract -> y) = mkL (LA.outer x y)
488 491
489mapR :: (ℝ -> ℝ) -> R n -> R n 492mapR :: KnownNat n => (ℝ -> ℝ) -> R n -> R n
490mapR f (R (Dim v)) = R (Dim (LA.cmap f v)) 493mapR f (extract -> v) = mkR (LA.cmap f v)
494
495zipWithR :: KnownNat n => (ℝ -> ℝ -> ℝ) -> R n -> R n -> R n
496zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y)
491 497
492mapM' :: (ℂ -> ℂ) -> M n m -> M n m 498mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m
493mapM' f (M (Dim (Dim m))) = M (Dim (Dim (LA.cmap f m))) 499mapM' f (extract -> m) = mkM (LA.cmap f m)
494 500
495-------------------------------------------------------------------------------- 501--------------------------------------------------------------------------------
496 502
@@ -533,11 +539,14 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3])
533outerC :: (KnownNat m, KnownNat n) => C n -> C m -> M n m 539outerC :: (KnownNat m, KnownNat n) => C n -> C m -> M n m
534outerC (extract -> x) (extract -> y) = mkM (LA.outer x y) 540outerC (extract -> x) (extract -> y) = mkM (LA.outer x y)
535 541
536mapC :: (ℂ -> ℂ) -> C n -> C n 542mapC :: KnownNat n => (ℂ -> ℂ) -> C n -> C n
537mapC f (C (Dim v)) = C (Dim (LA.cmap f v)) 543mapC f (extract -> v) = mkC (LA.cmap f v)
544
545zipWithC :: KnownNat n => (ℂ -> ℂ -> ℂ) -> C n -> C n -> C n
546zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y)
538 547
539mapL :: (ℝ -> ℝ) -> L n m -> L n m 548mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m
540mapL f (L (Dim (Dim m))) = L (Dim (Dim (LA.cmap f m))) 549mapL f (extract -> m) = mkL (LA.cmap f m)
541 550
542-------------------------------------------------------------------------------- 551--------------------------------------------------------------------------------
543 552