summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Static.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs95
1 files changed, 41 insertions, 54 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index 388d165..213c42c 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -64,7 +64,7 @@ import GHC.TypeLits
64import Numeric.LinearAlgebra.HMatrix hiding ( 64import Numeric.LinearAlgebra.HMatrix hiding (
65 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vector,matrix,linspace,toRows,toColumns, 65 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vector,matrix,linspace,toRows,toColumns,
66 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build, 66 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build,
67 qr) 67 qr,size)
68import qualified Numeric.LinearAlgebra.HMatrix as LA 68import qualified Numeric.LinearAlgebra.HMatrix as LA
69import Data.Proxy(Proxy) 69import Data.Proxy(Proxy)
70import Numeric.LinearAlgebra.Static.Internal 70import Numeric.LinearAlgebra.Static.Internal
@@ -107,20 +107,20 @@ matrix :: (KnownNat m, KnownNat n) => [ℝ] -> L m n
107matrix = fromList 107matrix = fromList
108 108
109linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n 109linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n
110linspace (a,b) = mkR (LA.linspace d (a,b)) 110linspace (a,b) = v
111 where 111 where
112 d = fromIntegral . natVal $ (undefined :: Proxy n) 112 v = mkR (LA.linspace (size v) (a,b))
113 113
114range :: forall n . KnownNat n => R n 114range :: forall n . KnownNat n => R n
115range = mkR (LA.linspace d (1,fromIntegral d)) 115range = v
116 where 116 where
117 d = fromIntegral . natVal $ (undefined :: Proxy n) 117 v = mkR (LA.linspace d (1,fromIntegral d))
118 d = size v
118 119
119dim :: forall n . KnownNat n => R n 120dim :: forall n . KnownNat n => R n
120dim = mkR (scalar d) 121dim = v
121 where 122 where
122 d = fromIntegral . natVal $ (undefined :: Proxy n) 123 v = mkR (scalar (fromIntegral $ size v))
123
124 124
125-------------------------------------------------------------------------------- 125--------------------------------------------------------------------------------
126 126
@@ -140,7 +140,7 @@ eye = diag 1
140-------------------------------------------------------------------------------- 140--------------------------------------------------------------------------------
141 141
142blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n 142blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n
143blockAt x r c a = mkL res 143blockAt x r c a = res
144 where 144 where
145 z = scalar x 145 z = scalar x
146 z1 = LA.konst x (r,c) 146 z1 = LA.konst x (r,c)
@@ -148,13 +148,8 @@ blockAt x r c a = mkL res
148 ra = min (rows a) . max 0 $ m'-r 148 ra = min (rows a) . max 0 $ m'-r
149 ca = min (cols a) . max 0 $ n'-c 149 ca = min (cols a) . max 0 $ n'-c
150 sa = subMatrix (0,0) (ra, ca) a 150 sa = subMatrix (0,0) (ra, ca) a
151 m' = fromIntegral . natVal $ (undefined :: Proxy m) 151 (m',n') = size res
152 n' = fromIntegral . natVal $ (undefined :: Proxy n) 152 res = mkL $ fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]]
153 res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]]
154
155
156
157
158 153
159-------------------------------------------------------------------------------- 154--------------------------------------------------------------------------------
160 155
@@ -189,22 +184,15 @@ type GL = (KnownNat n, KnownNat m) => L m n
189type GSq = KnownNat n => Sq n 184type GSq = KnownNat n => Sq n
190 185
191isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) 186isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int))
192isKonst (unwrap -> x) 187isKonst s@(unwrap -> x)
193 | singleM x = Just (x `atIndex` (0,0), (m',n')) 188 | singleM x = Just (x `atIndex` (0,0), (size s))
194 | otherwise = Nothing 189 | otherwise = Nothing
195 where
196 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
197 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
198 190
199 191
200isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int)) 192isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int))
201isKonstC (unwrap -> x) 193isKonstC s@(unwrap -> x)
202 | singleM x = Just (x `atIndex` (0,0), (m',n')) 194 | singleM x = Just (x `atIndex` (0,0), (size s))
203 | otherwise = Nothing 195 | otherwise = Nothing
204 where
205 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
206 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
207
208 196
209 197
210infixr 8 <> 198infixr 8 <>
@@ -256,7 +244,7 @@ svd (extract -> m) = (mkL u, mkR s', mkL v)
256 where 244 where
257 (u,s,v) = LA.svd m 245 (u,s,v) = LA.svd m
258 s' = vjoin [s, z] 246 s' = vjoin [s, z]
259 z = LA.konst 0 (max 0 (cols m - size s)) 247 z = LA.konst 0 (max 0 (cols m - LA.size s))
260 248
261 249
262svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) 250svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n)
@@ -333,7 +321,7 @@ withCompactSVD
333 -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z) 321 -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z)
334 -> z 322 -> z
335withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f = 323withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f =
336 case someNatVal $ fromIntegral $ size s of 324 case someNatVal $ fromIntegral $ LA.size s of
337 Nothing -> error "static/dynamic mismatch" 325 Nothing -> error "static/dynamic mismatch"
338 Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k) 326 Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k)
339 327
@@ -350,7 +338,7 @@ qr (extract -> x) = (mkL q, mkL r)
350 338
351split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) 339split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p))
352split (extract -> v) = ( mkR (subVector 0 p' v) , 340split (extract -> v) = ( mkR (subVector 0 p' v) ,
353 mkR (subVector p' (size v - p') v) ) 341 mkR (subVector p' (LA.size v - p') v) )
354 where 342 where
355 p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int 343 p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int
356 344
@@ -383,10 +371,9 @@ build
383 :: forall m n . (KnownNat n, KnownNat m) 371 :: forall m n . (KnownNat n, KnownNat m)
384 => (ℝ -> ℝ -> ℝ) 372 => (ℝ -> ℝ -> ℝ)
385 -> L m n 373 -> L m n
386build f = mkL $ LA.build (m',n') f 374build f = r
387 where 375 where
388 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 376 r = mkL $ LA.build (size r) f
389 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
390 377
391-------------------------------------------------------------------------------- 378--------------------------------------------------------------------------------
392 379
@@ -396,7 +383,7 @@ withVector
396 -> (forall n . (KnownNat n) => R n -> z) 383 -> (forall n . (KnownNat n) => R n -> z)
397 -> z 384 -> z
398withVector v f = 385withVector v f =
399 case someNatVal $ fromIntegral $ size v of 386 case someNatVal $ fromIntegral $ LA.size v of
400 Nothing -> error "static/dynamic mismatch" 387 Nothing -> error "static/dynamic mismatch"
401 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) 388 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m)
402 389
@@ -451,19 +438,19 @@ mulR (isKonst -> Just (a,(_,k))) (isKonst -> Just (b,_)) = konst (a * b * fromIn
451mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) 438mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k)
452 where 439 where
453 v = a' * b' 440 v = a' * b'
454 n = min (size a) (size b) 441 n = min (LA.size a) (LA.size b)
455 a' = subVector 0 n a 442 a' = subVector 0 n a
456 b' = subVector 0 n b 443 b' = subVector 0 n b
457 444
458mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (size a) b) 445mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (LA.size a) b)
459 446
460mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) 447mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (LA.size b) a * asRow b)
461 448
462mulR a b = mkL (extract a LA.<> extract b) 449mulR a b = mkL (extract a LA.<> extract b)
463 450
464 451
465appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m 452appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m
466appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (size w) (extract v)) 453appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (LA.size w) (extract v))
467appR m v = mkR (extract m LA.#> extract v) 454appR m v = mkR (extract m LA.#> extract v)
468 455
469 456
@@ -489,19 +476,19 @@ mulC (isKonstC -> Just (a,(_,k))) (isKonstC -> Just (b,_)) = konst (a * b * from
489mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k) 476mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k)
490 where 477 where
491 v = a' * b' 478 v = a' * b'
492 n = min (size a) (size b) 479 n = min (LA.size a) (LA.size b)
493 a' = subVector 0 n a 480 a' = subVector 0 n a
494 b' = subVector 0 n b 481 b' = subVector 0 n b
495 482
496mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (size a) b) 483mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (LA.size a) b)
497 484
498mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (size b) a * asRow b) 485mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (LA.size b) a * asRow b)
499 486
500mulC a b = mkM (extract a LA.<> extract b) 487mulC a b = mkM (extract a LA.<> extract b)
501 488
502 489
503appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m 490appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m
504appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (size w) (extract v)) 491appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (LA.size w) (extract v))
505appC m v = mkC (extract m LA.#> extract v) 492appC m v = mkC (extract m LA.#> extract v)
506 493
507 494
@@ -521,21 +508,21 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3])
521-------------------------------------------------------------------------------- 508--------------------------------------------------------------------------------
522 509
523diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n 510diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n
524diagRectR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) 511diagRectR x v = r
525 where 512 where
513 r = mkL (asRow (vjoin [scalar x, ev, zeros]))
526 ev = extract v 514 ev = extract v
527 zeros = LA.konst x (max 0 ((min m' n') - size ev)) 515 zeros = LA.konst x (max 0 ((min m' n') - LA.size ev))
528 m' = fromIntegral . natVal $ (undefined :: Proxy m) 516 (m',n') = size r
529 n' = fromIntegral . natVal $ (undefined :: Proxy n)
530 517
531 518
532diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n 519diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n
533diagRectC x v = mkM (asRow (vjoin [scalar x, ev, zeros])) 520diagRectC x v = r
534 where 521 where
522 r = mkM (asRow (vjoin [scalar x, ev, zeros]))
535 ev = extract v 523 ev = extract v
536 zeros = LA.konst x (max 0 ((min m' n') - size ev)) 524 zeros = LA.konst x (max 0 ((min m' n') - LA.size ev))
537 m' = fromIntegral . natVal $ (undefined :: Proxy m) 525 (m',n') = size r
538 n' = fromIntegral . natVal $ (undefined :: Proxy n)
539 526
540-------------------------------------------------------------------------------- 527--------------------------------------------------------------------------------
541 528
@@ -578,10 +565,10 @@ test = (ok,info)
578 tm = lmat 0 [1..] 565 tm = lmat 0 [1..]
579 566
580 lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n 567 lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n
581 lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z 568 lmat z xs = r
582 where 569 where
583 m' = fromIntegral . natVal $ (undefined :: Proxy m) 570 r = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z
584 n' = fromIntegral . natVal $ (undefined :: Proxy n) 571 (m',n') = size r
585 572
586 sm :: GSq 573 sm :: GSq
587 sm = lmat 0 [1..] 574 sm = lmat 0 [1..]
@@ -595,7 +582,7 @@ test = (ok,info)
595 m = LA.matrix 3 [1..30] 582 m = LA.matrix 3 [1..30]
596 583
597 precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v 584 precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v
598 precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v 585 precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (LA.size (ud1 u) +1, LA.size (ud1 v)) LA.#> ud1 v
599 586
600 587
601splittest 588splittest