diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-29 08:09:46 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-29 08:09:46 +0000 |
commit | e36c04dca536caa42b41a4280d3f21375219970d (patch) | |
tree | 18aef1b4607ee19c88cd740dbc2e7a65d783bd2f /lib/Data/Packed/Internal/Tensor.hs | |
parent | a749785e839d14fadc47ab4c6e94afdd167bdd21 (diff) |
tensor refactoring
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 197 |
1 files changed, 160 insertions, 37 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 8296935..dedbb9c 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -14,12 +14,11 @@ | |||
14 | 14 | ||
15 | module Data.Packed.Internal.Tensor where | 15 | module Data.Packed.Internal.Tensor where |
16 | 16 | ||
17 | import Data.Packed.Internal.Common | 17 | import Data.Packed.Internal |
18 | import Data.Packed.Internal.Vector | ||
19 | import Data.Packed.Internal.Matrix | ||
20 | import Foreign.Storable | 18 | import Foreign.Storable |
21 | import Data.List(sort,elemIndex,nub,foldl1') | 19 | import Data.List(sort,elemIndex,nub,foldl1',foldl') |
22 | import GSL.Vector | 20 | import GSL.Vector |
21 | import Data.Packed.Matrix | ||
23 | 22 | ||
24 | data IdxType = Covariant | Contravariant deriving (Show,Eq) | 23 | data IdxType = Covariant | Contravariant deriving (Show,Eq) |
25 | 24 | ||
@@ -27,12 +26,13 @@ type IdxName = String | |||
27 | 26 | ||
28 | data IdxDesc = IdxDesc { idxDim :: Int, | 27 | data IdxDesc = IdxDesc { idxDim :: Int, |
29 | idxType :: IdxType, | 28 | idxType :: IdxType, |
30 | idxName :: IdxName } | 29 | idxName :: IdxName } deriving Show |
31 | 30 | ||
32 | data Tensor t = T { dims :: [IdxDesc] | 31 | data Tensor t = T { dims :: [IdxDesc] |
33 | , ten :: Vector t | 32 | , ten :: Vector t |
34 | } | 33 | } |
35 | 34 | ||
35 | -- | tensor rank (number of indices) | ||
36 | rank :: Tensor t -> Int | 36 | rank :: Tensor t -> Int |
37 | rank = length . dims | 37 | rank = length . dims |
38 | 38 | ||
@@ -40,14 +40,16 @@ instance (Show a,Storable a) => Show (Tensor a) where | |||
40 | show T {dims = [], ten = t} = "scalar "++show (t `at` 0) | 40 | show T {dims = [], ten = t} = "scalar "++show (t `at` 0) |
41 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) | 41 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) |
42 | 42 | ||
43 | 43 | -- | a nice description of the tensor structure | |
44 | shdims :: [IdxDesc] -> String | 44 | shdims :: [IdxDesc] -> String |
45 | shdims [IdxDesc n t name] = name ++ sym t ++"["++show n++"]" | 45 | shdims [] = "" |
46 | shdims [d] = shdim d | ||
47 | shdims (d:ds) = shdim d ++ "><"++ shdims ds | ||
48 | shdim (IdxDesc n t name) = name ++ sym t ++"["++show n++"]" | ||
46 | where sym Covariant = "_" | 49 | where sym Covariant = "_" |
47 | sym Contravariant = "^" | 50 | sym Contravariant = "^" |
48 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds | ||
49 | |||
50 | 51 | ||
52 | -- | express the tensor as a matrix with the given index in columns | ||
51 | findIdx :: (Field t) => IdxName -> Tensor t | 53 | findIdx :: (Field t) => IdxName -> Tensor t |
52 | -> (([IdxDesc], [IdxDesc]), Matrix t) | 54 | -> (([IdxDesc], [IdxDesc]), Matrix t) |
53 | findIdx name t = ((d1,d2),m) where | 55 | findIdx name t = ((d1,d2),m) where |
@@ -55,6 +57,7 @@ findIdx name t = ((d1,d2),m) where | |||
55 | c = product (map idxDim d2) | 57 | c = product (map idxDim d2) |
56 | m = matrixFromVector RowMajor c (ten t) | 58 | m = matrixFromVector RowMajor c (ten t) |
57 | 59 | ||
60 | -- | express the tensor as a matrix with the given index in rows | ||
58 | putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t) | 61 | putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t) |
59 | putFirstIdx name t = (nd,m') | 62 | putFirstIdx name t = (nd,m') |
60 | where ((d1,d2),m) = findIdx name t | 63 | where ((d1,d2),m) = findIdx name t |
@@ -62,6 +65,7 @@ putFirstIdx name t = (nd,m') | |||
62 | nd = d2++d1 | 65 | nd = d2++d1 |
63 | c = dim (ten t) `div` (idxDim $ head d2) | 66 | c = dim (ten t) `div` (idxDim $ head d2) |
64 | 67 | ||
68 | -- | extracts a given part of a tensor | ||
65 | part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t | 69 | part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t |
66 | part t (name,k) = if k<0 || k>=l | 70 | part t (name,k) = if k<0 || k>=l |
67 | then error $ "part "++show (name,k)++" out of range" -- in "++show t | 71 | then error $ "part "++show (name,k)++" out of range" -- in "++show t |
@@ -69,32 +73,70 @@ part t (name,k) = if k<0 || k>=l | |||
69 | where (d:ds,m) = putFirstIdx name t | 73 | where (d:ds,m) = putFirstIdx name t |
70 | l = idxDim d | 74 | l = idxDim d |
71 | 75 | ||
76 | -- | creates a list with all parts of a tensor along a given index | ||
72 | parts :: (Field t) => Tensor t -> IdxName -> [Tensor t] | 77 | parts :: (Field t) => Tensor t -> IdxName -> [Tensor t] |
73 | parts t name = map f (toRows m) | 78 | parts t name = map f (toRows m) |
74 | where (d:ds,m) = putFirstIdx name t | 79 | where (d:ds,m) = putFirstIdx name t |
75 | l = idxDim d | 80 | l = idxDim d |
76 | f t = T {dims=ds, ten=t} | 81 | f t = T {dims=ds, ten=t} |
77 | 82 | ||
78 | concatRename :: [IdxDesc] -> [IdxDesc] -> [IdxDesc] | 83 | -- | tensor product without without any contractions |
79 | concatRename l1 l2 = l1 ++ map ren l2 where | 84 | rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t |
80 | ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx | 85 | rawProduct (T d1 v1) (T d2 v2) = T (d1++d2) (outer' v1 v2) |
81 | fs = map idxName l1 | ||
82 | 86 | ||
83 | --prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t | 87 | -- | contraction of the product of two tensors |
84 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) | 88 | contraction2 :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t |
85 | 89 | contraction2 t1 n1 t2 n2 = | |
86 | --contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | ||
87 | contraction t1 n1 t2 n2 = | ||
88 | if compatIdx t1 n1 t2 n2 | 90 | if compatIdx t1 n1 t2 n2 |
89 | then T (concatRename (tail d1) (tail d2)) (cdat m) | 91 | then T (tail d1 ++ tail d2) (cdat m) |
90 | else error "wrong contraction'" | 92 | else error "wrong contraction2" |
91 | where (d1,m1) = putFirstIdx n1 t1 | 93 | where (d1,m1) = putFirstIdx n1 t1 |
92 | (d2,m2) = putFirstIdx n2 t2 | 94 | (d2,m2) = putFirstIdx n2 t2 |
93 | m = multiply RowMajor (trans m1) m2 | 95 | m = multiply RowMajor (trans m1) m2 |
94 | 96 | ||
95 | --sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] | 97 | -- | contraction of a tensor along two given indices |
96 | --sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | 98 | contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t |
97 | --addT ts = T (dims (head ts)) (fromList $ sumT ts) | 99 | contraction1 t name1 name2 = |
100 | if compatIdx t name1 t name2 | ||
101 | then addT y | ||
102 | else error $ "wrong contraction1: "++(shdims$dims$t)++" "++name1++" "++name2 | ||
103 | where d = dims (head y) | ||
104 | x = (map (flip parts name2) (parts t name1)) | ||
105 | y = map head $ zipWith drop [0..] x | ||
106 | |||
107 | -- | contraction of a tensor along a repeated index | ||
108 | contraction1c :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t | ||
109 | contraction1c t n = contraction1 renamed n' n | ||
110 | where n' = n++"'" -- hmmm | ||
111 | renamed = withIdx t auxnames | ||
112 | auxnames = h ++ (n':r) | ||
113 | (h,_:r) = break (==n) (map idxName (dims t)) | ||
114 | |||
115 | -- | alternative and inefficient version of contraction2 | ||
116 | contraction2' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | ||
117 | contraction2' t1 n1 t2 n2 = | ||
118 | if compatIdx t1 n1 t2 n2 | ||
119 | then contraction1 (rawProduct t1 t2) n1 n2 | ||
120 | else error "wrong contraction'" | ||
121 | |||
122 | -- | applies a sequence of contractions | ||
123 | contractions t pairs = foldl' contract1b t pairs | ||
124 | where contract1b t (n1,n2) = contraction1 t n1 n2 | ||
125 | |||
126 | -- | applies a sequence of contractions of same index | ||
127 | contractionsC t is = foldl' contraction1c t is | ||
128 | |||
129 | |||
130 | -- | applies a contraction on the first indices of the tensors | ||
131 | contractionF t1 t2 = contraction2 t1 n1 t2 n2 | ||
132 | where n1 = fn t1 | ||
133 | n2 = fn t2 | ||
134 | fn = idxName . head . dims | ||
135 | |||
136 | -- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal | ||
137 | possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] | ||
138 | possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] | ||
139 | |||
98 | 140 | ||
99 | liftTensor f (T d v) = T d (f v) | 141 | liftTensor f (T d v) = T d (f v) |
100 | 142 | ||
@@ -106,18 +148,7 @@ liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2) | |||
106 | a |+| b = liftTensor2 add a b | 148 | a |+| b = liftTensor2 add a b |
107 | addT l = foldl1' (|+|) l | 149 | addT l = foldl1' (|+|) l |
108 | 150 | ||
109 | --contract1 :: (Num t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t | 151 | -- | index transposition to a desired order. You can specify only a subset of the indices, which will be moved to the front of indices list |
110 | contract1 t name1 name2 = addT y | ||
111 | where d = dims (head y) | ||
112 | x = (map (flip parts name2) (parts t name1)) | ||
113 | y = map head $ zipWith drop [0..] x | ||
114 | |||
115 | --contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | ||
116 | contraction' t1 n1 t2 n2 = | ||
117 | if compatIdx t1 n1 t2 n2 | ||
118 | then contract1 (prod t1 t2) n1 (n2++"'") | ||
119 | else error "wrong contraction'" | ||
120 | |||
121 | tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t | 152 | tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t |
122 | tridx [] t = t | 153 | tridx [] t = t |
123 | tridx (name:rest) t = T (d:ds) (join ts) where | 154 | tridx (name:rest) t = T (d:ds) (join ts) where |
@@ -142,8 +173,6 @@ names t = sort $ map idxName (dims t) | |||
142 | normal :: (Field t) => Tensor t -> Tensor t | 173 | normal :: (Field t) => Tensor t -> Tensor t |
143 | normal t = tridx (names t) t | 174 | normal t = tridx (names t) t |
144 | 175 | ||
145 | possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] | ||
146 | possibleContractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] | ||
147 | 176 | ||
148 | -- sent to Haskell-Cafe by Sebastian Sylvan | 177 | -- sent to Haskell-Cafe by Sebastian Sylvan |
149 | perms :: [t] -> [[t]] | 178 | perms :: [t] -> [[t]] |
@@ -161,3 +190,97 @@ signature :: (Num t, Ord a) => [a] -> t | |||
161 | signature l | length (nub l) < length l = 0 | 190 | signature l | length (nub l) < length l = 0 |
162 | | even (interchanges l) = 1 | 191 | | even (interchanges l) = 1 |
163 | | otherwise = -1 | 192 | | otherwise = -1 |
193 | |||
194 | scalar x = T [] (fromList [x]) | ||
195 | tensorFromVector (tp,nm) v = T {dims = [IdxDesc (dim v) tp nm] | ||
196 | , ten = v} | ||
197 | tensorFromMatrix (tpr,nmr) (tpc,nmc) m = T {dims = [IdxDesc (rows m) tpr nmr,IdxDesc (cols m) tpc nmc] | ||
198 | , ten = cdat m} | ||
199 | |||
200 | tvector n v = tensorFromVector (Contravariant,n) v | ||
201 | tcovector n v = tensorFromVector (Covariant,n) v | ||
202 | |||
203 | |||
204 | antisym t = T (dims t) (ten (antisym' (auxrename t))) | ||
205 | where | ||
206 | scsig t = scalar (signature (nms t)) `rawProduct` t | ||
207 | where nms = map idxName . dims | ||
208 | |||
209 | antisym' t = addT $ map (scsig . flip tridx t) (perms (names t)) | ||
210 | |||
211 | auxrename (T d v) = T d' v | ||
212 | where d' = [IdxDesc n c (show (pos q)) | IdxDesc n c q <- d] | ||
213 | pos n = i where Just i = elemIndex n nms | ||
214 | nms = map idxName d | ||
215 | |||
216 | |||
217 | norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) | ||
218 | antinorper t = rawProduct t (scalar (fromIntegral $ fact (rank t))) | ||
219 | |||
220 | wedge a b = antisym (rawProduct (norper a) (norper b)) | ||
221 | |||
222 | a /\ b = wedge a b | ||
223 | |||
224 | normAT t = sqrt $ innerAT t t | ||
225 | |||
226 | innerAT t1 t2 = dot (ten t1) (ten t2) / fromIntegral (fact $ rank t1) | ||
227 | |||
228 | fact n = product [1..n] | ||
229 | |||
230 | leviCivita n = antisym $ foldl1 rawProduct $ zipWith tcovector (map show [1,2..]) (toRows (ident n)) | ||
231 | |||
232 | -- | obtains de dual of the exterior product of a list of X? | ||
233 | dualV vs = foldl' contractionF (leviCivita n) vs | ||
234 | where n = idxDim . head . dims . head $ vs | ||
235 | |||
236 | -- | raises or lowers all the indices of a tensor (with euclidean metric) | ||
237 | raise (T d v) = T (map raise' d) v | ||
238 | where raise' idx@IdxDesc {idxType = Covariant } = idx {idxType = Contravariant} | ||
239 | raise' idx@IdxDesc {idxType = Contravariant } = idx {idxType = Covariant} | ||
240 | -- | raises or lowers all the indices of a tensor with a given an (inverse) metric | ||
241 | raiseWith = undefined | ||
242 | |||
243 | -- | obtains the dual of a multivector | ||
244 | dualMV t = rawProduct (contractions lct ds) x | ||
245 | where | ||
246 | lc = leviCivita n | ||
247 | lct = rawProduct lc t | ||
248 | nms1 = map idxName (dims lc) | ||
249 | nms2 = map idxName (dims t) | ||
250 | ds = zip nms1 nms2 | ||
251 | n = idxDim . head . dims $ t | ||
252 | x = scalar (recip $ fromIntegral $ fact (rank t)) | ||
253 | |||
254 | |||
255 | -- | shows only the relevant components of an antisymmetric tensor | ||
256 | niceAS t = filter ((/=0.0).fst) $ zip vals base | ||
257 | where vals = map ((`at` 0).ten.foldl' partF t) (map (map pred) base) | ||
258 | base = asBase r n | ||
259 | r = length (dims t) | ||
260 | n = idxDim . head . dims $ t | ||
261 | partF t i = part t (name,i) where name = idxName . head . dims $ t | ||
262 | asBase r n = filter (\x-> (x==nub x && x==sort x)) $ sequence $ replicate r [1..n] | ||
263 | |||
264 | -- | renames specified indices of a tensor (repeated indices get the same name) | ||
265 | idxRename (T d v) l = T (map (ir l) d) v | ||
266 | where ir l i = case lookup (idxName i) l of | ||
267 | Nothing -> i | ||
268 | Just r -> i {idxName = r} | ||
269 | |||
270 | -- | renames all the indices in the current order (repeated indices may get different names) | ||
271 | withIdx (T d v) l = T d' v | ||
272 | where d' = zipWith f d l | ||
273 | f i n = i {idxName=n} | ||
274 | |||
275 | desiredContractions2 t1 t2 = [ (n1,n2) | n1 <- names t1, n2 <- names t2, n1==n2] | ||
276 | |||
277 | desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2] | ||
278 | where x = zip [0..] (names t) | ||
279 | |||
280 | okContract t1 t2 = r where | ||
281 | t1r = contractionsC t1 (desiredContractions1 t1) | ||
282 | t2r = contractionsC t2 (desiredContractions1 t2) | ||
283 | cs = desiredContractions2 t1r t2r | ||
284 | r = case cs of | ||
285 | [] -> rawProduct t1r t2r | ||
286 | (n1,n2):as -> contractionsC (contraction2 t1r n1 t2r n2) (map fst as) | ||