summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-08 22:43:50 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-08 22:43:50 +0000
commit8050c64f706c027e0446b892ca64814a174013a4 (patch)
treee139fefc28f5d25e18507a949ff9662a3216455b /lib/Data/Packed/Internal
parent92b1ed5e7fcbebbfbcde34206c040a8472d847d9 (diff)
svdR, some quickCheck
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs74
-rw-r--r--lib/Data/Packed/Internal/aux.c37
-rw-r--r--lib/Data/Packed/Internal/aux.h3
3 files changed, 101 insertions, 13 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index db53cd1..bd333d4 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -74,8 +74,7 @@ common f = commonval . map f where
74 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing 74 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
75 75
76 76
77toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m 77toLists m = partit (cols m) . toList . cdat $ m
78 | otherwise = partit (cols m) . toList . dat $ m
79 78
80dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp 79dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
81 where 80 where
@@ -145,6 +144,8 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
145--{-# RULES "transdataR" transdata=transdataR #-} 144--{-# RULES "transdataR" transdata=transdataR #-}
146--{-# RULES "transdataC" transdata=transdataC #-} 145--{-# RULES "transdataC" transdata=transdataC #-}
147 146
147-----------------------------------------------------------------------------
148
148-- | creates a Matrix from a list of vectors 149-- | creates a Matrix from a list of vectors
149fromRows :: Field t => [Vector t] -> Matrix t 150fromRows :: Field t => [Vector t] -> Matrix t
150fromRows vs = case common dim vs of 151fromRows vs = case common dim vs of
@@ -160,6 +161,34 @@ toRows m = toRows' 0 where
160 toRows' k | k == r*c = [] 161 toRows' k | k == r*c = []
161 | otherwise = subVector k c v : toRows' (k+c) 162 | otherwise = subVector k c v : toRows' (k+c)
162 163
164-- | Creates a matrix from a list of vectors, as columns
165fromColumns :: Field t => [Vector t] -> Matrix t
166fromColumns m = trans . fromRows $ m
167
168-- | Creates a list of vectors from the columns of a matrix
169toColumns :: Field t => Matrix t -> [Vector t]
170toColumns m = toRows . trans $ m
171
172-- | creates a matrix from a vertical list of matrices
173joinVert :: Field t => [Matrix t] -> Matrix t
174joinVert ms = case common cols ms of
175 Nothing -> error "joinVert on matrices with different number of columns"
176 Just c -> reshape c $ join (map cdat ms)
177
178-- | creates a matrix from a horizontal list of matrices
179joinHoriz :: Field t => [Matrix t] -> Matrix t
180joinHoriz ms = trans. joinVert . map trans $ ms
181
182------------------------------------------------------------------------------
183
184-- | Reverse rows
185flipud :: Field t => Matrix t -> Matrix t
186flipud m = fromRows . reverse . toRows $ m
187
188-- | Reverse columns
189fliprl :: Field t => Matrix t -> Matrix t
190fliprl m = fromColumns . reverse . toColumns $ m
191
163----------------------------------------------------------------- 192-----------------------------------------------------------------
164 193
165liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes 194liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes
@@ -168,7 +197,11 @@ liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes
168 197
169dotL a b = sum (zipWith (*) a b) 198dotL a b = sum (zipWith (*) a b)
170 199
171multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] 200multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a]
201 | otherwise = error "inconsistent dimensions in contraction "
202 where ok = case common length a of
203 Nothing -> False
204 Just c -> c == length b
172 205
173transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) 206transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m)
174 207
@@ -201,9 +234,8 @@ foreign import ccall safe "aux.h multiplyC"
201 234
202multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a 235multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
203multiply RowMajor a b = multiplyD RowMajor a b 236multiply RowMajor a b = multiplyD RowMajor a b
204multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b 237multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor}
205 238 where m = multiplyD RowMajor (trans b) (trans a)
206multiplyT order a b = multiplyD order (trans b) (trans a)
207 239
208multiplyD order a b 240multiplyD order a b
209 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) 241 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b)
@@ -253,3 +285,33 @@ subMatrix st sz m
253 285
254subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) 286subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x))
255 where subList s n = take n . drop s 287 where subList s n = take n . drop s
288
289---------------------------------------------------------------------
290
291diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do
292 m <- createMatrix RowMajor n n
293 fun // vec v // mat dat m // check msg [dat m]
294 return m
295
296-- | diagonal matrix from a real vector
297diagR :: Vector Double -> Matrix Double
298diagR = diagAux c_diagR "diagR"
299foreign import ccall "aux.h diagR" c_diagR :: Double :> Double ::> IO Int
300
301-- | diagonal matrix from a real vector
302diagC :: Vector (Complex Double) -> Matrix (Complex Double)
303diagC = diagAux c_diagC "diagC"
304foreign import ccall "aux.h diagC" c_diagC :: (Complex Double) :> (Complex Double) ::> IO Int
305
306-- | diagonal matrix from a vector
307diag :: (Num a, Field a) => Vector a -> Matrix a
308diag v
309 | isReal (baseOf) v = scast $ diagR (scast v)
310 | isComp (baseOf) v = scast $ diagC (scast v)
311 | otherwise = diagG v
312
313diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]]
314 where c = dim v
315 l = toList v
316 delta i j | i==j = 1
317 | otherwise = 0
diff --git a/lib/Data/Packed/Internal/aux.c b/lib/Data/Packed/Internal/aux.c
index 01a2bb3..fe611e2 100644
--- a/lib/Data/Packed/Internal/aux.c
+++ b/lib/Data/Packed/Internal/aux.c
@@ -18,19 +18,17 @@
18#define MACRO(B) do {B} while (0) 18#define MACRO(B) do {B} while (0)
19#define ERROR(CODE) MACRO(return CODE;) 19#define ERROR(CODE) MACRO(return CODE;)
20#define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) 20#define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);})
21#define OK return 0;
21 22
22#define MIN(A,B) ((A)<(B)?(A):(B)) 23#define MIN(A,B) ((A)<(B)?(A):(B))
23#define MAX(A,B) ((A)>(B)?(A):(B)) 24#define MAX(A,B) ((A)>(B)?(A):(B))
24 25
25#ifdef DBG 26#ifdef DBG
26#define DEBUGMSG(M) printf("GSL Wrapper "M": "); size_t t0 = time(NULL); 27#define DEBUGMSG(M) printf("*** calling aux C function: %s\n",M);
27#define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;);
28#else 28#else
29#define DEBUGMSG(M) 29#define DEBUGMSG(M)
30#define OK return 0;
31#endif 30#endif
32 31
33
34#define CHECK(RES,CODE) MACRO(if(RES) return CODE;) 32#define CHECK(RES,CODE) MACRO(if(RES) return CODE;)
35 33
36#ifdef DBG 34#ifdef DBG
@@ -45,7 +43,6 @@
45#define DEBUGVEC(MSG,X) 43#define DEBUGVEC(MSG,X)
46#endif 44#endif
47 45
48
49#define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) 46#define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n)
50#define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) 47#define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c)
51#define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n) 48#define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n)
@@ -66,8 +63,6 @@
66#define MEM 1002 63#define MEM 1002
67#define BAD_FILE 1003 64#define BAD_FILE 1003
68 65
69
70
71int transR(KRMAT(x),RMAT(t)) { 66int transR(KRMAT(x),RMAT(t)) {
72 REQUIRES(xr==tc && xc==tr,BAD_SIZE); 67 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
73 DEBUGMSG("transR"); 68 DEBUGMSG("transR");
@@ -122,6 +117,7 @@ int constantC(gsl_complex* pval, CVEC(r)) {
122 OK 117 OK
123} 118}
124 119
120
125int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) { 121int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) {
126 //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc); 122 //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc);
127 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); 123 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
@@ -155,3 +151,30 @@ int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) {
155 CHECK(res,res); 151 CHECK(res,res);
156 OK 152 OK
157} 153}
154
155
156int diagR(KRVEC(d),RMAT(r)) {
157 REQUIRES(dn==rr && rr==rc,BAD_SIZE);
158 DEBUGMSG("diagR");
159 int i,j;
160 for (i=0;i<rr;i++) {
161 for(j=0;j<rc;j++) {
162 rp[i*rc+j] = i==j?dp[i]:0.;
163 }
164 }
165 OK
166}
167
168int diagC(KCVEC(d),CMAT(r)) {
169 REQUIRES(dn==rr && rr==rc,BAD_SIZE);
170 DEBUGMSG("diagC");
171 int i,j;
172 gsl_complex zero;
173 GSL_SET_COMPLEX(&zero,0.,0.);
174 for (i=0;i<rr;i++) {
175 for(j=0;j<rc;j++) {
176 rp[i*rc+j] = i==j?dp[i]:zero;
177 }
178 }
179 OK
180}
diff --git a/lib/Data/Packed/Internal/aux.h b/lib/Data/Packed/Internal/aux.h
index 59d90db..d055d35 100644
--- a/lib/Data/Packed/Internal/aux.h
+++ b/lib/Data/Packed/Internal/aux.h
@@ -21,3 +21,6 @@ int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r));
21int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)); 21int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r));
22 22
23int submatrixR(int r1, int r2, int c1, int c2, KRMAT(x),RMAT(r)); 23int submatrixR(int r1, int r2, int c1, int c2, KRMAT(x),RMAT(r));
24
25int diagR(KRVEC(d),RMAT(r));
26int diagC(KCVEC(d),CMAT(r));