summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs95
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs107
2 files changed, 101 insertions, 101 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index 388d165..213c42c 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -64,7 +64,7 @@ import GHC.TypeLits
64import Numeric.LinearAlgebra.HMatrix hiding ( 64import Numeric.LinearAlgebra.HMatrix hiding (
65 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vector,matrix,linspace,toRows,toColumns, 65 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vector,matrix,linspace,toRows,toColumns,
66 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build, 66 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build,
67 qr) 67 qr,size)
68import qualified Numeric.LinearAlgebra.HMatrix as LA 68import qualified Numeric.LinearAlgebra.HMatrix as LA
69import Data.Proxy(Proxy) 69import Data.Proxy(Proxy)
70import Numeric.LinearAlgebra.Static.Internal 70import Numeric.LinearAlgebra.Static.Internal
@@ -107,20 +107,20 @@ matrix :: (KnownNat m, KnownNat n) => [ℝ] -> L m n
107matrix = fromList 107matrix = fromList
108 108
109linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n 109linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n
110linspace (a,b) = mkR (LA.linspace d (a,b)) 110linspace (a,b) = v
111 where 111 where
112 d = fromIntegral . natVal $ (undefined :: Proxy n) 112 v = mkR (LA.linspace (size v) (a,b))
113 113
114range :: forall n . KnownNat n => R n 114range :: forall n . KnownNat n => R n
115range = mkR (LA.linspace d (1,fromIntegral d)) 115range = v
116 where 116 where
117 d = fromIntegral . natVal $ (undefined :: Proxy n) 117 v = mkR (LA.linspace d (1,fromIntegral d))
118 d = size v
118 119
119dim :: forall n . KnownNat n => R n 120dim :: forall n . KnownNat n => R n
120dim = mkR (scalar d) 121dim = v
121 where 122 where
122 d = fromIntegral . natVal $ (undefined :: Proxy n) 123 v = mkR (scalar (fromIntegral $ size v))
123
124 124
125-------------------------------------------------------------------------------- 125--------------------------------------------------------------------------------
126 126
@@ -140,7 +140,7 @@ eye = diag 1
140-------------------------------------------------------------------------------- 140--------------------------------------------------------------------------------
141 141
142blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n 142blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n
143blockAt x r c a = mkL res 143blockAt x r c a = res
144 where 144 where
145 z = scalar x 145 z = scalar x
146 z1 = LA.konst x (r,c) 146 z1 = LA.konst x (r,c)
@@ -148,13 +148,8 @@ blockAt x r c a = mkL res
148 ra = min (rows a) . max 0 $ m'-r 148 ra = min (rows a) . max 0 $ m'-r
149 ca = min (cols a) . max 0 $ n'-c 149 ca = min (cols a) . max 0 $ n'-c
150 sa = subMatrix (0,0) (ra, ca) a 150 sa = subMatrix (0,0) (ra, ca) a
151 m' = fromIntegral . natVal $ (undefined :: Proxy m) 151 (m',n') = size res
152 n' = fromIntegral . natVal $ (undefined :: Proxy n) 152 res = mkL $ fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]]
153 res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]]
154
155
156
157
158 153
159-------------------------------------------------------------------------------- 154--------------------------------------------------------------------------------
160 155
@@ -189,22 +184,15 @@ type GL = (KnownNat n, KnownNat m) => L m n
189type GSq = KnownNat n => Sq n 184type GSq = KnownNat n => Sq n
190 185
191isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) 186isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int))
192isKonst (unwrap -> x) 187isKonst s@(unwrap -> x)
193 | singleM x = Just (x `atIndex` (0,0), (m',n')) 188 | singleM x = Just (x `atIndex` (0,0), (size s))
194 | otherwise = Nothing 189 | otherwise = Nothing
195 where
196 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
197 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
198 190
199 191
200isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int)) 192isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int))
201isKonstC (unwrap -> x) 193isKonstC s@(unwrap -> x)
202 | singleM x = Just (x `atIndex` (0,0), (m',n')) 194 | singleM x = Just (x `atIndex` (0,0), (size s))
203 | otherwise = Nothing 195 | otherwise = Nothing
204 where
205 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
206 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
207
208 196
209 197
210infixr 8 <> 198infixr 8 <>
@@ -256,7 +244,7 @@ svd (extract -> m) = (mkL u, mkR s', mkL v)
256 where 244 where
257 (u,s,v) = LA.svd m 245 (u,s,v) = LA.svd m
258 s' = vjoin [s, z] 246 s' = vjoin [s, z]
259 z = LA.konst 0 (max 0 (cols m - size s)) 247 z = LA.konst 0 (max 0 (cols m - LA.size s))
260 248
261 249
262svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) 250svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n)
@@ -333,7 +321,7 @@ withCompactSVD
333 -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z) 321 -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z)
334 -> z 322 -> z
335withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f = 323withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f =
336 case someNatVal $ fromIntegral $ size s of 324 case someNatVal $ fromIntegral $ LA.size s of
337 Nothing -> error "static/dynamic mismatch" 325 Nothing -> error "static/dynamic mismatch"
338 Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k) 326 Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k)
339 327
@@ -350,7 +338,7 @@ qr (extract -> x) = (mkL q, mkL r)
350 338
351split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) 339split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p))
352split (extract -> v) = ( mkR (subVector 0 p' v) , 340split (extract -> v) = ( mkR (subVector 0 p' v) ,
353 mkR (subVector p' (size v - p') v) ) 341 mkR (subVector p' (LA.size v - p') v) )
354 where 342 where
355 p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int 343 p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int
356 344
@@ -383,10 +371,9 @@ build
383 :: forall m n . (KnownNat n, KnownNat m) 371 :: forall m n . (KnownNat n, KnownNat m)
384 => (ℝ -> ℝ -> ℝ) 372 => (ℝ -> ℝ -> ℝ)
385 -> L m n 373 -> L m n
386build f = mkL $ LA.build (m',n') f 374build f = r
387 where 375 where
388 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 376 r = mkL $ LA.build (size r) f
389 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
390 377
391-------------------------------------------------------------------------------- 378--------------------------------------------------------------------------------
392 379
@@ -396,7 +383,7 @@ withVector
396 -> (forall n . (KnownNat n) => R n -> z) 383 -> (forall n . (KnownNat n) => R n -> z)
397 -> z 384 -> z
398withVector v f = 385withVector v f =
399 case someNatVal $ fromIntegral $ size v of 386 case someNatVal $ fromIntegral $ LA.size v of
400 Nothing -> error "static/dynamic mismatch" 387 Nothing -> error "static/dynamic mismatch"
401 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) 388 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m)
402 389
@@ -451,19 +438,19 @@ mulR (isKonst -> Just (a,(_,k))) (isKonst -> Just (b,_)) = konst (a * b * fromIn
451mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) 438mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k)
452 where 439 where
453 v = a' * b' 440 v = a' * b'
454 n = min (size a) (size b) 441 n = min (LA.size a) (LA.size b)
455 a' = subVector 0 n a 442 a' = subVector 0 n a
456 b' = subVector 0 n b 443 b' = subVector 0 n b
457 444
458mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (size a) b) 445mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (LA.size a) b)
459 446
460mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) 447mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (LA.size b) a * asRow b)
461 448
462mulR a b = mkL (extract a LA.<> extract b) 449mulR a b = mkL (extract a LA.<> extract b)
463 450
464 451
465appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m 452appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m
466appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (size w) (extract v)) 453appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (LA.size w) (extract v))
467appR m v = mkR (extract m LA.#> extract v) 454appR m v = mkR (extract m LA.#> extract v)
468 455
469 456
@@ -489,19 +476,19 @@ mulC (isKonstC -> Just (a,(_,k))) (isKonstC -> Just (b,_)) = konst (a * b * from
489mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k) 476mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k)
490 where 477 where
491 v = a' * b' 478 v = a' * b'
492 n = min (size a) (size b) 479 n = min (LA.size a) (LA.size b)
493 a' = subVector 0 n a 480 a' = subVector 0 n a
494 b' = subVector 0 n b 481 b' = subVector 0 n b
495 482
496mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (size a) b) 483mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (LA.size a) b)
497 484
498mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (size b) a * asRow b) 485mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (LA.size b) a * asRow b)
499 486
500mulC a b = mkM (extract a LA.<> extract b) 487mulC a b = mkM (extract a LA.<> extract b)
501 488
502 489
503appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m 490appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m
504appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (size w) (extract v)) 491appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (LA.size w) (extract v))
505appC m v = mkC (extract m LA.#> extract v) 492appC m v = mkC (extract m LA.#> extract v)
506 493
507 494
@@ -521,21 +508,21 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3])
521-------------------------------------------------------------------------------- 508--------------------------------------------------------------------------------
522 509
523diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n 510diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n
524diagRectR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) 511diagRectR x v = r
525 where 512 where
513 r = mkL (asRow (vjoin [scalar x, ev, zeros]))
526 ev = extract v 514 ev = extract v
527 zeros = LA.konst x (max 0 ((min m' n') - size ev)) 515 zeros = LA.konst x (max 0 ((min m' n') - LA.size ev))
528 m' = fromIntegral . natVal $ (undefined :: Proxy m) 516 (m',n') = size r
529 n' = fromIntegral . natVal $ (undefined :: Proxy n)
530 517
531 518
532diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n 519diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n
533diagRectC x v = mkM (asRow (vjoin [scalar x, ev, zeros])) 520diagRectC x v = r
534 where 521 where
522 r = mkM (asRow (vjoin [scalar x, ev, zeros]))
535 ev = extract v 523 ev = extract v
536 zeros = LA.konst x (max 0 ((min m' n') - size ev)) 524 zeros = LA.konst x (max 0 ((min m' n') - LA.size ev))
537 m' = fromIntegral . natVal $ (undefined :: Proxy m) 525 (m',n') = size r
538 n' = fromIntegral . natVal $ (undefined :: Proxy n)
539 526
540-------------------------------------------------------------------------------- 527--------------------------------------------------------------------------------
541 528
@@ -578,10 +565,10 @@ test = (ok,info)
578 tm = lmat 0 [1..] 565 tm = lmat 0 [1..]
579 566
580 lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n 567 lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n
581 lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z 568 lmat z xs = r
582 where 569 where
583 m' = fromIntegral . natVal $ (undefined :: Proxy m) 570 r = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z
584 n' = fromIntegral . natVal $ (undefined :: Proxy n) 571 (m',n') = size r
585 572
586 sm :: GSq 573 sm :: GSq
587 sm = lmat 0 [1..] 574 sm = lmat 0 [1..]
@@ -595,7 +582,7 @@ test = (ok,info)
595 m = LA.matrix 3 [1..30] 582 m = LA.matrix 3 [1..30]
596 583
597 precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v 584 precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v
598 precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v 585 precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (LA.size (ud1 u) +1, LA.size (ud1 v)) LA.#> ud1 v
599 586
600 587
601splittest 588splittest
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
index 7968d77..339ef7d 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs
@@ -7,13 +7,10 @@
7{-# LANGUAGE FunctionalDependencies #-} 7{-# LANGUAGE FunctionalDependencies #-}
8{-# LANGUAGE FlexibleContexts #-} 8{-# LANGUAGE FlexibleContexts #-}
9{-# LANGUAGE ScopedTypeVariables #-} 9{-# LANGUAGE ScopedTypeVariables #-}
10{-# LANGUAGE EmptyDataDecls #-}
11{-# LANGUAGE Rank2Types #-} 10{-# LANGUAGE Rank2Types #-}
12{-# LANGUAGE FlexibleInstances #-} 11{-# LANGUAGE FlexibleInstances #-}
13{-# LANGUAGE TypeOperators #-} 12{-# LANGUAGE TypeOperators #-}
14{-# LANGUAGE ViewPatterns #-} 13{-# LANGUAGE ViewPatterns #-}
15{-# LANGUAGE GADTs #-}
16
17 14
18{- | 15{- |
19Module : Numeric.LinearAlgebra.Static.Internal 16Module : Numeric.LinearAlgebra.Static.Internal
@@ -28,7 +25,7 @@ module Numeric.LinearAlgebra.Static.Internal where
28 25
29import GHC.TypeLits 26import GHC.TypeLits
30import qualified Numeric.LinearAlgebra.HMatrix as LA 27import qualified Numeric.LinearAlgebra.HMatrix as LA
31import Numeric.LinearAlgebra.HMatrix hiding (konst) 28import Numeric.LinearAlgebra.HMatrix hiding (konst,size)
32import Data.Packed as D 29import Data.Packed as D
33import Data.Packed.ST 30import Data.Packed.ST
34import Data.Proxy(Proxy) 31import Data.Proxy(Proxy)
@@ -83,7 +80,7 @@ ud :: Dim n (Vector t) -> Vector t
83ud (Dim v) = v 80ud (Dim v) = v
84 81
85mkV :: forall (n :: Nat) t . t -> Dim n t 82mkV :: forall (n :: Nat) t . t -> Dim n t
86mkV = Dim 83mkV = Dim
87 84
88 85
89vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) 86vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t)
@@ -92,9 +89,9 @@ vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t)
92 where 89 where
93 du = fromIntegral . natVal $ (undefined :: Proxy n) 90 du = fromIntegral . natVal $ (undefined :: Proxy n)
94 dv = fromIntegral . natVal $ (undefined :: Proxy m) 91 dv = fromIntegral . natVal $ (undefined :: Proxy m)
95 u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du 92 u' | du > 1 && LA.size u == 1 = LA.konst (u D.@> 0) du
96 | otherwise = u 93 | otherwise = u
97 v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv 94 v' | dv > 1 && LA.size v == 1 = LA.konst (v D.@> 0) dv
98 | otherwise = v 95 | otherwise = v
99 96
100 97
@@ -132,7 +129,7 @@ gvect st xs'
132 | otherwise = abort (show xs) 129 | otherwise = abort (show xs)
133 where 130 where
134 (xs,rest) = splitAt d xs' 131 (xs,rest) = splitAt d xs'
135 ok = size v == d && null rest 132 ok = LA.size v == d && null rest
136 v = LA.fromList xs 133 v = LA.fromList xs
137 d = fromIntegral . natVal $ (undefined :: Proxy n) 134 d = fromIntegral . natVal $ (undefined :: Proxy n)
138 abort info = error $ st++" "++show d++" can't be created from elements "++info 135 abort info = error $ st++" "++show d++" can't be created from elements "++info
@@ -153,7 +150,7 @@ gmat st xs'
153 (xs,rest) = splitAt (m'*n') xs' 150 (xs,rest) = splitAt (m'*n') xs'
154 v = LA.fromList xs 151 v = LA.fromList xs
155 x = reshape n' v 152 x = reshape n' v
156 ok = rem (size v) n' == 0 && size x == (m',n') && null rest 153 ok = rem (LA.size v) n' == 0 && LA.size x == (m',n') && null rest
157 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 154 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
158 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 155 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
159 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info 156 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info
@@ -162,66 +159,84 @@ gmat st xs'
162 159
163class Num t => Sized t s d | s -> t, s -> d 160class Num t => Sized t s d | s -> t, s -> d
164 where 161 where
165 konst :: t -> s 162 konst :: t -> s
166 unwrap :: s -> d 163 unwrap :: s -> d t
167 fromList :: [t] -> s 164 fromList :: [t] -> s
168 extract :: s -> d 165 extract :: s -> d t
169 166 create :: d t -> Maybe s
170singleV v = size v == 1 167 size :: s -> IndexOf d
168
169singleV v = LA.size v == 1
171singleM m = rows m == 1 && cols m == 1 170singleM m = rows m == 1 && cols m == 1
172 171
173 172
174instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ) 173instance forall n. KnownNat n => Sized ℂ (C n) Vector
175 where 174 where
175 size _ = fromIntegral . natVal $ (undefined :: Proxy n)
176 konst x = mkC (LA.scalar x) 176 konst x = mkC (LA.scalar x)
177 unwrap (C (Dim v)) = v 177 unwrap (C (Dim v)) = v
178 fromList xs = C (gvect "C" xs) 178 fromList xs = C (gvect "C" xs)
179 extract (unwrap -> v) 179 extract s@(unwrap -> v)
180 | singleV v = LA.konst (v!0) d 180 | singleV v = LA.konst (v!0) (size s)
181 | otherwise = v 181 | otherwise = v
182 where 182 create v
183 d = fromIntegral . natVal $ (undefined :: Proxy n) 183 | LA.size v == size r = Just r
184 | otherwise = Nothing
185 where
186 r = mkC v :: C n
184 187
185 188
186instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) 189instance forall n. KnownNat n => Sized ℝ (R n) Vector
187 where 190 where
191 size _ = fromIntegral . natVal $ (undefined :: Proxy n)
188 konst x = mkR (LA.scalar x) 192 konst x = mkR (LA.scalar x)
189 unwrap (R (Dim v)) = v 193 unwrap (R (Dim v)) = v
190 fromList xs = R (gvect "R" xs) 194 fromList xs = R (gvect "R" xs)
191 extract (unwrap -> v) 195 extract s@(unwrap -> v)
192 | singleV v = LA.konst (v!0) d 196 | singleV v = LA.konst (v!0) (size s)
193 | otherwise = v 197 | otherwise = v
194 where 198 create v
195 d = fromIntegral . natVal $ (undefined :: Proxy n) 199 | LA.size v == size r = Just r
200 | otherwise = Nothing
201 where
202 r = mkR v :: R n
196 203
197 204
198 205
199instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) 206instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) Matrix
200 where 207 where
208 size _ = ((fromIntegral . natVal) (undefined :: Proxy m)
209 ,(fromIntegral . natVal) (undefined :: Proxy n))
201 konst x = mkL (LA.scalar x) 210 konst x = mkL (LA.scalar x)
202 fromList xs = L (gmat "L" xs) 211 fromList xs = L (gmat "L" xs)
203 unwrap (L (Dim (Dim m))) = m 212 unwrap (L (Dim (Dim m))) = m
204 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' 213 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n'
205 extract (unwrap -> a) 214 extract s@(unwrap -> a)
206 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') 215 | singleM a = LA.konst (a `atIndex` (0,0)) (size s)
207 | otherwise = a 216 | otherwise = a
217 create x
218 | LA.size x == size r = Just r
219 | otherwise = Nothing
208 where 220 where
209 m' = fromIntegral . natVal $ (undefined :: Proxy m) 221 r = mkL x :: L m n
210 n' = fromIntegral . natVal $ (undefined :: Proxy n)
211 222
212 223
213instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ) 224instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) Matrix
214 where 225 where
226 size _ = ((fromIntegral . natVal) (undefined :: Proxy m)
227 ,(fromIntegral . natVal) (undefined :: Proxy n))
215 konst x = mkM (LA.scalar x) 228 konst x = mkM (LA.scalar x)
216 fromList xs = M (gmat "M" xs) 229 fromList xs = M (gmat "M" xs)
217 unwrap (M (Dim (Dim m))) = m 230 unwrap (M (Dim (Dim m))) = m
218 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' 231 extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n'
219 extract (unwrap -> a) 232 extract s@(unwrap -> a)
220 | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') 233 | singleM a = LA.konst (a `atIndex` (0,0)) (size s)
221 | otherwise = a 234 | otherwise = a
235 create x
236 | LA.size x == size r = Just r
237 | otherwise = Nothing
222 where 238 where
223 m' = fromIntegral . natVal $ (undefined :: Proxy m) 239 r = mkM x :: M m n
224 n' = fromIntegral . natVal $ (undefined :: Proxy n)
225 240
226-------------------------------------------------------------------------------- 241--------------------------------------------------------------------------------
227 242
@@ -254,8 +269,8 @@ isDiagg (Dim (Dim x))
254 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 269 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
255 v = flatten x 270 v = flatten x
256 z = v `atIndex` 0 271 z = v `atIndex` 0
257 y = subVector 1 (size v-1) v 272 y = subVector 1 (LA.size v-1) v
258 ny = size y 273 ny = LA.size y
259 zeros = LA.konst 0 (max 0 (min m' n' - ny)) 274 zeros = LA.konst 0 (max 0 (min m' n' - ny))
260 yz = vjoin [y,zeros] 275 yz = vjoin [y,zeros]
261 276
@@ -263,39 +278,37 @@ isDiagg (Dim (Dim x))
263 278
264instance forall n . KnownNat n => Show (R n) 279instance forall n . KnownNat n => Show (R n)
265 where 280 where
266 show (R (Dim v)) 281 show s@(R (Dim v))
267 | singleV v = "("++show (v!0)++" :: R "++show d++")" 282 | singleV v = "("++show (v!0)++" :: R "++show d++")"
268 | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" 283 | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")"
269 where 284 where
270 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 285 d = size s
271 286
272instance forall n . KnownNat n => Show (C n) 287instance forall n . KnownNat n => Show (C n)
273 where 288 where
274 show (C (Dim v)) 289 show s@(C (Dim v))
275 | singleV v = "("++show (v!0)++" :: C "++show d++")" 290 | singleV v = "("++show (v!0)++" :: C "++show d++")"
276 | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" 291 | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")"
277 where 292 where
278 d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 293 d = size s
279 294
280instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) 295instance forall m n . (KnownNat m, KnownNat n) => Show (L m n)
281 where 296 where
282 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' 297 show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n'
283 show (L (Dim (Dim x))) 298 show s@(L (Dim (Dim x)))
284 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' 299 | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n'
285 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" 300 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")"
286 where 301 where
287 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 302 (m',n') = size s
288 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
289 303
290instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) 304instance forall m n . (KnownNat m, KnownNat n) => Show (M m n)
291 where 305 where
292 show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' 306 show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n'
293 show (M (Dim (Dim x))) 307 show s@(M (Dim (Dim x)))
294 | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' 308 | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n'
295 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" 309 | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")"
296 where 310 where
297 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 311 (m',n') = size s
298 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
299 312
300-------------------------------------------------------------------------------- 313--------------------------------------------------------------------------------
301 314