summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs59
1 files changed, 55 insertions, 4 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
index 8627084..aa48687 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Real.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs
@@ -28,12 +28,12 @@ Experimental interface for real arrays with statically checked dimensions.
28module Numeric.LinearAlgebra.Real( 28module Numeric.LinearAlgebra.Real(
29 -- * Vector 29 -- * Vector
30 R, C, 30 R, C,
31 vec2, vec3, vec4, (&), (#), 31 vec2, vec3, vec4, (&), (#), split, headTail,
32 vector, 32 vector,
33 linspace, range, dim, 33 linspace, range, dim,
34 -- * Matrix 34 -- * Matrix
35 L, Sq, M, 35 L, Sq, M, def,
36 row, col, (¦),(——), 36 row, col, (¦),(——), splitRows, splitCols,
37 unrow, uncol, 37 unrow, uncol,
38 38
39 eye, 39 eye,
@@ -58,11 +58,12 @@ module Numeric.LinearAlgebra.Real(
58import GHC.TypeLits 58import GHC.TypeLits
59import Numeric.HMatrix hiding ( 59import Numeric.HMatrix hiding (
60 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace, 60 (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace,
61 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH') 61 (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build)
62import qualified Numeric.HMatrix as LA 62import qualified Numeric.HMatrix as LA
63import Data.Proxy(Proxy) 63import Data.Proxy(Proxy)
64import Numeric.LinearAlgebra.Static 64import Numeric.LinearAlgebra.Static
65import Text.Printf 65import Text.Printf
66import Control.Arrow((***))
66 67
67 68
68𝑖 :: Sized ℂ s c => s 69𝑖 :: Sized ℂ s c => s
@@ -566,6 +567,55 @@ instance KnownNat n => Eigen (Sq n) (C n) (M n n)
566 567
567-------------------------------------------------------------------------------- 568--------------------------------------------------------------------------------
568 569
570split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p))
571split (extract -> v) = ( mkR (subVector 0 p' v) ,
572 mkR (subVector p' (size v - p') v) )
573 where
574 p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int
575
576
577headTail :: (KnownNat n, 1<=n) => R n -> (ℝ, R (n-1))
578headTail = ((!0) . extract *** id) . split
579
580
581splitRows :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, p<=m) => L m n -> (L p n, L (m-p) n)
582splitRows (extract -> x) = ( mkL (takeRows p' x) ,
583 mkL (dropRows p' x) )
584 where
585 p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int
586
587splitCols :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, KnownNat (n-p), p<=n) => L m n -> (L m p, L m (n-p))
588splitCols = (tr *** tr) . splitRows . tr
589
590
591splittest
592 = do
593 let v = range :: R 7
594 a = snd (split v) :: R 4
595 print $ a
596 print $ snd . headTail . snd . headTail $ v
597 print $ first (vec3 1 2 3)
598 print $ second (vec3 1 2 3)
599 print $ third (vec3 1 2 3)
600 print $ (snd $ splitRows eye :: L 4 6)
601 where
602 first v = fst . headTail $ v
603 second v = first . snd . headTail $ v
604 third v = first . snd . headTail . snd . headTail $ v
605
606--------------------------------------------------------------------------------
607
608def
609 :: forall m n . (KnownNat n, KnownNat m)
610 => (ℝ -> ℝ -> ℝ)
611 -> L m n
612def f = mkL $ LA.build (m',n') f
613 where
614 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
615 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
616
617--------------------------------------------------------------------------------
618
569withVector 619withVector
570 :: forall z 620 :: forall z
571 . Vector ℝ 621 . Vector ℝ
@@ -615,6 +665,7 @@ test = (ok,info)
615 print precS 665 print precS
616 print precD 666 print precD
617 print $ withVector (LA.vect [1..15]) sumV 667 print $ withVector (LA.vect [1..15]) sumV
668 splittest
618 669
619 sumV w = w <·> konst 1 670 sumV w = w <·> konst 1
620 671