summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Internal/Vector.hs19
-rw-r--r--lib/Numeric/ContainerBoot.hs39
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c22
-rw-r--r--lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h3
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs10
5 files changed, 91 insertions, 2 deletions
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index 70bbae2..b68d16a 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -22,7 +22,7 @@ module Data.Packed.Internal.Vector (
22 foldVector, foldVectorG, foldLoop, foldVectorWithIndex, 22 foldVector, foldVectorG, foldLoop, foldVectorWithIndex,
23 createVector, vec, 23 createVector, vec,
24 asComplex, asReal, float2DoubleV, double2FloatV, 24 asComplex, asReal, float2DoubleV, double2FloatV,
25 stepF, stepD, 25 stepF, stepD, condF, condD,
26 fwriteVector, freadVector, fprintfVector, fscanfVector, 26 fwriteVector, freadVector, fprintfVector, fscanfVector,
27 cloneVector, 27 cloneVector,
28 unsafeToForeignPtr, 28 unsafeToForeignPtr,
@@ -310,6 +310,23 @@ stepD v = unsafePerformIO $ do
310foreign import ccall "stepF" c_stepF :: TFF 310foreign import ccall "stepF" c_stepF :: TFF
311foreign import ccall "stepD" c_stepD :: TVV 311foreign import ccall "stepD" c_stepD :: TVV
312 312
313---------------------------------------------------------------
314
315condF :: Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float
316condF x y l e g = unsafePerformIO $ do
317 r <- createVector (dim x)
318 app6 c_condF vec x vec y vec l vec e vec g vec r "condF"
319 return r
320
321condD :: Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double
322condD x y l e g = unsafePerformIO $ do
323 r <- createVector (dim x)
324 app6 c_condD vec x vec y vec l vec e vec g vec r "condD"
325 return r
326
327foreign import ccall "condF" c_condF :: CInt -> PF -> CInt -> PF -> CInt -> PF -> TFFF
328foreign import ccall "condD" c_condD :: CInt -> PD -> CInt -> PD -> CInt -> PD -> TVVV
329
313---------------------------------------------------------------- 330----------------------------------------------------------------
314 331
315cloneVector :: Storable t => Vector t -> IO (Vector t) 332cloneVector :: Storable t => Vector t -> IO (Vector t)
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index 5a8e243..e33857a 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -127,6 +127,8 @@ class (Complexable c, Fractional e, Element e) => Container c e where
127 find :: (e -> Bool) -> c e -> [IndexOf c] 127 find :: (e -> Bool) -> c e -> [IndexOf c]
128 -- | create a structure from an association list 128 -- | create a structure from an association list
129 assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e 129 assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e
130 -- | a vectorized form of case 'compare' a_i b_i of LT -> l_i; EQ -> e_i; GT -> g_i
131 cond :: RealFloat e => c e -> c e -> c e -> c e -> c e -> c e
130 132
131-------------------------------------------------------------------------- 133--------------------------------------------------------------------------
132 134
@@ -155,6 +157,7 @@ instance Container Vector Float where
155 step = stepF 157 step = stepF
156 find = findV 158 find = findV
157 assoc = assocV 159 assoc = assocV
160 cond = condV condF
158 161
159instance Container Vector Double where 162instance Container Vector Double where
160 scale = vectorMapValR Scale 163 scale = vectorMapValR Scale
@@ -181,6 +184,7 @@ instance Container Vector Double where
181 step = stepD 184 step = stepD
182 find = findV 185 find = findV
183 assoc = assocV 186 assoc = assocV
187 cond = condV condD
184 188
185instance Container Vector (Complex Double) where 189instance Container Vector (Complex Double) where
186 scale = vectorMapValC Scale 190 scale = vectorMapValC Scale
@@ -207,6 +211,7 @@ instance Container Vector (Complex Double) where
207 step = undefined -- cannot match 211 step = undefined -- cannot match
208 find = findV 212 find = findV
209 assoc = assocV 213 assoc = assocV
214 cond = undefined -- cannot match
210 215
211instance Container Vector (Complex Float) where 216instance Container Vector (Complex Float) where
212 scale = vectorMapValQ Scale 217 scale = vectorMapValQ Scale
@@ -233,6 +238,7 @@ instance Container Vector (Complex Float) where
233 step = undefined -- cannot match 238 step = undefined -- cannot match
234 find = findV 239 find = findV
235 assoc = assocV 240 assoc = assocV
241 cond = undefined -- cannot match
236 242
237--------------------------------------------------------------- 243---------------------------------------------------------------
238 244
@@ -265,6 +271,7 @@ instance (Container Vector a) => Container Matrix a where
265 step = liftMatrix step 271 step = liftMatrix step
266 find = findM 272 find = findM
267 assoc = assocM 273 assoc = assocM
274 cond = condM
268 275
269---------------------------------------------------- 276----------------------------------------------------
270 277
@@ -620,3 +627,35 @@ assocM (r,c) z xs = ST.runSTMatrix $ do
620 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs 627 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
621 return m 628 return m
622 629
630----------------------------------------------------------------------
631
632conformMTo (r,c) m
633 | size m == (r,c) = m
634 | size m == (1,1) = konst (m@@>(0,0)) (r,c)
635 | size m == (r,1) = repCols c m
636 | size m == (1,c) = repRows r m
637 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
638
639conformVTo n v
640 | dim v == n = v
641 | dim v == 1 = konst (v@>0) n
642 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
643
644repRows n x = fromRows (replicate n (flatten x))
645repCols n x = fromColumns (replicate n (flatten x))
646
647size m = (rows m, cols m)
648
649shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
650
651condM a b l e t = reshape c $ cond a' b' l' e' t'
652 where
653 r = maximum (map rows [a,b,l,e,t])
654 c = maximum (map cols [a,b,l,e,t])
655 [a', b', l', e', t'] = map (flatten . conformMTo (r,c)) [a,b,l,e,t]
656
657condV f a b l e t = f a' b' l' e' t'
658 where
659 n = maximum (map dim [a,b,l,e,t])
660 [a', b', l', e', t'] = map (conformVTo n) [a,b,l,e,t]
661
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
index ae437d2..f4ae0f6 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -1267,3 +1267,25 @@ int stepD(DVEC(x),DVEC(y)) {
1267 OK 1267 OK
1268} 1268}
1269 1269
1270//////////////////// cond /////////////////////////
1271
1272int condF(FVEC(x),FVEC(y),FVEC(lt),FVEC(eq),FVEC(gt),FVEC(r)) {
1273 REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE);
1274 DEBUGMSG("condF")
1275 int k;
1276 for(k=0;k<xn;k++) {
1277 rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]);
1278 }
1279 OK
1280}
1281
1282int condD(DVEC(x),DVEC(y),DVEC(lt),DVEC(eq),DVEC(gt),DVEC(r)) {
1283 REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE);
1284 DEBUGMSG("condD")
1285 int k;
1286 for(k=0;k<xn;k++) {
1287 rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]);
1288 }
1289 OK
1290}
1291
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
index 6207a59..9526583 100644
--- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
+++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
@@ -90,6 +90,9 @@ int conjugateC(KCVEC(x),CVEC(t));
90int stepF(FVEC(x),FVEC(y)); 90int stepF(FVEC(x),FVEC(y));
91int stepD(DVEC(x),DVEC(y)); 91int stepD(DVEC(x),DVEC(y));
92 92
93int condF(FVEC(x),FVEC(y),FVEC(lt),FVEC(eq),FVEC(gt),FVEC(r));
94int condD(DVEC(x),DVEC(y),DVEC(lt),DVEC(eq),DVEC(gt),DVEC(r));
95
93int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); 96int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v));
94int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); 97int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v));
95int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v)); 98int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v));
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 181cfbf..76eaaae 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -369,7 +369,14 @@ findAssocTest = utest "findAssoc" ok
369 where 369 where
370 ok = m1 == m2 370 ok = m1 == m2
371 m1 = assoc (6,6) 7 $ zip (find (>0) (ident 5 :: Matrix Float)) [10 ..] :: Matrix Double 371 m1 = assoc (6,6) 7 $ zip (find (>0) (ident 5 :: Matrix Float)) [10 ..] :: Matrix Double
372 m2 = diagRect 7 (fromList[10..14]) 6 6 :: Matrix Double 372 m2 = diagRect 7 (fromList[10..14]) 6 6
373
374---------------------------------------------------------------------
375
376condTest = utest "cond" ok
377 where
378 ok = step v * v == cond v 0 0 0 v
379 v = fromList [-7 .. 7 ] :: Vector Float
373 380
374--------------------------------------------------------------------- 381---------------------------------------------------------------------
375 382
@@ -542,6 +549,7 @@ runTests n = do
542 , chainTest 549 , chainTest
543 , succTest 550 , succTest
544 , findAssocTest 551 , findAssocTest
552 , condTest
545 ] 553 ]
546 return () 554 return ()
547 555