summaryrefslogtreecommitdiff
path: root/examples/pru.hs
diff options
context:
space:
mode:
Diffstat (limited to 'examples/pru.hs')
-rw-r--r--examples/pru.hs10
1 files changed, 5 insertions, 5 deletions
diff --git a/examples/pru.hs b/examples/pru.hs
index 8b25780..10789d2 100644
--- a/examples/pru.hs
+++ b/examples/pru.hs
@@ -38,7 +38,7 @@ bf = (3>|<4) [7,11,15,8,12,16,9,13,17,10,14,18::Double]
38 38
39a |=| b = rows a == rows b && 39a |=| b = rows a == rows b &&
40 cols a == cols b && 40 cols a == cols b &&
41 toList (dat a) == toList (dat b) 41 toList (cdat a) == toList (cdat b)
42 42
43mulC a b = multiply RowMajor a b 43mulC a b = multiply RowMajor a b
44mulF a b = multiply ColumnMajor a b 44mulF a b = multiply ColumnMajor a b
@@ -75,15 +75,14 @@ delta i j | i==j = 1
75 75
76e i n = fromList [ delta k i | k <- [1..n]] 76e i n = fromList [ delta k i | k <- [1..n]]
77 77
78ident n = fromRows [ e i n | i <- [1..n]] 78diagl = diag.fromList
79 79
80diag l = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] 80ident n = diag (constant n 1)
81 where c = length l
82 81
83tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v} 82tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v}
84tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m} 83tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m}
85 84
86td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diag [1..4] :: Tensor Double 85td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diagl [1..4] :: Tensor Double
87 86
88tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double 87tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double
89tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double 88tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double
@@ -114,3 +113,4 @@ names t = sort $ map (snd.snd) (dims t)
114normal t = tridx (names t) t 113normal t = tridx (names t) t
115 114
116contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] 115contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
116