diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-22 17:33:17 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-22 17:33:17 +0000 |
commit | 978e6d038239af50d70bae2c303f4e45b1879b7a (patch) | |
tree | 571b2060f388d0693820f808b40089acb100a5d9 /lib/Data/Packed/Internal/Tensor.hs | |
parent | 989bdf7e88c13500bd1986dcde36f6cc4f467efb (diff) |
refactoring
Diffstat (limited to 'lib/Data/Packed/Internal/Tensor.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 65 |
1 files changed, 46 insertions, 19 deletions
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 27fce6a..c4faf49 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -14,18 +14,25 @@ | |||
14 | 14 | ||
15 | module Data.Packed.Internal.Tensor where | 15 | module Data.Packed.Internal.Tensor where |
16 | 16 | ||
17 | import Data.Packed.Internal | 17 | import Data.Packed.Internal.Common |
18 | import Data.Packed.Internal.Vector | 18 | import Data.Packed.Internal.Vector |
19 | import Data.Packed.Internal.Matrix | 19 | import Data.Packed.Internal.Matrix |
20 | import Foreign.Storable | 20 | import Foreign.Storable |
21 | import Data.List(sort,elemIndex,nub) | 21 | import Data.List(sort,elemIndex,nub) |
22 | 22 | ||
23 | data IdxTp = Covariant | Contravariant deriving (Show,Eq) | 23 | data IdxType = Covariant | Contravariant deriving (Show,Eq) |
24 | 24 | ||
25 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] | 25 | type IdxName = String |
26 | |||
27 | data IdxDesc = IdxDesc { idxDim :: Int, | ||
28 | idxType :: IdxType, | ||
29 | idxName :: IdxName } | ||
30 | |||
31 | data Tensor t = T { dims :: [IdxDesc] | ||
26 | , ten :: Vector t | 32 | , ten :: Vector t |
27 | } | 33 | } |
28 | 34 | ||
35 | rank :: Tensor t -> Int | ||
29 | rank = length . dims | 36 | rank = length . dims |
30 | 37 | ||
31 | instance (Show a,Storable a) => Show (Tensor a) where | 38 | instance (Show a,Storable a) => Show (Tensor a) where |
@@ -33,41 +40,49 @@ instance (Show a,Storable a) => Show (Tensor a) where | |||
33 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) | 40 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) |
34 | 41 | ||
35 | 42 | ||
36 | shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]" | 43 | shdims :: [IdxDesc] -> String |
44 | shdims [IdxDesc n t name] = name ++ sym t ++"["++show n++"]" | ||
37 | where sym Covariant = "_" | 45 | where sym Covariant = "_" |
38 | sym Contravariant = "^" | 46 | sym Contravariant = "^" |
39 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds | 47 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds |
40 | 48 | ||
41 | 49 | ||
42 | 50 | findIdx :: (Field t) => IdxName -> Tensor t | |
51 | -> (([IdxDesc], [IdxDesc]), Matrix t) | ||
43 | findIdx name t = ((d1,d2),m) where | 52 | findIdx name t = ((d1,d2),m) where |
44 | (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) | 53 | (d1,d2) = span (\d -> idxName d /= name) (dims t) |
45 | c = product (map fst d2) | 54 | c = product (map idxDim d2) |
46 | m = matrixFromVector RowMajor c (ten t) | 55 | m = matrixFromVector RowMajor c (ten t) |
47 | 56 | ||
57 | putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t) | ||
48 | putFirstIdx name t = (nd,m') | 58 | putFirstIdx name t = (nd,m') |
49 | where ((d1,d2),m) = findIdx name t | 59 | where ((d1,d2),m) = findIdx name t |
50 | m' = matrixFromVector RowMajor c $ cdat $ trans m | 60 | m' = matrixFromVector RowMajor c $ cdat $ trans m |
51 | nd = d2++d1 | 61 | nd = d2++d1 |
52 | c = dim (ten t) `div` (fst $ head d2) | 62 | c = dim (ten t) `div` (idxDim $ head d2) |
53 | 63 | ||
64 | part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t | ||
54 | part t (name,k) = if k<0 || k>=l | 65 | part t (name,k) = if k<0 || k>=l |
55 | then error $ "part "++show (name,k)++" out of range in "++show t | 66 | then error $ "part "++show (name,k)++" out of range" -- in "++show t |
56 | else T {dims = ds, ten = toRows m !! k} | 67 | else T {dims = ds, ten = toRows m !! k} |
57 | where (d:ds,m) = putFirstIdx name t | 68 | where (d:ds,m) = putFirstIdx name t |
58 | (l,_) = d | 69 | l = idxDim d |
59 | 70 | ||
71 | parts :: (Field t) => Tensor t -> IdxName -> [Tensor t] | ||
60 | parts t name = map f (toRows m) | 72 | parts t name = map f (toRows m) |
61 | where (d:ds,m) = putFirstIdx name t | 73 | where (d:ds,m) = putFirstIdx name t |
62 | (l,_) = d | 74 | l = idxDim d |
63 | f t = T {dims=ds, ten=t} | 75 | f t = T {dims=ds, ten=t} |
64 | 76 | ||
77 | concatRename :: [IdxDesc] -> [IdxDesc] -> [IdxDesc] | ||
65 | concatRename l1 l2 = l1 ++ map ren l2 where | 78 | concatRename l1 l2 = l1 ++ map ren l2 where |
66 | ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s)) | 79 | ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx |
67 | fs = map (snd.snd) l1 | 80 | fs = map idxName l1 |
68 | 81 | ||
69 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2) | 82 | prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t |
83 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) | ||
70 | 84 | ||
85 | contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | ||
71 | contraction t1 n1 t2 n2 = | 86 | contraction t1 n1 t2 n2 = |
72 | if compatIdx t1 n1 t2 n2 | 87 | if compatIdx t1 n1 t2 n2 |
73 | then T (concatRename (tail d1) (tail d2)) (cdat m) | 88 | then T (concatRename (tail d1) (tail d2)) (cdat m) |
@@ -76,18 +91,22 @@ contraction t1 n1 t2 n2 = | |||
76 | (d2,m2) = putFirstIdx n2 t2 | 91 | (d2,m2) = putFirstIdx n2 t2 |
77 | m = multiply RowMajor (trans m1) m2 | 92 | m = multiply RowMajor (trans m1) m2 |
78 | 93 | ||
94 | sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] | ||
79 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | 95 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) |
80 | 96 | ||
97 | contract1 :: (Num t, Enum t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t | ||
81 | contract1 t name1 name2 = T d $ fromList $ sumT y | 98 | contract1 t name1 name2 = T d $ fromList $ sumT y |
82 | where d = dims (head y) | 99 | where d = dims (head y) |
83 | x = (map (flip parts name2) (parts t name1)) | 100 | x = (map (flip parts name2) (parts t name1)) |
84 | y = map head $ zipWith drop [0..] x | 101 | y = map head $ zipWith drop [0..] x |
85 | 102 | ||
103 | contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | ||
86 | contraction' t1 n1 t2 n2 = | 104 | contraction' t1 n1 t2 n2 = |
87 | if compatIdx t1 n1 t2 n2 | 105 | if compatIdx t1 n1 t2 n2 |
88 | then contract1 (prod t1 t2) n1 (n2++"'") | 106 | then contract1 (prod t1 t2) n1 (n2++"'") |
89 | else error "wrong contraction'" | 107 | else error "wrong contraction'" |
90 | 108 | ||
109 | tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t | ||
91 | tridx [] t = t | 110 | tridx [] t = t |
92 | tridx (name:rest) t = T (d:ds) (join ts) where | 111 | tridx (name:rest) t = T (d:ds) (join ts) where |
93 | ((_,d:_),_) = findIdx name t | 112 | ((_,d:_),_) = findIdx name t |
@@ -95,30 +114,38 @@ tridx (name:rest) t = T (d:ds) (join ts) where | |||
95 | ts = map ten ps | 114 | ts = map ten ps |
96 | ds = dims (head ps) | 115 | ds = dims (head ps) |
97 | 116 | ||
98 | compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2 | 117 | compatIdxAux :: IdxDesc -> IdxDesc -> Bool |
118 | compatIdxAux IdxDesc {idxDim = n1, idxType = t1} | ||
119 | IdxDesc {idxDim = n2, idxType = t2} | ||
120 | = t1 /= t2 && n1 == n2 | ||
99 | 121 | ||
122 | compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool | ||
100 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where | 123 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where |
101 | d1 = head $ snd $ fst $ findIdx n1 t1 | 124 | d1 = head $ snd $ fst $ findIdx n1 t1 |
102 | d2 = head $ snd $ fst $ findIdx n2 t2 | 125 | d2 = head $ snd $ fst $ findIdx n2 t2 |
103 | 126 | ||
104 | names t = sort $ map (snd.snd) (dims t) | 127 | names :: Tensor t -> [IdxName] |
128 | names t = sort $ map idxName (dims t) | ||
105 | 129 | ||
130 | normal :: (Field t) => Tensor t -> Tensor t | ||
106 | normal t = tridx (names t) t | 131 | normal t = tridx (names t) t |
107 | 132 | ||
133 | contractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] | ||
108 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] | 134 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] |
109 | 135 | ||
110 | -- sent to Haskell-Cafe by Sebastian Sylvan | 136 | -- sent to Haskell-Cafe by Sebastian Sylvan |
137 | perms :: [t] -> [[t]] | ||
111 | perms [x] = [[x]] | 138 | perms [x] = [[x]] |
112 | perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] | 139 | perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] |
113 | selections [] = [] | 140 | selections [] = [] |
114 | selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] | 141 | selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] |
115 | 142 | ||
116 | 143 | interchanges :: (Ord a) => [a] -> Int | |
117 | interchanges ls = sum (map (count ls) ls) | 144 | interchanges ls = sum (map (count ls) ls) |
118 | where count l p = n | 145 | where count l p = length $ filter (>p) $ take pel l |
119 | where Just pel = elemIndex p l | 146 | where Just pel = elemIndex p l |
120 | n = length $ filter (>p) $ take pel l | ||
121 | 147 | ||
148 | signature :: (Num t, Ord a) => [a] -> t | ||
122 | signature l | length (nub l) < length l = 0 | 149 | signature l | length (nub l) < length l = 0 |
123 | | even (interchanges l) = 1 | 150 | | even (interchanges l) = 1 |
124 | | otherwise = -1 | 151 | | otherwise = -1 |