From e97a22d5bbdde5b96d9401a7abb25534a2d45bd1 Mon Sep 17 00:00:00 2001 From: Justin Le Date: Tue, 22 Dec 2015 22:15:53 -0800 Subject: NFData instances for various Static types, and mapping and outer product methods to Domain --- packages/base/src/Internal/Static.hs | 17 ++++++++++++++ packages/base/src/Numeric/LinearAlgebra/Static.hs | 27 +++++++++++++++++++++++ 2 files changed, 44 insertions(+) (limited to 'packages/base/src') diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs index 0068313..419ff07 100644 --- a/packages/base/src/Internal/Static.hs +++ b/packages/base/src/Internal/Static.hs @@ -11,6 +11,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE BangPatterns #-} {- | Module : Internal.Static @@ -28,6 +29,7 @@ import qualified Numeric.LinearAlgebra as LA import Numeric.LinearAlgebra hiding (konst,size,R,C) import Internal.Vector as D hiding (R,C) import Internal.ST +import Control.DeepSeq import Data.Proxy(Proxy) import Foreign.Storable(Storable) import Text.Printf @@ -50,6 +52,9 @@ lift2F -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) lift2F f (Dim u) (Dim v) = Dim (f u v) +instance NFData t => NFData (Dim n t) where + rnf (Dim (force -> !_)) = () + -------------------------------------------------------------------------------- newtype R n = R (Dim n (Vector ℝ)) @@ -75,6 +80,18 @@ mkL x = L (Dim (Dim x)) mkM :: Matrix ℂ -> M m n mkM x = M (Dim (Dim x)) +instance NFData (R n) where + rnf (R (force -> !_)) = () + +instance NFData (C n) where + rnf (C (force -> !_)) = () + +instance NFData (L n m) where + rnf (L (force -> !_)) = () + +instance NFData (M n m) where + rnf (M (force -> !_)) = () + -------------------------------------------------------------------------------- type V n t = Dim n (Vector t) diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 843c727..4de4d7a 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs @@ -416,6 +416,9 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve dot :: forall n . (KnownNat n) => vec n -> vec n -> field cross :: vec 3 -> vec 3 -> vec 3 diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n + dvmap :: forall n. (field -> field) -> vec n -> vec n + dmmap :: forall n m. (field -> field) -> mat n m -> mat n m + outer :: forall n m. vec n -> vec m -> mat n m instance Domain ℝ R L @@ -425,6 +428,9 @@ instance Domain ℝ R L dot = dotR cross = crossR diagR = diagRectR + dvmap = mapR + dmmap = mapL + outer = outerR instance Domain ℂ C M where @@ -433,6 +439,9 @@ instance Domain ℂ C M dot = dotC cross = crossC diagR = diagRectC + dvmap = mapC + dmmap = mapM' + outer = outerC -------------------------------------------------------------------------------- @@ -472,6 +481,15 @@ crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3 z2 = x!2*y!0-x!0*y!2 z3 = x!0*y!1-x!1*y!0 +outerR :: R n -> R m -> L n m +outerR (R (Dim x)) (R (Dim y)) = mkL (LA.outer x y) + +mapR :: (ℝ -> ℝ) -> R n -> R n +mapR f (R (Dim v)) = R (Dim (LA.cmap f v)) + +mapM' :: (ℂ -> ℂ) -> M n m -> M n m +mapM' f (M (Dim (Dim m))) = M (Dim (Dim (LA.cmap f m))) + -------------------------------------------------------------------------------- mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n @@ -510,6 +528,15 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) z2 = x!2*y!0-x!0*y!2 z3 = x!0*y!1-x!1*y!0 +outerC :: C n -> C m -> M n m +outerC (C (Dim x)) (C (Dim y)) = mkM (LA.outer x y) + +mapC :: (ℂ -> ℂ) -> C n -> C n +mapC f (C (Dim v)) = C (Dim (LA.cmap f v)) + +mapL :: (ℝ -> ℝ) -> L n m -> L n m +mapL f (L (Dim (Dim m))) = L (Dim (Dim (LA.cmap f m))) + -------------------------------------------------------------------------------- diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n -- cgit v1.2.3