summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-04 21:08:51 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-04 21:08:51 +0200
commit4078cf44c98b42960be27843782f6983bb66017f (patch)
treebee20c3c811a98247aab99738991ab4b2bcc2312 /lib/Data/Packed/Internal
parentae104ebd5891c84f9c8b4a40501fefdeeb1280c4 (diff)
allow empty arrays
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs26
-rw-r--r--lib/Data/Packed/Internal/Vector.hs4
2 files changed, 16 insertions, 14 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 8709a00..2004e85 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -198,16 +198,17 @@ atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
198 198
199------------------------------------------------------------------ 199------------------------------------------------------------------
200 200
201matrixFromVector o c v = Matrix { irows = r, icols = c, xdat = v, order = o } 201matrixFromVector o r c v
202 where (d,m) = dim v `quotRem` c 202 | r * c == dim v = m
203 r | m==0 = d 203 | otherwise = error $ "matrixFromVector " ++ shSize m ++ " <- " ++ show (dim v)
204 | otherwise = error "matrixFromVector" 204 where
205 m = Matrix { irows = r, icols = c, xdat = v, order = o }
205 206
206-- allocates memory for a new matrix 207-- allocates memory for a new matrix
207createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) 208createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
208createMatrix ord r c = do 209createMatrix ord r c = do
209 p <- createVector (r*c) 210 p <- createVector (r*c)
210 return (matrixFromVector ord c p) 211 return (matrixFromVector ord r c p)
211 212
212{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@ 213{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@
213where r is the desired number of rows.) 214where r is the desired number of rows.)
@@ -220,21 +221,22 @@ where r is the desired number of rows.)
220 221
221-} 222-}
222reshape :: Storable t => Int -> Vector t -> Matrix t 223reshape :: Storable t => Int -> Vector t -> Matrix t
223reshape c v = matrixFromVector RowMajor c v 224reshape 0 v = matrixFromVector RowMajor 0 0 v
225reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
224 226
225singleton x = reshape 1 (fromList [x]) 227singleton x = reshape 1 (fromList [x])
226 228
227-- | application of a vector function on the flattened matrix elements 229-- | application of a vector function on the flattened matrix elements
228liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 230liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
229liftMatrix f Matrix { icols = c, xdat = d, order = o } = matrixFromVector o c (f d) 231liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d)
230 232
231-- | application of a vector function on the flattened matrices elements 233-- | application of a vector function on the flattened matrices elements
232liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 234liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
233liftMatrix2 f m1 m2 235liftMatrix2 f m1 m2
234 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" 236 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
235 | otherwise = case orderOf m1 of 237 | otherwise = case orderOf m1 of
236 RowMajor -> matrixFromVector RowMajor (cols m1) (f (xdat m1) (flatten m2)) 238 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2))
237 ColumnMajor -> matrixFromVector ColumnMajor (cols m1) (f (xdat m1) ((xdat.fmat) m2)) 239 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
238 240
239 241
240compat :: Matrix a -> Matrix b -> Bool 242compat :: Matrix a -> Matrix b -> Bool
@@ -296,7 +298,7 @@ transdata' c1 v c2 =
296 return w 298 return w
297 where r1 = dim v `div` c1 299 where r1 = dim v `div` c1
298 r2 = dim v `div` c2 300 r2 = dim v `div` c2
299 noneed = r1 == 1 || c1 == 1 301 noneed = dim v == 0 || r1 == 1 || c1 == 1
300 302
301-- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-} 303-- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-}
302-- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-} 304-- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-}
@@ -318,7 +320,7 @@ transdataAux fun c1 d c2 =
318 return v 320 return v
319 where r1 = dim d `div` c1 321 where r1 = dim d `div` c1
320 r2 = dim d `div` c2 322 r2 = dim d `div` c2
321 noneed = r1 == 1 || c1 == 1 323 noneed = dim d == 0 || r1 == 1 || c1 == 1
322 324
323transdataP :: Storable a => Int -> Vector a -> Int -> Vector a 325transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
324transdataP c1 d c2 = 326transdataP c1 d c2 =
@@ -333,7 +335,7 @@ transdataP c1 d c2 =
333 where r1 = dim d `div` c1 335 where r1 = dim d `div` c1
334 r2 = dim d `div` c2 336 r2 = dim d `div` c2
335 sz = sizeOf (d @> 0) 337 sz = sizeOf (d @> 0)
336 noneed = r1 == 1 || c1 == 1 338 noneed = dim d == 0 || r1 == 1 || c1 == 1
337 339
338foreign import ccall unsafe "transF" ctransF :: TFMFM 340foreign import ccall unsafe "transF" ctransF :: TFMFM
339foreign import ccall unsafe "transR" ctransR :: TMM 341foreign import ccall unsafe "transR" ctransR :: TMM
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index 415c972..6d03438 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -81,7 +81,7 @@ vec x f = unsafeWith x $ \p -> do
81-- allocates memory for a new vector 81-- allocates memory for a new vector
82createVector :: Storable a => Int -> IO (Vector a) 82createVector :: Storable a => Int -> IO (Vector a)
83createVector n = do 83createVector n = do
84 when (n <= 0) $ error ("trying to createVector of dim "++show n) 84 when (n < 0) $ error ("trying to createVector of negative dim: "++show n)
85 fp <- doMalloc undefined 85 fp <- doMalloc undefined
86 return $ unsafeFromForeignPtr fp 0 n 86 return $ unsafeFromForeignPtr fp 0 n
87 where 87 where
@@ -192,7 +192,7 @@ fromList [1.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0]
192 192
193-} 193-}
194vjoin :: Storable t => [Vector t] -> Vector t 194vjoin :: Storable t => [Vector t] -> Vector t
195vjoin [] = error "vjoin zero vectors" 195vjoin [] = fromList []
196vjoin [v] = v 196vjoin [v] = v
197vjoin as = unsafePerformIO $ do 197vjoin as = unsafePerformIO $ do
198 let tot = sum (map dim as) 198 let tot = sum (map dim as)