diff options
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 148 |
1 files changed, 132 insertions, 16 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index bd0e593..9a2bdc8 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -49,14 +49,17 @@ module Numeric.LinearAlgebra.Static( | |||
49 | linSolve, (<\>), | 49 | linSolve, (<\>), |
50 | -- * Factorizations | 50 | -- * Factorizations |
51 | svd, withCompactSVD, svdTall, svdFlat, Eigen(..), | 51 | svd, withCompactSVD, svdTall, svdFlat, Eigen(..), |
52 | withNullspace, qr, chol, | 52 | withNullspace, withOrth, qr, chol, |
53 | -- * Norms | 53 | -- * Norms |
54 | Normed(..), | 54 | Normed(..), |
55 | -- * Random arrays | ||
56 | Seed, RandDist(..), | ||
57 | randomVector, rand, randn, gaussianSample, uniformSample, | ||
55 | -- * Misc | 58 | -- * Misc |
56 | mean, | 59 | mean, meanCov, |
57 | Disp(..), Domain(..), | 60 | Disp(..), Domain(..), |
58 | withVector, withMatrix, | 61 | withVector, withMatrix, exactLength, exactDims, |
59 | toRows, toColumns, | 62 | toRows, toColumns, withRows, withColumns, |
60 | Sized(..), Diag(..), Sym, sym, mTm, unSym, (<·>) | 63 | Sized(..), Diag(..), Sym, sym, mTm, unSym, (<·>) |
61 | ) where | 64 | ) where |
62 | 65 | ||
@@ -67,9 +70,11 @@ import Numeric.LinearAlgebra hiding ( | |||
67 | row,col,vector,matrix,linspace,toRows,toColumns, | 70 | row,col,vector,matrix,linspace,toRows,toColumns, |
68 | (<\>),fromList,takeDiag,svd,eig,eigSH, | 71 | (<\>),fromList,takeDiag,svd,eig,eigSH, |
69 | eigenvalues,eigenvaluesSH,build, | 72 | eigenvalues,eigenvaluesSH,build, |
70 | qr,size,dot,chol,range,R,C,sym,mTm,unSym) | 73 | qr,size,dot,chol,range,R,C,sym,mTm,unSym, |
74 | randomVector,rand,randn,gaussianSample,uniformSample,meanCov) | ||
71 | import qualified Numeric.LinearAlgebra as LA | 75 | import qualified Numeric.LinearAlgebra as LA |
72 | import Data.Proxy(Proxy) | 76 | import qualified Numeric.LinearAlgebra.Devel as LA |
77 | import Data.Proxy(Proxy(..)) | ||
73 | import Internal.Static | 78 | import Internal.Static |
74 | import Control.Arrow((***)) | 79 | import Control.Arrow((***)) |
75 | 80 | ||
@@ -321,6 +326,15 @@ withNullspace (LA.nullspace . extract -> a) f = | |||
321 | Nothing -> error "static/dynamic mismatch" | 326 | Nothing -> error "static/dynamic mismatch" |
322 | Just (SomeNat (_ :: Proxy k)) -> f (mkL a :: L n k) | 327 | Just (SomeNat (_ :: Proxy k)) -> f (mkL a :: L n k) |
323 | 328 | ||
329 | withOrth | ||
330 | :: forall m n z . (KnownNat m, KnownNat n) | ||
331 | => L m n | ||
332 | -> (forall k. (KnownNat k) => L n k -> z) | ||
333 | -> z | ||
334 | withOrth (LA.orth . extract -> a) f = | ||
335 | case someNatVal $ fromIntegral $ cols a of | ||
336 | Nothing -> error "static/dynamic mismatch" | ||
337 | Just (SomeNat (_ :: Proxy k)) -> f (mkL a :: L n k) | ||
324 | 338 | ||
325 | withCompactSVD | 339 | withCompactSVD |
326 | :: forall m n z . (KnownNat m, KnownNat n) | 340 | :: forall m n z . (KnownNat m, KnownNat n) |
@@ -367,10 +381,30 @@ splitCols = (tr *** tr) . splitRows . tr | |||
367 | toRows :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R n] | 381 | toRows :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R n] |
368 | toRows (LA.toRows . extract -> vs) = map mkR vs | 382 | toRows (LA.toRows . extract -> vs) = map mkR vs |
369 | 383 | ||
384 | withRows | ||
385 | :: forall n z . KnownNat n | ||
386 | => [R n] | ||
387 | -> (forall m . KnownNat m => L m n -> z) | ||
388 | -> z | ||
389 | withRows (LA.fromRows . map extract -> m) f = | ||
390 | case someNatVal $ fromIntegral $ LA.rows m of | ||
391 | Nothing -> error "static/dynamic mismatch" | ||
392 | Just (SomeNat (_ :: Proxy m)) -> f (mkL m :: L m n) | ||
370 | 393 | ||
371 | toColumns :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R m] | 394 | toColumns :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R m] |
372 | toColumns (LA.toColumns . extract -> vs) = map mkR vs | 395 | toColumns (LA.toColumns . extract -> vs) = map mkR vs |
373 | 396 | ||
397 | withColumns | ||
398 | :: forall m z . KnownNat m | ||
399 | => [R m] | ||
400 | -> (forall n . KnownNat n => L m n -> z) | ||
401 | -> z | ||
402 | withColumns (LA.fromColumns . map extract -> m) f = | ||
403 | case someNatVal $ fromIntegral $ LA.cols m of | ||
404 | Nothing -> error "static/dynamic mismatch" | ||
405 | Just (SomeNat (_ :: Proxy n)) -> f (mkL m :: L m n) | ||
406 | |||
407 | |||
374 | 408 | ||
375 | -------------------------------------------------------------------------------- | 409 | -------------------------------------------------------------------------------- |
376 | 410 | ||
@@ -394,6 +428,17 @@ withVector v f = | |||
394 | Nothing -> error "static/dynamic mismatch" | 428 | Nothing -> error "static/dynamic mismatch" |
395 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) | 429 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) |
396 | 430 | ||
431 | -- | Useful for constraining two dependently typed vectors to match each | ||
432 | -- other in length when they are unknown at compile-time. | ||
433 | exactLength | ||
434 | :: forall n m . (KnownNat n, KnownNat m) | ||
435 | => R m | ||
436 | -> Maybe (R n) | ||
437 | exactLength v | ||
438 | | natVal (Proxy :: Proxy n) == natVal (Proxy :: Proxy m) | ||
439 | = Just (mkR (unwrap v)) | ||
440 | | otherwise | ||
441 | = Nothing | ||
397 | 442 | ||
398 | withMatrix | 443 | withMatrix |
399 | :: forall z | 444 | :: forall z |
@@ -409,6 +454,66 @@ withMatrix a f = | |||
409 | Just (SomeNat (_ :: Proxy n)) -> | 454 | Just (SomeNat (_ :: Proxy n)) -> |
410 | f (mkL a :: L m n) | 455 | f (mkL a :: L m n) |
411 | 456 | ||
457 | -- | Useful for constraining two dependently typed matrices to match each | ||
458 | -- other in dimensions when they are unknown at compile-time. | ||
459 | exactDims | ||
460 | :: forall n m j k . (KnownNat n, KnownNat m, KnownNat j, KnownNat k) | ||
461 | => L m n | ||
462 | -> Maybe (L j k) | ||
463 | exactDims m | ||
464 | | natVal (Proxy :: Proxy m) == natVal (Proxy :: Proxy j) | ||
465 | && natVal (Proxy :: Proxy n) == natVal (Proxy :: Proxy k) | ||
466 | = Just (mkL (unwrap m)) | ||
467 | | otherwise | ||
468 | = Nothing | ||
469 | |||
470 | randomVector | ||
471 | :: forall n . KnownNat n | ||
472 | => Seed | ||
473 | -> RandDist | ||
474 | -> R n | ||
475 | randomVector s d = mkR (LA.randomVector s d | ||
476 | (fromInteger (natVal (Proxy :: Proxy n))) | ||
477 | ) | ||
478 | |||
479 | rand | ||
480 | :: forall m n . (KnownNat m, KnownNat n) | ||
481 | => IO (L m n) | ||
482 | rand = mkL <$> LA.rand (fromInteger (natVal (Proxy :: Proxy m))) | ||
483 | (fromInteger (natVal (Proxy :: Proxy n))) | ||
484 | |||
485 | randn | ||
486 | :: forall m n . (KnownNat m, KnownNat n) | ||
487 | => IO (L m n) | ||
488 | randn = mkL <$> LA.randn (fromInteger (natVal (Proxy :: Proxy m))) | ||
489 | (fromInteger (natVal (Proxy :: Proxy n))) | ||
490 | |||
491 | gaussianSample | ||
492 | :: forall m n . (KnownNat m, KnownNat n) | ||
493 | => Seed | ||
494 | -> R n | ||
495 | -> Sym n | ||
496 | -> L m n | ||
497 | gaussianSample s (extract -> mu) (Sym (extract -> sigma)) = | ||
498 | mkL $ LA.gaussianSample s (fromInteger (natVal (Proxy :: Proxy m))) | ||
499 | mu (LA.trustSym sigma) | ||
500 | |||
501 | uniformSample | ||
502 | :: forall m n . (KnownNat m, KnownNat n) | ||
503 | => Seed | ||
504 | -> R n -- ^ minimums of each row | ||
505 | -> R n -- ^ maximums of each row | ||
506 | -> L m n | ||
507 | uniformSample s (extract -> mins) (extract -> maxs) = | ||
508 | mkL $ LA.uniformSample s (fromInteger (natVal (Proxy :: Proxy m))) | ||
509 | (zip (LA.toList mins) (LA.toList maxs)) | ||
510 | |||
511 | meanCov | ||
512 | :: forall m n . (KnownNat m, KnownNat n) | ||
513 | => L m n | ||
514 | -> (R n, Sym n) | ||
515 | meanCov (extract -> vs) = mkR *** (Sym . mkL . LA.unSym) $ LA.meanCov vs | ||
516 | |||
412 | -------------------------------------------------------------------------------- | 517 | -------------------------------------------------------------------------------- |
413 | 518 | ||
414 | class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat vec | 519 | class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat vec |
@@ -418,9 +523,10 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve | |||
418 | dot :: forall n . (KnownNat n) => vec n -> vec n -> field | 523 | dot :: forall n . (KnownNat n) => vec n -> vec n -> field |
419 | cross :: vec 3 -> vec 3 -> vec 3 | 524 | cross :: vec 3 -> vec 3 -> vec 3 |
420 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n | 525 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n |
421 | dvmap :: forall n. (field -> field) -> vec n -> vec n | 526 | dvmap :: forall n. KnownNat n => (field -> field) -> vec n -> vec n |
422 | dmmap :: forall n m. (field -> field) -> mat n m -> mat n m | 527 | dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m |
423 | outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m | 528 | outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m |
529 | zipWith :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n | ||
424 | 530 | ||
425 | 531 | ||
426 | instance Domain ℝ R L | 532 | instance Domain ℝ R L |
@@ -433,6 +539,7 @@ instance Domain ℝ R L | |||
433 | dvmap = mapR | 539 | dvmap = mapR |
434 | dmmap = mapL | 540 | dmmap = mapL |
435 | outer = outerR | 541 | outer = outerR |
542 | zipWith = zipWithR | ||
436 | 543 | ||
437 | instance Domain ℂ C M | 544 | instance Domain ℂ C M |
438 | where | 545 | where |
@@ -444,6 +551,7 @@ instance Domain ℂ C M | |||
444 | dvmap = mapC | 551 | dvmap = mapC |
445 | dmmap = mapM' | 552 | dmmap = mapM' |
446 | outer = outerC | 553 | outer = outerC |
554 | zipWith = zipWithC | ||
447 | 555 | ||
448 | -------------------------------------------------------------------------------- | 556 | -------------------------------------------------------------------------------- |
449 | 557 | ||
@@ -486,11 +594,15 @@ crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3 | |||
486 | outerR :: (KnownNat m, KnownNat n) => R n -> R m -> L n m | 594 | outerR :: (KnownNat m, KnownNat n) => R n -> R m -> L n m |
487 | outerR (extract -> x) (extract -> y) = mkL (LA.outer x y) | 595 | outerR (extract -> x) (extract -> y) = mkL (LA.outer x y) |
488 | 596 | ||
489 | mapR :: (ℝ -> ℝ) -> R n -> R n | 597 | mapR :: KnownNat n => (ℝ -> ℝ) -> R n -> R n |
490 | mapR f (R (Dim v)) = R (Dim (LA.cmap f v)) | 598 | mapR f (unwrap -> v) = mkR (LA.cmap f v) |
599 | |||
600 | zipWithR :: KnownNat n => (ℝ -> ℝ -> ℝ) -> R n -> R n -> R n | ||
601 | zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y) | ||
602 | |||
603 | mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m | ||
604 | mapL f (unwrap -> m) = mkL (LA.cmap f m) | ||
491 | 605 | ||
492 | mapM' :: (ℂ -> ℂ) -> M n m -> M n m | ||
493 | mapM' f (M (Dim (Dim m))) = M (Dim (Dim (LA.cmap f m))) | ||
494 | 606 | ||
495 | -------------------------------------------------------------------------------- | 607 | -------------------------------------------------------------------------------- |
496 | 608 | ||
@@ -533,11 +645,15 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) | |||
533 | outerC :: (KnownNat m, KnownNat n) => C n -> C m -> M n m | 645 | outerC :: (KnownNat m, KnownNat n) => C n -> C m -> M n m |
534 | outerC (extract -> x) (extract -> y) = mkM (LA.outer x y) | 646 | outerC (extract -> x) (extract -> y) = mkM (LA.outer x y) |
535 | 647 | ||
536 | mapC :: (ℂ -> ℂ) -> C n -> C n | 648 | mapC :: KnownNat n => (ℂ -> ℂ) -> C n -> C n |
537 | mapC f (C (Dim v)) = C (Dim (LA.cmap f v)) | 649 | mapC f (unwrap -> v) = mkC (LA.cmap f v) |
650 | |||
651 | zipWithC :: KnownNat n => (ℂ -> ℂ -> ℂ) -> C n -> C n -> C n | ||
652 | zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y) | ||
653 | |||
654 | mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m | ||
655 | mapM' f (unwrap -> m) = mkM (LA.cmap f m) | ||
538 | 656 | ||
539 | mapL :: (ℝ -> ℝ) -> L n m -> L n m | ||
540 | mapL f (L (Dim (Dim m))) = L (Dim (Dim (LA.cmap f m))) | ||
541 | 657 | ||
542 | -------------------------------------------------------------------------------- | 658 | -------------------------------------------------------------------------------- |
543 | 659 | ||