diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 18 | ||||
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 40 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 20 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 3 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 17 |
5 files changed, 94 insertions, 4 deletions
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index ba68909..70bbae2 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -22,6 +22,7 @@ module Data.Packed.Internal.Vector ( | |||
22 | foldVector, foldVectorG, foldLoop, foldVectorWithIndex, | 22 | foldVector, foldVectorG, foldLoop, foldVectorWithIndex, |
23 | createVector, vec, | 23 | createVector, vec, |
24 | asComplex, asReal, float2DoubleV, double2FloatV, | 24 | asComplex, asReal, float2DoubleV, double2FloatV, |
25 | stepF, stepD, | ||
25 | fwriteVector, freadVector, fprintfVector, fscanfVector, | 26 | fwriteVector, freadVector, fprintfVector, fscanfVector, |
26 | cloneVector, | 27 | cloneVector, |
27 | unsafeToForeignPtr, | 28 | unsafeToForeignPtr, |
@@ -292,6 +293,23 @@ double2FloatV v = unsafePerformIO $ do | |||
292 | foreign import ccall "float2double" c_float2double:: TFV | 293 | foreign import ccall "float2double" c_float2double:: TFV |
293 | foreign import ccall "double2float" c_double2float:: TVF | 294 | foreign import ccall "double2float" c_double2float:: TVF |
294 | 295 | ||
296 | --------------------------------------------------------------- | ||
297 | |||
298 | stepF :: Vector Float -> Vector Float | ||
299 | stepF v = unsafePerformIO $ do | ||
300 | r <- createVector (dim v) | ||
301 | app2 c_stepF vec v vec r "stepF" | ||
302 | return r | ||
303 | |||
304 | stepD :: Vector Double -> Vector Double | ||
305 | stepD v = unsafePerformIO $ do | ||
306 | r <- createVector (dim v) | ||
307 | app2 c_stepD vec v vec r "stepD" | ||
308 | return r | ||
309 | |||
310 | foreign import ccall "stepF" c_stepF :: TFF | ||
311 | foreign import ccall "stepD" c_stepD :: TVV | ||
312 | |||
295 | ---------------------------------------------------------------- | 313 | ---------------------------------------------------------------- |
296 | 314 | ||
297 | cloneVector :: Storable t => Vector t -> IO (Vector t) | 315 | cloneVector :: Storable t => Vector t -> IO (Vector t) |
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index 992a501..5a8e243 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -45,6 +45,7 @@ module Numeric.ContainerBoot ( | |||
45 | ) where | 45 | ) where |
46 | 46 | ||
47 | import Data.Packed | 47 | import Data.Packed |
48 | import Data.Packed.ST as ST | ||
48 | import Numeric.Conversion | 49 | import Numeric.Conversion |
49 | import Data.Packed.Internal | 50 | import Data.Packed.Internal |
50 | import Numeric.GSL.Vector | 51 | import Numeric.GSL.Vector |
@@ -120,6 +121,12 @@ class (Complexable c, Fractional e, Element e) => Container c e where | |||
120 | sumElements :: c e -> e | 121 | sumElements :: c e -> e |
121 | -- | the product of elements (faster than using @fold@) | 122 | -- | the product of elements (faster than using @fold@) |
122 | prodElements :: c e -> e | 123 | prodElements :: c e -> e |
124 | -- | map (if x_i>0 then 1.0 else 0.0) | ||
125 | step :: RealFloat e => c e -> c e | ||
126 | -- | find index of elements which satisfy a predicate | ||
127 | find :: (e -> Bool) -> c e -> [IndexOf c] | ||
128 | -- | create a structure from an association list | ||
129 | assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e | ||
123 | 130 | ||
124 | -------------------------------------------------------------------------- | 131 | -------------------------------------------------------------------------- |
125 | 132 | ||
@@ -145,6 +152,9 @@ instance Container Vector Float where | |||
145 | maxElement = toScalarF Max | 152 | maxElement = toScalarF Max |
146 | sumElements = sumF | 153 | sumElements = sumF |
147 | prodElements = prodF | 154 | prodElements = prodF |
155 | step = stepF | ||
156 | find = findV | ||
157 | assoc = assocV | ||
148 | 158 | ||
149 | instance Container Vector Double where | 159 | instance Container Vector Double where |
150 | scale = vectorMapValR Scale | 160 | scale = vectorMapValR Scale |
@@ -168,6 +178,9 @@ instance Container Vector Double where | |||
168 | maxElement = toScalarR Max | 178 | maxElement = toScalarR Max |
169 | sumElements = sumR | 179 | sumElements = sumR |
170 | prodElements = prodR | 180 | prodElements = prodR |
181 | step = stepD | ||
182 | find = findV | ||
183 | assoc = assocV | ||
171 | 184 | ||
172 | instance Container Vector (Complex Double) where | 185 | instance Container Vector (Complex Double) where |
173 | scale = vectorMapValC Scale | 186 | scale = vectorMapValC Scale |
@@ -191,6 +204,9 @@ instance Container Vector (Complex Double) where | |||
191 | maxElement = ap (@>) maxIndex | 204 | maxElement = ap (@>) maxIndex |
192 | sumElements = sumC | 205 | sumElements = sumC |
193 | prodElements = prodC | 206 | prodElements = prodC |
207 | step = undefined -- cannot match | ||
208 | find = findV | ||
209 | assoc = assocV | ||
194 | 210 | ||
195 | instance Container Vector (Complex Float) where | 211 | instance Container Vector (Complex Float) where |
196 | scale = vectorMapValQ Scale | 212 | scale = vectorMapValQ Scale |
@@ -214,6 +230,9 @@ instance Container Vector (Complex Float) where | |||
214 | maxElement = ap (@>) maxIndex | 230 | maxElement = ap (@>) maxIndex |
215 | sumElements = sumQ | 231 | sumElements = sumQ |
216 | prodElements = prodQ | 232 | prodElements = prodQ |
233 | step = undefined -- cannot match | ||
234 | find = findV | ||
235 | assoc = assocV | ||
217 | 236 | ||
218 | --------------------------------------------------------------- | 237 | --------------------------------------------------------------- |
219 | 238 | ||
@@ -243,6 +262,9 @@ instance (Container Vector a) => Container Matrix a where | |||
243 | maxElement = ap (@@>) maxIndex | 262 | maxElement = ap (@@>) maxIndex |
244 | sumElements = sumElements . flatten | 263 | sumElements = sumElements . flatten |
245 | prodElements = prodElements . flatten | 264 | prodElements = prodElements . flatten |
265 | step = liftMatrix step | ||
266 | find = findM | ||
267 | assoc = assocM | ||
246 | 268 | ||
247 | ---------------------------------------------------- | 269 | ---------------------------------------------------- |
248 | 270 | ||
@@ -580,3 +602,21 @@ diag v = diagRect 0 v n n where n = dim v | |||
580 | -- | creates the identity matrix of given dimension | 602 | -- | creates the identity matrix of given dimension |
581 | ident :: (Num a, Element a) => Int -> Matrix a | 603 | ident :: (Num a, Element a) => Int -> Matrix a |
582 | ident n = diag (constantD 1 n) | 604 | ident n = diag (constantD 1 n) |
605 | |||
606 | -------------------------------------------------------- | ||
607 | |||
608 | findV p x = foldVectorWithIndex g [] x where | ||
609 | g k z l = if p z then k:l else l | ||
610 | |||
611 | findM p x = map ((`divMod` cols x)) $ findV p (flatten x) | ||
612 | |||
613 | assocV n z xs = ST.runSTVector $ do | ||
614 | v <- ST.newVector z n | ||
615 | mapM_ (\(k,x) -> ST.writeVector v k x) xs | ||
616 | return v | ||
617 | |||
618 | assocM (r,c) z xs = ST.runSTMatrix $ do | ||
619 | m <- ST.newMatrix z r c | ||
620 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs | ||
621 | return m | ||
622 | |||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index e8bbbdb..ae437d2 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -1247,3 +1247,23 @@ int conjugateC(KCVEC(x),CVEC(t)) { | |||
1247 | OK | 1247 | OK |
1248 | } | 1248 | } |
1249 | 1249 | ||
1250 | //////////////////// step ///////////////////////// | ||
1251 | |||
1252 | int stepF(FVEC(x),FVEC(y)) { | ||
1253 | DEBUGMSG("stepF") | ||
1254 | int k; | ||
1255 | for(k=0;k<xn;k++) { | ||
1256 | yp[k]=xp[k]>0; | ||
1257 | } | ||
1258 | OK | ||
1259 | } | ||
1260 | |||
1261 | int stepD(DVEC(x),DVEC(y)) { | ||
1262 | DEBUGMSG("stepD") | ||
1263 | int k; | ||
1264 | for(k=0;k<xn;k++) { | ||
1265 | yp[k]=xp[k]>0; | ||
1266 | } | ||
1267 | OK | ||
1268 | } | ||
1269 | |||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index 0543f7a..6207a59 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | |||
@@ -87,6 +87,9 @@ int double2float(DVEC(x),FVEC(y)); | |||
87 | int conjugateQ(KQVEC(x),QVEC(t)); | 87 | int conjugateQ(KQVEC(x),QVEC(t)); |
88 | int conjugateC(KCVEC(x),CVEC(t)); | 88 | int conjugateC(KCVEC(x),CVEC(t)); |
89 | 89 | ||
90 | int stepF(FVEC(x),FVEC(y)); | ||
91 | int stepD(DVEC(x),DVEC(y)); | ||
92 | |||
90 | int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); | 93 | int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); |
91 | int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); | 94 | int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); |
92 | int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v)); | 95 | int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v)); |
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 68e77f1..181cfbf 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -342,8 +342,8 @@ lift_maybe m = MaybeT $ do | |||
342 | 342 | ||
343 | -- | apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs | 343 | -- | apply a test to successive elements of a vector, evaluates to true iff test passes for all pairs |
344 | --successive_ :: Storable a => (a -> a -> Bool) -> Vector a -> Bool | 344 | --successive_ :: Storable a => (a -> a -> Bool) -> Vector a -> Bool |
345 | successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ step (subVector 1 (dim v - 1) v))) (v @> 0) | 345 | successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ stp (subVector 1 (dim v - 1) v))) (v @> 0) |
346 | where step e = do | 346 | where stp e = do |
347 | ep <- lift_maybe $ state_get | 347 | ep <- lift_maybe $ state_get |
348 | if t e ep | 348 | if t e ep |
349 | then lift_maybe $ state_put e | 349 | then lift_maybe $ state_put e |
@@ -351,8 +351,8 @@ successive_ t v = maybe False (\_ -> True) $ evalState (runMaybeT (mapVectorM_ s | |||
351 | 351 | ||
352 | -- | operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input | 352 | -- | operate on successive elements of a vector and return the resulting vector, whose length 1 less than that of the input |
353 | --successive :: (Storable a, Storable b) => (a -> a -> b) -> Vector a -> Vector b | 353 | --successive :: (Storable a, Storable b) => (a -> a -> b) -> Vector a -> Vector b |
354 | successive f v = evalState (mapVectorM step (subVector 1 (dim v - 1) v)) (v @> 0) | 354 | successive f v = evalState (mapVectorM stp (subVector 1 (dim v - 1) v)) (v @> 0) |
355 | where step e = do | 355 | where stp e = do |
356 | ep <- state_get | 356 | ep <- state_get |
357 | state_put e | 357 | state_put e |
358 | return $ f ep e | 358 | return $ f ep e |
@@ -365,6 +365,14 @@ succTest = utest "successive" $ | |||
365 | 365 | ||
366 | --------------------------------------------------------------------- | 366 | --------------------------------------------------------------------- |
367 | 367 | ||
368 | findAssocTest = utest "findAssoc" ok | ||
369 | where | ||
370 | ok = m1 == m2 | ||
371 | m1 = assoc (6,6) 7 $ zip (find (>0) (ident 5 :: Matrix Float)) [10 ..] :: Matrix Double | ||
372 | m2 = diagRect 7 (fromList[10..14]) 6 6 :: Matrix Double | ||
373 | |||
374 | --------------------------------------------------------------------- | ||
375 | |||
368 | 376 | ||
369 | -- | All tests must pass with a maximum dimension of about 20 | 377 | -- | All tests must pass with a maximum dimension of about 20 |
370 | -- (some tests may fail with bigger sizes due to precision loss). | 378 | -- (some tests may fail with bigger sizes due to precision loss). |
@@ -533,6 +541,7 @@ runTests n = do | |||
533 | , sumprodTest | 541 | , sumprodTest |
534 | , chainTest | 542 | , chainTest |
535 | , succTest | 543 | , succTest |
544 | , findAssocTest | ||
536 | ] | 545 | ] |
537 | return () | 546 | return () |
538 | 547 | ||