summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric
diff options
context:
space:
mode:
authorJustin Le <justin@jle.im>2016-01-07 01:38:43 -0800
committerJustin Le <justin@jle.im>2016-01-07 01:38:43 -0800
commit70b88736af624b635da192e5967175ae774fc646 (patch)
treef5432c2d968e1159fd9278aeb5fbcdf657469ce2 /packages/base/src/Numeric
parentcde00d2c7a660cb8f6edf7b4b308b66232b9b230 (diff)
random generator functions for Static module
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs51
1 files changed, 48 insertions, 3 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index 7790aef..f39e47f 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -52,6 +52,9 @@ module Numeric.LinearAlgebra.Static(
52 withNullspace, qr, chol, 52 withNullspace, qr, chol,
53 -- * Norms 53 -- * Norms
54 Normed(..), 54 Normed(..),
55 -- * Random arrays
56 Seed, RandDist(..),
57 randomVector, rand, randn, gaussianSample, uniformSample,
55 -- * Misc 58 -- * Misc
56 mean, 59 mean,
57 Disp(..), Domain(..), 60 Disp(..), Domain(..),
@@ -67,7 +70,8 @@ import Numeric.LinearAlgebra hiding (
67 row,col,vector,matrix,linspace,toRows,toColumns, 70 row,col,vector,matrix,linspace,toRows,toColumns,
68 (<\>),fromList,takeDiag,svd,eig,eigSH, 71 (<\>),fromList,takeDiag,svd,eig,eigSH,
69 eigenvalues,eigenvaluesSH,build, 72 eigenvalues,eigenvaluesSH,build,
70 qr,size,dot,chol,range,R,C,sym,mTm,unSym) 73 qr,size,dot,chol,range,R,C,sym,mTm,unSym,
74 randomVector, rand, randn, gaussianSample, uniformSample)
71import qualified Numeric.LinearAlgebra as LA 75import qualified Numeric.LinearAlgebra as LA
72import qualified Numeric.LinearAlgebra.Devel as LA 76import qualified Numeric.LinearAlgebra.Devel as LA
73import Data.Proxy(Proxy(..)) 77import Data.Proxy(Proxy(..))
@@ -398,7 +402,7 @@ withVector v f =
398-- | Useful for constraining two dependently typed vectors to match each 402-- | Useful for constraining two dependently typed vectors to match each
399-- other in length when they are unknown at compile-time. 403-- other in length when they are unknown at compile-time.
400exactLength 404exactLength
401 :: forall n m. (KnownNat n, KnownNat m) 405 :: forall n m . (KnownNat n, KnownNat m)
402 => R m 406 => R m
403 -> Maybe (R n) 407 -> Maybe (R n)
404exactLength v 408exactLength v
@@ -424,7 +428,7 @@ withMatrix a f =
424-- | Useful for constraining two dependently typed matrices to match each 428-- | Useful for constraining two dependently typed matrices to match each
425-- other in dimensions when they are unknown at compile-time. 429-- other in dimensions when they are unknown at compile-time.
426exactDims 430exactDims
427 :: forall n m j k. (KnownNat n, KnownNat m, KnownNat j, KnownNat k) 431 :: forall n m j k . (KnownNat n, KnownNat m, KnownNat j, KnownNat k)
428 => L m n 432 => L m n
429 -> Maybe (L j k) 433 -> Maybe (L j k)
430exactDims m 434exactDims m
@@ -434,6 +438,47 @@ exactDims m
434 | otherwise 438 | otherwise
435 = Nothing 439 = Nothing
436 440
441randomVector
442 :: forall n . KnownNat n
443 => Seed
444 -> RandDist
445 -> R n
446randomVector s d = mkR (LA.randomVector s d
447 (fromInteger (natVal (Proxy :: Proxy n)))
448 )
449
450rand
451 :: forall m n . (KnownNat m, KnownNat n)
452 => IO (L m n)
453rand = mkL <$> LA.rand (fromInteger (natVal (Proxy :: Proxy m)))
454 (fromInteger (natVal (Proxy :: Proxy n)))
455
456randn
457 :: forall m n . (KnownNat m, KnownNat n)
458 => IO (L m n)
459randn = mkL <$> LA.randn (fromInteger (natVal (Proxy :: Proxy m)))
460 (fromInteger (natVal (Proxy :: Proxy n)))
461
462gaussianSample
463 :: forall m n . (KnownNat m, KnownNat n)
464 => Seed
465 -> R n
466 -> Sym n
467 -> L m n
468gaussianSample s (extract -> mu) (Sym (extract -> sigma)) =
469 mkL $ LA.gaussianSample s (fromInteger (natVal (Proxy :: Proxy m)))
470 mu (LA.trustSym sigma)
471
472uniformSample
473 :: forall m n . (KnownNat m, KnownNat n)
474 => Seed
475 -> R n -- ^ minimums of each row
476 -> R n -- ^ maximums of each row
477 -> L m n
478uniformSample s (extract -> mins) (extract -> maxs) =
479 mkL $ LA.uniformSample s (fromInteger (natVal (Proxy :: Proxy m)))
480 (zip (LA.toList mins) (LA.toList maxs))
481
437 482
438-------------------------------------------------------------------------------- 483--------------------------------------------------------------------------------
439 484