diff options
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 398 |
1 files changed, 163 insertions, 235 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 6ba2d06..ba32a67 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-overlapping-instances #-} | 1 | {-# OPTIONS_GHC -fglasgow-exts #-} |
2 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
3 | -- | | 3 | -- | |
4 | -- Module : Data.Packed.Internal.Matrix | 4 | -- Module : Data.Packed.Internal.Matrix |
@@ -22,65 +22,10 @@ import Foreign hiding (xor) | |||
22 | import Complex | 22 | import Complex |
23 | import Control.Monad(when) | 23 | import Control.Monad(when) |
24 | import Data.List(transpose,intersperse) | 24 | import Data.List(transpose,intersperse) |
25 | --import Data.Typeable | ||
26 | import Data.Maybe(fromJust) | 25 | import Data.Maybe(fromJust) |
27 | 26 | ||
28 | ---------------------------------------------------------------- | ||
29 | |||
30 | -- the condition Storable a => Field a means that we can only put | ||
31 | -- in Field types that are in Storable, and therefore Storable a | ||
32 | -- is not required in signatures if we have a Field a. | ||
33 | |||
34 | class Storable a => Field a where | ||
35 | constant :: a -> Int -> Vector a | ||
36 | transdata :: Int -> Vector a -> Int -> Vector a | ||
37 | multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
38 | subMatrix :: (Int,Int) -- ^ (r0,c0) starting position | ||
39 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
40 | -> Matrix a -> Matrix a | ||
41 | diag :: Vector a -> Matrix a | ||
42 | |||
43 | instance Field Double where | ||
44 | constant = constantR | ||
45 | transdata = transdataR | ||
46 | multiplyD = multiplyR | ||
47 | subMatrix = subMatrixR | ||
48 | diag = diagR | ||
49 | |||
50 | instance Field (Complex Double) where | ||
51 | constant = constantC | ||
52 | transdata = transdataC | ||
53 | multiplyD = multiplyC | ||
54 | subMatrix = subMatrixC | ||
55 | diag = diagC | ||
56 | |||
57 | ----------------------------------------------------------------- | 27 | ----------------------------------------------------------------- |
58 | 28 | ||
59 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
60 | transdataR = transdataAux ctransR | ||
61 | |||
62 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
63 | transdataC = transdataAux ctransC | ||
64 | |||
65 | transdataAux fun c1 d c2 = | ||
66 | if noneed | ||
67 | then d | ||
68 | else unsafePerformIO $ do | ||
69 | v <- createVector (dim d) | ||
70 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
71 | --putStrLn "---> transdataAux" | ||
72 | return v | ||
73 | where r1 = dim d `div` c1 | ||
74 | r2 = dim d `div` c2 | ||
75 | noneed = r1 == 1 || c1 == 1 | ||
76 | |||
77 | foreign import ccall safe "aux.h transR" | ||
78 | ctransR :: TMM -- Double ::> Double ::> IO Int | ||
79 | foreign import ccall safe "aux.h transC" | ||
80 | ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int | ||
81 | |||
82 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
83 | |||
84 | {- Design considerations for the Matrix Type | 29 | {- Design considerations for the Matrix Type |
85 | ----------------------------------------- | 30 | ----------------------------------------- |
86 | 31 | ||
@@ -111,103 +56,79 @@ transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | |||
111 | 56 | ||
112 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | 57 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
113 | 58 | ||
114 | {- | 59 | data Matrix t = MC { rows :: Int, cols :: Int, cdat :: Vector t, fdat :: Vector t } |
115 | 60 | | MF { rows :: Int, cols :: Int, fdat :: Vector t, cdat :: Vector t } | |
116 | |||
117 | |||
118 | data Matrix t = M { rows :: Int | ||
119 | , cols :: Int | ||
120 | , dat :: Vector t | ||
121 | , tdat :: Vector t | ||
122 | , isTrans :: Bool | ||
123 | , order :: MatrixOrder | ||
124 | } -- deriving Typeable | ||
125 | -} | ||
126 | |||
127 | data Matrix t = MC { rows :: Int, cols :: Int, dat :: Vector t } -- row major order | ||
128 | | MF { rows :: Int, cols :: Int, dat :: Vector t } -- column major order | ||
129 | 61 | ||
130 | -- transposition just changes the data order | 62 | -- transposition just changes the data order |
131 | trans :: Matrix t -> Matrix t | 63 | trans :: Matrix t -> Matrix t |
132 | trans MC {rows = r, cols = c, dat = d} = MF {rows = c, cols = r, dat = d} | 64 | trans MC {rows = r, cols = c, cdat = d, fdat = dt } = MF {rows = c, cols = r, fdat = d, cdat = dt } |
133 | trans MF {rows = r, cols = c, dat = d} = MC {rows = c, cols = r, dat = d} | 65 | trans MF {rows = r, cols = c, fdat = d, cdat = dt } = MC {rows = c, cols = r, cdat = d, fdat = dt } |
134 | |||
135 | viewC m@MC{} = m | ||
136 | viewC MF {rows = r, cols = c, dat = d} = MC {rows = r, cols = c, dat = transdata r d c} | ||
137 | 66 | ||
138 | viewF m@MF{} = m | 67 | dat MC { cdat = d } = d |
139 | viewF MC {rows = r, cols = c, dat = d} = MF {rows = r, cols = c, dat = transdata c d r} | 68 | dat MF { fdat = d } = d |
140 | 69 | ||
141 | --fortran m = order m == ColumnMajor | 70 | mat d m f = f (rows m) (cols m) (ptr (d m)) |
142 | |||
143 | cdat m = dat (viewC m) | ||
144 | fdat m = dat (viewF m) | ||
145 | 71 | ||
146 | type Mt t s = Int -> Int -> Ptr t -> s | 72 | type Mt t s = Int -> Int -> Ptr t -> s |
147 | -- not yet admitted by my haddock version | 73 | -- not yet admitted by my haddock version |
148 | -- infixr 6 ::> | 74 | -- infixr 6 ::> |
149 | -- type t ::> s = Mt t s | 75 | -- type t ::> s = Mt t s |
150 | 76 | ||
151 | mat d m f = f (rows m) (cols m) (ptr (d m)) | 77 | -- | the inverse of 'fromLists' |
152 | --mat m f = f (rows m) (cols m) (ptr (dat m)) | 78 | toLists :: (Field t) => Matrix t -> [[t]] |
153 | --matC m f = f (rows m) (cols m) (ptr (cdat m)) | 79 | toLists m = partit (cols m) . toList . cdat $ m |
154 | 80 | ||
81 | -- | creates a Matrix from a list of vectors | ||
82 | fromRows :: Field t => [Vector t] -> Matrix t | ||
83 | fromRows vs = case common dim vs of | ||
84 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" | ||
85 | Just c -> reshape c (join vs) | ||
155 | 86 | ||
156 | --toLists :: (Storable t) => Matrix t -> [[t]] | 87 | -- | extracts the rows of a matrix as a list of vectors |
157 | toLists m = partit (cols m) . toList . cdat $ m | 88 | toRows :: Field t => Matrix t -> [Vector t] |
89 | toRows m = toRows' 0 where | ||
90 | v = cdat m | ||
91 | r = rows m | ||
92 | c = cols m | ||
93 | toRows' k | k == r*c = [] | ||
94 | | otherwise = subVector k c v : toRows' (k+c) | ||
158 | 95 | ||
159 | instance (Show a, Field a) => (Show (Matrix a)) where | 96 | -- | Creates a matrix from a list of vectors, as columns |
160 | show m = (sizes++) . dsp . map (map show) . toLists $ m | 97 | fromColumns :: Field t => [Vector t] -> Matrix t |
161 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | 98 | fromColumns m = trans . fromRows $ m |
162 | 99 | ||
163 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | 100 | -- | Creates a list of vectors from the columns of a matrix |
164 | where | 101 | toColumns :: Field t => Matrix t -> [Vector t] |
165 | mt = transpose as | 102 | toColumns m = toRows . trans $ m |
166 | longs = map (maximum . map length) mt | ||
167 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
168 | pad n str = replicate (n - length str) ' ' ++ str | ||
169 | unwords' = concat . intersperse ", " | ||
170 | 103 | ||
171 | {- | ||
172 | matrixFromVector RowMajor c v = | ||
173 | M { rows = r | ||
174 | , cols = c | ||
175 | , dat = v | ||
176 | , tdat = transdata c v r | ||
177 | , order = RowMajor | ||
178 | , isTrans = False | ||
179 | } where (d,m) = dim v `divMod` c | ||
180 | r | m==0 = d | ||
181 | | otherwise = error "matrixFromVector" | ||
182 | |||
183 | matrixFromVector ColumnMajor c v = | ||
184 | M { rows = r | ||
185 | , cols = c | ||
186 | , dat = v | ||
187 | , tdat = transdata r v c | ||
188 | , order = ColumnMajor | ||
189 | , isTrans = False | ||
190 | } where (d,m) = dim v `divMod` c | ||
191 | r | m==0 = d | ||
192 | | otherwise = error "matrixFromVector" | ||
193 | 104 | ||
194 | -} | 105 | -- | Reads a matrix position. |
106 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t | ||
107 | infixl 9 @@> | ||
108 | --m@M {rows = r, cols = c} @@> (i,j) | ||
109 | -- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
110 | -- | otherwise = cdat m `at` (i*c+j) | ||
195 | 111 | ||
196 | matrixFromVector RowMajor c v = MC { rows = r, cols = c, dat = v} | 112 | MC {rows = r, cols = c, cdat = v} @@> (i,j) |
113 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
114 | | otherwise = v `at` (i*c+j) | ||
115 | |||
116 | MF {rows = r, cols = c, fdat = v} @@> (i,j) | ||
117 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
118 | | otherwise = v `at` (j*r+i) | ||
119 | |||
120 | ------------------------------------------------------------------ | ||
121 | |||
122 | matrixFromVector RowMajor c v = MC { rows = r, cols = c, cdat = v, fdat = transdata c v r } | ||
197 | where (d,m) = dim v `divMod` c | 123 | where (d,m) = dim v `divMod` c |
198 | r | m==0 = d | 124 | r | m==0 = d |
199 | | otherwise = error "matrixFromVector" | 125 | | otherwise = error "matrixFromVector" |
200 | 126 | ||
201 | matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, dat = v} | 127 | matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, fdat = v, cdat = transdata r v c } |
202 | where (d,m) = dim v `divMod` c | 128 | where (d,m) = dim v `divMod` c |
203 | r | m==0 = d | 129 | r | m==0 = d |
204 | | otherwise = error "matrixFromVector" | 130 | | otherwise = error "matrixFromVector" |
205 | 131 | ||
206 | |||
207 | |||
208 | |||
209 | |||
210 | |||
211 | createMatrix order r c = do | 132 | createMatrix order r c = do |
212 | p <- createVector (r*c) | 133 | p <- createVector (r*c) |
213 | return (matrixFromVector order c p) | 134 | return (matrixFromVector order c p) |
@@ -226,45 +147,94 @@ reshape c v = matrixFromVector RowMajor c v | |||
226 | 147 | ||
227 | singleton x = reshape 1 (fromList [x]) | 148 | singleton x = reshape 1 (fromList [x]) |
228 | 149 | ||
229 | --liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 150 | liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
230 | liftMatrix f m = reshape (cols m) (f (cdat m)) | 151 | liftMatrix f MC { cols = c, cdat = d } = matrixFromVector RowMajor c (f d) |
152 | liftMatrix f MF { cols = c, fdat = d } = matrixFromVector ColumnMajor c (f d) | ||
153 | |||
154 | |||
155 | liftMatrix2 :: (Field t, Field a, Field b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | ||
156 | liftMatrix2 f m1 m2 | ||
157 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | ||
158 | | otherwise = case m1 of | ||
159 | MC {} -> matrixFromVector RowMajor (cols m1) (f (cdat m1) (cdat m2)) | ||
160 | MF {} -> matrixFromVector ColumnMajor (cols m1) (f (fdat m1) (fdat m2)) | ||
231 | 161 | ||
232 | --liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | ||
233 | liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) | ||
234 | | otherwise = error "nonconformant matrices in liftMatrix2" | ||
235 | ------------------------------------------------------------------ | ||
236 | 162 | ||
237 | compat :: Matrix a -> Matrix b -> Bool | 163 | compat :: Matrix a -> Matrix b -> Bool |
238 | compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | 164 | compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 |
239 | 165 | ||
240 | dotL a b = sum (zipWith (*) a b) | 166 | ---------------------------------------------------------------- |
241 | 167 | ||
242 | multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] | 168 | -- | element types for which optimized matrix computations are provided |
243 | | otherwise = error "inconsistent dimensions in contraction " | 169 | class Storable a => Field a where |
244 | where ok = case common length a of | 170 | -- | @constant val n@ creates a vector with @n@ elements, all equal to @val@. |
245 | Nothing -> False | 171 | constant :: a -> Int -> Vector a |
246 | Just c -> c == length b | 172 | transdata :: Int -> Vector a -> Int -> Vector a |
173 | multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
174 | -- | extracts a submatrix froma a matrix | ||
175 | subMatrix :: (Int,Int) -- ^ (r0,c0) starting position | ||
176 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
177 | -> Matrix a -> Matrix a | ||
178 | -- | creates a square matrix with the given diagonal | ||
179 | diag :: Vector a -> Matrix a | ||
247 | 180 | ||
248 | transL m = matrixFromVector RowMajor (rows m) $ transdata (cols m) (cdat m) (rows m) | 181 | instance Field Double where |
182 | constant = constantR | ||
183 | transdata = transdataR | ||
184 | multiplyD = multiplyR | ||
185 | subMatrix = subMatrixR | ||
186 | diag = diagR | ||
249 | 187 | ||
250 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | 188 | instance Field (Complex Double) where |
189 | constant = constantC | ||
190 | transdata = transdataC | ||
191 | multiplyD = multiplyC | ||
192 | subMatrix = subMatrixC | ||
193 | diag = diagC | ||
251 | 194 | ||
252 | ------------------------------------------------------------------ | 195 | ------------------------------------------------------------------ |
253 | 196 | ||
254 | {- | 197 | instance (Show a, Field a) => (Show (Matrix a)) where |
255 | gmatC m f | fortran m = | 198 | show m = (sizes++) . dsp . map (map show) . toLists $ m |
256 | if (isTrans m) | 199 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" |
257 | then f 0 (rows m) (cols m) (ptr (dat m)) | 200 | |
258 | else f 1 (cols m) (rows m) (ptr (dat m)) | 201 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp |
259 | | otherwise = | 202 | where |
260 | if isTrans m | 203 | mt = transpose as |
261 | then f 1 (cols m) (rows m) (ptr (dat m)) | 204 | longs = map (maximum . map length) mt |
262 | else f 0 (rows m) (cols m) (ptr (dat m)) | 205 | mtp = zipWith (\a b -> map (pad a) b) longs mt |
263 | -} | 206 | pad n str = replicate (n - length str) ' ' ++ str |
207 | unwords' = concat . intersperse ", " | ||
264 | 208 | ||
265 | gmatC MF {rows = r, cols = c, dat = d} f = f 1 c r (ptr d) | 209 | ------------------------------------------------------------------ |
266 | gmatC MC {rows = r, cols = c, dat = d} f = f 0 r c (ptr d) | 210 | |
267 | {-# INLINE gmatC #-} | 211 | transdataR :: Int -> Vector Double -> Int -> Vector Double |
212 | transdataR = transdataAux ctransR | ||
213 | |||
214 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
215 | transdataC = transdataAux ctransC | ||
216 | |||
217 | transdataAux fun c1 d c2 = | ||
218 | if noneed | ||
219 | then d | ||
220 | else unsafePerformIO $ do | ||
221 | v <- createVector (dim d) | ||
222 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
223 | --putStrLn "---> transdataAux" | ||
224 | return v | ||
225 | where r1 = dim d `div` c1 | ||
226 | r2 = dim d `div` c2 | ||
227 | noneed = r1 == 1 || c1 == 1 | ||
228 | |||
229 | foreign import ccall safe "aux.h transR" | ||
230 | ctransR :: TMM -- Double ::> Double ::> IO Int | ||
231 | foreign import ccall safe "aux.h transC" | ||
232 | ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int | ||
233 | |||
234 | ------------------------------------------------------------------ | ||
235 | |||
236 | gmatC MF {rows = r, cols = c, fdat = d} f = f 1 c r (ptr d) | ||
237 | gmatC MC {rows = r, cols = c, cdat = d} f = f 0 r c (ptr d) | ||
268 | 238 | ||
269 | multiplyAux fun order a b = unsafePerformIO $ do | 239 | multiplyAux fun order a b = unsafePerformIO $ do |
270 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | 240 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ |
@@ -272,14 +242,15 @@ multiplyAux fun order a b = unsafePerformIO $ do | |||
272 | r <- createMatrix order (rows a) (cols b) | 242 | r <- createMatrix order (rows a) (cols b) |
273 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] | 243 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] |
274 | return r | 244 | return r |
275 | {-# INLINE multiplyAux #-} | ||
276 | 245 | ||
246 | multiplyR = multiplyAux cmultiplyR | ||
277 | foreign import ccall safe "aux.h multiplyR" | 247 | foreign import ccall safe "aux.h multiplyR" |
278 | cmultiplyR :: Int -> Int -> Int -> Ptr Double | 248 | cmultiplyR :: Int -> Int -> Int -> Ptr Double |
279 | -> Int -> Int -> Int -> Ptr Double | 249 | -> Int -> Int -> Int -> Ptr Double |
280 | -> Int -> Int -> Ptr Double | 250 | -> Int -> Int -> Ptr Double |
281 | -> IO Int | 251 | -> IO Int |
282 | 252 | ||
253 | multiplyC = multiplyAux cmultiplyC | ||
283 | foreign import ccall safe "aux.h multiplyC" | 254 | foreign import ccall safe "aux.h multiplyC" |
284 | cmultiplyC :: Int -> Int -> Int -> Ptr (Complex Double) | 255 | cmultiplyC :: Int -> Int -> Int -> Ptr (Complex Double) |
285 | -> Int -> Int -> Int -> Ptr (Complex Double) | 256 | -> Int -> Int -> Int -> Ptr (Complex Double) |
@@ -288,14 +259,9 @@ foreign import ccall safe "aux.h multiplyC" | |||
288 | 259 | ||
289 | multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | 260 | multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a |
290 | multiply RowMajor a b = multiplyD RowMajor a b | 261 | multiply RowMajor a b = multiplyD RowMajor a b |
291 | multiply ColumnMajor a b = MF {rows = c, cols = r, dat = d} | 262 | multiply ColumnMajor a b = MF {rows = c, cols = r, fdat = d, cdat = dt } |
292 | where MC {rows = r, cols = c, dat = d } = multiplyD RowMajor (trans b) (trans a) | 263 | where MC {rows = r, cols = c, cdat = d, fdat = dt } = multiplyD RowMajor (trans b) (trans a) |
293 | 264 | -- FIXME using MatrixFromVector | |
294 | |||
295 | multiplyR = multiplyAux cmultiplyR' | ||
296 | multiplyC = multiplyAux cmultiplyC | ||
297 | |||
298 | cmultiplyR' p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3 = {-# SCC "mulR" #-} cmultiplyR p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3 | ||
299 | 265 | ||
300 | ---------------------------------------------------------------------- | 266 | ---------------------------------------------------------------------- |
301 | 267 | ||
@@ -318,18 +284,6 @@ subMatrixC (r0,c0) (rt,ct) x = | |||
318 | subMatrixR (r0,2*c0) (rt,2*ct) . | 284 | subMatrixR (r0,2*c0) (rt,2*ct) . |
319 | reshape (2*cols x) . asReal . cdat $ x | 285 | reshape (2*cols x) . asReal . cdat $ x |
320 | 286 | ||
321 | --subMatrix :: (Field a) | ||
322 | -- => (Int,Int) -- ^ (r0,c0) starting position | ||
323 | -- -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
324 | -- -> Matrix a -> Matrix a | ||
325 | --subMatrix st sz m | ||
326 | -- | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m) | ||
327 | -- | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m) | ||
328 | -- | otherwise = subMatrixG st sz m | ||
329 | |||
330 | subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) | ||
331 | where subList s n = take n . drop s | ||
332 | |||
333 | --------------------------------------------------------------------- | 287 | --------------------------------------------------------------------- |
334 | 288 | ||
335 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do | 289 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do |
@@ -347,66 +301,7 @@ diagC :: Vector (Complex Double) -> Matrix (Complex Double) | |||
347 | diagC = diagAux c_diagC "diagC" | 301 | diagC = diagAux c_diagC "diagC" |
348 | foreign import ccall "aux.h diagC" c_diagC :: TCVCM | 302 | foreign import ccall "aux.h diagC" c_diagC :: TCVCM |
349 | 303 | ||
350 | -- | diagonal matrix from a vector | 304 | ------------------------------------------------------------------------ |
351 | --diag :: (Num a, Field a) => Vector a -> Matrix a | ||
352 | --diag v | ||
353 | -- | isReal (baseOf) v = scast $ diagR (scast v) | ||
354 | -- | isComp (baseOf) v = scast $ diagC (scast v) | ||
355 | -- | otherwise = diagG v | ||
356 | |||
357 | diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] | ||
358 | where c = dim v | ||
359 | l = toList v | ||
360 | delta i j | i==j = 1 | ||
361 | | otherwise = 0 | ||
362 | |||
363 | -- | creates a Matrix from a list of vectors | ||
364 | --fromRows :: Field t => [Vector t] -> Matrix t | ||
365 | fromRows vs = case common dim vs of | ||
366 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" | ||
367 | Just c -> reshape c (join vs) | ||
368 | |||
369 | -- | extracts the rows of a matrix as a list of vectors | ||
370 | --toRows :: Storable t => Matrix t -> [Vector t] | ||
371 | toRows m = toRows' 0 where | ||
372 | v = cdat m | ||
373 | r = rows m | ||
374 | c = cols m | ||
375 | toRows' k | k == r*c = [] | ||
376 | | otherwise = subVector k c v : toRows' (k+c) | ||
377 | |||
378 | -- | Creates a matrix from a list of vectors, as columns | ||
379 | fromColumns :: Field t => [Vector t] -> Matrix t | ||
380 | fromColumns m = trans . fromRows $ m | ||
381 | |||
382 | -- | Creates a list of vectors from the columns of a matrix | ||
383 | toColumns :: Field t => Matrix t -> [Vector t] | ||
384 | toColumns m = toRows . trans $ m | ||
385 | |||
386 | |||
387 | -- | Reads a matrix position. | ||
388 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t | ||
389 | infixl 9 @@> | ||
390 | --m@M {rows = r, cols = c} @@> (i,j) | ||
391 | -- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
392 | -- | otherwise = cdat m `at` (i*c+j) | ||
393 | |||
394 | MC {rows = r, cols = c, dat = v} @@> (i,j) | ||
395 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
396 | | otherwise = v `at` (i*c+j) | ||
397 | |||
398 | MF {rows = r, cols = c, dat = v} @@> (i,j) | ||
399 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
400 | | otherwise = v `at` (j*r+i) | ||
401 | |||
402 | |||
403 | ------------------------------------------------------------------ | ||
404 | |||
405 | constantR :: Double -> Int -> Vector Double | ||
406 | constantR = constantAux cconstantR | ||
407 | |||
408 | constantC :: Complex Double -> Int -> Vector (Complex Double) | ||
409 | constantC = constantAux cconstantC | ||
410 | 305 | ||
411 | constantAux fun x n = unsafePerformIO $ do | 306 | constantAux fun x n = unsafePerformIO $ do |
412 | v <- createVector n | 307 | v <- createVector n |
@@ -415,8 +310,41 @@ constantAux fun x n = unsafePerformIO $ do | |||
415 | free px | 310 | free px |
416 | return v | 311 | return v |
417 | 312 | ||
313 | constantR :: Double -> Int -> Vector Double | ||
314 | constantR = constantAux cconstantR | ||
418 | foreign import ccall safe "aux.h constantR" | 315 | foreign import ccall safe "aux.h constantR" |
419 | cconstantR :: Ptr Double -> TV -- Double :> IO Int | 316 | cconstantR :: Ptr Double -> TV -- Double :> IO Int |
420 | 317 | ||
318 | constantC :: Complex Double -> Int -> Vector (Complex Double) | ||
319 | constantC = constantAux cconstantC | ||
421 | foreign import ccall safe "aux.h constantC" | 320 | foreign import ccall safe "aux.h constantC" |
422 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int | 321 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int |
322 | |||
323 | ------------------------------------------------------------------------- | ||
324 | |||
325 | -- Generic definitions | ||
326 | |||
327 | {- | ||
328 | transL m = matrixFromVector RowMajor (rows m) $ transdata (cols m) (cdat m) (rows m) | ||
329 | |||
330 | subMatrixG (r0,c0) (rt,ct) x = matrixFromVector RowMajor ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) | ||
331 | where subList s n = take n . drop s | ||
332 | |||
333 | diagG v = matrixFromVector RowMajor c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] | ||
334 | where c = dim v | ||
335 | l = toList v | ||
336 | delta i j | i==j = 1 | ||
337 | | otherwise = 0 | ||
338 | -} | ||
339 | |||
340 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
341 | |||
342 | dotL a b = sum (zipWith (*) a b) | ||
343 | |||
344 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | ||
345 | |||
346 | multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] | ||
347 | | otherwise = error "inconsistent dimensions in contraction " | ||
348 | where ok = case common length a of | ||
349 | Nothing -> False | ||
350 | Just c -> c == length b | ||