diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-06 17:40:09 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-06 17:40:09 +0000 |
commit | e7c03c1ab4de85e7a700d2eafaebd37f4607c51f (patch) | |
tree | 4512d18907d88d0390671fcde4e8886d30cd0492 /lib/Data/Packed/Internal/Tensor.hs | |
parent | a4254a0b9bfbd720efbe42b86aa50107a74d56c7 (diff) |
working on tensor contractions
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 73 |
1 files changed, 69 insertions, 4 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 960e3c5..b66d6b8 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | 1 | --{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} |
2 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
3 | -- | | 3 | -- | |
4 | -- Module : Data.Packed.Internal.Tensor | 4 | -- Module : Data.Packed.Internal.Tensor |
@@ -19,7 +19,7 @@ import Data.Packed.Internal.Vector | |||
19 | import Data.Packed.Internal.Matrix | 19 | import Data.Packed.Internal.Matrix |
20 | import Foreign.Storable | 20 | import Foreign.Storable |
21 | 21 | ||
22 | data IdxTp = Covariant | Contravariant deriving Show | 22 | data IdxTp = Covariant | Contravariant deriving (Show,Eq) |
23 | 23 | ||
24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] | 24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] |
25 | , ten :: Vector t | 25 | , ten :: Vector t |
@@ -36,5 +36,70 @@ instance (Show a,Storable a) => Show (Tensor a) where | |||
36 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) | 36 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) |
37 | 37 | ||
38 | 38 | ||
39 | shdims [(n,(t,name))] = name++"["++show n++"]" | 39 | shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]" |
40 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file | 40 | where sym Covariant = "_" |
41 | sym Contravariant = "^" | ||
42 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds | ||
43 | |||
44 | |||
45 | |||
46 | findIdx name t = ((d1,d2),m) where | ||
47 | (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) | ||
48 | c = product (map fst d2) | ||
49 | m = matrixFromVector RowMajor c (ten t) | ||
50 | |||
51 | putFirstIdx name t = (nd,m') | ||
52 | where ((d1,d2),m) = findIdx name t | ||
53 | m' = matrixFromVector RowMajor c $ cdat $ trans m | ||
54 | nd = d2++d1 | ||
55 | c = dim (ten t) `div` (fst $ head d2) | ||
56 | |||
57 | part t (name,k) = if k<0 || k>=l | ||
58 | then error $ "part "++show (name,k)++" out of range in "++show t | ||
59 | else T {dims = ds, ten = toRows m !! k} | ||
60 | where (d:ds,m) = putFirstIdx name t | ||
61 | (l,_) = d | ||
62 | |||
63 | parts t name = map f (toRows m) | ||
64 | where (d:ds,m) = putFirstIdx name t | ||
65 | (l,_) = d | ||
66 | f t = T {dims=ds, ten=t} | ||
67 | |||
68 | concatRename l1 l2 = l1 ++ map ren l2 where | ||
69 | ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s)) | ||
70 | fs = map (snd.snd) l1 | ||
71 | |||
72 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2) | ||
73 | |||
74 | contraction t1 n1 t2 n2 = | ||
75 | if compatIdx t1 n1 t2 n2 | ||
76 | then T (concatRename (tail d1) (tail d2)) (cdat m) | ||
77 | else error "wrong contraction'" | ||
78 | where (d1,m1) = putFirstIdx n1 t1 | ||
79 | (d2,m2) = putFirstIdx n2 t2 | ||
80 | m = multiply RowMajor (trans m1) m2 | ||
81 | |||
82 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | ||
83 | |||
84 | contract1 t name1 name2 = T d $ fromList $ sumT y | ||
85 | where d = dims (head y) | ||
86 | x = (map (flip parts name2) (parts t name1)) | ||
87 | y = map head $ zipWith drop [0..] x | ||
88 | |||
89 | contraction' t1 n1 t2 n2 = | ||
90 | if compatIdx t1 n1 t2 n2 | ||
91 | then contract1 (prod t1 t2) n1 (n2++"'") | ||
92 | else error "wrong contraction'" | ||
93 | |||
94 | tridx [] t = t | ||
95 | tridx (name:rest) t = T (d:ds) (join ts) where | ||
96 | ((_,d:_),_) = findIdx name t | ||
97 | ps = map (tridx rest) (parts t name) | ||
98 | ts = map ten ps | ||
99 | ds = dims (head ps) | ||
100 | |||
101 | compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2 | ||
102 | |||
103 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where | ||
104 | d1 = head $ snd $ fst $ findIdx n1 t1 | ||
105 | d2 = head $ snd $ fst $ findIdx n2 t2 | ||