diff options
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 26 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 4 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 21 | ||||
-rw-r--r-- | lib/Numeric/Container.hs | 2 | ||||
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 77 | ||||
-rw-r--r-- | lib/Numeric/GSL/Vector.hs | 3 |
6 files changed, 81 insertions, 52 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 8709a00..2004e85 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -198,16 +198,17 @@ atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) | |||
198 | 198 | ||
199 | ------------------------------------------------------------------ | 199 | ------------------------------------------------------------------ |
200 | 200 | ||
201 | matrixFromVector o c v = Matrix { irows = r, icols = c, xdat = v, order = o } | 201 | matrixFromVector o r c v |
202 | where (d,m) = dim v `quotRem` c | 202 | | r * c == dim v = m |
203 | r | m==0 = d | 203 | | otherwise = error $ "matrixFromVector " ++ shSize m ++ " <- " ++ show (dim v) |
204 | | otherwise = error "matrixFromVector" | 204 | where |
205 | m = Matrix { irows = r, icols = c, xdat = v, order = o } | ||
205 | 206 | ||
206 | -- allocates memory for a new matrix | 207 | -- allocates memory for a new matrix |
207 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) | 208 | createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) |
208 | createMatrix ord r c = do | 209 | createMatrix ord r c = do |
209 | p <- createVector (r*c) | 210 | p <- createVector (r*c) |
210 | return (matrixFromVector ord c p) | 211 | return (matrixFromVector ord r c p) |
211 | 212 | ||
212 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@ | 213 | {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@ |
213 | where r is the desired number of rows.) | 214 | where r is the desired number of rows.) |
@@ -220,21 +221,22 @@ where r is the desired number of rows.) | |||
220 | 221 | ||
221 | -} | 222 | -} |
222 | reshape :: Storable t => Int -> Vector t -> Matrix t | 223 | reshape :: Storable t => Int -> Vector t -> Matrix t |
223 | reshape c v = matrixFromVector RowMajor c v | 224 | reshape 0 v = matrixFromVector RowMajor 0 0 v |
225 | reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | ||
224 | 226 | ||
225 | singleton x = reshape 1 (fromList [x]) | 227 | singleton x = reshape 1 (fromList [x]) |
226 | 228 | ||
227 | -- | application of a vector function on the flattened matrix elements | 229 | -- | application of a vector function on the flattened matrix elements |
228 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 230 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
229 | liftMatrix f Matrix { icols = c, xdat = d, order = o } = matrixFromVector o c (f d) | 231 | liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d) |
230 | 232 | ||
231 | -- | application of a vector function on the flattened matrices elements | 233 | -- | application of a vector function on the flattened matrices elements |
232 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 234 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
233 | liftMatrix2 f m1 m2 | 235 | liftMatrix2 f m1 m2 |
234 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | 236 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" |
235 | | otherwise = case orderOf m1 of | 237 | | otherwise = case orderOf m1 of |
236 | RowMajor -> matrixFromVector RowMajor (cols m1) (f (xdat m1) (flatten m2)) | 238 | RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2)) |
237 | ColumnMajor -> matrixFromVector ColumnMajor (cols m1) (f (xdat m1) ((xdat.fmat) m2)) | 239 | ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2)) |
238 | 240 | ||
239 | 241 | ||
240 | compat :: Matrix a -> Matrix b -> Bool | 242 | compat :: Matrix a -> Matrix b -> Bool |
@@ -296,7 +298,7 @@ transdata' c1 v c2 = | |||
296 | return w | 298 | return w |
297 | where r1 = dim v `div` c1 | 299 | where r1 = dim v `div` c1 |
298 | r2 = dim v `div` c2 | 300 | r2 = dim v `div` c2 |
299 | noneed = r1 == 1 || c1 == 1 | 301 | noneed = dim v == 0 || r1 == 1 || c1 == 1 |
300 | 302 | ||
301 | -- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-} | 303 | -- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-} |
302 | -- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-} | 304 | -- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-} |
@@ -318,7 +320,7 @@ transdataAux fun c1 d c2 = | |||
318 | return v | 320 | return v |
319 | where r1 = dim d `div` c1 | 321 | where r1 = dim d `div` c1 |
320 | r2 = dim d `div` c2 | 322 | r2 = dim d `div` c2 |
321 | noneed = r1 == 1 || c1 == 1 | 323 | noneed = dim d == 0 || r1 == 1 || c1 == 1 |
322 | 324 | ||
323 | transdataP :: Storable a => Int -> Vector a -> Int -> Vector a | 325 | transdataP :: Storable a => Int -> Vector a -> Int -> Vector a |
324 | transdataP c1 d c2 = | 326 | transdataP c1 d c2 = |
@@ -333,7 +335,7 @@ transdataP c1 d c2 = | |||
333 | where r1 = dim d `div` c1 | 335 | where r1 = dim d `div` c1 |
334 | r2 = dim d `div` c2 | 336 | r2 = dim d `div` c2 |
335 | sz = sizeOf (d @> 0) | 337 | sz = sizeOf (d @> 0) |
336 | noneed = r1 == 1 || c1 == 1 | 338 | noneed = dim d == 0 || r1 == 1 || c1 == 1 |
337 | 339 | ||
338 | foreign import ccall unsafe "transF" ctransF :: TFMFM | 340 | foreign import ccall unsafe "transF" ctransF :: TFMFM |
339 | foreign import ccall unsafe "transR" ctransR :: TMM | 341 | foreign import ccall unsafe "transR" ctransR :: TMM |
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 415c972..6d03438 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -81,7 +81,7 @@ vec x f = unsafeWith x $ \p -> do | |||
81 | -- allocates memory for a new vector | 81 | -- allocates memory for a new vector |
82 | createVector :: Storable a => Int -> IO (Vector a) | 82 | createVector :: Storable a => Int -> IO (Vector a) |
83 | createVector n = do | 83 | createVector n = do |
84 | when (n <= 0) $ error ("trying to createVector of dim "++show n) | 84 | when (n < 0) $ error ("trying to createVector of negative dim: "++show n) |
85 | fp <- doMalloc undefined | 85 | fp <- doMalloc undefined |
86 | return $ unsafeFromForeignPtr fp 0 n | 86 | return $ unsafeFromForeignPtr fp 0 n |
87 | where | 87 | where |
@@ -192,7 +192,7 @@ fromList [1.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0] | |||
192 | 192 | ||
193 | -} | 193 | -} |
194 | vjoin :: Storable t => [Vector t] -> Vector t | 194 | vjoin :: Storable t => [Vector t] -> Vector t |
195 | vjoin [] = error "vjoin zero vectors" | 195 | vjoin [] = fromList [] |
196 | vjoin [v] = v | 196 | vjoin [v] = v |
197 | vjoin as = unsafePerformIO $ do | 197 | vjoin as = unsafePerformIO $ do |
198 | let tot = sum (map dim as) | 198 | let tot = sum (map dim as) |
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index f72bd15..b92d60f 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -74,8 +74,10 @@ instance (Binary a, Element a, Storable a) => Binary (Matrix a) where | |||
74 | ------------------------------------------------------------------- | 74 | ------------------------------------------------------------------- |
75 | 75 | ||
76 | instance (Show a, Element a) => (Show (Matrix a)) where | 76 | instance (Show a, Element a) => (Show (Matrix a)) where |
77 | show m = (sizes++) . dsp . map (map show) . toLists $ m | 77 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" |
78 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | 78 | show m = (sizes m++) . dsp . map (map show) . toLists $ m |
79 | |||
80 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
79 | 81 | ||
80 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | 82 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp |
81 | where | 83 | where |
@@ -104,7 +106,7 @@ breakAt c l = (a++[c],tail b) where | |||
104 | joinVert :: Element t => [Matrix t] -> Matrix t | 106 | joinVert :: Element t => [Matrix t] -> Matrix t |
105 | joinVert ms = case common cols ms of | 107 | joinVert ms = case common cols ms of |
106 | Nothing -> error "(impossible) joinVert on matrices with different number of columns" | 108 | Nothing -> error "(impossible) joinVert on matrices with different number of columns" |
107 | Just c -> reshape c $ vjoin (map flatten ms) | 109 | Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) |
108 | 110 | ||
109 | -- | creates a matrix from a horizontal list of matrices | 111 | -- | creates a matrix from a horizontal list of matrices |
110 | joinHoriz :: Element t => [Matrix t] -> Matrix t | 112 | joinHoriz :: Element t => [Matrix t] -> Matrix t |
@@ -147,7 +149,7 @@ adaptBlocks ms = ms' where | |||
147 | 149 | ||
148 | g [Just nr,Just nc] m | 150 | g [Just nr,Just nc] m |
149 | | nr == r && nc == c = m | 151 | | nr == r && nc == c = m |
150 | | r == 1 && c == 1 = reshape nc (constantD x (nr*nc)) | 152 | | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc)) |
151 | | r == 1 = fromRows (replicate nr (flatten m)) | 153 | | r == 1 = fromRows (replicate nr (flatten m)) |
152 | | otherwise = fromColumns (replicate nc (flatten m)) | 154 | | otherwise = fromColumns (replicate nc (flatten m)) |
153 | where | 155 | where |
@@ -237,7 +239,7 @@ safely be used with lists that are too long (like infinite lists). | |||
237 | -} | 239 | -} |
238 | (><) :: (Storable a) => Int -> Int -> [a] -> Matrix a | 240 | (><) :: (Storable a) => Int -> Int -> [a] -> Matrix a |
239 | r >< c = f where | 241 | r >< c = f where |
240 | f l | dim v == r*c = matrixFromVector RowMajor c v | 242 | f l | dim v == r*c = matrixFromVector RowMajor r c v |
241 | | otherwise = error $ "inconsistent list size = " | 243 | | otherwise = error $ "inconsistent list size = " |
242 | ++show (dim v) ++" in ("++show r++"><"++show c++")" | 244 | ++show (dim v) ++" in ("++show r++"><"++show c++")" |
243 | where v = fromList $ take (r*c) l | 245 | where v = fromList $ take (r*c) l |
@@ -291,7 +293,7 @@ asRow v = reshape (dim v) v | |||
291 | -- , 5.0 ] | 293 | -- , 5.0 ] |
292 | -- | 294 | -- |
293 | asColumn :: Storable a => Vector a -> Matrix a | 295 | asColumn :: Storable a => Vector a -> Matrix a |
294 | asColumn v = reshape 1 v | 296 | asColumn = trans . asRow |
295 | 297 | ||
296 | 298 | ||
297 | 299 | ||
@@ -358,7 +360,12 @@ liftMatrix2Auto f m1 m2 | |||
358 | m1' = conformMTo (r,c) m1 | 360 | m1' = conformMTo (r,c) m1 |
359 | m2' = conformMTo (r,c) m2 | 361 | m2' = conformMTo (r,c) m2 |
360 | 362 | ||
361 | lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) | 363 | -- FIXME do not flatten if equal order |
364 | lM f m1 m2 = matrixFromVector | ||
365 | RowMajor | ||
366 | (max (rows m1) (rows m2)) | ||
367 | (max (cols m1) (cols m2)) | ||
368 | (f (flatten m1) (flatten m2)) | ||
362 | 369 | ||
363 | compat' :: Matrix a -> Matrix b -> Bool | 370 | compat' :: Matrix a -> Matrix b -> Bool |
364 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | 371 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 |
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs index a71fdfe..b145a26 100644 --- a/lib/Numeric/Container.hs +++ b/lib/Numeric/Container.hs | |||
@@ -36,7 +36,7 @@ module Numeric.Container ( | |||
36 | -- * Generic operations | 36 | -- * Generic operations |
37 | Container(..), | 37 | Container(..), |
38 | -- * Matrix product | 38 | -- * Matrix product |
39 | Product(..), | 39 | Product(..), udot, |
40 | Mul(..), | 40 | Mul(..), |
41 | Contraction(..), mmul, | 41 | Contraction(..), mmul, |
42 | optimiseMult, | 42 | optimiseMult, |
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index a333489..6445e04 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -25,7 +25,7 @@ module Numeric.ContainerBoot ( | |||
25 | -- * Generic operations | 25 | -- * Generic operations |
26 | Container(..), | 26 | Container(..), |
27 | -- * Matrix product and related functions | 27 | -- * Matrix product and related functions |
28 | Product(..), | 28 | Product(..), udot, |
29 | mXm,mXv,vXm, | 29 | mXm,mXv,vXm, |
30 | outer, kronecker, | 30 | outer, kronecker, |
31 | -- * Element conversion | 31 | -- * Element conversion |
@@ -315,7 +315,7 @@ instance (Container Vector a) => Container Matrix a where | |||
315 | equal a b = cols a == cols b && flatten a `equal` flatten b | 315 | equal a b = cols a == cols b && flatten a `equal` flatten b |
316 | arctan2 = liftMatrix2 arctan2 | 316 | arctan2 = liftMatrix2 arctan2 |
317 | scalar x = (1><1) [x] | 317 | scalar x = (1><1) [x] |
318 | konst' v (r,c) = reshape c (konst' v (r*c)) | 318 | konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) |
319 | build' = buildM | 319 | build' = buildM |
320 | conj = liftMatrix conj | 320 | conj = liftMatrix conj |
321 | cmap f = liftMatrix (mapVector f) | 321 | cmap f = liftMatrix (mapVector f) |
@@ -339,11 +339,9 @@ instance (Container Vector a) => Container Matrix a where | |||
339 | ---------------------------------------------------- | 339 | ---------------------------------------------------- |
340 | 340 | ||
341 | -- | Matrix product and related functions | 341 | -- | Matrix product and related functions |
342 | class Element e => Product e where | 342 | class (Num e, Element e) => Product e where |
343 | -- | matrix product | 343 | -- | matrix product |
344 | multiply :: Matrix e -> Matrix e -> Matrix e | 344 | multiply :: Matrix e -> Matrix e -> Matrix e |
345 | -- | (unconjugated) dot product | ||
346 | udot :: Vector e -> Vector e -> e | ||
347 | -- | sum of absolute value of elements (differs in complex case from @norm1@) | 345 | -- | sum of absolute value of elements (differs in complex case from @norm1@) |
348 | absSum :: Vector e -> RealOf e | 346 | absSum :: Vector e -> RealOf e |
349 | -- | sum of absolute value of elements | 347 | -- | sum of absolute value of elements |
@@ -354,36 +352,57 @@ class Element e => Product e where | |||
354 | normInf :: Vector e -> RealOf e | 352 | normInf :: Vector e -> RealOf e |
355 | 353 | ||
356 | instance Product Float where | 354 | instance Product Float where |
357 | norm2 = toScalarF Norm2 | 355 | norm2 = emptyVal (toScalarF Norm2) |
358 | absSum = toScalarF AbsSum | 356 | absSum = emptyVal (toScalarF AbsSum) |
359 | udot = dotF | 357 | norm1 = emptyVal (toScalarF AbsSum) |
360 | norm1 = toScalarF AbsSum | 358 | normInf = emptyVal (maxElement . vectorMapF Abs) |
361 | normInf = maxElement . vectorMapF Abs | 359 | multiply = emptyMul multiplyF |
362 | multiply = multiplyF | ||
363 | 360 | ||
364 | instance Product Double where | 361 | instance Product Double where |
365 | norm2 = toScalarR Norm2 | 362 | norm2 = emptyVal (toScalarR Norm2) |
366 | absSum = toScalarR AbsSum | 363 | absSum = emptyVal (toScalarR AbsSum) |
367 | udot = dotR | 364 | norm1 = emptyVal (toScalarR AbsSum) |
368 | norm1 = toScalarR AbsSum | 365 | normInf = emptyVal (maxElement . vectorMapR Abs) |
369 | normInf = maxElement . vectorMapR Abs | 366 | multiply = emptyMul multiplyR |
370 | multiply = multiplyR | ||
371 | 367 | ||
372 | instance Product (Complex Float) where | 368 | instance Product (Complex Float) where |
373 | norm2 = toScalarQ Norm2 | 369 | norm2 = emptyVal (toScalarQ Norm2) |
374 | absSum = toScalarQ AbsSum | 370 | absSum = emptyVal (toScalarQ AbsSum) |
375 | udot = dotQ | 371 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs) |
376 | norm1 = sumElements . fst . fromComplex . vectorMapQ Abs | 372 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs) |
377 | normInf = maxElement . fst . fromComplex . vectorMapQ Abs | 373 | multiply = emptyMul multiplyQ |
378 | multiply = multiplyQ | ||
379 | 374 | ||
380 | instance Product (Complex Double) where | 375 | instance Product (Complex Double) where |
381 | norm2 = toScalarC Norm2 | 376 | norm2 = emptyVal (toScalarC Norm2) |
382 | absSum = toScalarC AbsSum | 377 | absSum = emptyVal (toScalarC AbsSum) |
383 | udot = dotC | 378 | norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs) |
384 | norm1 = sumElements . fst . fromComplex . vectorMapC Abs | 379 | normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs) |
385 | normInf = maxElement . fst . fromComplex . vectorMapC Abs | 380 | multiply = emptyMul multiplyC |
386 | multiply = multiplyC | 381 | |
382 | emptyMul m a b | ||
383 | | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c) | ||
384 | | otherwise = m a b | ||
385 | where | ||
386 | r = rows a | ||
387 | x1 = cols a | ||
388 | x2 = rows b | ||
389 | c = cols b | ||
390 | |||
391 | emptyVal f v = | ||
392 | if dim v > 0 | ||
393 | then f v | ||
394 | else 0 | ||
395 | |||
396 | |||
397 | -- FIXME remove unused C wrappers | ||
398 | -- | (unconjugated) dot product | ||
399 | udot :: Product e => Vector e -> Vector e -> e | ||
400 | udot u v | ||
401 | | dim u == dim v = val (asRow u `multiply` asColumn v) | ||
402 | | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product" | ||
403 | where | ||
404 | val m | dim u > 0 = m@@>(0,0) | ||
405 | | otherwise = 0 | ||
387 | 406 | ||
388 | ---------------------------------------------------------- | 407 | ---------------------------------------------------------- |
389 | 408 | ||
diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs index db34041..6204b8e 100644 --- a/lib/Numeric/GSL/Vector.hs +++ b/lib/Numeric/GSL/Vector.hs | |||
@@ -33,6 +33,7 @@ import Foreign.Marshal.Array(newArray) | |||
33 | import Foreign.Ptr(Ptr) | 33 | import Foreign.Ptr(Ptr) |
34 | import Foreign.C.Types | 34 | import Foreign.C.Types |
35 | import System.IO.Unsafe(unsafePerformIO) | 35 | import System.IO.Unsafe(unsafePerformIO) |
36 | import Control.Monad(when) | ||
36 | 37 | ||
37 | fromei x = fromIntegral (fromEnum x) :: CInt | 38 | fromei x = fromIntegral (fromEnum x) :: CInt |
38 | 39 | ||
@@ -201,7 +202,7 @@ vectorMapValAux fun code val v = unsafePerformIO $ do | |||
201 | 202 | ||
202 | vectorZipAux fun code u v = unsafePerformIO $ do | 203 | vectorZipAux fun code u v = unsafePerformIO $ do |
203 | r <- createVector (dim u) | 204 | r <- createVector (dim u) |
204 | app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" | 205 | when (dim u > 0) $ app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" |
205 | return r | 206 | return r |
206 | 207 | ||
207 | --------------------------------------------------------------------- | 208 | --------------------------------------------------------------------- |