summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Real.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Real.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs480
1 files changed, 0 insertions, 480 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
deleted file mode 100644
index 97c462e..0000000
--- a/packages/base/src/Numeric/LinearAlgebra/Real.hs
+++ /dev/null
@@ -1,480 +0,0 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE FlexibleContexts #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE EmptyDataDecls #-}
9{-# LANGUAGE Rank2Types #-}
10{-# LANGUAGE FlexibleInstances #-}
11{-# LANGUAGE TypeOperators #-}
12{-# LANGUAGE ViewPatterns #-}
13{-# LANGUAGE GADTs #-}
14{-# LANGUAGE OverlappingInstances #-}
15{-# LANGUAGE TypeFamilies #-}
16
17
18{- |
19Module : Numeric.LinearAlgebra.Real
20Copyright : (c) Alberto Ruiz 2006-14
21License : BSD3
22Stability : experimental
23
24Experimental interface for real arrays with statically checked dimensions.
25
26-}
27
28module Numeric.LinearAlgebra.Real(
29 -- * Vector
30 R, C,
31 vec2, vec3, vec4, (&), (#), split, headTail,
32 vector,
33 linspace, range, dim,
34 -- * Matrix
35 L, Sq, M, def,
36 row, col, (ยฆ),(โ€”โ€”), splitRows, splitCols,
37 unrow, uncol,
38
39 eye,
40 diagR, diag,
41 blockAt,
42 matrix,
43 -- * Products
44 (<>),(#>),(<ยท>),
45 -- * Linear Systems
46 linSolve, (<\>),
47 -- * Factorizations
48 svd, svdTall, svdFlat, Eigen(..),
49 -- * Pretty printing
50 Disp(..),
51 -- * Misc
52 withVector, withMatrix,
53 Sized(..), Diag(..), Sym, sym, Her, her, ๐‘–,
54 module Numeric.HMatrix
55) where
56
57
58import GHC.TypeLits
59import Numeric.HMatrix hiding (
60 (<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โ€”โ€”),row,col,vect,mat,linspace,
61 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build)
62import qualified Numeric.HMatrix as LA
63import Data.Proxy(Proxy)
64import Numeric.LinearAlgebra.Static
65import Control.Arrow((***))
66
67
68๐‘– :: Sized โ„‚ s c => s
69๐‘– = konst iC
70
71
72
73
74
75ud1 :: R n -> Vector โ„
76ud1 (R (Dim v)) = v
77
78
79infixl 4 &
80(&) :: forall n . KnownNat n
81 => R n -> โ„ -> R (n+1)
82u & x = u # (konst x :: R 1)
83
84infixl 4 #
85(#) :: forall n m . (KnownNat n, KnownNat m)
86 => R n -> R m -> R (n+m)
87(R u) # (R v) = R (vconcat u v)
88
89
90
91vec2 :: โ„ -> โ„ -> R 2
92vec2 a b = R (gvec2 a b)
93
94vec3 :: โ„ -> โ„ -> โ„ -> R 3
95vec3 a b c = R (gvec3 a b c)
96
97
98vec4 :: โ„ -> โ„ -> โ„ -> โ„ -> R 4
99vec4 a b c d = R (gvec4 a b c d)
100
101vector :: KnownNat n => [โ„] -> R n
102vector = fromList
103
104matrix :: (KnownNat m, KnownNat n) => [โ„] -> L m n
105matrix = fromList
106
107linspace :: forall n . KnownNat n => (โ„,โ„) -> R n
108linspace (a,b) = mkR (LA.linspace d (a,b))
109 where
110 d = fromIntegral . natVal $ (undefined :: Proxy n)
111
112range :: forall n . KnownNat n => R n
113range = mkR (LA.linspace d (1,fromIntegral d))
114 where
115 d = fromIntegral . natVal $ (undefined :: Proxy n)
116
117dim :: forall n . KnownNat n => R n
118dim = mkR (scalar d)
119 where
120 d = fromIntegral . natVal $ (undefined :: Proxy n)
121
122
123--------------------------------------------------------------------------------
124
125
126ud2 :: L m n -> Matrix โ„
127ud2 (L (Dim (Dim x))) = x
128
129
130--------------------------------------------------------------------------------
131--------------------------------------------------------------------------------
132
133diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ„ -> R k -> L m n
134diagR x v = mkL (asRow (vjoin [scalar x, ev, zeros]))
135 where
136 ev = extract v
137 zeros = LA.konst x (max 0 ((min m' n') - size ev))
138 m' = fromIntegral . natVal $ (undefined :: Proxy m)
139 n' = fromIntegral . natVal $ (undefined :: Proxy n)
140
141diag :: KnownNat n => R n -> Sq n
142diag = diagR 0
143
144eye :: KnownNat n => Sq n
145eye = diag 1
146
147--------------------------------------------------------------------------------
148
149blockAt :: forall m n . (KnownNat m, KnownNat n) => โ„ -> Int -> Int -> Matrix Double -> L m n
150blockAt x r c a = mkL res
151 where
152 z = scalar x
153 z1 = LA.konst x (r,c)
154 z2 = LA.konst x (max 0 (m'-(ra+r)), max 0 (n'-(ca+c)))
155 ra = min (rows a) . max 0 $ m'-r
156 ca = min (cols a) . max 0 $ n'-c
157 sa = subMatrix (0,0) (ra, ca) a
158 m' = fromIntegral . natVal $ (undefined :: Proxy m)
159 n' = fromIntegral . natVal $ (undefined :: Proxy n)
160 res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]]
161
162
163
164
165
166--------------------------------------------------------------------------------
167
168
169row :: R n -> L 1 n
170row = mkL . asRow . ud1
171
172--col :: R n -> L n 1
173col v = tr . row $ v
174
175unrow :: L 1 n -> R n
176unrow = mkR . head . toRows . ud2
177
178--uncol :: L n 1 -> R n
179uncol v = unrow . tr $ v
180
181
182infixl 2 โ€”โ€”
183(โ€”โ€”) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c
184a โ€”โ€” b = mkL (extract a LA.โ€”โ€” extract b)
185
186
187infixl 3 ยฆ
188-- (ยฆ) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2)
189a ยฆ b = tr (tr a โ€”โ€” tr b)
190
191
192type Sq n = L n n
193--type CSq n = CL n n
194
195type GL = (KnownNat n, KnownNat m) => L m n
196type GSq = KnownNat n => Sq n
197
198isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (โ„,(Int,Int))
199isKonst (unwrap -> x)
200 | singleM x = Just (x `atIndex` (0,0), (m',n'))
201 | otherwise = Nothing
202 where
203 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
204 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
205
206
207
208
209infixr 8 <>
210(<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n
211
212(isKonst -> Just (a,(_,k))) <> (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k)
213
214(isDiag -> Just (0,a,_)) <> (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k)
215 where
216 v = a' * b'
217 n = min (size a) (size b)
218 a' = subVector 0 n a
219 b' = subVector 0 n b
220
221(isDiag -> Just (0,a,_)) <> (extract -> b) = mkL (asColumn a * takeRows (size a) b)
222
223(extract -> a) <> (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b)
224
225a <> b = mkL (extract a LA.<> extract b)
226
227infixr 8 #>
228(#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m
229(isDiag -> Just (0, w, _)) #> v = mkR (w * subVector 0 (size w) (extract v))
230m #> v = mkR (extract m LA.#> extract v)
231
232
233infixr 8 <ยท>
234(<ยท>) :: R n -> R n -> โ„
235(ud1 -> u) <ยท> (ud1 -> v)
236 | singleV u || singleV v = sumElements (u * v)
237 | otherwise = udot u v
238
239--------------------------------------------------------------------------------
240
241{-
242class Minim (n :: Nat) (m :: Nat)
243 where
244 type Mini n m :: Nat
245
246instance forall (n :: Nat) . Minim n n
247 where
248 type Mini n n = n
249
250
251instance forall (n :: Nat) (m :: Nat) . (n <= m+1) => Minim n m
252 where
253 type Mini n m = n
254
255instance forall (n :: Nat) (m :: Nat) . (m <= n+1) => Minim n m
256 where
257 type Mini n m = m
258-}
259
260class Diag m d | m -> d
261 where
262 takeDiag :: m -> d
263
264
265
266instance forall n . (KnownNat n) => Diag (L n n) (R n)
267 where
268 takeDiag m = mkR (LA.takeDiag (extract m))
269
270
271instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m)
272 where
273 takeDiag m = mkR (LA.takeDiag (extract m))
274
275
276instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n)
277 where
278 takeDiag m = mkR (LA.takeDiag (extract m))
279
280
281--------------------------------------------------------------------------------
282
283linSolve :: (KnownNat m, KnownNat n) => L m m -> L m n -> Maybe (L m n)
284linSolve (extract -> a) (extract -> b) = fmap mkL (LA.linearSolve a b)
285
286(<\>) :: (KnownNat m, KnownNat n, KnownNat r) => L m n -> L m r -> L n r
287(extract -> a) <\> (extract -> b) = mkL (a LA.<\> b)
288
289svd :: (KnownNat m, KnownNat n) => L m n -> (L m m, R n, L n n)
290svd (extract -> m) = (mkL u, mkR s', mkL v)
291 where
292 (u,s,v) = LA.svd m
293 s' = vjoin [s, z]
294 z = LA.konst 0 (max 0 (cols m - size s))
295
296
297svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n)
298svdTall (extract -> m) = (mkL u, mkR s, mkL v)
299 where
300 (u,s,v) = LA.thinSVD m
301
302
303svdFlat :: (KnownNat m, KnownNat n, m <= n) => L m n -> (L m m, R m, L m n)
304svdFlat (extract -> m) = (mkL u, mkR s, mkL v)
305 where
306 (u,s,v) = LA.thinSVD m
307
308--------------------------------------------------------------------------------
309
310class Eigen m l v | m -> l, m -> v
311 where
312 eigensystem :: m -> (l,v)
313 eigenvalues :: m -> l
314
315newtype Sym n = Sym (Sq n) deriving Show
316
317newtype Her n = Her (M n n)
318
319sym :: KnownNat n => Sq n -> Sym n
320sym m = Sym $ (m + tr m)/2
321
322her :: KnownNat n => M n n -> Her n
323her m = Her $ (m + tr m)/2
324
325instance KnownNat n => Eigen (Sym n) (R n) (L n n)
326 where
327 eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m
328 eigensystem (Sym (extract -> m)) = (mkR l, mkL v)
329 where
330 (l,v) = LA.eigSH' m
331
332instance KnownNat n => Eigen (Sq n) (C n) (M n n)
333 where
334 eigenvalues (extract -> m) = mkC . LA.eigenvalues $ m
335 eigensystem (extract -> m) = (mkC l, mkM v)
336 where
337 (l,v) = LA.eig m
338
339--------------------------------------------------------------------------------
340
341split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p))
342split (extract -> v) = ( mkR (subVector 0 p' v) ,
343 mkR (subVector p' (size v - p') v) )
344 where
345 p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int
346
347
348headTail :: (KnownNat n, 1<=n) => R n -> (โ„, R (n-1))
349headTail = ((!0) . extract *** id) . split
350
351
352splitRows :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, p<=m) => L m n -> (L p n, L (m-p) n)
353splitRows (extract -> x) = ( mkL (takeRows p' x) ,
354 mkL (dropRows p' x) )
355 where
356 p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int
357
358splitCols :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, KnownNat (n-p), p<=n) => L m n -> (L m p, L m (n-p))
359splitCols = (tr *** tr) . splitRows . tr
360
361
362splittest
363 = do
364 let v = range :: R 7
365 a = snd (split v) :: R 4
366 print $ a
367 print $ snd . headTail . snd . headTail $ v
368 print $ first (vec3 1 2 3)
369 print $ second (vec3 1 2 3)
370 print $ third (vec3 1 2 3)
371 print $ (snd $ splitRows eye :: L 4 6)
372 where
373 first v = fst . headTail $ v
374 second v = first . snd . headTail $ v
375 third v = first . snd . headTail . snd . headTail $ v
376
377--------------------------------------------------------------------------------
378
379def
380 :: forall m n . (KnownNat n, KnownNat m)
381 => (โ„ -> โ„ -> โ„)
382 -> L m n
383def f = mkL $ LA.build (m',n') f
384 where
385 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
386 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
387
388--------------------------------------------------------------------------------
389
390withVector
391 :: forall z
392 . Vector โ„
393 -> (forall n . (KnownNat n) => R n -> z)
394 -> z
395withVector v f =
396 case someNatVal $ fromIntegral $ size v of
397 Nothing -> error "static/dynamic mismatch"
398 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m)
399
400
401withMatrix
402 :: forall z
403 . Matrix โ„
404 -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z)
405 -> z
406withMatrix a f =
407 case someNatVal $ fromIntegral $ rows a of
408 Nothing -> error "static/dynamic mismatch"
409 Just (SomeNat (_ :: Proxy m)) ->
410 case someNatVal $ fromIntegral $ cols a of
411 Nothing -> error "static/dynamic mismatch"
412 Just (SomeNat (_ :: Proxy n)) ->
413 f (mkL a :: L n m)
414
415--------------------------------------------------------------------------------
416
417test :: (Bool, IO ())
418test = (ok,info)
419 where
420 ok = extract (eye :: Sq 5) == ident 5
421 && unwrap (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..]
422 && unwrap (tm :: L 3 5) == LA.mat 5 [1..15]
423 && thingS == thingD
424 && precS == precD
425 && withVector (LA.vect [1..15]) sumV == sumElements (LA.fromList [1..15])
426
427 info = do
428 print $ u
429 print $ v
430 print (eye :: Sq 3)
431 print $ ((u & 5) + 1) <ยท> v
432 print (tm :: L 2 5)
433 print (tm <> sm :: L 2 3)
434 print thingS
435 print thingD
436 print precS
437 print precD
438 print $ withVector (LA.vect [1..15]) sumV
439 splittest
440
441 sumV w = w <ยท> konst 1
442
443 u = vec2 3 5
444
445 ๐•ง x = vector [x] :: R 1
446
447 v = ๐•ง 2 & 4 & 7
448
449-- mTm :: L n m -> Sq m
450 mTm a = tr a <> a
451
452 tm :: GL
453 tm = lmat 0 [1..]
454
455 lmat :: forall m n . (KnownNat m, KnownNat n) => โ„ -> [โ„] -> L m n
456 lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z
457 where
458 m' = fromIntegral . natVal $ (undefined :: Proxy m)
459 n' = fromIntegral . natVal $ (undefined :: Proxy n)
460
461 sm :: GSq
462 sm = lmat 0 [1..]
463
464 thingS = (u & 1) <ยท> tr q #> q #> v
465 where
466 q = tm :: L 10 3
467
468 thingD = vjoin [ud1 u, 1] LA.<ยท> tr m LA.#> m LA.#> ud1 v
469 where
470 m = LA.mat 3 [1..30]
471
472 precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <ยท> konst 2 #> v
473 precD = 1 + 2 * vjoin[ud1 u, 6] LA.<ยท> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v
474
475
476instance (KnownNat n', KnownNat m') => Testable (L n' m')
477 where
478 checkT _ = test
479
480