diff options
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 249 |
1 files changed, 159 insertions, 90 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 4430ebc..34132d8 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -8,7 +8,7 @@ | |||
8 | -- Stability : provisional | 8 | -- Stability : provisional |
9 | -- Portability : portable (uses FFI) | 9 | -- Portability : portable (uses FFI) |
10 | -- | 10 | -- |
11 | -- basic tensor operations | 11 | -- support for basic tensor operations |
12 | -- | 12 | -- |
13 | ----------------------------------------------------------------------------- | 13 | ----------------------------------------------------------------------------- |
14 | 14 | ||
@@ -26,28 +26,83 @@ type IdxName = String | |||
26 | 26 | ||
27 | data IdxDesc = IdxDesc { idxDim :: Int, | 27 | data IdxDesc = IdxDesc { idxDim :: Int, |
28 | idxType :: IdxType, | 28 | idxType :: IdxType, |
29 | idxName :: IdxName } deriving Show | 29 | idxName :: IdxName } deriving Eq |
30 | |||
31 | instance Show IdxDesc where | ||
32 | show (IdxDesc n t name) = name ++ sym t ++"["++show n++"]" | ||
33 | where sym Covariant = "_" | ||
34 | sym Contravariant = "^" | ||
35 | |||
30 | 36 | ||
31 | data Tensor t = T { dims :: [IdxDesc] | 37 | data Tensor t = T { dims :: [IdxDesc] |
32 | , ten :: Vector t | 38 | , ten :: Vector t |
33 | } | 39 | } |
34 | 40 | ||
35 | -- | tensor rank (number of indices) | 41 | -- | returns the coordinates of a tensor in row - major order |
36 | rank :: Tensor t -> Int | 42 | coords :: Tensor t -> Vector t |
37 | rank = length . dims | 43 | coords = ten |
38 | 44 | ||
39 | instance (Show a,Storable a) => Show (Tensor a) where | 45 | instance (Show a, Field a) => Show (Tensor a) where |
40 | show T {dims = [], ten = t} = "scalar "++show (t `at` 0) | 46 | show T {dims = [], ten = t} = "scalar "++show (t `at` 0) |
41 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) | 47 | show t = "("++shdims (dims t) ++") "++ showdt t |
48 | |||
49 | asMatrix t = reshape (idxDim $ dims t!!1) (ten t) | ||
50 | |||
51 | showdt t | rank t == 1 = show (toList (ten t)) | ||
52 | | rank t == 2 = ('\n':) . dsp . map (map show) . toLists $ asMatrix $ t | ||
53 | | otherwise = concatMap showdt $ parts t (head (names t)) | ||
42 | 54 | ||
43 | -- | a nice description of the tensor structure | 55 | -- | a nice description of the tensor structure |
44 | shdims :: [IdxDesc] -> String | 56 | shdims :: [IdxDesc] -> String |
45 | shdims [] = "" | 57 | shdims [] = "" |
46 | shdims [d] = shdim d | 58 | shdims [d] = show d |
47 | shdims (d:ds) = shdim d ++ "><"++ shdims ds | 59 | shdims (d:ds) = show d ++ "><"++ shdims ds |
48 | shdim (IdxDesc n t name) = name ++ sym t ++"["++show n++"]" | 60 | |
49 | where sym Covariant = "_" | 61 | -- | tensor rank (number of indices) |
50 | sym Contravariant = "^" | 62 | rank :: Tensor t -> Int |
63 | rank = length . dims | ||
64 | |||
65 | names :: Tensor t -> [IdxName] | ||
66 | names t = map idxName (dims t) | ||
67 | |||
68 | -- | number of contravariant and covariant indices | ||
69 | structure :: Tensor t -> (Int,Int) | ||
70 | structure t = (rank t - n, n) where | ||
71 | n = length $ filter isCov (dims t) | ||
72 | isCov d = idxType d == Covariant | ||
73 | |||
74 | -- | creates a rank-zero tensor from a scalar | ||
75 | scalar :: Storable t => t -> Tensor t | ||
76 | scalar x = T [] (fromList [x]) | ||
77 | |||
78 | -- | Creates a tensor from a signed list of dimensions (positive = contravariant, negative = covariant) and a Vector containing the coordinates in row major order. | ||
79 | tensor :: [Int] -> Vector a -> Tensor a | ||
80 | tensor dssig vec = T d v `withIdx` seqind where | ||
81 | n = product (map abs dssig) | ||
82 | v = if dim vec == n then vec else error "wrong arguments for tensor" | ||
83 | d = map cr dssig | ||
84 | cr n | n > 0 = IdxDesc {idxName = "", idxDim = n, idxType = Contravariant} | ||
85 | | n < 0 = IdxDesc {idxName = "", idxDim = -n, idxType = Covariant } | ||
86 | |||
87 | |||
88 | tensorFromVector :: IdxType -> Vector t -> Tensor t | ||
89 | tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v} | ||
90 | |||
91 | tensorFromMatrix :: IdxType -> IdxType -> Matrix t -> Tensor t | ||
92 | tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"] | ||
93 | , ten = cdat m} | ||
94 | |||
95 | |||
96 | liftTensor :: (Vector a -> Vector b) -> Tensor a -> Tensor b | ||
97 | liftTensor f (T d v) = T d (f v) | ||
98 | |||
99 | liftTensor2 :: (Vector a -> Vector b -> Vector c) -> Tensor a -> Tensor b -> Tensor c | ||
100 | liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2) | ||
101 | | otherwise = error "liftTensor2 with incompatible tensors" | ||
102 | where compat a b = length a == length b | ||
103 | |||
104 | |||
105 | |||
51 | 106 | ||
52 | -- | express the tensor as a matrix with the given index in columns | 107 | -- | express the tensor as a matrix with the given index in columns |
53 | findIdx :: (Field t) => IdxName -> Tensor t | 108 | findIdx :: (Field t) => IdxName -> Tensor t |
@@ -65,6 +120,32 @@ putFirstIdx name t = (nd,m') | |||
65 | nd = d2++d1 | 120 | nd = d2++d1 |
66 | c = dim (ten t) `div` (idxDim $ head d2) | 121 | c = dim (ten t) `div` (idxDim $ head d2) |
67 | 122 | ||
123 | |||
124 | -- | renames all the indices in the current order (repeated indices may get different names) | ||
125 | withIdx :: Tensor t -> [IdxName] -> Tensor t | ||
126 | withIdx (T d v) l = T d' v | ||
127 | where d' = zipWith f d l | ||
128 | f i n = i {idxName=n} | ||
129 | |||
130 | |||
131 | -- | raises or lowers all the indices of a tensor (with euclidean metric) | ||
132 | raise :: Tensor t -> Tensor t | ||
133 | raise (T d v) = T (map raise' d) v | ||
134 | where raise' idx@IdxDesc {idxType = Covariant } = idx {idxType = Contravariant} | ||
135 | raise' idx@IdxDesc {idxType = Contravariant } = idx {idxType = Covariant} | ||
136 | |||
137 | |||
138 | -- | index transposition to a desired order. You can specify only a subset of the indices, which will be moved to the front of indices list | ||
139 | tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t | ||
140 | tridx [] t = t | ||
141 | tridx (name:rest) t = T (d:ds) (join ts) where | ||
142 | ((_,d:_),_) = findIdx name t | ||
143 | ps = map (tridx rest) (parts t name) | ||
144 | ts = map ten ps | ||
145 | ds = dims (head ps) | ||
146 | |||
147 | |||
148 | |||
68 | -- | extracts a given part of a tensor | 149 | -- | extracts a given part of a tensor |
69 | part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t | 150 | part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t |
70 | part t (name,k) = if k<0 || k>=l | 151 | part t (name,k) = if k<0 || k>=l |
@@ -80,6 +161,17 @@ parts t name = map f (toRows m) | |||
80 | l = idxDim d | 161 | l = idxDim d |
81 | f t = T {dims=ds, ten=t} | 162 | f t = T {dims=ds, ten=t} |
82 | 163 | ||
164 | |||
165 | compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool | ||
166 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where | ||
167 | d1 = head $ snd $ fst $ findIdx n1 t1 | ||
168 | d2 = head $ snd $ fst $ findIdx n2 t2 | ||
169 | compatIdxAux IdxDesc {idxDim = n1, idxType = t1} | ||
170 | IdxDesc {idxDim = n2, idxType = t2} | ||
171 | = t1 /= t2 && n1 == n2 | ||
172 | |||
173 | |||
174 | |||
83 | -- | tensor product without without any contractions | 175 | -- | tensor product without without any contractions |
84 | rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t | 176 | rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t |
85 | rawProduct (T d1 v1) (T d2 v2) = T (d1++d2) (outer' v1 v2) | 177 | rawProduct (T d1 v1) (T d2 v2) = T (d1++d2) (outer' v1 v2) |
@@ -98,7 +190,7 @@ contraction2 t1 n1 t2 n2 = | |||
98 | contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t | 190 | contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t |
99 | contraction1 t name1 name2 = | 191 | contraction1 t name1 name2 = |
100 | if compatIdx t name1 t name2 | 192 | if compatIdx t name1 t name2 |
101 | then addT y | 193 | then sumT y |
102 | else error $ "wrong contraction1: "++(shdims$dims$t)++" "++name1++" "++name2 | 194 | else error $ "wrong contraction1: "++(shdims$dims$t)++" "++name1++" "++name2 |
103 | where d = dims (head y) | 195 | where d = dims (head y) |
104 | x = (map (flip parts name2) (parts t name1)) | 196 | x = (map (flip parts name2) (parts t name1)) |
@@ -120,14 +212,17 @@ contraction2' t1 n1 t2 n2 = | |||
120 | else error "wrong contraction'" | 212 | else error "wrong contraction'" |
121 | 213 | ||
122 | -- | applies a sequence of contractions | 214 | -- | applies a sequence of contractions |
215 | contractions :: (Field t, Num t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t | ||
123 | contractions t pairs = foldl' contract1b t pairs | 216 | contractions t pairs = foldl' contract1b t pairs |
124 | where contract1b t (n1,n2) = contraction1 t n1 n2 | 217 | where contract1b t (n1,n2) = contraction1 t n1 n2 |
125 | 218 | ||
126 | -- | applies a sequence of contractions of same index | 219 | -- | applies a sequence of contractions of same index |
220 | contractionsC :: (Field t, Num t) => Tensor t -> [IdxName] -> Tensor t | ||
127 | contractionsC t is = foldl' contraction1c t is | 221 | contractionsC t is = foldl' contraction1c t is |
128 | 222 | ||
129 | 223 | ||
130 | -- | applies a contraction on the first indices of the tensors | 224 | -- | applies a contraction on the first indices of the tensors |
225 | contractionF :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t | ||
131 | contractionF t1 t2 = contraction2 t1 n1 t2 n2 | 226 | contractionF t1 t2 = contraction2 t1 n1 t2 n2 |
132 | where n1 = fn t1 | 227 | where n1 = fn t1 |
133 | n2 = fn t2 | 228 | n2 = fn t2 |
@@ -138,45 +233,42 @@ possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] | |||
138 | possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] | 233 | possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] |
139 | 234 | ||
140 | 235 | ||
141 | liftTensor f (T d v) = T d (f v) | ||
142 | |||
143 | liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2) | ||
144 | | otherwise = error "liftTensor2 with incompatible tensors" | ||
145 | where compat a b = length a == length b | ||
146 | 236 | ||
237 | desiredContractions2 :: Tensor t -> Tensor t1 -> [(IdxName, IdxName)] | ||
238 | desiredContractions2 t1 t2 = [ (n1,n2) | n1 <- names t1, n2 <- names t2, n1==n2] | ||
147 | 239 | ||
148 | a |+| b = liftTensor2 add a b | 240 | desiredContractions1 :: Tensor t -> [IdxName] |
149 | addT l = foldl1' (|+|) l | 241 | desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2] |
242 | where x = zip [0..] (names t) | ||
150 | 243 | ||
151 | -- | index transposition to a desired order. You can specify only a subset of the indices, which will be moved to the front of indices list | 244 | -- | tensor product with the convention that repeated indices are contracted. |
152 | tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t | 245 | mulT :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t |
153 | tridx [] t = t | 246 | mulT t1 t2 = r where |
154 | tridx (name:rest) t = T (d:ds) (join ts) where | 247 | t1r = contractionsC t1 (desiredContractions1 t1) |
155 | ((_,d:_),_) = findIdx name t | 248 | t2r = contractionsC t2 (desiredContractions1 t2) |
156 | ps = map (tridx rest) (parts t name) | 249 | cs = desiredContractions2 t1r t2r |
157 | ts = map ten ps | 250 | r = case cs of |
158 | ds = dims (head ps) | 251 | [] -> rawProduct t1r t2r |
252 | (n1,n2):as -> contractionsC (contraction2 t1r n1 t2r n2) (map fst as) | ||
159 | 253 | ||
160 | compatIdxAux :: IdxDesc -> IdxDesc -> Bool | 254 | ----------------------------------------------------------------- |
161 | compatIdxAux IdxDesc {idxDim = n1, idxType = t1} | ||
162 | IdxDesc {idxDim = n2, idxType = t2} | ||
163 | = t1 /= t2 && n1 == n2 | ||
164 | 255 | ||
165 | compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool | 256 | -- | tensor addition (for tensors with the same structure) |
166 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where | 257 | addT :: (Num a, Field a) => Tensor a -> Tensor a -> Tensor a |
167 | d1 = head $ snd $ fst $ findIdx n1 t1 | 258 | addT a b = liftTensor2 add a b |
168 | d2 = head $ snd $ fst $ findIdx n2 t2 | ||
169 | 259 | ||
170 | names :: Tensor t -> [IdxName] | 260 | sumT :: (Field a, Num a) => [Tensor a] -> Tensor a |
171 | names t = map idxName (dims t) | 261 | sumT l = foldl1' addT l |
172 | 262 | ||
263 | ----------------------------------------------------------------- | ||
173 | 264 | ||
174 | -- sent to Haskell-Cafe by Sebastian Sylvan | 265 | -- sent to Haskell-Cafe by Sebastian Sylvan |
175 | perms :: [t] -> [[t]] | 266 | perms :: [t] -> [[t]] |
176 | perms [x] = [[x]] | 267 | perms [x] = [[x]] |
177 | perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] | 268 | perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] |
178 | selections [] = [] | 269 | where |
179 | selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] | 270 | selections [] = [] |
271 | selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] | ||
180 | 272 | ||
181 | interchanges :: (Ord a) => [a] -> Int | 273 | interchanges :: (Ord a) => [a] -> Int |
182 | interchanges ls = sum (map (count ls) ls) | 274 | interchanges ls = sum (map (count ls) ls) |
@@ -188,58 +280,59 @@ signature l | length (nub l) < length l = 0 | |||
188 | | even (interchanges l) = 1 | 280 | | even (interchanges l) = 1 |
189 | | otherwise = -1 | 281 | | otherwise = -1 |
190 | 282 | ||
191 | scalar x = T [] (fromList [x]) | ||
192 | tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v} | ||
193 | tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"] | ||
194 | , ten = cdat m} | ||
195 | 283 | ||
196 | tvector v = tensorFromVector Contravariant v | 284 | sym :: (Field t, Num t) => Tensor t -> Tensor t |
197 | tcovector v = tensorFromVector Covariant v | 285 | sym t = T (dims t) (ten (sym' (withIdx t seqind))) |
286 | where sym' t = sumT $ map (flip tridx t) (perms (names t)) | ||
287 | where nms = map idxName . dims | ||
198 | 288 | ||
289 | antisym :: (Field t, Num t) => Tensor t -> Tensor t | ||
199 | antisym t = T (dims t) (ten (antisym' (withIdx t seqind))) | 290 | antisym t = T (dims t) (ten (antisym' (withIdx t seqind))) |
200 | where antisym' t = addT $ map (scsig . flip tridx t) (perms (names t)) | 291 | where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t)) |
201 | scsig t = scalar (signature (nms t)) `rawProduct` t | 292 | scsig t = scalar (signature (nms t)) `rawProduct` t |
202 | where nms = map idxName . dims | 293 | where nms = map idxName . dims |
203 | 294 | ||
204 | norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) | 295 | -- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product). |
205 | antinorper t = rawProduct t (scalar (fromIntegral $ fact (rank t))) | 296 | wedge :: (Field t, Fractional t) => Tensor t -> Tensor t -> Tensor t |
206 | |||
207 | wedge a b = antisym (rawProduct (norper a) (norper b)) | 297 | wedge a b = antisym (rawProduct (norper a) (norper b)) |
298 | where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) | ||
208 | 299 | ||
209 | normAT t = sqrt $ innerAT t t | 300 | -- antinorper t = rawProduct t (scalar (fromIntegral $ fact (rank t))) |
210 | 301 | ||
302 | -- | The euclidean inner product of two completely antisymmetric tensors | ||
303 | innerAT :: (Fractional t, Field t) => Tensor t -> Tensor t -> t | ||
211 | innerAT t1 t2 = dot (ten t1) (ten t2) / fromIntegral (fact $ rank t1) | 304 | innerAT t1 t2 = dot (ten t1) (ten t2) / fromIntegral (fact $ rank t1) |
212 | 305 | ||
306 | fact :: (Num t, Enum t) => t -> t | ||
213 | fact n = product [1..n] | 307 | fact n = product [1..n] |
214 | 308 | ||
309 | seqind' :: [[String]] | ||
215 | seqind' = map return seqind | 310 | seqind' = map return seqind |
311 | |||
312 | seqind :: [String] | ||
216 | seqind = map show [1..] | 313 | seqind = map show [1..] |
217 | 314 | ||
315 | -- | completely antisymmetric covariant tensor of dimension n | ||
316 | leviCivita :: (Field t, Num t) => Int -> Tensor t | ||
218 | leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind' | 317 | leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind' |
219 | where auxbase = map tcovector (toRows (ident n)) | 318 | where auxbase = map tc (toRows (ident n)) |
319 | tc = tensorFromVector Covariant | ||
220 | 320 | ||
221 | -- | obtains de dual of the exterior product of a list of X? | 321 | -- | contraction of leviCivita with a list of vectors (and raise with euclidean metric) |
222 | dualV vs = foldl' contractionF (leviCivita n) vs | 322 | innerLevi :: (Num t, Field t) => [Tensor t] -> Tensor t |
323 | innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs | ||
223 | where n = idxDim . head . dims . head $ vs | 324 | where n = idxDim . head . dims . head $ vs |
224 | 325 | ||
225 | -- | raises or lowers all the indices of a tensor (with euclidean metric) | ||
226 | raise (T d v) = T (map raise' d) v | ||
227 | where raise' idx@IdxDesc {idxType = Covariant } = idx {idxType = Contravariant} | ||
228 | raise' idx@IdxDesc {idxType = Contravariant } = idx {idxType = Covariant} | ||
229 | -- | raises or lowers all the indices of a tensor with a given an (inverse) metric | ||
230 | raiseWith = undefined | ||
231 | 326 | ||
232 | dualg f t = f (leviCivita n) `okContract` withIdx t seqind `rawProduct` x | 327 | -- | obtains the dual of a multivector (with euclidean metric) |
328 | dual :: (Field t, Fractional t) => Tensor t -> Tensor t | ||
329 | dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x | ||
233 | where n = idxDim . head . dims $ t | 330 | where n = idxDim . head . dims $ t |
234 | x = scalar (recip $ fromIntegral $ fact (rank t)) | 331 | x = scalar (recip $ fromIntegral $ fact (rank t)) |
235 | 332 | ||
236 | -- | obtains the dual of a multivector | ||
237 | dual t = dualg id t | ||
238 | |||
239 | -- | obtains the dual of a multicovector (with euclidean metric) | ||
240 | codual t = dualg raise t | ||
241 | 333 | ||
242 | -- | shows only the relevant components of an antisymmetric tensor | 334 | -- | shows only the relevant components of an antisymmetric tensor |
335 | niceAS :: (Field t, Fractional t) => Tensor t -> [(t, [Int])] | ||
243 | niceAS t = filter ((/=0.0).fst) $ zip vals base | 336 | niceAS t = filter ((/=0.0).fst) $ zip vals base |
244 | where vals = map ((`at` 0).ten.foldl' partF t) (map (map pred) base) | 337 | where vals = map ((`at` 0).ten.foldl' partF t) (map (map pred) base) |
245 | base = asBase r n | 338 | base = asBase r n |
@@ -247,27 +340,3 @@ niceAS t = filter ((/=0.0).fst) $ zip vals base | |||
247 | n = idxDim . head . dims $ t | 340 | n = idxDim . head . dims $ t |
248 | partF t i = part t (name,i) where name = idxName . head . dims $ t | 341 | partF t i = part t (name,i) where name = idxName . head . dims $ t |
249 | asBase r n = filter (\x-> (x==nub x && x==sort x)) $ sequence $ replicate r [1..n] | 342 | asBase r n = filter (\x-> (x==nub x && x==sort x)) $ sequence $ replicate r [1..n] |
250 | |||
251 | -- | renames specified indices of a tensor (repeated indices get the same name) | ||
252 | idxRename (T d v) l = T (map (ir l) d) v | ||
253 | where ir l i = case lookup (idxName i) l of | ||
254 | Nothing -> i | ||
255 | Just r -> i {idxName = r} | ||
256 | |||
257 | -- | renames all the indices in the current order (repeated indices may get different names) | ||
258 | withIdx (T d v) l = T d' v | ||
259 | where d' = zipWith f d l | ||
260 | f i n = i {idxName=n} | ||
261 | |||
262 | desiredContractions2 t1 t2 = [ (n1,n2) | n1 <- names t1, n2 <- names t2, n1==n2] | ||
263 | |||
264 | desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2] | ||
265 | where x = zip [0..] (names t) | ||
266 | |||
267 | okContract t1 t2 = r where | ||
268 | t1r = contractionsC t1 (desiredContractions1 t1) | ||
269 | t2r = contractionsC t2 (desiredContractions1 t2) | ||
270 | cs = desiredContractions2 t1r t2r | ||
271 | r = case cs of | ||
272 | [] -> rawProduct t1r t2r | ||
273 | (n1,n2):as -> contractionsC (contraction2 t1r n1 t2r n2) (map fst as) | ||