From 70b88736af624b635da192e5967175ae774fc646 Mon Sep 17 00:00:00 2001 From: Justin Le Date: Thu, 7 Jan 2016 01:38:43 -0800 Subject: random generator functions for Static module --- packages/base/src/Numeric/LinearAlgebra/Static.hs | 51 +++++++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) (limited to 'packages/base/src/Numeric') diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 7790aef..f39e47f 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs @@ -52,6 +52,9 @@ module Numeric.LinearAlgebra.Static( withNullspace, qr, chol, -- * Norms Normed(..), + -- * Random arrays + Seed, RandDist(..), + randomVector, rand, randn, gaussianSample, uniformSample, -- * Misc mean, Disp(..), Domain(..), @@ -67,7 +70,8 @@ import Numeric.LinearAlgebra hiding ( row,col,vector,matrix,linspace,toRows,toColumns, (<\>),fromList,takeDiag,svd,eig,eigSH, eigenvalues,eigenvaluesSH,build, - qr,size,dot,chol,range,R,C,sym,mTm,unSym) + qr,size,dot,chol,range,R,C,sym,mTm,unSym, + randomVector, rand, randn, gaussianSample, uniformSample) import qualified Numeric.LinearAlgebra as LA import qualified Numeric.LinearAlgebra.Devel as LA import Data.Proxy(Proxy(..)) @@ -398,7 +402,7 @@ withVector v f = -- | Useful for constraining two dependently typed vectors to match each -- other in length when they are unknown at compile-time. exactLength - :: forall n m. (KnownNat n, KnownNat m) + :: forall n m . (KnownNat n, KnownNat m) => R m -> Maybe (R n) exactLength v @@ -424,7 +428,7 @@ withMatrix a f = -- | Useful for constraining two dependently typed matrices to match each -- other in dimensions when they are unknown at compile-time. exactDims - :: forall n m j k. (KnownNat n, KnownNat m, KnownNat j, KnownNat k) + :: forall n m j k . (KnownNat n, KnownNat m, KnownNat j, KnownNat k) => L m n -> Maybe (L j k) exactDims m @@ -434,6 +438,47 @@ exactDims m | otherwise = Nothing +randomVector + :: forall n . KnownNat n + => Seed + -> RandDist + -> R n +randomVector s d = mkR (LA.randomVector s d + (fromInteger (natVal (Proxy :: Proxy n))) + ) + +rand + :: forall m n . (KnownNat m, KnownNat n) + => IO (L m n) +rand = mkL <$> LA.rand (fromInteger (natVal (Proxy :: Proxy m))) + (fromInteger (natVal (Proxy :: Proxy n))) + +randn + :: forall m n . (KnownNat m, KnownNat n) + => IO (L m n) +randn = mkL <$> LA.randn (fromInteger (natVal (Proxy :: Proxy m))) + (fromInteger (natVal (Proxy :: Proxy n))) + +gaussianSample + :: forall m n . (KnownNat m, KnownNat n) + => Seed + -> R n + -> Sym n + -> L m n +gaussianSample s (extract -> mu) (Sym (extract -> sigma)) = + mkL $ LA.gaussianSample s (fromInteger (natVal (Proxy :: Proxy m))) + mu (LA.trustSym sigma) + +uniformSample + :: forall m n . (KnownNat m, KnownNat n) + => Seed + -> R n -- ^ minimums of each row + -> R n -- ^ maximums of each row + -> L m n +uniformSample s (extract -> mins) (extract -> maxs) = + mkL $ LA.uniformSample s (fromInteger (natVal (Proxy :: Proxy m))) + (zip (LA.toList mins) (LA.toList maxs)) + -------------------------------------------------------------------------------- -- cgit v1.2.3