From 69e0f0c19251cd2abe08cc5f0711f6ac7e42a2ce Mon Sep 17 00:00:00 2001 From: Justin Le Date: Thu, 7 Jan 2016 00:44:23 -0800 Subject: added zipWith to Domain typeclass, and fixed scalar expansion bugs for mapX functions from neglect of extract --- packages/base/src/Numeric/LinearAlgebra/Static.hs | 31 +++++++++++++++-------- 1 file changed, 20 insertions(+), 11 deletions(-) (limited to 'packages/base/src') diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index bd0e593..b551cd9 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs @@ -69,6 +69,7 @@ import Numeric.LinearAlgebra hiding ( eigenvalues,eigenvaluesSH,build, qr,size,dot,chol,range,R,C,sym,mTm,unSym) import qualified Numeric.LinearAlgebra as LA +import qualified Numeric.LinearAlgebra.Devel as LA import Data.Proxy(Proxy) import Internal.Static import Control.Arrow((***)) @@ -394,7 +395,6 @@ withVector v f = Nothing -> error "static/dynamic mismatch" Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) - withMatrix :: forall z . Matrix ℝ @@ -418,9 +418,10 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve dot :: forall n . (KnownNat n) => vec n -> vec n -> field cross :: vec 3 -> vec 3 -> vec 3 diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n - dvmap :: forall n. (field -> field) -> vec n -> vec n - dmmap :: forall n m. (field -> field) -> mat n m -> mat n m + dvmap :: forall n. KnownNat n => (field -> field) -> vec n -> vec n + dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m + zipWith :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n instance Domain ℝ R L @@ -433,6 +434,7 @@ instance Domain ℝ R L dvmap = mapR dmmap = mapL outer = outerR + zipWith = zipWithR instance Domain ℂ C M where @@ -444,6 +446,7 @@ instance Domain ℂ C M dvmap = mapC dmmap = mapM' outer = outerC + zipWith = zipWithC -------------------------------------------------------------------------------- @@ -486,11 +489,14 @@ crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3 outerR :: (KnownNat m, KnownNat n) => R n -> R m -> L n m outerR (extract -> x) (extract -> y) = mkL (LA.outer x y) -mapR :: (ℝ -> ℝ) -> R n -> R n -mapR f (R (Dim v)) = R (Dim (LA.cmap f v)) +mapR :: KnownNat n => (ℝ -> ℝ) -> R n -> R n +mapR f (extract -> v) = mkR (LA.cmap f v) + +zipWithR :: KnownNat n => (ℝ -> ℝ -> ℝ) -> R n -> R n -> R n +zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y) -mapM' :: (ℂ -> ℂ) -> M n m -> M n m -mapM' f (M (Dim (Dim m))) = M (Dim (Dim (LA.cmap f m))) +mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m +mapM' f (extract -> m) = mkM (LA.cmap f m) -------------------------------------------------------------------------------- @@ -533,11 +539,14 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) outerC :: (KnownNat m, KnownNat n) => C n -> C m -> M n m outerC (extract -> x) (extract -> y) = mkM (LA.outer x y) -mapC :: (ℂ -> ℂ) -> C n -> C n -mapC f (C (Dim v)) = C (Dim (LA.cmap f v)) +mapC :: KnownNat n => (ℂ -> ℂ) -> C n -> C n +mapC f (extract -> v) = mkC (LA.cmap f v) + +zipWithC :: KnownNat n => (ℂ -> ℂ -> ℂ) -> C n -> C n -> C n +zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y) -mapL :: (ℝ -> ℝ) -> L n m -> L n m -mapL f (L (Dim (Dim m))) = L (Dim (Dim (LA.cmap f m))) +mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m +mapL f (extract -> m) = mkL (LA.cmap f m) -------------------------------------------------------------------------------- -- cgit v1.2.3