diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static.hs')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 95 |
1 files changed, 41 insertions, 54 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 388d165..213c42c 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -64,7 +64,7 @@ import GHC.TypeLits | |||
64 | import Numeric.LinearAlgebra.HMatrix hiding ( | 64 | import Numeric.LinearAlgebra.HMatrix hiding ( |
65 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vector,matrix,linspace,toRows,toColumns, | 65 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vector,matrix,linspace,toRows,toColumns, |
66 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build, | 66 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build, |
67 | qr) | 67 | qr,size) |
68 | import qualified Numeric.LinearAlgebra.HMatrix as LA | 68 | import qualified Numeric.LinearAlgebra.HMatrix as LA |
69 | import Data.Proxy(Proxy) | 69 | import Data.Proxy(Proxy) |
70 | import Numeric.LinearAlgebra.Static.Internal | 70 | import Numeric.LinearAlgebra.Static.Internal |
@@ -107,20 +107,20 @@ matrix :: (KnownNat m, KnownNat n) => [ℝ] -> L m n | |||
107 | matrix = fromList | 107 | matrix = fromList |
108 | 108 | ||
109 | linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n | 109 | linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n |
110 | linspace (a,b) = mkR (LA.linspace d (a,b)) | 110 | linspace (a,b) = v |
111 | where | 111 | where |
112 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 112 | v = mkR (LA.linspace (size v) (a,b)) |
113 | 113 | ||
114 | range :: forall n . KnownNat n => R n | 114 | range :: forall n . KnownNat n => R n |
115 | range = mkR (LA.linspace d (1,fromIntegral d)) | 115 | range = v |
116 | where | 116 | where |
117 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 117 | v = mkR (LA.linspace d (1,fromIntegral d)) |
118 | d = size v | ||
118 | 119 | ||
119 | dim :: forall n . KnownNat n => R n | 120 | dim :: forall n . KnownNat n => R n |
120 | dim = mkR (scalar d) | 121 | dim = v |
121 | where | 122 | where |
122 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 123 | v = mkR (scalar (fromIntegral $ size v)) |
123 | |||
124 | 124 | ||
125 | -------------------------------------------------------------------------------- | 125 | -------------------------------------------------------------------------------- |
126 | 126 | ||
@@ -140,7 +140,7 @@ eye = diag 1 | |||
140 | -------------------------------------------------------------------------------- | 140 | -------------------------------------------------------------------------------- |
141 | 141 | ||
142 | blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n | 142 | blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n |
143 | blockAt x r c a = mkL res | 143 | blockAt x r c a = res |
144 | where | 144 | where |
145 | z = scalar x | 145 | z = scalar x |
146 | z1 = LA.konst x (r,c) | 146 | z1 = LA.konst x (r,c) |
@@ -148,13 +148,8 @@ blockAt x r c a = mkL res | |||
148 | ra = min (rows a) . max 0 $ m'-r | 148 | ra = min (rows a) . max 0 $ m'-r |
149 | ca = min (cols a) . max 0 $ n'-c | 149 | ca = min (cols a) . max 0 $ n'-c |
150 | sa = subMatrix (0,0) (ra, ca) a | 150 | sa = subMatrix (0,0) (ra, ca) a |
151 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 151 | (m',n') = size res |
152 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | 152 | res = mkL $ fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] |
153 | res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] | ||
154 | |||
155 | |||
156 | |||
157 | |||
158 | 153 | ||
159 | -------------------------------------------------------------------------------- | 154 | -------------------------------------------------------------------------------- |
160 | 155 | ||
@@ -189,22 +184,15 @@ type GL = (KnownNat n, KnownNat m) => L m n | |||
189 | type GSq = KnownNat n => Sq n | 184 | type GSq = KnownNat n => Sq n |
190 | 185 | ||
191 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) | 186 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) |
192 | isKonst (unwrap -> x) | 187 | isKonst s@(unwrap -> x) |
193 | | singleM x = Just (x `atIndex` (0,0), (m',n')) | 188 | | singleM x = Just (x `atIndex` (0,0), (size s)) |
194 | | otherwise = Nothing | 189 | | otherwise = Nothing |
195 | where | ||
196 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
197 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
198 | 190 | ||
199 | 191 | ||
200 | isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int)) | 192 | isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int)) |
201 | isKonstC (unwrap -> x) | 193 | isKonstC s@(unwrap -> x) |
202 | | singleM x = Just (x `atIndex` (0,0), (m',n')) | 194 | | singleM x = Just (x `atIndex` (0,0), (size s)) |
203 | | otherwise = Nothing | 195 | | otherwise = Nothing |
204 | where | ||
205 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
206 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
207 | |||
208 | 196 | ||
209 | 197 | ||
210 | infixr 8 <> | 198 | infixr 8 <> |
@@ -256,7 +244,7 @@ svd (extract -> m) = (mkL u, mkR s', mkL v) | |||
256 | where | 244 | where |
257 | (u,s,v) = LA.svd m | 245 | (u,s,v) = LA.svd m |
258 | s' = vjoin [s, z] | 246 | s' = vjoin [s, z] |
259 | z = LA.konst 0 (max 0 (cols m - size s)) | 247 | z = LA.konst 0 (max 0 (cols m - LA.size s)) |
260 | 248 | ||
261 | 249 | ||
262 | svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) | 250 | svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) |
@@ -333,7 +321,7 @@ withCompactSVD | |||
333 | -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z) | 321 | -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z) |
334 | -> z | 322 | -> z |
335 | withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f = | 323 | withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f = |
336 | case someNatVal $ fromIntegral $ size s of | 324 | case someNatVal $ fromIntegral $ LA.size s of |
337 | Nothing -> error "static/dynamic mismatch" | 325 | Nothing -> error "static/dynamic mismatch" |
338 | Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k) | 326 | Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k) |
339 | 327 | ||
@@ -350,7 +338,7 @@ qr (extract -> x) = (mkL q, mkL r) | |||
350 | 338 | ||
351 | split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) | 339 | split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) |
352 | split (extract -> v) = ( mkR (subVector 0 p' v) , | 340 | split (extract -> v) = ( mkR (subVector 0 p' v) , |
353 | mkR (subVector p' (size v - p') v) ) | 341 | mkR (subVector p' (LA.size v - p') v) ) |
354 | where | 342 | where |
355 | p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int | 343 | p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int |
356 | 344 | ||
@@ -383,10 +371,9 @@ build | |||
383 | :: forall m n . (KnownNat n, KnownNat m) | 371 | :: forall m n . (KnownNat n, KnownNat m) |
384 | => (ℝ -> ℝ -> ℝ) | 372 | => (ℝ -> ℝ -> ℝ) |
385 | -> L m n | 373 | -> L m n |
386 | build f = mkL $ LA.build (m',n') f | 374 | build f = r |
387 | where | 375 | where |
388 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 376 | r = mkL $ LA.build (size r) f |
389 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
390 | 377 | ||
391 | -------------------------------------------------------------------------------- | 378 | -------------------------------------------------------------------------------- |
392 | 379 | ||
@@ -396,7 +383,7 @@ withVector | |||
396 | -> (forall n . (KnownNat n) => R n -> z) | 383 | -> (forall n . (KnownNat n) => R n -> z) |
397 | -> z | 384 | -> z |
398 | withVector v f = | 385 | withVector v f = |
399 | case someNatVal $ fromIntegral $ size v of | 386 | case someNatVal $ fromIntegral $ LA.size v of |
400 | Nothing -> error "static/dynamic mismatch" | 387 | Nothing -> error "static/dynamic mismatch" |
401 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) | 388 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) |
402 | 389 | ||
@@ -451,19 +438,19 @@ mulR (isKonst -> Just (a,(_,k))) (isKonst -> Just (b,_)) = konst (a * b * fromIn | |||
451 | mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) | 438 | mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) |
452 | where | 439 | where |
453 | v = a' * b' | 440 | v = a' * b' |
454 | n = min (size a) (size b) | 441 | n = min (LA.size a) (LA.size b) |
455 | a' = subVector 0 n a | 442 | a' = subVector 0 n a |
456 | b' = subVector 0 n b | 443 | b' = subVector 0 n b |
457 | 444 | ||
458 | mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (size a) b) | 445 | mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (LA.size a) b) |
459 | 446 | ||
460 | mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) | 447 | mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (LA.size b) a * asRow b) |
461 | 448 | ||
462 | mulR a b = mkL (extract a LA.<> extract b) | 449 | mulR a b = mkL (extract a LA.<> extract b) |
463 | 450 | ||
464 | 451 | ||
465 | appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m | 452 | appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m |
466 | appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (size w) (extract v)) | 453 | appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (LA.size w) (extract v)) |
467 | appR m v = mkR (extract m LA.#> extract v) | 454 | appR m v = mkR (extract m LA.#> extract v) |
468 | 455 | ||
469 | 456 | ||
@@ -489,19 +476,19 @@ mulC (isKonstC -> Just (a,(_,k))) (isKonstC -> Just (b,_)) = konst (a * b * from | |||
489 | mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k) | 476 | mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k) |
490 | where | 477 | where |
491 | v = a' * b' | 478 | v = a' * b' |
492 | n = min (size a) (size b) | 479 | n = min (LA.size a) (LA.size b) |
493 | a' = subVector 0 n a | 480 | a' = subVector 0 n a |
494 | b' = subVector 0 n b | 481 | b' = subVector 0 n b |
495 | 482 | ||
496 | mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (size a) b) | 483 | mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (LA.size a) b) |
497 | 484 | ||
498 | mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (size b) a * asRow b) | 485 | mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (LA.size b) a * asRow b) |
499 | 486 | ||
500 | mulC a b = mkM (extract a LA.<> extract b) | 487 | mulC a b = mkM (extract a LA.<> extract b) |
501 | 488 | ||
502 | 489 | ||
503 | appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m | 490 | appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m |
504 | appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (size w) (extract v)) | 491 | appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (LA.size w) (extract v)) |
505 | appC m v = mkC (extract m LA.#> extract v) | 492 | appC m v = mkC (extract m LA.#> extract v) |
506 | 493 | ||
507 | 494 | ||
@@ -521,21 +508,21 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) | |||
521 | -------------------------------------------------------------------------------- | 508 | -------------------------------------------------------------------------------- |
522 | 509 | ||
523 | diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n | 510 | diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n |
524 | diagRectR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) | 511 | diagRectR x v = r |
525 | where | 512 | where |
513 | r = mkL (asRow (vjoin [scalar x, ev, zeros])) | ||
526 | ev = extract v | 514 | ev = extract v |
527 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | 515 | zeros = LA.konst x (max 0 ((min m' n') - LA.size ev)) |
528 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 516 | (m',n') = size r |
529 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
530 | 517 | ||
531 | 518 | ||
532 | diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n | 519 | diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n |
533 | diagRectC x v = mkM (asRow (vjoin [scalar x, ev, zeros])) | 520 | diagRectC x v = r |
534 | where | 521 | where |
522 | r = mkM (asRow (vjoin [scalar x, ev, zeros])) | ||
535 | ev = extract v | 523 | ev = extract v |
536 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | 524 | zeros = LA.konst x (max 0 ((min m' n') - LA.size ev)) |
537 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 525 | (m',n') = size r |
538 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
539 | 526 | ||
540 | -------------------------------------------------------------------------------- | 527 | -------------------------------------------------------------------------------- |
541 | 528 | ||
@@ -578,10 +565,10 @@ test = (ok,info) | |||
578 | tm = lmat 0 [1..] | 565 | tm = lmat 0 [1..] |
579 | 566 | ||
580 | lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n | 567 | lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n |
581 | lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z | 568 | lmat z xs = r |
582 | where | 569 | where |
583 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 570 | r = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z |
584 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | 571 | (m',n') = size r |
585 | 572 | ||
586 | sm :: GSq | 573 | sm :: GSq |
587 | sm = lmat 0 [1..] | 574 | sm = lmat 0 [1..] |
@@ -595,7 +582,7 @@ test = (ok,info) | |||
595 | m = LA.matrix 3 [1..30] | 582 | m = LA.matrix 3 [1..30] |
596 | 583 | ||
597 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v | 584 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v |
598 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v | 585 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (LA.size (ud1 u) +1, LA.size (ud1 v)) LA.#> ud1 v |
599 | 586 | ||
600 | 587 | ||
601 | splittest | 588 | splittest |