diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 19 | ||||
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 39 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 22 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 3 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 10 |
5 files changed, 91 insertions, 2 deletions
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 70bbae2..b68d16a 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -22,7 +22,7 @@ module Data.Packed.Internal.Vector ( | |||
22 | foldVector, foldVectorG, foldLoop, foldVectorWithIndex, | 22 | foldVector, foldVectorG, foldLoop, foldVectorWithIndex, |
23 | createVector, vec, | 23 | createVector, vec, |
24 | asComplex, asReal, float2DoubleV, double2FloatV, | 24 | asComplex, asReal, float2DoubleV, double2FloatV, |
25 | stepF, stepD, | 25 | stepF, stepD, condF, condD, |
26 | fwriteVector, freadVector, fprintfVector, fscanfVector, | 26 | fwriteVector, freadVector, fprintfVector, fscanfVector, |
27 | cloneVector, | 27 | cloneVector, |
28 | unsafeToForeignPtr, | 28 | unsafeToForeignPtr, |
@@ -310,6 +310,23 @@ stepD v = unsafePerformIO $ do | |||
310 | foreign import ccall "stepF" c_stepF :: TFF | 310 | foreign import ccall "stepF" c_stepF :: TFF |
311 | foreign import ccall "stepD" c_stepD :: TVV | 311 | foreign import ccall "stepD" c_stepD :: TVV |
312 | 312 | ||
313 | --------------------------------------------------------------- | ||
314 | |||
315 | condF :: Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float | ||
316 | condF x y l e g = unsafePerformIO $ do | ||
317 | r <- createVector (dim x) | ||
318 | app6 c_condF vec x vec y vec l vec e vec g vec r "condF" | ||
319 | return r | ||
320 | |||
321 | condD :: Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double | ||
322 | condD x y l e g = unsafePerformIO $ do | ||
323 | r <- createVector (dim x) | ||
324 | app6 c_condD vec x vec y vec l vec e vec g vec r "condD" | ||
325 | return r | ||
326 | |||
327 | foreign import ccall "condF" c_condF :: CInt -> PF -> CInt -> PF -> CInt -> PF -> TFFF | ||
328 | foreign import ccall "condD" c_condD :: CInt -> PD -> CInt -> PD -> CInt -> PD -> TVVV | ||
329 | |||
313 | ---------------------------------------------------------------- | 330 | ---------------------------------------------------------------- |
314 | 331 | ||
315 | cloneVector :: Storable t => Vector t -> IO (Vector t) | 332 | cloneVector :: Storable t => Vector t -> IO (Vector t) |
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index 5a8e243..e33857a 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -127,6 +127,8 @@ class (Complexable c, Fractional e, Element e) => Container c e where | |||
127 | find :: (e -> Bool) -> c e -> [IndexOf c] | 127 | find :: (e -> Bool) -> c e -> [IndexOf c] |
128 | -- | create a structure from an association list | 128 | -- | create a structure from an association list |
129 | assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e | 129 | assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e |
130 | -- | a vectorized form of case 'compare' a_i b_i of LT -> l_i; EQ -> e_i; GT -> g_i | ||
131 | cond :: RealFloat e => c e -> c e -> c e -> c e -> c e -> c e | ||
130 | 132 | ||
131 | -------------------------------------------------------------------------- | 133 | -------------------------------------------------------------------------- |
132 | 134 | ||
@@ -155,6 +157,7 @@ instance Container Vector Float where | |||
155 | step = stepF | 157 | step = stepF |
156 | find = findV | 158 | find = findV |
157 | assoc = assocV | 159 | assoc = assocV |
160 | cond = condV condF | ||
158 | 161 | ||
159 | instance Container Vector Double where | 162 | instance Container Vector Double where |
160 | scale = vectorMapValR Scale | 163 | scale = vectorMapValR Scale |
@@ -181,6 +184,7 @@ instance Container Vector Double where | |||
181 | step = stepD | 184 | step = stepD |
182 | find = findV | 185 | find = findV |
183 | assoc = assocV | 186 | assoc = assocV |
187 | cond = condV condD | ||
184 | 188 | ||
185 | instance Container Vector (Complex Double) where | 189 | instance Container Vector (Complex Double) where |
186 | scale = vectorMapValC Scale | 190 | scale = vectorMapValC Scale |
@@ -207,6 +211,7 @@ instance Container Vector (Complex Double) where | |||
207 | step = undefined -- cannot match | 211 | step = undefined -- cannot match |
208 | find = findV | 212 | find = findV |
209 | assoc = assocV | 213 | assoc = assocV |
214 | cond = undefined -- cannot match | ||
210 | 215 | ||
211 | instance Container Vector (Complex Float) where | 216 | instance Container Vector (Complex Float) where |
212 | scale = vectorMapValQ Scale | 217 | scale = vectorMapValQ Scale |
@@ -233,6 +238,7 @@ instance Container Vector (Complex Float) where | |||
233 | step = undefined -- cannot match | 238 | step = undefined -- cannot match |
234 | find = findV | 239 | find = findV |
235 | assoc = assocV | 240 | assoc = assocV |
241 | cond = undefined -- cannot match | ||
236 | 242 | ||
237 | --------------------------------------------------------------- | 243 | --------------------------------------------------------------- |
238 | 244 | ||
@@ -265,6 +271,7 @@ instance (Container Vector a) => Container Matrix a where | |||
265 | step = liftMatrix step | 271 | step = liftMatrix step |
266 | find = findM | 272 | find = findM |
267 | assoc = assocM | 273 | assoc = assocM |
274 | cond = condM | ||
268 | 275 | ||
269 | ---------------------------------------------------- | 276 | ---------------------------------------------------- |
270 | 277 | ||
@@ -620,3 +627,35 @@ assocM (r,c) z xs = ST.runSTMatrix $ do | |||
620 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs | 627 | mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs |
621 | return m | 628 | return m |
622 | 629 | ||
630 | ---------------------------------------------------------------------- | ||
631 | |||
632 | conformMTo (r,c) m | ||
633 | | size m == (r,c) = m | ||
634 | | size m == (1,1) = konst (m@@>(0,0)) (r,c) | ||
635 | | size m == (r,1) = repCols c m | ||
636 | | size m == (1,c) = repRows r m | ||
637 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | ||
638 | |||
639 | conformVTo n v | ||
640 | | dim v == n = v | ||
641 | | dim v == 1 = konst (v@>0) n | ||
642 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | ||
643 | |||
644 | repRows n x = fromRows (replicate n (flatten x)) | ||
645 | repCols n x = fromColumns (replicate n (flatten x)) | ||
646 | |||
647 | size m = (rows m, cols m) | ||
648 | |||
649 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | ||
650 | |||
651 | condM a b l e t = reshape c $ cond a' b' l' e' t' | ||
652 | where | ||
653 | r = maximum (map rows [a,b,l,e,t]) | ||
654 | c = maximum (map cols [a,b,l,e,t]) | ||
655 | [a', b', l', e', t'] = map (flatten . conformMTo (r,c)) [a,b,l,e,t] | ||
656 | |||
657 | condV f a b l e t = f a' b' l' e' t' | ||
658 | where | ||
659 | n = maximum (map dim [a,b,l,e,t]) | ||
660 | [a', b', l', e', t'] = map (conformVTo n) [a,b,l,e,t] | ||
661 | |||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index ae437d2..f4ae0f6 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -1267,3 +1267,25 @@ int stepD(DVEC(x),DVEC(y)) { | |||
1267 | OK | 1267 | OK |
1268 | } | 1268 | } |
1269 | 1269 | ||
1270 | //////////////////// cond ///////////////////////// | ||
1271 | |||
1272 | int condF(FVEC(x),FVEC(y),FVEC(lt),FVEC(eq),FVEC(gt),FVEC(r)) { | ||
1273 | REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE); | ||
1274 | DEBUGMSG("condF") | ||
1275 | int k; | ||
1276 | for(k=0;k<xn;k++) { | ||
1277 | rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]); | ||
1278 | } | ||
1279 | OK | ||
1280 | } | ||
1281 | |||
1282 | int condD(DVEC(x),DVEC(y),DVEC(lt),DVEC(eq),DVEC(gt),DVEC(r)) { | ||
1283 | REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE); | ||
1284 | DEBUGMSG("condD") | ||
1285 | int k; | ||
1286 | for(k=0;k<xn;k++) { | ||
1287 | rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]); | ||
1288 | } | ||
1289 | OK | ||
1290 | } | ||
1291 | |||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index 6207a59..9526583 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | |||
@@ -90,6 +90,9 @@ int conjugateC(KCVEC(x),CVEC(t)); | |||
90 | int stepF(FVEC(x),FVEC(y)); | 90 | int stepF(FVEC(x),FVEC(y)); |
91 | int stepD(DVEC(x),DVEC(y)); | 91 | int stepD(DVEC(x),DVEC(y)); |
92 | 92 | ||
93 | int condF(FVEC(x),FVEC(y),FVEC(lt),FVEC(eq),FVEC(gt),FVEC(r)); | ||
94 | int condD(DVEC(x),DVEC(y),DVEC(lt),DVEC(eq),DVEC(gt),DVEC(r)); | ||
95 | |||
93 | int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); | 96 | int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); |
94 | int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); | 97 | int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); |
95 | int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v)); | 98 | int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v)); |
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 181cfbf..76eaaae 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -369,7 +369,14 @@ findAssocTest = utest "findAssoc" ok | |||
369 | where | 369 | where |
370 | ok = m1 == m2 | 370 | ok = m1 == m2 |
371 | m1 = assoc (6,6) 7 $ zip (find (>0) (ident 5 :: Matrix Float)) [10 ..] :: Matrix Double | 371 | m1 = assoc (6,6) 7 $ zip (find (>0) (ident 5 :: Matrix Float)) [10 ..] :: Matrix Double |
372 | m2 = diagRect 7 (fromList[10..14]) 6 6 :: Matrix Double | 372 | m2 = diagRect 7 (fromList[10..14]) 6 6 |
373 | |||
374 | --------------------------------------------------------------------- | ||
375 | |||
376 | condTest = utest "cond" ok | ||
377 | where | ||
378 | ok = step v * v == cond v 0 0 0 v | ||
379 | v = fromList [-7 .. 7 ] :: Vector Float | ||
373 | 380 | ||
374 | --------------------------------------------------------------------- | 381 | --------------------------------------------------------------------- |
375 | 382 | ||
@@ -542,6 +549,7 @@ runTests n = do | |||
542 | , chainTest | 549 | , chainTest |
543 | , succTest | 550 | , succTest |
544 | , findAssocTest | 551 | , findAssocTest |
552 | , condTest | ||
545 | ] | 553 | ] |
546 | return () | 554 | return () |
547 | 555 | ||