diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 95 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs | 107 |
2 files changed, 101 insertions, 101 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 388d165..213c42c 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -64,7 +64,7 @@ import GHC.TypeLits | |||
64 | import Numeric.LinearAlgebra.HMatrix hiding ( | 64 | import Numeric.LinearAlgebra.HMatrix hiding ( |
65 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vector,matrix,linspace,toRows,toColumns, | 65 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vector,matrix,linspace,toRows,toColumns, |
66 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build, | 66 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build, |
67 | qr) | 67 | qr,size) |
68 | import qualified Numeric.LinearAlgebra.HMatrix as LA | 68 | import qualified Numeric.LinearAlgebra.HMatrix as LA |
69 | import Data.Proxy(Proxy) | 69 | import Data.Proxy(Proxy) |
70 | import Numeric.LinearAlgebra.Static.Internal | 70 | import Numeric.LinearAlgebra.Static.Internal |
@@ -107,20 +107,20 @@ matrix :: (KnownNat m, KnownNat n) => [ℝ] -> L m n | |||
107 | matrix = fromList | 107 | matrix = fromList |
108 | 108 | ||
109 | linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n | 109 | linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n |
110 | linspace (a,b) = mkR (LA.linspace d (a,b)) | 110 | linspace (a,b) = v |
111 | where | 111 | where |
112 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 112 | v = mkR (LA.linspace (size v) (a,b)) |
113 | 113 | ||
114 | range :: forall n . KnownNat n => R n | 114 | range :: forall n . KnownNat n => R n |
115 | range = mkR (LA.linspace d (1,fromIntegral d)) | 115 | range = v |
116 | where | 116 | where |
117 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 117 | v = mkR (LA.linspace d (1,fromIntegral d)) |
118 | d = size v | ||
118 | 119 | ||
119 | dim :: forall n . KnownNat n => R n | 120 | dim :: forall n . KnownNat n => R n |
120 | dim = mkR (scalar d) | 121 | dim = v |
121 | where | 122 | where |
122 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 123 | v = mkR (scalar (fromIntegral $ size v)) |
123 | |||
124 | 124 | ||
125 | -------------------------------------------------------------------------------- | 125 | -------------------------------------------------------------------------------- |
126 | 126 | ||
@@ -140,7 +140,7 @@ eye = diag 1 | |||
140 | -------------------------------------------------------------------------------- | 140 | -------------------------------------------------------------------------------- |
141 | 141 | ||
142 | blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n | 142 | blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n |
143 | blockAt x r c a = mkL res | 143 | blockAt x r c a = res |
144 | where | 144 | where |
145 | z = scalar x | 145 | z = scalar x |
146 | z1 = LA.konst x (r,c) | 146 | z1 = LA.konst x (r,c) |
@@ -148,13 +148,8 @@ blockAt x r c a = mkL res | |||
148 | ra = min (rows a) . max 0 $ m'-r | 148 | ra = min (rows a) . max 0 $ m'-r |
149 | ca = min (cols a) . max 0 $ n'-c | 149 | ca = min (cols a) . max 0 $ n'-c |
150 | sa = subMatrix (0,0) (ra, ca) a | 150 | sa = subMatrix (0,0) (ra, ca) a |
151 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 151 | (m',n') = size res |
152 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | 152 | res = mkL $ fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] |
153 | res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] | ||
154 | |||
155 | |||
156 | |||
157 | |||
158 | 153 | ||
159 | -------------------------------------------------------------------------------- | 154 | -------------------------------------------------------------------------------- |
160 | 155 | ||
@@ -189,22 +184,15 @@ type GL = (KnownNat n, KnownNat m) => L m n | |||
189 | type GSq = KnownNat n => Sq n | 184 | type GSq = KnownNat n => Sq n |
190 | 185 | ||
191 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) | 186 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) |
192 | isKonst (unwrap -> x) | 187 | isKonst s@(unwrap -> x) |
193 | | singleM x = Just (x `atIndex` (0,0), (m',n')) | 188 | | singleM x = Just (x `atIndex` (0,0), (size s)) |
194 | | otherwise = Nothing | 189 | | otherwise = Nothing |
195 | where | ||
196 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
197 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
198 | 190 | ||
199 | 191 | ||
200 | isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int)) | 192 | isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int)) |
201 | isKonstC (unwrap -> x) | 193 | isKonstC s@(unwrap -> x) |
202 | | singleM x = Just (x `atIndex` (0,0), (m',n')) | 194 | | singleM x = Just (x `atIndex` (0,0), (size s)) |
203 | | otherwise = Nothing | 195 | | otherwise = Nothing |
204 | where | ||
205 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
206 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
207 | |||
208 | 196 | ||
209 | 197 | ||
210 | infixr 8 <> | 198 | infixr 8 <> |
@@ -256,7 +244,7 @@ svd (extract -> m) = (mkL u, mkR s', mkL v) | |||
256 | where | 244 | where |
257 | (u,s,v) = LA.svd m | 245 | (u,s,v) = LA.svd m |
258 | s' = vjoin [s, z] | 246 | s' = vjoin [s, z] |
259 | z = LA.konst 0 (max 0 (cols m - size s)) | 247 | z = LA.konst 0 (max 0 (cols m - LA.size s)) |
260 | 248 | ||
261 | 249 | ||
262 | svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) | 250 | svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) |
@@ -333,7 +321,7 @@ withCompactSVD | |||
333 | -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z) | 321 | -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z) |
334 | -> z | 322 | -> z |
335 | withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f = | 323 | withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f = |
336 | case someNatVal $ fromIntegral $ size s of | 324 | case someNatVal $ fromIntegral $ LA.size s of |
337 | Nothing -> error "static/dynamic mismatch" | 325 | Nothing -> error "static/dynamic mismatch" |
338 | Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k) | 326 | Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k) |
339 | 327 | ||
@@ -350,7 +338,7 @@ qr (extract -> x) = (mkL q, mkL r) | |||
350 | 338 | ||
351 | split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) | 339 | split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) |
352 | split (extract -> v) = ( mkR (subVector 0 p' v) , | 340 | split (extract -> v) = ( mkR (subVector 0 p' v) , |
353 | mkR (subVector p' (size v - p') v) ) | 341 | mkR (subVector p' (LA.size v - p') v) ) |
354 | where | 342 | where |
355 | p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int | 343 | p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int |
356 | 344 | ||
@@ -383,10 +371,9 @@ build | |||
383 | :: forall m n . (KnownNat n, KnownNat m) | 371 | :: forall m n . (KnownNat n, KnownNat m) |
384 | => (ℝ -> ℝ -> ℝ) | 372 | => (ℝ -> ℝ -> ℝ) |
385 | -> L m n | 373 | -> L m n |
386 | build f = mkL $ LA.build (m',n') f | 374 | build f = r |
387 | where | 375 | where |
388 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 376 | r = mkL $ LA.build (size r) f |
389 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
390 | 377 | ||
391 | -------------------------------------------------------------------------------- | 378 | -------------------------------------------------------------------------------- |
392 | 379 | ||
@@ -396,7 +383,7 @@ withVector | |||
396 | -> (forall n . (KnownNat n) => R n -> z) | 383 | -> (forall n . (KnownNat n) => R n -> z) |
397 | -> z | 384 | -> z |
398 | withVector v f = | 385 | withVector v f = |
399 | case someNatVal $ fromIntegral $ size v of | 386 | case someNatVal $ fromIntegral $ LA.size v of |
400 | Nothing -> error "static/dynamic mismatch" | 387 | Nothing -> error "static/dynamic mismatch" |
401 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) | 388 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) |
402 | 389 | ||
@@ -451,19 +438,19 @@ mulR (isKonst -> Just (a,(_,k))) (isKonst -> Just (b,_)) = konst (a * b * fromIn | |||
451 | mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) | 438 | mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) |
452 | where | 439 | where |
453 | v = a' * b' | 440 | v = a' * b' |
454 | n = min (size a) (size b) | 441 | n = min (LA.size a) (LA.size b) |
455 | a' = subVector 0 n a | 442 | a' = subVector 0 n a |
456 | b' = subVector 0 n b | 443 | b' = subVector 0 n b |
457 | 444 | ||
458 | mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (size a) b) | 445 | mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (LA.size a) b) |
459 | 446 | ||
460 | mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) | 447 | mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (LA.size b) a * asRow b) |
461 | 448 | ||
462 | mulR a b = mkL (extract a LA.<> extract b) | 449 | mulR a b = mkL (extract a LA.<> extract b) |
463 | 450 | ||
464 | 451 | ||
465 | appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m | 452 | appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m |
466 | appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (size w) (extract v)) | 453 | appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (LA.size w) (extract v)) |
467 | appR m v = mkR (extract m LA.#> extract v) | 454 | appR m v = mkR (extract m LA.#> extract v) |
468 | 455 | ||
469 | 456 | ||
@@ -489,19 +476,19 @@ mulC (isKonstC -> Just (a,(_,k))) (isKonstC -> Just (b,_)) = konst (a * b * from | |||
489 | mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k) | 476 | mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k) |
490 | where | 477 | where |
491 | v = a' * b' | 478 | v = a' * b' |
492 | n = min (size a) (size b) | 479 | n = min (LA.size a) (LA.size b) |
493 | a' = subVector 0 n a | 480 | a' = subVector 0 n a |
494 | b' = subVector 0 n b | 481 | b' = subVector 0 n b |
495 | 482 | ||
496 | mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (size a) b) | 483 | mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (LA.size a) b) |
497 | 484 | ||
498 | mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (size b) a * asRow b) | 485 | mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (LA.size b) a * asRow b) |
499 | 486 | ||
500 | mulC a b = mkM (extract a LA.<> extract b) | 487 | mulC a b = mkM (extract a LA.<> extract b) |
501 | 488 | ||
502 | 489 | ||
503 | appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m | 490 | appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m |
504 | appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (size w) (extract v)) | 491 | appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (LA.size w) (extract v)) |
505 | appC m v = mkC (extract m LA.#> extract v) | 492 | appC m v = mkC (extract m LA.#> extract v) |
506 | 493 | ||
507 | 494 | ||
@@ -521,21 +508,21 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) | |||
521 | -------------------------------------------------------------------------------- | 508 | -------------------------------------------------------------------------------- |
522 | 509 | ||
523 | diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n | 510 | diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n |
524 | diagRectR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) | 511 | diagRectR x v = r |
525 | where | 512 | where |
513 | r = mkL (asRow (vjoin [scalar x, ev, zeros])) | ||
526 | ev = extract v | 514 | ev = extract v |
527 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | 515 | zeros = LA.konst x (max 0 ((min m' n') - LA.size ev)) |
528 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 516 | (m',n') = size r |
529 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
530 | 517 | ||
531 | 518 | ||
532 | diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n | 519 | diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n |
533 | diagRectC x v = mkM (asRow (vjoin [scalar x, ev, zeros])) | 520 | diagRectC x v = r |
534 | where | 521 | where |
522 | r = mkM (asRow (vjoin [scalar x, ev, zeros])) | ||
535 | ev = extract v | 523 | ev = extract v |
536 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | 524 | zeros = LA.konst x (max 0 ((min m' n') - LA.size ev)) |
537 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 525 | (m',n') = size r |
538 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
539 | 526 | ||
540 | -------------------------------------------------------------------------------- | 527 | -------------------------------------------------------------------------------- |
541 | 528 | ||
@@ -578,10 +565,10 @@ test = (ok,info) | |||
578 | tm = lmat 0 [1..] | 565 | tm = lmat 0 [1..] |
579 | 566 | ||
580 | lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n | 567 | lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n |
581 | lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z | 568 | lmat z xs = r |
582 | where | 569 | where |
583 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 570 | r = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z |
584 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | 571 | (m',n') = size r |
585 | 572 | ||
586 | sm :: GSq | 573 | sm :: GSq |
587 | sm = lmat 0 [1..] | 574 | sm = lmat 0 [1..] |
@@ -595,7 +582,7 @@ test = (ok,info) | |||
595 | m = LA.matrix 3 [1..30] | 582 | m = LA.matrix 3 [1..30] |
596 | 583 | ||
597 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v | 584 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v |
598 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v | 585 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (LA.size (ud1 u) +1, LA.size (ud1 v)) LA.#> ud1 v |
599 | 586 | ||
600 | 587 | ||
601 | splittest | 588 | splittest |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs index 7968d77..339ef7d 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs | |||
@@ -7,13 +7,10 @@ | |||
7 | {-# LANGUAGE FunctionalDependencies #-} | 7 | {-# LANGUAGE FunctionalDependencies #-} |
8 | {-# LANGUAGE FlexibleContexts #-} | 8 | {-# LANGUAGE FlexibleContexts #-} |
9 | {-# LANGUAGE ScopedTypeVariables #-} | 9 | {-# LANGUAGE ScopedTypeVariables #-} |
10 | {-# LANGUAGE EmptyDataDecls #-} | ||
11 | {-# LANGUAGE Rank2Types #-} | 10 | {-# LANGUAGE Rank2Types #-} |
12 | {-# LANGUAGE FlexibleInstances #-} | 11 | {-# LANGUAGE FlexibleInstances #-} |
13 | {-# LANGUAGE TypeOperators #-} | 12 | {-# LANGUAGE TypeOperators #-} |
14 | {-# LANGUAGE ViewPatterns #-} | 13 | {-# LANGUAGE ViewPatterns #-} |
15 | {-# LANGUAGE GADTs #-} | ||
16 | |||
17 | 14 | ||
18 | {- | | 15 | {- | |
19 | Module : Numeric.LinearAlgebra.Static.Internal | 16 | Module : Numeric.LinearAlgebra.Static.Internal |
@@ -28,7 +25,7 @@ module Numeric.LinearAlgebra.Static.Internal where | |||
28 | 25 | ||
29 | import GHC.TypeLits | 26 | import GHC.TypeLits |
30 | import qualified Numeric.LinearAlgebra.HMatrix as LA | 27 | import qualified Numeric.LinearAlgebra.HMatrix as LA |
31 | import Numeric.LinearAlgebra.HMatrix hiding (konst) | 28 | import Numeric.LinearAlgebra.HMatrix hiding (konst,size) |
32 | import Data.Packed as D | 29 | import Data.Packed as D |
33 | import Data.Packed.ST | 30 | import Data.Packed.ST |
34 | import Data.Proxy(Proxy) | 31 | import Data.Proxy(Proxy) |
@@ -83,7 +80,7 @@ ud :: Dim n (Vector t) -> Vector t | |||
83 | ud (Dim v) = v | 80 | ud (Dim v) = v |
84 | 81 | ||
85 | mkV :: forall (n :: Nat) t . t -> Dim n t | 82 | mkV :: forall (n :: Nat) t . t -> Dim n t |
86 | mkV = Dim | 83 | mkV = Dim |
87 | 84 | ||
88 | 85 | ||
89 | vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) | 86 | vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) |
@@ -92,9 +89,9 @@ vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) | |||
92 | where | 89 | where |
93 | du = fromIntegral . natVal $ (undefined :: Proxy n) | 90 | du = fromIntegral . natVal $ (undefined :: Proxy n) |
94 | dv = fromIntegral . natVal $ (undefined :: Proxy m) | 91 | dv = fromIntegral . natVal $ (undefined :: Proxy m) |
95 | u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du | 92 | u' | du > 1 && LA.size u == 1 = LA.konst (u D.@> 0) du |
96 | | otherwise = u | 93 | | otherwise = u |
97 | v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv | 94 | v' | dv > 1 && LA.size v == 1 = LA.konst (v D.@> 0) dv |
98 | | otherwise = v | 95 | | otherwise = v |
99 | 96 | ||
100 | 97 | ||
@@ -132,7 +129,7 @@ gvect st xs' | |||
132 | | otherwise = abort (show xs) | 129 | | otherwise = abort (show xs) |
133 | where | 130 | where |
134 | (xs,rest) = splitAt d xs' | 131 | (xs,rest) = splitAt d xs' |
135 | ok = size v == d && null rest | 132 | ok = LA.size v == d && null rest |
136 | v = LA.fromList xs | 133 | v = LA.fromList xs |
137 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 134 | d = fromIntegral . natVal $ (undefined :: Proxy n) |
138 | abort info = error $ st++" "++show d++" can't be created from elements "++info | 135 | abort info = error $ st++" "++show d++" can't be created from elements "++info |
@@ -153,7 +150,7 @@ gmat st xs' | |||
153 | (xs,rest) = splitAt (m'*n') xs' | 150 | (xs,rest) = splitAt (m'*n') xs' |
154 | v = LA.fromList xs | 151 | v = LA.fromList xs |
155 | x = reshape n' v | 152 | x = reshape n' v |
156 | ok = rem (size v) n' == 0 && size x == (m',n') && null rest | 153 | ok = rem (LA.size v) n' == 0 && LA.size x == (m',n') && null rest |
157 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 154 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int |
158 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 155 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
159 | abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info | 156 | abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info |
@@ -162,66 +159,84 @@ gmat st xs' | |||
162 | 159 | ||
163 | class Num t => Sized t s d | s -> t, s -> d | 160 | class Num t => Sized t s d | s -> t, s -> d |
164 | where | 161 | where |
165 | konst :: t -> s | 162 | konst :: t -> s |
166 | unwrap :: s -> d | 163 | unwrap :: s -> d t |
167 | fromList :: [t] -> s | 164 | fromList :: [t] -> s |
168 | extract :: s -> d | 165 | extract :: s -> d t |
169 | 166 | create :: d t -> Maybe s | |
170 | singleV v = size v == 1 | 167 | size :: s -> IndexOf d |
168 | |||
169 | singleV v = LA.size v == 1 | ||
171 | singleM m = rows m == 1 && cols m == 1 | 170 | singleM m = rows m == 1 && cols m == 1 |
172 | 171 | ||
173 | 172 | ||
174 | instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ) | 173 | instance forall n. KnownNat n => Sized ℂ (C n) Vector |
175 | where | 174 | where |
175 | size _ = fromIntegral . natVal $ (undefined :: Proxy n) | ||
176 | konst x = mkC (LA.scalar x) | 176 | konst x = mkC (LA.scalar x) |
177 | unwrap (C (Dim v)) = v | 177 | unwrap (C (Dim v)) = v |
178 | fromList xs = C (gvect "C" xs) | 178 | fromList xs = C (gvect "C" xs) |
179 | extract (unwrap -> v) | 179 | extract s@(unwrap -> v) |
180 | | singleV v = LA.konst (v!0) d | 180 | | singleV v = LA.konst (v!0) (size s) |
181 | | otherwise = v | 181 | | otherwise = v |
182 | where | 182 | create v |
183 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 183 | | LA.size v == size r = Just r |
184 | | otherwise = Nothing | ||
185 | where | ||
186 | r = mkC v :: C n | ||
184 | 187 | ||
185 | 188 | ||
186 | instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) | 189 | instance forall n. KnownNat n => Sized ℝ (R n) Vector |
187 | where | 190 | where |
191 | size _ = fromIntegral . natVal $ (undefined :: Proxy n) | ||
188 | konst x = mkR (LA.scalar x) | 192 | konst x = mkR (LA.scalar x) |
189 | unwrap (R (Dim v)) = v | 193 | unwrap (R (Dim v)) = v |
190 | fromList xs = R (gvect "R" xs) | 194 | fromList xs = R (gvect "R" xs) |
191 | extract (unwrap -> v) | 195 | extract s@(unwrap -> v) |
192 | | singleV v = LA.konst (v!0) d | 196 | | singleV v = LA.konst (v!0) (size s) |
193 | | otherwise = v | 197 | | otherwise = v |
194 | where | 198 | create v |
195 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 199 | | LA.size v == size r = Just r |
200 | | otherwise = Nothing | ||
201 | where | ||
202 | r = mkR v :: R n | ||
196 | 203 | ||
197 | 204 | ||
198 | 205 | ||
199 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) | 206 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) Matrix |
200 | where | 207 | where |
208 | size _ = ((fromIntegral . natVal) (undefined :: Proxy m) | ||
209 | ,(fromIntegral . natVal) (undefined :: Proxy n)) | ||
201 | konst x = mkL (LA.scalar x) | 210 | konst x = mkL (LA.scalar x) |
202 | fromList xs = L (gmat "L" xs) | 211 | fromList xs = L (gmat "L" xs) |
203 | unwrap (L (Dim (Dim m))) = m | 212 | unwrap (L (Dim (Dim m))) = m |
204 | extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' | 213 | extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' |
205 | extract (unwrap -> a) | 214 | extract s@(unwrap -> a) |
206 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | 215 | | singleM a = LA.konst (a `atIndex` (0,0)) (size s) |
207 | | otherwise = a | 216 | | otherwise = a |
217 | create x | ||
218 | | LA.size x == size r = Just r | ||
219 | | otherwise = Nothing | ||
208 | where | 220 | where |
209 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 221 | r = mkL x :: L m n |
210 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
211 | 222 | ||
212 | 223 | ||
213 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ) | 224 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) Matrix |
214 | where | 225 | where |
226 | size _ = ((fromIntegral . natVal) (undefined :: Proxy m) | ||
227 | ,(fromIntegral . natVal) (undefined :: Proxy n)) | ||
215 | konst x = mkM (LA.scalar x) | 228 | konst x = mkM (LA.scalar x) |
216 | fromList xs = M (gmat "M" xs) | 229 | fromList xs = M (gmat "M" xs) |
217 | unwrap (M (Dim (Dim m))) = m | 230 | unwrap (M (Dim (Dim m))) = m |
218 | extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' | 231 | extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' |
219 | extract (unwrap -> a) | 232 | extract s@(unwrap -> a) |
220 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | 233 | | singleM a = LA.konst (a `atIndex` (0,0)) (size s) |
221 | | otherwise = a | 234 | | otherwise = a |
235 | create x | ||
236 | | LA.size x == size r = Just r | ||
237 | | otherwise = Nothing | ||
222 | where | 238 | where |
223 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 239 | r = mkM x :: M m n |
224 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
225 | 240 | ||
226 | -------------------------------------------------------------------------------- | 241 | -------------------------------------------------------------------------------- |
227 | 242 | ||
@@ -254,8 +269,8 @@ isDiagg (Dim (Dim x)) | |||
254 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 269 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
255 | v = flatten x | 270 | v = flatten x |
256 | z = v `atIndex` 0 | 271 | z = v `atIndex` 0 |
257 | y = subVector 1 (size v-1) v | 272 | y = subVector 1 (LA.size v-1) v |
258 | ny = size y | 273 | ny = LA.size y |
259 | zeros = LA.konst 0 (max 0 (min m' n' - ny)) | 274 | zeros = LA.konst 0 (max 0 (min m' n' - ny)) |
260 | yz = vjoin [y,zeros] | 275 | yz = vjoin [y,zeros] |
261 | 276 | ||
@@ -263,39 +278,37 @@ isDiagg (Dim (Dim x)) | |||
263 | 278 | ||
264 | instance forall n . KnownNat n => Show (R n) | 279 | instance forall n . KnownNat n => Show (R n) |
265 | where | 280 | where |
266 | show (R (Dim v)) | 281 | show s@(R (Dim v)) |
267 | | singleV v = "("++show (v!0)++" :: R "++show d++")" | 282 | | singleV v = "("++show (v!0)++" :: R "++show d++")" |
268 | | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" | 283 | | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" |
269 | where | 284 | where |
270 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 285 | d = size s |
271 | 286 | ||
272 | instance forall n . KnownNat n => Show (C n) | 287 | instance forall n . KnownNat n => Show (C n) |
273 | where | 288 | where |
274 | show (C (Dim v)) | 289 | show s@(C (Dim v)) |
275 | | singleV v = "("++show (v!0)++" :: C "++show d++")" | 290 | | singleV v = "("++show (v!0)++" :: C "++show d++")" |
276 | | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" | 291 | | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" |
277 | where | 292 | where |
278 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 293 | d = size s |
279 | 294 | ||
280 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) | 295 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) |
281 | where | 296 | where |
282 | show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' | 297 | show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' |
283 | show (L (Dim (Dim x))) | 298 | show s@(L (Dim (Dim x))) |
284 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' | 299 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' |
285 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" | 300 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" |
286 | where | 301 | where |
287 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 302 | (m',n') = size s |
288 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
289 | 303 | ||
290 | instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) | 304 | instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) |
291 | where | 305 | where |
292 | show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' | 306 | show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' |
293 | show (M (Dim (Dim x))) | 307 | show s@(M (Dim (Dim x))) |
294 | | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' | 308 | | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' |
295 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" | 309 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" |
296 | where | 310 | where |
297 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 311 | (m',n') = size s |
298 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
299 | 312 | ||
300 | -------------------------------------------------------------------------------- | 313 | -------------------------------------------------------------------------------- |
301 | 314 | ||