summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-06 17:40:09 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-06 17:40:09 +0000
commite7c03c1ab4de85e7a700d2eafaebd37f4607c51f (patch)
tree4512d18907d88d0390671fcde4e8886d30cd0492 /examples
parenta4254a0b9bfbd720efbe42b86aa50107a74d56c7 (diff)
working on tensor contractions
Diffstat (limited to 'examples')
-rw-r--r--examples/pru.hs76
1 files changed, 46 insertions, 30 deletions
diff --git a/examples/pru.hs b/examples/pru.hs
index bddc08f..a935d93 100644
--- a/examples/pru.hs
+++ b/examples/pru.hs
@@ -7,7 +7,7 @@ import Data.Packed.Internal.Tensor
7 7
8import Complex 8import Complex
9import Numeric(showGFloat) 9import Numeric(showGFloat)
10import Data.List(transpose,intersperse) 10import Data.List(transpose,intersperse,sort)
11import Foreign.Storable 11import Foreign.Storable
12 12
13r >< c = f where 13r >< c = f where
@@ -22,8 +22,6 @@ r >|< c = f where
22 ++show (dim v) ++"in ("++show r++"><"++show c++")" 22 ++show (dim v) ++"in ("++show r++"><"++show c++")"
23 where v = fromList l 23 where v = fromList l
24 24
25
26
27vr = fromList [1..15::Double] 25vr = fromList [1..15::Double]
28vc = fromList (map (\x->x :+ (x+1)) [1..15::Double]) 26vc = fromList (map (\x->x :+ (x+1)) [1..15::Double])
29 27
@@ -62,38 +60,56 @@ main = do
62t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] 60t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double]
63 61
64 62
65findIdx name t = ((d1,d2),m) where
66 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t)
67 c = product (map fst d2)
68 m = matrixFromVector RowMajor c (ten t)
69 63
70putFirstIdx name t = (nd,m') 64t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double]
71 where ((d1,d2),m) = findIdx name t 65t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double]
72 m' = matrixFromVector RowMajor c $ cdat $ trans m
73 nd = d2++d1
74 c = dim (ten t) `div` (fst $ head d2)
75 66
76part t (name,k) = if k<0 || k>=l
77 then error $ "part "++show (name,k)++" out of range in "++show t
78 else T {dims = ds, ten = toRows m !! k}
79 where (d:ds,m) = putFirstIdx name t
80 (l,_) = d
81 67
82parts t name = map f (toRows m)
83 where (d:ds,m) = putFirstIdx name t
84 (l,_) = d
85 f t = T {dims=ds, ten=t}
86 68
87t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double]
88t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double]
89 69
90contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1))
91 70
92sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
93 71
94on f g = \x y -> f (g x) (g y) 72delta i j | i==j = 1
73 | otherwise = 0
74
75e i n = fromList [ delta k i | k <- [1..n]]
76
77ident n = fromRows [ e i n | i <- [1..n]]
78
79diag l = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]]
80 where c = length l
81
82tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v}
83tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m}
84
85td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diag [1..4] :: Tensor Double
86
87tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double
88tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double
89
90tq = T [(3,(Covariant,"p")),(2,(Covariant,"q")),(2,(Covariant,"r"))] $ fromList [11 .. 22] :: Tensor Double
91
92r1 = contraction tt "j" tq "p"
93r1' = contraction' tt "j" tq "p"
94
95pru = do
96 mapM_ (putStrLn.shdims.dims.normal) (contractions t1 t2)
97 let t1 = contraction tt "i" tq "q"
98 print $ normal t1
99 print $ foldl part t1 [("j",0),("p'",1),("r'",1)]
100 let t2 = contraction' tt "i" tq "q"
101 print $ normal t2
102 print $ foldl part t2 [("j",0),("p'",1),("r'",1)]
103 let t1 = contraction tq "q" tt "i"
104 print $ normal t1
105 print $ foldl part t1 [("j'",0),("p",1),("r",1)]
106 let t2 = contraction' tq "q" tt "i"
107 print $ normal t2
108 print $ foldl part t2 [("j'",0),("p",1),("r",1)]
109
110
111names t = sort $ map (snd.snd) (dims t)
112
113normal t = tridx (names t) t
95 114
96contract t1 n1 t2 n2 = T (tail d1++tail d2) (cdat m) 115contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
97 where (d1,m1) = putFirstIdx n1 t1
98 (d2,m2) = putFirstIdx n2 t2
99 m = multiply RowMajor (trans m2) m1 \ No newline at end of file