diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Real.hs')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Real.hs | 480 |
1 files changed, 0 insertions, 480 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs deleted file mode 100644 index 97c462e..0000000 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ /dev/null | |||
@@ -1,480 +0,0 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE FlexibleContexts #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE EmptyDataDecls #-} | ||
9 | {-# LANGUAGE Rank2Types #-} | ||
10 | {-# LANGUAGE FlexibleInstances #-} | ||
11 | {-# LANGUAGE TypeOperators #-} | ||
12 | {-# LANGUAGE ViewPatterns #-} | ||
13 | {-# LANGUAGE GADTs #-} | ||
14 | {-# LANGUAGE OverlappingInstances #-} | ||
15 | {-# LANGUAGE TypeFamilies #-} | ||
16 | |||
17 | |||
18 | {- | | ||
19 | Module : Numeric.LinearAlgebra.Real | ||
20 | Copyright : (c) Alberto Ruiz 2006-14 | ||
21 | License : BSD3 | ||
22 | Stability : experimental | ||
23 | |||
24 | Experimental interface for real arrays with statically checked dimensions. | ||
25 | |||
26 | -} | ||
27 | |||
28 | module Numeric.LinearAlgebra.Real( | ||
29 | -- * Vector | ||
30 | R, C, | ||
31 | vec2, vec3, vec4, (&), (#), split, headTail, | ||
32 | vector, | ||
33 | linspace, range, dim, | ||
34 | -- * Matrix | ||
35 | L, Sq, M, def, | ||
36 | row, col, (ยฆ),(โโ), splitRows, splitCols, | ||
37 | unrow, uncol, | ||
38 | |||
39 | eye, | ||
40 | diagR, diag, | ||
41 | blockAt, | ||
42 | matrix, | ||
43 | -- * Products | ||
44 | (<>),(#>),(<ยท>), | ||
45 | -- * Linear Systems | ||
46 | linSolve, (<\>), | ||
47 | -- * Factorizations | ||
48 | svd, svdTall, svdFlat, Eigen(..), | ||
49 | -- * Pretty printing | ||
50 | Disp(..), | ||
51 | -- * Misc | ||
52 | withVector, withMatrix, | ||
53 | Sized(..), Diag(..), Sym, sym, Her, her, ๐, | ||
54 | module Numeric.HMatrix | ||
55 | ) where | ||
56 | |||
57 | |||
58 | import GHC.TypeLits | ||
59 | import Numeric.HMatrix hiding ( | ||
60 | (<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โโ),row,col,vect,mat,linspace, | ||
61 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build) | ||
62 | import qualified Numeric.HMatrix as LA | ||
63 | import Data.Proxy(Proxy) | ||
64 | import Numeric.LinearAlgebra.Static | ||
65 | import Control.Arrow((***)) | ||
66 | |||
67 | |||
68 | ๐ :: Sized โ s c => s | ||
69 | ๐ = konst iC | ||
70 | |||
71 | |||
72 | |||
73 | |||
74 | |||
75 | ud1 :: R n -> Vector โ | ||
76 | ud1 (R (Dim v)) = v | ||
77 | |||
78 | |||
79 | infixl 4 & | ||
80 | (&) :: forall n . KnownNat n | ||
81 | => R n -> โ -> R (n+1) | ||
82 | u & x = u # (konst x :: R 1) | ||
83 | |||
84 | infixl 4 # | ||
85 | (#) :: forall n m . (KnownNat n, KnownNat m) | ||
86 | => R n -> R m -> R (n+m) | ||
87 | (R u) # (R v) = R (vconcat u v) | ||
88 | |||
89 | |||
90 | |||
91 | vec2 :: โ -> โ -> R 2 | ||
92 | vec2 a b = R (gvec2 a b) | ||
93 | |||
94 | vec3 :: โ -> โ -> โ -> R 3 | ||
95 | vec3 a b c = R (gvec3 a b c) | ||
96 | |||
97 | |||
98 | vec4 :: โ -> โ -> โ -> โ -> R 4 | ||
99 | vec4 a b c d = R (gvec4 a b c d) | ||
100 | |||
101 | vector :: KnownNat n => [โ] -> R n | ||
102 | vector = fromList | ||
103 | |||
104 | matrix :: (KnownNat m, KnownNat n) => [โ] -> L m n | ||
105 | matrix = fromList | ||
106 | |||
107 | linspace :: forall n . KnownNat n => (โ,โ) -> R n | ||
108 | linspace (a,b) = mkR (LA.linspace d (a,b)) | ||
109 | where | ||
110 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
111 | |||
112 | range :: forall n . KnownNat n => R n | ||
113 | range = mkR (LA.linspace d (1,fromIntegral d)) | ||
114 | where | ||
115 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
116 | |||
117 | dim :: forall n . KnownNat n => R n | ||
118 | dim = mkR (scalar d) | ||
119 | where | ||
120 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
121 | |||
122 | |||
123 | -------------------------------------------------------------------------------- | ||
124 | |||
125 | |||
126 | ud2 :: L m n -> Matrix โ | ||
127 | ud2 (L (Dim (Dim x))) = x | ||
128 | |||
129 | |||
130 | -------------------------------------------------------------------------------- | ||
131 | -------------------------------------------------------------------------------- | ||
132 | |||
133 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ -> R k -> L m n | ||
134 | diagR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) | ||
135 | where | ||
136 | ev = extract v | ||
137 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | ||
138 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
139 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
140 | |||
141 | diag :: KnownNat n => R n -> Sq n | ||
142 | diag = diagR 0 | ||
143 | |||
144 | eye :: KnownNat n => Sq n | ||
145 | eye = diag 1 | ||
146 | |||
147 | -------------------------------------------------------------------------------- | ||
148 | |||
149 | blockAt :: forall m n . (KnownNat m, KnownNat n) => โ -> Int -> Int -> Matrix Double -> L m n | ||
150 | blockAt x r c a = mkL res | ||
151 | where | ||
152 | z = scalar x | ||
153 | z1 = LA.konst x (r,c) | ||
154 | z2 = LA.konst x (max 0 (m'-(ra+r)), max 0 (n'-(ca+c))) | ||
155 | ra = min (rows a) . max 0 $ m'-r | ||
156 | ca = min (cols a) . max 0 $ n'-c | ||
157 | sa = subMatrix (0,0) (ra, ca) a | ||
158 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
159 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
160 | res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] | ||
161 | |||
162 | |||
163 | |||
164 | |||
165 | |||
166 | -------------------------------------------------------------------------------- | ||
167 | |||
168 | |||
169 | row :: R n -> L 1 n | ||
170 | row = mkL . asRow . ud1 | ||
171 | |||
172 | --col :: R n -> L n 1 | ||
173 | col v = tr . row $ v | ||
174 | |||
175 | unrow :: L 1 n -> R n | ||
176 | unrow = mkR . head . toRows . ud2 | ||
177 | |||
178 | --uncol :: L n 1 -> R n | ||
179 | uncol v = unrow . tr $ v | ||
180 | |||
181 | |||
182 | infixl 2 โโ | ||
183 | (โโ) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c | ||
184 | a โโ b = mkL (extract a LA.โโ extract b) | ||
185 | |||
186 | |||
187 | infixl 3 ยฆ | ||
188 | -- (ยฆ) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) | ||
189 | a ยฆ b = tr (tr a โโ tr b) | ||
190 | |||
191 | |||
192 | type Sq n = L n n | ||
193 | --type CSq n = CL n n | ||
194 | |||
195 | type GL = (KnownNat n, KnownNat m) => L m n | ||
196 | type GSq = KnownNat n => Sq n | ||
197 | |||
198 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (โ,(Int,Int)) | ||
199 | isKonst (unwrap -> x) | ||
200 | | singleM x = Just (x `atIndex` (0,0), (m',n')) | ||
201 | | otherwise = Nothing | ||
202 | where | ||
203 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
204 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
205 | |||
206 | |||
207 | |||
208 | |||
209 | infixr 8 <> | ||
210 | (<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n | ||
211 | |||
212 | (isKonst -> Just (a,(_,k))) <> (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k) | ||
213 | |||
214 | (isDiag -> Just (0,a,_)) <> (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) | ||
215 | where | ||
216 | v = a' * b' | ||
217 | n = min (size a) (size b) | ||
218 | a' = subVector 0 n a | ||
219 | b' = subVector 0 n b | ||
220 | |||
221 | (isDiag -> Just (0,a,_)) <> (extract -> b) = mkL (asColumn a * takeRows (size a) b) | ||
222 | |||
223 | (extract -> a) <> (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) | ||
224 | |||
225 | a <> b = mkL (extract a LA.<> extract b) | ||
226 | |||
227 | infixr 8 #> | ||
228 | (#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m | ||
229 | (isDiag -> Just (0, w, _)) #> v = mkR (w * subVector 0 (size w) (extract v)) | ||
230 | m #> v = mkR (extract m LA.#> extract v) | ||
231 | |||
232 | |||
233 | infixr 8 <ยท> | ||
234 | (<ยท>) :: R n -> R n -> โ | ||
235 | (ud1 -> u) <ยท> (ud1 -> v) | ||
236 | | singleV u || singleV v = sumElements (u * v) | ||
237 | | otherwise = udot u v | ||
238 | |||
239 | -------------------------------------------------------------------------------- | ||
240 | |||
241 | {- | ||
242 | class Minim (n :: Nat) (m :: Nat) | ||
243 | where | ||
244 | type Mini n m :: Nat | ||
245 | |||
246 | instance forall (n :: Nat) . Minim n n | ||
247 | where | ||
248 | type Mini n n = n | ||
249 | |||
250 | |||
251 | instance forall (n :: Nat) (m :: Nat) . (n <= m+1) => Minim n m | ||
252 | where | ||
253 | type Mini n m = n | ||
254 | |||
255 | instance forall (n :: Nat) (m :: Nat) . (m <= n+1) => Minim n m | ||
256 | where | ||
257 | type Mini n m = m | ||
258 | -} | ||
259 | |||
260 | class Diag m d | m -> d | ||
261 | where | ||
262 | takeDiag :: m -> d | ||
263 | |||
264 | |||
265 | |||
266 | instance forall n . (KnownNat n) => Diag (L n n) (R n) | ||
267 | where | ||
268 | takeDiag m = mkR (LA.takeDiag (extract m)) | ||
269 | |||
270 | |||
271 | instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m) | ||
272 | where | ||
273 | takeDiag m = mkR (LA.takeDiag (extract m)) | ||
274 | |||
275 | |||
276 | instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n) | ||
277 | where | ||
278 | takeDiag m = mkR (LA.takeDiag (extract m)) | ||
279 | |||
280 | |||
281 | -------------------------------------------------------------------------------- | ||
282 | |||
283 | linSolve :: (KnownNat m, KnownNat n) => L m m -> L m n -> Maybe (L m n) | ||
284 | linSolve (extract -> a) (extract -> b) = fmap mkL (LA.linearSolve a b) | ||
285 | |||
286 | (<\>) :: (KnownNat m, KnownNat n, KnownNat r) => L m n -> L m r -> L n r | ||
287 | (extract -> a) <\> (extract -> b) = mkL (a LA.<\> b) | ||
288 | |||
289 | svd :: (KnownNat m, KnownNat n) => L m n -> (L m m, R n, L n n) | ||
290 | svd (extract -> m) = (mkL u, mkR s', mkL v) | ||
291 | where | ||
292 | (u,s,v) = LA.svd m | ||
293 | s' = vjoin [s, z] | ||
294 | z = LA.konst 0 (max 0 (cols m - size s)) | ||
295 | |||
296 | |||
297 | svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) | ||
298 | svdTall (extract -> m) = (mkL u, mkR s, mkL v) | ||
299 | where | ||
300 | (u,s,v) = LA.thinSVD m | ||
301 | |||
302 | |||
303 | svdFlat :: (KnownNat m, KnownNat n, m <= n) => L m n -> (L m m, R m, L m n) | ||
304 | svdFlat (extract -> m) = (mkL u, mkR s, mkL v) | ||
305 | where | ||
306 | (u,s,v) = LA.thinSVD m | ||
307 | |||
308 | -------------------------------------------------------------------------------- | ||
309 | |||
310 | class Eigen m l v | m -> l, m -> v | ||
311 | where | ||
312 | eigensystem :: m -> (l,v) | ||
313 | eigenvalues :: m -> l | ||
314 | |||
315 | newtype Sym n = Sym (Sq n) deriving Show | ||
316 | |||
317 | newtype Her n = Her (M n n) | ||
318 | |||
319 | sym :: KnownNat n => Sq n -> Sym n | ||
320 | sym m = Sym $ (m + tr m)/2 | ||
321 | |||
322 | her :: KnownNat n => M n n -> Her n | ||
323 | her m = Her $ (m + tr m)/2 | ||
324 | |||
325 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) | ||
326 | where | ||
327 | eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m | ||
328 | eigensystem (Sym (extract -> m)) = (mkR l, mkL v) | ||
329 | where | ||
330 | (l,v) = LA.eigSH' m | ||
331 | |||
332 | instance KnownNat n => Eigen (Sq n) (C n) (M n n) | ||
333 | where | ||
334 | eigenvalues (extract -> m) = mkC . LA.eigenvalues $ m | ||
335 | eigensystem (extract -> m) = (mkC l, mkM v) | ||
336 | where | ||
337 | (l,v) = LA.eig m | ||
338 | |||
339 | -------------------------------------------------------------------------------- | ||
340 | |||
341 | split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) | ||
342 | split (extract -> v) = ( mkR (subVector 0 p' v) , | ||
343 | mkR (subVector p' (size v - p') v) ) | ||
344 | where | ||
345 | p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int | ||
346 | |||
347 | |||
348 | headTail :: (KnownNat n, 1<=n) => R n -> (โ, R (n-1)) | ||
349 | headTail = ((!0) . extract *** id) . split | ||
350 | |||
351 | |||
352 | splitRows :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, p<=m) => L m n -> (L p n, L (m-p) n) | ||
353 | splitRows (extract -> x) = ( mkL (takeRows p' x) , | ||
354 | mkL (dropRows p' x) ) | ||
355 | where | ||
356 | p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int | ||
357 | |||
358 | splitCols :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, KnownNat (n-p), p<=n) => L m n -> (L m p, L m (n-p)) | ||
359 | splitCols = (tr *** tr) . splitRows . tr | ||
360 | |||
361 | |||
362 | splittest | ||
363 | = do | ||
364 | let v = range :: R 7 | ||
365 | a = snd (split v) :: R 4 | ||
366 | print $ a | ||
367 | print $ snd . headTail . snd . headTail $ v | ||
368 | print $ first (vec3 1 2 3) | ||
369 | print $ second (vec3 1 2 3) | ||
370 | print $ third (vec3 1 2 3) | ||
371 | print $ (snd $ splitRows eye :: L 4 6) | ||
372 | where | ||
373 | first v = fst . headTail $ v | ||
374 | second v = first . snd . headTail $ v | ||
375 | third v = first . snd . headTail . snd . headTail $ v | ||
376 | |||
377 | -------------------------------------------------------------------------------- | ||
378 | |||
379 | def | ||
380 | :: forall m n . (KnownNat n, KnownNat m) | ||
381 | => (โ -> โ -> โ) | ||
382 | -> L m n | ||
383 | def f = mkL $ LA.build (m',n') f | ||
384 | where | ||
385 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
386 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
387 | |||
388 | -------------------------------------------------------------------------------- | ||
389 | |||
390 | withVector | ||
391 | :: forall z | ||
392 | . Vector โ | ||
393 | -> (forall n . (KnownNat n) => R n -> z) | ||
394 | -> z | ||
395 | withVector v f = | ||
396 | case someNatVal $ fromIntegral $ size v of | ||
397 | Nothing -> error "static/dynamic mismatch" | ||
398 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) | ||
399 | |||
400 | |||
401 | withMatrix | ||
402 | :: forall z | ||
403 | . Matrix โ | ||
404 | -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z) | ||
405 | -> z | ||
406 | withMatrix a f = | ||
407 | case someNatVal $ fromIntegral $ rows a of | ||
408 | Nothing -> error "static/dynamic mismatch" | ||
409 | Just (SomeNat (_ :: Proxy m)) -> | ||
410 | case someNatVal $ fromIntegral $ cols a of | ||
411 | Nothing -> error "static/dynamic mismatch" | ||
412 | Just (SomeNat (_ :: Proxy n)) -> | ||
413 | f (mkL a :: L n m) | ||
414 | |||
415 | -------------------------------------------------------------------------------- | ||
416 | |||
417 | test :: (Bool, IO ()) | ||
418 | test = (ok,info) | ||
419 | where | ||
420 | ok = extract (eye :: Sq 5) == ident 5 | ||
421 | && unwrap (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] | ||
422 | && unwrap (tm :: L 3 5) == LA.mat 5 [1..15] | ||
423 | && thingS == thingD | ||
424 | && precS == precD | ||
425 | && withVector (LA.vect [1..15]) sumV == sumElements (LA.fromList [1..15]) | ||
426 | |||
427 | info = do | ||
428 | print $ u | ||
429 | print $ v | ||
430 | print (eye :: Sq 3) | ||
431 | print $ ((u & 5) + 1) <ยท> v | ||
432 | print (tm :: L 2 5) | ||
433 | print (tm <> sm :: L 2 3) | ||
434 | print thingS | ||
435 | print thingD | ||
436 | print precS | ||
437 | print precD | ||
438 | print $ withVector (LA.vect [1..15]) sumV | ||
439 | splittest | ||
440 | |||
441 | sumV w = w <ยท> konst 1 | ||
442 | |||
443 | u = vec2 3 5 | ||
444 | |||
445 | ๐ง x = vector [x] :: R 1 | ||
446 | |||
447 | v = ๐ง 2 & 4 & 7 | ||
448 | |||
449 | -- mTm :: L n m -> Sq m | ||
450 | mTm a = tr a <> a | ||
451 | |||
452 | tm :: GL | ||
453 | tm = lmat 0 [1..] | ||
454 | |||
455 | lmat :: forall m n . (KnownNat m, KnownNat n) => โ -> [โ] -> L m n | ||
456 | lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z | ||
457 | where | ||
458 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
459 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
460 | |||
461 | sm :: GSq | ||
462 | sm = lmat 0 [1..] | ||
463 | |||
464 | thingS = (u & 1) <ยท> tr q #> q #> v | ||
465 | where | ||
466 | q = tm :: L 10 3 | ||
467 | |||
468 | thingD = vjoin [ud1 u, 1] LA.<ยท> tr m LA.#> m LA.#> ud1 v | ||
469 | where | ||
470 | m = LA.mat 3 [1..30] | ||
471 | |||
472 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <ยท> konst 2 #> v | ||
473 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<ยท> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v | ||
474 | |||
475 | |||
476 | instance (KnownNat n', KnownNat m') => Testable (L n' m') | ||
477 | where | ||
478 | checkT _ = test | ||
479 | |||
480 | |||