From 05e40db4fdc85b73f38ae5e105db0d523176debe Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 10 Jun 2014 16:10:14 +0200 Subject: Domain class --- packages/base/src/Numeric/HMatrix.hs | 235 +++++++++++++++++++++++++---------- 1 file changed, 167 insertions(+), 68 deletions(-) (limited to 'packages/base/src/Numeric/HMatrix.hs') diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs index 421333a..34f4346 100644 --- a/packages/base/src/Numeric/HMatrix.hs +++ b/packages/base/src/Numeric/HMatrix.hs @@ -21,7 +21,7 @@ Copyright : (c) Alberto Ruiz 2006-14 License : BSD3 Stability : experimental -Experimental interface for real arrays with statically checked dimensions. +Experimental interface with statically checked dimensions. -} @@ -37,9 +37,11 @@ module Numeric.HMatrix( unrow, uncol, eye, - diagR, diag, + diag, blockAt, matrix, + -- * Complex + C, M, Her, her, ๐‘–, -- * Products (<>),(#>),(<ยท>), -- * Linear Systems @@ -48,11 +50,11 @@ module Numeric.HMatrix( svd, svdTall, svdFlat, Eigen(..), withNullspace, -- * Misc - Disp(..), + mean, + Disp(..), Domain(..), withVector, withMatrix, toRows, toColumns, - Sized(..), Diag(..), Sym, sym, - module Numeric.LinearAlgebra.HMatrix + Sized(..), Diag(..), Sym, sym ) where @@ -124,17 +126,8 @@ ud2 :: L m n -> Matrix โ„ ud2 (L (Dim (Dim x))) = x --------------------------------------------------------------------------------- -------------------------------------------------------------------------------- -diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ„ -> R k -> L m n -diagR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) - where - ev = extract v - zeros = LA.konst x (max 0 ((min m' n') - size ev)) - m' = fromIntegral . natVal $ (undefined :: Proxy m) - n' = fromIntegral . natVal $ (undefined :: Proxy n) - diag :: KnownNat n => R n -> Sq n diag = diagR 0 @@ -201,65 +194,37 @@ isKonst (unwrap -> x) n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int +isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (โ„‚,(Int,Int)) +isKonstC (unwrap -> x) + | singleM x = Just (x `atIndex` (0,0), (m',n')) + | otherwise = Nothing + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + infixr 8 <> (<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n +(<>) = mulR -(isKonst -> Just (a,(_,k))) <> (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k) - -(isDiag -> Just (0,a,_)) <> (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) - where - v = a' * b' - n = min (size a) (size b) - a' = subVector 0 n a - b' = subVector 0 n b - -(isDiag -> Just (0,a,_)) <> (extract -> b) = mkL (asColumn a * takeRows (size a) b) - -(extract -> a) <> (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) - -a <> b = mkL (extract a LA.<> extract b) infixr 8 #> (#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m -(isDiag -> Just (0, w, _)) #> v = mkR (w * subVector 0 (size w) (extract v)) -m #> v = mkR (extract m LA.#> extract v) +(#>) = appR infixr 8 <ยท> (<ยท>) :: R n -> R n -> โ„ -(ud1 -> u) <ยท> (ud1 -> v) - | singleV u || singleV v = sumElements (u * v) - | otherwise = udot u v +(<ยท>) = dotR -------------------------------------------------------------------------------- -{- -class Minim (n :: Nat) (m :: Nat) - where - type Mini n m :: Nat - -instance forall (n :: Nat) . Minim n n - where - type Mini n n = n - - -instance forall (n :: Nat) (m :: Nat) . (n <= m+1) => Minim n m - where - type Mini n m = n - -instance forall (n :: Nat) (m :: Nat) . (m <= n+1) => Minim n m - where - type Mini n m = m --} - class Diag m d | m -> d where takeDiag :: m -> d - instance forall n . (KnownNat n) => Diag (L n n) (R n) where takeDiag m = mkR (LA.takeDiag (extract m)) @@ -316,6 +281,15 @@ sym :: KnownNat n => Sq n -> Sym n sym m = Sym $ (m + tr m)/2 +๐‘– :: Sized โ„‚ s c => s +๐‘– = konst iC + +newtype Her n = Her (M n n) + +her :: KnownNat n => M n n -> Her n +her m = Her $ (m + LA.tr m)/2 + + instance KnownNat n => Eigen (Sym n) (R n) (L n n) where @@ -375,21 +349,6 @@ toColumns :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R m] toColumns (LA.toColumns . extract -> vs) = map mkR vs -splittest - = do - let v = range :: R 7 - a = snd (split v) :: R 4 - print $ a - print $ snd . headTail . snd . headTail $ v - print $ first (vec3 1 2 3) - print $ second (vec3 1 2 3) - print $ third (vec3 1 2 3) - print $ (snd $ splitRows eye :: L 4 6) - where - first v = fst . headTail $ v - second v = first . snd . headTail $ v - third v = first . snd . headTail . snd . headTail $ v - -------------------------------------------------------------------------------- build @@ -428,9 +387,133 @@ withMatrix a f = Just (SomeNat (_ :: Proxy n)) -> f (mkL a :: L m n) +-------------------------------------------------------------------------------- + +class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat vec + where + mul :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => mat m k -> mat k n -> mat m n + app :: forall m n . (KnownNat m, KnownNat n) => mat m n -> vec n -> vec m + dot :: forall n . (KnownNat n) => vec n -> vec n -> field + cross :: vec 3 -> vec 3 -> vec 3 + diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n + + +instance Domain โ„ R L + where + mul = mulR + app = appR + dot = dotR + cross = crossR + diagR = diagRectR + +instance Domain โ„‚ C M + where + mul = mulC + app = appC + dot = dotC + cross = crossC + diagR = diagRectC + +-------------------------------------------------------------------------------- + +mulR :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n + +mulR (isKonst -> Just (a,(_,k))) (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k) + +mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) + where + v = a' * b' + n = min (size a) (size b) + a' = subVector 0 n a + b' = subVector 0 n b + +mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (size a) b) + +mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) + +mulR a b = mkL (extract a LA.<> extract b) + + +appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m +appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (size w) (extract v)) +appR m v = mkR (extract m LA.#> extract v) + + +dotR :: R n -> R n -> โ„ +dotR (ud1 -> u) (ud1 -> v) + | singleV u || singleV v = sumElements (u * v) + | otherwise = udot u v + + +crossR :: R 3 -> R 3 -> R 3 +crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3 + where + z1 = x!1*y!2-x!2*y!1 + z2 = x!2*y!0-x!0*y!2 + z3 = x!0*y!1-x!1*y!0 + +-------------------------------------------------------------------------------- + +mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n + +mulC (isKonstC -> Just (a,(_,k))) (isKonstC -> Just (b,_)) = konst (a * b * fromIntegral k) + +mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k) + where + v = a' * b' + n = min (size a) (size b) + a' = subVector 0 n a + b' = subVector 0 n b + +mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (size a) b) + +mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (size b) a * asRow b) + +mulC a b = mkM (extract a LA.<> extract b) + + +appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m +appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (size w) (extract v)) +appC m v = mkC (extract m LA.#> extract v) + + +dotC :: KnownNat n => C n -> C n -> โ„‚ +dotC (unwrap -> u) (unwrap -> v) + | singleV u || singleV v = sumElements (conj u * v) + | otherwise = u LA.<ยท> v + + +crossC :: C 3 -> C 3 -> C 3 +crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) + where + z1 = x!1*y!2-x!2*y!1 + z2 = x!2*y!0-x!0*y!2 + z3 = x!0*y!1-x!1*y!0 + +-------------------------------------------------------------------------------- + +diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ„ -> R k -> L m n +diagRectR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) + where + ev = extract v + zeros = LA.konst x (max 0 ((min m' n') - size ev)) + m' = fromIntegral . natVal $ (undefined :: Proxy m) + n' = fromIntegral . natVal $ (undefined :: Proxy n) + + +diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ„‚ -> C k -> M m n +diagRectC x v = mkM (asRow (vjoin [scalar x, ev, zeros])) + where + ev = extract v + zeros = LA.konst x (max 0 ((min m' n') - size ev)) + m' = fromIntegral . natVal $ (undefined :: Proxy m) + n' = fromIntegral . natVal $ (undefined :: Proxy n) -------------------------------------------------------------------------------- +mean :: (KnownNat n, 1<=n) => R n -> โ„ +mean v = v <ยท> (1/dim) + test :: (Bool, IO ()) test = (ok,info) where @@ -490,6 +573,22 @@ test = (ok,info) precD = 1 + 2 * vjoin[ud1 u, 6] LA.<ยท> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v +splittest + = do + let v = range :: R 7 + a = snd (split v) :: R 4 + print $ a + print $ snd . headTail . snd . headTail $ v + print $ first (vec3 1 2 3) + print $ second (vec3 1 2 3) + print $ third (vec3 1 2 3) + print $ (snd $ splitRows eye :: L 4 6) + where + first v = fst . headTail $ v + second v = first . snd . headTail $ v + third v = first . snd . headTail . snd . headTail $ v + + instance (KnownNat n', KnownNat m') => Testable (L n' m') where checkT _ = test -- cgit v1.2.3