From 34380f2b5d7b048a4d68197f16a8db0e53742030 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sat, 8 Sep 2007 09:46:33 +0000 Subject: type classes --- lib/Data/Packed/Internal/Common.hs | 18 ++-- lib/Data/Packed/Internal/Matrix.hs | 204 +++++++++++++++++++++---------------- lib/Data/Packed/Internal/Tensor.hs | 37 ++++--- lib/Data/Packed/Internal/Vector.hs | 42 +++----- lib/Data/Packed/Matrix.hs | 15 +++ lib/Data/Packed/Vector.hs | 12 ++- 6 files changed, 184 insertions(+), 144 deletions(-) (limited to 'lib/Data') diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs index 1bfed6d..1212968 100644 --- a/lib/Data/Packed/Internal/Common.hs +++ b/lib/Data/Packed/Internal/Common.hs @@ -28,7 +28,7 @@ debug x = trace (show x) x data Vector t = V { dim :: Int , fptr :: ForeignPtr t , ptr :: Ptr t - } deriving Typeable + } -- deriving Typeable ---------------------------------------------------------------------- instance (Storable a, RealFloat a) => Storable (Complex a) where -- @@ -78,17 +78,17 @@ check msg ls f = do mapM_ (touchForeignPtr . fptr) ls return () -class (Storable a, Typeable a) => Field a -instance (Storable a, Typeable a) => Field a +--class (Storable a, Typeable a) => Field a +--instance (Storable a, Typeable a) => Field a -isReal :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool -isReal w x = typeOf (undefined :: Double) == typeOf (w x) +--isReal :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool +--isReal w x = typeOf (undefined :: Double) == typeOf (w x) -isComp :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool -isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) +--isComp :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool +--isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) -scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b -scast = fromJust . cast +--scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b +--scast = fromJust . cast {- | conversion of Haskell functions into function pointers that can be used in the C side -} diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 9895393..48652f3 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -1,4 +1,4 @@ -{-# OPTIONS_GHC -fglasgow-exts #-} +{-# OPTIONS_GHC -fglasgow-exts -fallow-overlapping-instances #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Packed.Internal.Matrix @@ -22,9 +22,65 @@ import Foreign hiding (xor) import Complex import Control.Monad(when) import Data.List(transpose,intersperse) -import Data.Typeable +--import Data.Typeable import Data.Maybe(fromJust) +---------------------------------------------------------------- + +class Storable a => Field a where + constant :: a -> Int -> Vector a + transdata :: Int -> Vector a -> Int -> Vector a + multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a + subMatrix :: (Int,Int) -- ^ (r0,c0) starting position + -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix + -> Matrix a -> Matrix a + diag :: Vector a -> Matrix a + + +instance Field Double where + constant = constantR + transdata = transdataR + multiplyD = multiplyR + subMatrix = subMatrixR + diag = diagR + +instance Field (Complex Double) where + constant = constantC + transdata = transdataC + multiplyD = multiplyC + subMatrix = subMatrixC + diag = diagC + +----------------------------------------------------------------- + +transdataR :: Int -> Vector Double -> Int -> Vector Double +transdataR = transdataAux ctransR + +transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) +transdataC = transdataAux ctransC + +transdataAux fun c1 d c2 = + if noneed + then d + else unsafePerformIO $ do + v <- createVector (dim d) + fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] + --putStrLn "---> transdataAux" + return v + where r1 = dim d `div` c1 + r2 = dim d `div` c2 + noneed = r1 == 1 || c1 == 1 + +foreign import ccall safe "aux.h transR" + ctransR :: TMM -- Double ::> Double ::> IO Int +foreign import ccall safe "aux.h transC" + ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int + +transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d + + + + data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) @@ -34,9 +90,18 @@ data Matrix t = M { rows :: Int , tdat :: Vector t , isTrans :: Bool , order :: MatrixOrder - } deriving Typeable + } -- deriving Typeable +data NMat t = MC { rws, cls :: Int, dtc :: Vector t} + | MF { rws, cls :: Int, dtf :: Vector t} + | Tr (NMat t) + +ntrans (Tr m) = m +ntrans m = Tr m + +viewC m@MC{} = m +viewF m@MF{} = m fortran m = order m == ColumnMajor @@ -78,7 +143,11 @@ matrixFromVector RowMajor c v = , tdat = transdata c v r , order = RowMajor , isTrans = False - } where r = dim v `div` c -- TODO check mod=0 + } where (d,m) = dim v `divMod` c + r | m==0 = d + | otherwise = error "matrixFromVector" + +-- r = dim v `div` c -- TODO check mod=0 matrixFromVector ColumnMajor c v = M { rows = r @@ -87,7 +156,9 @@ matrixFromVector ColumnMajor c v = , tdat = transdata r v c , order = ColumnMajor , isTrans = False - } where r = dim v `div` c -- TODO check mod=0 + } where (d,m) = dim v `divMod` c + r | m==0 = d + | otherwise = error "matrixFromVector" createMatrix order r c = do p <- createVector (r*c) @@ -102,48 +173,11 @@ createMatrix order r c = do , 9.0, 10.0, 11.0, 12.0 ]@ -} -reshape :: (Field t) => Int -> Vector t -> Matrix t +reshape :: Field t => Int -> Vector t -> Matrix t reshape c v = matrixFromVector RowMajor c v singleton x = reshape 1 (fromList [x]) -transdataG :: Storable a => Int -> Vector a -> Int -> Vector a -transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d - -transdataR :: Int -> Vector Double -> Int -> Vector Double -transdataR = transdataAux ctransR - -transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) -transdataC = transdataAux ctransC - -transdataAux fun c1 d c2 = - if noneed - then d - else unsafePerformIO $ do - v <- createVector (dim d) - fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] - --putStrLn "---> transdataAux" - return v - where r1 = dim d `div` c1 - r2 = dim d `div` c2 - noneed = r1 == 1 || c1 == 1 - -foreign import ccall safe "aux.h transR" - ctransR :: TMM -- Double ::> Double ::> IO Int -foreign import ccall safe "aux.h transC" - ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int - -transdata :: Field a => Int -> Vector a -> Int -> Vector a -transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 - | isComp baseOf d = scast $ transdataC c1 (scast d) c2 - | otherwise = transdataG c1 d c2 - ---transdata :: Storable a => Int -> Vector a -> Int -> Vector a ---transdata = transdataG ---{-# RULES "transdataR" transdata=transdataR #-} ---{-# RULES "transdataC" transdata=transdataC #-} - ------------------------------------------------------------------ liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b liftMatrix f m = reshape (cols m) (f (cdat m)) @@ -163,7 +197,7 @@ multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] Nothing -> False Just c -> c == length b -transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) +transL m = matrixFromVector RowMajor (rows m) $ transdata (cols m) (cdat m) (rows m) multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) @@ -179,7 +213,7 @@ gmatC m f | fortran m = else f 0 (rows m) (cols m) (ptr (dat m)) -multiplyAux order fun a b = unsafePerformIO $ do +multiplyAux fun order a b = unsafePerformIO $ do when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ show (rows a,cols a) ++ " x " ++ show (rows b, cols b) r <- createMatrix order (rows a) (cols b) @@ -198,37 +232,14 @@ foreign import ccall safe "aux.h multiplyC" -> Int -> Int -> Ptr (Complex Double) -> IO Int -multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a +multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a multiply RowMajor a b = multiplyD RowMajor a b multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} where m = multiplyD RowMajor (trans b) (trans a) -multiplyD order a b - | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) - | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) - | otherwise = multiplyG a b - ----------------------------------------------------------------------- -outer' u v = dat (outer u v) - -{- | Outer product of two vectors. - -@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] -(3><3) - [ 5.0, 2.0, 3.0 - , 10.0, 4.0, 6.0 - , 15.0, 6.0, 9.0 ]@ --} -outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t -outer u v = multiply RowMajor r c - where r = matrixFromVector RowMajor 1 u - c = matrixFromVector RowMajor (dim v) v - -dot :: (Field t, Num t) => Vector t -> Vector t -> t -dot u v = dat (multiply RowMajor r c) `at` 0 - where r = matrixFromVector RowMajor (dim u) u - c = matrixFromVector RowMajor 1 v +multiplyR = multiplyAux cmultiplyR +multiplyC = multiplyAux cmultiplyC ---------------------------------------------------------------------- @@ -251,14 +262,14 @@ subMatrixC (r0,c0) (rt,ct) x = subMatrixR (r0,2*c0) (rt,2*ct) . reshape (2*cols x) . asReal . cdat $ x -subMatrix :: (Field a) - => (Int,Int) -- ^ (r0,c0) starting position - -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix - -> Matrix a -> Matrix a -subMatrix st sz m - | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m) - | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m) - | otherwise = subMatrixG st sz m +--subMatrix :: (Field a) +-- => (Int,Int) -- ^ (r0,c0) starting position +-- -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix +-- -> Matrix a -> Matrix a +--subMatrix st sz m +-- | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m) +-- | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m) +-- | otherwise = subMatrixG st sz m subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) where subList s n = take n . drop s @@ -281,11 +292,11 @@ diagC = diagAux c_diagC "diagC" foreign import ccall "aux.h diagC" c_diagC :: TCVCM -- | diagonal matrix from a vector -diag :: (Num a, Field a) => Vector a -> Matrix a -diag v - | isReal (baseOf) v = scast $ diagR (scast v) - | isComp (baseOf) v = scast $ diagC (scast v) - | otherwise = diagG v +--diag :: (Num a, Field a) => Vector a -> Matrix a +--diag v +-- | isReal (baseOf) v = scast $ diagR (scast v) +-- | isComp (baseOf) v = scast $ diagC (scast v) +-- | otherwise = diagG v diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] where c = dim v @@ -313,13 +324,34 @@ fromColumns :: Field t => [Vector t] -> Matrix t fromColumns m = trans . fromRows $ m -- | Creates a list of vectors from the columns of a matrix -toColumns :: Field t => Matrix t -> [Vector t] +toColumns :: Storable t => Matrix t -> [Vector t] toColumns m = toRows . trans $ m -- | Reads a matrix position. -(@@>) :: Field t => Matrix t -> (Int,Int) -> t +(@@>) :: Storable t => Matrix t -> (Int,Int) -> t infixl 9 @@> m@M {rows = r, cols = c} @@> (i,j) | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" | otherwise = cdat m `at` (i*c+j) + +------------------------------------------------------------------ + +constantR :: Double -> Int -> Vector Double +constantR = constantAux cconstantR + +constantC :: Complex Double -> Int -> Vector (Complex Double) +constantC = constantAux cconstantC + +constantAux fun x n = unsafePerformIO $ do + v <- createVector n + px <- newArray [x] + fun px // vec v // check "constantAux" [] + free px + return v + +foreign import ccall safe "aux.h constantR" + cconstantR :: Ptr Double -> TV -- Double :> IO Int + +foreign import ccall safe "aux.h constantC" + cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 34132d8..6876685 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs @@ -1,3 +1,5 @@ +{-# OPTIONS_GHC -fglasgow-exts #-} + ----------------------------------------------------------------------------- -- | -- Module : Data.Packed.Internal.Tensor @@ -19,6 +21,8 @@ import Foreign.Storable import Data.List(sort,elemIndex,nub,foldl1',foldl') import GSL.Vector import Data.Packed.Matrix +import Data.Packed.Vector +import LinearAlgebra.Linear data IdxType = Covariant | Contravariant deriving (Show,Eq) @@ -171,6 +175,7 @@ compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where = t1 /= t2 && n1 == n2 +outer' u v = dat (outer u v) -- | tensor product without without any contractions rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t @@ -187,7 +192,7 @@ contraction2 t1 n1 t2 n2 = m = multiply RowMajor (trans m1) m2 -- | contraction of a tensor along two given indices -contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t +contraction1 :: (Linear Vector t) => Tensor t -> IdxName -> IdxName -> Tensor t contraction1 t name1 name2 = if compatIdx t name1 t name2 then sumT y @@ -197,7 +202,7 @@ contraction1 t name1 name2 = y = map head $ zipWith drop [0..] x -- | contraction of a tensor along a repeated index -contraction1c :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t +contraction1c :: (Linear Vector t) => Tensor t -> IdxName -> Tensor t contraction1c t n = contraction1 renamed n' n where n' = n++"'" -- hmmm renamed = withIdx t auxnames @@ -205,31 +210,31 @@ contraction1c t n = contraction1 renamed n' n (h,_:r) = break (==n) (map idxName (dims t)) -- | alternative and inefficient version of contraction2 -contraction2' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t +contraction2' :: (Linear Vector t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t contraction2' t1 n1 t2 n2 = if compatIdx t1 n1 t2 n2 then contraction1 (rawProduct t1 t2) n1 n2 else error "wrong contraction'" -- | applies a sequence of contractions -contractions :: (Field t, Num t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t +contractions :: (Linear Vector t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t contractions t pairs = foldl' contract1b t pairs where contract1b t (n1,n2) = contraction1 t n1 n2 -- | applies a sequence of contractions of same index -contractionsC :: (Field t, Num t) => Tensor t -> [IdxName] -> Tensor t +contractionsC :: (Linear Vector t) => Tensor t -> [IdxName] -> Tensor t contractionsC t is = foldl' contraction1c t is -- | applies a contraction on the first indices of the tensors -contractionF :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t +contractionF :: (Linear Vector t) => Tensor t -> Tensor t -> Tensor t contractionF t1 t2 = contraction2 t1 n1 t2 n2 where n1 = fn t1 n2 = fn t2 fn = idxName . head . dims -- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal -possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] +possibleContractions :: (Linear Vector t) => Tensor t -> Tensor t -> [Tensor t] possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] @@ -242,7 +247,7 @@ desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2] where x = zip [0..] (names t) -- | tensor product with the convention that repeated indices are contracted. -mulT :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t +mulT :: (Linear Vector t) => Tensor t -> Tensor t -> Tensor t mulT t1 t2 = r where t1r = contractionsC t1 (desiredContractions1 t1) t2r = contractionsC t2 (desiredContractions1 t2) @@ -254,10 +259,10 @@ mulT t1 t2 = r where ----------------------------------------------------------------- -- | tensor addition (for tensors with the same structure) -addT :: (Num a, Field a) => Tensor a -> Tensor a -> Tensor a +addT :: (Linear Vector a) => Tensor a -> Tensor a -> Tensor a addT a b = liftTensor2 add a b -sumT :: (Field a, Num a) => [Tensor a] -> Tensor a +sumT :: (Linear Vector a) => [Tensor a] -> Tensor a sumT l = foldl1' addT l ----------------------------------------------------------------- @@ -281,19 +286,19 @@ signature l | length (nub l) < length l = 0 | otherwise = -1 -sym :: (Field t, Num t) => Tensor t -> Tensor t +sym :: (Linear Vector t) => Tensor t -> Tensor t sym t = T (dims t) (ten (sym' (withIdx t seqind))) where sym' t = sumT $ map (flip tridx t) (perms (names t)) where nms = map idxName . dims -antisym :: (Field t, Num t) => Tensor t -> Tensor t +antisym :: (Linear Vector t) => Tensor t -> Tensor t antisym t = T (dims t) (ten (antisym' (withIdx t seqind))) where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t)) scsig t = scalar (signature (nms t)) `rawProduct` t where nms = map idxName . dims -- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product). -wedge :: (Field t, Fractional t) => Tensor t -> Tensor t -> Tensor t +wedge :: (Linear Vector t, Fractional t) => Tensor t -> Tensor t -> Tensor t wedge a b = antisym (rawProduct (norper a) (norper b)) where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) @@ -313,19 +318,19 @@ seqind :: [String] seqind = map show [1..] -- | completely antisymmetric covariant tensor of dimension n -leviCivita :: (Field t, Num t) => Int -> Tensor t +leviCivita :: (Linear Vector t) => Int -> Tensor t leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind' where auxbase = map tc (toRows (ident n)) tc = tensorFromVector Covariant -- | contraction of leviCivita with a list of vectors (and raise with euclidean metric) -innerLevi :: (Num t, Field t) => [Tensor t] -> Tensor t +innerLevi :: (Linear Vector t) => [Tensor t] -> Tensor t innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs where n = idxDim . head . dims . head $ vs -- | obtains the dual of a multivector (with euclidean metric) -dual :: (Field t, Fractional t) => Tensor t -> Tensor t +dual :: (Linear Vector t, Fractional t) => Tensor t -> Tensor t dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x where n = idxDim . head . dims $ t x = scalar (recip $ fromIntegral $ fact (rank t)) diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index ab93577..f2646a4 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs @@ -1,4 +1,4 @@ -{-# OPTIONS_GHC -fglasgow-exts #-} +{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Packed.Internal.Vector @@ -19,6 +19,8 @@ import Data.Packed.Internal.Common import Foreign import Complex import Control.Monad(when) +import Data.List(transpose) +import Debug.Trace(trace) type Vc t s = Int -> Ptr t -> s -- not yet admitted by my haddock version @@ -28,7 +30,7 @@ type Vc t s = Int -> Ptr t -> s vec :: Vector t -> (Vc t s) -> s vec v f = f (dim v) (ptr v) -baseOf v = (v `at` 0) +--baseOf v = (v `at` 0) createVector :: Storable a => Int -> IO (Vector a) createVector n = do @@ -78,9 +80,16 @@ subVector' k l (v@V {dim=n, ptr=p, fptr=fp}) | otherwise = v {dim=l, ptr=advancePtr p k} +-- | Reads a vector position. +(@>) :: Storable t => Vector t -> Int -> t +infixl 9 @> +(@>) = at + + + -- | creates a new Vector by joining a list of Vectors -join :: Field t => [Vector t] -> Vector t +join :: Storable t => [Vector t] -> Vector t join [] = error "joining zero vectors" join as = unsafePerformIO $ do let tot = sum (map dim as) @@ -103,34 +112,11 @@ asComplex :: Vector Double -> Vector (Complex Double) asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } -constantG x n = fromList (replicate n x) - -constantR :: Double -> Int -> Vector Double -constantR = constantAux cconstantR - -constantC :: Complex Double -> Int -> Vector (Complex Double) -constantC = constantAux cconstantC - -constantAux fun x n = unsafePerformIO $ do - v <- createVector n - px <- newArray [x] - fun px // vec v // check "constantAux" [] - free px - return v - -foreign import ccall safe "aux.h constantR" - cconstantR :: Ptr Double -> TV -- Double :> IO Int - -foreign import ccall safe "aux.h constantC" - cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int - -constant :: Field a => a -> Int -> Vector a -constant x n | isReal id x = scast $ constantR (scast x) n - | isComp id x = scast $ constantC (scast x) n - | otherwise = constantG x n +---------------------------------------------------------------- liftVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b liftVector f = fromList . map f . toList liftVector2 :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c liftVector2 f u v = fromList $ zipWith f (toList u) (toList v) + diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 2033dc7..2e8cb3d 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -134,3 +134,18 @@ asRow v = reshape (dim v) v asColumn :: Field a => Vector a -> Matrix a asColumn v = reshape 1 v + +------------------------------------------------ + +{- | Outer product of two vectors. + +@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3] +(3><3) + [ 5.0, 2.0, 3.0 + , 10.0, 4.0, 6.0 + , 15.0, 6.0, 9.0 ]@ +-} +outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t +outer u v = multiply RowMajor r c + where r = matrixFromVector RowMajor 1 u + c = matrixFromVector RowMajor (dim v) v diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs index 27ba6a3..867b77b 100644 --- a/lib/Data/Packed/Vector.hs +++ b/lib/Data/Packed/Vector.hs @@ -27,7 +27,7 @@ module Data.Packed.Vector ( import Data.Packed.Internal import Complex -import GSL.Vector +--import GSL.Vector -- | creates a complex vector from vectors with real and imaginary parts toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) @@ -50,7 +50,9 @@ linspace :: Int -> (Double, Double) -> Vector Double linspace n (a,b) = fromList [a::Double,a+delta .. b] where delta = (b-a)/(fromIntegral n -1) --- | Reads a vector position. -(@>) :: Field t => Vector t -> Int -> t -infixl 9 @> -(@>) = at + +dot :: (Field t) => Vector t -> Vector t -> t +dot u v = dat (multiply RowMajor r c) `at` 0 + where r = matrixFromVector RowMajor (dim u) u + c = matrixFromVector RowMajor 1 v + -- cgit v1.2.3