diff options
-rw-r--r-- | examples/pru.hs | 76 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 17 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 73 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 2 |
4 files changed, 133 insertions, 35 deletions
diff --git a/examples/pru.hs b/examples/pru.hs index bddc08f..a935d93 100644 --- a/examples/pru.hs +++ b/examples/pru.hs | |||
@@ -7,7 +7,7 @@ import Data.Packed.Internal.Tensor | |||
7 | 7 | ||
8 | import Complex | 8 | import Complex |
9 | import Numeric(showGFloat) | 9 | import Numeric(showGFloat) |
10 | import Data.List(transpose,intersperse) | 10 | import Data.List(transpose,intersperse,sort) |
11 | import Foreign.Storable | 11 | import Foreign.Storable |
12 | 12 | ||
13 | r >< c = f where | 13 | r >< c = f where |
@@ -22,8 +22,6 @@ r >|< c = f where | |||
22 | ++show (dim v) ++"in ("++show r++"><"++show c++")" | 22 | ++show (dim v) ++"in ("++show r++"><"++show c++")" |
23 | where v = fromList l | 23 | where v = fromList l |
24 | 24 | ||
25 | |||
26 | |||
27 | vr = fromList [1..15::Double] | 25 | vr = fromList [1..15::Double] |
28 | vc = fromList (map (\x->x :+ (x+1)) [1..15::Double]) | 26 | vc = fromList (map (\x->x :+ (x+1)) [1..15::Double]) |
29 | 27 | ||
@@ -62,38 +60,56 @@ main = do | |||
62 | t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] | 60 | t = T [(4,(Covariant,"p")),(2,(Covariant,"q")),(3,(Contravariant,"r"))] $ fromList [1..24::Double] |
63 | 61 | ||
64 | 62 | ||
65 | findIdx name t = ((d1,d2),m) where | ||
66 | (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) | ||
67 | c = product (map fst d2) | ||
68 | m = matrixFromVector RowMajor c (ten t) | ||
69 | 63 | ||
70 | putFirstIdx name t = (nd,m') | 64 | t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double] |
71 | where ((d1,d2),m) = findIdx name t | 65 | t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double] |
72 | m' = matrixFromVector RowMajor c $ cdat $ trans m | ||
73 | nd = d2++d1 | ||
74 | c = dim (ten t) `div` (fst $ head d2) | ||
75 | 66 | ||
76 | part t (name,k) = if k<0 || k>=l | ||
77 | then error $ "part "++show (name,k)++" out of range in "++show t | ||
78 | else T {dims = ds, ten = toRows m !! k} | ||
79 | where (d:ds,m) = putFirstIdx name t | ||
80 | (l,_) = d | ||
81 | 67 | ||
82 | parts t name = map f (toRows m) | ||
83 | where (d:ds,m) = putFirstIdx name t | ||
84 | (l,_) = d | ||
85 | f t = T {dims=ds, ten=t} | ||
86 | 68 | ||
87 | t1 = T [(4,(Covariant,"p")),(4,(Contravariant,"q")),(2,(Covariant,"r"))] $ fromList [1..32::Double] | ||
88 | t2 = T [(4,(Covariant,"p")),(4,(Contravariant,"q"))] $ fromList [1..16::Double] | ||
89 | 69 | ||
90 | contract1 t name1 name2 = map head $ zipWith drop [0..] (map (flip parts name2) (parts t name1)) | ||
91 | 70 | ||
92 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | ||
93 | 71 | ||
94 | on f g = \x y -> f (g x) (g y) | 72 | delta i j | i==j = 1 |
73 | | otherwise = 0 | ||
74 | |||
75 | e i n = fromList [ delta k i | k <- [1..n]] | ||
76 | |||
77 | ident n = fromRows [ e i n | i <- [1..n]] | ||
78 | |||
79 | diag l = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] | ||
80 | where c = length l | ||
81 | |||
82 | tensorFromVector idx v = T {dims = [(dim v,idx)], ten = v} | ||
83 | tensorFromMatrix idxr idxc m = T {dims = [(rows m,idxr),(cols m,idxc)], ten = cdat m} | ||
84 | |||
85 | td = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ diag [1..4] :: Tensor Double | ||
86 | |||
87 | tn = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double | ||
88 | tt = tensorFromMatrix (Contravariant,"i") (Covariant,"j") $ (2><3) [1..6] :: Tensor Double | ||
89 | |||
90 | tq = T [(3,(Covariant,"p")),(2,(Covariant,"q")),(2,(Covariant,"r"))] $ fromList [11 .. 22] :: Tensor Double | ||
91 | |||
92 | r1 = contraction tt "j" tq "p" | ||
93 | r1' = contraction' tt "j" tq "p" | ||
94 | |||
95 | pru = do | ||
96 | mapM_ (putStrLn.shdims.dims.normal) (contractions t1 t2) | ||
97 | let t1 = contraction tt "i" tq "q" | ||
98 | print $ normal t1 | ||
99 | print $ foldl part t1 [("j",0),("p'",1),("r'",1)] | ||
100 | let t2 = contraction' tt "i" tq "q" | ||
101 | print $ normal t2 | ||
102 | print $ foldl part t2 [("j",0),("p'",1),("r'",1)] | ||
103 | let t1 = contraction tq "q" tt "i" | ||
104 | print $ normal t1 | ||
105 | print $ foldl part t1 [("j'",0),("p",1),("r",1)] | ||
106 | let t2 = contraction' tq "q" tt "i" | ||
107 | print $ normal t2 | ||
108 | print $ foldl part t2 [("j'",0),("p",1),("r",1)] | ||
109 | |||
110 | |||
111 | names t = sort $ map (snd.snd) (dims t) | ||
112 | |||
113 | normal t = tridx (names t) t | ||
95 | 114 | ||
96 | contract t1 n1 t2 n2 = T (tail d1++tail d2) (cdat m) | 115 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] |
97 | where (d1,m1) = putFirstIdx n1 t1 | ||
98 | (d2,m2) = putFirstIdx n2 t2 | ||
99 | m = multiply RowMajor (trans m2) m1 \ No newline at end of file | ||
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index a2a70dd..ec6657a 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -65,6 +65,15 @@ partit :: Int -> [a] -> [[a]] | |||
65 | partit _ [] = [] | 65 | partit _ [] = [] |
66 | partit n l = take n l : partit n (drop n l) | 66 | partit n l = take n l : partit n (drop n l) |
67 | 67 | ||
68 | -- | obtains the common value of a property of a list | ||
69 | common :: (Eq a) => (b->a) -> [b] -> Maybe a | ||
70 | common f = commonval . map f where | ||
71 | commonval :: (Eq a) => [a] -> Maybe a | ||
72 | commonval [] = Nothing | ||
73 | commonval [a] = Just a | ||
74 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing | ||
75 | |||
76 | |||
68 | toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m | 77 | toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m |
69 | | otherwise = partit (cols m) . toList . dat $ m | 78 | | otherwise = partit (cols m) . toList . dat $ m |
70 | 79 | ||
@@ -115,7 +124,7 @@ transdataAux fun c1 d c2 = | |||
115 | else unsafePerformIO $ do | 124 | else unsafePerformIO $ do |
116 | v <- createVector (dim d) | 125 | v <- createVector (dim d) |
117 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | 126 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] |
118 | putStrLn "---> transdataAux" | 127 | --putStrLn "---> transdataAux" |
119 | return v | 128 | return v |
120 | where r1 = dim d `div` c1 | 129 | where r1 = dim d `div` c1 |
121 | r2 = dim d `div` c2 | 130 | r2 = dim d `div` c2 |
@@ -136,6 +145,12 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | |||
136 | --{-# RULES "transdataR" transdata=transdataR #-} | 145 | --{-# RULES "transdataR" transdata=transdataR #-} |
137 | --{-# RULES "transdataC" transdata=transdataC #-} | 146 | --{-# RULES "transdataC" transdata=transdataC #-} |
138 | 147 | ||
148 | -- | creates a Matrix from a list of vectors | ||
149 | fromRows :: Field t => [Vector t] -> Matrix t | ||
150 | fromRows vs = case common dim vs of | ||
151 | Nothing -> error "fromRows applied to [] or to vectors with different sizes" | ||
152 | Just c -> reshape c (join vs) | ||
153 | |||
139 | -- | extracts the rows of a matrix as a list of vectors | 154 | -- | extracts the rows of a matrix as a list of vectors |
140 | toRows :: Storable t => Matrix t -> [Vector t] | 155 | toRows :: Storable t => Matrix t -> [Vector t] |
141 | toRows m = toRows' 0 where | 156 | toRows m = toRows' 0 where |
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 960e3c5..b66d6b8 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -1,4 +1,4 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | 1 | --{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} |
2 | ----------------------------------------------------------------------------- | 2 | ----------------------------------------------------------------------------- |
3 | -- | | 3 | -- | |
4 | -- Module : Data.Packed.Internal.Tensor | 4 | -- Module : Data.Packed.Internal.Tensor |
@@ -19,7 +19,7 @@ import Data.Packed.Internal.Vector | |||
19 | import Data.Packed.Internal.Matrix | 19 | import Data.Packed.Internal.Matrix |
20 | import Foreign.Storable | 20 | import Foreign.Storable |
21 | 21 | ||
22 | data IdxTp = Covariant | Contravariant deriving Show | 22 | data IdxTp = Covariant | Contravariant deriving (Show,Eq) |
23 | 23 | ||
24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] | 24 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] |
25 | , ten :: Vector t | 25 | , ten :: Vector t |
@@ -36,5 +36,70 @@ instance (Show a,Storable a) => Show (Tensor a) where | |||
36 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) | 36 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) |
37 | 37 | ||
38 | 38 | ||
39 | shdims [(n,(t,name))] = name++"["++show n++"]" | 39 | shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]" |
40 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file | 40 | where sym Covariant = "_" |
41 | sym Contravariant = "^" | ||
42 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds | ||
43 | |||
44 | |||
45 | |||
46 | findIdx name t = ((d1,d2),m) where | ||
47 | (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) | ||
48 | c = product (map fst d2) | ||
49 | m = matrixFromVector RowMajor c (ten t) | ||
50 | |||
51 | putFirstIdx name t = (nd,m') | ||
52 | where ((d1,d2),m) = findIdx name t | ||
53 | m' = matrixFromVector RowMajor c $ cdat $ trans m | ||
54 | nd = d2++d1 | ||
55 | c = dim (ten t) `div` (fst $ head d2) | ||
56 | |||
57 | part t (name,k) = if k<0 || k>=l | ||
58 | then error $ "part "++show (name,k)++" out of range in "++show t | ||
59 | else T {dims = ds, ten = toRows m !! k} | ||
60 | where (d:ds,m) = putFirstIdx name t | ||
61 | (l,_) = d | ||
62 | |||
63 | parts t name = map f (toRows m) | ||
64 | where (d:ds,m) = putFirstIdx name t | ||
65 | (l,_) = d | ||
66 | f t = T {dims=ds, ten=t} | ||
67 | |||
68 | concatRename l1 l2 = l1 ++ map ren l2 where | ||
69 | ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s)) | ||
70 | fs = map (snd.snd) l1 | ||
71 | |||
72 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2) | ||
73 | |||
74 | contraction t1 n1 t2 n2 = | ||
75 | if compatIdx t1 n1 t2 n2 | ||
76 | then T (concatRename (tail d1) (tail d2)) (cdat m) | ||
77 | else error "wrong contraction'" | ||
78 | where (d1,m1) = putFirstIdx n1 t1 | ||
79 | (d2,m2) = putFirstIdx n2 t2 | ||
80 | m = multiply RowMajor (trans m1) m2 | ||
81 | |||
82 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | ||
83 | |||
84 | contract1 t name1 name2 = T d $ fromList $ sumT y | ||
85 | where d = dims (head y) | ||
86 | x = (map (flip parts name2) (parts t name1)) | ||
87 | y = map head $ zipWith drop [0..] x | ||
88 | |||
89 | contraction' t1 n1 t2 n2 = | ||
90 | if compatIdx t1 n1 t2 n2 | ||
91 | then contract1 (prod t1 t2) n1 (n2++"'") | ||
92 | else error "wrong contraction'" | ||
93 | |||
94 | tridx [] t = t | ||
95 | tridx (name:rest) t = T (d:ds) (join ts) where | ||
96 | ((_,d:_),_) = findIdx name t | ||
97 | ps = map (tridx rest) (parts t name) | ||
98 | ts = map ten ps | ||
99 | ds = dims (head ps) | ||
100 | |||
101 | compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2 | ||
102 | |||
103 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where | ||
104 | d1 = head $ snd $ fst $ findIdx n1 t1 | ||
105 | d2 = head $ snd $ fst $ findIdx n2 t2 | ||
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 6ed9339..36d5df7 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -35,6 +35,8 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where -- | |||
35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | 35 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- |
36 | ---------------------------------------------------------------------- | 36 | ---------------------------------------------------------------------- |
37 | 37 | ||
38 | on f g = \x y -> f (g x) (g y) | ||
39 | |||
38 | (//) :: x -> (x -> y) -> y | 40 | (//) :: x -> (x -> y) -> y |
39 | infixl 0 // | 41 | infixl 0 // |
40 | (//) = flip ($) | 42 | (//) = flip ($) |