From e7c03c1ab4de85e7a700d2eafaebd37f4607c51f Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 6 Jun 2007 17:40:09 +0000 Subject: working on tensor contractions --- examples/pru.hs | 76 ++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 46 insertions(+), 30 deletions(-) (limited to 'examples') diff --git a/examples/pru.hs b/examples/pru.hs index bddc08f..a935d93 100644 --- a/examples/pru.hs +++ b/examples/pru.hs @@ -7,7 +7,7 @@ import Data.Packed.Internal.Tensor import Complex import Numeric(showGFloat) -import Data.List(transpose,intersperse) +import Data.List(transpose,intersperse,sort) import Foreign.Storable r >< c = f where @@ -22,8 +22,6 @@ r >|< c = f where ++show (dim v) ++"in ("++show r++"><"++show c++")" where v = fromList l - - vr = fromList [1..15::Double] vc = fromList (map (\x->x :+ (x+1)) [1..15::Double]) @@ -62,38 +60,56 @@ main = do t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] -findIdx name t = ((d1,d2),m) where - (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) - c = product (map fst d2) - m = matrixFromVector RowMajor c (ten t) -putFirstIdx name t = (nd,m') - where ((d1,d2),m) = findIdx name t - m' = matrixFromVector RowMajor c $ cdat $ trans m - nd = d2++d1 - c = dim (ten t) `div` (fst $ head d2) +t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double] +t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double] -part t (name,k) = if k<0 || k>=l - then error $ "part "++show (name,k)++" out of range in "++show t - else T {dims = ds, ten = toRows m !! k} - where (d:ds,m) = putFirstIdx name t - (l,_) = d -parts t name = map f (toRows m) - where (d:ds,m) = putFirstIdx name t - (l,_) = d - f t = T {dims=ds, ten=t} -t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double] -t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double] -contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1)) -sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) -on f g = \x y -> f (g x) (g y) +delta i j | i==j = 1 + | otherwise = 0 + +e i n = fromList [ delta k i | k <- [1..n]] + +ident n = fromRows [ e i n | i <- [1..n]] + +diag l = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] + where c = length l + +tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v} +tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m} + +td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diag [1..4] :: Tensor Double + +tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double +tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double + +tq = T [(3,(Covariant,"p")),(2,(Covariant,"q")),(2,(Covariant,"r"))] $ fromList [11 .. 22] :: Tensor Double + +r1 = contraction tt "j" tq "p" +r1' = contraction' tt "j" tq "p" + +pru = do + mapM_ (putStrLn.shdims.dims.normal) (contractions t1 t2) + let t1 = contraction tt "i" tq "q" + print $ normal t1 + print $ foldl part t1 [("j",0),("p'",1),("r'",1)] + let t2 = contraction' tt "i" tq "q" + print $ normal t2 + print $ foldl part t2 [("j",0),("p'",1),("r'",1)] + let t1 = contraction tq "q" tt "i" + print $ normal t1 + print $ foldl part t1 [("j'",0),("p",1),("r",1)] + let t2 = contraction' tq "q" tt "i" + print $ normal t2 + print $ foldl part t2 [("j'",0),("p",1),("r",1)] + + +names t = sort $ map (snd.snd) (dims t) + +normal t = tridx (names t) t -contract t1 n1 t2 n2 = T (tail d1++tail d2) (cdat m) - where (d1,m1) = putFirstIdx n1 t1 - (d2,m2) = putFirstIdx n2 t2 - m = multiply RowMajor (trans m2) m1 \ No newline at end of file +contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] -- cgit v1.2.3