From 8050c64f706c027e0446b892ca64814a174013a4 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 8 Jun 2007 22:43:50 +0000 Subject: svdR, some quickCheck --- lib/Data/Packed/Internal/Matrix.hs | 74 ++++++++++++++++++++++++++++++++++---- lib/Data/Packed/Internal/aux.c | 37 +++++++++++++++---- lib/Data/Packed/Internal/aux.h | 3 ++ 3 files changed, 101 insertions(+), 13 deletions(-) (limited to 'lib/Data/Packed/Internal') diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index db53cd1..bd333d4 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -74,8 +74,7 @@ common f = commonval . map f where commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing -toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m - | otherwise = partit (cols m) . toList . dat $ m +toLists m = partit (cols m) . toList . cdat $ m dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp where @@ -145,6 +144,8 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 --{-# RULES "transdataR" transdata=transdataR #-} --{-# RULES "transdataC" transdata=transdataC #-} +----------------------------------------------------------------------------- + -- | creates a Matrix from a list of vectors fromRows :: Field t => [Vector t] -> Matrix t fromRows vs = case common dim vs of @@ -160,6 +161,34 @@ toRows m = toRows' 0 where toRows' k | k == r*c = [] | otherwise = subVector k c v : toRows' (k+c) +-- | Creates a matrix from a list of vectors, as columns +fromColumns :: Field t => [Vector t] -> Matrix t +fromColumns m = trans . fromRows $ m + +-- | Creates a list of vectors from the columns of a matrix +toColumns :: Field t => Matrix t -> [Vector t] +toColumns m = toRows . trans $ m + +-- | creates a matrix from a vertical list of matrices +joinVert :: Field t => [Matrix t] -> Matrix t +joinVert ms = case common cols ms of + Nothing -> error "joinVert on matrices with different number of columns" + Just c -> reshape c $ join (map cdat ms) + +-- | creates a matrix from a horizontal list of matrices +joinHoriz :: Field t => [Matrix t] -> Matrix t +joinHoriz ms = trans. joinVert . map trans $ ms + +------------------------------------------------------------------------------ + +-- | Reverse rows +flipud :: Field t => Matrix t -> Matrix t +flipud m = fromRows . reverse . toRows $ m + +-- | Reverse columns +fliprl :: Field t => Matrix t -> Matrix t +fliprl m = fromColumns . reverse . toColumns $ m + ----------------------------------------------------------------- liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes @@ -168,7 +197,11 @@ liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes dotL a b = sum (zipWith (*) a b) -multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] +multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] + | otherwise = error "inconsistent dimensions in contraction " + where ok = case common length a of + Nothing -> False + Just c -> c == length b transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) @@ -201,9 +234,8 @@ foreign import ccall safe "aux.h multiplyC" multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a multiply RowMajor a b = multiplyD RowMajor a b -multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b - -multiplyT order a b = multiplyD order (trans b) (trans a) +multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} + where m = multiplyD RowMajor (trans b) (trans a) multiplyD order a b | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) @@ -253,3 +285,33 @@ subMatrix st sz m subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) where subList s n = take n . drop s + +--------------------------------------------------------------------- + +diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do + m <- createMatrix RowMajor n n + fun // vec v // mat dat m // check msg [dat m] + return m + +-- | diagonal matrix from a real vector +diagR :: Vector Double -> Matrix Double +diagR = diagAux c_diagR "diagR" +foreign import ccall "aux.h diagR" c_diagR :: Double :> Double ::> IO Int + +-- | diagonal matrix from a real vector +diagC :: Vector (Complex Double) -> Matrix (Complex Double) +diagC = diagAux c_diagC "diagC" +foreign import ccall "aux.h diagC" c_diagC :: (Complex Double) :> (Complex Double) ::> IO Int + +-- | diagonal matrix from a vector +diag :: (Num a, Field a) => Vector a -> Matrix a +diag v + | isReal (baseOf) v = scast $ diagR (scast v) + | isComp (baseOf) v = scast $ diagC (scast v) + | otherwise = diagG v + +diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] + where c = dim v + l = toList v + delta i j | i==j = 1 + | otherwise = 0 diff --git a/lib/Data/Packed/Internal/aux.c b/lib/Data/Packed/Internal/aux.c index 01a2bb3..fe611e2 100644 --- a/lib/Data/Packed/Internal/aux.c +++ b/lib/Data/Packed/Internal/aux.c @@ -18,19 +18,17 @@ #define MACRO(B) do {B} while (0) #define ERROR(CODE) MACRO(return CODE;) #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) +#define OK return 0; #define MIN(A,B) ((A)<(B)?(A):(B)) #define MAX(A,B) ((A)>(B)?(A):(B)) #ifdef DBG -#define DEBUGMSG(M) printf("GSL Wrapper "M": "); size_t t0 = time(NULL); -#define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;); +#define DEBUGMSG(M) printf("*** calling aux C function: %s\n",M); #else #define DEBUGMSG(M) -#define OK return 0; #endif - #define CHECK(RES,CODE) MACRO(if(RES) return CODE;) #ifdef DBG @@ -45,7 +43,6 @@ #define DEBUGVEC(MSG,X) #endif - #define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) #define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) #define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n) @@ -66,8 +63,6 @@ #define MEM 1002 #define BAD_FILE 1003 - - int transR(KRMAT(x),RMAT(t)) { REQUIRES(xr==tc && xc==tr,BAD_SIZE); DEBUGMSG("transR"); @@ -122,6 +117,7 @@ int constantC(gsl_complex* pval, CVEC(r)) { OK } + int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) { //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc); //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); @@ -155,3 +151,30 @@ int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) { CHECK(res,res); OK } + + +int diagR(KRVEC(d),RMAT(r)) { + REQUIRES(dn==rr && rr==rc,BAD_SIZE); + DEBUGMSG("diagR"); + int i,j; + for (i=0;i