diff options
author | Alberto Ruiz <aruiz@um.es> | 2007-06-22 17:33:17 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2007-06-22 17:33:17 +0000 |
commit | 978e6d038239af50d70bae2c303f4e45b1879b7a (patch) | |
tree | 571b2060f388d0693820f808b40089acb100a5d9 /lib/Data/Packed | |
parent | 989bdf7e88c13500bd1986dcde36f6cc4f467efb (diff) |
refactoring
Diffstat (limited to 'lib/Data/Packed')
-rw-r--r-- | lib/Data/Packed/Internal.hs | 4 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Common.hs | 7 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 7 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Tensor.hs | 65 | ||||
-rw-r--r-- | lib/Data/Packed/Internal/Vector.hs | 16 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 9 | ||||
-rw-r--r-- | lib/Data/Packed/Plot.hs | 167 | ||||
-rw-r--r-- | lib/Data/Packed/Tensor.hs | 21 | ||||
-rw-r--r-- | lib/Data/Packed/Vector.hs | 13 |
9 files changed, 272 insertions, 37 deletions
diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs index a7fca1a..a5a77c5 100644 --- a/lib/Data/Packed/Internal.hs +++ b/lib/Data/Packed/Internal.hs | |||
@@ -15,9 +15,11 @@ | |||
15 | module Data.Packed.Internal ( | 15 | module Data.Packed.Internal ( |
16 | module Data.Packed.Internal.Common, | 16 | module Data.Packed.Internal.Common, |
17 | module Data.Packed.Internal.Vector, | 17 | module Data.Packed.Internal.Vector, |
18 | module Data.Packed.Internal.Matrix | 18 | module Data.Packed.Internal.Matrix, |
19 | module Data.Packed.Internal.Tensor | ||
19 | ) where | 20 | ) where |
20 | 21 | ||
21 | import Data.Packed.Internal.Common | 22 | import Data.Packed.Internal.Common |
22 | import Data.Packed.Internal.Vector | 23 | import Data.Packed.Internal.Vector |
23 | import Data.Packed.Internal.Matrix | 24 | import Data.Packed.Internal.Matrix |
25 | import Data.Packed.Internal.Tensor \ No newline at end of file | ||
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs index bdd7f34..1bfed6d 100644 --- a/lib/Data/Packed/Internal/Common.hs +++ b/lib/Data/Packed/Internal/Common.hs | |||
@@ -40,6 +40,7 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where -- | |||
40 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- | 40 | poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- |
41 | ---------------------------------------------------------------------- | 41 | ---------------------------------------------------------------------- |
42 | 42 | ||
43 | on :: (a -> a -> b) -> (t -> a) -> t -> t -> b | ||
43 | on f g = \x y -> f (g x) (g y) | 44 | on f g = \x y -> f (g x) (g y) |
44 | 45 | ||
45 | partit :: Int -> [a] -> [[a]] | 46 | partit :: Int -> [a] -> [[a]] |
@@ -54,12 +55,14 @@ common f = commonval . map f where | |||
54 | commonval [a] = Just a | 55 | commonval [a] = Just a |
55 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing | 56 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing |
56 | 57 | ||
58 | xor :: Bool -> Bool -> Bool | ||
57 | xor a b = a && not b || b && not a | 59 | xor a b = a && not b || b && not a |
58 | 60 | ||
59 | (//) :: x -> (x -> y) -> y | 61 | (//) :: x -> (x -> y) -> y |
60 | infixl 0 // | 62 | infixl 0 // |
61 | (//) = flip ($) | 63 | (//) = flip ($) |
62 | 64 | ||
65 | errorCode :: Int -> String | ||
63 | errorCode 1000 = "bad size" | 66 | errorCode 1000 = "bad size" |
64 | errorCode 1001 = "bad function code" | 67 | errorCode 1001 = "bad function code" |
65 | errorCode 1002 = "memory problem" | 68 | errorCode 1002 = "memory problem" |
@@ -68,6 +71,7 @@ errorCode 1004 = "singular" | |||
68 | errorCode 1005 = "didn't converge" | 71 | errorCode 1005 = "didn't converge" |
69 | errorCode n = "code "++show n | 72 | errorCode n = "code "++show n |
70 | 73 | ||
74 | check :: String -> [Vector a] -> IO Int -> IO () | ||
71 | check msg ls f = do | 75 | check msg ls f = do |
72 | err <- f | 76 | err <- f |
73 | when (err/=0) (error (msg++": "++errorCode err)) | 77 | when (err/=0) (error (msg++": "++errorCode err)) |
@@ -77,7 +81,10 @@ check msg ls f = do | |||
77 | class (Storable a, Typeable a) => Field a | 81 | class (Storable a, Typeable a) => Field a |
78 | instance (Storable a, Typeable a) => Field a | 82 | instance (Storable a, Typeable a) => Field a |
79 | 83 | ||
84 | isReal :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool | ||
80 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) | 85 | isReal w x = typeOf (undefined :: Double) == typeOf (w x) |
86 | |||
87 | isComp :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool | ||
81 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) | 88 | isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) |
82 | 89 | ||
83 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b | 90 | scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b |
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 2925fc0..32dc603 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -194,7 +194,9 @@ multiplyD order a b | |||
194 | 194 | ||
195 | ---------------------------------------------------------------------- | 195 | ---------------------------------------------------------------------- |
196 | 196 | ||
197 | outer u v = dat (multiply RowMajor r c) | 197 | outer' u v = dat (outer u v) |
198 | |||
199 | outer u v = multiply RowMajor r c | ||
198 | where r = matrixFromVector RowMajor 1 u | 200 | where r = matrixFromVector RowMajor 1 u |
199 | c = matrixFromVector RowMajor (dim v) v | 201 | c = matrixFromVector RowMajor (dim v) v |
200 | 202 | ||
@@ -212,8 +214,7 @@ subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do | |||
212 | r <- createMatrix RowMajor rt ct | 214 | r <- createMatrix RowMajor rt ct |
213 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r] | 215 | c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r] |
214 | return r | 216 | return r |
215 | foreign import ccall "aux.h submatrixR" | 217 | foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM |
216 | c_submatrixR :: Int -> Int -> Int -> Int -> TMM | ||
217 | 218 | ||
218 | -- | extraction of a submatrix of a complex matrix | 219 | -- | extraction of a submatrix of a complex matrix |
219 | subMatrixC :: (Int,Int) -- ^ (r0,c0) starting position | 220 | subMatrixC :: (Int,Int) -- ^ (r0,c0) starting position |
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 27fce6a..c4faf49 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs | |||
@@ -14,18 +14,25 @@ | |||
14 | 14 | ||
15 | module Data.Packed.Internal.Tensor where | 15 | module Data.Packed.Internal.Tensor where |
16 | 16 | ||
17 | import Data.Packed.Internal | 17 | import Data.Packed.Internal.Common |
18 | import Data.Packed.Internal.Vector | 18 | import Data.Packed.Internal.Vector |
19 | import Data.Packed.Internal.Matrix | 19 | import Data.Packed.Internal.Matrix |
20 | import Foreign.Storable | 20 | import Foreign.Storable |
21 | import Data.List(sort,elemIndex,nub) | 21 | import Data.List(sort,elemIndex,nub) |
22 | 22 | ||
23 | data IdxTp = Covariant | Contravariant deriving (Show,Eq) | 23 | data IdxType = Covariant | Contravariant deriving (Show,Eq) |
24 | 24 | ||
25 | data Tensor t = T { dims :: [(Int,(IdxTp,String))] | 25 | type IdxName = String |
26 | |||
27 | data IdxDesc = IdxDesc { idxDim :: Int, | ||
28 | idxType :: IdxType, | ||
29 | idxName :: IdxName } | ||
30 | |||
31 | data Tensor t = T { dims :: [IdxDesc] | ||
26 | , ten :: Vector t | 32 | , ten :: Vector t |
27 | } | 33 | } |
28 | 34 | ||
35 | rank :: Tensor t -> Int | ||
29 | rank = length . dims | 36 | rank = length . dims |
30 | 37 | ||
31 | instance (Show a,Storable a) => Show (Tensor a) where | 38 | instance (Show a,Storable a) => Show (Tensor a) where |
@@ -33,41 +40,49 @@ instance (Show a,Storable a) => Show (Tensor a) where | |||
33 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) | 40 | show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) |
34 | 41 | ||
35 | 42 | ||
36 | shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]" | 43 | shdims :: [IdxDesc] -> String |
44 | shdims [IdxDesc n t name] = name ++ sym t ++"["++show n++"]" | ||
37 | where sym Covariant = "_" | 45 | where sym Covariant = "_" |
38 | sym Contravariant = "^" | 46 | sym Contravariant = "^" |
39 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds | 47 | shdims (d:ds) = shdims [d] ++ "><"++ shdims ds |
40 | 48 | ||
41 | 49 | ||
42 | 50 | findIdx :: (Field t) => IdxName -> Tensor t | |
51 | -> (([IdxDesc], [IdxDesc]), Matrix t) | ||
43 | findIdx name t = ((d1,d2),m) where | 52 | findIdx name t = ((d1,d2),m) where |
44 | (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) | 53 | (d1,d2) = span (\d -> idxName d /= name) (dims t) |
45 | c = product (map fst d2) | 54 | c = product (map idxDim d2) |
46 | m = matrixFromVector RowMajor c (ten t) | 55 | m = matrixFromVector RowMajor c (ten t) |
47 | 56 | ||
57 | putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t) | ||
48 | putFirstIdx name t = (nd,m') | 58 | putFirstIdx name t = (nd,m') |
49 | where ((d1,d2),m) = findIdx name t | 59 | where ((d1,d2),m) = findIdx name t |
50 | m' = matrixFromVector RowMajor c $ cdat $ trans m | 60 | m' = matrixFromVector RowMajor c $ cdat $ trans m |
51 | nd = d2++d1 | 61 | nd = d2++d1 |
52 | c = dim (ten t) `div` (fst $ head d2) | 62 | c = dim (ten t) `div` (idxDim $ head d2) |
53 | 63 | ||
64 | part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t | ||
54 | part t (name,k) = if k<0 || k>=l | 65 | part t (name,k) = if k<0 || k>=l |
55 | then error $ "part "++show (name,k)++" out of range in "++show t | 66 | then error $ "part "++show (name,k)++" out of range" -- in "++show t |
56 | else T {dims = ds, ten = toRows m !! k} | 67 | else T {dims = ds, ten = toRows m !! k} |
57 | where (d:ds,m) = putFirstIdx name t | 68 | where (d:ds,m) = putFirstIdx name t |
58 | (l,_) = d | 69 | l = idxDim d |
59 | 70 | ||
71 | parts :: (Field t) => Tensor t -> IdxName -> [Tensor t] | ||
60 | parts t name = map f (toRows m) | 72 | parts t name = map f (toRows m) |
61 | where (d:ds,m) = putFirstIdx name t | 73 | where (d:ds,m) = putFirstIdx name t |
62 | (l,_) = d | 74 | l = idxDim d |
63 | f t = T {dims=ds, ten=t} | 75 | f t = T {dims=ds, ten=t} |
64 | 76 | ||
77 | concatRename :: [IdxDesc] -> [IdxDesc] -> [IdxDesc] | ||
65 | concatRename l1 l2 = l1 ++ map ren l2 where | 78 | concatRename l1 l2 = l1 ++ map ren l2 where |
66 | ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s)) | 79 | ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx |
67 | fs = map (snd.snd) l1 | 80 | fs = map idxName l1 |
68 | 81 | ||
69 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2) | 82 | prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t |
83 | prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) | ||
70 | 84 | ||
85 | contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | ||
71 | contraction t1 n1 t2 n2 = | 86 | contraction t1 n1 t2 n2 = |
72 | if compatIdx t1 n1 t2 n2 | 87 | if compatIdx t1 n1 t2 n2 |
73 | then T (concatRename (tail d1) (tail d2)) (cdat m) | 88 | then T (concatRename (tail d1) (tail d2)) (cdat m) |
@@ -76,18 +91,22 @@ contraction t1 n1 t2 n2 = | |||
76 | (d2,m2) = putFirstIdx n2 t2 | 91 | (d2,m2) = putFirstIdx n2 t2 |
77 | m = multiply RowMajor (trans m1) m2 | 92 | m = multiply RowMajor (trans m1) m2 |
78 | 93 | ||
94 | sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] | ||
79 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) | 95 | sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) |
80 | 96 | ||
97 | contract1 :: (Num t, Enum t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t | ||
81 | contract1 t name1 name2 = T d $ fromList $ sumT y | 98 | contract1 t name1 name2 = T d $ fromList $ sumT y |
82 | where d = dims (head y) | 99 | where d = dims (head y) |
83 | x = (map (flip parts name2) (parts t name1)) | 100 | x = (map (flip parts name2) (parts t name1)) |
84 | y = map head $ zipWith drop [0..] x | 101 | y = map head $ zipWith drop [0..] x |
85 | 102 | ||
103 | contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t | ||
86 | contraction' t1 n1 t2 n2 = | 104 | contraction' t1 n1 t2 n2 = |
87 | if compatIdx t1 n1 t2 n2 | 105 | if compatIdx t1 n1 t2 n2 |
88 | then contract1 (prod t1 t2) n1 (n2++"'") | 106 | then contract1 (prod t1 t2) n1 (n2++"'") |
89 | else error "wrong contraction'" | 107 | else error "wrong contraction'" |
90 | 108 | ||
109 | tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t | ||
91 | tridx [] t = t | 110 | tridx [] t = t |
92 | tridx (name:rest) t = T (d:ds) (join ts) where | 111 | tridx (name:rest) t = T (d:ds) (join ts) where |
93 | ((_,d:_),_) = findIdx name t | 112 | ((_,d:_),_) = findIdx name t |
@@ -95,30 +114,38 @@ tridx (name:rest) t = T (d:ds) (join ts) where | |||
95 | ts = map ten ps | 114 | ts = map ten ps |
96 | ds = dims (head ps) | 115 | ds = dims (head ps) |
97 | 116 | ||
98 | compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2 | 117 | compatIdxAux :: IdxDesc -> IdxDesc -> Bool |
118 | compatIdxAux IdxDesc {idxDim = n1, idxType = t1} | ||
119 | IdxDesc {idxDim = n2, idxType = t2} | ||
120 | = t1 /= t2 && n1 == n2 | ||
99 | 121 | ||
122 | compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool | ||
100 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where | 123 | compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where |
101 | d1 = head $ snd $ fst $ findIdx n1 t1 | 124 | d1 = head $ snd $ fst $ findIdx n1 t1 |
102 | d2 = head $ snd $ fst $ findIdx n2 t2 | 125 | d2 = head $ snd $ fst $ findIdx n2 t2 |
103 | 126 | ||
104 | names t = sort $ map (snd.snd) (dims t) | 127 | names :: Tensor t -> [IdxName] |
128 | names t = sort $ map idxName (dims t) | ||
105 | 129 | ||
130 | normal :: (Field t) => Tensor t -> Tensor t | ||
106 | normal t = tridx (names t) t | 131 | normal t = tridx (names t) t |
107 | 132 | ||
133 | contractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] | ||
108 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] | 134 | contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] |
109 | 135 | ||
110 | -- sent to Haskell-Cafe by Sebastian Sylvan | 136 | -- sent to Haskell-Cafe by Sebastian Sylvan |
137 | perms :: [t] -> [[t]] | ||
111 | perms [x] = [[x]] | 138 | perms [x] = [[x]] |
112 | perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] | 139 | perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] |
113 | selections [] = [] | 140 | selections [] = [] |
114 | selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] | 141 | selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] |
115 | 142 | ||
116 | 143 | interchanges :: (Ord a) => [a] -> Int | |
117 | interchanges ls = sum (map (count ls) ls) | 144 | interchanges ls = sum (map (count ls) ls) |
118 | where count l p = n | 145 | where count l p = length $ filter (>p) $ take pel l |
119 | where Just pel = elemIndex p l | 146 | where Just pel = elemIndex p l |
120 | n = length $ filter (>p) $ take pel l | ||
121 | 147 | ||
148 | signature :: (Num t, Ord a) => [a] -> t | ||
122 | signature l | length (nub l) < length l = 0 | 149 | signature l | length (nub l) < length l = 0 |
123 | | even (interchanges l) = 1 | 150 | | even (interchanges l) = 1 |
124 | | otherwise = -1 | 151 | | otherwise = -1 |
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 8848062..25e848d 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs | |||
@@ -103,15 +103,15 @@ asComplex :: Vector Double -> Vector (Complex Double) | |||
103 | asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } | 103 | asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } |
104 | 104 | ||
105 | 105 | ||
106 | constantG n x = fromList (replicate n x) | 106 | constantG x n = fromList (replicate n x) |
107 | 107 | ||
108 | constantR :: Int -> Double -> Vector Double | 108 | constantR :: Double -> Int -> Vector Double |
109 | constantR = constantAux cconstantR | 109 | constantR = constantAux cconstantR |
110 | 110 | ||
111 | constantC :: Int -> Complex Double -> Vector (Complex Double) | 111 | constantC :: Complex Double -> Int -> Vector (Complex Double) |
112 | constantC = constantAux cconstantC | 112 | constantC = constantAux cconstantC |
113 | 113 | ||
114 | constantAux fun n x = unsafePerformIO $ do | 114 | constantAux fun x n = unsafePerformIO $ do |
115 | v <- createVector n | 115 | v <- createVector n |
116 | px <- newArray [x] | 116 | px <- newArray [x] |
117 | fun px // vec v // check "constantAux" [] | 117 | fun px // vec v // check "constantAux" [] |
@@ -124,8 +124,8 @@ foreign import ccall safe "aux.h constantR" | |||
124 | foreign import ccall safe "aux.h constantC" | 124 | foreign import ccall safe "aux.h constantC" |
125 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int | 125 | cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int |
126 | 126 | ||
127 | constant :: Field a => Int -> a -> Vector a | 127 | constant :: Field a => a -> Int -> Vector a |
128 | constant n x | isReal id x = scast $ constantR n (scast x) | 128 | constant x n | isReal id x = scast $ constantR (scast x) n |
129 | | isComp id x = scast $ constantC n (scast x) | 129 | | isComp id x = scast $ constantC (scast x) n |
130 | | otherwise = constantG n x | 130 | | otherwise = constantG x n |
131 | 131 | ||
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index ec5744d..c7d5cfa 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -16,12 +16,13 @@ module Data.Packed.Matrix ( | |||
16 | Matrix(rows,cols), Field, | 16 | Matrix(rows,cols), Field, |
17 | toLists, (><), (>|<), (@@>), | 17 | toLists, (><), (>|<), (@@>), |
18 | trans, | 18 | trans, |
19 | reshape, | 19 | reshape, flatten, |
20 | fromRows, toRows, fromColumns, toColumns, | 20 | fromRows, toRows, fromColumns, toColumns, |
21 | joinVert, joinHoriz, | 21 | joinVert, joinHoriz, |
22 | flipud, fliprl, | 22 | flipud, fliprl, |
23 | liftMatrix, liftMatrix2, | 23 | liftMatrix, liftMatrix2, |
24 | multiply, | 24 | multiply, |
25 | outer, | ||
25 | subMatrix, | 26 | subMatrix, |
26 | takeRows, dropRows, takeColumns, dropColumns, | 27 | takeRows, dropRows, takeColumns, dropColumns, |
27 | diag, takeDiag, diagRect, ident | 28 | diag, takeDiag, diagRect, ident |
@@ -54,11 +55,11 @@ diagRect s r c | |||
54 | | r == c = diag s | 55 | | r == c = diag s |
55 | | r < c = trans $ diagRect s c r | 56 | | r < c = trans $ diagRect s c r |
56 | | r > c = joinVert [diag s , zeros (r-c,c)] | 57 | | r > c = joinVert [diag s , zeros (r-c,c)] |
57 | where zeros (r,c) = reshape c $ constant (r*c) 0 | 58 | where zeros (r,c) = reshape c $ constant 0 (r*c) |
58 | 59 | ||
59 | takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | 60 | takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] |
60 | 61 | ||
61 | ident n = diag (constant n 1) | 62 | ident n = diag (constant 1 n) |
62 | 63 | ||
63 | r >< c = f where | 64 | r >< c = f where |
64 | f l | dim v == r*c = matrixFromVector RowMajor c v | 65 | f l | dim v == r*c = matrixFromVector RowMajor c v |
@@ -88,3 +89,5 @@ dropColumns :: Field t => Int -> Matrix t -> Matrix t | |||
88 | dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat | 89 | dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat |
89 | 90 | ||
90 | ---------------------------------------------------------------- | 91 | ---------------------------------------------------------------- |
92 | |||
93 | flatten = cdat | ||
diff --git a/lib/Data/Packed/Plot.hs b/lib/Data/Packed/Plot.hs new file mode 100644 index 0000000..9eddc9f --- /dev/null +++ b/lib/Data/Packed/Plot.hs | |||
@@ -0,0 +1,167 @@ | |||
1 | ----------------------------------------------------------------------------- | ||
2 | -- | | ||
3 | -- Module : Data.Packed.Plot | ||
4 | -- Copyright : (c) Alberto Ruiz 2005 | ||
5 | -- License : GPL-style | ||
6 | -- | ||
7 | -- Maintainer : Alberto Ruiz (aruiz at um dot es) | ||
8 | -- Stability : provisional | ||
9 | -- Portability : uses gnuplot and ImageMagick | ||
10 | -- | ||
11 | -- Very basic (and provisional) drawing tools. | ||
12 | -- | ||
13 | ----------------------------------------------------------------------------- | ||
14 | |||
15 | module Data.Packed.Plot( | ||
16 | |||
17 | gnuplotX, mplot, | ||
18 | |||
19 | plot, parametricPlot, | ||
20 | |||
21 | splot, mesh, mesh', meshdom, | ||
22 | |||
23 | matrixToPGM, imshow, | ||
24 | |||
25 | ) where | ||
26 | |||
27 | import Data.Packed.Vector | ||
28 | import Data.Packed.Matrix | ||
29 | import GSL.Vector(FunCodeS(Max,Min),toScalarR) | ||
30 | import Data.List(intersperse) | ||
31 | import System | ||
32 | import Data.IORef | ||
33 | import System.Exit | ||
34 | import Foreign hiding (rotate) | ||
35 | |||
36 | |||
37 | size = dim | ||
38 | |||
39 | -- | Loads a real matrix from a formatted ASCII text file | ||
40 | --fromFile :: FilePath -> IO Matrix | ||
41 | --fromFile filename = readFile filename >>= return . readMatrix read | ||
42 | |||
43 | -- | Saves a real matrix to a formatted ascii text file | ||
44 | toFile :: FilePath -> Matrix Double -> IO () | ||
45 | toFile filename matrix = writeFile filename (unlines . map unwords. map (map show) . toLists $ matrix) | ||
46 | |||
47 | ------------------------------------------------------------------------ | ||
48 | |||
49 | |||
50 | -- | From vectors x and y, it generates a pair of matrices to be used as x and y arguments for matrix functions. | ||
51 | meshdom :: Vector Double -> Vector Double -> (Matrix Double , Matrix Double) | ||
52 | meshdom r1 r2 = (outer r1 (constant 1 (size r2)), outer (constant 1 (size r1)) r2) | ||
53 | |||
54 | |||
55 | gnuplotX command = do {system cmdstr; return()} where | ||
56 | cmdstr = "echo \""++command++"\" | gnuplot -persist" | ||
57 | |||
58 | datafollows = "\\\"-\\\"" | ||
59 | |||
60 | prep = (++"e\n\n") . unlines . map (unwords . (map show)) | ||
61 | |||
62 | |||
63 | {- | Draws a 3D surface representation of a real matrix. | ||
64 | |||
65 | > > mesh (hilb 20) | ||
66 | |||
67 | In certain versions you can interactively rotate the graphic using the mouse. | ||
68 | |||
69 | -} | ||
70 | mesh :: Matrix Double -> IO () | ||
71 | mesh m = gnuplotX (command++dat) where | ||
72 | command = "splot "++datafollows++" matrix with lines\n" | ||
73 | dat = prep $ toLists $ m | ||
74 | |||
75 | mesh' m = do | ||
76 | writeFile "splot-gnu-command" "splot \"splot-tmp.txt\" matrix with lines; pause -1"; | ||
77 | toFile "splot-tmp.txt" m | ||
78 | putStr "Press [Return] to close the graphic and continue... " | ||
79 | system "gnuplot -persist splot-gnu-command" | ||
80 | system "rm splot-tmp.txt splot-gnu-command" | ||
81 | return () | ||
82 | |||
83 | {- | Draws the surface represented by the function f in the desired ranges and number of points, internally using 'mesh'. | ||
84 | |||
85 | > > let f x y = cos (x + y) | ||
86 | > > splot f (0,pi) (0,2*pi) 50 | ||
87 | |||
88 | -} | ||
89 | splot :: (Matrix Double->Matrix Double->Matrix Double) -> (Double,Double) -> (Double,Double) -> Int -> IO () | ||
90 | splot f rx ry n = mesh' z where | ||
91 | (x,y) = meshdom (linspace n rx) (linspace n ry) | ||
92 | z = f x y | ||
93 | |||
94 | {- | plots several vectors against the first one -} | ||
95 | mplot :: [Vector Double] -> IO () | ||
96 | mplot m = gnuplotX (commands++dats) where | ||
97 | commands = if length m == 1 then command1 else commandmore | ||
98 | command1 = "plot "++datafollows++" with lines\n" ++ dat | ||
99 | commandmore = "plot " ++ plots ++ "\n" | ||
100 | plots = concat $ intersperse ", " (map cmd [2 .. length m]) | ||
101 | cmd k = datafollows++" using 1:"++show k++" with lines" | ||
102 | dat = prep $ toLists $ fromColumns m | ||
103 | dats = concat (replicate (length m-1) dat) | ||
104 | |||
105 | |||
106 | |||
107 | |||
108 | |||
109 | |||
110 | mplot' m = do | ||
111 | writeFile "plot-gnu-command" (commands++endcmd) | ||
112 | toFile "plot-tmp.txt" (fromColumns m) | ||
113 | putStr "Press [Return] to close the graphic and continue... " | ||
114 | system "gnuplot plot-gnu-command" | ||
115 | system "rm plot-tmp.txt plot-gnu-command" | ||
116 | return () | ||
117 | where | ||
118 | commands = if length m == 1 then command1 else commandmore | ||
119 | command1 = "plot \"plot-tmp.txt\" with lines\n" | ||
120 | commandmore = "plot " ++ plots ++ "\n" | ||
121 | plots = concat $ intersperse ", " (map cmd [2 .. length m]) | ||
122 | cmd k = "\"plot-tmp.txt\" using 1:"++show k++" with lines" | ||
123 | endcmd = "pause -1" | ||
124 | |||
125 | -- apply several functions to one object | ||
126 | mapf fs x = map ($ x) fs | ||
127 | |||
128 | {- | Draws a list of functions over a desired range and with a desired number of points | ||
129 | |||
130 | > > plot [sin, cos, sin.(3*)] (0,2*pi) 1000 | ||
131 | |||
132 | -} | ||
133 | plot :: [Vector Double->Vector Double] -> (Double,Double) -> Int -> IO () | ||
134 | plot fs rx n = mplot (x: mapf fs x) | ||
135 | where x = linspace n rx | ||
136 | |||
137 | {- | Draws a parametric curve. For instance, to draw a spiral we can do something like: | ||
138 | |||
139 | > > parametricPlot (\t->(t * sin t, t * cos t)) (0,10*pi) 1000 | ||
140 | |||
141 | -} | ||
142 | parametricPlot :: (Vector Double->(Vector Double,Vector Double)) -> (Double, Double) -> Int -> IO () | ||
143 | parametricPlot f rt n = mplot [fx, fy] | ||
144 | where t = linspace n rt | ||
145 | (fx,fy) = f t | ||
146 | |||
147 | |||
148 | -- | writes a matrix to pgm image file | ||
149 | matrixToPGM :: Matrix Double -> String | ||
150 | matrixToPGM m = header ++ unlines (map unwords ll) where | ||
151 | c = cols m | ||
152 | r = rows m | ||
153 | header = "P2 "++show c++" "++show r++" "++show (round maxgray :: Int)++"\n" | ||
154 | maxgray = 255.0 | ||
155 | maxval = toScalarR Max $ flatten $ m | ||
156 | minval = toScalarR Min $ flatten $ m | ||
157 | scale = if (maxval == minval) | ||
158 | then 0.0 | ||
159 | else maxgray / (maxval - minval) | ||
160 | f x = show ( round ( scale *(x - minval) ) :: Int ) | ||
161 | ll = map (map f) (toLists m) | ||
162 | |||
163 | -- | imshow shows a representation of a matrix as a gray level image using ImageMagick's display. | ||
164 | imshow :: Matrix Double -> IO () | ||
165 | imshow m = do | ||
166 | system $ "echo \""++ matrixToPGM m ++"\"| display -antialias -resize 300 - &" | ||
167 | return () | ||
diff --git a/lib/Data/Packed/Tensor.hs b/lib/Data/Packed/Tensor.hs index 8d1c8b6..75a9288 100644 --- a/lib/Data/Packed/Tensor.hs +++ b/lib/Data/Packed/Tensor.hs | |||
@@ -1 +1,20 @@ | |||
1 | 1 | ----------------------------------------------------------------------------- | |
2 | -- | | ||
3 | -- Module : Data.Packed.Tensor | ||
4 | -- Copyright : (c) Alberto Ruiz 2007 | ||
5 | -- License : GPL-style | ||
6 | -- | ||
7 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
8 | -- Stability : provisional | ||
9 | -- Portability : portable | ||
10 | -- | ||
11 | -- Tensors | ||
12 | -- | ||
13 | ----------------------------------------------------------------------------- | ||
14 | |||
15 | module Data.Packed.Tensor ( | ||
16 | |||
17 | ) where | ||
18 | |||
19 | import Data.Packed.Internal | ||
20 | import Complex | ||
diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs index 992301a..aa1b489 100644 --- a/lib/Data/Packed/Vector.hs +++ b/lib/Data/Packed/Vector.hs | |||
@@ -20,7 +20,8 @@ module Data.Packed.Vector ( | |||
20 | constant, | 20 | constant, |
21 | toComplex, comp, | 21 | toComplex, comp, |
22 | conj, | 22 | conj, |
23 | dot | 23 | dot, |
24 | linspace | ||
24 | ) where | 25 | ) where |
25 | 26 | ||
26 | import Data.Packed.Internal | 27 | import Data.Packed.Internal |
@@ -35,6 +36,14 @@ conj :: Vector (Complex Double) -> Vector (Complex Double) | |||
35 | conj v = asComplex $ cdat $ reshape 2 (asReal v) `mulC` diag (fromList [1,-1]) | 36 | conj v = asComplex $ cdat $ reshape 2 (asReal v) `mulC` diag (fromList [1,-1]) |
36 | where mulC = multiply RowMajor | 37 | where mulC = multiply RowMajor |
37 | 38 | ||
38 | comp v = toComplex (v,constant (dim v) 0) | 39 | comp v = toComplex (v,constant 0 (dim v)) |
39 | 40 | ||
41 | {- | Creates a real vector containing a range of values: | ||
40 | 42 | ||
43 | > > linspace 10 (-2,2) | ||
44 | >-2. -1.556 -1.111 -0.667 -0.222 0.222 0.667 1.111 1.556 2. | ||
45 | |||
46 | -} | ||
47 | linspace :: Int -> (Double, Double) -> Vector Double | ||
48 | linspace n (a,b) = fromList [a::Double,a+delta .. b] | ||
49 | where delta = (b-a)/(fromIntegral n -1) | ||