From 2e07762524d0d08fbc2e565529d480dc7fa479b5 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 6 Jun 2014 20:02:15 +0200 Subject: safe split --- packages/base/src/Numeric/LinearAlgebra/Real.hs | 59 +++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 4 deletions(-) (limited to 'packages/base/src/Numeric') diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index 8627084..aa48687 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs @@ -28,12 +28,12 @@ Experimental interface for real arrays with statically checked dimensions. module Numeric.LinearAlgebra.Real( -- * Vector R, C, - vec2, vec3, vec4, (&), (#), + vec2, vec3, vec4, (&), (#), split, headTail, vector, linspace, range, dim, -- * Matrix - L, Sq, M, - row, col, (¦),(——), + L, Sq, M, def, + row, col, (¦),(——), splitRows, splitCols, unrow, uncol, eye, @@ -58,11 +58,12 @@ module Numeric.LinearAlgebra.Real( import GHC.TypeLits import Numeric.HMatrix hiding ( (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace, - (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH') + (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build) import qualified Numeric.HMatrix as LA import Data.Proxy(Proxy) import Numeric.LinearAlgebra.Static import Text.Printf +import Control.Arrow((***)) 𝑖 :: Sized ℂ s c => s @@ -566,6 +567,55 @@ instance KnownNat n => Eigen (Sq n) (C n) (M n n) -------------------------------------------------------------------------------- +split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) +split (extract -> v) = ( mkR (subVector 0 p' v) , + mkR (subVector p' (size v - p') v) ) + where + p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int + + +headTail :: (KnownNat n, 1<=n) => R n -> (ℝ, R (n-1)) +headTail = ((!0) . extract *** id) . split + + +splitRows :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, p<=m) => L m n -> (L p n, L (m-p) n) +splitRows (extract -> x) = ( mkL (takeRows p' x) , + mkL (dropRows p' x) ) + where + p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int + +splitCols :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, KnownNat (n-p), p<=n) => L m n -> (L m p, L m (n-p)) +splitCols = (tr *** tr) . splitRows . tr + + +splittest + = do + let v = range :: R 7 + a = snd (split v) :: R 4 + print $ a + print $ snd . headTail . snd . headTail $ v + print $ first (vec3 1 2 3) + print $ second (vec3 1 2 3) + print $ third (vec3 1 2 3) + print $ (snd $ splitRows eye :: L 4 6) + where + first v = fst . headTail $ v + second v = first . snd . headTail $ v + third v = first . snd . headTail . snd . headTail $ v + +-------------------------------------------------------------------------------- + +def + :: forall m n . (KnownNat n, KnownNat m) + => (ℝ -> ℝ -> ℝ) + -> L m n +def f = mkL $ LA.build (m',n') f + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + +-------------------------------------------------------------------------------- + withVector :: forall z . Vector ℝ @@ -615,6 +665,7 @@ test = (ok,info) print precS print precD print $ withVector (LA.vect [1..15]) sumV + splittest sumV w = w <·> konst 1 -- cgit v1.2.3