diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-06-10 16:10:14 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-06-10 16:10:14 +0200 |
commit | 05e40db4fdc85b73f38ae5e105db0d523176debe (patch) | |
tree | ef0e87d3c7f6ca6d65cbaec6d6783db72df177f4 /packages/base/src/Numeric/HMatrix.hs | |
parent | a928a3a1713704cf3d5148bedc7ff8acb1347599 (diff) |
Domain class
Diffstat (limited to 'packages/base/src/Numeric/HMatrix.hs')
-rw-r--r-- | packages/base/src/Numeric/HMatrix.hs | 235 |
1 files changed, 167 insertions, 68 deletions
diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs index 421333a..34f4346 100644 --- a/packages/base/src/Numeric/HMatrix.hs +++ b/packages/base/src/Numeric/HMatrix.hs | |||
@@ -21,7 +21,7 @@ Copyright : (c) Alberto Ruiz 2006-14 | |||
21 | License : BSD3 | 21 | License : BSD3 |
22 | Stability : experimental | 22 | Stability : experimental |
23 | 23 | ||
24 | Experimental interface for real arrays with statically checked dimensions. | 24 | Experimental interface with statically checked dimensions. |
25 | 25 | ||
26 | -} | 26 | -} |
27 | 27 | ||
@@ -37,9 +37,11 @@ module Numeric.HMatrix( | |||
37 | unrow, uncol, | 37 | unrow, uncol, |
38 | 38 | ||
39 | eye, | 39 | eye, |
40 | diagR, diag, | 40 | diag, |
41 | blockAt, | 41 | blockAt, |
42 | matrix, | 42 | matrix, |
43 | -- * Complex | ||
44 | C, M, Her, her, 𝑖, | ||
43 | -- * Products | 45 | -- * Products |
44 | (<>),(#>),(<·>), | 46 | (<>),(#>),(<·>), |
45 | -- * Linear Systems | 47 | -- * Linear Systems |
@@ -48,11 +50,11 @@ module Numeric.HMatrix( | |||
48 | svd, svdTall, svdFlat, Eigen(..), | 50 | svd, svdTall, svdFlat, Eigen(..), |
49 | withNullspace, | 51 | withNullspace, |
50 | -- * Misc | 52 | -- * Misc |
51 | Disp(..), | 53 | mean, |
54 | Disp(..), Domain(..), | ||
52 | withVector, withMatrix, | 55 | withVector, withMatrix, |
53 | toRows, toColumns, | 56 | toRows, toColumns, |
54 | Sized(..), Diag(..), Sym, sym, | 57 | Sized(..), Diag(..), Sym, sym |
55 | module Numeric.LinearAlgebra.HMatrix | ||
56 | ) where | 58 | ) where |
57 | 59 | ||
58 | 60 | ||
@@ -125,15 +127,6 @@ ud2 (L (Dim (Dim x))) = x | |||
125 | 127 | ||
126 | 128 | ||
127 | -------------------------------------------------------------------------------- | 129 | -------------------------------------------------------------------------------- |
128 | -------------------------------------------------------------------------------- | ||
129 | |||
130 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n | ||
131 | diagR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) | ||
132 | where | ||
133 | ev = extract v | ||
134 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | ||
135 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
136 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
137 | 130 | ||
138 | diag :: KnownNat n => R n -> Sq n | 131 | diag :: KnownNat n => R n -> Sq n |
139 | diag = diagR 0 | 132 | diag = diagR 0 |
@@ -201,65 +194,37 @@ isKonst (unwrap -> x) | |||
201 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 194 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
202 | 195 | ||
203 | 196 | ||
197 | isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int)) | ||
198 | isKonstC (unwrap -> x) | ||
199 | | singleM x = Just (x `atIndex` (0,0), (m',n')) | ||
200 | | otherwise = Nothing | ||
201 | where | ||
202 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
203 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
204 | |||
204 | 205 | ||
205 | 206 | ||
206 | infixr 8 <> | 207 | infixr 8 <> |
207 | (<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n | 208 | (<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n |
209 | (<>) = mulR | ||
208 | 210 | ||
209 | (isKonst -> Just (a,(_,k))) <> (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k) | ||
210 | |||
211 | (isDiag -> Just (0,a,_)) <> (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) | ||
212 | where | ||
213 | v = a' * b' | ||
214 | n = min (size a) (size b) | ||
215 | a' = subVector 0 n a | ||
216 | b' = subVector 0 n b | ||
217 | |||
218 | (isDiag -> Just (0,a,_)) <> (extract -> b) = mkL (asColumn a * takeRows (size a) b) | ||
219 | |||
220 | (extract -> a) <> (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) | ||
221 | |||
222 | a <> b = mkL (extract a LA.<> extract b) | ||
223 | 211 | ||
224 | infixr 8 #> | 212 | infixr 8 #> |
225 | (#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m | 213 | (#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m |
226 | (isDiag -> Just (0, w, _)) #> v = mkR (w * subVector 0 (size w) (extract v)) | 214 | (#>) = appR |
227 | m #> v = mkR (extract m LA.#> extract v) | ||
228 | 215 | ||
229 | 216 | ||
230 | infixr 8 <·> | 217 | infixr 8 <·> |
231 | (<·>) :: R n -> R n -> ℝ | 218 | (<·>) :: R n -> R n -> ℝ |
232 | (ud1 -> u) <·> (ud1 -> v) | 219 | (<·>) = dotR |
233 | | singleV u || singleV v = sumElements (u * v) | ||
234 | | otherwise = udot u v | ||
235 | 220 | ||
236 | -------------------------------------------------------------------------------- | 221 | -------------------------------------------------------------------------------- |
237 | 222 | ||
238 | {- | ||
239 | class Minim (n :: Nat) (m :: Nat) | ||
240 | where | ||
241 | type Mini n m :: Nat | ||
242 | |||
243 | instance forall (n :: Nat) . Minim n n | ||
244 | where | ||
245 | type Mini n n = n | ||
246 | |||
247 | |||
248 | instance forall (n :: Nat) (m :: Nat) . (n <= m+1) => Minim n m | ||
249 | where | ||
250 | type Mini n m = n | ||
251 | |||
252 | instance forall (n :: Nat) (m :: Nat) . (m <= n+1) => Minim n m | ||
253 | where | ||
254 | type Mini n m = m | ||
255 | -} | ||
256 | |||
257 | class Diag m d | m -> d | 223 | class Diag m d | m -> d |
258 | where | 224 | where |
259 | takeDiag :: m -> d | 225 | takeDiag :: m -> d |
260 | 226 | ||
261 | 227 | ||
262 | |||
263 | instance forall n . (KnownNat n) => Diag (L n n) (R n) | 228 | instance forall n . (KnownNat n) => Diag (L n n) (R n) |
264 | where | 229 | where |
265 | takeDiag m = mkR (LA.takeDiag (extract m)) | 230 | takeDiag m = mkR (LA.takeDiag (extract m)) |
@@ -316,6 +281,15 @@ sym :: KnownNat n => Sq n -> Sym n | |||
316 | sym m = Sym $ (m + tr m)/2 | 281 | sym m = Sym $ (m + tr m)/2 |
317 | 282 | ||
318 | 283 | ||
284 | 𝑖 :: Sized ℂ s c => s | ||
285 | 𝑖 = konst iC | ||
286 | |||
287 | newtype Her n = Her (M n n) | ||
288 | |||
289 | her :: KnownNat n => M n n -> Her n | ||
290 | her m = Her $ (m + LA.tr m)/2 | ||
291 | |||
292 | |||
319 | 293 | ||
320 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) | 294 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) |
321 | where | 295 | where |
@@ -375,21 +349,6 @@ toColumns :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R m] | |||
375 | toColumns (LA.toColumns . extract -> vs) = map mkR vs | 349 | toColumns (LA.toColumns . extract -> vs) = map mkR vs |
376 | 350 | ||
377 | 351 | ||
378 | splittest | ||
379 | = do | ||
380 | let v = range :: R 7 | ||
381 | a = snd (split v) :: R 4 | ||
382 | print $ a | ||
383 | print $ snd . headTail . snd . headTail $ v | ||
384 | print $ first (vec3 1 2 3) | ||
385 | print $ second (vec3 1 2 3) | ||
386 | print $ third (vec3 1 2 3) | ||
387 | print $ (snd $ splitRows eye :: L 4 6) | ||
388 | where | ||
389 | first v = fst . headTail $ v | ||
390 | second v = first . snd . headTail $ v | ||
391 | third v = first . snd . headTail . snd . headTail $ v | ||
392 | |||
393 | -------------------------------------------------------------------------------- | 352 | -------------------------------------------------------------------------------- |
394 | 353 | ||
395 | build | 354 | build |
@@ -428,9 +387,133 @@ withMatrix a f = | |||
428 | Just (SomeNat (_ :: Proxy n)) -> | 387 | Just (SomeNat (_ :: Proxy n)) -> |
429 | f (mkL a :: L m n) | 388 | f (mkL a :: L m n) |
430 | 389 | ||
390 | -------------------------------------------------------------------------------- | ||
391 | |||
392 | class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat vec | ||
393 | where | ||
394 | mul :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => mat m k -> mat k n -> mat m n | ||
395 | app :: forall m n . (KnownNat m, KnownNat n) => mat m n -> vec n -> vec m | ||
396 | dot :: forall n . (KnownNat n) => vec n -> vec n -> field | ||
397 | cross :: vec 3 -> vec 3 -> vec 3 | ||
398 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n | ||
399 | |||
400 | |||
401 | instance Domain ℝ R L | ||
402 | where | ||
403 | mul = mulR | ||
404 | app = appR | ||
405 | dot = dotR | ||
406 | cross = crossR | ||
407 | diagR = diagRectR | ||
408 | |||
409 | instance Domain ℂ C M | ||
410 | where | ||
411 | mul = mulC | ||
412 | app = appC | ||
413 | dot = dotC | ||
414 | cross = crossC | ||
415 | diagR = diagRectC | ||
416 | |||
417 | -------------------------------------------------------------------------------- | ||
418 | |||
419 | mulR :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n | ||
420 | |||
421 | mulR (isKonst -> Just (a,(_,k))) (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k) | ||
422 | |||
423 | mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) | ||
424 | where | ||
425 | v = a' * b' | ||
426 | n = min (size a) (size b) | ||
427 | a' = subVector 0 n a | ||
428 | b' = subVector 0 n b | ||
429 | |||
430 | mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (size a) b) | ||
431 | |||
432 | mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) | ||
433 | |||
434 | mulR a b = mkL (extract a LA.<> extract b) | ||
435 | |||
436 | |||
437 | appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m | ||
438 | appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (size w) (extract v)) | ||
439 | appR m v = mkR (extract m LA.#> extract v) | ||
440 | |||
441 | |||
442 | dotR :: R n -> R n -> ℝ | ||
443 | dotR (ud1 -> u) (ud1 -> v) | ||
444 | | singleV u || singleV v = sumElements (u * v) | ||
445 | | otherwise = udot u v | ||
446 | |||
447 | |||
448 | crossR :: R 3 -> R 3 -> R 3 | ||
449 | crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3 | ||
450 | where | ||
451 | z1 = x!1*y!2-x!2*y!1 | ||
452 | z2 = x!2*y!0-x!0*y!2 | ||
453 | z3 = x!0*y!1-x!1*y!0 | ||
454 | |||
455 | -------------------------------------------------------------------------------- | ||
456 | |||
457 | mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n | ||
458 | |||
459 | mulC (isKonstC -> Just (a,(_,k))) (isKonstC -> Just (b,_)) = konst (a * b * fromIntegral k) | ||
460 | |||
461 | mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k) | ||
462 | where | ||
463 | v = a' * b' | ||
464 | n = min (size a) (size b) | ||
465 | a' = subVector 0 n a | ||
466 | b' = subVector 0 n b | ||
467 | |||
468 | mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (size a) b) | ||
469 | |||
470 | mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (size b) a * asRow b) | ||
471 | |||
472 | mulC a b = mkM (extract a LA.<> extract b) | ||
473 | |||
474 | |||
475 | appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m | ||
476 | appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (size w) (extract v)) | ||
477 | appC m v = mkC (extract m LA.#> extract v) | ||
478 | |||
479 | |||
480 | dotC :: KnownNat n => C n -> C n -> ℂ | ||
481 | dotC (unwrap -> u) (unwrap -> v) | ||
482 | | singleV u || singleV v = sumElements (conj u * v) | ||
483 | | otherwise = u LA.<·> v | ||
484 | |||
485 | |||
486 | crossC :: C 3 -> C 3 -> C 3 | ||
487 | crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) | ||
488 | where | ||
489 | z1 = x!1*y!2-x!2*y!1 | ||
490 | z2 = x!2*y!0-x!0*y!2 | ||
491 | z3 = x!0*y!1-x!1*y!0 | ||
492 | |||
493 | -------------------------------------------------------------------------------- | ||
494 | |||
495 | diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n | ||
496 | diagRectR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) | ||
497 | where | ||
498 | ev = extract v | ||
499 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | ||
500 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
501 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
502 | |||
503 | |||
504 | diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n | ||
505 | diagRectC x v = mkM (asRow (vjoin [scalar x, ev, zeros])) | ||
506 | where | ||
507 | ev = extract v | ||
508 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | ||
509 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
510 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
431 | 511 | ||
432 | -------------------------------------------------------------------------------- | 512 | -------------------------------------------------------------------------------- |
433 | 513 | ||
514 | mean :: (KnownNat n, 1<=n) => R n -> ℝ | ||
515 | mean v = v <·> (1/dim) | ||
516 | |||
434 | test :: (Bool, IO ()) | 517 | test :: (Bool, IO ()) |
435 | test = (ok,info) | 518 | test = (ok,info) |
436 | where | 519 | where |
@@ -490,6 +573,22 @@ test = (ok,info) | |||
490 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v | 573 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v |
491 | 574 | ||
492 | 575 | ||
576 | splittest | ||
577 | = do | ||
578 | let v = range :: R 7 | ||
579 | a = snd (split v) :: R 4 | ||
580 | print $ a | ||
581 | print $ snd . headTail . snd . headTail $ v | ||
582 | print $ first (vec3 1 2 3) | ||
583 | print $ second (vec3 1 2 3) | ||
584 | print $ third (vec3 1 2 3) | ||
585 | print $ (snd $ splitRows eye :: L 4 6) | ||
586 | where | ||
587 | first v = fst . headTail $ v | ||
588 | second v = first . snd . headTail $ v | ||
589 | third v = first . snd . headTail . snd . headTail $ v | ||
590 | |||
591 | |||
493 | instance (KnownNat n', KnownNat m') => Testable (L n' m') | 592 | instance (KnownNat n', KnownNat m') => Testable (L n' m') |
494 | where | 593 | where |
495 | checkT _ = test | 594 | checkT _ = test |