summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Tensor.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs73
1 files changed, 69 insertions, 4 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 960e3c5..b66d6b8 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -1,4 +1,4 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} 1--{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3-- | 3-- |
4-- Module : Data.Packed.Internal.Tensor 4-- Module : Data.Packed.Internal.Tensor
@@ -19,7 +19,7 @@ import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix 19import Data.Packed.Internal.Matrix
20import Foreign.Storable 20import Foreign.Storable
21 21
22data IdxTp = Covariant | Contravariant deriving Show 22data IdxTp = Covariant | Contravariant deriving (Show,Eq)
23 23
24data Tensor t = T { dims :: [(Int,(IdxTp,String))] 24data Tensor t = T { dims :: [(Int,(IdxTp,String))]
25 , ten :: Vector t 25 , ten :: Vector t
@@ -36,5 +36,70 @@ instance (Show a,Storable a) => Show (Tensor a) where
36 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) 36 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t)
37 37
38 38
39shdims [(n,(t,name))] = name++"["++show n++"]" 39shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]"
40shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file 40 where sym Covariant = "_"
41 sym Contravariant = "^"
42shdims (d:ds) = shdims [d] ++ "><"++ shdims ds
43
44
45
46findIdx name t = ((d1,d2),m) where
47 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t)
48 c = product (map fst d2)
49 m = matrixFromVector RowMajor c (ten t)
50
51putFirstIdx name t = (nd,m')
52 where ((d1,d2),m) = findIdx name t
53 m' = matrixFromVector RowMajor c $ cdat $ trans m
54 nd = d2++d1
55 c = dim (ten t) `div` (fst $ head d2)
56
57part t (name,k) = if k<0 || k>=l
58 then error $ "part "++show (name,k)++" out of range in "++show t
59 else T {dims = ds, ten = toRows m !! k}
60 where (d:ds,m) = putFirstIdx name t
61 (l,_) = d
62
63parts t name = map f (toRows m)
64 where (d:ds,m) = putFirstIdx name t
65 (l,_) = d
66 f t = T {dims=ds, ten=t}
67
68concatRename l1 l2 = l1 ++ map ren l2 where
69 ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s))
70 fs = map (snd.snd) l1
71
72prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2)
73
74contraction t1 n1 t2 n2 =
75 if compatIdx t1 n1 t2 n2
76 then T (concatRename (tail d1) (tail d2)) (cdat m)
77 else error "wrong contraction'"
78 where (d1,m1) = putFirstIdx n1 t1
79 (d2,m2) = putFirstIdx n2 t2
80 m = multiply RowMajor (trans m1) m2
81
82sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
83
84contract1 t name1 name2 = T d $ fromList $ sumT y
85 where d = dims (head y)
86 x = (map (flip parts name2) (parts t name1))
87 y = map head $ zipWith drop [0..] x
88
89contraction' t1 n1 t2 n2 =
90 if compatIdx t1 n1 t2 n2
91 then contract1 (prod t1 t2) n1 (n2++"'")
92 else error "wrong contraction'"
93
94tridx [] t = t
95tridx (name:rest) t = T (d:ds) (join ts) where
96 ((_,d:_),_) = findIdx name t
97 ps = map (tridx rest) (parts t name)
98 ts = map ten ps
99 ds = dims (head ps)
100
101compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2
102
103compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where
104 d1 = head $ snd $ fst $ findIdx n1 t1
105 d2 = head $ snd $ fst $ findIdx n2 t2