diff options
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r-- | lib/Data/Packed/Internal/Common.hs | 7 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 7 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 65 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 16 |
4 files changed, 65 insertions, 30 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs index bdd7f34..1bfed6d 100644 --- a/lib/Data/Packed/Internal/Common.hs +++ b/lib/Data/Packed/Internal/Common.hs | |||
@@ -40,6 +40,7 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where -- | |||
40 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | 40 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- |
41 | ---------------------------------------------------------------------- | 41 | ---------------------------------------------------------------------- |
42 | 42 | ||
43 | on :: (a -> a -> b) -> (t -> a) -> t -> t -> b | ||
43 | on f g = \x y -> f (g x) (g y) | 44 | on f g = \x y -> f (g x) (g y) |
44 | 45 | ||
45 | partit :: Int -> [a] -> [[a]] | 46 | partit :: Int -> [a] -> [[a]] |
@@ -54,12 +55,14 @@ common f = commonval . map f where | |||
54 | commonval [a] = Just a | 55 | commonval [a] = Just a |
55 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing | 56 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing |
56 | 57 | ||
58 | xor :: Bool -> Bool -> Bool | ||
57 | xor a b = a && not b || b && not a | 59 | xor a b = a && not b || b && not a |
58 | 60 | ||
59 | (//) :: x -> (x -> y) -> y | 61 | (//) :: x -> (x -> y) -> y |
60 | infixl 0 // | 62 | infixl 0 // |
61 | (//) = flip ($) | 63 | (//) = flip ($) |
62 | 64 | ||
65 | errorCode :: Int -> String | ||
63 | errorCode 1000 = "bad size" | 66 | errorCode 1000 = "bad size" |
64 | errorCode 1001 = "bad function code" | 67 | errorCode 1001 = "bad function code" |
65 | errorCode 1002 = "memory problem" | 68 | errorCode 1002 = "memory problem" |
@@ -68,6 +71,7 @@ errorCode 1004 = "singular" | |||
68 | errorCode 1005 = "didn't converge" | 71 | errorCode 1005 = "didn't converge" |
69 | errorCode n = "code "++show n | 72 | errorCode n = "code "++show n |
70 | 73 | ||
74 | check :: String -> [Vector a] -> IO Int -> IO () | ||
71 | check msg ls f = do | 75 | check msg ls f = do |
72 | err <- f | 76 | err <- f |
73 | when (err/=0) (error (msg++": "++errorCode err)) | 77 | when (err/=0) (error (msg++": "++errorCode err)) |
@@ -77,7 +81,10 @@ check msg ls f = do | |||
77 | class (Storable a, Typeable a) => Field a | 81 | class (Storable a, Typeable a) => Field a |
78 | instance (Storable a, Typeable a) => Field a | 82 | instance (Storable a, Typeable a) => Field a |
79 | 83 | ||
84 | isReal :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool | ||
80 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) | 85 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) |
86 | |||
87 | isComp :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool | ||
81 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) | 88 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) |
82 | 89 | ||
83 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b | 90 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b |
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 2925fc0..32dc603 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -194,7 +194,9 @@ multiplyD order a b | |||
194 | 194 | ||
195 | ---------------------------------------------------------------------- | 195 | ---------------------------------------------------------------------- |
196 | 196 | ||
197 | outer u v = dat (multiply RowMajor r c) | 197 | outer' u v = dat (outer u v) |
198 | |||
199 | outer u v = multiply RowMajor r c | ||
198 | where r = matrixFromVector RowMajor 1 u | 200 | where r = matrixFromVector RowMajor 1 u |
199 | c = matrixFromVector RowMajor (dim v) v | 201 | c = matrixFromVector RowMajor (dim v) v |
200 | 202 | ||
@@ -212,8 +214,7 @@ subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do | |||
212 | r <- createMatrix RowMajor rt ct | 214 | r <- createMatrix RowMajor rt ct |
213 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r] | 215 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r] |
214 | return r | 216 | return r |
215 | foreign import ccall "aux.h submatrixR" | 217 | foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM |
216 | c_submatrixR :: Int -> Int -> Int -> Int -> TMM | ||
217 | 218 | ||
218 | -- | extraction of a submatrix of a complex matrix | 219 | -- | extraction of a submatrix of a complex matrix |
219 | subMatrixC :: (Int,Int) -- ^ (r0,c0) starting position | 220 | subMatrixC :: (Int,Int) -- ^ (r0,c0) starting position |
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 27fce6a..c4faf49 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -14,18 +14,25 @@ | |||
14 | 14 | ||
15 | module Data.Packed.Internal.Tensor where | 15 | module Data.Packed.Internal.Tensor where |
16 | 16 | ||
17 | import Data.Packed.Internal | 17 | import Data.Packed.Internal.Common |
18 | import Data.Packed.Internal.Vector | 18 | import Data.Packed.Internal.Vector |
19 | import Data.Packed.Internal.Matrix | 19 | import Data.Packed.Internal.Matrix |
20 | import Foreign.Storable | 20 | import Foreign.Storable |
21 | import Data.List(sort,elemIndex,nub) | 21 | import Data.List(sort,elemIndex,nub) |
22 | 22 | ||
23 | data IdxTp = Covariant | Contravariant deriving (Show,Eq) | 23 | data IdxType = Covariant | Contravariant deriving (Show,Eq) |
24 | 24 | ||
25 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] | 25 | type IdxName = String |
26 | |||
27 | data IdxDesc = IdxDesc { idxDim :: Int, | ||
28 | idxType :: IdxType, | ||
29 | idxName :: IdxName } | ||
30 | |||
31 | data Tensor t = T { dims :: [IdxDesc] | ||
26 | , ten :: Vector t | 32 | , ten :: Vector t |
27 | } | 33 | } |
28 | 34 | ||
35 | rank :: Tensor t -> Int | ||
29 | rank = length . dims | 36 | rank = length . dims |
30 | 37 | ||
31 | instance (Show a,Storable a) => Show (Tensor a) where | 38 | instance (Show a,Storable a) => Show (Tensor a) where |
@@ -33,41 +40,49 @@ instance (Show a,Storable a) => Show (Tensor a) where | |||
33 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) | 40 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) |
34 | 41 | ||
35 | 42 | ||
36 | shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]" | 43 | shdims :: [IdxDesc] -> String |
44 | shdims [IdxDesc n t name] = name ++ sym t ++"["++show n++"]" | ||
37 | where sym Covariant = "_" | 45 | where sym Covariant = "_" |
38 | sym Contravariant = "^" | 46 | sym Contravariant = "^" |
39 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds | 47 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds |
40 | 48 | ||
41 | 49 | ||
42 | 50 | findIdx :: (Field t) => IdxName -> Tensor t | |
51 | -> (([IdxDesc], [IdxDesc]), Matrix t) | ||
43 | findIdx name t = ((d1,d2),m) where | 52 | findIdx name t = ((d1,d2),m) where |
44 | (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) | 53 | (d1,d2) = span (\d -> idxName d /= name) (dims t) |
45 | c = product (map fst d2) | 54 | c = product (map idxDim d2) |
46 | m = matrixFromVector RowMajor c (ten t) | 55 | m = matrixFromVector RowMajor c (ten t) |
47 | 56 | ||
57 | putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t) | ||
48 | putFirstIdx name t = (nd,m') | 58 | putFirstIdx name t = (nd,m') |
49 | where ((d1,d2),m) = findIdx name t | 59 | where ((d1,d2),m) = findIdx name t |
50 | m' = matrixFromVector RowMajor c $ cdat $ trans m | 60 | m' = matrixFromVector RowMajor c $ cdat $ trans m |
51 | nd = d2++d1 | 61 | nd = d2++d1 |
52 | c = dim (ten t) `div` (fst $ head d2) | 62 | c = dim (ten t) `div` (idxDim $ head d2) |
53 | 63 | ||
64 | part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t | ||
54 | part t (name,k) = if k<0 || k>=l | 65 | part t (name,k) = if k<0 || k>=l |
55 | then error $ "part "++show (name,k)++" out of range in "++show t | 66 | then error $ "part "++show (name,k)++" out of range" -- in "++show t |
56 | else T {dims = ds, ten = toRows m !! k} | 67 | else T {dims = ds, ten = toRows m !! k} |
57 | where (d:ds,m) = putFirstIdx name t | 68 | where (d:ds,m) = putFirstIdx name t |
58 | (l,_) = d | 69 | l = idxDim d |
59 | 70 | ||
71 | parts :: (Field t) => Tensor t -> IdxName -> [Tensor t] | ||
60 | parts t name = map f (toRows m) | 72 | parts t name = map f (toRows m) |
61 | where (d:ds,m) = putFirstIdx name t | 73 | where (d:ds,m) = putFirstIdx name t |
62 | (l,_) = d | 74 | l = idxDim d |
63 | f t = T {dims=ds, ten=t} | 75 | f t = T {dims=ds, ten=t} |
64 | 76 | ||
77 | concatRename :: [IdxDesc] -> [IdxDesc] -> [IdxDesc] | ||
65 | concatRename l1 l2 = l1 ++ map ren l2 where | 78 | concatRename l1 l2 = l1 ++ map ren l2 where |
66 | ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s)) | 79 | ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx |
67 | fs = map (snd.snd) l1 | 80 | fs = map idxName l1 |
68 | 81 | ||
69 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2) | 82 | prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t |
83 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) | ||
70 | 84 | ||
85 | contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | ||
71 | contraction t1 n1 t2 n2 = | 86 | contraction t1 n1 t2 n2 = |
72 | if compatIdx t1 n1 t2 n2 | 87 | if compatIdx t1 n1 t2 n2 |
73 | then T (concatRename (tail d1) (tail d2)) (cdat m) | 88 | then T (concatRename (tail d1) (tail d2)) (cdat m) |
@@ -76,18 +91,22 @@ contraction t1 n1 t2 n2 = | |||
76 | (d2,m2) = putFirstIdx n2 t2 | 91 | (d2,m2) = putFirstIdx n2 t2 |
77 | m = multiply RowMajor (trans m1) m2 | 92 | m = multiply RowMajor (trans m1) m2 |
78 | 93 | ||
94 | sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] | ||
79 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | 95 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) |
80 | 96 | ||
97 | contract1 :: (Num t, Enum t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t | ||
81 | contract1 t name1 name2 = T d $ fromList $ sumT y | 98 | contract1 t name1 name2 = T d $ fromList $ sumT y |
82 | where d = dims (head y) | 99 | where d = dims (head y) |
83 | x = (map (flip parts name2) (parts t name1)) | 100 | x = (map (flip parts name2) (parts t name1)) |
84 | y = map head $ zipWith drop [0..] x | 101 | y = map head $ zipWith drop [0..] x |
85 | 102 | ||
103 | contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | ||
86 | contraction' t1 n1 t2 n2 = | 104 | contraction' t1 n1 t2 n2 = |
87 | if compatIdx t1 n1 t2 n2 | 105 | if compatIdx t1 n1 t2 n2 |
88 | then contract1 (prod t1 t2) n1 (n2++"'") | 106 | then contract1 (prod t1 t2) n1 (n2++"'") |
89 | else error "wrong contraction'" | 107 | else error "wrong contraction'" |
90 | 108 | ||
109 | tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t | ||
91 | tridx [] t = t | 110 | tridx [] t = t |
92 | tridx (name:rest) t = T (d:ds) (join ts) where | 111 | tridx (name:rest) t = T (d:ds) (join ts) where |
93 | ((_,d:_),_) = findIdx name t | 112 | ((_,d:_),_) = findIdx name t |
@@ -95,30 +114,38 @@ tridx (name:rest) t = T (d:ds) (join ts) where | |||
95 | ts = map ten ps | 114 | ts = map ten ps |
96 | ds = dims (head ps) | 115 | ds = dims (head ps) |
97 | 116 | ||
98 | compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2 | 117 | compatIdxAux :: IdxDesc -> IdxDesc -> Bool |
118 | compatIdxAux IdxDesc {idxDim = n1, idxType = t1} | ||
119 | IdxDesc {idxDim = n2, idxType = t2} | ||
120 | = t1 /= t2 && n1 == n2 | ||
99 | 121 | ||
122 | compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool | ||
100 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where | 123 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where |
101 | d1 = head $ snd $ fst $ findIdx n1 t1 | 124 | d1 = head $ snd $ fst $ findIdx n1 t1 |
102 | d2 = head $ snd $ fst $ findIdx n2 t2 | 125 | d2 = head $ snd $ fst $ findIdx n2 t2 |
103 | 126 | ||
104 | names t = sort $ map (snd.snd) (dims t) | 127 | names :: Tensor t -> [IdxName] |
128 | names t = sort $ map idxName (dims t) | ||
105 | 129 | ||
130 | normal :: (Field t) => Tensor t -> Tensor t | ||
106 | normal t = tridx (names t) t | 131 | normal t = tridx (names t) t |
107 | 132 | ||
133 | contractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] | ||
108 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] | 134 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] |
109 | 135 | ||
110 | -- sent to Haskell-Cafe by Sebastian Sylvan | 136 | -- sent to Haskell-Cafe by Sebastian Sylvan |
137 | perms :: [t] -> [[t]] | ||
111 | perms [x] = [[x]] | 138 | perms [x] = [[x]] |
112 | perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] | 139 | perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] |
113 | selections [] = [] | 140 | selections [] = [] |
114 | selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] | 141 | selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] |
115 | 142 | ||
116 | 143 | interchanges :: (Ord a) => [a] -> Int | |
117 | interchanges ls = sum (map (count ls) ls) | 144 | interchanges ls = sum (map (count ls) ls) |
118 | where count l p = n | 145 | where count l p = length $ filter (>p) $ take pel l |
119 | where Just pel = elemIndex p l | 146 | where Just pel = elemIndex p l |
120 | n = length $ filter (>p) $ take pel l | ||
121 | 147 | ||
148 | signature :: (Num t, Ord a) => [a] -> t | ||
122 | signature l | length (nub l) < length l = 0 | 149 | signature l | length (nub l) < length l = 0 |
123 | | even (interchanges l) = 1 | 150 | | even (interchanges l) = 1 |
124 | | otherwise = -1 | 151 | | otherwise = -1 |
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 8848062..25e848d 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -103,15 +103,15 @@ asComplex :: Vector Double -> Vector (Complex Double) | |||
103 | asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } | 103 | asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } |
104 | 104 | ||
105 | 105 | ||
106 | constantG n x = fromList (replicate n x) | 106 | constantG x n = fromList (replicate n x) |
107 | 107 | ||
108 | constantR :: Int -> Double -> Vector Double | 108 | constantR :: Double -> Int -> Vector Double |
109 | constantR = constantAux cconstantR | 109 | constantR = constantAux cconstantR |
110 | 110 | ||
111 | constantC :: Int -> Complex Double -> Vector (Complex Double) | 111 | constantC :: Complex Double -> Int -> Vector (Complex Double) |
112 | constantC = constantAux cconstantC | 112 | constantC = constantAux cconstantC |
113 | 113 | ||
114 | constantAux fun n x = unsafePerformIO $ do | 114 | constantAux fun x n = unsafePerformIO $ do |
115 | v <- createVector n | 115 | v <- createVector n |
116 | px <- newArray [x] | 116 | px <- newArray [x] |
117 | fun px // vec v // check "constantAux" [] | 117 | fun px // vec v // check "constantAux" [] |
@@ -124,8 +124,8 @@ foreign import ccall safe "aux.h constantR" | |||
124 | foreign import ccall safe "aux.h constantC" | 124 | foreign import ccall safe "aux.h constantC" |
125 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int | 125 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int |
126 | 126 | ||
127 | constant :: Field a => Int -> a -> Vector a | 127 | constant :: Field a => a -> Int -> Vector a |
128 | constant n x | isReal id x = scast $ constantR n (scast x) | 128 | constant x n | isReal id x = scast $ constantR (scast x) n |
129 | | isComp id x = scast $ constantC n (scast x) | 129 | | isComp id x = scast $ constantC (scast x) n |
130 | | otherwise = constantG n x | 130 | | otherwise = constantG x n |
131 | 131 | ||