diff options
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Real.hs | 231 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 334 |
2 files changed, 282 insertions, 283 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index aa48687..97c462e 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs | |||
@@ -62,28 +62,13 @@ import Numeric.HMatrix hiding ( | |||
62 | import qualified Numeric.HMatrix as LA | 62 | import qualified Numeric.HMatrix as LA |
63 | import Data.Proxy(Proxy) | 63 | import Data.Proxy(Proxy) |
64 | import Numeric.LinearAlgebra.Static | 64 | import Numeric.LinearAlgebra.Static |
65 | import Text.Printf | ||
66 | import Control.Arrow((***)) | 65 | import Control.Arrow((***)) |
67 | 66 | ||
68 | 67 | ||
69 | 𝑖 :: Sized ℂ s c => s | 68 | 𝑖 :: Sized ℂ s c => s |
70 | 𝑖 = konst i_C | 69 | 𝑖 = konst iC |
71 | 70 | ||
72 | instance forall n . KnownNat n => Show (R n) | ||
73 | where | ||
74 | show (ud1 -> v) | ||
75 | | singleV v = "("++show (v!0)++" :: R "++show d++")" | ||
76 | | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" | ||
77 | where | ||
78 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
79 | 71 | ||
80 | instance forall n . KnownNat n => Show (C n) | ||
81 | where | ||
82 | show (C (Dim v)) | ||
83 | | singleV v = "("++show (v!0)++" :: C "++show d++")" | ||
84 | | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" | ||
85 | where | ||
86 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
87 | 72 | ||
88 | 73 | ||
89 | 74 | ||
@@ -91,12 +76,6 @@ ud1 :: R n -> Vector ℝ | |||
91 | ud1 (R (Dim v)) = v | 76 | ud1 (R (Dim v)) = v |
92 | 77 | ||
93 | 78 | ||
94 | mkR :: Vector ℝ -> R n | ||
95 | mkR = R . Dim | ||
96 | |||
97 | mkC :: Vector ℂ -> C n | ||
98 | mkC = C . Dim | ||
99 | |||
100 | infixl 4 & | 79 | infixl 4 & |
101 | (&) :: forall n . KnownNat n | 80 | (&) :: forall n . KnownNat n |
102 | => R n -> ℝ -> R (n+1) | 81 | => R n -> ℝ -> R (n+1) |
@@ -143,95 +122,12 @@ dim = mkR (scalar d) | |||
143 | 122 | ||
144 | -------------------------------------------------------------------------------- | 123 | -------------------------------------------------------------------------------- |
145 | 124 | ||
146 | newtype L m n = L (Dim m (Dim n (Matrix ℝ))) | ||
147 | |||
148 | newtype M m n = M (Dim m (Dim n (Matrix ℂ))) | ||
149 | 125 | ||
150 | ud2 :: L m n -> Matrix ℝ | 126 | ud2 :: L m n -> Matrix ℝ |
151 | ud2 (L (Dim (Dim x))) = x | 127 | ud2 (L (Dim (Dim x))) = x |
152 | 128 | ||
153 | 129 | ||
154 | mkL :: Matrix ℝ -> L m n | ||
155 | mkL x = L (Dim (Dim x)) | ||
156 | |||
157 | mkM :: Matrix ℂ -> M m n | ||
158 | mkM x = M (Dim (Dim x)) | ||
159 | |||
160 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) | ||
161 | where | ||
162 | show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' | ||
163 | show (ud2 -> x) | ||
164 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' | ||
165 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" | ||
166 | where | ||
167 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
168 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
169 | |||
170 | instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) | ||
171 | where | ||
172 | show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' | ||
173 | show (M (Dim (Dim x))) | ||
174 | | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' | ||
175 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" | ||
176 | where | ||
177 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
178 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
179 | |||
180 | |||
181 | -------------------------------------------------------------------------------- | 130 | -------------------------------------------------------------------------------- |
182 | |||
183 | instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ) | ||
184 | where | ||
185 | konst x = mkC (LA.scalar x) | ||
186 | unwrap (C (Dim v)) = v | ||
187 | fromList xs = C (gvect "C" xs) | ||
188 | extract (unwrap -> v) | ||
189 | | singleV v = LA.konst (v!0) d | ||
190 | | otherwise = v | ||
191 | where | ||
192 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
193 | |||
194 | |||
195 | instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) | ||
196 | where | ||
197 | konst x = mkR (LA.scalar x) | ||
198 | unwrap = ud1 | ||
199 | fromList xs = R (gvect "R" xs) | ||
200 | extract (unwrap -> v) | ||
201 | | singleV v = LA.konst (v!0) d | ||
202 | | otherwise = v | ||
203 | where | ||
204 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
205 | |||
206 | |||
207 | |||
208 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) | ||
209 | where | ||
210 | konst x = mkL (LA.scalar x) | ||
211 | fromList xs = L (gmat "L" xs) | ||
212 | unwrap = ud2 | ||
213 | extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' | ||
214 | extract (unwrap -> a) | ||
215 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | ||
216 | | otherwise = a | ||
217 | where | ||
218 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
219 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
220 | |||
221 | |||
222 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ) | ||
223 | where | ||
224 | konst x = mkM (LA.scalar x) | ||
225 | fromList xs = M (gmat "M" xs) | ||
226 | unwrap (M (Dim (Dim m))) = m | ||
227 | extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' | ||
228 | extract (unwrap -> a) | ||
229 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | ||
230 | | otherwise = a | ||
231 | where | ||
232 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
233 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
234 | |||
235 | -------------------------------------------------------------------------------- | 131 | -------------------------------------------------------------------------------- |
236 | 132 | ||
237 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n | 133 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n |
@@ -269,41 +165,6 @@ blockAt x r c a = mkL res | |||
269 | 165 | ||
270 | -------------------------------------------------------------------------------- | 166 | -------------------------------------------------------------------------------- |
271 | 167 | ||
272 | class Disp t | ||
273 | where | ||
274 | disp :: Int -> t -> IO () | ||
275 | |||
276 | |||
277 | instance (KnownNat m, KnownNat n) => Disp (L m n) | ||
278 | where | ||
279 | disp n x = do | ||
280 | let a = extract x | ||
281 | let su = LA.dispf n a | ||
282 | printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) | ||
283 | |||
284 | instance (KnownNat m, KnownNat n) => Disp (M m n) | ||
285 | where | ||
286 | disp n x = do | ||
287 | let a = extract x | ||
288 | let su = LA.dispcf n a | ||
289 | printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) | ||
290 | |||
291 | |||
292 | instance KnownNat n => Disp (R n) | ||
293 | where | ||
294 | disp n v = do | ||
295 | let su = LA.dispf n (asRow $ extract v) | ||
296 | putStr "R " >> putStr (tail . dropWhile (/='x') $ su) | ||
297 | |||
298 | instance KnownNat n => Disp (C n) | ||
299 | where | ||
300 | disp n v = do | ||
301 | let su = LA.dispcf n (asRow $ extract v) | ||
302 | putStr "C " >> putStr (tail . dropWhile (/='x') $ su) | ||
303 | |||
304 | |||
305 | -------------------------------------------------------------------------------- | ||
306 | |||
307 | 168 | ||
308 | row :: R n -> L 1 n | 169 | row :: R n -> L 1 n |
309 | row = mkL . asRow . ud1 | 170 | row = mkL . asRow . ud1 |
@@ -344,28 +205,6 @@ isKonst (unwrap -> x) | |||
344 | 205 | ||
345 | 206 | ||
346 | 207 | ||
347 | isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) | ||
348 | isDiag (L x) = isDiagg x | ||
349 | |||
350 | isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int)) | ||
351 | isDiagC (M x) = isDiagg x | ||
352 | |||
353 | |||
354 | isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int)) | ||
355 | isDiagg (Dim (Dim x)) | ||
356 | | singleM x = Nothing | ||
357 | | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) | ||
358 | | otherwise = Nothing | ||
359 | where | ||
360 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
361 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
362 | v = flatten x | ||
363 | z = v `atIndex` 0 | ||
364 | y = subVector 1 (size v-1) v | ||
365 | ny = size y | ||
366 | zeros = LA.konst 0 (max 0 (min m' n' - ny)) | ||
367 | yz = vjoin [y,zeros] | ||
368 | |||
369 | 208 | ||
370 | infixr 8 <> | 209 | infixr 8 <> |
371 | (<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n | 210 | (<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n |
@@ -397,74 +236,6 @@ infixr 8 <·> | |||
397 | | singleV u || singleV v = sumElements (u * v) | 236 | | singleV u || singleV v = sumElements (u * v) |
398 | | otherwise = udot u v | 237 | | otherwise = udot u v |
399 | 238 | ||
400 | |||
401 | instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m) | ||
402 | where | ||
403 | tr a@(isDiag -> Just _) = mkL (extract a) | ||
404 | tr (extract -> a) = mkL (tr a) | ||
405 | |||
406 | instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m) | ||
407 | where | ||
408 | tr a@(isDiagC -> Just _) = mkM (extract a) | ||
409 | tr (extract -> a) = mkM (tr a) | ||
410 | |||
411 | |||
412 | -------------------------------------------------------------------------------- | ||
413 | |||
414 | adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b | ||
415 | adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b)) | ||
416 | adaptDiag f a b = f a b | ||
417 | |||
418 | isFull m = isDiag m == Nothing && not (singleM (unwrap m)) | ||
419 | |||
420 | |||
421 | lift1L f (L v) = L (f v) | ||
422 | lift2L f (L a) (L b) = L (f a b) | ||
423 | lift2LD f = adaptDiag (lift2L f) | ||
424 | |||
425 | |||
426 | instance (KnownNat n, KnownNat m) => Num (L n m) | ||
427 | where | ||
428 | (+) = lift2LD (+) | ||
429 | (*) = lift2LD (*) | ||
430 | (-) = lift2LD (-) | ||
431 | abs = lift1L abs | ||
432 | signum = lift1L signum | ||
433 | negate = lift1L negate | ||
434 | fromInteger = L . Dim . Dim . fromInteger | ||
435 | |||
436 | instance (KnownNat n, KnownNat m) => Fractional (L n m) | ||
437 | where | ||
438 | fromRational = L . Dim . Dim . fromRational | ||
439 | (/) = lift2LD (/) | ||
440 | |||
441 | -------------------------------------------------------------------------------- | ||
442 | |||
443 | adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b | ||
444 | adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b)) | ||
445 | adaptDiagC f a b = f a b | ||
446 | |||
447 | isFullC m = isDiagC m == Nothing && not (singleM (unwrap m)) | ||
448 | |||
449 | lift1M f (M v) = M (f v) | ||
450 | lift2M f (M a) (M b) = M (f a b) | ||
451 | lift2MD f = adaptDiagC (lift2M f) | ||
452 | |||
453 | instance (KnownNat n, KnownNat m) => Num (M n m) | ||
454 | where | ||
455 | (+) = lift2MD (+) | ||
456 | (*) = lift2MD (*) | ||
457 | (-) = lift2MD (-) | ||
458 | abs = lift1M abs | ||
459 | signum = lift1M signum | ||
460 | negate = lift1M negate | ||
461 | fromInteger = M . Dim . Dim . fromInteger | ||
462 | |||
463 | instance (KnownNat n, KnownNat m) => Fractional (M n m) | ||
464 | where | ||
465 | fromRational = M . Dim . Dim . fromRational | ||
466 | (/) = lift2MD (/) | ||
467 | |||
468 | -------------------------------------------------------------------------------- | 239 | -------------------------------------------------------------------------------- |
469 | 240 | ||
470 | {- | 241 | {- |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 6acd9a3..2647f20 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -21,14 +21,7 @@ Stability : provisional | |||
21 | 21 | ||
22 | -} | 22 | -} |
23 | 23 | ||
24 | module Numeric.LinearAlgebra.Static( | 24 | module Numeric.LinearAlgebra.Static where |
25 | Dim(..), | ||
26 | R(..), C(..), | ||
27 | lift1F, lift2F, | ||
28 | vconcat, gvec2, gvec3, gvec4, gvect, gmat, | ||
29 | Sized(..), | ||
30 | singleV, singleM,GM | ||
31 | ) where | ||
32 | 25 | ||
33 | 26 | ||
34 | import GHC.TypeLits | 27 | import GHC.TypeLits |
@@ -37,17 +30,9 @@ import Data.Packed as D | |||
37 | import Data.Packed.ST | 30 | import Data.Packed.ST |
38 | import Data.Proxy(Proxy) | 31 | import Data.Proxy(Proxy) |
39 | import Foreign.Storable(Storable) | 32 | import Foreign.Storable(Storable) |
33 | import Text.Printf | ||
40 | 34 | ||
41 | 35 | -------------------------------------------------------------------------------- | |
42 | |||
43 | newtype R n = R (Dim n (Vector ℝ)) | ||
44 | deriving (Num,Fractional) | ||
45 | |||
46 | |||
47 | newtype C n = C (Dim n (Vector ℂ)) | ||
48 | deriving (Num,Fractional) | ||
49 | |||
50 | |||
51 | 36 | ||
52 | newtype Dim (n :: Nat) t = Dim t | 37 | newtype Dim (n :: Nat) t = Dim t |
53 | deriving Show | 38 | deriving Show |
@@ -64,36 +49,28 @@ lift2F f (Dim u) (Dim v) = Dim (f u v) | |||
64 | 49 | ||
65 | -------------------------------------------------------------------------------- | 50 | -------------------------------------------------------------------------------- |
66 | 51 | ||
67 | instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) | 52 | newtype R n = R (Dim n (Vector ℝ)) |
68 | where | 53 | deriving (Num,Fractional) |
69 | (+) = lift2F (+) | ||
70 | (*) = lift2F (*) | ||
71 | (-) = lift2F (-) | ||
72 | abs = lift1F abs | ||
73 | signum = lift1F signum | ||
74 | negate = lift1F negate | ||
75 | fromInteger x = Dim (fromInteger x) | ||
76 | 54 | ||
77 | instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t)) | 55 | newtype C n = C (Dim n (Vector ℂ)) |
78 | where | 56 | deriving (Num,Fractional) |
79 | fromRational x = Dim (fromRational x) | ||
80 | (/) = lift2F (/) | ||
81 | 57 | ||
58 | newtype L m n = L (Dim m (Dim n (Matrix ℝ))) | ||
82 | 59 | ||
83 | instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) | 60 | newtype M m n = M (Dim m (Dim n (Matrix ℂ))) |
84 | where | ||
85 | (+) = (lift2F . lift2F) (+) | ||
86 | (*) = (lift2F . lift2F) (*) | ||
87 | (-) = (lift2F . lift2F) (-) | ||
88 | abs = (lift1F . lift1F) abs | ||
89 | signum = (lift1F . lift1F) signum | ||
90 | negate = (lift1F . lift1F) negate | ||
91 | fromInteger x = Dim (Dim (fromInteger x)) | ||
92 | 61 | ||
93 | instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t))) | 62 | |
94 | where | 63 | mkR :: Vector ℝ -> R n |
95 | fromRational x = Dim (Dim (fromRational x)) | 64 | mkR = R . Dim |
96 | (/) = (lift2F.lift2F) (/) | 65 | |
66 | mkC :: Vector ℂ -> C n | ||
67 | mkC = C . Dim | ||
68 | |||
69 | mkL :: Matrix ℝ -> L m n | ||
70 | mkL x = L (Dim (Dim x)) | ||
71 | |||
72 | mkM :: Matrix ℂ -> M m n | ||
73 | mkM x = M (Dim (Dim x)) | ||
97 | 74 | ||
98 | -------------------------------------------------------------------------------- | 75 | -------------------------------------------------------------------------------- |
99 | 76 | ||
@@ -105,14 +82,6 @@ ud (Dim v) = v | |||
105 | mkV :: forall (n :: Nat) t . t -> Dim n t | 82 | mkV :: forall (n :: Nat) t . t -> Dim n t |
106 | mkV = Dim | 83 | mkV = Dim |
107 | 84 | ||
108 | type GM m n t = Dim m (Dim n (Matrix t)) | ||
109 | |||
110 | --ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t | ||
111 | --ud2 (Dim (Dim m)) = m | ||
112 | |||
113 | mkM :: forall (m :: Nat) (n :: Nat) t . t -> Dim m (Dim n t) | ||
114 | mkM = Dim . Dim | ||
115 | |||
116 | 85 | ||
117 | vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) | 86 | vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) |
118 | => V n t -> V m t -> V (n+m) t | 87 | => V n t -> V m t -> V (n+m) t |
@@ -166,9 +135,14 @@ gvect st xs' | |||
166 | abort info = error $ st++" "++show d++" can't be created from elements "++info | 135 | abort info = error $ st++" "++show d++" can't be created from elements "++info |
167 | 136 | ||
168 | 137 | ||
138 | -------------------------------------------------------------------------------- | ||
139 | |||
140 | type GM m n t = Dim m (Dim n (Matrix t)) | ||
141 | |||
142 | |||
169 | gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t | 143 | gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t |
170 | gmat st xs' | 144 | gmat st xs' |
171 | | ok = mkM x | 145 | | ok = Dim (Dim x) |
172 | | not (null rest) && null (tail rest) = abort (show xs') | 146 | | not (null rest) && null (tail rest) = abort (show xs') |
173 | | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") | 147 | | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") |
174 | | otherwise = abort (show xs) | 148 | | otherwise = abort (show xs) |
@@ -181,6 +155,7 @@ gmat st xs' | |||
181 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 155 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
182 | abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info | 156 | abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info |
183 | 157 | ||
158 | -------------------------------------------------------------------------------- | ||
184 | 159 | ||
185 | class Num t => Sized t s d | s -> t, s -> d | 160 | class Num t => Sized t s d | s -> t, s -> d |
186 | where | 161 | where |
@@ -192,3 +167,256 @@ class Num t => Sized t s d | s -> t, s -> d | |||
192 | singleV v = size v == 1 | 167 | singleV v = size v == 1 |
193 | singleM m = rows m == 1 && cols m == 1 | 168 | singleM m = rows m == 1 && cols m == 1 |
194 | 169 | ||
170 | |||
171 | instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ) | ||
172 | where | ||
173 | konst x = mkC (LA.scalar x) | ||
174 | unwrap (C (Dim v)) = v | ||
175 | fromList xs = C (gvect "C" xs) | ||
176 | extract (unwrap -> v) | ||
177 | | singleV v = LA.konst (v!0) d | ||
178 | | otherwise = v | ||
179 | where | ||
180 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
181 | |||
182 | |||
183 | instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) | ||
184 | where | ||
185 | konst x = mkR (LA.scalar x) | ||
186 | unwrap (R (Dim v)) = v | ||
187 | fromList xs = R (gvect "R" xs) | ||
188 | extract (unwrap -> v) | ||
189 | | singleV v = LA.konst (v!0) d | ||
190 | | otherwise = v | ||
191 | where | ||
192 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
193 | |||
194 | |||
195 | |||
196 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) | ||
197 | where | ||
198 | konst x = mkL (LA.scalar x) | ||
199 | fromList xs = L (gmat "L" xs) | ||
200 | unwrap (L (Dim (Dim m))) = m | ||
201 | extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' | ||
202 | extract (unwrap -> a) | ||
203 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | ||
204 | | otherwise = a | ||
205 | where | ||
206 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
207 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
208 | |||
209 | |||
210 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ) | ||
211 | where | ||
212 | konst x = mkM (LA.scalar x) | ||
213 | fromList xs = M (gmat "M" xs) | ||
214 | unwrap (M (Dim (Dim m))) = m | ||
215 | extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' | ||
216 | extract (unwrap -> a) | ||
217 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | ||
218 | | otherwise = a | ||
219 | where | ||
220 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
221 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
222 | |||
223 | -------------------------------------------------------------------------------- | ||
224 | |||
225 | instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m) | ||
226 | where | ||
227 | tr a@(isDiag -> Just _) = mkL (extract a) | ||
228 | tr (extract -> a) = mkL (tr a) | ||
229 | |||
230 | instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m) | ||
231 | where | ||
232 | tr a@(isDiagC -> Just _) = mkM (extract a) | ||
233 | tr (extract -> a) = mkM (tr a) | ||
234 | |||
235 | -------------------------------------------------------------------------------- | ||
236 | |||
237 | isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) | ||
238 | isDiag (L x) = isDiagg x | ||
239 | |||
240 | isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int)) | ||
241 | isDiagC (M x) = isDiagg x | ||
242 | |||
243 | |||
244 | isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int)) | ||
245 | isDiagg (Dim (Dim x)) | ||
246 | | singleM x = Nothing | ||
247 | | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) | ||
248 | | otherwise = Nothing | ||
249 | where | ||
250 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
251 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
252 | v = flatten x | ||
253 | z = v `atIndex` 0 | ||
254 | y = subVector 1 (size v-1) v | ||
255 | ny = size y | ||
256 | zeros = LA.konst 0 (max 0 (min m' n' - ny)) | ||
257 | yz = vjoin [y,zeros] | ||
258 | |||
259 | -------------------------------------------------------------------------------- | ||
260 | |||
261 | instance forall n . KnownNat n => Show (R n) | ||
262 | where | ||
263 | show (R (Dim v)) | ||
264 | | singleV v = "("++show (v!0)++" :: R "++show d++")" | ||
265 | | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" | ||
266 | where | ||
267 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
268 | |||
269 | instance forall n . KnownNat n => Show (C n) | ||
270 | where | ||
271 | show (C (Dim v)) | ||
272 | | singleV v = "("++show (v!0)++" :: C "++show d++")" | ||
273 | | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" | ||
274 | where | ||
275 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
276 | |||
277 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) | ||
278 | where | ||
279 | show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' | ||
280 | show (L (Dim (Dim x))) | ||
281 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' | ||
282 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" | ||
283 | where | ||
284 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
285 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
286 | |||
287 | instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) | ||
288 | where | ||
289 | show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' | ||
290 | show (M (Dim (Dim x))) | ||
291 | | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' | ||
292 | | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" | ||
293 | where | ||
294 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
295 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
296 | |||
297 | -------------------------------------------------------------------------------- | ||
298 | |||
299 | instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) | ||
300 | where | ||
301 | (+) = lift2F (+) | ||
302 | (*) = lift2F (*) | ||
303 | (-) = lift2F (-) | ||
304 | abs = lift1F abs | ||
305 | signum = lift1F signum | ||
306 | negate = lift1F negate | ||
307 | fromInteger x = Dim (fromInteger x) | ||
308 | |||
309 | instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t)) | ||
310 | where | ||
311 | fromRational x = Dim (fromRational x) | ||
312 | (/) = lift2F (/) | ||
313 | |||
314 | |||
315 | instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) | ||
316 | where | ||
317 | (+) = (lift2F . lift2F) (+) | ||
318 | (*) = (lift2F . lift2F) (*) | ||
319 | (-) = (lift2F . lift2F) (-) | ||
320 | abs = (lift1F . lift1F) abs | ||
321 | signum = (lift1F . lift1F) signum | ||
322 | negate = (lift1F . lift1F) negate | ||
323 | fromInteger x = Dim (Dim (fromInteger x)) | ||
324 | |||
325 | instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t))) | ||
326 | where | ||
327 | fromRational x = Dim (Dim (fromRational x)) | ||
328 | (/) = (lift2F.lift2F) (/) | ||
329 | |||
330 | -------------------------------------------------------------------------------- | ||
331 | |||
332 | |||
333 | adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b | ||
334 | adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b)) | ||
335 | adaptDiag f a b = f a b | ||
336 | |||
337 | isFull m = isDiag m == Nothing && not (singleM (unwrap m)) | ||
338 | |||
339 | |||
340 | lift1L f (L v) = L (f v) | ||
341 | lift2L f (L a) (L b) = L (f a b) | ||
342 | lift2LD f = adaptDiag (lift2L f) | ||
343 | |||
344 | |||
345 | instance (KnownNat n, KnownNat m) => Num (L n m) | ||
346 | where | ||
347 | (+) = lift2LD (+) | ||
348 | (*) = lift2LD (*) | ||
349 | (-) = lift2LD (-) | ||
350 | abs = lift1L abs | ||
351 | signum = lift1L signum | ||
352 | negate = lift1L negate | ||
353 | fromInteger = L . Dim . Dim . fromInteger | ||
354 | |||
355 | instance (KnownNat n, KnownNat m) => Fractional (L n m) | ||
356 | where | ||
357 | fromRational = L . Dim . Dim . fromRational | ||
358 | (/) = lift2LD (/) | ||
359 | |||
360 | -------------------------------------------------------------------------------- | ||
361 | |||
362 | adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b | ||
363 | adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b)) | ||
364 | adaptDiagC f a b = f a b | ||
365 | |||
366 | isFullC m = isDiagC m == Nothing && not (singleM (unwrap m)) | ||
367 | |||
368 | lift1M f (M v) = M (f v) | ||
369 | lift2M f (M a) (M b) = M (f a b) | ||
370 | lift2MD f = adaptDiagC (lift2M f) | ||
371 | |||
372 | instance (KnownNat n, KnownNat m) => Num (M n m) | ||
373 | where | ||
374 | (+) = lift2MD (+) | ||
375 | (*) = lift2MD (*) | ||
376 | (-) = lift2MD (-) | ||
377 | abs = lift1M abs | ||
378 | signum = lift1M signum | ||
379 | negate = lift1M negate | ||
380 | fromInteger = M . Dim . Dim . fromInteger | ||
381 | |||
382 | instance (KnownNat n, KnownNat m) => Fractional (M n m) | ||
383 | where | ||
384 | fromRational = M . Dim . Dim . fromRational | ||
385 | (/) = lift2MD (/) | ||
386 | |||
387 | -------------------------------------------------------------------------------- | ||
388 | |||
389 | |||
390 | class Disp t | ||
391 | where | ||
392 | disp :: Int -> t -> IO () | ||
393 | |||
394 | |||
395 | instance (KnownNat m, KnownNat n) => Disp (L m n) | ||
396 | where | ||
397 | disp n x = do | ||
398 | let a = extract x | ||
399 | let su = LA.dispf n a | ||
400 | printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) | ||
401 | |||
402 | instance (KnownNat m, KnownNat n) => Disp (M m n) | ||
403 | where | ||
404 | disp n x = do | ||
405 | let a = extract x | ||
406 | let su = LA.dispcf n a | ||
407 | printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) | ||
408 | |||
409 | |||
410 | instance KnownNat n => Disp (R n) | ||
411 | where | ||
412 | disp n v = do | ||
413 | let su = LA.dispf n (asRow $ extract v) | ||
414 | putStr "R " >> putStr (tail . dropWhile (/='x') $ su) | ||
415 | |||
416 | instance KnownNat n => Disp (C n) | ||
417 | where | ||
418 | disp n v = do | ||
419 | let su = LA.dispcf n (asRow $ extract v) | ||
420 | putStr "C " >> putStr (tail . dropWhile (/='x') $ su) | ||
421 | |||
422 | -------------------------------------------------------------------------------- | ||