summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/HMatrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/HMatrix.hs')
-rw-r--r--packages/base/src/Numeric/HMatrix.hs235
1 files changed, 167 insertions, 68 deletions
diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs
index 421333a..34f4346 100644
--- a/packages/base/src/Numeric/HMatrix.hs
+++ b/packages/base/src/Numeric/HMatrix.hs
@@ -21,7 +21,7 @@ Copyright : (c) Alberto Ruiz 2006-14
21License : BSD3 21License : BSD3
22Stability : experimental 22Stability : experimental
23 23
24Experimental interface for real arrays with statically checked dimensions. 24Experimental interface with statically checked dimensions.
25 25
26-} 26-}
27 27
@@ -37,9 +37,11 @@ module Numeric.HMatrix(
37 unrow, uncol, 37 unrow, uncol,
38 38
39 eye, 39 eye,
40 diagR, diag, 40 diag,
41 blockAt, 41 blockAt,
42 matrix, 42 matrix,
43 -- * Complex
44 C, M, Her, her, 𝑖,
43 -- * Products 45 -- * Products
44 (<>),(#>),(<·>), 46 (<>),(#>),(<·>),
45 -- * Linear Systems 47 -- * Linear Systems
@@ -48,11 +50,11 @@ module Numeric.HMatrix(
48 svd, svdTall, svdFlat, Eigen(..), 50 svd, svdTall, svdFlat, Eigen(..),
49 withNullspace, 51 withNullspace,
50 -- * Misc 52 -- * Misc
51 Disp(..), 53 mean,
54 Disp(..), Domain(..),
52 withVector, withMatrix, 55 withVector, withMatrix,
53 toRows, toColumns, 56 toRows, toColumns,
54 Sized(..), Diag(..), Sym, sym, 57 Sized(..), Diag(..), Sym, sym
55 module Numeric.LinearAlgebra.HMatrix
56) where 58) where
57 59
58 60
@@ -125,15 +127,6 @@ ud2 (L (Dim (Dim x))) = x
125 127
126 128
127-------------------------------------------------------------------------------- 129--------------------------------------------------------------------------------
128--------------------------------------------------------------------------------
129
130diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n
131diagR x v = mkL (asRow (vjoin [scalar x, ev, zeros]))
132 where
133 ev = extract v
134 zeros = LA.konst x (max 0 ((min m' n') - size ev))
135 m' = fromIntegral . natVal $ (undefined :: Proxy m)
136 n' = fromIntegral . natVal $ (undefined :: Proxy n)
137 130
138diag :: KnownNat n => R n -> Sq n 131diag :: KnownNat n => R n -> Sq n
139diag = diagR 0 132diag = diagR 0
@@ -201,65 +194,37 @@ isKonst (unwrap -> x)
201 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int 194 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
202 195
203 196
197isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int))
198isKonstC (unwrap -> x)
199 | singleM x = Just (x `atIndex` (0,0), (m',n'))
200 | otherwise = Nothing
201 where
202 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
203 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
204
204 205
205 206
206infixr 8 <> 207infixr 8 <>
207(<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n 208(<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n
209(<>) = mulR
208 210
209(isKonst -> Just (a,(_,k))) <> (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k)
210
211(isDiag -> Just (0,a,_)) <> (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k)
212 where
213 v = a' * b'
214 n = min (size a) (size b)
215 a' = subVector 0 n a
216 b' = subVector 0 n b
217
218(isDiag -> Just (0,a,_)) <> (extract -> b) = mkL (asColumn a * takeRows (size a) b)
219
220(extract -> a) <> (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b)
221
222a <> b = mkL (extract a LA.<> extract b)
223 211
224infixr 8 #> 212infixr 8 #>
225(#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m 213(#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m
226(isDiag -> Just (0, w, _)) #> v = mkR (w * subVector 0 (size w) (extract v)) 214(#>) = appR
227m #> v = mkR (extract m LA.#> extract v)
228 215
229 216
230infixr 8 <·> 217infixr 8 <·>
231(<·>) :: R n -> R n -> ℝ 218(<·>) :: R n -> R n -> ℝ
232(ud1 -> u) <·> (ud1 -> v) 219(<·>) = dotR
233 | singleV u || singleV v = sumElements (u * v)
234 | otherwise = udot u v
235 220
236-------------------------------------------------------------------------------- 221--------------------------------------------------------------------------------
237 222
238{-
239class Minim (n :: Nat) (m :: Nat)
240 where
241 type Mini n m :: Nat
242
243instance forall (n :: Nat) . Minim n n
244 where
245 type Mini n n = n
246
247
248instance forall (n :: Nat) (m :: Nat) . (n <= m+1) => Minim n m
249 where
250 type Mini n m = n
251
252instance forall (n :: Nat) (m :: Nat) . (m <= n+1) => Minim n m
253 where
254 type Mini n m = m
255-}
256
257class Diag m d | m -> d 223class Diag m d | m -> d
258 where 224 where
259 takeDiag :: m -> d 225 takeDiag :: m -> d
260 226
261 227
262
263instance forall n . (KnownNat n) => Diag (L n n) (R n) 228instance forall n . (KnownNat n) => Diag (L n n) (R n)
264 where 229 where
265 takeDiag m = mkR (LA.takeDiag (extract m)) 230 takeDiag m = mkR (LA.takeDiag (extract m))
@@ -316,6 +281,15 @@ sym :: KnownNat n => Sq n -> Sym n
316sym m = Sym $ (m + tr m)/2 281sym m = Sym $ (m + tr m)/2
317 282
318 283
284𝑖 :: Sized ℂ s c => s
285𝑖 = konst iC
286
287newtype Her n = Her (M n n)
288
289her :: KnownNat n => M n n -> Her n
290her m = Her $ (m + LA.tr m)/2
291
292
319 293
320instance KnownNat n => Eigen (Sym n) (R n) (L n n) 294instance KnownNat n => Eigen (Sym n) (R n) (L n n)
321 where 295 where
@@ -375,21 +349,6 @@ toColumns :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R m]
375toColumns (LA.toColumns . extract -> vs) = map mkR vs 349toColumns (LA.toColumns . extract -> vs) = map mkR vs
376 350
377 351
378splittest
379 = do
380 let v = range :: R 7
381 a = snd (split v) :: R 4
382 print $ a
383 print $ snd . headTail . snd . headTail $ v
384 print $ first (vec3 1 2 3)
385 print $ second (vec3 1 2 3)
386 print $ third (vec3 1 2 3)
387 print $ (snd $ splitRows eye :: L 4 6)
388 where
389 first v = fst . headTail $ v
390 second v = first . snd . headTail $ v
391 third v = first . snd . headTail . snd . headTail $ v
392
393-------------------------------------------------------------------------------- 352--------------------------------------------------------------------------------
394 353
395build 354build
@@ -428,9 +387,133 @@ withMatrix a f =
428 Just (SomeNat (_ :: Proxy n)) -> 387 Just (SomeNat (_ :: Proxy n)) ->
429 f (mkL a :: L m n) 388 f (mkL a :: L m n)
430 389
390--------------------------------------------------------------------------------
391
392class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat vec
393 where
394 mul :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => mat m k -> mat k n -> mat m n
395 app :: forall m n . (KnownNat m, KnownNat n) => mat m n -> vec n -> vec m
396 dot :: forall n . (KnownNat n) => vec n -> vec n -> field
397 cross :: vec 3 -> vec 3 -> vec 3
398 diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n
399
400
401instance Domain ℝ R L
402 where
403 mul = mulR
404 app = appR
405 dot = dotR
406 cross = crossR
407 diagR = diagRectR
408
409instance Domain ℂ C M
410 where
411 mul = mulC
412 app = appC
413 dot = dotC
414 cross = crossC
415 diagR = diagRectC
416
417--------------------------------------------------------------------------------
418
419mulR :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n
420
421mulR (isKonst -> Just (a,(_,k))) (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k)
422
423mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k)
424 where
425 v = a' * b'
426 n = min (size a) (size b)
427 a' = subVector 0 n a
428 b' = subVector 0 n b
429
430mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (size a) b)
431
432mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b)
433
434mulR a b = mkL (extract a LA.<> extract b)
435
436
437appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m
438appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (size w) (extract v))
439appR m v = mkR (extract m LA.#> extract v)
440
441
442dotR :: R n -> R n -> ℝ
443dotR (ud1 -> u) (ud1 -> v)
444 | singleV u || singleV v = sumElements (u * v)
445 | otherwise = udot u v
446
447
448crossR :: R 3 -> R 3 -> R 3
449crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3
450 where
451 z1 = x!1*y!2-x!2*y!1
452 z2 = x!2*y!0-x!0*y!2
453 z3 = x!0*y!1-x!1*y!0
454
455--------------------------------------------------------------------------------
456
457mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n
458
459mulC (isKonstC -> Just (a,(_,k))) (isKonstC -> Just (b,_)) = konst (a * b * fromIntegral k)
460
461mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k)
462 where
463 v = a' * b'
464 n = min (size a) (size b)
465 a' = subVector 0 n a
466 b' = subVector 0 n b
467
468mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (size a) b)
469
470mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (size b) a * asRow b)
471
472mulC a b = mkM (extract a LA.<> extract b)
473
474
475appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m
476appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (size w) (extract v))
477appC m v = mkC (extract m LA.#> extract v)
478
479
480dotC :: KnownNat n => C n -> C n -> ℂ
481dotC (unwrap -> u) (unwrap -> v)
482 | singleV u || singleV v = sumElements (conj u * v)
483 | otherwise = u LA.<·> v
484
485
486crossC :: C 3 -> C 3 -> C 3
487crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3])
488 where
489 z1 = x!1*y!2-x!2*y!1
490 z2 = x!2*y!0-x!0*y!2
491 z3 = x!0*y!1-x!1*y!0
492
493--------------------------------------------------------------------------------
494
495diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n
496diagRectR x v = mkL (asRow (vjoin [scalar x, ev, zeros]))
497 where
498 ev = extract v
499 zeros = LA.konst x (max 0 ((min m' n') - size ev))
500 m' = fromIntegral . natVal $ (undefined :: Proxy m)
501 n' = fromIntegral . natVal $ (undefined :: Proxy n)
502
503
504diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n
505diagRectC x v = mkM (asRow (vjoin [scalar x, ev, zeros]))
506 where
507 ev = extract v
508 zeros = LA.konst x (max 0 ((min m' n') - size ev))
509 m' = fromIntegral . natVal $ (undefined :: Proxy m)
510 n' = fromIntegral . natVal $ (undefined :: Proxy n)
431 511
432-------------------------------------------------------------------------------- 512--------------------------------------------------------------------------------
433 513
514mean :: (KnownNat n, 1<=n) => R n -> ℝ
515mean v = v <·> (1/dim)
516
434test :: (Bool, IO ()) 517test :: (Bool, IO ())
435test = (ok,info) 518test = (ok,info)
436 where 519 where
@@ -490,6 +573,22 @@ test = (ok,info)
490 precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v 573 precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v
491 574
492 575
576splittest
577 = do
578 let v = range :: R 7
579 a = snd (split v) :: R 4
580 print $ a
581 print $ snd . headTail . snd . headTail $ v
582 print $ first (vec3 1 2 3)
583 print $ second (vec3 1 2 3)
584 print $ third (vec3 1 2 3)
585 print $ (snd $ splitRows eye :: L 4 6)
586 where
587 first v = fst . headTail $ v
588 second v = first . snd . headTail $ v
589 third v = first . snd . headTail . snd . headTail $ v
590
591
493instance (KnownNat n', KnownNat m') => Testable (L n' m') 592instance (KnownNat n', KnownNat m') => Testable (L n' m')
494 where 593 where
495 checkT _ = test 594 checkT _ = test