From 4078cf44c98b42960be27843782f6983bb66017f Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sun, 4 May 2014 21:08:51 +0200 Subject: allow empty arrays --- lib/Data/Packed/Internal/Matrix.hs | 26 +++++++------ lib/Data/Packed/Internal/Vector.hs | 4 +- lib/Data/Packed/Matrix.hs | 21 +++++++---- lib/Numeric/Container.hs | 2 +- lib/Numeric/ContainerBoot.hs | 77 ++++++++++++++++++++++++-------------- lib/Numeric/GSL/Vector.hs | 3 +- 6 files changed, 81 insertions(+), 52 deletions(-) (limited to 'lib') diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 8709a00..2004e85 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -198,16 +198,17 @@ atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i) ------------------------------------------------------------------ -matrixFromVector o c v = Matrix { irows = r, icols = c, xdat = v, order = o } - where (d,m) = dim v `quotRem` c - r | m==0 = d - | otherwise = error "matrixFromVector" +matrixFromVector o r c v + | r * c == dim v = m + | otherwise = error $ "matrixFromVector " ++ shSize m ++ " <- " ++ show (dim v) + where + m = Matrix { irows = r, icols = c, xdat = v, order = o } -- allocates memory for a new matrix createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) createMatrix ord r c = do p <- createVector (r*c) - return (matrixFromVector ord c p) + return (matrixFromVector ord r c p) {- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@ where r is the desired number of rows.) @@ -220,21 +221,22 @@ where r is the desired number of rows.) -} reshape :: Storable t => Int -> Vector t -> Matrix t -reshape c v = matrixFromVector RowMajor c v +reshape 0 v = matrixFromVector RowMajor 0 0 v +reshape c v = matrixFromVector RowMajor (dim v `div` c) c v singleton x = reshape 1 (fromList [x]) -- | application of a vector function on the flattened matrix elements liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b -liftMatrix f Matrix { icols = c, xdat = d, order = o } = matrixFromVector o c (f d) +liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d) -- | application of a vector function on the flattened matrices elements liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t liftMatrix2 f m1 m2 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | otherwise = case orderOf m1 of - RowMajor -> matrixFromVector RowMajor (cols m1) (f (xdat m1) (flatten m2)) - ColumnMajor -> matrixFromVector ColumnMajor (cols m1) (f (xdat m1) ((xdat.fmat) m2)) + RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2)) + ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2)) compat :: Matrix a -> Matrix b -> Bool @@ -296,7 +298,7 @@ transdata' c1 v c2 = return w where r1 = dim v `div` c1 r2 = dim v `div` c2 - noneed = r1 == 1 || c1 == 1 + noneed = dim v == 0 || r1 == 1 || c1 == 1 -- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-} -- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-} @@ -318,7 +320,7 @@ transdataAux fun c1 d c2 = return v where r1 = dim d `div` c1 r2 = dim d `div` c2 - noneed = r1 == 1 || c1 == 1 + noneed = dim d == 0 || r1 == 1 || c1 == 1 transdataP :: Storable a => Int -> Vector a -> Int -> Vector a transdataP c1 d c2 = @@ -333,7 +335,7 @@ transdataP c1 d c2 = where r1 = dim d `div` c1 r2 = dim d `div` c2 sz = sizeOf (d @> 0) - noneed = r1 == 1 || c1 == 1 + noneed = dim d == 0 || r1 == 1 || c1 == 1 foreign import ccall unsafe "transF" ctransF :: TFMFM foreign import ccall unsafe "transR" ctransR :: TMM diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 415c972..6d03438 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs @@ -81,7 +81,7 @@ vec x f = unsafeWith x $ \p -> do -- allocates memory for a new vector createVector :: Storable a => Int -> IO (Vector a) createVector n = do - when (n <= 0) $ error ("trying to createVector of dim "++show n) + when (n < 0) $ error ("trying to createVector of negative dim: "++show n) fp <- doMalloc undefined return $ unsafeFromForeignPtr fp 0 n where @@ -192,7 +192,7 @@ fromList [1.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0] -} vjoin :: Storable t => [Vector t] -> Vector t -vjoin [] = error "vjoin zero vectors" +vjoin [] = fromList [] vjoin [v] = v vjoin as = unsafePerformIO $ do let tot = sum (map dim as) diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index f72bd15..b92d60f 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -74,8 +74,10 @@ instance (Binary a, Element a, Storable a) => Binary (Matrix a) where ------------------------------------------------------------------- instance (Show a, Element a) => (Show (Matrix a)) where - show m = (sizes++) . dsp . map (map show) . toLists $ m - where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" + show m | rows m == 0 || cols m == 0 = sizes m ++" []" + show m = (sizes m++) . dsp . map (map show) . toLists $ m + +sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp where @@ -104,7 +106,7 @@ breakAt c l = (a++[c],tail b) where joinVert :: Element t => [Matrix t] -> Matrix t joinVert ms = case common cols ms of Nothing -> error "(impossible) joinVert on matrices with different number of columns" - Just c -> reshape c $ vjoin (map flatten ms) + Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) -- | creates a matrix from a horizontal list of matrices joinHoriz :: Element t => [Matrix t] -> Matrix t @@ -147,7 +149,7 @@ adaptBlocks ms = ms' where g [Just nr,Just nc] m | nr == r && nc == c = m - | r == 1 && c == 1 = reshape nc (constantD x (nr*nc)) + | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc)) | r == 1 = fromRows (replicate nr (flatten m)) | otherwise = fromColumns (replicate nc (flatten m)) where @@ -237,7 +239,7 @@ safely be used with lists that are too long (like infinite lists). -} (><) :: (Storable a) => Int -> Int -> [a] -> Matrix a r >< c = f where - f l | dim v == r*c = matrixFromVector RowMajor c v + f l | dim v == r*c = matrixFromVector RowMajor r c v | otherwise = error $ "inconsistent list size = " ++show (dim v) ++" in ("++show r++"><"++show c++")" where v = fromList $ take (r*c) l @@ -291,7 +293,7 @@ asRow v = reshape (dim v) v -- , 5.0 ] -- asColumn :: Storable a => Vector a -> Matrix a -asColumn v = reshape 1 v +asColumn = trans . asRow @@ -358,7 +360,12 @@ liftMatrix2Auto f m1 m2 m1' = conformMTo (r,c) m1 m2' = conformMTo (r,c) m2 -lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) +-- FIXME do not flatten if equal order +lM f m1 m2 = matrixFromVector + RowMajor + (max (rows m1) (rows m2)) + (max (cols m1) (cols m2)) + (f (flatten m1) (flatten m2)) compat' :: Matrix a -> Matrix b -> Bool compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs index a71fdfe..b145a26 100644 --- a/lib/Numeric/Container.hs +++ b/lib/Numeric/Container.hs @@ -36,7 +36,7 @@ module Numeric.Container ( -- * Generic operations Container(..), -- * Matrix product - Product(..), + Product(..), udot, Mul(..), Contraction(..), mmul, optimiseMult, diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index a333489..6445e04 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs @@ -25,7 +25,7 @@ module Numeric.ContainerBoot ( -- * Generic operations Container(..), -- * Matrix product and related functions - Product(..), + Product(..), udot, mXm,mXv,vXm, outer, kronecker, -- * Element conversion @@ -315,7 +315,7 @@ instance (Container Vector a) => Container Matrix a where equal a b = cols a == cols b && flatten a `equal` flatten b arctan2 = liftMatrix2 arctan2 scalar x = (1><1) [x] - konst' v (r,c) = reshape c (konst' v (r*c)) + konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c)) build' = buildM conj = liftMatrix conj cmap f = liftMatrix (mapVector f) @@ -339,11 +339,9 @@ instance (Container Vector a) => Container Matrix a where ---------------------------------------------------- -- | Matrix product and related functions -class Element e => Product e where +class (Num e, Element e) => Product e where -- | matrix product multiply :: Matrix e -> Matrix e -> Matrix e - -- | (unconjugated) dot product - udot :: Vector e -> Vector e -> e -- | sum of absolute value of elements (differs in complex case from @norm1@) absSum :: Vector e -> RealOf e -- | sum of absolute value of elements @@ -354,36 +352,57 @@ class Element e => Product e where normInf :: Vector e -> RealOf e instance Product Float where - norm2 = toScalarF Norm2 - absSum = toScalarF AbsSum - udot = dotF - norm1 = toScalarF AbsSum - normInf = maxElement . vectorMapF Abs - multiply = multiplyF + norm2 = emptyVal (toScalarF Norm2) + absSum = emptyVal (toScalarF AbsSum) + norm1 = emptyVal (toScalarF AbsSum) + normInf = emptyVal (maxElement . vectorMapF Abs) + multiply = emptyMul multiplyF instance Product Double where - norm2 = toScalarR Norm2 - absSum = toScalarR AbsSum - udot = dotR - norm1 = toScalarR AbsSum - normInf = maxElement . vectorMapR Abs - multiply = multiplyR + norm2 = emptyVal (toScalarR Norm2) + absSum = emptyVal (toScalarR AbsSum) + norm1 = emptyVal (toScalarR AbsSum) + normInf = emptyVal (maxElement . vectorMapR Abs) + multiply = emptyMul multiplyR instance Product (Complex Float) where - norm2 = toScalarQ Norm2 - absSum = toScalarQ AbsSum - udot = dotQ - norm1 = sumElements . fst . fromComplex . vectorMapQ Abs - normInf = maxElement . fst . fromComplex . vectorMapQ Abs - multiply = multiplyQ + norm2 = emptyVal (toScalarQ Norm2) + absSum = emptyVal (toScalarQ AbsSum) + norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs) + normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs) + multiply = emptyMul multiplyQ instance Product (Complex Double) where - norm2 = toScalarC Norm2 - absSum = toScalarC AbsSum - udot = dotC - norm1 = sumElements . fst . fromComplex . vectorMapC Abs - normInf = maxElement . fst . fromComplex . vectorMapC Abs - multiply = multiplyC + norm2 = emptyVal (toScalarC Norm2) + absSum = emptyVal (toScalarC AbsSum) + norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs) + normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs) + multiply = emptyMul multiplyC + +emptyMul m a b + | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c) + | otherwise = m a b + where + r = rows a + x1 = cols a + x2 = rows b + c = cols b + +emptyVal f v = + if dim v > 0 + then f v + else 0 + + +-- FIXME remove unused C wrappers +-- | (unconjugated) dot product +udot :: Product e => Vector e -> Vector e -> e +udot u v + | dim u == dim v = val (asRow u `multiply` asColumn v) + | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product" + where + val m | dim u > 0 = m@@>(0,0) + | otherwise = 0 ---------------------------------------------------------- diff --git a/lib/Numeric/GSL/Vector.hs b/lib/Numeric/GSL/Vector.hs index db34041..6204b8e 100644 --- a/lib/Numeric/GSL/Vector.hs +++ b/lib/Numeric/GSL/Vector.hs @@ -33,6 +33,7 @@ import Foreign.Marshal.Array(newArray) import Foreign.Ptr(Ptr) import Foreign.C.Types import System.IO.Unsafe(unsafePerformIO) +import Control.Monad(when) fromei x = fromIntegral (fromEnum x) :: CInt @@ -201,7 +202,7 @@ vectorMapValAux fun code val v = unsafePerformIO $ do vectorZipAux fun code u v = unsafePerformIO $ do r <- createVector (dim u) - app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" + when (dim u > 0) $ app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux" return r --------------------------------------------------------------------- -- cgit v1.2.3