summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r--lib/Data/Packed/Internal/Common.hs7
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs7
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs65
-rw-r--r--lib/Data/Packed/Internal/Vector.hs16
4 files changed, 65 insertions, 30 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs
index bdd7f34..1bfed6d 100644
--- a/lib/Data/Packed/Internal/Common.hs
+++ b/lib/Data/Packed/Internal/Common.hs
@@ -40,6 +40,7 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where --
40 poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- 40 poke p (a :+ b) = pokeArray (castPtr p) [a,b] --
41---------------------------------------------------------------------- 41----------------------------------------------------------------------
42 42
43on :: (a -> a -> b) -> (t -> a) -> t -> t -> b
43on f g = \x y -> f (g x) (g y) 44on f g = \x y -> f (g x) (g y)
44 45
45partit :: Int -> [a] -> [[a]] 46partit :: Int -> [a] -> [[a]]
@@ -54,12 +55,14 @@ common f = commonval . map f where
54 commonval [a] = Just a 55 commonval [a] = Just a
55 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing 56 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
56 57
58xor :: Bool -> Bool -> Bool
57xor a b = a && not b || b && not a 59xor a b = a && not b || b && not a
58 60
59(//) :: x -> (x -> y) -> y 61(//) :: x -> (x -> y) -> y
60infixl 0 // 62infixl 0 //
61(//) = flip ($) 63(//) = flip ($)
62 64
65errorCode :: Int -> String
63errorCode 1000 = "bad size" 66errorCode 1000 = "bad size"
64errorCode 1001 = "bad function code" 67errorCode 1001 = "bad function code"
65errorCode 1002 = "memory problem" 68errorCode 1002 = "memory problem"
@@ -68,6 +71,7 @@ errorCode 1004 = "singular"
68errorCode 1005 = "didn't converge" 71errorCode 1005 = "didn't converge"
69errorCode n = "code "++show n 72errorCode n = "code "++show n
70 73
74check :: String -> [Vector a] -> IO Int -> IO ()
71check msg ls f = do 75check msg ls f = do
72 err <- f 76 err <- f
73 when (err/=0) (error (msg++": "++errorCode err)) 77 when (err/=0) (error (msg++": "++errorCode err))
@@ -77,7 +81,10 @@ check msg ls f = do
77class (Storable a, Typeable a) => Field a 81class (Storable a, Typeable a) => Field a
78instance (Storable a, Typeable a) => Field a 82instance (Storable a, Typeable a) => Field a
79 83
84isReal :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool
80isReal w x = typeOf (undefined :: Double) == typeOf (w x) 85isReal w x = typeOf (undefined :: Double) == typeOf (w x)
86
87isComp :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool
81isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) 88isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x)
82 89
83scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b 90scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 2925fc0..32dc603 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -194,7 +194,9 @@ multiplyD order a b
194 194
195---------------------------------------------------------------------- 195----------------------------------------------------------------------
196 196
197outer u v = dat (multiply RowMajor r c) 197outer' u v = dat (outer u v)
198
199outer u v = multiply RowMajor r c
198 where r = matrixFromVector RowMajor 1 u 200 where r = matrixFromVector RowMajor 1 u
199 c = matrixFromVector RowMajor (dim v) v 201 c = matrixFromVector RowMajor (dim v) v
200 202
@@ -212,8 +214,7 @@ subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do
212 r <- createMatrix RowMajor rt ct 214 r <- createMatrix RowMajor rt ct
213 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r] 215 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r]
214 return r 216 return r
215foreign import ccall "aux.h submatrixR" 217foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM
216 c_submatrixR :: Int -> Int -> Int -> Int -> TMM
217 218
218-- | extraction of a submatrix of a complex matrix 219-- | extraction of a submatrix of a complex matrix
219subMatrixC :: (Int,Int) -- ^ (r0,c0) starting position 220subMatrixC :: (Int,Int) -- ^ (r0,c0) starting position
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 27fce6a..c4faf49 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -14,18 +14,25 @@
14 14
15module Data.Packed.Internal.Tensor where 15module Data.Packed.Internal.Tensor where
16 16
17import Data.Packed.Internal 17import Data.Packed.Internal.Common
18import Data.Packed.Internal.Vector 18import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix 19import Data.Packed.Internal.Matrix
20import Foreign.Storable 20import Foreign.Storable
21import Data.List(sort,elemIndex,nub) 21import Data.List(sort,elemIndex,nub)
22 22
23data IdxTp = Covariant | Contravariant deriving (Show,Eq) 23data IdxType = Covariant | Contravariant deriving (Show,Eq)
24 24
25data Tensor t = T { dims :: [(Int,(IdxTp,String))] 25type IdxName = String
26
27data IdxDesc = IdxDesc { idxDim :: Int,
28 idxType :: IdxType,
29 idxName :: IdxName }
30
31data Tensor t = T { dims :: [IdxDesc]
26 , ten :: Vector t 32 , ten :: Vector t
27 } 33 }
28 34
35rank :: Tensor t -> Int
29rank = length . dims 36rank = length . dims
30 37
31instance (Show a,Storable a) => Show (Tensor a) where 38instance (Show a,Storable a) => Show (Tensor a) where
@@ -33,41 +40,49 @@ instance (Show a,Storable a) => Show (Tensor a) where
33 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) 40 show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t)
34 41
35 42
36shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]" 43shdims :: [IdxDesc] -> String
44shdims [IdxDesc n t name] = name ++ sym t ++"["++show n++"]"
37 where sym Covariant = "_" 45 where sym Covariant = "_"
38 sym Contravariant = "^" 46 sym Contravariant = "^"
39shdims (d:ds) = shdims [d] ++ "><"++ shdims ds 47shdims (d:ds) = shdims [d] ++ "><"++ shdims ds
40 48
41 49
42 50findIdx :: (Field t) => IdxName -> Tensor t
51 -> (([IdxDesc], [IdxDesc]), Matrix t)
43findIdx name t = ((d1,d2),m) where 52findIdx name t = ((d1,d2),m) where
44 (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) 53 (d1,d2) = span (\d -> idxName d /= name) (dims t)
45 c = product (map fst d2) 54 c = product (map idxDim d2)
46 m = matrixFromVector RowMajor c (ten t) 55 m = matrixFromVector RowMajor c (ten t)
47 56
57putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t)
48putFirstIdx name t = (nd,m') 58putFirstIdx name t = (nd,m')
49 where ((d1,d2),m) = findIdx name t 59 where ((d1,d2),m) = findIdx name t
50 m' = matrixFromVector RowMajor c $ cdat $ trans m 60 m' = matrixFromVector RowMajor c $ cdat $ trans m
51 nd = d2++d1 61 nd = d2++d1
52 c = dim (ten t) `div` (fst $ head d2) 62 c = dim (ten t) `div` (idxDim $ head d2)
53 63
64part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t
54part t (name,k) = if k<0 || k>=l 65part t (name,k) = if k<0 || k>=l
55 then error $ "part "++show (name,k)++" out of range in "++show t 66 then error $ "part "++show (name,k)++" out of range" -- in "++show t
56 else T {dims = ds, ten = toRows m !! k} 67 else T {dims = ds, ten = toRows m !! k}
57 where (d:ds,m) = putFirstIdx name t 68 where (d:ds,m) = putFirstIdx name t
58 (l,_) = d 69 l = idxDim d
59 70
71parts :: (Field t) => Tensor t -> IdxName -> [Tensor t]
60parts t name = map f (toRows m) 72parts t name = map f (toRows m)
61 where (d:ds,m) = putFirstIdx name t 73 where (d:ds,m) = putFirstIdx name t
62 (l,_) = d 74 l = idxDim d
63 f t = T {dims=ds, ten=t} 75 f t = T {dims=ds, ten=t}
64 76
77concatRename :: [IdxDesc] -> [IdxDesc] -> [IdxDesc]
65concatRename l1 l2 = l1 ++ map ren l2 where 78concatRename l1 l2 = l1 ++ map ren l2 where
66 ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s)) 79 ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx
67 fs = map (snd.snd) l1 80 fs = map idxName l1
68 81
69prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2) 82prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t
83prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2)
70 84
85contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
71contraction t1 n1 t2 n2 = 86contraction t1 n1 t2 n2 =
72 if compatIdx t1 n1 t2 n2 87 if compatIdx t1 n1 t2 n2
73 then T (concatRename (tail d1) (tail d2)) (cdat m) 88 then T (concatRename (tail d1) (tail d2)) (cdat m)
@@ -76,18 +91,22 @@ contraction t1 n1 t2 n2 =
76 (d2,m2) = putFirstIdx n2 t2 91 (d2,m2) = putFirstIdx n2 t2
77 m = multiply RowMajor (trans m1) m2 92 m = multiply RowMajor (trans m1) m2
78 93
94sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t]
79sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) 95sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls)
80 96
97contract1 :: (Num t, Enum t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t
81contract1 t name1 name2 = T d $ fromList $ sumT y 98contract1 t name1 name2 = T d $ fromList $ sumT y
82 where d = dims (head y) 99 where d = dims (head y)
83 x = (map (flip parts name2) (parts t name1)) 100 x = (map (flip parts name2) (parts t name1))
84 y = map head $ zipWith drop [0..] x 101 y = map head $ zipWith drop [0..] x
85 102
103contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t
86contraction' t1 n1 t2 n2 = 104contraction' t1 n1 t2 n2 =
87 if compatIdx t1 n1 t2 n2 105 if compatIdx t1 n1 t2 n2
88 then contract1 (prod t1 t2) n1 (n2++"'") 106 then contract1 (prod t1 t2) n1 (n2++"'")
89 else error "wrong contraction'" 107 else error "wrong contraction'"
90 108
109tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t
91tridx [] t = t 110tridx [] t = t
92tridx (name:rest) t = T (d:ds) (join ts) where 111tridx (name:rest) t = T (d:ds) (join ts) where
93 ((_,d:_),_) = findIdx name t 112 ((_,d:_),_) = findIdx name t
@@ -95,30 +114,38 @@ tridx (name:rest) t = T (d:ds) (join ts) where
95 ts = map ten ps 114 ts = map ten ps
96 ds = dims (head ps) 115 ds = dims (head ps)
97 116
98compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2 117compatIdxAux :: IdxDesc -> IdxDesc -> Bool
118compatIdxAux IdxDesc {idxDim = n1, idxType = t1}
119 IdxDesc {idxDim = n2, idxType = t2}
120 = t1 /= t2 && n1 == n2
99 121
122compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool
100compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where 123compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where
101 d1 = head $ snd $ fst $ findIdx n1 t1 124 d1 = head $ snd $ fst $ findIdx n1 t1
102 d2 = head $ snd $ fst $ findIdx n2 t2 125 d2 = head $ snd $ fst $ findIdx n2 t2
103 126
104names t = sort $ map (snd.snd) (dims t) 127names :: Tensor t -> [IdxName]
128names t = sort $ map idxName (dims t)
105 129
130normal :: (Field t) => Tensor t -> Tensor t
106normal t = tridx (names t) t 131normal t = tridx (names t) t
107 132
133contractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t]
108contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] 134contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
109 135
110-- sent to Haskell-Cafe by Sebastian Sylvan 136-- sent to Haskell-Cafe by Sebastian Sylvan
137perms :: [t] -> [[t]]
111perms [x] = [[x]] 138perms [x] = [[x]]
112perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] 139perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys]
113selections [] = [] 140selections [] = []
114selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] 141selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs]
115 142
116 143interchanges :: (Ord a) => [a] -> Int
117interchanges ls = sum (map (count ls) ls) 144interchanges ls = sum (map (count ls) ls)
118 where count l p = n 145 where count l p = length $ filter (>p) $ take pel l
119 where Just pel = elemIndex p l 146 where Just pel = elemIndex p l
120 n = length $ filter (>p) $ take pel l
121 147
148signature :: (Num t, Ord a) => [a] -> t
122signature l | length (nub l) < length l = 0 149signature l | length (nub l) < length l = 0
123 | even (interchanges l) = 1 150 | even (interchanges l) = 1
124 | otherwise = -1 151 | otherwise = -1
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index 8848062..25e848d 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -103,15 +103,15 @@ asComplex :: Vector Double -> Vector (Complex Double)
103asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } 103asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) }
104 104
105 105
106constantG n x = fromList (replicate n x) 106constantG x n = fromList (replicate n x)
107 107
108constantR :: Int -> Double -> Vector Double 108constantR :: Double -> Int -> Vector Double
109constantR = constantAux cconstantR 109constantR = constantAux cconstantR
110 110
111constantC :: Int -> Complex Double -> Vector (Complex Double) 111constantC :: Complex Double -> Int -> Vector (Complex Double)
112constantC = constantAux cconstantC 112constantC = constantAux cconstantC
113 113
114constantAux fun n x = unsafePerformIO $ do 114constantAux fun x n = unsafePerformIO $ do
115 v <- createVector n 115 v <- createVector n
116 px <- newArray [x] 116 px <- newArray [x]
117 fun px // vec v // check "constantAux" [] 117 fun px // vec v // check "constantAux" []
@@ -124,8 +124,8 @@ foreign import ccall safe "aux.h constantR"
124foreign import ccall safe "aux.h constantC" 124foreign import ccall safe "aux.h constantC"
125 cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int 125 cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int
126 126
127constant :: Field a => Int -> a -> Vector a 127constant :: Field a => a -> Int -> Vector a
128constant n x | isReal id x = scast $ constantR n (scast x) 128constant x n | isReal id x = scast $ constantR (scast x) n
129 | isComp id x = scast $ constantC n (scast x) 129 | isComp id x = scast $ constantC (scast x) n
130 | otherwise = constantG n x 130 | otherwise = constantG x n
131 131