summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-05 10:09:17 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-05 10:09:17 +0000
commit1fb4ea70c517050d3cbad75357a4fffbf5a40e7b (patch)
tree6107b470be9300297f9f08a4280fbf46faf6a862 /examples
parent7430630fa0504296b796223e01cbd417b88650ef (diff)
working on contraction
Diffstat (limited to 'examples')
-rw-r--r--examples/pru.hs24
1 files changed, 14 insertions, 10 deletions
diff --git a/examples/pru.hs b/examples/pru.hs
index d6dc5d4..bddc08f 100644
--- a/examples/pru.hs
+++ b/examples/pru.hs
@@ -55,24 +55,23 @@ rd = (2><2)
55 55
56main = do 56main = do
57 print $ r |=| rd 57 print $ r |=| rd
58 print $ foldl part t [("p",1),("q",0),("r",2)]
58 print $ foldl part t [("p",1),("r",2),("q",0)] 59 print $ foldl part t [("p",1),("r",2),("q",0)]
60 print $ foldl part t $ reverse [("p",1),("r",2),("q",0)]
59 61
60t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] 62t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double]
61 63
62 64
63findIdx name t = ((d1,d2),m) where 65findIdx name t = ((d1,d2),m) where
64 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) 66 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t)
65 c = product (map fst (tail d2)) 67 c = product (map fst d2)
66 m = matrixFromVector RowMajor c (ten t) 68 m = matrixFromVector RowMajor c (ten t)
67 69
68 70putFirstIdx name t = (nd,m')
69putFirstIdx name t =
70 if null d1
71 then (nd,m)
72 else (nd,m')
73 where ((d1,d2),m) = findIdx name t 71 where ((d1,d2),m) = findIdx name t
74 m' = trans $ matrixFromVector RowMajor (fst $ head d2) $ dat m 72 m' = matrixFromVector RowMajor c $ cdat $ trans m
75 nd = d2++d1 73 nd = d2++d1
74 c = dim (ten t) `div` (fst $ head d2)
76 75
77part t (name,k) = if k<0 || k>=l 76part t (name,k) = if k<0 || k>=l
78 then error $ "part "++show (name,k)++" out of range in "++show t 77 then error $ "part "++show (name,k)++" out of range in "++show t
@@ -88,8 +87,13 @@ parts t name = map f (toRows m)
88t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double] 87t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double]
89t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double] 88t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double]
90 89
91--contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1)) 90contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1))
91
92sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
92 93
93--sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) 94on f g = \x y -> f (g x) (g y)
94 95
95on f g = \x y -> f (g x) (g y) \ No newline at end of file 96contract t1 n1 t2 n2 = T (tail d1++tail d2) (cdat m)
97 where (d1,m1) = putFirstIdx n1 t1
98 (d2,m2) = putFirstIdx n2 t2
99 m = multiply RowMajor (trans m2) m1 \ No newline at end of file