summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Tensor.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-26 16:57:58 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-26 16:57:58 +0000
commita749785e839d14fadc47ab4c6e94afdd167bdd21 (patch)
tree2b715bf233aa8e82137621a251b0edf0b32cdd67 /lib/Data/Packed/Internal/Tensor.hs
parent3019948b97ba1c177b21ab103823fabe561b3ffe (diff)
tensor refactorization
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs32
1 files changed, 22 insertions, 10 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index c4faf49..8296935 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -18,7 +18,8 @@ import Data.Packed.Internal.Common
18import Data.Packed.Internal.Vector 18import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix 19import Data.Packed.Internal.Matrix
20import Foreign.Storable 20import Foreign.Storable
21import Data.List(sort,elemIndex,nub) 21import Data.List(sort,elemIndex,nub,foldl1')
22import GSL.Vector
22 23
23data IdxType = Covariant | Contravariant deriving (Show,Eq) 24data IdxType = Covariant | Contravariant deriving (Show,Eq)
24 25
@@ -79,10 +80,10 @@ concatRename l1 l2 = l1 ++ map ren l2 where
79 ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx 80 ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx
80 fs = map idxName l1 81 fs = map idxName l1
81 82
82prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t 83--prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
83prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) 84prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2)
84 85
85contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t 86--contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
86contraction t1 n1 t2 n2 = 87contraction t1 n1 t2 n2 =
87 if compatIdx t1 n1 t2 n2 88 if compatIdx t1 n1 t2 n2
88 then T (concatRename (tail d1) (tail d2)) (cdat m) 89 then T (concatRename (tail d1) (tail d2)) (cdat m)
@@ -91,16 +92,27 @@ contraction t1 n1 t2 n2 =
91 (d2,m2) = putFirstIdx n2 t2 92 (d2,m2) = putFirstIdx n2 t2
92 m = multiply RowMajor (trans m1) m2 93 m = multiply RowMajor (trans m1) m2
93 94
94sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] 95--sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t]
95sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) 96--sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
97--addT ts = T (dims (head ts)) (fromList $ sumT ts)
96 98
97contract1 :: (Num t, Enum t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t 99liftTensor f (T d v) = T d (f v)
98contract1 t name1 name2 = T d $ fromList $ sumT y 100
101liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2)
102 | otherwise = error "liftTensor2 with incompatible tensors"
103 where compat a b = length a == length b
104
105
106a |+| b = liftTensor2 add a b
107addT l = foldl1' (|+|) l
108
109--contract1 :: (Num t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t
110contract1 t name1 name2 = addT y
99 where d = dims (head y) 111 where d = dims (head y)
100 x = (map (flip parts name2) (parts t name1)) 112 x = (map (flip parts name2) (parts t name1))
101 y = map head $ zipWith drop [0..] x 113 y = map head $ zipWith drop [0..] x
102 114
103contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t 115--contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
104contraction' t1 n1 t2 n2 = 116contraction' t1 n1 t2 n2 =
105 if compatIdx t1 n1 t2 n2 117 if compatIdx t1 n1 t2 n2
106 then contract1 (prod t1 t2) n1 (n2++"'") 118 then contract1 (prod t1 t2) n1 (n2++"'")
@@ -130,8 +142,8 @@ names t = sort $ map idxName (dims t)
130normal :: (Field t) => Tensor t -> Tensor t 142normal :: (Field t) => Tensor t -> Tensor t
131normal t = tridx (names t) t 143normal t = tridx (names t) t
132 144
133contractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] 145possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t]
134contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] 146possibleContractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
135 147
136-- sent to Haskell-Cafe by Sebastian Sylvan 148-- sent to Haskell-Cafe by Sebastian Sylvan
137perms :: [t] -> [[t]] 149perms :: [t] -> [[t]]