summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Internal/Vector.hs18
-rw-r--r--lib/Numeric/ContainerBoot.hs40
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c20
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h3
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs17
5 files changed, 94 insertions, 4 deletions
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index ba68909..70bbae2 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -22,6 +22,7 @@ module Data.Packed.Internal.Vector (
22 foldVector, foldVectorG, foldLoop, foldVectorWithIndex, 22 foldVector, foldVectorG, foldLoop, foldVectorWithIndex,
23 createVector, vec, 23 createVector, vec,
24 asComplex, asReal, float2DoubleV, double2FloatV, 24 asComplex, asReal, float2DoubleV, double2FloatV,
25 stepF, stepD,
25 fwriteVector, freadVector, fprintfVector, fscanfVector, 26 fwriteVector, freadVector, fprintfVector, fscanfVector,
26 cloneVector, 27 cloneVector,
27 unsafeToForeignPtr, 28 unsafeToForeignPtr,
@@ -292,6 +293,23 @@ double2FloatV v = unsafePerformIO $ do
292foreign import ccall "float2double" c_float2double:: TFV 293foreign import ccall "float2double" c_float2double:: TFV
293foreign import ccall "double2float" c_double2float:: TVF 294foreign import ccall "double2float" c_double2float:: TVF
294 295
296---------------------------------------------------------------
297
298stepF :: Vector Float -> Vector Float
299stepF v = unsafePerformIO $ do
300 r <- createVector (dim v)
301 app2 c_stepF vec v vec r "stepF"
302 return r
303
304stepD :: Vector Double -> Vector Double
305stepD v = unsafePerformIO $ do
306 r <- createVector (dim v)
307 app2 c_stepD vec v vec r "stepD"
308 return r
309
310foreign import ccall "stepF" c_stepF :: TFF
311foreign import ccall "stepD" c_stepD :: TVV
312
295---------------------------------------------------------------- 313----------------------------------------------------------------
296 314
297cloneVector :: Storable t => Vector t -> IO (Vector t) 315cloneVector :: Storable t => Vector t -> IO (Vector t)
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index 992a501..5a8e243 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -45,6 +45,7 @@ module Numeric.ContainerBoot (
45) where 45) where
46 46
47import Data.Packed 47import Data.Packed
48import Data.Packed.ST as ST
48import Numeric.Conversion 49import Numeric.Conversion
49import Data.Packed.Internal 50import Data.Packed.Internal
50import Numeric.GSL.Vector 51import Numeric.GSL.Vector
@@ -120,6 +121,12 @@ class (Complexable c, Fractional e, Element e) => Container c e where
120 sumElements :: c e -> e 121 sumElements :: c e -> e
121 -- | the product of elements (faster than using @fold@) 122 -- | the product of elements (faster than using @fold@)
122 prodElements :: c e -> e 123 prodElements :: c e -> e
124 -- | map (if x_i>0 then 1.0 else 0.0)
125 step :: RealFloat e => c e -> c e
126 -- | find index of elements which satisfy a predicate
127 find :: (e -> Bool) -> c e -> [IndexOf c]
128 -- | create a structure from an association list
129 assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e
123 130
124-------------------------------------------------------------------------- 131--------------------------------------------------------------------------
125 132
@@ -145,6 +152,9 @@ instance Container Vector Float where
145 maxElement = toScalarF Max 152 maxElement = toScalarF Max
146 sumElements = sumF 153 sumElements = sumF
147 prodElements = prodF 154 prodElements = prodF
155 step = stepF
156 find = findV
157 assoc = assocV
148 158
149instance Container Vector Double where 159instance Container Vector Double where
150 scale = vectorMapValR Scale 160 scale = vectorMapValR Scale
@@ -168,6 +178,9 @@ instance Container Vector Double where
168 maxElement = toScalarR Max 178 maxElement = toScalarR Max
169 sumElements = sumR 179 sumElements = sumR
170 prodElements = prodR 180 prodElements = prodR
181 step = stepD
182 find = findV
183 assoc = assocV
171 184
172instance Container Vector (Complex Double) where 185instance Container Vector (Complex Double) where
173 scale = vectorMapValC Scale 186 scale = vectorMapValC Scale
@@ -191,6 +204,9 @@ instance Container Vector (Complex Double) where
191 maxElement = ap (@>) maxIndex 204 maxElement = ap (@>) maxIndex
192 sumElements = sumC 205 sumElements = sumC
193 prodElements = prodC 206 prodElements = prodC
207 step = undefined -- cannot match
208 find = findV
209 assoc = assocV
194 210
195instance Container Vector (Complex Float) where 211instance Container Vector (Complex Float) where
196 scale = vectorMapValQ Scale 212 scale = vectorMapValQ Scale
@@ -214,6 +230,9 @@ instance Container Vector (Complex Float) where
214 maxElement = ap (@>) maxIndex 230 maxElement = ap (@>) maxIndex
215 sumElements = sumQ 231 sumElements = sumQ
216 prodElements = prodQ 232 prodElements = prodQ
233 step = undefined -- cannot match
234 find = findV
235 assoc = assocV
217 236
218--------------------------------------------------------------- 237---------------------------------------------------------------
219 238
@@ -243,6 +262,9 @@ instance (Container Vector a) => Container Matrix a where
243 maxElement = ap (@@>) maxIndex 262 maxElement = ap (@@>) maxIndex
244 sumElements = sumElements . flatten 263 sumElements = sumElements . flatten
245 prodElements = prodElements . flatten 264 prodElements = prodElements . flatten
265 step = liftMatrix step
266 find = findM
267 assoc = assocM
246 268
247---------------------------------------------------- 269----------------------------------------------------
248 270
@@ -580,3 +602,21 @@ diag v = diagRect 0 v n n where n = dim v
580-- | creates the identity matrix of given dimension 602-- | creates the identity matrix of given dimension
581ident :: (Num a, Element a) => Int -> Matrix a 603ident :: (Num a, Element a) => Int -> Matrix a
582ident n = diag (constantD 1 n) 604ident n = diag (constantD 1 n)
605
606--------------------------------------------------------
607
608findV p x = foldVectorWithIndex g [] x where
609 g k z l = if p z then k:l else l
610
611findM p x = map ((`divMod` cols x)) $ findV p (flatten x)
612
613assocV n z xs = ST.runSTVector $ do
614 v <- ST.newVector z n
615 mapM_ (\(k,x) -> ST.writeVector v k x) xs
616 return v
617
618assocM (r,c) z xs = ST.runSTMatrix $ do
619 m <- ST.newMatrix z r c
620 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
621 return m
622
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
index e8bbbdb..ae437d2 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -1247,3 +1247,23 @@ int conjugateC(KCVEC(x),CVEC(t)) {
1247 OK 1247 OK
1248} 1248}
1249 1249
1250//////////////////// step /////////////////////////
1251
1252int stepF(FVEC(x),FVEC(y)) {
1253 DEBUGMSG("stepF")
1254 int k;
1255 for(k=0;k<xn;k++) {
1256 yp[k]=xp[k]>0;
1257 }
1258 OK
1259}
1260
1261int stepD(DVEC(x),DVEC(y)) {
1262 DEBUGMSG("stepD")
1263 int k;
1264 for(k=0;k<xn;k++) {
1265 yp[k]=xp[k]>0;
1266 }
1267 OK
1268}
1269
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
index 0543f7a..6207a59 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
@@ -87,6 +87,9 @@ int double2float(DVEC(x),FVEC(y));
87int conjugateQ(KQVEC(x),QVEC(t)); 87int conjugateQ(KQVEC(x),QVEC(t));
88int conjugateC(KCVEC(x),CVEC(t)); 88int conjugateC(KCVEC(x),CVEC(t));
89 89
90int stepF(FVEC(x),FVEC(y));
91int stepD(DVEC(x),DVEC(y));
92
90int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); 93int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v));
91int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); 94int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v));
92int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v)); 95int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v));
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 68e77f1..181cfbf 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -342,8 +342,8 @@ lift_maybe m = MaybeT $ do
342 342
343-- | apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs 343-- | apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs
344--successive_ :: Storable a => (a -> a -> Bool) -> Vector a -> Bool 344--successive_ :: Storable a => (a -> a -> Bool) -> Vector a -> Bool
345successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ step (subVector 1 (dim v - 1) v))) (v @> 0) 345successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ stp (subVector 1 (dim v - 1) v))) (v @> 0)
346 where step e = do 346 where stp e = do
347 ep <- lift_maybe $ state_get 347 ep <- lift_maybe $ state_get
348 if t e ep 348 if t e ep
349 then lift_maybe $ state_put e 349 then lift_maybe $ state_put e
@@ -351,8 +351,8 @@ successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ s
351 351
352-- | operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input 352-- | operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input
353--successive :: (Storable a, Storable b) => (a -> a -> b) -> Vector a -> Vector b 353--successive :: (Storable a, Storable b) => (a -> a -> b) -> Vector a -> Vector b
354successive f v = evalState (mapVectorM step (subVector 1 (dim v - 1) v)) (v @> 0) 354successive f v = evalState (mapVectorM stp (subVector 1 (dim v - 1) v)) (v @> 0)
355 where step e = do 355 where stp e = do
356 ep <- state_get 356 ep <- state_get
357 state_put e 357 state_put e
358 return $ f ep e 358 return $ f ep e
@@ -365,6 +365,14 @@ succTest = utest "successive" $
365 365
366--------------------------------------------------------------------- 366---------------------------------------------------------------------
367 367
368findAssocTest = utest "findAssoc" ok
369 where
370 ok = m1 == m2
371 m1 = assoc (6,6) 7 $ zip (find (>0) (ident 5 :: Matrix Float)) [10 ..] :: Matrix Double
372 m2 = diagRect 7 (fromList[10..14]) 6 6 :: Matrix Double
373
374---------------------------------------------------------------------
375
368 376
369-- | All tests must pass with a maximum dimension of about 20 377-- | All tests must pass with a maximum dimension of about 20
370-- (some tests may fail with bigger sizes due to precision loss). 378-- (some tests may fail with bigger sizes due to precision loss).
@@ -533,6 +541,7 @@ runTests n = do
533 , sumprodTest 541 , sumprodTest
534 , chainTest 542 , chainTest
535 , succTest 543 , succTest
544 , findAssocTest
536 ] 545 ]
537 return () 546 return ()
538 547