diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-06 17:40:09 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-06 17:40:09 +0000 |
commit | e7c03c1ab4de85e7a700d2eafaebd37f4607c51f (patch) | |
tree | 4512d18907d88d0390671fcde4e8886d30cd0492 /lib/Data/Packed/Internal | |
parent | a4254a0b9bfbd720efbe42b86aa50107a74d56c7 (diff) |
working on tensor contractions
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 17 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 73 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 2 |
3 files changed, 87 insertions, 5 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index a2a70dd..ec6657a 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -65,6 +65,15 @@ partit :: Int -> [a] -> [[a]] | |||
65 | partit _ [] = [] | 65 | partit _ [] = [] |
66 | partit n l = take n l : partit n (drop n l) | 66 | partit n l = take n l : partit n (drop n l) |
67 | 67 | ||
68 | -- | obtains the common value of a property of a list | ||
69 | common :: (Eq a) => (b->a) -> [b] -> Maybe a | ||
70 | common f = commonval . map f where | ||
71 | commonval :: (Eq a) => [a] -> Maybe a | ||
72 | commonval [] = Nothing | ||
73 | commonval [a] = Just a | ||
74 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing | ||
75 | |||
76 | |||
68 | toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m | 77 | toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m |
69 | | otherwise = partit (cols m) . toList . dat $ m | 78 | | otherwise = partit (cols m) . toList . dat $ m |
70 | 79 | ||
@@ -115,7 +124,7 @@ transdataAux fun c1 d c2 = | |||
115 | else unsafePerformIO $ do | 124 | else unsafePerformIO $ do |
116 | v <- createVector (dim d) | 125 | v <- createVector (dim d) |
117 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | 126 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] |
118 | putStrLn "---> transdataAux" | 127 | --putStrLn "---> transdataAux" |
119 | return v | 128 | return v |
120 | where r1 = dim d `div` c1 | 129 | where r1 = dim d `div` c1 |
121 | r2 = dim d `div` c2 | 130 | r2 = dim d `div` c2 |
@@ -136,6 +145,12 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | |||
136 | --{-# RULES "transdataR" transdata=transdataR #-} | 145 | --{-# RULES "transdataR" transdata=transdataR #-} |
137 | --{-# RULES "transdataC" transdata=transdataC #-} | 146 | --{-# RULES "transdataC" transdata=transdataC #-} |
138 | 147 | ||
148 | -- | creates a Matrix from a list of vectors | ||
149 | fromRows :: Field t => [Vector t] -> Matrix t | ||
150 | fromRows vs = case common dim vs of | ||
151 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" | ||
152 | Just c -> reshape c (join vs) | ||
153 | |||
139 | -- | extracts the rows of a matrix as a list of vectors | 154 | -- | extracts the rows of a matrix as a list of vectors |
140 | toRows :: Storable t => Matrix t -> [Vector t] | 155 | toRows :: Storable t => Matrix t -> [Vector t] |
141 | toRows m = toRows' 0 where | 156 | toRows m = toRows' 0 where |
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 960e3c5..b66d6b8 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | 1 | --{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} |
2 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
3 | -- | | 3 | -- | |
4 | -- Module : Data.Packed.Internal.Tensor | 4 | -- Module : Data.Packed.Internal.Tensor |
@@ -19,7 +19,7 @@ import Data.Packed.Internal.Vector | |||
19 | import Data.Packed.Internal.Matrix | 19 | import Data.Packed.Internal.Matrix |
20 | import Foreign.Storable | 20 | import Foreign.Storable |
21 | 21 | ||
22 | data IdxTp = Covariant | Contravariant deriving Show | 22 | data IdxTp = Covariant | Contravariant deriving (Show,Eq) |
23 | 23 | ||
24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] | 24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] |
25 | , ten :: Vector t | 25 | , ten :: Vector t |
@@ -36,5 +36,70 @@ instance (Show a,Storable a) => Show (Tensor a) where | |||
36 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) | 36 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) |
37 | 37 | ||
38 | 38 | ||
39 | shdims [(n,(t,name))] = name++"["++show n++"]" | 39 | shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]" |
40 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file | 40 | where sym Covariant = "_" |
41 | sym Contravariant = "^" | ||
42 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds | ||
43 | |||
44 | |||
45 | |||
46 | findIdx name t = ((d1,d2),m) where | ||
47 | (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) | ||
48 | c = product (map fst d2) | ||
49 | m = matrixFromVector RowMajor c (ten t) | ||
50 | |||
51 | putFirstIdx name t = (nd,m') | ||
52 | where ((d1,d2),m) = findIdx name t | ||
53 | m' = matrixFromVector RowMajor c $ cdat $ trans m | ||
54 | nd = d2++d1 | ||
55 | c = dim (ten t) `div` (fst $ head d2) | ||
56 | |||
57 | part t (name,k) = if k<0 || k>=l | ||
58 | then error $ "part "++show (name,k)++" out of range in "++show t | ||
59 | else T {dims = ds, ten = toRows m !! k} | ||
60 | where (d:ds,m) = putFirstIdx name t | ||
61 | (l,_) = d | ||
62 | |||
63 | parts t name = map f (toRows m) | ||
64 | where (d:ds,m) = putFirstIdx name t | ||
65 | (l,_) = d | ||
66 | f t = T {dims=ds, ten=t} | ||
67 | |||
68 | concatRename l1 l2 = l1 ++ map ren l2 where | ||
69 | ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s)) | ||
70 | fs = map (snd.snd) l1 | ||
71 | |||
72 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2) | ||
73 | |||
74 | contraction t1 n1 t2 n2 = | ||
75 | if compatIdx t1 n1 t2 n2 | ||
76 | then T (concatRename (tail d1) (tail d2)) (cdat m) | ||
77 | else error "wrong contraction'" | ||
78 | where (d1,m1) = putFirstIdx n1 t1 | ||
79 | (d2,m2) = putFirstIdx n2 t2 | ||
80 | m = multiply RowMajor (trans m1) m2 | ||
81 | |||
82 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | ||
83 | |||
84 | contract1 t name1 name2 = T d $ fromList $ sumT y | ||
85 | where d = dims (head y) | ||
86 | x = (map (flip parts name2) (parts t name1)) | ||
87 | y = map head $ zipWith drop [0..] x | ||
88 | |||
89 | contraction' t1 n1 t2 n2 = | ||
90 | if compatIdx t1 n1 t2 n2 | ||
91 | then contract1 (prod t1 t2) n1 (n2++"'") | ||
92 | else error "wrong contraction'" | ||
93 | |||
94 | tridx [] t = t | ||
95 | tridx (name:rest) t = T (d:ds) (join ts) where | ||
96 | ((_,d:_),_) = findIdx name t | ||
97 | ps = map (tridx rest) (parts t name) | ||
98 | ts = map ten ps | ||
99 | ds = dims (head ps) | ||
100 | |||
101 | compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2 | ||
102 | |||
103 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where | ||
104 | d1 = head $ snd $ fst $ findIdx n1 t1 | ||
105 | d2 = head $ snd $ fst $ findIdx n2 t2 | ||
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 6ed9339..36d5df7 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -35,6 +35,8 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where -- | |||
35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | 35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- |
36 | ---------------------------------------------------------------------- | 36 | ---------------------------------------------------------------------- |
37 | 37 | ||
38 | on f g = \x y -> f (g x) (g y) | ||
39 | |||
38 | (//) :: x -> (x -> y) -> y | 40 | (//) :: x -> (x -> y) -> y |
39 | infixl 0 // | 41 | infixl 0 // |
40 | (//) = flip ($) | 42 | (//) = flip ($) |