summaryrefslogtreecommitdiff
path: root/lib/Data
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs17
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs73
-rw-r--r--lib/Data/Packed/Internal/Vector.hs2
3 files changed, 87 insertions, 5 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index a2a70dd..ec6657a 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -65,6 +65,15 @@ partit :: Int -> [a] -> [[a]]
65partit _ [] = [] 65partit _ [] = []
66partit n l = take n l : partit n (drop n l) 66partit n l = take n l : partit n (drop n l)
67 67
68-- | obtains the common value of a property of a list
69common :: (Eq a) => (b->a) -> [b] -> Maybe a
70common f = commonval . map f where
71 commonval :: (Eq a) => [a] -> Maybe a
72 commonval [] = Nothing
73 commonval [a] = Just a
74 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
75
76
68toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m 77toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m
69 | otherwise = partit (cols m) . toList . dat $ m 78 | otherwise = partit (cols m) . toList . dat $ m
70 79
@@ -115,7 +124,7 @@ transdataAux fun c1 d c2 =
115 else unsafePerformIO $ do 124 else unsafePerformIO $ do
116 v <- createVector (dim d) 125 v <- createVector (dim d)
117 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] 126 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
118 putStrLn "---> transdataAux" 127 --putStrLn "---> transdataAux"
119 return v 128 return v
120 where r1 = dim d `div` c1 129 where r1 = dim d `div` c1
121 r2 = dim d `div` c2 130 r2 = dim d `div` c2
@@ -136,6 +145,12 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
136--{-# RULES "transdataR" transdata=transdataR #-} 145--{-# RULES "transdataR" transdata=transdataR #-}
137--{-# RULES "transdataC" transdata=transdataC #-} 146--{-# RULES "transdataC" transdata=transdataC #-}
138 147
148-- | creates a Matrix from a list of vectors
149fromRows :: Field t => [Vector t] -> Matrix t
150fromRows vs = case common dim vs of
151 Nothing -> error "fromRows applied to [] or to vectors with different sizes"
152 Just c -> reshape c (join vs)
153
139-- | extracts the rows of a matrix as a list of vectors 154-- | extracts the rows of a matrix as a list of vectors
140toRows :: Storable t => Matrix t -> [Vector t] 155toRows :: Storable t => Matrix t -> [Vector t]
141toRows m = toRows' 0 where 156toRows m = toRows' 0 where
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 960e3c5..b66d6b8 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -1,4 +1,4 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} 1--{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3-- | 3-- |
4-- Module : Data.Packed.Internal.Tensor 4-- Module : Data.Packed.Internal.Tensor
@@ -19,7 +19,7 @@ import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix 19import Data.Packed.Internal.Matrix
20import Foreign.Storable 20import Foreign.Storable
21 21
22data IdxTp = Covariant | Contravariant deriving Show 22data IdxTp = Covariant | Contravariant deriving (Show,Eq)
23 23
24data Tensor t = T { dims :: [(Int,(IdxTp,String))] 24data Tensor t = T { dims :: [(Int,(IdxTp,String))]
25 , ten :: Vector t 25 , ten :: Vector t
@@ -36,5 +36,70 @@ instance (Show a,Storable a) => Show (Tensor a) where
36 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) 36 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t)
37 37
38 38
39shdims [(n,(t,name))] = name++"["++show n++"]" 39shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]"
40shdims (d:ds) = shdims [d] ++ "><"++ shdims ds \ No newline at end of file 40 where sym Covariant = "_"
41 sym Contravariant = "^"
42shdims (d:ds) = shdims [d] ++ "><"++ shdims ds
43
44
45
46findIdx name t = ((d1,d2),m) where
47 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t)
48 c = product (map fst d2)
49 m = matrixFromVector RowMajor c (ten t)
50
51putFirstIdx name t = (nd,m')
52 where ((d1,d2),m) = findIdx name t
53 m' = matrixFromVector RowMajor c $ cdat $ trans m
54 nd = d2++d1
55 c = dim (ten t) `div` (fst $ head d2)
56
57part t (name,k) = if k<0 || k>=l
58 then error $ "part "++show (name,k)++" out of range in "++show t
59 else T {dims = ds, ten = toRows m !! k}
60 where (d:ds,m) = putFirstIdx name t
61 (l,_) = d
62
63parts t name = map f (toRows m)
64 where (d:ds,m) = putFirstIdx name t
65 (l,_) = d
66 f t = T {dims=ds, ten=t}
67
68concatRename l1 l2 = l1 ++ map ren l2 where
69 ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s))
70 fs = map (snd.snd) l1
71
72prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2)
73
74contraction t1 n1 t2 n2 =
75 if compatIdx t1 n1 t2 n2
76 then T (concatRename (tail d1) (tail d2)) (cdat m)
77 else error "wrong contraction'"
78 where (d1,m1) = putFirstIdx n1 t1
79 (d2,m2) = putFirstIdx n2 t2
80 m = multiply RowMajor (trans m1) m2
81
82sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
83
84contract1 t name1 name2 = T d $ fromList $ sumT y
85 where d = dims (head y)
86 x = (map (flip parts name2) (parts t name1))
87 y = map head $ zipWith drop [0..] x
88
89contraction' t1 n1 t2 n2 =
90 if compatIdx t1 n1 t2 n2
91 then contract1 (prod t1 t2) n1 (n2++"'")
92 else error "wrong contraction'"
93
94tridx [] t = t
95tridx (name:rest) t = T (d:ds) (join ts) where
96 ((_,d:_),_) = findIdx name t
97 ps = map (tridx rest) (parts t name)
98 ts = map ten ps
99 ds = dims (head ps)
100
101compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2
102
103compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where
104 d1 = head $ snd $ fst $ findIdx n1 t1
105 d2 = head $ snd $ fst $ findIdx n2 t2
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index 6ed9339..36d5df7 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -35,6 +35,8 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where --
35 poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- 35 poke p (a :+ b) = pokeArray (castPtr p) [a,b] --
36---------------------------------------------------------------------- 36----------------------------------------------------------------------
37 37
38on f g = \x y -> f (g x) (g y)
39
38(//) :: x -> (x -> y) -> y 40(//) :: x -> (x -> y) -> y
39infixl 0 // 41infixl 0 //
40(//) = flip ($) 42(//) = flip ($)