summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-08 22:43:50 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-08 22:43:50 +0000
commit8050c64f706c027e0446b892ca64814a174013a4 (patch)
treee139fefc28f5d25e18507a949ff9662a3216455b
parent92b1ed5e7fcbebbfbcde34206c040a8472d847d9 (diff)
svdR, some quickCheck
-rw-r--r--examples/pru.hs10
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs74
-rw-r--r--lib/Data/Packed/Internal/aux.c37
-rw-r--r--lib/Data/Packed/Internal/aux.h3
-rw-r--r--lib/LAPACK.hs1
-rw-r--r--lib/LAPACK/Internal.hs36
6 files changed, 142 insertions, 19 deletions
diff --git a/examples/pru.hs b/examples/pru.hs
index 8b25780..10789d2 100644
--- a/examples/pru.hs
+++ b/examples/pru.hs
@@ -38,7 +38,7 @@ bf = (3>|<4) [7,11,15,8,12,16,9,13,17,10,14,18::Double]
38 38
39a |=| b = rows a == rows b && 39a |=| b = rows a == rows b &&
40 cols a == cols b && 40 cols a == cols b &&
41 toList (dat a) == toList (dat b) 41 toList (cdat a) == toList (cdat b)
42 42
43mulC a b = multiply RowMajor a b 43mulC a b = multiply RowMajor a b
44mulF a b = multiply ColumnMajor a b 44mulF a b = multiply ColumnMajor a b
@@ -75,15 +75,14 @@ delta i j | i==j = 1
75 75
76e i n = fromList [ delta k i | k <- [1..n]] 76e i n = fromList [ delta k i | k <- [1..n]]
77 77
78ident n = fromRows [ e i n | i <- [1..n]] 78diagl = diag.fromList
79 79
80diag l = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] 80ident n = diag (constant n 1)
81 where c = length l
82 81
83tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v} 82tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v}
84tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m} 83tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m}
85 84
86td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diag [1..4] :: Tensor Double 85td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diagl [1..4] :: Tensor Double
87 86
88tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double 87tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double
89tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double 88tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double
@@ -114,3 +113,4 @@ names t = sort $ map (snd.snd) (dims t)
114normal t = tridx (names t) t 113normal t = tridx (names t) t
115 114
116contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] 115contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
116
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index db53cd1..bd333d4 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -74,8 +74,7 @@ common f = commonval . map f where
74 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing 74 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
75 75
76 76
77toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m 77toLists m = partit (cols m) . toList . cdat $ m
78 | otherwise = partit (cols m) . toList . dat $ m
79 78
80dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp 79dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
81 where 80 where
@@ -145,6 +144,8 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
145--{-# RULES "transdataR" transdata=transdataR #-} 144--{-# RULES "transdataR" transdata=transdataR #-}
146--{-# RULES "transdataC" transdata=transdataC #-} 145--{-# RULES "transdataC" transdata=transdataC #-}
147 146
147-----------------------------------------------------------------------------
148
148-- | creates a Matrix from a list of vectors 149-- | creates a Matrix from a list of vectors
149fromRows :: Field t => [Vector t] -> Matrix t 150fromRows :: Field t => [Vector t] -> Matrix t
150fromRows vs = case common dim vs of 151fromRows vs = case common dim vs of
@@ -160,6 +161,34 @@ toRows m = toRows' 0 where
160 toRows' k | k == r*c = [] 161 toRows' k | k == r*c = []
161 | otherwise = subVector k c v : toRows' (k+c) 162 | otherwise = subVector k c v : toRows' (k+c)
162 163
164-- | Creates a matrix from a list of vectors, as columns
165fromColumns :: Field t => [Vector t] -> Matrix t
166fromColumns m = trans . fromRows $ m
167
168-- | Creates a list of vectors from the columns of a matrix
169toColumns :: Field t => Matrix t -> [Vector t]
170toColumns m = toRows . trans $ m
171
172-- | creates a matrix from a vertical list of matrices
173joinVert :: Field t => [Matrix t] -> Matrix t
174joinVert ms = case common cols ms of
175 Nothing -> error "joinVert on matrices with different number of columns"
176 Just c -> reshape c $ join (map cdat ms)
177
178-- | creates a matrix from a horizontal list of matrices
179joinHoriz :: Field t => [Matrix t] -> Matrix t
180joinHoriz ms = trans. joinVert . map trans $ ms
181
182------------------------------------------------------------------------------
183
184-- | Reverse rows
185flipud :: Field t => Matrix t -> Matrix t
186flipud m = fromRows . reverse . toRows $ m
187
188-- | Reverse columns
189fliprl :: Field t => Matrix t -> Matrix t
190fliprl m = fromColumns . reverse . toColumns $ m
191
163----------------------------------------------------------------- 192-----------------------------------------------------------------
164 193
165liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes 194liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes
@@ -168,7 +197,11 @@ liftMatrix f m = m { dat = f dat, tdat = f tdat } -- check sizes
168 197
169dotL a b = sum (zipWith (*) a b) 198dotL a b = sum (zipWith (*) a b)
170 199
171multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] 200multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a]
201 | otherwise = error "inconsistent dimensions in contraction "
202 where ok = case common length a of
203 Nothing -> False
204 Just c -> c == length b
172 205
173transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) 206transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m)
174 207
@@ -201,9 +234,8 @@ foreign import ccall safe "aux.h multiplyC"
201 234
202multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a 235multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
203multiply RowMajor a b = multiplyD RowMajor a b 236multiply RowMajor a b = multiplyD RowMajor a b
204multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b 237multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor}
205 238 where m = multiplyD RowMajor (trans b) (trans a)
206multiplyT order a b = multiplyD order (trans b) (trans a)
207 239
208multiplyD order a b 240multiplyD order a b
209 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) 241 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b)
@@ -253,3 +285,33 @@ subMatrix st sz m
253 285
254subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) 286subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x))
255 where subList s n = take n . drop s 287 where subList s n = take n . drop s
288
289---------------------------------------------------------------------
290
291diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do
292 m <- createMatrix RowMajor n n
293 fun // vec v // mat dat m // check msg [dat m]
294 return m
295
296-- | diagonal matrix from a real vector
297diagR :: Vector Double -> Matrix Double
298diagR = diagAux c_diagR "diagR"
299foreign import ccall "aux.h diagR" c_diagR :: Double :> Double ::> IO Int
300
301-- | diagonal matrix from a real vector
302diagC :: Vector (Complex Double) -> Matrix (Complex Double)
303diagC = diagAux c_diagC "diagC"
304foreign import ccall "aux.h diagC" c_diagC :: (Complex Double) :> (Complex Double) ::> IO Int
305
306-- | diagonal matrix from a vector
307diag :: (Num a, Field a) => Vector a -> Matrix a
308diag v
309 | isReal (baseOf) v = scast $ diagR (scast v)
310 | isComp (baseOf) v = scast $ diagC (scast v)
311 | otherwise = diagG v
312
313diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]]
314 where c = dim v
315 l = toList v
316 delta i j | i==j = 1
317 | otherwise = 0
diff --git a/lib/Data/Packed/Internal/aux.c b/lib/Data/Packed/Internal/aux.c
index 01a2bb3..fe611e2 100644
--- a/lib/Data/Packed/Internal/aux.c
+++ b/lib/Data/Packed/Internal/aux.c
@@ -18,19 +18,17 @@
18#define MACRO(B) do {B} while (0) 18#define MACRO(B) do {B} while (0)
19#define ERROR(CODE) MACRO(return CODE;) 19#define ERROR(CODE) MACRO(return CODE;)
20#define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);}) 20#define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);})
21#define OK return 0;
21 22
22#define MIN(A,B) ((A)<(B)?(A):(B)) 23#define MIN(A,B) ((A)<(B)?(A):(B))
23#define MAX(A,B) ((A)>(B)?(A):(B)) 24#define MAX(A,B) ((A)>(B)?(A):(B))
24 25
25#ifdef DBG 26#ifdef DBG
26#define DEBUGMSG(M) printf("GSL Wrapper "M": "); size_t t0 = time(NULL); 27#define DEBUGMSG(M) printf("*** calling aux C function: %s\n",M);
27#define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;);
28#else 28#else
29#define DEBUGMSG(M) 29#define DEBUGMSG(M)
30#define OK return 0;
31#endif 30#endif
32 31
33
34#define CHECK(RES,CODE) MACRO(if(RES) return CODE;) 32#define CHECK(RES,CODE) MACRO(if(RES) return CODE;)
35 33
36#ifdef DBG 34#ifdef DBG
@@ -45,7 +43,6 @@
45#define DEBUGVEC(MSG,X) 43#define DEBUGVEC(MSG,X)
46#endif 44#endif
47 45
48
49#define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n) 46#define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n)
50#define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c) 47#define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c)
51#define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n) 48#define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n)
@@ -66,8 +63,6 @@
66#define MEM 1002 63#define MEM 1002
67#define BAD_FILE 1003 64#define BAD_FILE 1003
68 65
69
70
71int transR(KRMAT(x),RMAT(t)) { 66int transR(KRMAT(x),RMAT(t)) {
72 REQUIRES(xr==tc && xc==tr,BAD_SIZE); 67 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
73 DEBUGMSG("transR"); 68 DEBUGMSG("transR");
@@ -122,6 +117,7 @@ int constantC(gsl_complex* pval, CVEC(r)) {
122 OK 117 OK
123} 118}
124 119
120
125int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) { 121int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) {
126 //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc); 122 //printf("%d %d %d %d %d %d\n",ar,ac,br,bc,rr,rc);
127 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE); 123 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
@@ -155,3 +151,30 @@ int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) {
155 CHECK(res,res); 151 CHECK(res,res);
156 OK 152 OK
157} 153}
154
155
156int diagR(KRVEC(d),RMAT(r)) {
157 REQUIRES(dn==rr && rr==rc,BAD_SIZE);
158 DEBUGMSG("diagR");
159 int i,j;
160 for (i=0;i<rr;i++) {
161 for(j=0;j<rc;j++) {
162 rp[i*rc+j] = i==j?dp[i]:0.;
163 }
164 }
165 OK
166}
167
168int diagC(KCVEC(d),CMAT(r)) {
169 REQUIRES(dn==rr && rr==rc,BAD_SIZE);
170 DEBUGMSG("diagC");
171 int i,j;
172 gsl_complex zero;
173 GSL_SET_COMPLEX(&zero,0.,0.);
174 for (i=0;i<rr;i++) {
175 for(j=0;j<rc;j++) {
176 rp[i*rc+j] = i==j?dp[i]:zero;
177 }
178 }
179 OK
180}
diff --git a/lib/Data/Packed/Internal/aux.h b/lib/Data/Packed/Internal/aux.h
index 59d90db..d055d35 100644
--- a/lib/Data/Packed/Internal/aux.h
+++ b/lib/Data/Packed/Internal/aux.h
@@ -21,3 +21,6 @@ int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r));
21int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)); 21int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r));
22 22
23int submatrixR(int r1, int r2, int c1, int c2, KRMAT(x),RMAT(r)); 23int submatrixR(int r1, int r2, int c1, int c2, KRMAT(x),RMAT(r));
24
25int diagR(KRVEC(d),RMAT(r));
26int diagC(KCVEC(d),CMAT(r));
diff --git a/lib/LAPACK.hs b/lib/LAPACK.hs
index cab437c..67e49af 100644
--- a/lib/LAPACK.hs
+++ b/lib/LAPACK.hs
@@ -14,6 +14,7 @@
14 14
15module LAPACK ( 15module LAPACK (
16 --module LAPACK.Internal 16 --module LAPACK.Internal
17 svdR, svdR',
17 eigC, 18 eigC,
18 linearSolveLSR 19 linearSolveLSR
19) where 20) where
diff --git a/lib/LAPACK/Internal.hs b/lib/LAPACK/Internal.hs
index 4c755bc..2569215 100644
--- a/lib/LAPACK/Internal.hs
+++ b/lib/LAPACK/Internal.hs
@@ -27,6 +27,24 @@ import Foreign.C.String
27foreign import ccall "lapack-aux.h svd_l_R" 27foreign import ccall "lapack-aux.h svd_l_R"
28 dgesvd :: Double ::> Double ::> (Double :> Double ::> IO Int) 28 dgesvd :: Double ::> Double ::> (Double :> Double ::> IO Int)
29 29
30-- | Wrapper for LAPACK's /dgesvd/, which computes the full svd decomposition of a real matrix.
31--
32-- @(u,s,v)=svdR m@ so that @m=u \<\> s \<\> 'trans' v@.
33svdR :: Matrix Double -> (Matrix Double, Matrix Double , Matrix Double)
34svdR x@M {rows = r, cols = c} = (u, s, v)
35 where (u,s',v) = svdR' x
36 s | r == c = diag s'
37 | r < c = joinHoriz [diag s' , zeros (r,c-r)]
38 | otherwise = joinVert [diag s' , zeros (r-c,c)]
39 zeros (r,c) = reshape c $ constant (r*c) 0
40
41svdR' x@M {rows = r, cols = c} = unsafePerformIO $ do
42 u <- createMatrix ColumnMajor r r
43 s <- createVector (min r c)
44 v <- createMatrix ColumnMajor c c
45 dgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdR" [fdat x]
46 return (u,s,trans v)
47
30----------------------------------------------------------------------------- 48-----------------------------------------------------------------------------
31-- dgesdd 49-- dgesdd
32foreign import ccall "lapack-aux.h svd_l_Rdd" 50foreign import ccall "lapack-aux.h svd_l_Rdd"
@@ -62,11 +80,27 @@ foreign import ccall "lapack-aux.h eig_l_R"
62 dgeev :: Double ::> Double ::> ((Complex Double) :> Double ::> IO Int) 80 dgeev :: Double ::> Double ::> ((Complex Double) :> Double ::> IO Int)
63 81
64----------------------------------------------------------------------------- 82-----------------------------------------------------------------------------
65
66-- dsyev 83-- dsyev
67foreign import ccall "lapack-aux.h eig_l_S" 84foreign import ccall "lapack-aux.h eig_l_S"
68 dsyev :: Double ::> (Double :> Double ::> IO Int) 85 dsyev :: Double ::> (Double :> Double ::> IO Int)
69 86
87-- | Wrapper for LAPACK's /dsyev/, which computes the eigenvalues and right eigenvectors of a symmetric real matrix:
88--
89-- if @(l,v)=eigSl m@ then @m \<\> v = v \<\> diag l@.
90--
91-- The eigenvectors are the columns of v.
92-- The eigenvalues are sorted in descending order (use eigS' for ascending order).
93eigS :: Matrix Double -> (Vector Double, Matrix Double)
94eigS m = (s', fliprl v)
95 where (s,v) = eigS' m
96 s' = fromList . reverse . toList $ s
97
98eigS' (m@M {rows = r}) = unsafePerformIO $ do
99 l <- createVector r
100 v <- createMatrix ColumnMajor r r
101 dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m]
102 return (l,v)
103
70----------------------------------------------------------------------------- 104-----------------------------------------------------------------------------
71-- zheev 105-- zheev
72foreign import ccall "lapack-aux.h eig_l_H" 106foreign import ccall "lapack-aux.h eig_l_H"