diff options
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 32 |
1 files changed, 22 insertions, 10 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index c4faf49..8296935 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -18,7 +18,8 @@ import Data.Packed.Internal.Common | |||
18 | import Data.Packed.Internal.Vector | 18 | import Data.Packed.Internal.Vector |
19 | import Data.Packed.Internal.Matrix | 19 | import Data.Packed.Internal.Matrix |
20 | import Foreign.Storable | 20 | import Foreign.Storable |
21 | import Data.List(sort,elemIndex,nub) | 21 | import Data.List(sort,elemIndex,nub,foldl1') |
22 | import GSL.Vector | ||
22 | 23 | ||
23 | data IdxType = Covariant | Contravariant deriving (Show,Eq) | 24 | data IdxType = Covariant | Contravariant deriving (Show,Eq) |
24 | 25 | ||
@@ -79,10 +80,10 @@ concatRename l1 l2 = l1 ++ map ren l2 where | |||
79 | ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx | 80 | ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx |
80 | fs = map idxName l1 | 81 | fs = map idxName l1 |
81 | 82 | ||
82 | prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t | 83 | --prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t |
83 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) | 84 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) |
84 | 85 | ||
85 | contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | 86 | --contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t |
86 | contraction t1 n1 t2 n2 = | 87 | contraction t1 n1 t2 n2 = |
87 | if compatIdx t1 n1 t2 n2 | 88 | if compatIdx t1 n1 t2 n2 |
88 | then T (concatRename (tail d1) (tail d2)) (cdat m) | 89 | then T (concatRename (tail d1) (tail d2)) (cdat m) |
@@ -91,16 +92,27 @@ contraction t1 n1 t2 n2 = | |||
91 | (d2,m2) = putFirstIdx n2 t2 | 92 | (d2,m2) = putFirstIdx n2 t2 |
92 | m = multiply RowMajor (trans m1) m2 | 93 | m = multiply RowMajor (trans m1) m2 |
93 | 94 | ||
94 | sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] | 95 | --sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] |
95 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | 96 | --sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) |
97 | --addT ts = T (dims (head ts)) (fromList $ sumT ts) | ||
96 | 98 | ||
97 | contract1 :: (Num t, Enum t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t | 99 | liftTensor f (T d v) = T d (f v) |
98 | contract1 t name1 name2 = T d $ fromList $ sumT y | 100 | |
101 | liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2) | ||
102 | | otherwise = error "liftTensor2 with incompatible tensors" | ||
103 | where compat a b = length a == length b | ||
104 | |||
105 | |||
106 | a |+| b = liftTensor2 add a b | ||
107 | addT l = foldl1' (|+|) l | ||
108 | |||
109 | --contract1 :: (Num t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t | ||
110 | contract1 t name1 name2 = addT y | ||
99 | where d = dims (head y) | 111 | where d = dims (head y) |
100 | x = (map (flip parts name2) (parts t name1)) | 112 | x = (map (flip parts name2) (parts t name1)) |
101 | y = map head $ zipWith drop [0..] x | 113 | y = map head $ zipWith drop [0..] x |
102 | 114 | ||
103 | contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | 115 | --contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t |
104 | contraction' t1 n1 t2 n2 = | 116 | contraction' t1 n1 t2 n2 = |
105 | if compatIdx t1 n1 t2 n2 | 117 | if compatIdx t1 n1 t2 n2 |
106 | then contract1 (prod t1 t2) n1 (n2++"'") | 118 | then contract1 (prod t1 t2) n1 (n2++"'") |
@@ -130,8 +142,8 @@ names t = sort $ map idxName (dims t) | |||
130 | normal :: (Field t) => Tensor t -> Tensor t | 142 | normal :: (Field t) => Tensor t -> Tensor t |
131 | normal t = tridx (names t) t | 143 | normal t = tridx (names t) t |
132 | 144 | ||
133 | contractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] | 145 | possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] |
134 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] | 146 | possibleContractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] |
135 | 147 | ||
136 | -- sent to Haskell-Cafe by Sebastian Sylvan | 148 | -- sent to Haskell-Cafe by Sebastian Sylvan |
137 | perms :: [t] -> [[t]] | 149 | perms :: [t] -> [[t]] |