diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs (renamed from packages/base/src/Numeric/LinearAlgebra/Static.hs) | 4 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/tmpStatic.hs | 619 |
2 files changed, 621 insertions, 2 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs index daf8d80..c9641d5 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs | |||
@@ -14,14 +14,14 @@ | |||
14 | 14 | ||
15 | 15 | ||
16 | {- | | 16 | {- | |
17 | Module : Numeric.LinearAlgebra.Static | 17 | Module : Numeric.LinearAlgebra.Static.Internal |
18 | Copyright : (c) Alberto Ruiz 2006-14 | 18 | Copyright : (c) Alberto Ruiz 2006-14 |
19 | License : BSD3 | 19 | License : BSD3 |
20 | Stability : provisional | 20 | Stability : provisional |
21 | 21 | ||
22 | -} | 22 | -} |
23 | 23 | ||
24 | module Numeric.LinearAlgebra.Static where | 24 | module Numeric.LinearAlgebra.Static.Internal where |
25 | 25 | ||
26 | 26 | ||
27 | import GHC.TypeLits | 27 | import GHC.TypeLits |
diff --git a/packages/base/src/Numeric/LinearAlgebra/tmpStatic.hs b/packages/base/src/Numeric/LinearAlgebra/tmpStatic.hs new file mode 100644 index 0000000..4258d6b --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/tmpStatic.hs | |||
@@ -0,0 +1,619 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE FlexibleContexts #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE EmptyDataDecls #-} | ||
9 | {-# LANGUAGE Rank2Types #-} | ||
10 | {-# LANGUAGE FlexibleInstances #-} | ||
11 | {-# LANGUAGE TypeOperators #-} | ||
12 | {-# LANGUAGE ViewPatterns #-} | ||
13 | {-# LANGUAGE GADTs #-} | ||
14 | {-# LANGUAGE OverlappingInstances #-} | ||
15 | {-# LANGUAGE TypeFamilies #-} | ||
16 | |||
17 | |||
18 | {- | | ||
19 | Module : Numeric.LinearAlgebra.Static | ||
20 | Copyright : (c) Alberto Ruiz 2014 | ||
21 | License : BSD3 | ||
22 | Stability : experimental | ||
23 | |||
24 | Experimental interface with statically checked dimensions. | ||
25 | |||
26 | -} | ||
27 | |||
28 | module Numeric.LinearAlgebra.Static( | ||
29 | -- * Vector | ||
30 | โ, R, | ||
31 | vec2, vec3, vec4, (&), (#), split, headTail, | ||
32 | vector, | ||
33 | linspace, range, dim, | ||
34 | -- * Matrix | ||
35 | L, Sq, build, | ||
36 | row, col, (ยฆ),(โโ), splitRows, splitCols, | ||
37 | unrow, uncol, | ||
38 | tr, | ||
39 | eye, | ||
40 | diag, | ||
41 | blockAt, | ||
42 | matrix, | ||
43 | -- * Complex | ||
44 | C, M, Her, her, ๐, | ||
45 | -- * Products | ||
46 | (<>),(#>),(<ยท>), | ||
47 | -- * Linear Systems | ||
48 | linSolve, (<\>), | ||
49 | -- * Factorizations | ||
50 | svd, withCompactSVD, svdTall, svdFlat, Eigen(..), | ||
51 | withNullspace, qr, | ||
52 | -- * Misc | ||
53 | mean, | ||
54 | Disp(..), Domain(..), | ||
55 | withVector, withMatrix, | ||
56 | toRows, toColumns, | ||
57 | Sized(..), Diag(..), Sym, sym, mTm, unSym | ||
58 | ) where | ||
59 | |||
60 | |||
61 | import GHC.TypeLits | ||
62 | import Numeric.LinearAlgebra.HMatrix hiding ( | ||
63 | (<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โโ),row,col,vector,matrix,linspace,toRows,toColumns, | ||
64 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build, | ||
65 | qr) | ||
66 | import qualified Numeric.LinearAlgebra.HMatrix as LA | ||
67 | import Data.Proxy(Proxy) | ||
68 | import Numeric.LinearAlgebra.Static.Internal | ||
69 | import Control.Arrow((***)) | ||
70 | |||
71 | |||
72 | |||
73 | |||
74 | |||
75 | ud1 :: R n -> Vector โ | ||
76 | ud1 (R (Dim v)) = v | ||
77 | |||
78 | |||
79 | infixl 4 & | ||
80 | (&) :: forall n . (KnownNat n, 1 <= n) | ||
81 | => R n -> โ -> R (n+1) | ||
82 | u & x = u # (konst x :: R 1) | ||
83 | |||
84 | infixl 4 # | ||
85 | (#) :: forall n m . (KnownNat n, KnownNat m) | ||
86 | => R n -> R m -> R (n+m) | ||
87 | (R u) # (R v) = R (vconcat u v) | ||
88 | |||
89 | |||
90 | |||
91 | vec2 :: โ -> โ -> R 2 | ||
92 | vec2 a b = R (gvec2 a b) | ||
93 | |||
94 | vec3 :: โ -> โ -> โ -> R 3 | ||
95 | vec3 a b c = R (gvec3 a b c) | ||
96 | |||
97 | |||
98 | vec4 :: โ -> โ -> โ -> โ -> R 4 | ||
99 | vec4 a b c d = R (gvec4 a b c d) | ||
100 | |||
101 | vector :: KnownNat n => [โ] -> R n | ||
102 | vector = fromList | ||
103 | |||
104 | matrix :: (KnownNat m, KnownNat n) => [โ] -> L m n | ||
105 | matrix = fromList | ||
106 | |||
107 | linspace :: forall n . KnownNat n => (โ,โ) -> R n | ||
108 | linspace (a,b) = mkR (LA.linspace d (a,b)) | ||
109 | where | ||
110 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
111 | |||
112 | range :: forall n . KnownNat n => R n | ||
113 | range = mkR (LA.linspace d (1,fromIntegral d)) | ||
114 | where | ||
115 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
116 | |||
117 | dim :: forall n . KnownNat n => R n | ||
118 | dim = mkR (scalar d) | ||
119 | where | ||
120 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
121 | |||
122 | |||
123 | -------------------------------------------------------------------------------- | ||
124 | |||
125 | |||
126 | ud2 :: L m n -> Matrix โ | ||
127 | ud2 (L (Dim (Dim x))) = x | ||
128 | |||
129 | |||
130 | -------------------------------------------------------------------------------- | ||
131 | |||
132 | diag :: KnownNat n => R n -> Sq n | ||
133 | diag = diagR 0 | ||
134 | |||
135 | eye :: KnownNat n => Sq n | ||
136 | eye = diag 1 | ||
137 | |||
138 | -------------------------------------------------------------------------------- | ||
139 | |||
140 | blockAt :: forall m n . (KnownNat m, KnownNat n) => โ -> Int -> Int -> Matrix Double -> L m n | ||
141 | blockAt x r c a = mkL res | ||
142 | where | ||
143 | z = scalar x | ||
144 | z1 = LA.konst x (r,c) | ||
145 | z2 = LA.konst x (max 0 (m'-(ra+r)), max 0 (n'-(ca+c))) | ||
146 | ra = min (rows a) . max 0 $ m'-r | ||
147 | ca = min (cols a) . max 0 $ n'-c | ||
148 | sa = subMatrix (0,0) (ra, ca) a | ||
149 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
150 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
151 | res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] | ||
152 | |||
153 | |||
154 | |||
155 | |||
156 | |||
157 | -------------------------------------------------------------------------------- | ||
158 | |||
159 | |||
160 | row :: R n -> L 1 n | ||
161 | row = mkL . asRow . ud1 | ||
162 | |||
163 | --col :: R n -> L n 1 | ||
164 | col v = tr . row $ v | ||
165 | |||
166 | unrow :: L 1 n -> R n | ||
167 | unrow = mkR . head . LA.toRows . ud2 | ||
168 | |||
169 | --uncol :: L n 1 -> R n | ||
170 | uncol v = unrow . tr $ v | ||
171 | |||
172 | |||
173 | infixl 2 โโ | ||
174 | (โโ) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c | ||
175 | a โโ b = mkL (extract a LA.โโ extract b) | ||
176 | |||
177 | |||
178 | infixl 3 ยฆ | ||
179 | -- (ยฆ) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) | ||
180 | a ยฆ b = tr (tr a โโ tr b) | ||
181 | |||
182 | |||
183 | type Sq n = L n n | ||
184 | --type CSq n = CL n n | ||
185 | |||
186 | type GL = (KnownNat n, KnownNat m) => L m n | ||
187 | type GSq = KnownNat n => Sq n | ||
188 | |||
189 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (โ,(Int,Int)) | ||
190 | isKonst (unwrap -> x) | ||
191 | | singleM x = Just (x `atIndex` (0,0), (m',n')) | ||
192 | | otherwise = Nothing | ||
193 | where | ||
194 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
195 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
196 | |||
197 | |||
198 | isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (โ,(Int,Int)) | ||
199 | isKonstC (unwrap -> x) | ||
200 | | singleM x = Just (x `atIndex` (0,0), (m',n')) | ||
201 | | otherwise = Nothing | ||
202 | where | ||
203 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
204 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
205 | |||
206 | |||
207 | |||
208 | infixr 8 <> | ||
209 | (<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n | ||
210 | (<>) = mulR | ||
211 | |||
212 | |||
213 | infixr 8 #> | ||
214 | (#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m | ||
215 | (#>) = appR | ||
216 | |||
217 | |||
218 | infixr 8 <ยท> | ||
219 | (<ยท>) :: R n -> R n -> โ | ||
220 | (<ยท>) = dotR | ||
221 | |||
222 | -------------------------------------------------------------------------------- | ||
223 | |||
224 | class Diag m d | m -> d | ||
225 | where | ||
226 | takeDiag :: m -> d | ||
227 | |||
228 | |||
229 | instance forall n . (KnownNat n) => Diag (L n n) (R n) | ||
230 | where | ||
231 | takeDiag m = mkR (LA.takeDiag (extract m)) | ||
232 | |||
233 | |||
234 | instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m) | ||
235 | where | ||
236 | takeDiag m = mkR (LA.takeDiag (extract m)) | ||
237 | |||
238 | |||
239 | instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n) | ||
240 | where | ||
241 | takeDiag m = mkR (LA.takeDiag (extract m)) | ||
242 | |||
243 | |||
244 | -------------------------------------------------------------------------------- | ||
245 | |||
246 | linSolve :: (KnownNat m, KnownNat n) => L m m -> L m n -> Maybe (L m n) | ||
247 | linSolve (extract -> a) (extract -> b) = fmap mkL (LA.linearSolve a b) | ||
248 | |||
249 | (<\>) :: (KnownNat m, KnownNat n, KnownNat r) => L m n -> L m r -> L n r | ||
250 | (extract -> a) <\> (extract -> b) = mkL (a LA.<\> b) | ||
251 | |||
252 | svd :: (KnownNat m, KnownNat n) => L m n -> (L m m, R n, L n n) | ||
253 | svd (extract -> m) = (mkL u, mkR s', mkL v) | ||
254 | where | ||
255 | (u,s,v) = LA.svd m | ||
256 | s' = vjoin [s, z] | ||
257 | z = LA.konst 0 (max 0 (cols m - size s)) | ||
258 | |||
259 | |||
260 | svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) | ||
261 | svdTall (extract -> m) = (mkL u, mkR s, mkL v) | ||
262 | where | ||
263 | (u,s,v) = LA.thinSVD m | ||
264 | |||
265 | |||
266 | svdFlat :: (KnownNat m, KnownNat n, m <= n) => L m n -> (L m m, R m, L n m) | ||
267 | svdFlat (extract -> m) = (mkL u, mkR s, mkL v) | ||
268 | where | ||
269 | (u,s,v) = LA.thinSVD m | ||
270 | |||
271 | -------------------------------------------------------------------------------- | ||
272 | |||
273 | class Eigen m l v | m -> l, m -> v | ||
274 | where | ||
275 | eigensystem :: m -> (l,v) | ||
276 | eigenvalues :: m -> l | ||
277 | |||
278 | newtype Sym n = Sym (Sq n) deriving Show | ||
279 | |||
280 | |||
281 | sym :: KnownNat n => Sq n -> Sym n | ||
282 | sym m = Sym $ (m + tr m)/2 | ||
283 | |||
284 | mTm :: (KnownNat m, KnownNat n) => L m n -> Sym n | ||
285 | mTm x = Sym (tr x <> x) | ||
286 | |||
287 | unSym :: Sym n -> Sq n | ||
288 | unSym (Sym x) = x | ||
289 | |||
290 | |||
291 | ๐ :: Sized โ s c => s | ||
292 | ๐ = konst iC | ||
293 | |||
294 | newtype Her n = Her (M n n) | ||
295 | |||
296 | her :: KnownNat n => M n n -> Her n | ||
297 | her m = Her $ (m + LA.tr m)/2 | ||
298 | |||
299 | |||
300 | |||
301 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) | ||
302 | where | ||
303 | eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m | ||
304 | eigensystem (Sym (extract -> m)) = (mkR l, mkL v) | ||
305 | where | ||
306 | (l,v) = LA.eigSH' m | ||
307 | |||
308 | instance KnownNat n => Eigen (Sq n) (C n) (M n n) | ||
309 | where | ||
310 | eigenvalues (extract -> m) = mkC . LA.eigenvalues $ m | ||
311 | eigensystem (extract -> m) = (mkC l, mkM v) | ||
312 | where | ||
313 | (l,v) = LA.eig m | ||
314 | |||
315 | -------------------------------------------------------------------------------- | ||
316 | |||
317 | withNullspace | ||
318 | :: forall m n z . (KnownNat m, KnownNat n) | ||
319 | => L m n | ||
320 | -> (forall k . (KnownNat k) => L n k -> z) | ||
321 | -> z | ||
322 | withNullspace (LA.nullspace . extract -> a) f = | ||
323 | case someNatVal $ fromIntegral $ cols a of | ||
324 | Nothing -> error "static/dynamic mismatch" | ||
325 | Just (SomeNat (_ :: Proxy k)) -> f (mkL a :: L n k) | ||
326 | |||
327 | |||
328 | withCompactSVD | ||
329 | :: forall m n z . (KnownNat m, KnownNat n) | ||
330 | => L m n | ||
331 | -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z) | ||
332 | -> z | ||
333 | withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f = | ||
334 | case someNatVal $ fromIntegral $ size s of | ||
335 | Nothing -> error "static/dynamic mismatch" | ||
336 | Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k) | ||
337 | |||
338 | -------------------------------------------------------------------------------- | ||
339 | |||
340 | qr :: (KnownNat m, KnownNat n) => L m n -> (L m m, L m n) | ||
341 | qr (extract -> x) = (mkL q, mkL r) | ||
342 | where | ||
343 | (q,r) = LA.qr x | ||
344 | |||
345 | -- use qrRaw? | ||
346 | |||
347 | -------------------------------------------------------------------------------- | ||
348 | |||
349 | split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) | ||
350 | split (extract -> v) = ( mkR (subVector 0 p' v) , | ||
351 | mkR (subVector p' (size v - p') v) ) | ||
352 | where | ||
353 | p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int | ||
354 | |||
355 | |||
356 | headTail :: (KnownNat n, 1<=n) => R n -> (โ, R (n-1)) | ||
357 | headTail = ((!0) . extract *** id) . split | ||
358 | |||
359 | |||
360 | splitRows :: forall p m n . (KnownNat p, KnownNat m, KnownNat n, p<=m) => L m n -> (L p n, L (m-p) n) | ||
361 | splitRows (extract -> x) = ( mkL (takeRows p' x) , | ||
362 | mkL (dropRows p' x) ) | ||
363 | where | ||
364 | p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int | ||
365 | |||
366 | splitCols :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, KnownNat (n-p), p<=n) => L m n -> (L m p, L m (n-p)) | ||
367 | splitCols = (tr *** tr) . splitRows . tr | ||
368 | |||
369 | |||
370 | toRows :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R n] | ||
371 | toRows (LA.toRows . extract -> vs) = map mkR vs | ||
372 | |||
373 | |||
374 | toColumns :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R m] | ||
375 | toColumns (LA.toColumns . extract -> vs) = map mkR vs | ||
376 | |||
377 | |||
378 | -------------------------------------------------------------------------------- | ||
379 | |||
380 | build | ||
381 | :: forall m n . (KnownNat n, KnownNat m) | ||
382 | => (โ -> โ -> โ) | ||
383 | -> L m n | ||
384 | build f = mkL $ LA.build (m',n') f | ||
385 | where | ||
386 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
387 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
388 | |||
389 | -------------------------------------------------------------------------------- | ||
390 | |||
391 | withVector | ||
392 | :: forall z | ||
393 | . Vector โ | ||
394 | -> (forall n . (KnownNat n) => R n -> z) | ||
395 | -> z | ||
396 | withVector v f = | ||
397 | case someNatVal $ fromIntegral $ size v of | ||
398 | Nothing -> error "static/dynamic mismatch" | ||
399 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) | ||
400 | |||
401 | |||
402 | withMatrix | ||
403 | :: forall z | ||
404 | . Matrix โ | ||
405 | -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z) | ||
406 | -> z | ||
407 | withMatrix a f = | ||
408 | case someNatVal $ fromIntegral $ rows a of | ||
409 | Nothing -> error "static/dynamic mismatch" | ||
410 | Just (SomeNat (_ :: Proxy m)) -> | ||
411 | case someNatVal $ fromIntegral $ cols a of | ||
412 | Nothing -> error "static/dynamic mismatch" | ||
413 | Just (SomeNat (_ :: Proxy n)) -> | ||
414 | f (mkL a :: L m n) | ||
415 | |||
416 | -------------------------------------------------------------------------------- | ||
417 | |||
418 | class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat vec | ||
419 | where | ||
420 | mul :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => mat m k -> mat k n -> mat m n | ||
421 | app :: forall m n . (KnownNat m, KnownNat n) => mat m n -> vec n -> vec m | ||
422 | dot :: forall n . (KnownNat n) => vec n -> vec n -> field | ||
423 | cross :: vec 3 -> vec 3 -> vec 3 | ||
424 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n | ||
425 | |||
426 | |||
427 | instance Domain โ R L | ||
428 | where | ||
429 | mul = mulR | ||
430 | app = appR | ||
431 | dot = dotR | ||
432 | cross = crossR | ||
433 | diagR = diagRectR | ||
434 | |||
435 | instance Domain โ C M | ||
436 | where | ||
437 | mul = mulC | ||
438 | app = appC | ||
439 | dot = dotC | ||
440 | cross = crossC | ||
441 | diagR = diagRectC | ||
442 | |||
443 | -------------------------------------------------------------------------------- | ||
444 | |||
445 | mulR :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n | ||
446 | |||
447 | mulR (isKonst -> Just (a,(_,k))) (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k) | ||
448 | |||
449 | mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) | ||
450 | where | ||
451 | v = a' * b' | ||
452 | n = min (size a) (size b) | ||
453 | a' = subVector 0 n a | ||
454 | b' = subVector 0 n b | ||
455 | |||
456 | mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (size a) b) | ||
457 | |||
458 | mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) | ||
459 | |||
460 | mulR a b = mkL (extract a LA.<> extract b) | ||
461 | |||
462 | |||
463 | appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m | ||
464 | appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (size w) (extract v)) | ||
465 | appR m v = mkR (extract m LA.#> extract v) | ||
466 | |||
467 | |||
468 | dotR :: R n -> R n -> โ | ||
469 | dotR (ud1 -> u) (ud1 -> v) | ||
470 | | singleV u || singleV v = sumElements (u * v) | ||
471 | | otherwise = udot u v | ||
472 | |||
473 | |||
474 | crossR :: R 3 -> R 3 -> R 3 | ||
475 | crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3 | ||
476 | where | ||
477 | z1 = x!1*y!2-x!2*y!1 | ||
478 | z2 = x!2*y!0-x!0*y!2 | ||
479 | z3 = x!0*y!1-x!1*y!0 | ||
480 | |||
481 | -------------------------------------------------------------------------------- | ||
482 | |||
483 | mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n | ||
484 | |||
485 | mulC (isKonstC -> Just (a,(_,k))) (isKonstC -> Just (b,_)) = konst (a * b * fromIntegral k) | ||
486 | |||
487 | mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k) | ||
488 | where | ||
489 | v = a' * b' | ||
490 | n = min (size a) (size b) | ||
491 | a' = subVector 0 n a | ||
492 | b' = subVector 0 n b | ||
493 | |||
494 | mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (size a) b) | ||
495 | |||
496 | mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (size b) a * asRow b) | ||
497 | |||
498 | mulC a b = mkM (extract a LA.<> extract b) | ||
499 | |||
500 | |||
501 | appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m | ||
502 | appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (size w) (extract v)) | ||
503 | appC m v = mkC (extract m LA.#> extract v) | ||
504 | |||
505 | |||
506 | dotC :: KnownNat n => C n -> C n -> โ | ||
507 | dotC (unwrap -> u) (unwrap -> v) | ||
508 | | singleV u || singleV v = sumElements (conj u * v) | ||
509 | | otherwise = u LA.<ยท> v | ||
510 | |||
511 | |||
512 | crossC :: C 3 -> C 3 -> C 3 | ||
513 | crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) | ||
514 | where | ||
515 | z1 = x!1*y!2-x!2*y!1 | ||
516 | z2 = x!2*y!0-x!0*y!2 | ||
517 | z3 = x!0*y!1-x!1*y!0 | ||
518 | |||
519 | -------------------------------------------------------------------------------- | ||
520 | |||
521 | diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ -> R k -> L m n | ||
522 | diagRectR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) | ||
523 | where | ||
524 | ev = extract v | ||
525 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | ||
526 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
527 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
528 | |||
529 | |||
530 | diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ -> C k -> M m n | ||
531 | diagRectC x v = mkM (asRow (vjoin [scalar x, ev, zeros])) | ||
532 | where | ||
533 | ev = extract v | ||
534 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | ||
535 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
536 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
537 | |||
538 | -------------------------------------------------------------------------------- | ||
539 | |||
540 | mean :: (KnownNat n, 1<=n) => R n -> โ | ||
541 | mean v = v <ยท> (1/dim) | ||
542 | |||
543 | test :: (Bool, IO ()) | ||
544 | test = (ok,info) | ||
545 | where | ||
546 | ok = extract (eye :: Sq 5) == ident 5 | ||
547 | && (unwrap .unSym) (mTm sm :: Sym 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] | ||
548 | && unwrap (tm :: L 3 5) == LA.matrix 5 [1..15] | ||
549 | && thingS == thingD | ||
550 | && precS == precD | ||
551 | && withVector (LA.vector [1..15]) sumV == sumElements (LA.fromList [1..15]) | ||
552 | |||
553 | info = do | ||
554 | print $ u | ||
555 | print $ v | ||
556 | print (eye :: Sq 3) | ||
557 | print $ ((u & 5) + 1) <ยท> v | ||
558 | print (tm :: L 2 5) | ||
559 | print (tm <> sm :: L 2 3) | ||
560 | print thingS | ||
561 | print thingD | ||
562 | print precS | ||
563 | print precD | ||
564 | print $ withVector (LA.vector [1..15]) sumV | ||
565 | splittest | ||
566 | |||
567 | sumV w = w <ยท> konst 1 | ||
568 | |||
569 | u = vec2 3 5 | ||
570 | |||
571 | ๐ง x = vector [x] :: R 1 | ||
572 | |||
573 | v = ๐ง 2 & 4 & 7 | ||
574 | |||
575 | tm :: GL | ||
576 | tm = lmat 0 [1..] | ||
577 | |||
578 | lmat :: forall m n . (KnownNat m, KnownNat n) => โ -> [โ] -> L m n | ||
579 | lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z | ||
580 | where | ||
581 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
582 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
583 | |||
584 | sm :: GSq | ||
585 | sm = lmat 0 [1..] | ||
586 | |||
587 | thingS = (u & 1) <ยท> tr q #> q #> v | ||
588 | where | ||
589 | q = tm :: L 10 3 | ||
590 | |||
591 | thingD = vjoin [ud1 u, 1] LA.<ยท> tr m LA.#> m LA.#> ud1 v | ||
592 | where | ||
593 | m = LA.matrix 3 [1..30] | ||
594 | |||
595 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <ยท> konst 2 #> v | ||
596 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<ยท> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v | ||
597 | |||
598 | |||
599 | splittest | ||
600 | = do | ||
601 | let v = range :: R 7 | ||
602 | a = snd (split v) :: R 4 | ||
603 | print $ a | ||
604 | print $ snd . headTail . snd . headTail $ v | ||
605 | print $ first (vec3 1 2 3) | ||
606 | print $ second (vec3 1 2 3) | ||
607 | print $ third (vec3 1 2 3) | ||
608 | print $ (snd $ splitRows eye :: L 4 6) | ||
609 | where | ||
610 | first v = fst . headTail $ v | ||
611 | second v = first . snd . headTail $ v | ||
612 | third v = first . snd . headTail . snd . headTail $ v | ||
613 | |||
614 | |||
615 | instance (KnownNat n', KnownNat m') => Testable (L n' m') | ||
616 | where | ||
617 | checkT _ = test | ||
618 | |||
619 | |||