diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 51 |
1 files changed, 48 insertions, 3 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 7790aef..f39e47f 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -52,6 +52,9 @@ module Numeric.LinearAlgebra.Static( | |||
52 | withNullspace, qr, chol, | 52 | withNullspace, qr, chol, |
53 | -- * Norms | 53 | -- * Norms |
54 | Normed(..), | 54 | Normed(..), |
55 | -- * Random arrays | ||
56 | Seed, RandDist(..), | ||
57 | randomVector, rand, randn, gaussianSample, uniformSample, | ||
55 | -- * Misc | 58 | -- * Misc |
56 | mean, | 59 | mean, |
57 | Disp(..), Domain(..), | 60 | Disp(..), Domain(..), |
@@ -67,7 +70,8 @@ import Numeric.LinearAlgebra hiding ( | |||
67 | row,col,vector,matrix,linspace,toRows,toColumns, | 70 | row,col,vector,matrix,linspace,toRows,toColumns, |
68 | (<\>),fromList,takeDiag,svd,eig,eigSH, | 71 | (<\>),fromList,takeDiag,svd,eig,eigSH, |
69 | eigenvalues,eigenvaluesSH,build, | 72 | eigenvalues,eigenvaluesSH,build, |
70 | qr,size,dot,chol,range,R,C,sym,mTm,unSym) | 73 | qr,size,dot,chol,range,R,C,sym,mTm,unSym, |
74 | randomVector, rand, randn, gaussianSample, uniformSample) | ||
71 | import qualified Numeric.LinearAlgebra as LA | 75 | import qualified Numeric.LinearAlgebra as LA |
72 | import qualified Numeric.LinearAlgebra.Devel as LA | 76 | import qualified Numeric.LinearAlgebra.Devel as LA |
73 | import Data.Proxy(Proxy(..)) | 77 | import Data.Proxy(Proxy(..)) |
@@ -398,7 +402,7 @@ withVector v f = | |||
398 | -- | Useful for constraining two dependently typed vectors to match each | 402 | -- | Useful for constraining two dependently typed vectors to match each |
399 | -- other in length when they are unknown at compile-time. | 403 | -- other in length when they are unknown at compile-time. |
400 | exactLength | 404 | exactLength |
401 | :: forall n m. (KnownNat n, KnownNat m) | 405 | :: forall n m . (KnownNat n, KnownNat m) |
402 | => R m | 406 | => R m |
403 | -> Maybe (R n) | 407 | -> Maybe (R n) |
404 | exactLength v | 408 | exactLength v |
@@ -424,7 +428,7 @@ withMatrix a f = | |||
424 | -- | Useful for constraining two dependently typed matrices to match each | 428 | -- | Useful for constraining two dependently typed matrices to match each |
425 | -- other in dimensions when they are unknown at compile-time. | 429 | -- other in dimensions when they are unknown at compile-time. |
426 | exactDims | 430 | exactDims |
427 | :: forall n m j k. (KnownNat n, KnownNat m, KnownNat j, KnownNat k) | 431 | :: forall n m j k . (KnownNat n, KnownNat m, KnownNat j, KnownNat k) |
428 | => L m n | 432 | => L m n |
429 | -> Maybe (L j k) | 433 | -> Maybe (L j k) |
430 | exactDims m | 434 | exactDims m |
@@ -434,6 +438,47 @@ exactDims m | |||
434 | | otherwise | 438 | | otherwise |
435 | = Nothing | 439 | = Nothing |
436 | 440 | ||
441 | randomVector | ||
442 | :: forall n . KnownNat n | ||
443 | => Seed | ||
444 | -> RandDist | ||
445 | -> R n | ||
446 | randomVector s d = mkR (LA.randomVector s d | ||
447 | (fromInteger (natVal (Proxy :: Proxy n))) | ||
448 | ) | ||
449 | |||
450 | rand | ||
451 | :: forall m n . (KnownNat m, KnownNat n) | ||
452 | => IO (L m n) | ||
453 | rand = mkL <$> LA.rand (fromInteger (natVal (Proxy :: Proxy m))) | ||
454 | (fromInteger (natVal (Proxy :: Proxy n))) | ||
455 | |||
456 | randn | ||
457 | :: forall m n . (KnownNat m, KnownNat n) | ||
458 | => IO (L m n) | ||
459 | randn = mkL <$> LA.randn (fromInteger (natVal (Proxy :: Proxy m))) | ||
460 | (fromInteger (natVal (Proxy :: Proxy n))) | ||
461 | |||
462 | gaussianSample | ||
463 | :: forall m n . (KnownNat m, KnownNat n) | ||
464 | => Seed | ||
465 | -> R n | ||
466 | -> Sym n | ||
467 | -> L m n | ||
468 | gaussianSample s (extract -> mu) (Sym (extract -> sigma)) = | ||
469 | mkL $ LA.gaussianSample s (fromInteger (natVal (Proxy :: Proxy m))) | ||
470 | mu (LA.trustSym sigma) | ||
471 | |||
472 | uniformSample | ||
473 | :: forall m n . (KnownNat m, KnownNat n) | ||
474 | => Seed | ||
475 | -> R n -- ^ minimums of each row | ||
476 | -> R n -- ^ maximums of each row | ||
477 | -> L m n | ||
478 | uniformSample s (extract -> mins) (extract -> maxs) = | ||
479 | mkL $ LA.uniformSample s (fromInteger (natVal (Proxy :: Proxy m))) | ||
480 | (zip (LA.toList mins) (LA.toList maxs)) | ||
481 | |||
437 | 482 | ||
438 | -------------------------------------------------------------------------------- | 483 | -------------------------------------------------------------------------------- |
439 | 484 | ||