diff options
Diffstat (limited to 'examples/pru.hs')
-rw-r--r-- | examples/pru.hs | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/examples/pru.hs b/examples/pru.hs index 8b25780..10789d2 100644 --- a/examples/pru.hs +++ b/examples/pru.hs | |||
@@ -38,7 +38,7 @@ bf = (3>|<4) [7,11,15,8,12,16,9,13,17,10,14,18::Double] | |||
38 | 38 | ||
39 | a |=| b = rows a == rows b && | 39 | a |=| b = rows a == rows b && |
40 | cols a == cols b && | 40 | cols a == cols b && |
41 | toList (dat a) == toList (dat b) | 41 | toList (cdat a) == toList (cdat b) |
42 | 42 | ||
43 | mulC a b = multiply RowMajor a b | 43 | mulC a b = multiply RowMajor a b |
44 | mulF a b = multiply ColumnMajor a b | 44 | mulF a b = multiply ColumnMajor a b |
@@ -75,15 +75,14 @@ delta i j | i==j = 1 | |||
75 | 75 | ||
76 | e i n = fromList [ delta k i | k <- [1..n]] | 76 | e i n = fromList [ delta k i | k <- [1..n]] |
77 | 77 | ||
78 | ident n = fromRows [ e i n | i <- [1..n]] | 78 | diagl = diag.fromList |
79 | 79 | ||
80 | diag l = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] | 80 | ident n = diag (constant n 1) |
81 | where c = length l | ||
82 | 81 | ||
83 | tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v} | 82 | tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v} |
84 | tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m} | 83 | tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m} |
85 | 84 | ||
86 | td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diag [1..4] :: Tensor Double | 85 | td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diagl [1..4] :: Tensor Double |
87 | 86 | ||
88 | tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double | 87 | tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double |
89 | tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double | 88 | tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double |
@@ -114,3 +113,4 @@ names t = sort $ map (snd.snd) (dims t) | |||
114 | normal t = tridx (names t) t | 113 | normal t = tridx (names t) t |
115 | 114 | ||
116 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] | 115 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] |
116 | |||