summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/tmpStatic.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/tmpStatic.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/tmpStatic.hs619
1 files changed, 619 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/tmpStatic.hs b/packages/base/src/Numeric/LinearAlgebra/tmpStatic.hs
new file mode 100644
index 0000000..4258d6b
--- /dev/null
+++ b/packages/base/src/Numeric/LinearAlgebra/tmpStatic.hs
@@ -0,0 +1,619 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE FlexibleContexts #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE EmptyDataDecls #-}
9{-# LANGUAGE Rank2Types #-}
10{-# LANGUAGE FlexibleInstances #-}
11{-# LANGUAGE TypeOperators #-}
12{-# LANGUAGE ViewPatterns #-}
13{-# LANGUAGE GADTs #-}
14{-# LANGUAGE OverlappingInstances #-}
15{-# LANGUAGE TypeFamilies #-}
16
17
18{- |
19Module : Numeric.LinearAlgebra.Static
20Copyright : (c) Alberto Ruiz 2014
21License : BSD3
22Stability : experimental
23
24Experimental interface with statically checked dimensions.
25
26-}
27
28module Numeric.LinearAlgebra.Static(
29 -- * Vector
30 โ„, R,
31 vec2, vec3, vec4, (&), (#), split, headTail,
32 vector,
33 linspace, range, dim,
34 -- * Matrix
35 L, Sq, build,
36 row, col, (ยฆ),(โ€”โ€”), splitRows, splitCols,
37 unrow, uncol,
38 tr,
39 eye,
40 diag,
41 blockAt,
42 matrix,
43 -- * Complex
44 C, M, Her, her, ๐‘–,
45 -- * Products
46 (<>),(#>),(<ยท>),
47 -- * Linear Systems
48 linSolve, (<\>),
49 -- * Factorizations
50 svd, withCompactSVD, svdTall, svdFlat, Eigen(..),
51 withNullspace, qr,
52 -- * Misc
53 mean,
54 Disp(..), Domain(..),
55 withVector, withMatrix,
56 toRows, toColumns,
57 Sized(..), Diag(..), Sym, sym, mTm, unSym
58) where
59
60
61import GHC.TypeLits
62import Numeric.LinearAlgebra.HMatrix hiding (
63 (<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โ€”โ€”),row,col,vector,matrix,linspace,toRows,toColumns,
64 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build,
65 qr)
66import qualified Numeric.LinearAlgebra.HMatrix as LA
67import Data.Proxy(Proxy)
68import Numeric.LinearAlgebra.Static.Internal
69import Control.Arrow((***))
70
71
72
73
74
75ud1 :: R n -> Vector โ„
76ud1 (R (Dim v)) = v
77
78
79infixl 4 &
80(&) :: forall n . (KnownNat n, 1 <= n)
81 => R n -> โ„ -> R (n+1)
82u & x = u # (konst x :: R 1)
83
84infixl 4 #
85(#) :: forall n m . (KnownNat n, KnownNat m)
86 => R n -> R m -> R (n+m)
87(R u) # (R v) = R (vconcat u v)
88
89
90
91vec2 :: โ„ -> โ„ -> R 2
92vec2 a b = R (gvec2 a b)
93
94vec3 :: โ„ -> โ„ -> โ„ -> R 3
95vec3 a b c = R (gvec3 a b c)
96
97
98vec4 :: โ„ -> โ„ -> โ„ -> โ„ -> R 4
99vec4 a b c d = R (gvec4 a b c d)
100
101vector :: KnownNat n => [โ„] -> R n
102vector = fromList
103
104matrix :: (KnownNat m, KnownNat n) => [โ„] -> L m n
105matrix = fromList
106
107linspace :: forall n . KnownNat n => (โ„,โ„) -> R n
108linspace (a,b) = mkR (LA.linspace d (a,b))
109 where
110 d = fromIntegral . natVal $ (undefined :: Proxy n)
111
112range :: forall n . KnownNat n => R n
113range = mkR (LA.linspace d (1,fromIntegral d))
114 where
115 d = fromIntegral . natVal $ (undefined :: Proxy n)
116
117dim :: forall n . KnownNat n => R n
118dim = mkR (scalar d)
119 where
120 d = fromIntegral . natVal $ (undefined :: Proxy n)
121
122
123--------------------------------------------------------------------------------
124
125
126ud2 :: L m n -> Matrix โ„
127ud2 (L (Dim (Dim x))) = x
128
129
130--------------------------------------------------------------------------------
131
132diag :: KnownNat n => R n -> Sq n
133diag = diagR 0
134
135eye :: KnownNat n => Sq n
136eye = diag 1
137
138--------------------------------------------------------------------------------
139
140blockAt :: forall m n . (KnownNat m, KnownNat n) => โ„ -> Int -> Int -> Matrix Double -> L m n
141blockAt x r c a = mkL res
142 where
143 z = scalar x
144 z1 = LA.konst x (r,c)
145 z2 = LA.konst x (max 0 (m'-(ra+r)), max 0 (n'-(ca+c)))
146 ra = min (rows a) . max 0 $ m'-r
147 ca = min (cols a) . max 0 $ n'-c
148 sa = subMatrix (0,0) (ra, ca) a
149 m' = fromIntegral . natVal $ (undefined :: Proxy m)
150 n' = fromIntegral . natVal $ (undefined :: Proxy n)
151 res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]]
152
153
154
155
156
157--------------------------------------------------------------------------------
158
159
160row :: R n -> L 1 n
161row = mkL . asRow . ud1
162
163--col :: R n -> L n 1
164col v = tr . row $ v
165
166unrow :: L 1 n -> R n
167unrow = mkR . head . LA.toRows . ud2
168
169--uncol :: L n 1 -> R n
170uncol v = unrow . tr $ v
171
172
173infixl 2 โ€”โ€”
174(โ€”โ€”) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c
175a โ€”โ€” b = mkL (extract a LA.โ€”โ€” extract b)
176
177
178infixl 3 ยฆ
179-- (ยฆ) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2)
180a ยฆ b = tr (tr a โ€”โ€” tr b)
181
182
183type Sq n = L n n
184--type CSq n = CL n n
185
186type GL = (KnownNat n, KnownNat m) => L m n
187type GSq = KnownNat n => Sq n
188
189isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (โ„,(Int,Int))
190isKonst (unwrap -> x)
191 | singleM x = Just (x `atIndex` (0,0), (m',n'))
192 | otherwise = Nothing
193 where
194 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
195 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
196
197
198isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (โ„‚,(Int,Int))
199isKonstC (unwrap -> x)
200 | singleM x = Just (x `atIndex` (0,0), (m',n'))
201 | otherwise = Nothing
202 where
203 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
204 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
205
206
207
208infixr 8 <>
209(<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n
210(<>) = mulR
211
212
213infixr 8 #>
214(#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m
215(#>) = appR
216
217
218infixr 8 <ยท>
219(<ยท>) :: R n -> R n -> โ„
220(<ยท>) = dotR
221
222--------------------------------------------------------------------------------
223
224class Diag m d | m -> d
225 where
226 takeDiag :: m -> d
227
228
229instance forall n . (KnownNat n) => Diag (L n n) (R n)
230 where
231 takeDiag m = mkR (LA.takeDiag (extract m))
232
233
234instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m)
235 where
236 takeDiag m = mkR (LA.takeDiag (extract m))
237
238
239instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n)
240 where
241 takeDiag m = mkR (LA.takeDiag (extract m))
242
243
244--------------------------------------------------------------------------------
245
246linSolve :: (KnownNat m, KnownNat n) => L m m -> L m n -> Maybe (L m n)
247linSolve (extract -> a) (extract -> b) = fmap mkL (LA.linearSolve a b)
248
249(<\>) :: (KnownNat m, KnownNat n, KnownNat r) => L m n -> L m r -> L n r
250(extract -> a) <\> (extract -> b) = mkL (a LA.<\> b)
251
252svd :: (KnownNat m, KnownNat n) => L m n -> (L m m, R n, L n n)
253svd (extract -> m) = (mkL u, mkR s', mkL v)
254 where
255 (u,s,v) = LA.svd m
256 s' = vjoin [s, z]
257 z = LA.konst 0 (max 0 (cols m - size s))
258
259
260svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n)
261svdTall (extract -> m) = (mkL u, mkR s, mkL v)
262 where
263 (u,s,v) = LA.thinSVD m
264
265
266svdFlat :: (KnownNat m, KnownNat n, m <= n) => L m n -> (L m m, R m, L n m)
267svdFlat (extract -> m) = (mkL u, mkR s, mkL v)
268 where
269 (u,s,v) = LA.thinSVD m
270
271--------------------------------------------------------------------------------
272
273class Eigen m l v | m -> l, m -> v
274 where
275 eigensystem :: m -> (l,v)
276 eigenvalues :: m -> l
277
278newtype Sym n = Sym (Sq n) deriving Show
279
280
281sym :: KnownNat n => Sq n -> Sym n
282sym m = Sym $ (m + tr m)/2
283
284mTm :: (KnownNat m, KnownNat n) => L m n -> Sym n
285mTm x = Sym (tr x <> x)
286
287unSym :: Sym n -> Sq n
288unSym (Sym x) = x
289
290
291๐‘– :: Sized โ„‚ s c => s
292๐‘– = konst iC
293
294newtype Her n = Her (M n n)
295
296her :: KnownNat n => M n n -> Her n
297her m = Her $ (m + LA.tr m)/2
298
299
300
301instance KnownNat n => Eigen (Sym n) (R n) (L n n)
302 where
303 eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m
304 eigensystem (Sym (extract -> m)) = (mkR l, mkL v)
305 where
306 (l,v) = LA.eigSH' m
307
308instance KnownNat n => Eigen (Sq n) (C n) (M n n)
309 where
310 eigenvalues (extract -> m) = mkC . LA.eigenvalues $ m
311 eigensystem (extract -> m) = (mkC l, mkM v)
312 where
313 (l,v) = LA.eig m
314
315--------------------------------------------------------------------------------
316
317withNullspace
318 :: forall m n z . (KnownNat m, KnownNat n)
319 => L m n
320 -> (forall k . (KnownNat k) => L n k -> z)
321 -> z
322withNullspace (LA.nullspace . extract -> a) f =
323 case someNatVal $ fromIntegral $ cols a of
324 Nothing -> error "static/dynamic mismatch"
325 Just (SomeNat (_ :: Proxy k)) -> f (mkL a :: L n k)
326
327
328withCompactSVD
329 :: forall m n z . (KnownNat m, KnownNat n)
330 => L m n
331 -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z)
332 -> z
333withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f =
334 case someNatVal $ fromIntegral $ size s of
335 Nothing -> error "static/dynamic mismatch"
336 Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k)
337
338--------------------------------------------------------------------------------
339
340qr :: (KnownNat m, KnownNat n) => L m n -> (L m m, L m n)
341qr (extract -> x) = (mkL q, mkL r)
342 where
343 (q,r) = LA.qr x
344
345-- use qrRaw?
346
347--------------------------------------------------------------------------------
348
349split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p))
350split (extract -> v) = ( mkR (subVector 0 p' v) ,
351 mkR (subVector p' (size v - p') v) )
352 where
353 p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int
354
355
356headTail :: (KnownNat n, 1<=n) => R n -> (โ„, R (n-1))
357headTail = ((!0) . extract *** id) . split
358
359
360splitRows :: forall p m n . (KnownNat p, KnownNat m, KnownNat n, p<=m) => L m n -> (L p n, L (m-p) n)
361splitRows (extract -> x) = ( mkL (takeRows p' x) ,
362 mkL (dropRows p' x) )
363 where
364 p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int
365
366splitCols :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, KnownNat (n-p), p<=n) => L m n -> (L m p, L m (n-p))
367splitCols = (tr *** tr) . splitRows . tr
368
369
370toRows :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R n]
371toRows (LA.toRows . extract -> vs) = map mkR vs
372
373
374toColumns :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R m]
375toColumns (LA.toColumns . extract -> vs) = map mkR vs
376
377
378--------------------------------------------------------------------------------
379
380build
381 :: forall m n . (KnownNat n, KnownNat m)
382 => (โ„ -> โ„ -> โ„)
383 -> L m n
384build f = mkL $ LA.build (m',n') f
385 where
386 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
387 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
388
389--------------------------------------------------------------------------------
390
391withVector
392 :: forall z
393 . Vector โ„
394 -> (forall n . (KnownNat n) => R n -> z)
395 -> z
396withVector v f =
397 case someNatVal $ fromIntegral $ size v of
398 Nothing -> error "static/dynamic mismatch"
399 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m)
400
401
402withMatrix
403 :: forall z
404 . Matrix โ„
405 -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z)
406 -> z
407withMatrix a f =
408 case someNatVal $ fromIntegral $ rows a of
409 Nothing -> error "static/dynamic mismatch"
410 Just (SomeNat (_ :: Proxy m)) ->
411 case someNatVal $ fromIntegral $ cols a of
412 Nothing -> error "static/dynamic mismatch"
413 Just (SomeNat (_ :: Proxy n)) ->
414 f (mkL a :: L m n)
415
416--------------------------------------------------------------------------------
417
418class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat vec
419 where
420 mul :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => mat m k -> mat k n -> mat m n
421 app :: forall m n . (KnownNat m, KnownNat n) => mat m n -> vec n -> vec m
422 dot :: forall n . (KnownNat n) => vec n -> vec n -> field
423 cross :: vec 3 -> vec 3 -> vec 3
424 diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n
425
426
427instance Domain โ„ R L
428 where
429 mul = mulR
430 app = appR
431 dot = dotR
432 cross = crossR
433 diagR = diagRectR
434
435instance Domain โ„‚ C M
436 where
437 mul = mulC
438 app = appC
439 dot = dotC
440 cross = crossC
441 diagR = diagRectC
442
443--------------------------------------------------------------------------------
444
445mulR :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n
446
447mulR (isKonst -> Just (a,(_,k))) (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k)
448
449mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k)
450 where
451 v = a' * b'
452 n = min (size a) (size b)
453 a' = subVector 0 n a
454 b' = subVector 0 n b
455
456mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (size a) b)
457
458mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b)
459
460mulR a b = mkL (extract a LA.<> extract b)
461
462
463appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m
464appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (size w) (extract v))
465appR m v = mkR (extract m LA.#> extract v)
466
467
468dotR :: R n -> R n -> โ„
469dotR (ud1 -> u) (ud1 -> v)
470 | singleV u || singleV v = sumElements (u * v)
471 | otherwise = udot u v
472
473
474crossR :: R 3 -> R 3 -> R 3
475crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3
476 where
477 z1 = x!1*y!2-x!2*y!1
478 z2 = x!2*y!0-x!0*y!2
479 z3 = x!0*y!1-x!1*y!0
480
481--------------------------------------------------------------------------------
482
483mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n
484
485mulC (isKonstC -> Just (a,(_,k))) (isKonstC -> Just (b,_)) = konst (a * b * fromIntegral k)
486
487mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k)
488 where
489 v = a' * b'
490 n = min (size a) (size b)
491 a' = subVector 0 n a
492 b' = subVector 0 n b
493
494mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (size a) b)
495
496mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (size b) a * asRow b)
497
498mulC a b = mkM (extract a LA.<> extract b)
499
500
501appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m
502appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (size w) (extract v))
503appC m v = mkC (extract m LA.#> extract v)
504
505
506dotC :: KnownNat n => C n -> C n -> โ„‚
507dotC (unwrap -> u) (unwrap -> v)
508 | singleV u || singleV v = sumElements (conj u * v)
509 | otherwise = u LA.<ยท> v
510
511
512crossC :: C 3 -> C 3 -> C 3
513crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3])
514 where
515 z1 = x!1*y!2-x!2*y!1
516 z2 = x!2*y!0-x!0*y!2
517 z3 = x!0*y!1-x!1*y!0
518
519--------------------------------------------------------------------------------
520
521diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ„ -> R k -> L m n
522diagRectR x v = mkL (asRow (vjoin [scalar x, ev, zeros]))
523 where
524 ev = extract v
525 zeros = LA.konst x (max 0 ((min m' n') - size ev))
526 m' = fromIntegral . natVal $ (undefined :: Proxy m)
527 n' = fromIntegral . natVal $ (undefined :: Proxy n)
528
529
530diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ„‚ -> C k -> M m n
531diagRectC x v = mkM (asRow (vjoin [scalar x, ev, zeros]))
532 where
533 ev = extract v
534 zeros = LA.konst x (max 0 ((min m' n') - size ev))
535 m' = fromIntegral . natVal $ (undefined :: Proxy m)
536 n' = fromIntegral . natVal $ (undefined :: Proxy n)
537
538--------------------------------------------------------------------------------
539
540mean :: (KnownNat n, 1<=n) => R n -> โ„
541mean v = v <ยท> (1/dim)
542
543test :: (Bool, IO ())
544test = (ok,info)
545 where
546 ok = extract (eye :: Sq 5) == ident 5
547 && (unwrap .unSym) (mTm sm :: Sym 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..]
548 && unwrap (tm :: L 3 5) == LA.matrix 5 [1..15]
549 && thingS == thingD
550 && precS == precD
551 && withVector (LA.vector [1..15]) sumV == sumElements (LA.fromList [1..15])
552
553 info = do
554 print $ u
555 print $ v
556 print (eye :: Sq 3)
557 print $ ((u & 5) + 1) <ยท> v
558 print (tm :: L 2 5)
559 print (tm <> sm :: L 2 3)
560 print thingS
561 print thingD
562 print precS
563 print precD
564 print $ withVector (LA.vector [1..15]) sumV
565 splittest
566
567 sumV w = w <ยท> konst 1
568
569 u = vec2 3 5
570
571 ๐•ง x = vector [x] :: R 1
572
573 v = ๐•ง 2 & 4 & 7
574
575 tm :: GL
576 tm = lmat 0 [1..]
577
578 lmat :: forall m n . (KnownNat m, KnownNat n) => โ„ -> [โ„] -> L m n
579 lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z
580 where
581 m' = fromIntegral . natVal $ (undefined :: Proxy m)
582 n' = fromIntegral . natVal $ (undefined :: Proxy n)
583
584 sm :: GSq
585 sm = lmat 0 [1..]
586
587 thingS = (u & 1) <ยท> tr q #> q #> v
588 where
589 q = tm :: L 10 3
590
591 thingD = vjoin [ud1 u, 1] LA.<ยท> tr m LA.#> m LA.#> ud1 v
592 where
593 m = LA.matrix 3 [1..30]
594
595 precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <ยท> konst 2 #> v
596 precD = 1 + 2 * vjoin[ud1 u, 6] LA.<ยท> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v
597
598
599splittest
600 = do
601 let v = range :: R 7
602 a = snd (split v) :: R 4
603 print $ a
604 print $ snd . headTail . snd . headTail $ v
605 print $ first (vec3 1 2 3)
606 print $ second (vec3 1 2 3)
607 print $ third (vec3 1 2 3)
608 print $ (snd $ splitRows eye :: L 4 6)
609 where
610 first v = fst . headTail $ v
611 second v = first . snd . headTail $ v
612 third v = first . snd . headTail . snd . headTail $ v
613
614
615instance (KnownNat n', KnownNat m') => Testable (L n' m')
616 where
617 checkT _ = test
618
619