diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-05-04 21:08:51 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-05-04 21:08:51 +0200 |
commit | 4078cf44c98b42960be27843782f6983bb66017f (patch) | |
tree | bee20c3c811a98247aab99738991ab4b2bcc2312 /lib/Data/Packed/Internal | |
parent | ae104ebd5891c84f9c8b4a40501fefdeeb1280c4 (diff) |
allow empty arrays
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 26 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 4 |
2 files changed, 16 insertions, 14 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 8709a00..2004e85 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -198,16 +198,17 @@ atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) | |||
198 | 198 | ||
199 | ------------------------------------------------------------------ | 199 | ------------------------------------------------------------------ |
200 | 200 | ||
201 | matrixFromVector o c v = Matrix { irows = r, icols = c, xdat = v, order = o } | 201 | matrixFromVector o r c v |
202 | where (d,m) = dim v `quotRem` c | 202 | | r * c == dim v = m |
203 | r | m==0 = d | 203 | | otherwise = error $ "matrixFromVector " ++ shSize m ++ " <- " ++ show (dim v) |
204 | | otherwise = error "matrixFromVector" | 204 | where |
205 | m = Matrix { irows = r, icols = c, xdat = v, order = o } | ||
205 | 206 | ||
206 | -- allocates memory for a new matrix | 207 | -- allocates memory for a new matrix |
207 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | 208 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) |
208 | createMatrix ord r c = do | 209 | createMatrix ord r c = do |
209 | p <- createVector (r*c) | 210 | p <- createVector (r*c) |
210 | return (matrixFromVector ord c p) | 211 | return (matrixFromVector ord r c p) |
211 | 212 | ||
212 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@ | 213 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@ |
213 | where r is the desired number of rows.) | 214 | where r is the desired number of rows.) |
@@ -220,21 +221,22 @@ where r is the desired number of rows.) | |||
220 | 221 | ||
221 | -} | 222 | -} |
222 | reshape :: Storable t => Int -> Vector t -> Matrix t | 223 | reshape :: Storable t => Int -> Vector t -> Matrix t |
223 | reshape c v = matrixFromVector RowMajor c v | 224 | reshape 0 v = matrixFromVector RowMajor 0 0 v |
225 | reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | ||
224 | 226 | ||
225 | singleton x = reshape 1 (fromList [x]) | 227 | singleton x = reshape 1 (fromList [x]) |
226 | 228 | ||
227 | -- | application of a vector function on the flattened matrix elements | 229 | -- | application of a vector function on the flattened matrix elements |
228 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 230 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
229 | liftMatrix f Matrix { icols = c, xdat = d, order = o } = matrixFromVector o c (f d) | 231 | liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d) |
230 | 232 | ||
231 | -- | application of a vector function on the flattened matrices elements | 233 | -- | application of a vector function on the flattened matrices elements |
232 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 234 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
233 | liftMatrix2 f m1 m2 | 235 | liftMatrix2 f m1 m2 |
234 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | 236 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" |
235 | | otherwise = case orderOf m1 of | 237 | | otherwise = case orderOf m1 of |
236 | RowMajor -> matrixFromVector RowMajor (cols m1) (f (xdat m1) (flatten m2)) | 238 | RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2)) |
237 | ColumnMajor -> matrixFromVector ColumnMajor (cols m1) (f (xdat m1) ((xdat.fmat) m2)) | 239 | ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2)) |
238 | 240 | ||
239 | 241 | ||
240 | compat :: Matrix a -> Matrix b -> Bool | 242 | compat :: Matrix a -> Matrix b -> Bool |
@@ -296,7 +298,7 @@ transdata' c1 v c2 = | |||
296 | return w | 298 | return w |
297 | where r1 = dim v `div` c1 | 299 | where r1 = dim v `div` c1 |
298 | r2 = dim v `div` c2 | 300 | r2 = dim v `div` c2 |
299 | noneed = r1 == 1 || c1 == 1 | 301 | noneed = dim v == 0 || r1 == 1 || c1 == 1 |
300 | 302 | ||
301 | -- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-} | 303 | -- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-} |
302 | -- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-} | 304 | -- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-} |
@@ -318,7 +320,7 @@ transdataAux fun c1 d c2 = | |||
318 | return v | 320 | return v |
319 | where r1 = dim d `div` c1 | 321 | where r1 = dim d `div` c1 |
320 | r2 = dim d `div` c2 | 322 | r2 = dim d `div` c2 |
321 | noneed = r1 == 1 || c1 == 1 | 323 | noneed = dim d == 0 || r1 == 1 || c1 == 1 |
322 | 324 | ||
323 | transdataP :: Storable a => Int -> Vector a -> Int -> Vector a | 325 | transdataP :: Storable a => Int -> Vector a -> Int -> Vector a |
324 | transdataP c1 d c2 = | 326 | transdataP c1 d c2 = |
@@ -333,7 +335,7 @@ transdataP c1 d c2 = | |||
333 | where r1 = dim d `div` c1 | 335 | where r1 = dim d `div` c1 |
334 | r2 = dim d `div` c2 | 336 | r2 = dim d `div` c2 |
335 | sz = sizeOf (d @> 0) | 337 | sz = sizeOf (d @> 0) |
336 | noneed = r1 == 1 || c1 == 1 | 338 | noneed = dim d == 0 || r1 == 1 || c1 == 1 |
337 | 339 | ||
338 | foreign import ccall unsafe "transF" ctransF :: TFMFM | 340 | foreign import ccall unsafe "transF" ctransF :: TFMFM |
339 | foreign import ccall unsafe "transR" ctransR :: TMM | 341 | foreign import ccall unsafe "transR" ctransR :: TMM |
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 415c972..6d03438 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -81,7 +81,7 @@ vec x f = unsafeWith x $ \p -> do | |||
81 | -- allocates memory for a new vector | 81 | -- allocates memory for a new vector |
82 | createVector :: Storable a => Int -> IO (Vector a) | 82 | createVector :: Storable a => Int -> IO (Vector a) |
83 | createVector n = do | 83 | createVector n = do |
84 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | 84 | when (n < 0) $ error ("trying to createVector of negative dim: "++show n) |
85 | fp <- doMalloc undefined | 85 | fp <- doMalloc undefined |
86 | return $ unsafeFromForeignPtr fp 0 n | 86 | return $ unsafeFromForeignPtr fp 0 n |
87 | where | 87 | where |
@@ -192,7 +192,7 @@ fromList [1.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0] | |||
192 | 192 | ||
193 | -} | 193 | -} |
194 | vjoin :: Storable t => [Vector t] -> Vector t | 194 | vjoin :: Storable t => [Vector t] -> Vector t |
195 | vjoin [] = error "vjoin zero vectors" | 195 | vjoin [] = fromList [] |
196 | vjoin [v] = v | 196 | vjoin [v] = v |
197 | vjoin as = unsafePerformIO $ do | 197 | vjoin as = unsafePerformIO $ do |
198 | let tot = sum (map dim as) | 198 | let tot = sum (map dim as) |