diff options
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 204 |
1 files changed, 118 insertions, 86 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 9895393..48652f3 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts #-} | 1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-overlapping-instances #-} |
2 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
3 | -- | | 3 | -- | |
4 | -- Module : Data.Packed.Internal.Matrix | 4 | -- Module : Data.Packed.Internal.Matrix |
@@ -22,9 +22,65 @@ import Foreign hiding (xor) | |||
22 | import Complex | 22 | import Complex |
23 | import Control.Monad(when) | 23 | import Control.Monad(when) |
24 | import Data.List(transpose,intersperse) | 24 | import Data.List(transpose,intersperse) |
25 | import Data.Typeable | 25 | --import Data.Typeable |
26 | import Data.Maybe(fromJust) | 26 | import Data.Maybe(fromJust) |
27 | 27 | ||
28 | ---------------------------------------------------------------- | ||
29 | |||
30 | class Storable a => Field a where | ||
31 | constant :: a -> Int -> Vector a | ||
32 | transdata :: Int -> Vector a -> Int -> Vector a | ||
33 | multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
34 | subMatrix :: (Int,Int) -- ^ (r0,c0) starting position | ||
35 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | ||
36 | -> Matrix a -> Matrix a | ||
37 | diag :: Vector a -> Matrix a | ||
38 | |||
39 | |||
40 | instance Field Double where | ||
41 | constant = constantR | ||
42 | transdata = transdataR | ||
43 | multiplyD = multiplyR | ||
44 | subMatrix = subMatrixR | ||
45 | diag = diagR | ||
46 | |||
47 | instance Field (Complex Double) where | ||
48 | constant = constantC | ||
49 | transdata = transdataC | ||
50 | multiplyD = multiplyC | ||
51 | subMatrix = subMatrixC | ||
52 | diag = diagC | ||
53 | |||
54 | ----------------------------------------------------------------- | ||
55 | |||
56 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
57 | transdataR = transdataAux ctransR | ||
58 | |||
59 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
60 | transdataC = transdataAux ctransC | ||
61 | |||
62 | transdataAux fun c1 d c2 = | ||
63 | if noneed | ||
64 | then d | ||
65 | else unsafePerformIO $ do | ||
66 | v <- createVector (dim d) | ||
67 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
68 | --putStrLn "---> transdataAux" | ||
69 | return v | ||
70 | where r1 = dim d `div` c1 | ||
71 | r2 = dim d `div` c2 | ||
72 | noneed = r1 == 1 || c1 == 1 | ||
73 | |||
74 | foreign import ccall safe "aux.h transR" | ||
75 | ctransR :: TMM -- Double ::> Double ::> IO Int | ||
76 | foreign import ccall safe "aux.h transC" | ||
77 | ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int | ||
78 | |||
79 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
80 | |||
81 | |||
82 | |||
83 | |||
28 | 84 | ||
29 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | 85 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
30 | 86 | ||
@@ -34,9 +90,18 @@ data Matrix t = M { rows :: Int | |||
34 | , tdat :: Vector t | 90 | , tdat :: Vector t |
35 | , isTrans :: Bool | 91 | , isTrans :: Bool |
36 | , order :: MatrixOrder | 92 | , order :: MatrixOrder |
37 | } deriving Typeable | 93 | } -- deriving Typeable |
38 | 94 | ||
39 | 95 | ||
96 | data NMat t = MC { rws, cls :: Int, dtc :: Vector t} | ||
97 | | MF { rws, cls :: Int, dtf :: Vector t} | ||
98 | | Tr (NMat t) | ||
99 | |||
100 | ntrans (Tr m) = m | ||
101 | ntrans m = Tr m | ||
102 | |||
103 | viewC m@MC{} = m | ||
104 | viewF m@MF{} = m | ||
40 | 105 | ||
41 | fortran m = order m == ColumnMajor | 106 | fortran m = order m == ColumnMajor |
42 | 107 | ||
@@ -78,7 +143,11 @@ matrixFromVector RowMajor c v = | |||
78 | , tdat = transdata c v r | 143 | , tdat = transdata c v r |
79 | , order = RowMajor | 144 | , order = RowMajor |
80 | , isTrans = False | 145 | , isTrans = False |
81 | } where r = dim v `div` c -- TODO check mod=0 | 146 | } where (d,m) = dim v `divMod` c |
147 | r | m==0 = d | ||
148 | | otherwise = error "matrixFromVector" | ||
149 | |||
150 | -- r = dim v `div` c -- TODO check mod=0 | ||
82 | 151 | ||
83 | matrixFromVector ColumnMajor c v = | 152 | matrixFromVector ColumnMajor c v = |
84 | M { rows = r | 153 | M { rows = r |
@@ -87,7 +156,9 @@ matrixFromVector ColumnMajor c v = | |||
87 | , tdat = transdata r v c | 156 | , tdat = transdata r v c |
88 | , order = ColumnMajor | 157 | , order = ColumnMajor |
89 | , isTrans = False | 158 | , isTrans = False |
90 | } where r = dim v `div` c -- TODO check mod=0 | 159 | } where (d,m) = dim v `divMod` c |
160 | r | m==0 = d | ||
161 | | otherwise = error "matrixFromVector" | ||
91 | 162 | ||
92 | createMatrix order r c = do | 163 | createMatrix order r c = do |
93 | p <- createVector (r*c) | 164 | p <- createVector (r*c) |
@@ -102,48 +173,11 @@ createMatrix order r c = do | |||
102 | , 9.0, 10.0, 11.0, 12.0 ]@ | 173 | , 9.0, 10.0, 11.0, 12.0 ]@ |
103 | 174 | ||
104 | -} | 175 | -} |
105 | reshape :: (Field t) => Int -> Vector t -> Matrix t | 176 | reshape :: Field t => Int -> Vector t -> Matrix t |
106 | reshape c v = matrixFromVector RowMajor c v | 177 | reshape c v = matrixFromVector RowMajor c v |
107 | 178 | ||
108 | singleton x = reshape 1 (fromList [x]) | 179 | singleton x = reshape 1 (fromList [x]) |
109 | 180 | ||
110 | transdataG :: Storable a => Int -> Vector a -> Int -> Vector a | ||
111 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
112 | |||
113 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
114 | transdataR = transdataAux ctransR | ||
115 | |||
116 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
117 | transdataC = transdataAux ctransC | ||
118 | |||
119 | transdataAux fun c1 d c2 = | ||
120 | if noneed | ||
121 | then d | ||
122 | else unsafePerformIO $ do | ||
123 | v <- createVector (dim d) | ||
124 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
125 | --putStrLn "---> transdataAux" | ||
126 | return v | ||
127 | where r1 = dim d `div` c1 | ||
128 | r2 = dim d `div` c2 | ||
129 | noneed = r1 == 1 || c1 == 1 | ||
130 | |||
131 | foreign import ccall safe "aux.h transR" | ||
132 | ctransR :: TMM -- Double ::> Double ::> IO Int | ||
133 | foreign import ccall safe "aux.h transC" | ||
134 | ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int | ||
135 | |||
136 | transdata :: Field a => Int -> Vector a -> Int -> Vector a | ||
137 | transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | ||
138 | | isComp baseOf d = scast $ transdataC c1 (scast d) c2 | ||
139 | | otherwise = transdataG c1 d c2 | ||
140 | |||
141 | --transdata :: Storable a => Int -> Vector a -> Int -> Vector a | ||
142 | --transdata = transdataG | ||
143 | --{-# RULES "transdataR" transdata=transdataR #-} | ||
144 | --{-# RULES "transdataC" transdata=transdataC #-} | ||
145 | |||
146 | ----------------------------------------------------------------- | ||
147 | liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 181 | liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
148 | liftMatrix f m = reshape (cols m) (f (cdat m)) | 182 | liftMatrix f m = reshape (cols m) (f (cdat m)) |
149 | 183 | ||
@@ -163,7 +197,7 @@ multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] | |||
163 | Nothing -> False | 197 | Nothing -> False |
164 | Just c -> c == length b | 198 | Just c -> c == length b |
165 | 199 | ||
166 | transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) | 200 | transL m = matrixFromVector RowMajor (rows m) $ transdata (cols m) (cdat m) (rows m) |
167 | 201 | ||
168 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | 202 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) |
169 | 203 | ||
@@ -179,7 +213,7 @@ gmatC m f | fortran m = | |||
179 | else f 0 (rows m) (cols m) (ptr (dat m)) | 213 | else f 0 (rows m) (cols m) (ptr (dat m)) |
180 | 214 | ||
181 | 215 | ||
182 | multiplyAux order fun a b = unsafePerformIO $ do | 216 | multiplyAux fun order a b = unsafePerformIO $ do |
183 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | 217 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ |
184 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | 218 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) |
185 | r <- createMatrix order (rows a) (cols b) | 219 | r <- createMatrix order (rows a) (cols b) |
@@ -198,37 +232,14 @@ foreign import ccall safe "aux.h multiplyC" | |||
198 | -> Int -> Int -> Ptr (Complex Double) | 232 | -> Int -> Int -> Ptr (Complex Double) |
199 | -> IO Int | 233 | -> IO Int |
200 | 234 | ||
201 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | 235 | multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a |
202 | multiply RowMajor a b = multiplyD RowMajor a b | 236 | multiply RowMajor a b = multiplyD RowMajor a b |
203 | multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} | 237 | multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} |
204 | where m = multiplyD RowMajor (trans b) (trans a) | 238 | where m = multiplyD RowMajor (trans b) (trans a) |
205 | 239 | ||
206 | multiplyD order a b | ||
207 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | ||
208 | | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) | ||
209 | | otherwise = multiplyG a b | ||
210 | |||
211 | ---------------------------------------------------------------------- | ||
212 | 240 | ||
213 | outer' u v = dat (outer u v) | 241 | multiplyR = multiplyAux cmultiplyR |
214 | 242 | multiplyC = multiplyAux cmultiplyC | |
215 | {- | Outer product of two vectors. | ||
216 | |||
217 | @\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] | ||
218 | (3><3) | ||
219 | [ 5.0, 2.0, 3.0 | ||
220 | , 10.0, 4.0, 6.0 | ||
221 | , 15.0, 6.0, 9.0 ]@ | ||
222 | -} | ||
223 | outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t | ||
224 | outer u v = multiply RowMajor r c | ||
225 | where r = matrixFromVector RowMajor 1 u | ||
226 | c = matrixFromVector RowMajor (dim v) v | ||
227 | |||
228 | dot :: (Field t, Num t) => Vector t -> Vector t -> t | ||
229 | dot u v = dat (multiply RowMajor r c) `at` 0 | ||
230 | where r = matrixFromVector RowMajor (dim u) u | ||
231 | c = matrixFromVector RowMajor 1 v | ||
232 | 243 | ||
233 | ---------------------------------------------------------------------- | 244 | ---------------------------------------------------------------------- |
234 | 245 | ||
@@ -251,14 +262,14 @@ subMatrixC (r0,c0) (rt,ct) x = | |||
251 | subMatrixR (r0,2*c0) (rt,2*ct) . | 262 | subMatrixR (r0,2*c0) (rt,2*ct) . |
252 | reshape (2*cols x) . asReal . cdat $ x | 263 | reshape (2*cols x) . asReal . cdat $ x |
253 | 264 | ||
254 | subMatrix :: (Field a) | 265 | --subMatrix :: (Field a) |
255 | => (Int,Int) -- ^ (r0,c0) starting position | 266 | -- => (Int,Int) -- ^ (r0,c0) starting position |
256 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | 267 | -- -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix |
257 | -> Matrix a -> Matrix a | 268 | -- -> Matrix a -> Matrix a |
258 | subMatrix st sz m | 269 | --subMatrix st sz m |
259 | | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m) | 270 | -- | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m) |
260 | | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m) | 271 | -- | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m) |
261 | | otherwise = subMatrixG st sz m | 272 | -- | otherwise = subMatrixG st sz m |
262 | 273 | ||
263 | subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) | 274 | subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) |
264 | where subList s n = take n . drop s | 275 | where subList s n = take n . drop s |
@@ -281,11 +292,11 @@ diagC = diagAux c_diagC "diagC" | |||
281 | foreign import ccall "aux.h diagC" c_diagC :: TCVCM | 292 | foreign import ccall "aux.h diagC" c_diagC :: TCVCM |
282 | 293 | ||
283 | -- | diagonal matrix from a vector | 294 | -- | diagonal matrix from a vector |
284 | diag :: (Num a, Field a) => Vector a -> Matrix a | 295 | --diag :: (Num a, Field a) => Vector a -> Matrix a |
285 | diag v | 296 | --diag v |
286 | | isReal (baseOf) v = scast $ diagR (scast v) | 297 | -- | isReal (baseOf) v = scast $ diagR (scast v) |
287 | | isComp (baseOf) v = scast $ diagC (scast v) | 298 | -- | isComp (baseOf) v = scast $ diagC (scast v) |
288 | | otherwise = diagG v | 299 | -- | otherwise = diagG v |
289 | 300 | ||
290 | diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] | 301 | diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] |
291 | where c = dim v | 302 | where c = dim v |
@@ -313,13 +324,34 @@ fromColumns :: Field t => [Vector t] -> Matrix t | |||
313 | fromColumns m = trans . fromRows $ m | 324 | fromColumns m = trans . fromRows $ m |
314 | 325 | ||
315 | -- | Creates a list of vectors from the columns of a matrix | 326 | -- | Creates a list of vectors from the columns of a matrix |
316 | toColumns :: Field t => Matrix t -> [Vector t] | 327 | toColumns :: Storable t => Matrix t -> [Vector t] |
317 | toColumns m = toRows . trans $ m | 328 | toColumns m = toRows . trans $ m |
318 | 329 | ||
319 | 330 | ||
320 | -- | Reads a matrix position. | 331 | -- | Reads a matrix position. |
321 | (@@>) :: Field t => Matrix t -> (Int,Int) -> t | 332 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t |
322 | infixl 9 @@> | 333 | infixl 9 @@> |
323 | m@M {rows = r, cols = c} @@> (i,j) | 334 | m@M {rows = r, cols = c} @@> (i,j) |
324 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | 335 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" |
325 | | otherwise = cdat m `at` (i*c+j) | 336 | | otherwise = cdat m `at` (i*c+j) |
337 | |||
338 | ------------------------------------------------------------------ | ||
339 | |||
340 | constantR :: Double -> Int -> Vector Double | ||
341 | constantR = constantAux cconstantR | ||
342 | |||
343 | constantC :: Complex Double -> Int -> Vector (Complex Double) | ||
344 | constantC = constantAux cconstantC | ||
345 | |||
346 | constantAux fun x n = unsafePerformIO $ do | ||
347 | v <- createVector n | ||
348 | px <- newArray [x] | ||
349 | fun px // vec v // check "constantAux" [] | ||
350 | free px | ||
351 | return v | ||
352 | |||
353 | foreign import ccall safe "aux.h constantR" | ||
354 | cconstantR :: Ptr Double -> TV -- Double :> IO Int | ||
355 | |||
356 | foreign import ccall safe "aux.h constantC" | ||
357 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int | ||