diff options
Diffstat (limited to 'examples/pru.hs')
-rw-r--r-- | examples/pru.hs | 76 |
1 files changed, 46 insertions, 30 deletions
diff --git a/examples/pru.hs b/examples/pru.hs index bddc08f..a935d93 100644 --- a/examples/pru.hs +++ b/examples/pru.hs | |||
@@ -7,7 +7,7 @@ import Data.Packed.Internal.Tensor | |||
7 | 7 | ||
8 | import Complex | 8 | import Complex |
9 | import Numeric(showGFloat) | 9 | import Numeric(showGFloat) |
10 | import Data.List(transpose,intersperse) | 10 | import Data.List(transpose,intersperse,sort) |
11 | import Foreign.Storable | 11 | import Foreign.Storable |
12 | 12 | ||
13 | r >< c = f where | 13 | r >< c = f where |
@@ -22,8 +22,6 @@ r >|< c = f where | |||
22 | ++show (dim v) ++"in ("++show r++"><"++show c++")" | 22 | ++show (dim v) ++"in ("++show r++"><"++show c++")" |
23 | where v = fromList l | 23 | where v = fromList l |
24 | 24 | ||
25 | |||
26 | |||
27 | vr = fromList [1..15::Double] | 25 | vr = fromList [1..15::Double] |
28 | vc = fromList (map (\x->x :+ (x+1)) [1..15::Double]) | 26 | vc = fromList (map (\x->x :+ (x+1)) [1..15::Double]) |
29 | 27 | ||
@@ -62,38 +60,56 @@ main = do | |||
62 | t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] | 60 | t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] |
63 | 61 | ||
64 | 62 | ||
65 | findIdx name t = ((d1,d2),m) where | ||
66 | (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) | ||
67 | c = product (map fst d2) | ||
68 | m = matrixFromVector RowMajor c (ten t) | ||
69 | 63 | ||
70 | putFirstIdx name t = (nd,m') | 64 | t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double] |
71 | where ((d1,d2),m) = findIdx name t | 65 | t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double] |
72 | m' = matrixFromVector RowMajor c $ cdat $ trans m | ||
73 | nd = d2++d1 | ||
74 | c = dim (ten t) `div` (fst $ head d2) | ||
75 | 66 | ||
76 | part t (name,k) = if k<0 || k>=l | ||
77 | then error $ "part "++show (name,k)++" out of range in "++show t | ||
78 | else T {dims = ds, ten = toRows m !! k} | ||
79 | where (d:ds,m) = putFirstIdx name t | ||
80 | (l,_) = d | ||
81 | 67 | ||
82 | parts t name = map f (toRows m) | ||
83 | where (d:ds,m) = putFirstIdx name t | ||
84 | (l,_) = d | ||
85 | f t = T {dims=ds, ten=t} | ||
86 | 68 | ||
87 | t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double] | ||
88 | t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double] | ||
89 | 69 | ||
90 | contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1)) | ||
91 | 70 | ||
92 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | ||
93 | 71 | ||
94 | on f g = \x y -> f (g x) (g y) | 72 | delta i j | i==j = 1 |
73 | | otherwise = 0 | ||
74 | |||
75 | e i n = fromList [ delta k i | k <- [1..n]] | ||
76 | |||
77 | ident n = fromRows [ e i n | i <- [1..n]] | ||
78 | |||
79 | diag l = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] | ||
80 | where c = length l | ||
81 | |||
82 | tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v} | ||
83 | tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m} | ||
84 | |||
85 | td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diag [1..4] :: Tensor Double | ||
86 | |||
87 | tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double | ||
88 | tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double | ||
89 | |||
90 | tq = T [(3,(Covariant,"p")),(2,(Covariant,"q")),(2,(Covariant,"r"))] $ fromList [11 .. 22] :: Tensor Double | ||
91 | |||
92 | r1 = contraction tt "j" tq "p" | ||
93 | r1' = contraction' tt "j" tq "p" | ||
94 | |||
95 | pru = do | ||
96 | mapM_ (putStrLn.shdims.dims.normal) (contractions t1 t2) | ||
97 | let t1 = contraction tt "i" tq "q" | ||
98 | print $ normal t1 | ||
99 | print $ foldl part t1 [("j",0),("p'",1),("r'",1)] | ||
100 | let t2 = contraction' tt "i" tq "q" | ||
101 | print $ normal t2 | ||
102 | print $ foldl part t2 [("j",0),("p'",1),("r'",1)] | ||
103 | let t1 = contraction tq "q" tt "i" | ||
104 | print $ normal t1 | ||
105 | print $ foldl part t1 [("j'",0),("p",1),("r",1)] | ||
106 | let t2 = contraction' tq "q" tt "i" | ||
107 | print $ normal t2 | ||
108 | print $ foldl part t2 [("j'",0),("p",1),("r",1)] | ||
109 | |||
110 | |||
111 | names t = sort $ map (snd.snd) (dims t) | ||
112 | |||
113 | normal t = tridx (names t) t | ||
95 | 114 | ||
96 | contract t1 n1 t2 n2 = T (tail d1++tail d2) (cdat m) | 115 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] |
97 | where (d1,m1) = putFirstIdx n1 t1 | ||
98 | (d2,m2) = putFirstIdx n2 t2 | ||
99 | m = multiply RowMajor (trans m2) m1 \ No newline at end of file | ||