summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Tensor.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-06-22 17:33:17 +0000
committerAlberto Ruiz <aruiz@um.es>2007-06-22 17:33:17 +0000
commit978e6d038239af50d70bae2c303f4e45b1879b7a (patch)
tree571b2060f388d0693820f808b40089acb100a5d9 /lib/Data/Packed/Internal/Tensor.hs
parent989bdf7e88c13500bd1986dcde36f6cc4f467efb (diff)
refactoring
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs65
1 files changed, 46 insertions, 19 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 27fce6a..c4faf49 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -14,18 +14,25 @@
14 14
15module Data.Packed.Internal.Tensor where 15module Data.Packed.Internal.Tensor where
16 16
17import Data.Packed.Internal 17import Data.Packed.Internal.Common
18import Data.Packed.Internal.Vector 18import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix 19import Data.Packed.Internal.Matrix
20import Foreign.Storable 20import Foreign.Storable
21import Data.List(sort,elemIndex,nub) 21import Data.List(sort,elemIndex,nub)
22 22
23data IdxTp = Covariant | Contravariant deriving (Show,Eq) 23data IdxType = Covariant | Contravariant deriving (Show,Eq)
24 24
25data Tensor t = T { dims :: [(Int,(IdxTp,String))] 25type IdxName = String
26
27data IdxDesc = IdxDesc { idxDim :: Int,
28 idxType :: IdxType,
29 idxName :: IdxName }
30
31data Tensor t = T { dims :: [IdxDesc]
26 , ten :: Vector t 32 , ten :: Vector t
27 } 33 }
28 34
35rank :: Tensor t -> Int
29rank = length . dims 36rank = length . dims
30 37
31instance (Show a,Storable a) => Show (Tensor a) where 38instance (Show a,Storable a) => Show (Tensor a) where
@@ -33,41 +40,49 @@ instance (Show a,Storable a) => Show (Tensor a) where
33 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) 40 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t)
34 41
35 42
36shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]" 43shdims :: [IdxDesc] -> String
44shdims [IdxDesc n t name] = name ++ sym t ++"["++show n++"]"
37 where sym Covariant = "_" 45 where sym Covariant = "_"
38 sym Contravariant = "^" 46 sym Contravariant = "^"
39shdims (d:ds) = shdims [d] ++ "><"++ shdims ds 47shdims (d:ds) = shdims [d] ++ "><"++ shdims ds
40 48
41 49
42 50findIdx :: (Field t) => IdxName -> Tensor t
51 -> (([IdxDesc], [IdxDesc]), Matrix t)
43findIdx name t = ((d1,d2),m) where 52findIdx name t = ((d1,d2),m) where
44 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) 53 (d1,d2) = span (\d -> idxName d /= name) (dims t)
45 c = product (map fst d2) 54 c = product (map idxDim d2)
46 m = matrixFromVector RowMajor c (ten t) 55 m = matrixFromVector RowMajor c (ten t)
47 56
57putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t)
48putFirstIdx name t = (nd,m') 58putFirstIdx name t = (nd,m')
49 where ((d1,d2),m) = findIdx name t 59 where ((d1,d2),m) = findIdx name t
50 m' = matrixFromVector RowMajor c $ cdat $ trans m 60 m' = matrixFromVector RowMajor c $ cdat $ trans m
51 nd = d2++d1 61 nd = d2++d1
52 c = dim (ten t) `div` (fst $ head d2) 62 c = dim (ten t) `div` (idxDim $ head d2)
53 63
64part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t
54part t (name,k) = if k<0 || k>=l 65part t (name,k) = if k<0 || k>=l
55 then error $ "part "++show (name,k)++" out of range in "++show t 66 then error $ "part "++show (name,k)++" out of range" -- in "++show t
56 else T {dims = ds, ten = toRows m !! k} 67 else T {dims = ds, ten = toRows m !! k}
57 where (d:ds,m) = putFirstIdx name t 68 where (d:ds,m) = putFirstIdx name t
58 (l,_) = d 69 l = idxDim d
59 70
71parts :: (Field t) => Tensor t -> IdxName -> [Tensor t]
60parts t name = map f (toRows m) 72parts t name = map f (toRows m)
61 where (d:ds,m) = putFirstIdx name t 73 where (d:ds,m) = putFirstIdx name t
62 (l,_) = d 74 l = idxDim d
63 f t = T {dims=ds, ten=t} 75 f t = T {dims=ds, ten=t}
64 76
77concatRename :: [IdxDesc] -> [IdxDesc] -> [IdxDesc]
65concatRename l1 l2 = l1 ++ map ren l2 where 78concatRename l1 l2 = l1 ++ map ren l2 where
66 ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s)) 79 ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx
67 fs = map (snd.snd) l1 80 fs = map idxName l1
68 81
69prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2) 82prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
83prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2)
70 84
85contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
71contraction t1 n1 t2 n2 = 86contraction t1 n1 t2 n2 =
72 if compatIdx t1 n1 t2 n2 87 if compatIdx t1 n1 t2 n2
73 then T (concatRename (tail d1) (tail d2)) (cdat m) 88 then T (concatRename (tail d1) (tail d2)) (cdat m)
@@ -76,18 +91,22 @@ contraction t1 n1 t2 n2 =
76 (d2,m2) = putFirstIdx n2 t2 91 (d2,m2) = putFirstIdx n2 t2
77 m = multiply RowMajor (trans m1) m2 92 m = multiply RowMajor (trans m1) m2
78 93
94sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t]
79sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) 95sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
80 96
97contract1 :: (Num t, Enum t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t
81contract1 t name1 name2 = T d $ fromList $ sumT y 98contract1 t name1 name2 = T d $ fromList $ sumT y
82 where d = dims (head y) 99 where d = dims (head y)
83 x = (map (flip parts name2) (parts t name1)) 100 x = (map (flip parts name2) (parts t name1))
84 y = map head $ zipWith drop [0..] x 101 y = map head $ zipWith drop [0..] x
85 102
103contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
86contraction' t1 n1 t2 n2 = 104contraction' t1 n1 t2 n2 =
87 if compatIdx t1 n1 t2 n2 105 if compatIdx t1 n1 t2 n2
88 then contract1 (prod t1 t2) n1 (n2++"'") 106 then contract1 (prod t1 t2) n1 (n2++"'")
89 else error "wrong contraction'" 107 else error "wrong contraction'"
90 108
109tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t
91tridx [] t = t 110tridx [] t = t
92tridx (name:rest) t = T (d:ds) (join ts) where 111tridx (name:rest) t = T (d:ds) (join ts) where
93 ((_,d:_),_) = findIdx name t 112 ((_,d:_),_) = findIdx name t
@@ -95,30 +114,38 @@ tridx (name:rest) t = T (d:ds) (join ts) where
95 ts = map ten ps 114 ts = map ten ps
96 ds = dims (head ps) 115 ds = dims (head ps)
97 116
98compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2 117compatIdxAux :: IdxDesc -> IdxDesc -> Bool
118compatIdxAux IdxDesc {idxDim = n1, idxType = t1}
119 IdxDesc {idxDim = n2, idxType = t2}
120 = t1 /= t2 && n1 == n2
99 121
122compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool
100compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where 123compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where
101 d1 = head $ snd $ fst $ findIdx n1 t1 124 d1 = head $ snd $ fst $ findIdx n1 t1
102 d2 = head $ snd $ fst $ findIdx n2 t2 125 d2 = head $ snd $ fst $ findIdx n2 t2
103 126
104names t = sort $ map (snd.snd) (dims t) 127names :: Tensor t -> [IdxName]
128names t = sort $ map idxName (dims t)
105 129
130normal :: (Field t) => Tensor t -> Tensor t
106normal t = tridx (names t) t 131normal t = tridx (names t) t
107 132
133contractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t]
108contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] 134contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
109 135
110-- sent to Haskell-Cafe by Sebastian Sylvan 136-- sent to Haskell-Cafe by Sebastian Sylvan
137perms :: [t] -> [[t]]
111perms [x] = [[x]] 138perms [x] = [[x]]
112perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] 139perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys]
113selections [] = [] 140selections [] = []
114selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] 141selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs]
115 142
116 143interchanges :: (Ord a) => [a] -> Int
117interchanges ls = sum (map (count ls) ls) 144interchanges ls = sum (map (count ls) ls)
118 where count l p = n 145 where count l p = length $ filter (>p) $ take pel l
119 where Just pel = elemIndex p l 146 where Just pel = elemIndex p l
120 n = length $ filter (>p) $ take pel l
121 147
148signature :: (Num t, Ord a) => [a] -> t
122signature l | length (nub l) < length l = 0 149signature l | length (nub l) < length l = 0
123 | even (interchanges l) = 1 150 | even (interchanges l) = 1
124 | otherwise = -1 151 | otherwise = -1