diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-08 22:43:50 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-08 22:43:50 +0000 |
commit | 8050c64f706c027e0446b892ca64814a174013a4 (patch) | |
tree | e139fefc28f5d25e18507a949ff9662a3216455b /lib/Data | |
parent | 92b1ed5e7fcbebbfbcde34206c040a8472d847d9 (diff) |
svdR, some quickCheck
Diffstat (limited to 'lib/Data')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 74 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/aux.c | 37 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/aux.h | 3 |
3 files changed, 101 insertions, 13 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index db53cd1..bd333d4 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -74,8 +74,7 @@ common f = commonval . map f where | |||
74 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing | 74 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing |
75 | 75 | ||
76 | 76 | ||
77 | toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m | 77 | toLists m = partit (cols m) . toList . cdat $ m |
78 | | otherwise = partit (cols m) . toList . dat $ m | ||
79 | 78 | ||
80 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | 79 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp |
81 | where | 80 | where |
@@ -145,6 +144,8 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | |||
145 | --{-# RULES "transdataR" transdata=transdataR #-} | 144 | --{-# RULES "transdataR" transdata=transdataR #-} |
146 | --{-# RULES "transdataC" transdata=transdataC #-} | 145 | --{-# RULES "transdataC" transdata=transdataC #-} |
147 | 146 | ||
147 | ----------------------------------------------------------------------------- | ||
148 | |||
148 | -- | creates a Matrix from a list of vectors | 149 | -- | creates a Matrix from a list of vectors |
149 | fromRows :: Field t => [Vector t] -> Matrix t | 150 | fromRows :: Field t => [Vector t] -> Matrix t |
150 | fromRows vs = case common dim vs of | 151 | fromRows vs = case common dim vs of |
@@ -160,6 +161,34 @@ toRows m = toRows' 0 where | |||
160 | toRows' k | k == r*c = [] | 161 | toRows' k | k == r*c = [] |
161 | | otherwise = subVector k c v : toRows' (k+c) | 162 | | otherwise = subVector k c v : toRows' (k+c) |
162 | 163 | ||
164 | -- | Creates a matrix from a list of vectors, as columns | ||
165 | fromColumns :: Field t => [Vector t] -> Matrix t | ||
166 | fromColumns m = trans . fromRows $ m | ||
167 | |||
168 | -- | Creates a list of vectors from the columns of a matrix | ||
169 | toColumns :: Field t => Matrix t -> [Vector t] | ||
170 | toColumns m = toRows . trans $ m | ||
171 | |||
172 | -- | creates a matrix from a vertical list of matrices | ||
173 | joinVert :: Field t => [Matrix t] -> Matrix t | ||
174 | joinVert ms = case common cols ms of | ||
175 | Nothing -> error "joinVert on matrices with different number of columns" | ||
176 | Just c -> reshape c $ join (map cdat ms) | ||
177 | |||
178 | -- | creates a matrix from a horizontal list of matrices | ||
179 | joinHoriz :: Field t => [Matrix t] -> Matrix t | ||
180 | joinHoriz ms = trans. joinVert . map trans $ ms | ||
181 | |||
182 | ------------------------------------------------------------------------------ | ||
183 | |||
184 | -- | Reverse rows | ||
185 | flipud :: Field t => Matrix t -> Matrix t | ||
186 | flipud m = fromRows . reverse . toRows $ m | ||
187 | |||
188 | -- | Reverse columns | ||
189 | fliprl :: Field t => Matrix t -> Matrix t | ||
190 | fliprl m = fromColumns . reverse . toColumns $ m | ||
191 | |||
163 | ----------------------------------------------------------------- | 192 | ----------------------------------------------------------------- |
164 | 193 | ||
165 | liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes | 194 | liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes |
@@ -168,7 +197,11 @@ liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes | |||
168 | 197 | ||
169 | dotL a b = sum (zipWith (*) a b) | 198 | dotL a b = sum (zipWith (*) a b) |
170 | 199 | ||
171 | multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] | 200 | multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] |
201 | | otherwise = error "inconsistent dimensions in contraction " | ||
202 | where ok = case common length a of | ||
203 | Nothing -> False | ||
204 | Just c -> c == length b | ||
172 | 205 | ||
173 | transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) | 206 | transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) |
174 | 207 | ||
@@ -201,9 +234,8 @@ foreign import ccall safe "aux.h multiplyC" | |||
201 | 234 | ||
202 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | 235 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a |
203 | multiply RowMajor a b = multiplyD RowMajor a b | 236 | multiply RowMajor a b = multiplyD RowMajor a b |
204 | multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b | 237 | multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} |
205 | 238 | where m = multiplyD RowMajor (trans b) (trans a) | |
206 | multiplyT order a b = multiplyD order (trans b) (trans a) | ||
207 | 239 | ||
208 | multiplyD order a b | 240 | multiplyD order a b |
209 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | 241 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) |
@@ -253,3 +285,33 @@ subMatrix st sz m | |||
253 | 285 | ||
254 | subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) | 286 | subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) |
255 | where subList s n = take n . drop s | 287 | where subList s n = take n . drop s |
288 | |||
289 | --------------------------------------------------------------------- | ||
290 | |||
291 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do | ||
292 | m <- createMatrix RowMajor n n | ||
293 | fun // vec v // mat dat m // check msg [dat m] | ||
294 | return m | ||
295 | |||
296 | -- | diagonal matrix from a real vector | ||
297 | diagR :: Vector Double -> Matrix Double | ||
298 | diagR = diagAux c_diagR "diagR" | ||
299 | foreign import ccall "aux.h diagR" c_diagR :: Double :> Double ::> IO Int | ||
300 | |||
301 | -- | diagonal matrix from a real vector | ||
302 | diagC :: Vector (Complex Double) -> Matrix (Complex Double) | ||
303 | diagC = diagAux c_diagC "diagC" | ||
304 | foreign import ccall "aux.h diagC" c_diagC :: (Complex Double) :> (Complex Double) ::> IO Int | ||
305 | |||
306 | -- | diagonal matrix from a vector | ||
307 | diag :: (Num a, Field a) => Vector a -> Matrix a | ||
308 | diag v | ||
309 | | isReal (baseOf) v = scast $ diagR (scast v) | ||
310 | | isComp (baseOf) v = scast $ diagC (scast v) | ||
311 | | otherwise = diagG v | ||
312 | |||
313 | diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] | ||
314 | where c = dim v | ||
315 | l = toList v | ||
316 | delta i j | i==j = 1 | ||
317 | | otherwise = 0 | ||
diff --git a/lib/Data/Packed/Internal/aux.c b/lib/Data/Packed/Internal/aux.c index 01a2bb3..fe611e2 100644 --- a/lib/Data/Packed/Internal/aux.c +++ b/lib/Data/Packed/Internal/aux.c | |||
@@ -18,19 +18,17 @@ | |||
18 | #define MACRO(B) do {B} while (0) | 18 | #define MACRO(B) do {B} while (0) |
19 | #define ERROR(CODE) MACRO(return CODE;) | 19 | #define ERROR(CODE) MACRO(return CODE;) |
20 | #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) | 20 | #define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) |
21 | #define OK return 0; | ||
21 | 22 | ||
22 | #define MIN(A,B) ((A)<(B)?(A):(B)) | 23 | #define MIN(A,B) ((A)<(B)?(A):(B)) |
23 | #define MAX(A,B) ((A)>(B)?(A):(B)) | 24 | #define MAX(A,B) ((A)>(B)?(A):(B)) |
24 | 25 | ||
25 | #ifdef DBG | 26 | #ifdef DBG |
26 | #define DEBUGMSG(M) printf("GSL Wrapper "M": "); size_t t0 = time(NULL); | 27 | #define DEBUGMSG(M) printf("*** calling aux C function: %s\n",M); |
27 | #define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;); | ||
28 | #else | 28 | #else |
29 | #define DEBUGMSG(M) | 29 | #define DEBUGMSG(M) |
30 | #define OK return 0; | ||
31 | #endif | 30 | #endif |
32 | 31 | ||
33 | |||
34 | #define CHECK(RES,CODE) MACRO(if(RES) return CODE;) | 32 | #define CHECK(RES,CODE) MACRO(if(RES) return CODE;) |
35 | 33 | ||
36 | #ifdef DBG | 34 | #ifdef DBG |
@@ -45,7 +43,6 @@ | |||
45 | #define DEBUGVEC(MSG,X) | 43 | #define DEBUGVEC(MSG,X) |
46 | #endif | 44 | #endif |
47 | 45 | ||
48 | |||
49 | #define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) | 46 | #define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) |
50 | #define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) | 47 | #define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) |
51 | #define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n) | 48 | #define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n) |
@@ -66,8 +63,6 @@ | |||
66 | #define MEM 1002 | 63 | #define MEM 1002 |
67 | #define BAD_FILE 1003 | 64 | #define BAD_FILE 1003 |
68 | 65 | ||
69 | |||
70 | |||
71 | int transR(KRMAT(x),RMAT(t)) { | 66 | int transR(KRMAT(x),RMAT(t)) { |
72 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); | 67 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); |
73 | DEBUGMSG("transR"); | 68 | DEBUGMSG("transR"); |
@@ -122,6 +117,7 @@ int constantC(gsl_complex* pval, CVEC(r)) { | |||
122 | OK | 117 | OK |
123 | } | 118 | } |
124 | 119 | ||
120 | |||
125 | int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) { | 121 | int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) { |
126 | //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc); | 122 | //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc); |
127 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); | 123 | //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); |
@@ -155,3 +151,30 @@ int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) { | |||
155 | CHECK(res,res); | 151 | CHECK(res,res); |
156 | OK | 152 | OK |
157 | } | 153 | } |
154 | |||
155 | |||
156 | int diagR(KRVEC(d),RMAT(r)) { | ||
157 | REQUIRES(dn==rr && rr==rc,BAD_SIZE); | ||
158 | DEBUGMSG("diagR"); | ||
159 | int i,j; | ||
160 | for (i=0;i<rr;i++) { | ||
161 | for(j=0;j<rc;j++) { | ||
162 | rp[i*rc+j] = i==j?dp[i]:0.; | ||
163 | } | ||
164 | } | ||
165 | OK | ||
166 | } | ||
167 | |||
168 | int diagC(KCVEC(d),CMAT(r)) { | ||
169 | REQUIRES(dn==rr && rr==rc,BAD_SIZE); | ||
170 | DEBUGMSG("diagC"); | ||
171 | int i,j; | ||
172 | gsl_complex zero; | ||
173 | GSL_SET_COMPLEX(&zero,0.,0.); | ||
174 | for (i=0;i<rr;i++) { | ||
175 | for(j=0;j<rc;j++) { | ||
176 | rp[i*rc+j] = i==j?dp[i]:zero; | ||
177 | } | ||
178 | } | ||
179 | OK | ||
180 | } | ||
diff --git a/lib/Data/Packed/Internal/aux.h b/lib/Data/Packed/Internal/aux.h index 59d90db..d055d35 100644 --- a/lib/Data/Packed/Internal/aux.h +++ b/lib/Data/Packed/Internal/aux.h | |||
@@ -21,3 +21,6 @@ int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)); | |||
21 | int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)); | 21 | int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)); |
22 | 22 | ||
23 | int submatrixR(int r1, int r2, int c1, int c2, KRMAT(x),RMAT(r)); | 23 | int submatrixR(int r1, int r2, int c1, int c2, KRMAT(x),RMAT(r)); |
24 | |||
25 | int diagR(KRVEC(d),RMAT(r)); | ||
26 | int diagC(KCVEC(d),CMAT(r)); | ||