summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs26
-rw-r--r--lib/Data/Packed/Internal/Vector.hs4
-rw-r--r--lib/Data/Packed/Matrix.hs21
-rw-r--r--lib/Numeric/Container.hs2
-rw-r--r--lib/Numeric/ContainerBoot.hs77
-rw-r--r--lib/Numeric/GSL/Vector.hs3
6 files changed, 81 insertions, 52 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 8709a00..2004e85 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -198,16 +198,17 @@ atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
198 198
199------------------------------------------------------------------ 199------------------------------------------------------------------
200 200
201matrixFromVector o c v = Matrix { irows = r, icols = c, xdat = v, order = o } 201matrixFromVector o r c v
202 where (d,m) = dim v `quotRem` c 202 | r * c == dim v = m
203 r | m==0 = d 203 | otherwise = error $ "matrixFromVector " ++ shSize m ++ " <- " ++ show (dim v)
204 | otherwise = error "matrixFromVector" 204 where
205 m = Matrix { irows = r, icols = c, xdat = v, order = o }
205 206
206-- allocates memory for a new matrix 207-- allocates memory for a new matrix
207createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) 208createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
208createMatrix ord r c = do 209createMatrix ord r c = do
209 p <- createVector (r*c) 210 p <- createVector (r*c)
210 return (matrixFromVector ord c p) 211 return (matrixFromVector ord r c p)
211 212
212{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@ 213{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@
213where r is the desired number of rows.) 214where r is the desired number of rows.)
@@ -220,21 +221,22 @@ where r is the desired number of rows.)
220 221
221-} 222-}
222reshape :: Storable t => Int -> Vector t -> Matrix t 223reshape :: Storable t => Int -> Vector t -> Matrix t
223reshape c v = matrixFromVector RowMajor c v 224reshape 0 v = matrixFromVector RowMajor 0 0 v
225reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
224 226
225singleton x = reshape 1 (fromList [x]) 227singleton x = reshape 1 (fromList [x])
226 228
227-- | application of a vector function on the flattened matrix elements 229-- | application of a vector function on the flattened matrix elements
228liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 230liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
229liftMatrix f Matrix { icols = c, xdat = d, order = o } = matrixFromVector o c (f d) 231liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d)
230 232
231-- | application of a vector function on the flattened matrices elements 233-- | application of a vector function on the flattened matrices elements
232liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 234liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
233liftMatrix2 f m1 m2 235liftMatrix2 f m1 m2
234 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" 236 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
235 | otherwise = case orderOf m1 of 237 | otherwise = case orderOf m1 of
236 RowMajor -> matrixFromVector RowMajor (cols m1) (f (xdat m1) (flatten m2)) 238 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2))
237 ColumnMajor -> matrixFromVector ColumnMajor (cols m1) (f (xdat m1) ((xdat.fmat) m2)) 239 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
238 240
239 241
240compat :: Matrix a -> Matrix b -> Bool 242compat :: Matrix a -> Matrix b -> Bool
@@ -296,7 +298,7 @@ transdata' c1 v c2 =
296 return w 298 return w
297 where r1 = dim v `div` c1 299 where r1 = dim v `div` c1
298 r2 = dim v `div` c2 300 r2 = dim v `div` c2
299 noneed = r1 == 1 || c1 == 1 301 noneed = dim v == 0 || r1 == 1 || c1 == 1
300 302
301-- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-} 303-- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-}
302-- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-} 304-- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-}
@@ -318,7 +320,7 @@ transdataAux fun c1 d c2 =
318 return v 320 return v
319 where r1 = dim d `div` c1 321 where r1 = dim d `div` c1
320 r2 = dim d `div` c2 322 r2 = dim d `div` c2
321 noneed = r1 == 1 || c1 == 1 323 noneed = dim d == 0 || r1 == 1 || c1 == 1
322 324
323transdataP :: Storable a => Int -> Vector a -> Int -> Vector a 325transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
324transdataP c1 d c2 = 326transdataP c1 d c2 =
@@ -333,7 +335,7 @@ transdataP c1 d c2 =
333 where r1 = dim d `div` c1 335 where r1 = dim d `div` c1
334 r2 = dim d `div` c2 336 r2 = dim d `div` c2
335 sz = sizeOf (d @> 0) 337 sz = sizeOf (d @> 0)
336 noneed = r1 == 1 || c1 == 1 338 noneed = dim d == 0 || r1 == 1 || c1 == 1
337 339
338foreign import ccall unsafe "transF" ctransF :: TFMFM 340foreign import ccall unsafe "transF" ctransF :: TFMFM
339foreign import ccall unsafe "transR" ctransR :: TMM 341foreign import ccall unsafe "transR" ctransR :: TMM
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index 415c972..6d03438 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -81,7 +81,7 @@ vec x f = unsafeWith x $ \p -> do
81-- allocates memory for a new vector 81-- allocates memory for a new vector
82createVector :: Storable a => Int -> IO (Vector a) 82createVector :: Storable a => Int -> IO (Vector a)
83createVector n = do 83createVector n = do
84 when (n <= 0) $ error ("trying to createVector of dim "++show n) 84 when (n < 0) $ error ("trying to createVector of negative dim: "++show n)
85 fp <- doMalloc undefined 85 fp <- doMalloc undefined
86 return $ unsafeFromForeignPtr fp 0 n 86 return $ unsafeFromForeignPtr fp 0 n
87 where 87 where
@@ -192,7 +192,7 @@ fromList [1.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0]
192 192
193-} 193-}
194vjoin :: Storable t => [Vector t] -> Vector t 194vjoin :: Storable t => [Vector t] -> Vector t
195vjoin [] = error "vjoin zero vectors" 195vjoin [] = fromList []
196vjoin [v] = v 196vjoin [v] = v
197vjoin as = unsafePerformIO $ do 197vjoin as = unsafePerformIO $ do
198 let tot = sum (map dim as) 198 let tot = sum (map dim as)
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index f72bd15..b92d60f 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -74,8 +74,10 @@ instance (Binary a, Element a, Storable a) => Binary (Matrix a) where
74------------------------------------------------------------------- 74-------------------------------------------------------------------
75 75
76instance (Show a, Element a) => (Show (Matrix a)) where 76instance (Show a, Element a) => (Show (Matrix a)) where
77 show m = (sizes++) . dsp . map (map show) . toLists $ m 77 show m | rows m == 0 || cols m == 0 = sizes m ++" []"
78 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" 78 show m = (sizes m++) . dsp . map (map show) . toLists $ m
79
80sizes m = "("++show (rows m)++"><"++show (cols m)++")\n"
79 81
80dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp 82dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
81 where 83 where
@@ -104,7 +106,7 @@ breakAt c l = (a++[c],tail b) where
104joinVert :: Element t => [Matrix t] -> Matrix t 106joinVert :: Element t => [Matrix t] -> Matrix t
105joinVert ms = case common cols ms of 107joinVert ms = case common cols ms of
106 Nothing -> error "(impossible) joinVert on matrices with different number of columns" 108 Nothing -> error "(impossible) joinVert on matrices with different number of columns"
107 Just c -> reshape c $ vjoin (map flatten ms) 109 Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms)
108 110
109-- | creates a matrix from a horizontal list of matrices 111-- | creates a matrix from a horizontal list of matrices
110joinHoriz :: Element t => [Matrix t] -> Matrix t 112joinHoriz :: Element t => [Matrix t] -> Matrix t
@@ -147,7 +149,7 @@ adaptBlocks ms = ms' where
147 149
148 g [Just nr,Just nc] m 150 g [Just nr,Just nc] m
149 | nr == r && nc == c = m 151 | nr == r && nc == c = m
150 | r == 1 && c == 1 = reshape nc (constantD x (nr*nc)) 152 | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc))
151 | r == 1 = fromRows (replicate nr (flatten m)) 153 | r == 1 = fromRows (replicate nr (flatten m))
152 | otherwise = fromColumns (replicate nc (flatten m)) 154 | otherwise = fromColumns (replicate nc (flatten m))
153 where 155 where
@@ -237,7 +239,7 @@ safely be used with lists that are too long (like infinite lists).
237-} 239-}
238(><) :: (Storable a) => Int -> Int -> [a] -> Matrix a 240(><) :: (Storable a) => Int -> Int -> [a] -> Matrix a
239r >< c = f where 241r >< c = f where
240 f l | dim v == r*c = matrixFromVector RowMajor c v 242 f l | dim v == r*c = matrixFromVector RowMajor r c v
241 | otherwise = error $ "inconsistent list size = " 243 | otherwise = error $ "inconsistent list size = "
242 ++show (dim v) ++" in ("++show r++"><"++show c++")" 244 ++show (dim v) ++" in ("++show r++"><"++show c++")"
243 where v = fromList $ take (r*c) l 245 where v = fromList $ take (r*c) l
@@ -291,7 +293,7 @@ asRow v = reshape (dim v) v
291-- , 5.0 ] 293-- , 5.0 ]
292-- 294--
293asColumn :: Storable a => Vector a -> Matrix a 295asColumn :: Storable a => Vector a -> Matrix a
294asColumn v = reshape 1 v 296asColumn = trans . asRow
295 297
296 298
297 299
@@ -358,7 +360,12 @@ liftMatrix2Auto f m1 m2
358 m1' = conformMTo (r,c) m1 360 m1' = conformMTo (r,c) m1
359 m2' = conformMTo (r,c) m2 361 m2' = conformMTo (r,c) m2
360 362
361lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) 363-- FIXME do not flatten if equal order
364lM f m1 m2 = matrixFromVector
365 RowMajor
366 (max (rows m1) (rows m2))
367 (max (cols m1) (cols m2))
368 (f (flatten m1) (flatten m2))
362 369
363compat' :: Matrix a -> Matrix b -> Bool 370compat' :: Matrix a -> Matrix b -> Bool
364compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 371compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs
index a71fdfe..b145a26 100644
--- a/lib/Numeric/Container.hs
+++ b/lib/Numeric/Container.hs
@@ -36,7 +36,7 @@ module Numeric.Container (
36 -- * Generic operations 36 -- * Generic operations
37 Container(..), 37 Container(..),
38 -- * Matrix product 38 -- * Matrix product
39 Product(..), 39 Product(..), udot,
40 Mul(..), 40 Mul(..),
41 Contraction(..), mmul, 41 Contraction(..), mmul,
42 optimiseMult, 42 optimiseMult,
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index a333489..6445e04 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -25,7 +25,7 @@ module Numeric.ContainerBoot (
25 -- * Generic operations 25 -- * Generic operations
26 Container(..), 26 Container(..),
27 -- * Matrix product and related functions 27 -- * Matrix product and related functions
28 Product(..), 28 Product(..), udot,
29 mXm,mXv,vXm, 29 mXm,mXv,vXm,
30 outer, kronecker, 30 outer, kronecker,
31 -- * Element conversion 31 -- * Element conversion
@@ -315,7 +315,7 @@ instance (Container Vector a) => Container Matrix a where
315 equal a b = cols a == cols b && flatten a `equal` flatten b 315 equal a b = cols a == cols b && flatten a `equal` flatten b
316 arctan2 = liftMatrix2 arctan2 316 arctan2 = liftMatrix2 arctan2
317 scalar x = (1><1) [x] 317 scalar x = (1><1) [x]
318 konst' v (r,c) = reshape c (konst' v (r*c)) 318 konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c))
319 build' = buildM 319 build' = buildM
320 conj = liftMatrix conj 320 conj = liftMatrix conj
321 cmap f = liftMatrix (mapVector f) 321 cmap f = liftMatrix (mapVector f)
@@ -339,11 +339,9 @@ instance (Container Vector a) => Container Matrix a where
339---------------------------------------------------- 339----------------------------------------------------
340 340
341-- | Matrix product and related functions 341-- | Matrix product and related functions
342class Element e => Product e where 342class (Num e, Element e) => Product e where
343 -- | matrix product 343 -- | matrix product
344 multiply :: Matrix e -> Matrix e -> Matrix e 344 multiply :: Matrix e -> Matrix e -> Matrix e
345 -- | (unconjugated) dot product
346 udot :: Vector e -> Vector e -> e
347 -- | sum of absolute value of elements (differs in complex case from @norm1@) 345 -- | sum of absolute value of elements (differs in complex case from @norm1@)
348 absSum :: Vector e -> RealOf e 346 absSum :: Vector e -> RealOf e
349 -- | sum of absolute value of elements 347 -- | sum of absolute value of elements
@@ -354,36 +352,57 @@ class Element e => Product e where
354 normInf :: Vector e -> RealOf e 352 normInf :: Vector e -> RealOf e
355 353
356instance Product Float where 354instance Product Float where
357 norm2 = toScalarF Norm2 355 norm2 = emptyVal (toScalarF Norm2)
358 absSum = toScalarF AbsSum 356 absSum = emptyVal (toScalarF AbsSum)
359 udot = dotF 357 norm1 = emptyVal (toScalarF AbsSum)
360 norm1 = toScalarF AbsSum 358 normInf = emptyVal (maxElement . vectorMapF Abs)
361 normInf = maxElement . vectorMapF Abs 359 multiply = emptyMul multiplyF
362 multiply = multiplyF
363 360
364instance Product Double where 361instance Product Double where
365 norm2 = toScalarR Norm2 362 norm2 = emptyVal (toScalarR Norm2)
366 absSum = toScalarR AbsSum 363 absSum = emptyVal (toScalarR AbsSum)
367 udot = dotR 364 norm1 = emptyVal (toScalarR AbsSum)
368 norm1 = toScalarR AbsSum 365 normInf = emptyVal (maxElement . vectorMapR Abs)
369 normInf = maxElement . vectorMapR Abs 366 multiply = emptyMul multiplyR
370 multiply = multiplyR
371 367
372instance Product (Complex Float) where 368instance Product (Complex Float) where
373 norm2 = toScalarQ Norm2 369 norm2 = emptyVal (toScalarQ Norm2)
374 absSum = toScalarQ AbsSum 370 absSum = emptyVal (toScalarQ AbsSum)
375 udot = dotQ 371 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs)
376 norm1 = sumElements . fst . fromComplex . vectorMapQ Abs 372 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs)
377 normInf = maxElement . fst . fromComplex . vectorMapQ Abs 373 multiply = emptyMul multiplyQ
378 multiply = multiplyQ
379 374
380instance Product (Complex Double) where 375instance Product (Complex Double) where
381 norm2 = toScalarC Norm2 376 norm2 = emptyVal (toScalarC Norm2)
382 absSum = toScalarC AbsSum 377 absSum = emptyVal (toScalarC AbsSum)
383 udot = dotC 378 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs)
384 norm1 = sumElements . fst . fromComplex . vectorMapC Abs 379 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs)
385 normInf = maxElement . fst . fromComplex . vectorMapC Abs 380 multiply = emptyMul multiplyC
386 multiply = multiplyC 381
382emptyMul m a b
383 | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c)
384 | otherwise = m a b
385 where
386 r = rows a
387 x1 = cols a
388 x2 = rows b
389 c = cols b
390
391emptyVal f v =
392 if dim v > 0
393 then f v
394 else 0
395
396
397-- FIXME remove unused C wrappers
398-- | (unconjugated) dot product
399udot :: Product e => Vector e -> Vector e -> e
400udot u v
401 | dim u == dim v = val (asRow u `multiply` asColumn v)
402 | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product"
403 where
404 val m | dim u > 0 = m@@>(0,0)
405 | otherwise = 0
387 406
388---------------------------------------------------------- 407----------------------------------------------------------
389 408
diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs
index db34041..6204b8e 100644
--- a/lib/Numeric/GSL/Vector.hs
+++ b/lib/Numeric/GSL/Vector.hs
@@ -33,6 +33,7 @@ import Foreign.Marshal.Array(newArray)
33import Foreign.Ptr(Ptr) 33import Foreign.Ptr(Ptr)
34import Foreign.C.Types 34import Foreign.C.Types
35import System.IO.Unsafe(unsafePerformIO) 35import System.IO.Unsafe(unsafePerformIO)
36import Control.Monad(when)
36 37
37fromei x = fromIntegral (fromEnum x) :: CInt 38fromei x = fromIntegral (fromEnum x) :: CInt
38 39
@@ -201,7 +202,7 @@ vectorMapValAux fun code val v = unsafePerformIO $ do
201 202
202vectorZipAux fun code u v = unsafePerformIO $ do 203vectorZipAux fun code u v = unsafePerformIO $ do
203 r <- createVector (dim u) 204 r <- createVector (dim u)
204 app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" 205 when (dim u > 0) $ app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux"
205 return r 206 return r
206 207
207--------------------------------------------------------------------- 208---------------------------------------------------------------------