From e7c03c1ab4de85e7a700d2eafaebd37f4607c51f Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 6 Jun 2007 17:40:09 +0000 Subject: working on tensor contractions --- lib/Data/Packed/Internal/Matrix.hs | 17 ++++++++- lib/Data/Packed/Internal/Tensor.hs | 73 +++++++++++++++++++++++++++++++++++--- lib/Data/Packed/Internal/Vector.hs | 2 ++ 3 files changed, 87 insertions(+), 5 deletions(-) (limited to 'lib/Data/Packed') diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index a2a70dd..ec6657a 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -65,6 +65,15 @@ partit :: Int -> [a] -> [[a]] partit _ [] = [] partit n l = take n l : partit n (drop n l) +-- | obtains the common value of a property of a list +common :: (Eq a) => (b->a) -> [b] -> Maybe a +common f = commonval . map f where + commonval :: (Eq a) => [a] -> Maybe a + commonval [] = Nothing + commonval [a] = Just a + commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing + + toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m | otherwise = partit (cols m) . toList . dat $ m @@ -115,7 +124,7 @@ transdataAux fun c1 d c2 = else unsafePerformIO $ do v <- createVector (dim d) fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] - putStrLn "---> transdataAux" + --putStrLn "---> transdataAux" return v where r1 = dim d `div` c1 r2 = dim d `div` c2 @@ -136,6 +145,12 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 --{-# RULES "transdataR" transdata=transdataR #-} --{-# RULES "transdataC" transdata=transdataC #-} +-- | creates a Matrix from a list of vectors +fromRows :: Field t => [Vector t] -> Matrix t +fromRows vs = case common dim vs of + Nothing -> error "fromRows applied to [] or to vectors with different sizes" + Just c -> reshape c (join vs) + -- | extracts the rows of a matrix as a list of vectors toRows :: Storable t => Matrix t -> [Vector t] toRows m = toRows' 0 where diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 960e3c5..b66d6b8 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs @@ -1,4 +1,4 @@ -{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} +--{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Packed.Internal.Tensor @@ -19,7 +19,7 @@ import Data.Packed.Internal.Vector import Data.Packed.Internal.Matrix import Foreign.Storable -data IdxTp = Covariant | Contravariant deriving Show +data IdxTp = Covariant | Contravariant deriving (Show,Eq) data Tensor t = T { dims :: [(Int,(IdxTp,String))] , ten :: Vector t @@ -36,5 +36,70 @@ instance (Show a,Storable a) => Show (Tensor a) where show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) -shdims [(n,(t,name))] = name++"["++show n++"]" -shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file +shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]" + where sym Covariant = "_" + sym Contravariant = "^" +shdims (d:ds) = shdims [d] ++ "><"++ shdims ds + + + +findIdx name t = ((d1,d2),m) where + (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) + c = product (map fst d2) + m = matrixFromVector RowMajor c (ten t) + +putFirstIdx name t = (nd,m') + where ((d1,d2),m) = findIdx name t + m' = matrixFromVector RowMajor c $ cdat $ trans m + nd = d2++d1 + c = dim (ten t) `div` (fst $ head d2) + +part t (name,k) = if k<0 || k>=l + then error $ "part "++show (name,k)++" out of range in "++show t + else T {dims = ds, ten = toRows m !! k} + where (d:ds,m) = putFirstIdx name t + (l,_) = d + +parts t name = map f (toRows m) + where (d:ds,m) = putFirstIdx name t + (l,_) = d + f t = T {dims=ds, ten=t} + +concatRename l1 l2 = l1 ++ map ren l2 where + ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s)) + fs = map (snd.snd) l1 + +prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2) + +contraction t1 n1 t2 n2 = + if compatIdx t1 n1 t2 n2 + then T (concatRename (tail d1) (tail d2)) (cdat m) + else error "wrong contraction'" + where (d1,m1) = putFirstIdx n1 t1 + (d2,m2) = putFirstIdx n2 t2 + m = multiply RowMajor (trans m1) m2 + +sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) + +contract1 t name1 name2 = T d $ fromList $ sumT y + where d = dims (head y) + x = (map (flip parts name2) (parts t name1)) + y = map head $ zipWith drop [0..] x + +contraction' t1 n1 t2 n2 = + if compatIdx t1 n1 t2 n2 + then contract1 (prod t1 t2) n1 (n2++"'") + else error "wrong contraction'" + +tridx [] t = t +tridx (name:rest) t = T (d:ds) (join ts) where + ((_,d:_),_) = findIdx name t + ps = map (tridx rest) (parts t name) + ts = map ten ps + ds = dims (head ps) + +compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2 + +compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where + d1 = head $ snd $ fst $ findIdx n1 t1 + d2 = head $ snd $ fst $ findIdx n2 t2 diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 6ed9339..36d5df7 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs @@ -35,6 +35,8 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where -- poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- ---------------------------------------------------------------------- +on f g = \x y -> f (g x) (g y) + (//) :: x -> (x -> y) -> y infixl 0 // (//) = flip ($) -- cgit v1.2.3