summaryrefslogtreecommitdiff
path: root/lib/Data/Packed
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs131
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs2
-rw-r--r--lib/Data/Packed/Matrix.hs4
3 files changed, 101 insertions, 36 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 48652f3..6ba2d06 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -27,6 +27,10 @@ import Data.Maybe(fromJust)
27 27
28---------------------------------------------------------------- 28----------------------------------------------------------------
29 29
30-- the condition Storable a => Field a means that we can only put
31-- in Field types that are in Storable, and therefore Storable a
32-- is not required in signatures if we have a Field a.
33
30class Storable a => Field a where 34class Storable a => Field a where
31 constant :: a -> Int -> Vector a 35 constant :: a -> Int -> Vector a
32 transdata :: Int -> Vector a -> Int -> Vector a 36 transdata :: Int -> Vector a -> Int -> Vector a
@@ -36,7 +40,6 @@ class Storable a => Field a where
36 -> Matrix a -> Matrix a 40 -> Matrix a -> Matrix a
37 diag :: Vector a -> Matrix a 41 diag :: Vector a -> Matrix a
38 42
39
40instance Field Double where 43instance Field Double where
41 constant = constantR 44 constant = constantR
42 transdata = transdataR 45 transdata = transdataR
@@ -78,12 +81,40 @@ foreign import ccall safe "aux.h transC"
78 81
79transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d 82transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d
80 83
84{- Design considerations for the Matrix Type
85 -----------------------------------------
86
87- we must easily handle both row major and column major order,
88 for bindings to LAPACK and GSL/C
89
90- we'd like to simplify redundant matrix transposes:
91 - Some of them arise from the order requirements of some functions
92 - some functions (matrix product) admit transposed arguments
81 93
94- maybe we don't really need this kind of simplification:
95 - more complex code
96 - some computational overhead
97 - only appreciable gain in code with a lot of redundant transpositions
98 and cheap matrix computations
82 99
100- we could carry both the matrix and its (lazily computed) transpose.
101 This may save some transpositions, but it is necessary to keep track of the
102 data which is actually computed to be used by functions like the matrix product
103 which admit both orders. Therefore, maybe it is better to have something like
104 viewC and viewF, which may actually perform a transpose if required.
83 105
106- but if we need the transposed data and it is not in the structure, we must make
107 sure that we touch the same foreignptr that is used in the computation. Access
108 to such pointer cannot be made by creating a new vector.
109
110-}
84 111
85data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) 112data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
86 113
114{-
115
116
117
87data Matrix t = M { rows :: Int 118data Matrix t = M { rows :: Int
88 , cols :: Int 119 , cols :: Int
89 , dat :: Vector t 120 , dat :: Vector t
@@ -91,28 +122,26 @@ data Matrix t = M { rows :: Int
91 , isTrans :: Bool 122 , isTrans :: Bool
92 , order :: MatrixOrder 123 , order :: MatrixOrder
93 } -- deriving Typeable 124 } -- deriving Typeable
125-}
94 126
127data Matrix t = MC { rows :: Int, cols :: Int, dat :: Vector t } -- row major order
128 | MF { rows :: Int, cols :: Int, dat :: Vector t } -- column major order
95 129
96data NMat t = MC { rws, cls :: Int, dtc :: Vector t} 130-- transposition just changes the data order
97 | MF { rws, cls :: Int, dtf :: Vector t} 131trans :: Matrix t -> Matrix t
98 | Tr (NMat t) 132trans MC {rows = r, cols = c, dat = d} = MF {rows = c, cols = r, dat = d}
99 133trans MF {rows = r, cols = c, dat = d} = MC {rows = c, cols = r, dat = d}
100ntrans (Tr m) = m
101ntrans m = Tr m
102 134
103viewC m@MC{} = m 135viewC m@MC{} = m
104viewF m@MF{} = m 136viewC MF {rows = r, cols = c, dat = d} = MC {rows = r, cols = c, dat = transdata r d c}
105 137
106fortran m = order m == ColumnMajor 138viewF m@MF{} = m
139viewF MC {rows = r, cols = c, dat = d} = MF {rows = r, cols = c, dat = transdata c d r}
107 140
108cdat m = if fortran m `xor` isTrans m then tdat m else dat m 141--fortran m = order m == ColumnMajor
109fdat m = if fortran m `xor` isTrans m then dat m else tdat m
110 142
111trans :: Matrix t -> Matrix t 143cdat m = dat (viewC m)
112trans m = m { rows = cols m 144fdat m = dat (viewF m)
113 , cols = rows m
114 , isTrans = not (isTrans m)
115 }
116 145
117type Mt t s = Int -> Int -> Ptr t -> s 146type Mt t s = Int -> Int -> Ptr t -> s
118-- not yet admitted by my haddock version 147-- not yet admitted by my haddock version
@@ -120,11 +149,14 @@ type Mt t s = Int -> Int -> Ptr t -> s
120-- type t ::> s = Mt t s 149-- type t ::> s = Mt t s
121 150
122mat d m f = f (rows m) (cols m) (ptr (d m)) 151mat d m f = f (rows m) (cols m) (ptr (d m))
152--mat m f = f (rows m) (cols m) (ptr (dat m))
153--matC m f = f (rows m) (cols m) (ptr (cdat m))
154
123 155
124toLists :: (Storable t) => Matrix t -> [[t]] 156--toLists :: (Storable t) => Matrix t -> [[t]]
125toLists m = partit (cols m) . toList . cdat $ m 157toLists m = partit (cols m) . toList . cdat $ m
126 158
127instance (Show a, Storable a) => (Show (Matrix a)) where 159instance (Show a, Field a) => (Show (Matrix a)) where
128 show m = (sizes++) . dsp . map (map show) . toLists $ m 160 show m = (sizes++) . dsp . map (map show) . toLists $ m
129 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" 161 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n"
130 162
@@ -136,6 +168,7 @@ dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unw
136 pad n str = replicate (n - length str) ' ' ++ str 168 pad n str = replicate (n - length str) ' ' ++ str
137 unwords' = concat . intersperse ", " 169 unwords' = concat . intersperse ", "
138 170
171{-
139matrixFromVector RowMajor c v = 172matrixFromVector RowMajor c v =
140 M { rows = r 173 M { rows = r
141 , cols = c 174 , cols = c
@@ -147,8 +180,6 @@ matrixFromVector RowMajor c v =
147 r | m==0 = d 180 r | m==0 = d
148 | otherwise = error "matrixFromVector" 181 | otherwise = error "matrixFromVector"
149 182
150-- r = dim v `div` c -- TODO check mod=0
151
152matrixFromVector ColumnMajor c v = 183matrixFromVector ColumnMajor c v =
153 M { rows = r 184 M { rows = r
154 , cols = c 185 , cols = c
@@ -160,6 +191,23 @@ matrixFromVector ColumnMajor c v =
160 r | m==0 = d 191 r | m==0 = d
161 | otherwise = error "matrixFromVector" 192 | otherwise = error "matrixFromVector"
162 193
194-}
195
196matrixFromVector RowMajor c v = MC { rows = r, cols = c, dat = v}
197 where (d,m) = dim v `divMod` c
198 r | m==0 = d
199 | otherwise = error "matrixFromVector"
200
201matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, dat = v}
202 where (d,m) = dim v `divMod` c
203 r | m==0 = d
204 | otherwise = error "matrixFromVector"
205
206
207
208
209
210
163createMatrix order r c = do 211createMatrix order r c = do
164 p <- createVector (r*c) 212 p <- createVector (r*c)
165 return (matrixFromVector order c p) 213 return (matrixFromVector order c p)
@@ -178,10 +226,10 @@ reshape c v = matrixFromVector RowMajor c v
178 226
179singleton x = reshape 1 (fromList [x]) 227singleton x = reshape 1 (fromList [x])
180 228
181liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 229--liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
182liftMatrix f m = reshape (cols m) (f (cdat m)) 230liftMatrix f m = reshape (cols m) (f (cdat m))
183 231
184liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 232--liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
185liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) 233liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2))
186 | otherwise = error "nonconformant matrices in liftMatrix2" 234 | otherwise = error "nonconformant matrices in liftMatrix2"
187------------------------------------------------------------------ 235------------------------------------------------------------------
@@ -203,6 +251,7 @@ multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multipl
203 251
204------------------------------------------------------------------ 252------------------------------------------------------------------
205 253
254{-
206gmatC m f | fortran m = 255gmatC m f | fortran m =
207 if (isTrans m) 256 if (isTrans m)
208 then f 0 (rows m) (cols m) (ptr (dat m)) 257 then f 0 (rows m) (cols m) (ptr (dat m))
@@ -211,7 +260,11 @@ gmatC m f | fortran m =
211 if isTrans m 260 if isTrans m
212 then f 1 (cols m) (rows m) (ptr (dat m)) 261 then f 1 (cols m) (rows m) (ptr (dat m))
213 else f 0 (rows m) (cols m) (ptr (dat m)) 262 else f 0 (rows m) (cols m) (ptr (dat m))
263-}
214 264
265gmatC MF {rows = r, cols = c, dat = d} f = f 1 c r (ptr d)
266gmatC MC {rows = r, cols = c, dat = d} f = f 0 r c (ptr d)
267{-# INLINE gmatC #-}
215 268
216multiplyAux fun order a b = unsafePerformIO $ do 269multiplyAux fun order a b = unsafePerformIO $ do
217 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ 270 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
@@ -219,6 +272,7 @@ multiplyAux fun order a b = unsafePerformIO $ do
219 r <- createMatrix order (rows a) (cols b) 272 r <- createMatrix order (rows a) (cols b)
220 fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] 273 fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b]
221 return r 274 return r
275{-# INLINE multiplyAux #-}
222 276
223foreign import ccall safe "aux.h multiplyR" 277foreign import ccall safe "aux.h multiplyR"
224 cmultiplyR :: Int -> Int -> Int -> Ptr Double 278 cmultiplyR :: Int -> Int -> Int -> Ptr Double
@@ -234,13 +288,15 @@ foreign import ccall safe "aux.h multiplyC"
234 288
235multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a 289multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
236multiply RowMajor a b = multiplyD RowMajor a b 290multiply RowMajor a b = multiplyD RowMajor a b
237multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} 291multiply ColumnMajor a b = MF {rows = c, cols = r, dat = d}
238 where m = multiplyD RowMajor (trans b) (trans a) 292 where MC {rows = r, cols = c, dat = d } = multiplyD RowMajor (trans b) (trans a)
239 293
240 294
241multiplyR = multiplyAux cmultiplyR 295multiplyR = multiplyAux cmultiplyR'
242multiplyC = multiplyAux cmultiplyC 296multiplyC = multiplyAux cmultiplyC
243 297
298cmultiplyR' p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3 = {-# SCC "mulR" #-} cmultiplyR p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3
299
244---------------------------------------------------------------------- 300----------------------------------------------------------------------
245 301
246-- | extraction of a submatrix of a real matrix 302-- | extraction of a submatrix of a real matrix
@@ -249,7 +305,7 @@ subMatrixR :: (Int,Int) -- ^ (r0,c0) starting position
249 -> Matrix Double -> Matrix Double 305 -> Matrix Double -> Matrix Double
250subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do 306subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do
251 r <- createMatrix RowMajor rt ct 307 r <- createMatrix RowMajor rt ct
252 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r] 308 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r]
253 return r 309 return r
254foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM 310foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM
255 311
@@ -278,8 +334,8 @@ subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0
278 334
279diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do 335diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do
280 m <- createMatrix RowMajor n n 336 m <- createMatrix RowMajor n n
281 fun // vec v // mat dat m // check msg [dat m] 337 fun // vec v // mat cdat m // check msg [dat m]
282 return m {tdat = dat m} 338 return m -- {tdat = dat m}
283 339
284-- | diagonal matrix from a real vector 340-- | diagonal matrix from a real vector
285diagR :: Vector Double -> Matrix Double 341diagR :: Vector Double -> Matrix Double
@@ -305,13 +361,13 @@ diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..
305 | otherwise = 0 361 | otherwise = 0
306 362
307-- | creates a Matrix from a list of vectors 363-- | creates a Matrix from a list of vectors
308fromRows :: Field t => [Vector t] -> Matrix t 364--fromRows :: Field t => [Vector t] -> Matrix t
309fromRows vs = case common dim vs of 365fromRows vs = case common dim vs of
310 Nothing -> error "fromRows applied to [] or to vectors with different sizes" 366 Nothing -> error "fromRows applied to [] or to vectors with different sizes"
311 Just c -> reshape c (join vs) 367 Just c -> reshape c (join vs)
312 368
313-- | extracts the rows of a matrix as a list of vectors 369-- | extracts the rows of a matrix as a list of vectors
314toRows :: Storable t => Matrix t -> [Vector t] 370--toRows :: Storable t => Matrix t -> [Vector t]
315toRows m = toRows' 0 where 371toRows m = toRows' 0 where
316 v = cdat m 372 v = cdat m
317 r = rows m 373 r = rows m
@@ -324,16 +380,25 @@ fromColumns :: Field t => [Vector t] -> Matrix t
324fromColumns m = trans . fromRows $ m 380fromColumns m = trans . fromRows $ m
325 381
326-- | Creates a list of vectors from the columns of a matrix 382-- | Creates a list of vectors from the columns of a matrix
327toColumns :: Storable t => Matrix t -> [Vector t] 383toColumns :: Field t => Matrix t -> [Vector t]
328toColumns m = toRows . trans $ m 384toColumns m = toRows . trans $ m
329 385
330 386
331-- | Reads a matrix position. 387-- | Reads a matrix position.
332(@@>) :: Storable t => Matrix t -> (Int,Int) -> t 388(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
333infixl 9 @@> 389infixl 9 @@>
334m@M {rows = r, cols = c} @@> (i,j) 390--m@M {rows = r, cols = c} @@> (i,j)
391-- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
392-- | otherwise = cdat m `at` (i*c+j)
393
394MC {rows = r, cols = c, dat = v} @@> (i,j)
395 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
396 | otherwise = v `at` (i*c+j)
397
398MF {rows = r, cols = c, dat = v} @@> (i,j)
335 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" 399 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
336 | otherwise = cdat m `at` (i*c+j) 400 | otherwise = v `at` (j*r+i)
401
337 402
338------------------------------------------------------------------ 403------------------------------------------------------------------
339 404
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 6876685..dea1636 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -92,7 +92,7 @@ tensor dssig vec = T d v `withIdx` seqind where
92tensorFromVector :: IdxType -> Vector t -> Tensor t 92tensorFromVector :: IdxType -> Vector t -> Tensor t
93tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v} 93tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v}
94 94
95tensorFromMatrix :: IdxType -> IdxType -> Matrix t -> Tensor t 95tensorFromMatrix :: Field t => IdxType -> IdxType -> Matrix t -> Tensor t
96tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"] 96tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"]
97 , ten = cdat m} 97 , ten = cdat m}
98 98
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index 2e8cb3d..45aaaba 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -77,7 +77,7 @@ diagRect s r c
77 | r > c = joinVert [diag s , zeros (r-c,c)] 77 | r > c = joinVert [diag s , zeros (r-c,c)]
78 where zeros (r,c) = reshape c $ constant 0 (r*c) 78 where zeros (r,c) = reshape c $ constant 0 (r*c)
79 79
80takeDiag :: (Storable t) => Matrix t -> Vector t 80takeDiag :: (Field t) => Matrix t -> Vector t
81takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] 81takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]]
82 82
83ident :: (Num t, Field t) => Int -> Matrix t 83ident :: (Num t, Field t) => Int -> Matrix t
@@ -119,7 +119,7 @@ dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat
119@\> flatten ('ident' 3) 119@\> flatten ('ident' 3)
1209 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ 1209 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@
121-} 121-}
122flatten :: Matrix t -> Vector t 122flatten :: Field t => Matrix t -> Vector t
123flatten = cdat 123flatten = cdat
124 124
125-- | Creates a 'Matrix' from a list of lists (considered as rows). 125-- | Creates a 'Matrix' from a list of lists (considered as rows).