diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-11-12 10:01:39 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-11-12 10:01:39 +0000 |
commit | c41d21fefa04c66039a0b218daaa53c2577ef838 (patch) | |
tree | 3dd182457a89edbf52688fb43fdc8b2a130829e8 /lib/Data/Packed | |
parent | 9e6500bf8e925b363e74e01834f81dde0810f808 (diff) |
data structures simplification
Diffstat (limited to 'lib/Data/Packed')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 66 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 36 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/auxi.c | 39 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 12 |
4 files changed, 97 insertions, 56 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index fbab33c..010c40b 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -62,21 +62,35 @@ import Data.List(transpose) | |||
62 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | 62 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
63 | 63 | ||
64 | -- | Matrix representation suitable for GSL and LAPACK computations. | 64 | -- | Matrix representation suitable for GSL and LAPACK computations. |
65 | data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t, fdat :: Vector t } | 65 | data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t } |
66 | | MF { rows :: Int, cols :: Int, fdat :: Vector t, cdat :: Vector t } | 66 | | MF { rows :: Int, cols :: Int, fdat :: Vector t } |
67 | 67 | ||
68 | -- MC: preferred by C, fdat may require a transposition | 68 | -- MC: preferred by C, fdat may require a transposition |
69 | -- MF: preferred by LAPACK, cdat may require a transposition | 69 | -- MF: preferred by LAPACK, cdat may require a transposition |
70 | 70 | ||
71 | -- | Matrix transpose. | 71 | -- | Matrix transpose. |
72 | trans :: Matrix t -> Matrix t | 72 | trans :: Matrix t -> Matrix t |
73 | trans MC {rows = r, cols = c, cdat = d, fdat = dt } = MF {rows = c, cols = r, fdat = d, cdat = dt } | 73 | trans MC {rows = r, cols = c, cdat = d } = MF {rows = c, cols = r, fdat = d } |
74 | trans MF {rows = r, cols = c, fdat = d, cdat = dt } = MC {rows = c, cols = r, cdat = d, fdat = dt } | 74 | trans MF {rows = r, cols = c, fdat = d } = MC {rows = c, cols = r, cdat = d } |
75 | 75 | ||
76 | dat MC { cdat = d } = d | 76 | cmat m@MC{} = m |
77 | dat MF { fdat = d } = d | 77 | cmat MF {rows = r, cols = c, fdat = d } = MC {rows = r, cols = c, cdat = transdata r d c} |
78 | |||
79 | fmat m@MF{} = m | ||
80 | fmat MC {rows = r, cols = c, cdat = d } = MF {rows = r, cols = c, fdat = transdata c d r} | ||
81 | |||
82 | matc m f = f (rows m) (cols m) (ptr (cdat m)) | ||
83 | matf m f = f (rows m) (cols m) (ptr (fdat m)) | ||
84 | |||
85 | |||
86 | {- | Creates a vector by concatenation of rows | ||
87 | |||
88 | @\> flatten ('ident' 3) | ||
89 | 9 |> [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ | ||
90 | -} | ||
91 | flatten :: Element t => Matrix t -> Vector t | ||
92 | flatten = cdat . cmat | ||
78 | 93 | ||
79 | mat d m f = f (rows m) (cols m) (ptr (d m)) | ||
80 | 94 | ||
81 | type Mt t s = Int -> Int -> Ptr t -> s | 95 | type Mt t s = Int -> Int -> Ptr t -> s |
82 | -- not yet admitted by my haddock version | 96 | -- not yet admitted by my haddock version |
@@ -85,7 +99,7 @@ type Mt t s = Int -> Int -> Ptr t -> s | |||
85 | 99 | ||
86 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | 100 | -- | the inverse of 'Data.Packed.Matrix.fromLists' |
87 | toLists :: (Element t) => Matrix t -> [[t]] | 101 | toLists :: (Element t) => Matrix t -> [[t]] |
88 | toLists m = partit (cols m) . toList . cdat $ m | 102 | toLists m = partit (cols m) . toList . flatten $ m |
89 | 103 | ||
90 | -- | creates a Matrix from a list of vectors | 104 | -- | creates a Matrix from a list of vectors |
91 | fromRows :: Element t => [Vector t] -> Matrix t | 105 | fromRows :: Element t => [Vector t] -> Matrix t |
@@ -96,7 +110,7 @@ fromRows vs = case common dim vs of | |||
96 | -- | extracts the rows of a matrix as a list of vectors | 110 | -- | extracts the rows of a matrix as a list of vectors |
97 | toRows :: Element t => Matrix t -> [Vector t] | 111 | toRows :: Element t => Matrix t -> [Vector t] |
98 | toRows m = toRows' 0 where | 112 | toRows m = toRows' 0 where |
99 | v = cdat m | 113 | v = flatten $ m |
100 | r = rows m | 114 | r = rows m |
101 | c = cols m | 115 | c = cols m |
102 | toRows' k | k == r*c = [] | 116 | toRows' k | k == r*c = [] |
@@ -128,12 +142,12 @@ MF {rows = r, cols = c, fdat = v} @@> (i,j) | |||
128 | 142 | ||
129 | ------------------------------------------------------------------ | 143 | ------------------------------------------------------------------ |
130 | 144 | ||
131 | matrixFromVector RowMajor c v = MC { rows = r, cols = c, cdat = v, fdat = transdata c v r } | 145 | matrixFromVector RowMajor c v = MC { rows = r, cols = c, cdat = v } |
132 | where (d,m) = dim v `divMod` c | 146 | where (d,m) = dim v `divMod` c |
133 | r | m==0 = d | 147 | r | m==0 = d |
134 | | otherwise = error "matrixFromVector" | 148 | | otherwise = error "matrixFromVector" |
135 | 149 | ||
136 | matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, fdat = v, cdat = transdata r v c } | 150 | matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, fdat = v } |
137 | where (d,m) = dim v `divMod` c | 151 | where (d,m) = dim v `divMod` c |
138 | r | m==0 = d | 152 | r | m==0 = d |
139 | | otherwise = error "matrixFromVector" | 153 | | otherwise = error "matrixFromVector" |
@@ -167,8 +181,8 @@ liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vec | |||
167 | liftMatrix2 f m1 m2 | 181 | liftMatrix2 f m1 m2 |
168 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | 182 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" |
169 | | otherwise = case m1 of | 183 | | otherwise = case m1 of |
170 | MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (cdat m2)) | 184 | MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (flatten m2)) |
171 | MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) (fdat m2)) | 185 | MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) ((fdat.fmat) m2)) |
172 | 186 | ||
173 | 187 | ||
174 | compat :: Matrix a -> Matrix b -> Bool | 188 | compat :: Matrix a -> Matrix b -> Bool |
@@ -222,8 +236,8 @@ transdataAux fun c1 d c2 = | |||
222 | then d | 236 | then d |
223 | else unsafePerformIO $ do | 237 | else unsafePerformIO $ do |
224 | v <- createVector (dim d) | 238 | v <- createVector (dim d) |
225 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | 239 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d,v] |
226 | --putStrLn "---> transdataAux" | 240 | -- putStrLn $ "---> transdataAux" ++ show (toList d) ++ show (toList v) |
227 | return v | 241 | return v |
228 | where r1 = dim d `div` c1 | 242 | where r1 = dim d `div` c1 |
229 | r2 = dim d `div` c2 | 243 | r2 = dim d `div` c2 |
@@ -239,11 +253,14 @@ foreign import ccall safe "auxi.h transC" | |||
239 | gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) | 253 | gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) |
240 | gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) | 254 | gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) |
241 | 255 | ||
256 | dtt MC { cdat = d } = d | ||
257 | dtt MF { fdat = d } = d | ||
258 | |||
242 | multiplyAux fun a b = unsafePerformIO $ do | 259 | multiplyAux fun a b = unsafePerformIO $ do |
243 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | 260 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ |
244 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | 261 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) |
245 | r <- createMatrix RowMajor (rows a) (cols b) | 262 | r <- createMatrix RowMajor (rows a) (cols b) |
246 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] | 263 | fun // gmatC a // gmatC b // matc r // check "multiplyAux" [dtt a, dtt b, cdat r] |
247 | return r | 264 | return r |
248 | 265 | ||
249 | multiplyR = multiplyAux cmultiplyR | 266 | multiplyR = multiplyAux cmultiplyR |
@@ -273,18 +290,19 @@ multiply = multiplyD | |||
273 | 290 | ||
274 | -- | extraction of a submatrix from a real matrix | 291 | -- | extraction of a submatrix from a real matrix |
275 | subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double | 292 | subMatrixR :: (Int,Int) -> (Int,Int) -> Matrix Double -> Matrix Double |
276 | subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do | 293 | subMatrixR (r0,c0) (rt,ct) x' = unsafePerformIO $ do |
277 | r <- createMatrix RowMajor rt ct | 294 | r <- createMatrix RowMajor rt ct |
278 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r] | 295 | let x = cmat x' |
296 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // matc x // matc r // check "subMatrixR" [cdat x] | ||
279 | return r | 297 | return r |
280 | foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM | 298 | foreign import ccall "auxi.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM |
281 | 299 | ||
282 | -- | extraction of a submatrix from a complex matrix | 300 | -- | extraction of a submatrix from a complex matrix |
283 | subMatrixC :: (Int,Int) -> (Int,Int) -> Matrix (Complex Double) -> Matrix (Complex Double) | 301 | subMatrixC :: (Int,Int) -> (Int,Int) -> Matrix (Complex Double) -> Matrix (Complex Double) |
284 | subMatrixC (r0,c0) (rt,ct) x = | 302 | subMatrixC (r0,c0) (rt,ct) x = |
285 | reshape ct . asComplex . cdat . | 303 | reshape ct . asComplex . flatten . |
286 | subMatrixR (r0,2*c0) (rt,2*ct) . | 304 | subMatrixR (r0,2*c0) (rt,2*ct) . |
287 | reshape (2*cols x) . asReal . cdat $ x | 305 | reshape (2*cols x) . asReal . flatten $ x |
288 | 306 | ||
289 | -- | Extracts a submatrix from a matrix. | 307 | -- | Extracts a submatrix from a matrix. |
290 | subMatrix :: Element a | 308 | subMatrix :: Element a |
@@ -299,7 +317,7 @@ subMatrix = subMatrixD | |||
299 | 317 | ||
300 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do | 318 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do |
301 | m <- createMatrix RowMajor n n | 319 | m <- createMatrix RowMajor n n |
302 | fun // vec v // mat cdat m // check msg [dat m] | 320 | fun // vec v // matc m // check msg [cdat m] |
303 | return m -- {tdat = dat m} | 321 | return m -- {tdat = dat m} |
304 | 322 | ||
305 | -- | diagonal matrix from a real vector | 323 | -- | diagonal matrix from a real vector |
@@ -347,11 +365,11 @@ constant = constantD | |||
347 | 365 | ||
348 | -- | obtains the complex conjugate of a complex vector | 366 | -- | obtains the complex conjugate of a complex vector |
349 | conj :: Vector (Complex Double) -> Vector (Complex Double) | 367 | conj :: Vector (Complex Double) -> Vector (Complex Double) |
350 | conj v = asComplex $ cdat $ reshape 2 (asReal v) `multiply` diag (fromList [1,-1]) | 368 | conj v = asComplex $ flatten $ reshape 2 (asReal v) `multiply` diag (fromList [1,-1]) |
351 | 369 | ||
352 | -- | creates a complex vector from vectors with real and imaginary parts | 370 | -- | creates a complex vector from vectors with real and imaginary parts |
353 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) | 371 | toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) |
354 | toComplex (r,i) = asComplex $ cdat $ fromColumns [r,i] | 372 | toComplex (r,i) = asComplex $ flatten $ fromColumns [r,i] |
355 | 373 | ||
356 | -- | the inverse of 'toComplex' | 374 | -- | the inverse of 'toComplex' |
357 | fromComplex :: Vector (Complex Double) -> (Vector Double, Vector Double) | 375 | fromComplex :: Vector (Complex Double) -> (Vector Double, Vector Double) |
@@ -367,7 +385,7 @@ fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double) | |||
367 | fromFile filename (r,c) = do | 385 | fromFile filename (r,c) = do |
368 | charname <- newCString filename | 386 | charname <- newCString filename |
369 | res <- createMatrix RowMajor r c | 387 | res <- createMatrix RowMajor r c |
370 | c_gslReadMatrix charname // mat dat res // check "gslReadMatrix" [] | 388 | c_gslReadMatrix charname // matc res // check "gslReadMatrix" [] |
371 | --free charname -- TO DO: free the auxiliary CString | 389 | --free charname -- TO DO: free the auxiliary CString |
372 | return res | 390 | return res |
373 | foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM | 391 | foreign import ccall "auxi.h matrix_fscanf" c_gslReadMatrix:: Ptr CChar -> TM |
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 082e09d..dc86484 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -27,11 +27,12 @@ import Foreign.C.Types | |||
27 | import Data.Monoid | 27 | import Data.Monoid |
28 | 28 | ||
29 | -- | A one-dimensional array of objects stored in a contiguous memory block. | 29 | -- | A one-dimensional array of objects stored in a contiguous memory block. |
30 | data Vector t = V { dim :: Int -- ^ number of elements | 30 | data Vector t = V { dim :: Int -- ^ number of elements |
31 | , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block | 31 | , fptr :: ForeignPtr t -- ^ foreign pointer to the memory block |
32 | , ptr :: Ptr t -- ^ ordinary pointer to the actual starting address (usually the same) | ||
33 | } | 32 | } |
34 | 33 | ||
34 | ptr (V _ fptr) = unsafeForeignPtrToPtr fptr | ||
35 | |||
35 | -- | check the error code and touch foreign ptr of vector arguments (if any) | 36 | -- | check the error code and touch foreign ptr of vector arguments (if any) |
36 | check :: String -> [Vector a] -> IO Int -> IO () | 37 | check :: String -> [Vector a] -> IO Int -> IO () |
37 | check msg ls f = do | 38 | check msg ls f = do |
@@ -63,9 +64,7 @@ createVector :: Storable a => Int -> IO (Vector a) | |||
63 | createVector n = do | 64 | createVector n = do |
64 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | 65 | when (n <= 0) $ error ("trying to createVector of dim "++show n) |
65 | fp <- mallocForeignPtrArray n | 66 | fp <- mallocForeignPtrArray n |
66 | let p = unsafeForeignPtrToPtr fp | 67 | return $ V n fp |
67 | --putStrLn ("\n---------> V"++show n) | ||
68 | return $ V n fp p | ||
69 | 68 | ||
70 | {- | creates a Vector from a list: | 69 | {- | creates a Vector from a list: |
71 | 70 | ||
@@ -80,7 +79,7 @@ fromList l = unsafePerformIO $ do | |||
80 | f // vec v // check "fromList" [] | 79 | f // vec v // check "fromList" [] |
81 | return v | 80 | return v |
82 | 81 | ||
83 | safeRead v f = unsafePerformIO $ withForeignPtr (fptr v) $ const $ f (ptr v) | 82 | safeRead v = unsafePerformIO . withForeignPtr (fptr v) |
84 | 83 | ||
85 | {- | extracts the Vector elements to a list | 84 | {- | extracts the Vector elements to a list |
86 | 85 | ||
@@ -115,19 +114,14 @@ subVector :: Storable t => Int -- ^ index of the starting element | |||
115 | -> Int -- ^ number of elements to extract | 114 | -> Int -- ^ number of elements to extract |
116 | -> Vector t -- ^ source | 115 | -> Vector t -- ^ source |
117 | -> Vector t -- ^ result | 116 | -> Vector t -- ^ result |
118 | subVector k l (v@V {dim=n, ptr=p, fptr=fp}) | 117 | subVector k l (v@V {dim=n}) |
119 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | 118 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" |
120 | | otherwise = unsafePerformIO $ do | 119 | | otherwise = unsafePerformIO $ do |
121 | r <- createVector l | 120 | r <- createVector l |
122 | let f = copyArray (ptr r) (advancePtr p k) l >> return 0 | 121 | let f = copyArray (ptr r) (advancePtr (ptr v) k) l >> return 0 |
123 | f // check "subVector" [v] | 122 | f // check "subVector" [v,r] |
124 | return r | 123 | return r |
125 | 124 | ||
126 | subVector' k l (v@V {dim=n, ptr=p, fptr=fp}) | ||
127 | | k<0 || k >= n || k+l > n || l < 0 = error "subVector out of range" | ||
128 | | otherwise = v {dim=l, ptr=advancePtr p k} | ||
129 | |||
130 | |||
131 | {- | Reads a vector position: | 125 | {- | Reads a vector position: |
132 | 126 | ||
133 | @> fromList [0..9] \@\> 7 | 127 | @> fromList [0..9] \@\> 7 |
@@ -149,23 +143,23 @@ join :: Storable t => [Vector t] -> Vector t | |||
149 | join [] = error "joining zero vectors" | 143 | join [] = error "joining zero vectors" |
150 | join as = unsafePerformIO $ do | 144 | join as = unsafePerformIO $ do |
151 | let tot = sum (map dim as) | 145 | let tot = sum (map dim as) |
152 | r@V {fptr = p, ptr = p'} <- createVector tot | 146 | r@V {fptr = p} <- createVector tot |
153 | withForeignPtr p $ \_ -> | 147 | withForeignPtr p $ \_ -> |
154 | joiner as tot p' | 148 | joiner as tot (ptr r) |
155 | return r | 149 | return r |
156 | where joiner [] _ _ = return () | 150 | where joiner [] _ _ = return () |
157 | joiner (V {dim = n, fptr = b, ptr = q} : cs) _ p = do | 151 | joiner (r@V {dim = n, fptr = b} : cs) _ p = do |
158 | withForeignPtr b $ \_ -> copyArray p q n | 152 | withForeignPtr b $ \_ -> copyArray p (ptr r) n |
159 | joiner cs 0 (advancePtr p n) | 153 | joiner cs 0 (advancePtr p n) |
160 | 154 | ||
161 | 155 | ||
162 | -- | transforms a complex vector into a real vector with alternating real and imaginary parts | 156 | -- | transforms a complex vector into a real vector with alternating real and imaginary parts |
163 | asReal :: Vector (Complex Double) -> Vector Double | 157 | asReal :: Vector (Complex Double) -> Vector Double |
164 | asReal v = V { dim = 2*dim v, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } | 158 | asReal v = V { dim = 2*dim v, fptr = castForeignPtr (fptr v) } |
165 | 159 | ||
166 | -- | transforms a real vector into a complex vector with alternating real and imaginary parts | 160 | -- | transforms a real vector into a complex vector with alternating real and imaginary parts |
167 | asComplex :: Vector Double -> Vector (Complex Double) | 161 | asComplex :: Vector Double -> Vector (Complex Double) |
168 | asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } | 162 | asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v) } |
169 | 163 | ||
170 | ---------------------------------------------------------------- | 164 | ---------------------------------------------------------------- |
171 | 165 | ||
diff --git a/lib/Data/Packed/Internal/auxi.c b/lib/Data/Packed/Internal/auxi.c index b53d9b7..7f83bcf 100644 --- a/lib/Data/Packed/Internal/auxi.c +++ b/lib/Data/Packed/Internal/auxi.c | |||
@@ -125,11 +125,26 @@ int multiplyR(int ta, KRMAT(a), int tb, KRMAT(b),RMAT(r)) { | |||
125 | KDMVIEW(a); | 125 | KDMVIEW(a); |
126 | KDMVIEW(b); | 126 | KDMVIEW(b); |
127 | DMVIEW(r); | 127 | DMVIEW(r); |
128 | int k; | ||
129 | for(k=0;k<rr*rc;k++) rp[k]=0; | ||
130 | int debug = 0; | ||
131 | if(debug) { | ||
132 | printf("---------------------------\n"); | ||
133 | printf("%p: ",ap); for(k=0;k<ar*ac;k++) printf("%f ",ap[k]); printf("\n"); | ||
134 | printf("%p: ",bp); for(k=0;k<br*bc;k++) printf("%f ",bp[k]); printf("\n"); | ||
135 | printf("%p: ",rp); for(k=0;k<rr*rc;k++) printf("%f ",rp[k]); printf("\n"); | ||
136 | } | ||
128 | int res = gsl_blas_dgemm( | 137 | int res = gsl_blas_dgemm( |
129 | ta?CblasTrans:CblasNoTrans, | 138 | ta?CblasTrans:CblasNoTrans, |
130 | tb?CblasTrans:CblasNoTrans, | 139 | tb?CblasTrans:CblasNoTrans, |
131 | 1.0, M(a), M(b), | 140 | 1.0, M(a), M(b), |
132 | 0.0, M(r)); | 141 | 0.0, M(r)); |
142 | if(debug) { | ||
143 | printf("--------------\n"); | ||
144 | printf("%p: ",ap); for(k=0;k<ar*ac;k++) printf("%f ",ap[k]); printf("\n"); | ||
145 | printf("%p: ",bp); for(k=0;k<br*bc;k++) printf("%f ",bp[k]); printf("\n"); | ||
146 | printf("%p: ",rp); for(k=0;k<rr*rc;k++) printf("%f ",rp[k]); printf("\n"); | ||
147 | } | ||
133 | CHECK(res,res); | 148 | CHECK(res,res); |
134 | OK | 149 | OK |
135 | } | 150 | } |
@@ -140,14 +155,36 @@ int multiplyC(int ta, KCMAT(a), int tb, KCMAT(b),CMAT(r)) { | |||
140 | KCMVIEW(a); | 155 | KCMVIEW(a); |
141 | KCMVIEW(b); | 156 | KCMVIEW(b); |
142 | CMVIEW(r); | 157 | CMVIEW(r); |
158 | int k; | ||
143 | gsl_complex alpha, beta; | 159 | gsl_complex alpha, beta; |
144 | GSL_SET_COMPLEX(&alpha,1.,0.); | 160 | GSL_SET_COMPLEX(&alpha,1.,0.); |
145 | GSL_SET_COMPLEX(&beta,0.,0.); | 161 | GSL_SET_COMPLEX(&beta,0.,0.); |
162 | //double *TEMP = (double*)malloc(rr*rc*2*sizeof(double)); | ||
163 | //gsl_matrix_complex_view T = gsl_matrix_complex_view_array(TEMP,rr,rc); | ||
164 | for(k=0;k<rr*rc;k++) rp[k]=beta; | ||
165 | //for(k=0;k<2*rr*rc;k++) TEMP[k]=0; | ||
166 | int debug = 0; | ||
167 | if(debug) { | ||
168 | printf("---------------------------\n"); | ||
169 | printf("%p: ",ap); for(k=0;k<2*ar*ac;k++) printf("%f ",((double*)ap)[k]); printf("\n"); | ||
170 | printf("%p: ",bp); for(k=0;k<2*br*bc;k++) printf("%f ",((double*)bp)[k]); printf("\n"); | ||
171 | printf("%p: ",rp); for(k=0;k<2*rr*rc;k++) printf("%f ",((double*)rp)[k]); printf("\n"); | ||
172 | //printf("%p: ",T); for(k=0;k<2*rr*rc;k++) printf("%f ",TEMP[k]); printf("\n"); | ||
173 | } | ||
146 | int res = gsl_blas_zgemm( | 174 | int res = gsl_blas_zgemm( |
147 | ta?CblasTrans:CblasNoTrans, | 175 | ta?CblasTrans:CblasNoTrans, |
148 | tb?CblasTrans:CblasNoTrans, | 176 | tb?CblasTrans:CblasNoTrans, |
149 | alpha, M(a), M(b), | 177 | alpha, M(a), M(b), |
150 | beta, M(r)); | 178 | beta, M(r)); |
179 | //&T.matrix); | ||
180 | //memcpy(rp,TEMP,2*rr*rc*sizeof(double)); | ||
181 | if(debug) { | ||
182 | printf("--------------\n"); | ||
183 | printf("%p: ",ap); for(k=0;k<2*ar*ac;k++) printf("%f ",((double*)ap)[k]); printf("\n"); | ||
184 | printf("%p: ",bp); for(k=0;k<2*br*bc;k++) printf("%f ",((double*)bp)[k]); printf("\n"); | ||
185 | printf("%p: ",rp); for(k=0;k<2*rr*rc;k++) printf("%f ",((double*)rp)[k]); printf("\n"); | ||
186 | //printf("%p: ",T); for(k=0;k<2*rr*rc;k++) printf("%f ",TEMP[k]); printf("\n"); | ||
187 | } | ||
151 | CHECK(res,res); | 188 | CHECK(res,res); |
152 | OK | 189 | OK |
153 | } | 190 | } |
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index e96500f..7b6bf29 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -44,7 +44,7 @@ import Data.Array | |||
44 | joinVert :: Element t => [Matrix t] -> Matrix t | 44 | joinVert :: Element t => [Matrix t] -> Matrix t |
45 | joinVert ms = case common cols ms of | 45 | joinVert ms = case common cols ms of |
46 | Nothing -> error "joinVert on matrices with different number of columns" | 46 | Nothing -> error "joinVert on matrices with different number of columns" |
47 | Just c -> reshape c $ join (map cdat ms) | 47 | Just c -> reshape c $ join (map flatten ms) |
48 | 48 | ||
49 | -- | creates a matrix from a horizontal list of matrices | 49 | -- | creates a matrix from a horizontal list of matrices |
50 | joinHoriz :: Element t => [Matrix t] -> Matrix t | 50 | joinHoriz :: Element t => [Matrix t] -> Matrix t |
@@ -94,7 +94,7 @@ diagRect s r c | |||
94 | 94 | ||
95 | -- | extracts the diagonal from a rectangular matrix | 95 | -- | extracts the diagonal from a rectangular matrix |
96 | takeDiag :: (Element t) => Matrix t -> Vector t | 96 | takeDiag :: (Element t) => Matrix t -> Vector t |
97 | takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | 97 | takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] |
98 | 98 | ||
99 | -- | creates the identity matrix of given dimension | 99 | -- | creates the identity matrix of given dimension |
100 | ident :: Element a => Int -> Matrix a | 100 | ident :: Element a => Int -> Matrix a |
@@ -136,14 +136,6 @@ dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat | |||
136 | 136 | ||
137 | ---------------------------------------------------------------- | 137 | ---------------------------------------------------------------- |
138 | 138 | ||
139 | {- | Creates a vector by concatenation of rows | ||
140 | |||
141 | @\> flatten ('ident' 3) | ||
142 | 9 |> [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ | ||
143 | -} | ||
144 | flatten :: Element t => Matrix t -> Vector t | ||
145 | flatten = cdat | ||
146 | |||
147 | {- | Creates a 'Matrix' from a list of lists (considered as rows). | 139 | {- | Creates a 'Matrix' from a list of lists (considered as rows). |
148 | 140 | ||
149 | @\> fromLists [[1,2],[3,4],[5,6]] | 141 | @\> fromLists [[1,2],[3,4],[5,6]] |