From e503945c666dc28f1a806ba1a2deaa587a836200 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Thu, 30 Dec 2010 18:07:39 +0000 Subject: cond --- lib/Data/Packed/Internal/Vector.hs | 19 ++++++++++++- lib/Numeric/ContainerBoot.hs | 39 +++++++++++++++++++++++++++ lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 22 +++++++++++++++ lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 3 +++ lib/Numeric/LinearAlgebra/Tests.hs | 10 ++++++- 5 files changed, 91 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 70bbae2..b68d16a 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs @@ -22,7 +22,7 @@ module Data.Packed.Internal.Vector ( foldVector, foldVectorG, foldLoop, foldVectorWithIndex, createVector, vec, asComplex, asReal, float2DoubleV, double2FloatV, - stepF, stepD, + stepF, stepD, condF, condD, fwriteVector, freadVector, fprintfVector, fscanfVector, cloneVector, unsafeToForeignPtr, @@ -310,6 +310,23 @@ stepD v = unsafePerformIO $ do foreign import ccall "stepF" c_stepF :: TFF foreign import ccall "stepD" c_stepD :: TVV +--------------------------------------------------------------- + +condF :: Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float +condF x y l e g = unsafePerformIO $ do + r <- createVector (dim x) + app6 c_condF vec x vec y vec l vec e vec g vec r "condF" + return r + +condD :: Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double +condD x y l e g = unsafePerformIO $ do + r <- createVector (dim x) + app6 c_condD vec x vec y vec l vec e vec g vec r "condD" + return r + +foreign import ccall "condF" c_condF :: CInt -> PF -> CInt -> PF -> CInt -> PF -> TFFF +foreign import ccall "condD" c_condD :: CInt -> PD -> CInt -> PD -> CInt -> PD -> TVVV + ---------------------------------------------------------------- cloneVector :: Storable t => Vector t -> IO (Vector t) diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index 5a8e243..e33857a 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs @@ -127,6 +127,8 @@ class (Complexable c, Fractional e, Element e) => Container c e where find :: (e -> Bool) -> c e -> [IndexOf c] -- | create a structure from an association list assoc :: IndexOf c -> e -> [(IndexOf c, e)] -> c e + -- | a vectorized form of case 'compare' a_i b_i of LT -> l_i; EQ -> e_i; GT -> g_i + cond :: RealFloat e => c e -> c e -> c e -> c e -> c e -> c e -------------------------------------------------------------------------- @@ -155,6 +157,7 @@ instance Container Vector Float where step = stepF find = findV assoc = assocV + cond = condV condF instance Container Vector Double where scale = vectorMapValR Scale @@ -181,6 +184,7 @@ instance Container Vector Double where step = stepD find = findV assoc = assocV + cond = condV condD instance Container Vector (Complex Double) where scale = vectorMapValC Scale @@ -207,6 +211,7 @@ instance Container Vector (Complex Double) where step = undefined -- cannot match find = findV assoc = assocV + cond = undefined -- cannot match instance Container Vector (Complex Float) where scale = vectorMapValQ Scale @@ -233,6 +238,7 @@ instance Container Vector (Complex Float) where step = undefined -- cannot match find = findV assoc = assocV + cond = undefined -- cannot match --------------------------------------------------------------- @@ -265,6 +271,7 @@ instance (Container Vector a) => Container Matrix a where step = liftMatrix step find = findM assoc = assocM + cond = condM ---------------------------------------------------- @@ -620,3 +627,35 @@ assocM (r,c) z xs = ST.runSTMatrix $ do mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs return m +---------------------------------------------------------------------- + +conformMTo (r,c) m + | size m == (r,c) = m + | size m == (1,1) = konst (m@@>(0,0)) (r,c) + | size m == (r,1) = repCols c m + | size m == (1,c) = repRows r m + | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" + +conformVTo n v + | dim v == n = v + | dim v == 1 = konst (v@>0) n + | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n + +repRows n x = fromRows (replicate n (flatten x)) +repCols n x = fromColumns (replicate n (flatten x)) + +size m = (rows m, cols m) + +shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" + +condM a b l e t = reshape c $ cond a' b' l' e' t' + where + r = maximum (map rows [a,b,l,e,t]) + c = maximum (map cols [a,b,l,e,t]) + [a', b', l', e', t'] = map (flatten . conformMTo (r,c)) [a,b,l,e,t] + +condV f a b l e t = f a' b' l' e' t' + where + n = maximum (map dim [a,b,l,e,t]) + [a', b', l', e', t'] = map (conformVTo n) [a,b,l,e,t] + diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index ae437d2..f4ae0f6 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c @@ -1267,3 +1267,25 @@ int stepD(DVEC(x),DVEC(y)) { OK } +//////////////////// cond ///////////////////////// + +int condF(FVEC(x),FVEC(y),FVEC(lt),FVEC(eq),FVEC(gt),FVEC(r)) { + REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE); + DEBUGMSG("condF") + int k; + for(k=0;kyp[k]?gtp[k]:eqp[k]); + } + OK +} + +int condD(DVEC(x),DVEC(y),DVEC(lt),DVEC(eq),DVEC(gt),DVEC(r)) { + REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE); + DEBUGMSG("condD") + int k; + for(k=0;kyp[k]?gtp[k]:eqp[k]); + } + OK +} + diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index 6207a59..9526583 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h @@ -90,6 +90,9 @@ int conjugateC(KCVEC(x),CVEC(t)); int stepF(FVEC(x),FVEC(y)); int stepD(DVEC(x),DVEC(y)); +int condF(FVEC(x),FVEC(y),FVEC(lt),FVEC(eq),FVEC(gt),FVEC(r)); +int condD(DVEC(x),DVEC(y),DVEC(lt),DVEC(eq),DVEC(gt),DVEC(r)); + int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v)); diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 181cfbf..76eaaae 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs @@ -369,7 +369,14 @@ findAssocTest = utest "findAssoc" ok where ok = m1 == m2 m1 = assoc (6,6) 7 $ zip (find (>0) (ident 5 :: Matrix Float)) [10 ..] :: Matrix Double - m2 = diagRect 7 (fromList[10..14]) 6 6 :: Matrix Double + m2 = diagRect 7 (fromList[10..14]) 6 6 + +--------------------------------------------------------------------- + +condTest = utest "cond" ok + where + ok = step v * v == cond v 0 0 0 v + v = fromList [-7 .. 7 ] :: Vector Float --------------------------------------------------------------------- @@ -542,6 +549,7 @@ runTests n = do , chainTest , succTest , findAssocTest + , condTest ] return () -- cgit v1.2.3