summaryrefslogtreecommitdiff
path: root/lib/Data
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2007-09-08 09:46:33 +0000
committerAlberto Ruiz <aruiz@um.es>2007-09-08 09:46:33 +0000
commit34380f2b5d7b048a4d68197f16a8db0e53742030 (patch)
tree444aff88cda5c247d49bac0d294d8cfb9ef7bf23 /lib/Data
parent0c38c1b0e122a56ea98c494e60ba90afe2688664 (diff)
type classes
Diffstat (limited to 'lib/Data')
-rw-r--r--lib/Data/Packed/Internal/Common.hs18
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs204
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs37
-rw-r--r--lib/Data/Packed/Internal/Vector.hs42
-rw-r--r--lib/Data/Packed/Matrix.hs15
-rw-r--r--lib/Data/Packed/Vector.hs12
6 files changed, 184 insertions, 144 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs
index 1bfed6d..1212968 100644
--- a/lib/Data/Packed/Internal/Common.hs
+++ b/lib/Data/Packed/Internal/Common.hs
@@ -28,7 +28,7 @@ debug x = trace (show x) x
28data Vector t = V { dim :: Int 28data Vector t = V { dim :: Int
29 , fptr :: ForeignPtr t 29 , fptr :: ForeignPtr t
30 , ptr :: Ptr t 30 , ptr :: Ptr t
31 } deriving Typeable 31 } -- deriving Typeable
32 32
33---------------------------------------------------------------------- 33----------------------------------------------------------------------
34instance (Storable a, RealFloat a) => Storable (Complex a) where -- 34instance (Storable a, RealFloat a) => Storable (Complex a) where --
@@ -78,17 +78,17 @@ check msg ls f = do
78 mapM_ (touchForeignPtr . fptr) ls 78 mapM_ (touchForeignPtr . fptr) ls
79 return () 79 return ()
80 80
81class (Storable a, Typeable a) => Field a 81--class (Storable a, Typeable a) => Field a
82instance (Storable a, Typeable a) => Field a 82--instance (Storable a, Typeable a) => Field a
83 83
84isReal :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool 84--isReal :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool
85isReal w x = typeOf (undefined :: Double) == typeOf (w x) 85--isReal w x = typeOf (undefined :: Double) == typeOf (w x)
86 86
87isComp :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool 87--isComp :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool
88isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) 88--isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x)
89 89
90scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b 90--scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b
91scast = fromJust . cast 91--scast = fromJust . cast
92 92
93{- | conversion of Haskell functions into function pointers that can be used in the C side 93{- | conversion of Haskell functions into function pointers that can be used in the C side
94-} 94-}
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 9895393..48652f3 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -1,4 +1,4 @@
1{-# OPTIONS_GHC -fglasgow-exts #-} 1{-# OPTIONS_GHC -fglasgow-exts -fallow-overlapping-instances #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3-- | 3-- |
4-- Module : Data.Packed.Internal.Matrix 4-- Module : Data.Packed.Internal.Matrix
@@ -22,9 +22,65 @@ import Foreign hiding (xor)
22import Complex 22import Complex
23import Control.Monad(when) 23import Control.Monad(when)
24import Data.List(transpose,intersperse) 24import Data.List(transpose,intersperse)
25import Data.Typeable 25--import Data.Typeable
26import Data.Maybe(fromJust) 26import Data.Maybe(fromJust)
27 27
28----------------------------------------------------------------
29
30class Storable a => Field a where
31 constant :: a -> Int -> Vector a
32 transdata :: Int -> Vector a -> Int -> Vector a
33 multiplyD :: MatrixOrder -> Matrix a -> Matrix a -> Matrix a
34 subMatrix :: (Int,Int) -- ^ (r0,c0) starting position
35 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
36 -> Matrix a -> Matrix a
37 diag :: Vector a -> Matrix a
38
39
40instance Field Double where
41 constant = constantR
42 transdata = transdataR
43 multiplyD = multiplyR
44 subMatrix = subMatrixR
45 diag = diagR
46
47instance Field (Complex Double) where
48 constant = constantC
49 transdata = transdataC
50 multiplyD = multiplyC
51 subMatrix = subMatrixC
52 diag = diagC
53
54-----------------------------------------------------------------
55
56transdataR :: Int -> Vector Double -> Int -> Vector Double
57transdataR = transdataAux ctransR
58
59transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double)
60transdataC = transdataAux ctransC
61
62transdataAux fun c1 d c2 =
63 if noneed
64 then d
65 else unsafePerformIO $ do
66 v <- createVector (dim d)
67 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
68 --putStrLn "---> transdataAux"
69 return v
70 where r1 = dim d `div` c1
71 r2 = dim d `div` c2
72 noneed = r1 == 1 || c1 == 1
73
74foreign import ccall safe "aux.h transR"
75 ctransR :: TMM -- Double ::> Double ::> IO Int
76foreign import ccall safe "aux.h transC"
77 ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int
78
79transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d
80
81
82
83
28 84
29data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) 85data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
30 86
@@ -34,9 +90,18 @@ data Matrix t = M { rows :: Int
34 , tdat :: Vector t 90 , tdat :: Vector t
35 , isTrans :: Bool 91 , isTrans :: Bool
36 , order :: MatrixOrder 92 , order :: MatrixOrder
37 } deriving Typeable 93 } -- deriving Typeable
38 94
39 95
96data NMat t = MC { rws, cls :: Int, dtc :: Vector t}
97 | MF { rws, cls :: Int, dtf :: Vector t}
98 | Tr (NMat t)
99
100ntrans (Tr m) = m
101ntrans m = Tr m
102
103viewC m@MC{} = m
104viewF m@MF{} = m
40 105
41fortran m = order m == ColumnMajor 106fortran m = order m == ColumnMajor
42 107
@@ -78,7 +143,11 @@ matrixFromVector RowMajor c v =
78 , tdat = transdata c v r 143 , tdat = transdata c v r
79 , order = RowMajor 144 , order = RowMajor
80 , isTrans = False 145 , isTrans = False
81 } where r = dim v `div` c -- TODO check mod=0 146 } where (d,m) = dim v `divMod` c
147 r | m==0 = d
148 | otherwise = error "matrixFromVector"
149
150-- r = dim v `div` c -- TODO check mod=0
82 151
83matrixFromVector ColumnMajor c v = 152matrixFromVector ColumnMajor c v =
84 M { rows = r 153 M { rows = r
@@ -87,7 +156,9 @@ matrixFromVector ColumnMajor c v =
87 , tdat = transdata r v c 156 , tdat = transdata r v c
88 , order = ColumnMajor 157 , order = ColumnMajor
89 , isTrans = False 158 , isTrans = False
90 } where r = dim v `div` c -- TODO check mod=0 159 } where (d,m) = dim v `divMod` c
160 r | m==0 = d
161 | otherwise = error "matrixFromVector"
91 162
92createMatrix order r c = do 163createMatrix order r c = do
93 p <- createVector (r*c) 164 p <- createVector (r*c)
@@ -102,48 +173,11 @@ createMatrix order r c = do
102 , 9.0, 10.0, 11.0, 12.0 ]@ 173 , 9.0, 10.0, 11.0, 12.0 ]@
103 174
104-} 175-}
105reshape :: (Field t) => Int -> Vector t -> Matrix t 176reshape :: Field t => Int -> Vector t -> Matrix t
106reshape c v = matrixFromVector RowMajor c v 177reshape c v = matrixFromVector RowMajor c v
107 178
108singleton x = reshape 1 (fromList [x]) 179singleton x = reshape 1 (fromList [x])
109 180
110transdataG :: Storable a => Int -> Vector a -> Int -> Vector a
111transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d
112
113transdataR :: Int -> Vector Double -> Int -> Vector Double
114transdataR = transdataAux ctransR
115
116transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double)
117transdataC = transdataAux ctransC
118
119transdataAux fun c1 d c2 =
120 if noneed
121 then d
122 else unsafePerformIO $ do
123 v <- createVector (dim d)
124 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
125 --putStrLn "---> transdataAux"
126 return v
127 where r1 = dim d `div` c1
128 r2 = dim d `div` c2
129 noneed = r1 == 1 || c1 == 1
130
131foreign import ccall safe "aux.h transR"
132 ctransR :: TMM -- Double ::> Double ::> IO Int
133foreign import ccall safe "aux.h transC"
134 ctransC :: TCMCM -- Complex Double ::> Complex Double ::> IO Int
135
136transdata :: Field a => Int -> Vector a -> Int -> Vector a
137transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
138 | isComp baseOf d = scast $ transdataC c1 (scast d) c2
139 | otherwise = transdataG c1 d c2
140
141--transdata :: Storable a => Int -> Vector a -> Int -> Vector a
142--transdata = transdataG
143--{-# RULES "transdataR" transdata=transdataR #-}
144--{-# RULES "transdataC" transdata=transdataC #-}
145
146-----------------------------------------------------------------
147liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 181liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
148liftMatrix f m = reshape (cols m) (f (cdat m)) 182liftMatrix f m = reshape (cols m) (f (cdat m))
149 183
@@ -163,7 +197,7 @@ multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a]
163 Nothing -> False 197 Nothing -> False
164 Just c -> c == length b 198 Just c -> c == length b
165 199
166transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) 200transL m = matrixFromVector RowMajor (rows m) $ transdata (cols m) (cdat m) (rows m)
167 201
168multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) 202multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b)
169 203
@@ -179,7 +213,7 @@ gmatC m f | fortran m =
179 else f 0 (rows m) (cols m) (ptr (dat m)) 213 else f 0 (rows m) (cols m) (ptr (dat m))
180 214
181 215
182multiplyAux order fun a b = unsafePerformIO $ do 216multiplyAux fun order a b = unsafePerformIO $ do
183 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ 217 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
184 show (rows a,cols a) ++ " x " ++ show (rows b, cols b) 218 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
185 r <- createMatrix order (rows a) (cols b) 219 r <- createMatrix order (rows a) (cols b)
@@ -198,37 +232,14 @@ foreign import ccall safe "aux.h multiplyC"
198 -> Int -> Int -> Ptr (Complex Double) 232 -> Int -> Int -> Ptr (Complex Double)
199 -> IO Int 233 -> IO Int
200 234
201multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a 235multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
202multiply RowMajor a b = multiplyD RowMajor a b 236multiply RowMajor a b = multiplyD RowMajor a b
203multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} 237multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor}
204 where m = multiplyD RowMajor (trans b) (trans a) 238 where m = multiplyD RowMajor (trans b) (trans a)
205 239
206multiplyD order a b
207 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b)
208 | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b)
209 | otherwise = multiplyG a b
210
211----------------------------------------------------------------------
212 240
213outer' u v = dat (outer u v) 241multiplyR = multiplyAux cmultiplyR
214 242multiplyC = multiplyAux cmultiplyC
215{- | Outer product of two vectors.
216
217@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
218(3><3)
219 [ 5.0, 2.0, 3.0
220 , 10.0, 4.0, 6.0
221 , 15.0, 6.0, 9.0 ]@
222-}
223outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t
224outer u v = multiply RowMajor r c
225 where r = matrixFromVector RowMajor 1 u
226 c = matrixFromVector RowMajor (dim v) v
227
228dot :: (Field t, Num t) => Vector t -> Vector t -> t
229dot u v = dat (multiply RowMajor r c) `at` 0
230 where r = matrixFromVector RowMajor (dim u) u
231 c = matrixFromVector RowMajor 1 v
232 243
233---------------------------------------------------------------------- 244----------------------------------------------------------------------
234 245
@@ -251,14 +262,14 @@ subMatrixC (r0,c0) (rt,ct) x =
251 subMatrixR (r0,2*c0) (rt,2*ct) . 262 subMatrixR (r0,2*c0) (rt,2*ct) .
252 reshape (2*cols x) . asReal . cdat $ x 263 reshape (2*cols x) . asReal . cdat $ x
253 264
254subMatrix :: (Field a) 265--subMatrix :: (Field a)
255 => (Int,Int) -- ^ (r0,c0) starting position 266-- => (Int,Int) -- ^ (r0,c0) starting position
256 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 267-- -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
257 -> Matrix a -> Matrix a 268-- -> Matrix a -> Matrix a
258subMatrix st sz m 269--subMatrix st sz m
259 | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m) 270-- | isReal (baseOf.dat) m = scast $ subMatrixR st sz (scast m)
260 | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m) 271-- | isComp (baseOf.dat) m = scast $ subMatrixC st sz (scast m)
261 | otherwise = subMatrixG st sz m 272-- | otherwise = subMatrixG st sz m
262 273
263subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x)) 274subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0 ct) (subList r0 rt (toLists x))
264 where subList s n = take n . drop s 275 where subList s n = take n . drop s
@@ -281,11 +292,11 @@ diagC = diagAux c_diagC "diagC"
281foreign import ccall "aux.h diagC" c_diagC :: TCVCM 292foreign import ccall "aux.h diagC" c_diagC :: TCVCM
282 293
283-- | diagonal matrix from a vector 294-- | diagonal matrix from a vector
284diag :: (Num a, Field a) => Vector a -> Matrix a 295--diag :: (Num a, Field a) => Vector a -> Matrix a
285diag v 296--diag v
286 | isReal (baseOf) v = scast $ diagR (scast v) 297-- | isReal (baseOf) v = scast $ diagR (scast v)
287 | isComp (baseOf) v = scast $ diagC (scast v) 298-- | isComp (baseOf) v = scast $ diagC (scast v)
288 | otherwise = diagG v 299-- | otherwise = diagG v
289 300
290diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]] 301diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..c]]
291 where c = dim v 302 where c = dim v
@@ -313,13 +324,34 @@ fromColumns :: Field t => [Vector t] -> Matrix t
313fromColumns m = trans . fromRows $ m 324fromColumns m = trans . fromRows $ m
314 325
315-- | Creates a list of vectors from the columns of a matrix 326-- | Creates a list of vectors from the columns of a matrix
316toColumns :: Field t => Matrix t -> [Vector t] 327toColumns :: Storable t => Matrix t -> [Vector t]
317toColumns m = toRows . trans $ m 328toColumns m = toRows . trans $ m
318 329
319 330
320-- | Reads a matrix position. 331-- | Reads a matrix position.
321(@@>) :: Field t => Matrix t -> (Int,Int) -> t 332(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
322infixl 9 @@> 333infixl 9 @@>
323m@M {rows = r, cols = c} @@> (i,j) 334m@M {rows = r, cols = c} @@> (i,j)
324 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" 335 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
325 | otherwise = cdat m `at` (i*c+j) 336 | otherwise = cdat m `at` (i*c+j)
337
338------------------------------------------------------------------
339
340constantR :: Double -> Int -> Vector Double
341constantR = constantAux cconstantR
342
343constantC :: Complex Double -> Int -> Vector (Complex Double)
344constantC = constantAux cconstantC
345
346constantAux fun x n = unsafePerformIO $ do
347 v <- createVector n
348 px <- newArray [x]
349 fun px // vec v // check "constantAux" []
350 free px
351 return v
352
353foreign import ccall safe "aux.h constantR"
354 cconstantR :: Ptr Double -> TV -- Double :> IO Int
355
356foreign import ccall safe "aux.h constantC"
357 cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 34132d8..6876685 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -1,3 +1,5 @@
1{-# OPTIONS_GHC -fglasgow-exts #-}
2
1----------------------------------------------------------------------------- 3-----------------------------------------------------------------------------
2-- | 4-- |
3-- Module : Data.Packed.Internal.Tensor 5-- Module : Data.Packed.Internal.Tensor
@@ -19,6 +21,8 @@ import Foreign.Storable
19import Data.List(sort,elemIndex,nub,foldl1',foldl') 21import Data.List(sort,elemIndex,nub,foldl1',foldl')
20import GSL.Vector 22import GSL.Vector
21import Data.Packed.Matrix 23import Data.Packed.Matrix
24import Data.Packed.Vector
25import LinearAlgebra.Linear
22 26
23data IdxType = Covariant | Contravariant deriving (Show,Eq) 27data IdxType = Covariant | Contravariant deriving (Show,Eq)
24 28
@@ -171,6 +175,7 @@ compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where
171 = t1 /= t2 && n1 == n2 175 = t1 /= t2 && n1 == n2
172 176
173 177
178outer' u v = dat (outer u v)
174 179
175-- | tensor product without without any contractions 180-- | tensor product without without any contractions
176rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t 181rawProduct :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
@@ -187,7 +192,7 @@ contraction2 t1 n1 t2 n2 =
187 m = multiply RowMajor (trans m1) m2 192 m = multiply RowMajor (trans m1) m2
188 193
189-- | contraction of a tensor along two given indices 194-- | contraction of a tensor along two given indices
190contraction1 :: (Field t, Num t) => Tensor t -> IdxName -> IdxName -> Tensor t 195contraction1 :: (Linear Vector t) => Tensor t -> IdxName -> IdxName -> Tensor t
191contraction1 t name1 name2 = 196contraction1 t name1 name2 =
192 if compatIdx t name1 t name2 197 if compatIdx t name1 t name2
193 then sumT y 198 then sumT y
@@ -197,7 +202,7 @@ contraction1 t name1 name2 =
197 y = map head $ zipWith drop [0..] x 202 y = map head $ zipWith drop [0..] x
198 203
199-- | contraction of a tensor along a repeated index 204-- | contraction of a tensor along a repeated index
200contraction1c :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t 205contraction1c :: (Linear Vector t) => Tensor t -> IdxName -> Tensor t
201contraction1c t n = contraction1 renamed n' n 206contraction1c t n = contraction1 renamed n' n
202 where n' = n++"'" -- hmmm 207 where n' = n++"'" -- hmmm
203 renamed = withIdx t auxnames 208 renamed = withIdx t auxnames
@@ -205,31 +210,31 @@ contraction1c t n = contraction1 renamed n' n
205 (h,_:r) = break (==n) (map idxName (dims t)) 210 (h,_:r) = break (==n) (map idxName (dims t))
206 211
207-- | alternative and inefficient version of contraction2 212-- | alternative and inefficient version of contraction2
208contraction2' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t 213contraction2' :: (Linear Vector t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
209contraction2' t1 n1 t2 n2 = 214contraction2' t1 n1 t2 n2 =
210 if compatIdx t1 n1 t2 n2 215 if compatIdx t1 n1 t2 n2
211 then contraction1 (rawProduct t1 t2) n1 n2 216 then contraction1 (rawProduct t1 t2) n1 n2
212 else error "wrong contraction'" 217 else error "wrong contraction'"
213 218
214-- | applies a sequence of contractions 219-- | applies a sequence of contractions
215contractions :: (Field t, Num t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t 220contractions :: (Linear Vector t) => Tensor t -> [(IdxName, IdxName)] -> Tensor t
216contractions t pairs = foldl' contract1b t pairs 221contractions t pairs = foldl' contract1b t pairs
217 where contract1b t (n1,n2) = contraction1 t n1 n2 222 where contract1b t (n1,n2) = contraction1 t n1 n2
218 223
219-- | applies a sequence of contractions of same index 224-- | applies a sequence of contractions of same index
220contractionsC :: (Field t, Num t) => Tensor t -> [IdxName] -> Tensor t 225contractionsC :: (Linear Vector t) => Tensor t -> [IdxName] -> Tensor t
221contractionsC t is = foldl' contraction1c t is 226contractionsC t is = foldl' contraction1c t is
222 227
223 228
224-- | applies a contraction on the first indices of the tensors 229-- | applies a contraction on the first indices of the tensors
225contractionF :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t 230contractionF :: (Linear Vector t) => Tensor t -> Tensor t -> Tensor t
226contractionF t1 t2 = contraction2 t1 n1 t2 n2 231contractionF t1 t2 = contraction2 t1 n1 t2 n2
227 where n1 = fn t1 232 where n1 = fn t1
228 n2 = fn t2 233 n2 = fn t2
229 fn = idxName . head . dims 234 fn = idxName . head . dims
230 235
231-- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal 236-- | computes all compatible contractions of the product of two tensors that would arise if the index names were equal
232possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] 237possibleContractions :: (Linear Vector t) => Tensor t -> Tensor t -> [Tensor t]
233possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] 238possibleContractions t1 t2 = [ contraction2 t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
234 239
235 240
@@ -242,7 +247,7 @@ desiredContractions1 t = [ n1 | (a,n1) <- x , (b,n2) <- x, a/=b, n1==n2]
242 where x = zip [0..] (names t) 247 where x = zip [0..] (names t)
243 248
244-- | tensor product with the convention that repeated indices are contracted. 249-- | tensor product with the convention that repeated indices are contracted.
245mulT :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t 250mulT :: (Linear Vector t) => Tensor t -> Tensor t -> Tensor t
246mulT t1 t2 = r where 251mulT t1 t2 = r where
247 t1r = contractionsC t1 (desiredContractions1 t1) 252 t1r = contractionsC t1 (desiredContractions1 t1)
248 t2r = contractionsC t2 (desiredContractions1 t2) 253 t2r = contractionsC t2 (desiredContractions1 t2)
@@ -254,10 +259,10 @@ mulT t1 t2 = r where
254----------------------------------------------------------------- 259-----------------------------------------------------------------
255 260
256-- | tensor addition (for tensors with the same structure) 261-- | tensor addition (for tensors with the same structure)
257addT :: (Num a, Field a) => Tensor a -> Tensor a -> Tensor a 262addT :: (Linear Vector a) => Tensor a -> Tensor a -> Tensor a
258addT a b = liftTensor2 add a b 263addT a b = liftTensor2 add a b
259 264
260sumT :: (Field a, Num a) => [Tensor a] -> Tensor a 265sumT :: (Linear Vector a) => [Tensor a] -> Tensor a
261sumT l = foldl1' addT l 266sumT l = foldl1' addT l
262 267
263----------------------------------------------------------------- 268-----------------------------------------------------------------
@@ -281,19 +286,19 @@ signature l | length (nub l) < length l = 0
281 | otherwise = -1 286 | otherwise = -1
282 287
283 288
284sym :: (Field t, Num t) => Tensor t -> Tensor t 289sym :: (Linear Vector t) => Tensor t -> Tensor t
285sym t = T (dims t) (ten (sym' (withIdx t seqind))) 290sym t = T (dims t) (ten (sym' (withIdx t seqind)))
286 where sym' t = sumT $ map (flip tridx t) (perms (names t)) 291 where sym' t = sumT $ map (flip tridx t) (perms (names t))
287 where nms = map idxName . dims 292 where nms = map idxName . dims
288 293
289antisym :: (Field t, Num t) => Tensor t -> Tensor t 294antisym :: (Linear Vector t) => Tensor t -> Tensor t
290antisym t = T (dims t) (ten (antisym' (withIdx t seqind))) 295antisym t = T (dims t) (ten (antisym' (withIdx t seqind)))
291 where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t)) 296 where antisym' t = sumT $ map (scsig . flip tridx t) (perms (names t))
292 scsig t = scalar (signature (nms t)) `rawProduct` t 297 scsig t = scalar (signature (nms t)) `rawProduct` t
293 where nms = map idxName . dims 298 where nms = map idxName . dims
294 299
295-- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product). 300-- | the wedge product of two tensors (implemented as the antisymmetrization of the ordinary tensor product).
296wedge :: (Field t, Fractional t) => Tensor t -> Tensor t -> Tensor t 301wedge :: (Linear Vector t, Fractional t) => Tensor t -> Tensor t -> Tensor t
297wedge a b = antisym (rawProduct (norper a) (norper b)) 302wedge a b = antisym (rawProduct (norper a) (norper b))
298 where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t))) 303 where norper t = rawProduct t (scalar (recip $ fromIntegral $ fact (rank t)))
299 304
@@ -313,19 +318,19 @@ seqind :: [String]
313seqind = map show [1..] 318seqind = map show [1..]
314 319
315-- | completely antisymmetric covariant tensor of dimension n 320-- | completely antisymmetric covariant tensor of dimension n
316leviCivita :: (Field t, Num t) => Int -> Tensor t 321leviCivita :: (Linear Vector t) => Int -> Tensor t
317leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind' 322leviCivita n = antisym $ foldl1 rawProduct $ zipWith withIdx auxbase seqind'
318 where auxbase = map tc (toRows (ident n)) 323 where auxbase = map tc (toRows (ident n))
319 tc = tensorFromVector Covariant 324 tc = tensorFromVector Covariant
320 325
321-- | contraction of leviCivita with a list of vectors (and raise with euclidean metric) 326-- | contraction of leviCivita with a list of vectors (and raise with euclidean metric)
322innerLevi :: (Num t, Field t) => [Tensor t] -> Tensor t 327innerLevi :: (Linear Vector t) => [Tensor t] -> Tensor t
323innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs 328innerLevi vs = raise $ foldl' contractionF (leviCivita n) vs
324 where n = idxDim . head . dims . head $ vs 329 where n = idxDim . head . dims . head $ vs
325 330
326 331
327-- | obtains the dual of a multivector (with euclidean metric) 332-- | obtains the dual of a multivector (with euclidean metric)
328dual :: (Field t, Fractional t) => Tensor t -> Tensor t 333dual :: (Linear Vector t, Fractional t) => Tensor t -> Tensor t
329dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x 334dual t = raise $ leviCivita n `mulT` withIdx t seqind `rawProduct` x
330 where n = idxDim . head . dims $ t 335 where n = idxDim . head . dims $ t
331 x = scalar (recip $ fromIntegral $ fact (rank t)) 336 x = scalar (recip $ fromIntegral $ fact (rank t))
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index ab93577..f2646a4 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -1,4 +1,4 @@
1{-# OPTIONS_GHC -fglasgow-exts #-} 1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3-- | 3-- |
4-- Module : Data.Packed.Internal.Vector 4-- Module : Data.Packed.Internal.Vector
@@ -19,6 +19,8 @@ import Data.Packed.Internal.Common
19import Foreign 19import Foreign
20import Complex 20import Complex
21import Control.Monad(when) 21import Control.Monad(when)
22import Data.List(transpose)
23import Debug.Trace(trace)
22 24
23type Vc t s = Int -> Ptr t -> s 25type Vc t s = Int -> Ptr t -> s
24-- not yet admitted by my haddock version 26-- not yet admitted by my haddock version
@@ -28,7 +30,7 @@ type Vc t s = Int -> Ptr t -> s
28vec :: Vector t -> (Vc t s) -> s 30vec :: Vector t -> (Vc t s) -> s
29vec v f = f (dim v) (ptr v) 31vec v f = f (dim v) (ptr v)
30 32
31baseOf v = (v `at` 0) 33--baseOf v = (v `at` 0)
32 34
33createVector :: Storable a => Int -> IO (Vector a) 35createVector :: Storable a => Int -> IO (Vector a)
34createVector n = do 36createVector n = do
@@ -78,9 +80,16 @@ subVector' k l (v@V {dim=n, ptr=p, fptr=fp})
78 | otherwise = v {dim=l, ptr=advancePtr p k} 80 | otherwise = v {dim=l, ptr=advancePtr p k}
79 81
80 82
83-- | Reads a vector position.
84(@>) :: Storable t => Vector t -> Int -> t
85infixl 9 @>
86(@>) = at
87
88
89
81 90
82-- | creates a new Vector by joining a list of Vectors 91-- | creates a new Vector by joining a list of Vectors
83join :: Field t => [Vector t] -> Vector t 92join :: Storable t => [Vector t] -> Vector t
84join [] = error "joining zero vectors" 93join [] = error "joining zero vectors"
85join as = unsafePerformIO $ do 94join as = unsafePerformIO $ do
86 let tot = sum (map dim as) 95 let tot = sum (map dim as)
@@ -103,34 +112,11 @@ asComplex :: Vector Double -> Vector (Complex Double)
103asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } 112asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) }
104 113
105 114
106constantG x n = fromList (replicate n x) 115----------------------------------------------------------------
107
108constantR :: Double -> Int -> Vector Double
109constantR = constantAux cconstantR
110
111constantC :: Complex Double -> Int -> Vector (Complex Double)
112constantC = constantAux cconstantC
113
114constantAux fun x n = unsafePerformIO $ do
115 v <- createVector n
116 px <- newArray [x]
117 fun px // vec v // check "constantAux" []
118 free px
119 return v
120
121foreign import ccall safe "aux.h constantR"
122 cconstantR :: Ptr Double -> TV -- Double :> IO Int
123
124foreign import ccall safe "aux.h constantC"
125 cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int
126
127constant :: Field a => a -> Int -> Vector a
128constant x n | isReal id x = scast $ constantR (scast x) n
129 | isComp id x = scast $ constantC (scast x) n
130 | otherwise = constantG x n
131 116
132liftVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b 117liftVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b
133liftVector f = fromList . map f . toList 118liftVector f = fromList . map f . toList
134 119
135liftVector2 :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c 120liftVector2 :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c
136liftVector2 f u v = fromList $ zipWith f (toList u) (toList v) 121liftVector2 f u v = fromList $ zipWith f (toList u) (toList v)
122
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index 2033dc7..2e8cb3d 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -134,3 +134,18 @@ asRow v = reshape (dim v) v
134 134
135asColumn :: Field a => Vector a -> Matrix a 135asColumn :: Field a => Vector a -> Matrix a
136asColumn v = reshape 1 v 136asColumn v = reshape 1 v
137
138------------------------------------------------
139
140{- | Outer product of two vectors.
141
142@\> 'fromList' [1,2,3] \`outer\` 'fromList' [5,2,3]
143(3><3)
144 [ 5.0, 2.0, 3.0
145 , 10.0, 4.0, 6.0
146 , 15.0, 6.0, 9.0 ]@
147-}
148outer :: (Num t, Field t) => Vector t -> Vector t -> Matrix t
149outer u v = multiply RowMajor r c
150 where r = matrixFromVector RowMajor 1 u
151 c = matrixFromVector RowMajor (dim v) v
diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs
index 27ba6a3..867b77b 100644
--- a/lib/Data/Packed/Vector.hs
+++ b/lib/Data/Packed/Vector.hs
@@ -27,7 +27,7 @@ module Data.Packed.Vector (
27 27
28import Data.Packed.Internal 28import Data.Packed.Internal
29import Complex 29import Complex
30import GSL.Vector 30--import GSL.Vector
31 31
32-- | creates a complex vector from vectors with real and imaginary parts 32-- | creates a complex vector from vectors with real and imaginary parts
33toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double) 33toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double)
@@ -50,7 +50,9 @@ linspace :: Int -> (Double, Double) -> Vector Double
50linspace n (a,b) = fromList [a::Double,a+delta .. b] 50linspace n (a,b) = fromList [a::Double,a+delta .. b]
51 where delta = (b-a)/(fromIntegral n -1) 51 where delta = (b-a)/(fromIntegral n -1)
52 52
53-- | Reads a vector position. 53
54(@>) :: Field t => Vector t -> Int -> t 54dot :: (Field t) => Vector t -> Vector t -> t
55infixl 9 @> 55dot u v = dat (multiply RowMajor r c) `at` 0
56(@>) = at 56 where r = matrixFromVector RowMajor (dim u) u
57 c = matrixFromVector RowMajor 1 v
58