summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs148
1 files changed, 132 insertions, 16 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index bd0e593..9a2bdc8 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -49,14 +49,17 @@ module Numeric.LinearAlgebra.Static(
49 linSolve, (<\>), 49 linSolve, (<\>),
50 -- * Factorizations 50 -- * Factorizations
51 svd, withCompactSVD, svdTall, svdFlat, Eigen(..), 51 svd, withCompactSVD, svdTall, svdFlat, Eigen(..),
52 withNullspace, qr, chol, 52 withNullspace, withOrth, qr, chol,
53 -- * Norms 53 -- * Norms
54 Normed(..), 54 Normed(..),
55 -- * Random arrays
56 Seed, RandDist(..),
57 randomVector, rand, randn, gaussianSample, uniformSample,
55 -- * Misc 58 -- * Misc
56 mean, 59 mean, meanCov,
57 Disp(..), Domain(..), 60 Disp(..), Domain(..),
58 withVector, withMatrix, 61 withVector, withMatrix, exactLength, exactDims,
59 toRows, toColumns, 62 toRows, toColumns, withRows, withColumns,
60 Sized(..), Diag(..), Sym, sym, mTm, unSym, (<·>) 63 Sized(..), Diag(..), Sym, sym, mTm, unSym, (<·>)
61) where 64) where
62 65
@@ -67,9 +70,11 @@ import Numeric.LinearAlgebra hiding (
67 row,col,vector,matrix,linspace,toRows,toColumns, 70 row,col,vector,matrix,linspace,toRows,toColumns,
68 (<\>),fromList,takeDiag,svd,eig,eigSH, 71 (<\>),fromList,takeDiag,svd,eig,eigSH,
69 eigenvalues,eigenvaluesSH,build, 72 eigenvalues,eigenvaluesSH,build,
70 qr,size,dot,chol,range,R,C,sym,mTm,unSym) 73 qr,size,dot,chol,range,R,C,sym,mTm,unSym,
74 randomVector,rand,randn,gaussianSample,uniformSample,meanCov)
71import qualified Numeric.LinearAlgebra as LA 75import qualified Numeric.LinearAlgebra as LA
72import Data.Proxy(Proxy) 76import qualified Numeric.LinearAlgebra.Devel as LA
77import Data.Proxy(Proxy(..))
73import Internal.Static 78import Internal.Static
74import Control.Arrow((***)) 79import Control.Arrow((***))
75 80
@@ -321,6 +326,15 @@ withNullspace (LA.nullspace . extract -> a) f =
321 Nothing -> error "static/dynamic mismatch" 326 Nothing -> error "static/dynamic mismatch"
322 Just (SomeNat (_ :: Proxy k)) -> f (mkL a :: L n k) 327 Just (SomeNat (_ :: Proxy k)) -> f (mkL a :: L n k)
323 328
329withOrth
330 :: forall m n z . (KnownNat m, KnownNat n)
331 => L m n
332 -> (forall k. (KnownNat k) => L n k -> z)
333 -> z
334withOrth (LA.orth . extract -> a) f =
335 case someNatVal $ fromIntegral $ cols a of
336 Nothing -> error "static/dynamic mismatch"
337 Just (SomeNat (_ :: Proxy k)) -> f (mkL a :: L n k)
324 338
325withCompactSVD 339withCompactSVD
326 :: forall m n z . (KnownNat m, KnownNat n) 340 :: forall m n z . (KnownNat m, KnownNat n)
@@ -367,10 +381,30 @@ splitCols = (tr *** tr) . splitRows . tr
367toRows :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R n] 381toRows :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R n]
368toRows (LA.toRows . extract -> vs) = map mkR vs 382toRows (LA.toRows . extract -> vs) = map mkR vs
369 383
384withRows
385 :: forall n z . KnownNat n
386 => [R n]
387 -> (forall m . KnownNat m => L m n -> z)
388 -> z
389withRows (LA.fromRows . map extract -> m) f =
390 case someNatVal $ fromIntegral $ LA.rows m of
391 Nothing -> error "static/dynamic mismatch"
392 Just (SomeNat (_ :: Proxy m)) -> f (mkL m :: L m n)
370 393
371toColumns :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R m] 394toColumns :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R m]
372toColumns (LA.toColumns . extract -> vs) = map mkR vs 395toColumns (LA.toColumns . extract -> vs) = map mkR vs
373 396
397withColumns
398 :: forall m z . KnownNat m
399 => [R m]
400 -> (forall n . KnownNat n => L m n -> z)
401 -> z
402withColumns (LA.fromColumns . map extract -> m) f =
403 case someNatVal $ fromIntegral $ LA.cols m of
404 Nothing -> error "static/dynamic mismatch"
405 Just (SomeNat (_ :: Proxy n)) -> f (mkL m :: L m n)
406
407
374 408
375-------------------------------------------------------------------------------- 409--------------------------------------------------------------------------------
376 410
@@ -394,6 +428,17 @@ withVector v f =
394 Nothing -> error "static/dynamic mismatch" 428 Nothing -> error "static/dynamic mismatch"
395 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) 429 Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m)
396 430
431-- | Useful for constraining two dependently typed vectors to match each
432-- other in length when they are unknown at compile-time.
433exactLength
434 :: forall n m . (KnownNat n, KnownNat m)
435 => R m
436 -> Maybe (R n)
437exactLength v
438 | natVal (Proxy :: Proxy n) == natVal (Proxy :: Proxy m)
439 = Just (mkR (unwrap v))
440 | otherwise
441 = Nothing
397 442
398withMatrix 443withMatrix
399 :: forall z 444 :: forall z
@@ -409,6 +454,66 @@ withMatrix a f =
409 Just (SomeNat (_ :: Proxy n)) -> 454 Just (SomeNat (_ :: Proxy n)) ->
410 f (mkL a :: L m n) 455 f (mkL a :: L m n)
411 456
457-- | Useful for constraining two dependently typed matrices to match each
458-- other in dimensions when they are unknown at compile-time.
459exactDims
460 :: forall n m j k . (KnownNat n, KnownNat m, KnownNat j, KnownNat k)
461 => L m n
462 -> Maybe (L j k)
463exactDims m
464 | natVal (Proxy :: Proxy m) == natVal (Proxy :: Proxy j)
465 && natVal (Proxy :: Proxy n) == natVal (Proxy :: Proxy k)
466 = Just (mkL (unwrap m))
467 | otherwise
468 = Nothing
469
470randomVector
471 :: forall n . KnownNat n
472 => Seed
473 -> RandDist
474 -> R n
475randomVector s d = mkR (LA.randomVector s d
476 (fromInteger (natVal (Proxy :: Proxy n)))
477 )
478
479rand
480 :: forall m n . (KnownNat m, KnownNat n)
481 => IO (L m n)
482rand = mkL <$> LA.rand (fromInteger (natVal (Proxy :: Proxy m)))
483 (fromInteger (natVal (Proxy :: Proxy n)))
484
485randn
486 :: forall m n . (KnownNat m, KnownNat n)
487 => IO (L m n)
488randn = mkL <$> LA.randn (fromInteger (natVal (Proxy :: Proxy m)))
489 (fromInteger (natVal (Proxy :: Proxy n)))
490
491gaussianSample
492 :: forall m n . (KnownNat m, KnownNat n)
493 => Seed
494 -> R n
495 -> Sym n
496 -> L m n
497gaussianSample s (extract -> mu) (Sym (extract -> sigma)) =
498 mkL $ LA.gaussianSample s (fromInteger (natVal (Proxy :: Proxy m)))
499 mu (LA.trustSym sigma)
500
501uniformSample
502 :: forall m n . (KnownNat m, KnownNat n)
503 => Seed
504 -> R n -- ^ minimums of each row
505 -> R n -- ^ maximums of each row
506 -> L m n
507uniformSample s (extract -> mins) (extract -> maxs) =
508 mkL $ LA.uniformSample s (fromInteger (natVal (Proxy :: Proxy m)))
509 (zip (LA.toList mins) (LA.toList maxs))
510
511meanCov
512 :: forall m n . (KnownNat m, KnownNat n)
513 => L m n
514 -> (R n, Sym n)
515meanCov (extract -> vs) = mkR *** (Sym . mkL . LA.unSym) $ LA.meanCov vs
516
412-------------------------------------------------------------------------------- 517--------------------------------------------------------------------------------
413 518
414class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat vec 519class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat vec
@@ -418,9 +523,10 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve
418 dot :: forall n . (KnownNat n) => vec n -> vec n -> field 523 dot :: forall n . (KnownNat n) => vec n -> vec n -> field
419 cross :: vec 3 -> vec 3 -> vec 3 524 cross :: vec 3 -> vec 3 -> vec 3
420 diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n 525 diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => field -> vec k -> mat m n
421 dvmap :: forall n. (field -> field) -> vec n -> vec n 526 dvmap :: forall n. KnownNat n => (field -> field) -> vec n -> vec n
422 dmmap :: forall n m. (field -> field) -> mat n m -> mat n m 527 dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m
423 outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m 528 outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m
529 zipWith :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n
424 530
425 531
426instance Domain ℝ R L 532instance Domain ℝ R L
@@ -433,6 +539,7 @@ instance Domain ℝ R L
433 dvmap = mapR 539 dvmap = mapR
434 dmmap = mapL 540 dmmap = mapL
435 outer = outerR 541 outer = outerR
542 zipWith = zipWithR
436 543
437instance Domain ℂ C M 544instance Domain ℂ C M
438 where 545 where
@@ -444,6 +551,7 @@ instance Domain ℂ C M
444 dvmap = mapC 551 dvmap = mapC
445 dmmap = mapM' 552 dmmap = mapM'
446 outer = outerC 553 outer = outerC
554 zipWith = zipWithC
447 555
448-------------------------------------------------------------------------------- 556--------------------------------------------------------------------------------
449 557
@@ -486,11 +594,15 @@ crossR (extract -> x) (extract -> y) = vec3 z1 z2 z3
486outerR :: (KnownNat m, KnownNat n) => R n -> R m -> L n m 594outerR :: (KnownNat m, KnownNat n) => R n -> R m -> L n m
487outerR (extract -> x) (extract -> y) = mkL (LA.outer x y) 595outerR (extract -> x) (extract -> y) = mkL (LA.outer x y)
488 596
489mapR :: (ℝ -> ℝ) -> R n -> R n 597mapR :: KnownNat n => (ℝ -> ℝ) -> R n -> R n
490mapR f (R (Dim v)) = R (Dim (LA.cmap f v)) 598mapR f (unwrap -> v) = mkR (LA.cmap f v)
599
600zipWithR :: KnownNat n => (ℝ -> ℝ -> ℝ) -> R n -> R n -> R n
601zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y)
602
603mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m
604mapL f (unwrap -> m) = mkL (LA.cmap f m)
491 605
492mapM' :: (ℂ -> ℂ) -> M n m -> M n m
493mapM' f (M (Dim (Dim m))) = M (Dim (Dim (LA.cmap f m)))
494 606
495-------------------------------------------------------------------------------- 607--------------------------------------------------------------------------------
496 608
@@ -533,11 +645,15 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3])
533outerC :: (KnownNat m, KnownNat n) => C n -> C m -> M n m 645outerC :: (KnownNat m, KnownNat n) => C n -> C m -> M n m
534outerC (extract -> x) (extract -> y) = mkM (LA.outer x y) 646outerC (extract -> x) (extract -> y) = mkM (LA.outer x y)
535 647
536mapC :: (ℂ -> ℂ) -> C n -> C n 648mapC :: KnownNat n => (ℂ -> ℂ) -> C n -> C n
537mapC f (C (Dim v)) = C (Dim (LA.cmap f v)) 649mapC f (unwrap -> v) = mkC (LA.cmap f v)
650
651zipWithC :: KnownNat n => (ℂ -> ℂ -> ℂ) -> C n -> C n -> C n
652zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y)
653
654mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m
655mapM' f (unwrap -> m) = mkM (LA.cmap f m)
538 656
539mapL :: (ℝ -> ℝ) -> L n m -> L n m
540mapL f (L (Dim (Dim m))) = L (Dim (Dim (LA.cmap f m)))
541 657
542-------------------------------------------------------------------------------- 658--------------------------------------------------------------------------------
543 659