diff options
Diffstat (limited to 'lib/Data/Packed')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 131 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 2 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 4 |
3 files changed, 101 insertions, 36 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 48652f3..6ba2d06 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -27,6 +27,10 @@ import Data.Maybe(fromJust) | |||
27 | 27 | ||
28 | ---------------------------------------------------------------- | 28 | ---------------------------------------------------------------- |
29 | 29 | ||
30 | -- the condition Storable a => Field a means that we can only put | ||
31 | -- in Field types that are in Storable, and therefore Storable a | ||
32 | -- is not required in signatures if we have a Field a. | ||
33 | |||
30 | class Storable a => Field a where | 34 | class Storable a => Field a where |
31 | constant :: a -> Int -> Vector a | 35 | constant :: a -> Int -> Vector a |
32 | transdata :: Int -> Vector a -> Int -> Vector a | 36 | transdata :: Int -> Vector a -> Int -> Vector a |
@@ -36,7 +40,6 @@ class Storable a => Field a where | |||
36 | -> Matrix a -> Matrix a | 40 | -> Matrix a -> Matrix a |
37 | diag :: Vector a -> Matrix a | 41 | diag :: Vector a -> Matrix a |
38 | 42 | ||
39 | |||
40 | instance Field Double where | 43 | instance Field Double where |
41 | constant = constantR | 44 | constant = constantR |
42 | transdata = transdataR | 45 | transdata = transdataR |
@@ -78,12 +81,40 @@ foreign import ccall safe "aux.h transC" | |||
78 | 81 | ||
79 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | 82 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d |
80 | 83 | ||
84 | {- Design considerations for the Matrix Type | ||
85 | ----------------------------------------- | ||
86 | |||
87 | - we must easily handle both row major and column major order, | ||
88 | for bindings to LAPACK and GSL/C | ||
89 | |||
90 | - we'd like to simplify redundant matrix transposes: | ||
91 | - Some of them arise from the order requirements of some functions | ||
92 | - some functions (matrix product) admit transposed arguments | ||
81 | 93 | ||
94 | - maybe we don't really need this kind of simplification: | ||
95 | - more complex code | ||
96 | - some computational overhead | ||
97 | - only appreciable gain in code with a lot of redundant transpositions | ||
98 | and cheap matrix computations | ||
82 | 99 | ||
100 | - we could carry both the matrix and its (lazily computed) transpose. | ||
101 | This may save some transpositions, but it is necessary to keep track of the | ||
102 | data which is actually computed to be used by functions like the matrix product | ||
103 | which admit both orders. Therefore, maybe it is better to have something like | ||
104 | viewC and viewF, which may actually perform a transpose if required. | ||
83 | 105 | ||
106 | - but if we need the transposed data and it is not in the structure, we must make | ||
107 | sure that we touch the same foreignptr that is used in the computation. Access | ||
108 | to such pointer cannot be made by creating a new vector. | ||
109 | |||
110 | -} | ||
84 | 111 | ||
85 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | 112 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
86 | 113 | ||
114 | {- | ||
115 | |||
116 | |||
117 | |||
87 | data Matrix t = M { rows :: Int | 118 | data Matrix t = M { rows :: Int |
88 | , cols :: Int | 119 | , cols :: Int |
89 | , dat :: Vector t | 120 | , dat :: Vector t |
@@ -91,28 +122,26 @@ data Matrix t = M { rows :: Int | |||
91 | , isTrans :: Bool | 122 | , isTrans :: Bool |
92 | , order :: MatrixOrder | 123 | , order :: MatrixOrder |
93 | } -- deriving Typeable | 124 | } -- deriving Typeable |
125 | -} | ||
94 | 126 | ||
127 | data Matrix t = MC { rows :: Int, cols :: Int, dat :: Vector t } -- row major order | ||
128 | | MF { rows :: Int, cols :: Int, dat :: Vector t } -- column major order | ||
95 | 129 | ||
96 | data NMat t = MC { rws, cls :: Int, dtc :: Vector t} | 130 | -- transposition just changes the data order |
97 | | MF { rws, cls :: Int, dtf :: Vector t} | 131 | trans :: Matrix t -> Matrix t |
98 | | Tr (NMat t) | 132 | trans MC {rows = r, cols = c, dat = d} = MF {rows = c, cols = r, dat = d} |
99 | 133 | trans MF {rows = r, cols = c, dat = d} = MC {rows = c, cols = r, dat = d} | |
100 | ntrans (Tr m) = m | ||
101 | ntrans m = Tr m | ||
102 | 134 | ||
103 | viewC m@MC{} = m | 135 | viewC m@MC{} = m |
104 | viewF m@MF{} = m | 136 | viewC MF {rows = r, cols = c, dat = d} = MC {rows = r, cols = c, dat = transdata r d c} |
105 | 137 | ||
106 | fortran m = order m == ColumnMajor | 138 | viewF m@MF{} = m |
139 | viewF MC {rows = r, cols = c, dat = d} = MF {rows = r, cols = c, dat = transdata c d r} | ||
107 | 140 | ||
108 | cdat m = if fortran m `xor` isTrans m then tdat m else dat m | 141 | --fortran m = order m == ColumnMajor |
109 | fdat m = if fortran m `xor` isTrans m then dat m else tdat m | ||
110 | 142 | ||
111 | trans :: Matrix t -> Matrix t | 143 | cdat m = dat (viewC m) |
112 | trans m = m { rows = cols m | 144 | fdat m = dat (viewF m) |
113 | , cols = rows m | ||
114 | , isTrans = not (isTrans m) | ||
115 | } | ||
116 | 145 | ||
117 | type Mt t s = Int -> Int -> Ptr t -> s | 146 | type Mt t s = Int -> Int -> Ptr t -> s |
118 | -- not yet admitted by my haddock version | 147 | -- not yet admitted by my haddock version |
@@ -120,11 +149,14 @@ type Mt t s = Int -> Int -> Ptr t -> s | |||
120 | -- type t ::> s = Mt t s | 149 | -- type t ::> s = Mt t s |
121 | 150 | ||
122 | mat d m f = f (rows m) (cols m) (ptr (d m)) | 151 | mat d m f = f (rows m) (cols m) (ptr (d m)) |
152 | --mat m f = f (rows m) (cols m) (ptr (dat m)) | ||
153 | --matC m f = f (rows m) (cols m) (ptr (cdat m)) | ||
154 | |||
123 | 155 | ||
124 | toLists :: (Storable t) => Matrix t -> [[t]] | 156 | --toLists :: (Storable t) => Matrix t -> [[t]] |
125 | toLists m = partit (cols m) . toList . cdat $ m | 157 | toLists m = partit (cols m) . toList . cdat $ m |
126 | 158 | ||
127 | instance (Show a, Storable a) => (Show (Matrix a)) where | 159 | instance (Show a, Field a) => (Show (Matrix a)) where |
128 | show m = (sizes++) . dsp . map (map show) . toLists $ m | 160 | show m = (sizes++) . dsp . map (map show) . toLists $ m |
129 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | 161 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" |
130 | 162 | ||
@@ -136,6 +168,7 @@ dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unw | |||
136 | pad n str = replicate (n - length str) ' ' ++ str | 168 | pad n str = replicate (n - length str) ' ' ++ str |
137 | unwords' = concat . intersperse ", " | 169 | unwords' = concat . intersperse ", " |
138 | 170 | ||
171 | {- | ||
139 | matrixFromVector RowMajor c v = | 172 | matrixFromVector RowMajor c v = |
140 | M { rows = r | 173 | M { rows = r |
141 | , cols = c | 174 | , cols = c |
@@ -147,8 +180,6 @@ matrixFromVector RowMajor c v = | |||
147 | r | m==0 = d | 180 | r | m==0 = d |
148 | | otherwise = error "matrixFromVector" | 181 | | otherwise = error "matrixFromVector" |
149 | 182 | ||
150 | -- r = dim v `div` c -- TODO check mod=0 | ||
151 | |||
152 | matrixFromVector ColumnMajor c v = | 183 | matrixFromVector ColumnMajor c v = |
153 | M { rows = r | 184 | M { rows = r |
154 | , cols = c | 185 | , cols = c |
@@ -160,6 +191,23 @@ matrixFromVector ColumnMajor c v = | |||
160 | r | m==0 = d | 191 | r | m==0 = d |
161 | | otherwise = error "matrixFromVector" | 192 | | otherwise = error "matrixFromVector" |
162 | 193 | ||
194 | -} | ||
195 | |||
196 | matrixFromVector RowMajor c v = MC { rows = r, cols = c, dat = v} | ||
197 | where (d,m) = dim v `divMod` c | ||
198 | r | m==0 = d | ||
199 | | otherwise = error "matrixFromVector" | ||
200 | |||
201 | matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, dat = v} | ||
202 | where (d,m) = dim v `divMod` c | ||
203 | r | m==0 = d | ||
204 | | otherwise = error "matrixFromVector" | ||
205 | |||
206 | |||
207 | |||
208 | |||
209 | |||
210 | |||
163 | createMatrix order r c = do | 211 | createMatrix order r c = do |
164 | p <- createVector (r*c) | 212 | p <- createVector (r*c) |
165 | return (matrixFromVector order c p) | 213 | return (matrixFromVector order c p) |
@@ -178,10 +226,10 @@ reshape c v = matrixFromVector RowMajor c v | |||
178 | 226 | ||
179 | singleton x = reshape 1 (fromList [x]) | 227 | singleton x = reshape 1 (fromList [x]) |
180 | 228 | ||
181 | liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 229 | --liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
182 | liftMatrix f m = reshape (cols m) (f (cdat m)) | 230 | liftMatrix f m = reshape (cols m) (f (cdat m)) |
183 | 231 | ||
184 | liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 232 | --liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
185 | liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) | 233 | liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) |
186 | | otherwise = error "nonconformant matrices in liftMatrix2" | 234 | | otherwise = error "nonconformant matrices in liftMatrix2" |
187 | ------------------------------------------------------------------ | 235 | ------------------------------------------------------------------ |
@@ -203,6 +251,7 @@ multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multipl | |||
203 | 251 | ||
204 | ------------------------------------------------------------------ | 252 | ------------------------------------------------------------------ |
205 | 253 | ||
254 | {- | ||
206 | gmatC m f | fortran m = | 255 | gmatC m f | fortran m = |
207 | if (isTrans m) | 256 | if (isTrans m) |
208 | then f 0 (rows m) (cols m) (ptr (dat m)) | 257 | then f 0 (rows m) (cols m) (ptr (dat m)) |
@@ -211,7 +260,11 @@ gmatC m f | fortran m = | |||
211 | if isTrans m | 260 | if isTrans m |
212 | then f 1 (cols m) (rows m) (ptr (dat m)) | 261 | then f 1 (cols m) (rows m) (ptr (dat m)) |
213 | else f 0 (rows m) (cols m) (ptr (dat m)) | 262 | else f 0 (rows m) (cols m) (ptr (dat m)) |
263 | -} | ||
214 | 264 | ||
265 | gmatC MF {rows = r, cols = c, dat = d} f = f 1 c r (ptr d) | ||
266 | gmatC MC {rows = r, cols = c, dat = d} f = f 0 r c (ptr d) | ||
267 | {-# INLINE gmatC #-} | ||
215 | 268 | ||
216 | multiplyAux fun order a b = unsafePerformIO $ do | 269 | multiplyAux fun order a b = unsafePerformIO $ do |
217 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | 270 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ |
@@ -219,6 +272,7 @@ multiplyAux fun order a b = unsafePerformIO $ do | |||
219 | r <- createMatrix order (rows a) (cols b) | 272 | r <- createMatrix order (rows a) (cols b) |
220 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] | 273 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] |
221 | return r | 274 | return r |
275 | {-# INLINE multiplyAux #-} | ||
222 | 276 | ||
223 | foreign import ccall safe "aux.h multiplyR" | 277 | foreign import ccall safe "aux.h multiplyR" |
224 | cmultiplyR :: Int -> Int -> Int -> Ptr Double | 278 | cmultiplyR :: Int -> Int -> Int -> Ptr Double |
@@ -234,13 +288,15 @@ foreign import ccall safe "aux.h multiplyC" | |||
234 | 288 | ||
235 | multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | 289 | multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a |
236 | multiply RowMajor a b = multiplyD RowMajor a b | 290 | multiply RowMajor a b = multiplyD RowMajor a b |
237 | multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} | 291 | multiply ColumnMajor a b = MF {rows = c, cols = r, dat = d} |
238 | where m = multiplyD RowMajor (trans b) (trans a) | 292 | where MC {rows = r, cols = c, dat = d } = multiplyD RowMajor (trans b) (trans a) |
239 | 293 | ||
240 | 294 | ||
241 | multiplyR = multiplyAux cmultiplyR | 295 | multiplyR = multiplyAux cmultiplyR' |
242 | multiplyC = multiplyAux cmultiplyC | 296 | multiplyC = multiplyAux cmultiplyC |
243 | 297 | ||
298 | cmultiplyR' p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3 = {-# SCC "mulR" #-} cmultiplyR p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3 | ||
299 | |||
244 | ---------------------------------------------------------------------- | 300 | ---------------------------------------------------------------------- |
245 | 301 | ||
246 | -- | extraction of a submatrix of a real matrix | 302 | -- | extraction of a submatrix of a real matrix |
@@ -249,7 +305,7 @@ subMatrixR :: (Int,Int) -- ^ (r0,c0) starting position | |||
249 | -> Matrix Double -> Matrix Double | 305 | -> Matrix Double -> Matrix Double |
250 | subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do | 306 | subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do |
251 | r <- createMatrix RowMajor rt ct | 307 | r <- createMatrix RowMajor rt ct |
252 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r] | 308 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r] |
253 | return r | 309 | return r |
254 | foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM | 310 | foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM |
255 | 311 | ||
@@ -278,8 +334,8 @@ subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 | |||
278 | 334 | ||
279 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do | 335 | diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do |
280 | m <- createMatrix RowMajor n n | 336 | m <- createMatrix RowMajor n n |
281 | fun // vec v // mat dat m // check msg [dat m] | 337 | fun // vec v // mat cdat m // check msg [dat m] |
282 | return m {tdat = dat m} | 338 | return m -- {tdat = dat m} |
283 | 339 | ||
284 | -- | diagonal matrix from a real vector | 340 | -- | diagonal matrix from a real vector |
285 | diagR :: Vector Double -> Matrix Double | 341 | diagR :: Vector Double -> Matrix Double |
@@ -305,13 +361,13 @@ diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1.. | |||
305 | | otherwise = 0 | 361 | | otherwise = 0 |
306 | 362 | ||
307 | -- | creates a Matrix from a list of vectors | 363 | -- | creates a Matrix from a list of vectors |
308 | fromRows :: Field t => [Vector t] -> Matrix t | 364 | --fromRows :: Field t => [Vector t] -> Matrix t |
309 | fromRows vs = case common dim vs of | 365 | fromRows vs = case common dim vs of |
310 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" | 366 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" |
311 | Just c -> reshape c (join vs) | 367 | Just c -> reshape c (join vs) |
312 | 368 | ||
313 | -- | extracts the rows of a matrix as a list of vectors | 369 | -- | extracts the rows of a matrix as a list of vectors |
314 | toRows :: Storable t => Matrix t -> [Vector t] | 370 | --toRows :: Storable t => Matrix t -> [Vector t] |
315 | toRows m = toRows' 0 where | 371 | toRows m = toRows' 0 where |
316 | v = cdat m | 372 | v = cdat m |
317 | r = rows m | 373 | r = rows m |
@@ -324,16 +380,25 @@ fromColumns :: Field t => [Vector t] -> Matrix t | |||
324 | fromColumns m = trans . fromRows $ m | 380 | fromColumns m = trans . fromRows $ m |
325 | 381 | ||
326 | -- | Creates a list of vectors from the columns of a matrix | 382 | -- | Creates a list of vectors from the columns of a matrix |
327 | toColumns :: Storable t => Matrix t -> [Vector t] | 383 | toColumns :: Field t => Matrix t -> [Vector t] |
328 | toColumns m = toRows . trans $ m | 384 | toColumns m = toRows . trans $ m |
329 | 385 | ||
330 | 386 | ||
331 | -- | Reads a matrix position. | 387 | -- | Reads a matrix position. |
332 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t | 388 | (@@>) :: Storable t => Matrix t -> (Int,Int) -> t |
333 | infixl 9 @@> | 389 | infixl 9 @@> |
334 | m@M {rows = r, cols = c} @@> (i,j) | 390 | --m@M {rows = r, cols = c} @@> (i,j) |
391 | -- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
392 | -- | otherwise = cdat m `at` (i*c+j) | ||
393 | |||
394 | MC {rows = r, cols = c, dat = v} @@> (i,j) | ||
395 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | ||
396 | | otherwise = v `at` (i*c+j) | ||
397 | |||
398 | MF {rows = r, cols = c, dat = v} @@> (i,j) | ||
335 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | 399 | | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" |
336 | | otherwise = cdat m `at` (i*c+j) | 400 | | otherwise = v `at` (j*r+i) |
401 | |||
337 | 402 | ||
338 | ------------------------------------------------------------------ | 403 | ------------------------------------------------------------------ |
339 | 404 | ||
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 6876685..dea1636 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -92,7 +92,7 @@ tensor dssig vec = T d v `withIdx` seqind where | |||
92 | tensorFromVector :: IdxType -> Vector t -> Tensor t | 92 | tensorFromVector :: IdxType -> Vector t -> Tensor t |
93 | tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v} | 93 | tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v} |
94 | 94 | ||
95 | tensorFromMatrix :: IdxType -> IdxType -> Matrix t -> Tensor t | 95 | tensorFromMatrix :: Field t => IdxType -> IdxType -> Matrix t -> Tensor t |
96 | tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"] | 96 | tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"] |
97 | , ten = cdat m} | 97 | , ten = cdat m} |
98 | 98 | ||
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 2e8cb3d..45aaaba 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -77,7 +77,7 @@ diagRect s r c | |||
77 | | r > c = joinVert [diag s , zeros (r-c,c)] | 77 | | r > c = joinVert [diag s , zeros (r-c,c)] |
78 | where zeros (r,c) = reshape c $ constant 0 (r*c) | 78 | where zeros (r,c) = reshape c $ constant 0 (r*c) |
79 | 79 | ||
80 | takeDiag :: (Storable t) => Matrix t -> Vector t | 80 | takeDiag :: (Field t) => Matrix t -> Vector t |
81 | takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | 81 | takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] |
82 | 82 | ||
83 | ident :: (Num t, Field t) => Int -> Matrix t | 83 | ident :: (Num t, Field t) => Int -> Matrix t |
@@ -119,7 +119,7 @@ dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat | |||
119 | @\> flatten ('ident' 3) | 119 | @\> flatten ('ident' 3) |
120 | 9 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ | 120 | 9 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ |
121 | -} | 121 | -} |
122 | flatten :: Matrix t -> Vector t | 122 | flatten :: Field t => Matrix t -> Vector t |
123 | flatten = cdat | 123 | flatten = cdat |
124 | 124 | ||
125 | -- | Creates a 'Matrix' from a list of lists (considered as rows). | 125 | -- | Creates a 'Matrix' from a list of lists (considered as rows). |