From 978e6d038239af50d70bae2c303f4e45b1879b7a Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 22 Jun 2007 17:33:17 +0000 Subject: refactoring --- lib/Data/Packed/Internal.hs | 4 +- lib/Data/Packed/Internal/Common.hs | 7 ++ lib/Data/Packed/Internal/Matrix.hs | 7 +- lib/Data/Packed/Internal/Tensor.hs | 65 ++++++++++----- lib/Data/Packed/Internal/Vector.hs | 16 ++-- lib/Data/Packed/Matrix.hs | 9 +- lib/Data/Packed/Plot.hs | 167 +++++++++++++++++++++++++++++++++++++ lib/Data/Packed/Tensor.hs | 21 ++++- lib/Data/Packed/Vector.hs | 13 ++- 9 files changed, 272 insertions(+), 37 deletions(-) create mode 100644 lib/Data/Packed/Plot.hs (limited to 'lib/Data/Packed') diff --git a/lib/Data/Packed/Internal.hs b/lib/Data/Packed/Internal.hs index a7fca1a..a5a77c5 100644 --- a/lib/Data/Packed/Internal.hs +++ b/lib/Data/Packed/Internal.hs @@ -15,9 +15,11 @@ module Data.Packed.Internal ( module Data.Packed.Internal.Common, module Data.Packed.Internal.Vector, - module Data.Packed.Internal.Matrix + module Data.Packed.Internal.Matrix, + module Data.Packed.Internal.Tensor ) where import Data.Packed.Internal.Common import Data.Packed.Internal.Vector import Data.Packed.Internal.Matrix +import Data.Packed.Internal.Tensor \ No newline at end of file diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs index bdd7f34..1bfed6d 100644 --- a/lib/Data/Packed/Internal/Common.hs +++ b/lib/Data/Packed/Internal/Common.hs @@ -40,6 +40,7 @@ instance (Storable a, RealFloat a) => Storable (Complex a) where -- poke p (a :+ b) = pokeArray (castPtr p) [a,b] -- ---------------------------------------------------------------------- +on :: (a -> a -> b) -> (t -> a) -> t -> t -> b on f g = \x y -> f (g x) (g y) partit :: Int -> [a] -> [[a]] @@ -54,12 +55,14 @@ common f = commonval . map f where commonval [a] = Just a commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing +xor :: Bool -> Bool -> Bool xor a b = a && not b || b && not a (//) :: x -> (x -> y) -> y infixl 0 // (//) = flip ($) +errorCode :: Int -> String errorCode 1000 = "bad size" errorCode 1001 = "bad function code" errorCode 1002 = "memory problem" @@ -68,6 +71,7 @@ errorCode 1004 = "singular" errorCode 1005 = "didn't converge" errorCode n = "code "++show n +check :: String -> [Vector a] -> IO Int -> IO () check msg ls f = do err <- f when (err/=0) (error (msg++": "++errorCode err)) @@ -77,7 +81,10 @@ check msg ls f = do class (Storable a, Typeable a) => Field a instance (Storable a, Typeable a) => Field a +isReal :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool isReal w x = typeOf (undefined :: Double) == typeOf (w x) + +isComp :: (Data.Typeable.Typeable a) => (t -> a) -> t -> Bool isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x) scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 2925fc0..32dc603 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs @@ -194,7 +194,9 @@ multiplyD order a b ---------------------------------------------------------------------- -outer u v = dat (multiply RowMajor r c) +outer' u v = dat (outer u v) + +outer u v = multiply RowMajor r c where r = matrixFromVector RowMajor 1 u c = matrixFromVector RowMajor (dim v) v @@ -212,8 +214,7 @@ subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do r <- createMatrix RowMajor rt ct c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r] return r -foreign import ccall "aux.h submatrixR" - c_submatrixR :: Int -> Int -> Int -> Int -> TMM +foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM -- | extraction of a submatrix of a complex matrix subMatrixC :: (Int,Int) -- ^ (r0,c0) starting position diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index 27fce6a..c4faf49 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs @@ -14,18 +14,25 @@ module Data.Packed.Internal.Tensor where -import Data.Packed.Internal +import Data.Packed.Internal.Common import Data.Packed.Internal.Vector import Data.Packed.Internal.Matrix import Foreign.Storable import Data.List(sort,elemIndex,nub) -data IdxTp = Covariant | Contravariant deriving (Show,Eq) +data IdxType = Covariant | Contravariant deriving (Show,Eq) -data Tensor t = T { dims :: [(Int,(IdxTp,String))] +type IdxName = String + +data IdxDesc = IdxDesc { idxDim :: Int, + idxType :: IdxType, + idxName :: IdxName } + +data Tensor t = T { dims :: [IdxDesc] , ten :: Vector t } +rank :: Tensor t -> Int rank = length . dims instance (Show a,Storable a) => Show (Tensor a) where @@ -33,41 +40,49 @@ instance (Show a,Storable a) => Show (Tensor a) where show T {dims = ds, ten = t} = "("++shdims ds ++") "++ show (toList t) -shdims [(n,(t,name))] = name ++ sym t ++"["++show n++"]" +shdims :: [IdxDesc] -> String +shdims [IdxDesc n t name] = name ++ sym t ++"["++show n++"]" where sym Covariant = "_" sym Contravariant = "^" shdims (d:ds) = shdims [d] ++ "><"++ shdims ds - +findIdx :: (Field t) => IdxName -> Tensor t + -> (([IdxDesc], [IdxDesc]), Matrix t) findIdx name t = ((d1,d2),m) where - (d1,d2) = span (\(_,(_,n)) -> n /=name) (dims t) - c = product (map fst d2) + (d1,d2) = span (\d -> idxName d /= name) (dims t) + c = product (map idxDim d2) m = matrixFromVector RowMajor c (ten t) +putFirstIdx :: (Field t) => String -> Tensor t -> ([IdxDesc], Matrix t) putFirstIdx name t = (nd,m') where ((d1,d2),m) = findIdx name t m' = matrixFromVector RowMajor c $ cdat $ trans m nd = d2++d1 - c = dim (ten t) `div` (fst $ head d2) + c = dim (ten t) `div` (idxDim $ head d2) +part :: (Field t) => Tensor t -> (IdxName, Int) -> Tensor t part t (name,k) = if k<0 || k>=l - then error $ "part "++show (name,k)++" out of range in "++show t + then error $ "part "++show (name,k)++" out of range" -- in "++show t else T {dims = ds, ten = toRows m !! k} where (d:ds,m) = putFirstIdx name t - (l,_) = d + l = idxDim d +parts :: (Field t) => Tensor t -> IdxName -> [Tensor t] parts t name = map f (toRows m) where (d:ds,m) = putFirstIdx name t - (l,_) = d + l = idxDim d f t = T {dims=ds, ten=t} +concatRename :: [IdxDesc] -> [IdxDesc] -> [IdxDesc] concatRename l1 l2 = l1 ++ map ren l2 where - ren (n,(t,s)) = if {- s `elem` fs -} True then (n,(t,s++"'")) else (n,(t,s)) - fs = map (snd.snd) l1 + ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx + fs = map idxName l1 -prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer v1 v2) +prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t +prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) +contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t contraction t1 n1 t2 n2 = if compatIdx t1 n1 t2 n2 then T (concatRename (tail d1) (tail d2)) (cdat m) @@ -76,18 +91,22 @@ contraction t1 n1 t2 n2 = (d2,m2) = putFirstIdx n2 t2 m = multiply RowMajor (trans m1) m2 +sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) +contract1 :: (Num t, Enum t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t contract1 t name1 name2 = T d $ fromList $ sumT y where d = dims (head y) x = (map (flip parts name2) (parts t name1)) y = map head $ zipWith drop [0..] x +contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t contraction' t1 n1 t2 n2 = if compatIdx t1 n1 t2 n2 then contract1 (prod t1 t2) n1 (n2++"'") else error "wrong contraction'" +tridx :: (Field t) => [IdxName] -> Tensor t -> Tensor t tridx [] t = t tridx (name:rest) t = T (d:ds) (join ts) where ((_,d:_),_) = findIdx name t @@ -95,30 +114,38 @@ tridx (name:rest) t = T (d:ds) (join ts) where ts = map ten ps ds = dims (head ps) -compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2 +compatIdxAux :: IdxDesc -> IdxDesc -> Bool +compatIdxAux IdxDesc {idxDim = n1, idxType = t1} + IdxDesc {idxDim = n2, idxType = t2} + = t1 /= t2 && n1 == n2 +compatIdx :: (Field t1, Field t) => Tensor t1 -> IdxName -> Tensor t -> IdxName -> Bool compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where d1 = head $ snd $ fst $ findIdx n1 t1 d2 = head $ snd $ fst $ findIdx n2 t2 -names t = sort $ map (snd.snd) (dims t) +names :: Tensor t -> [IdxName] +names t = sort $ map idxName (dims t) +normal :: (Field t) => Tensor t -> Tensor t normal t = tridx (names t) t +contractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] -- sent to Haskell-Cafe by Sebastian Sylvan +perms :: [t] -> [[t]] perms [x] = [[x]] perms xs = [y:ps | (y,ys) <- selections xs , ps <- perms ys] selections [] = [] selections (x:xs) = (x,xs) : [(y,x:ys) | (y,ys) <- selections xs] - +interchanges :: (Ord a) => [a] -> Int interchanges ls = sum (map (count ls) ls) - where count l p = n + where count l p = length $ filter (>p) $ take pel l where Just pel = elemIndex p l - n = length $ filter (>p) $ take pel l +signature :: (Num t, Ord a) => [a] -> t signature l | length (nub l) < length l = 0 | even (interchanges l) = 1 | otherwise = -1 diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 8848062..25e848d 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs @@ -103,15 +103,15 @@ asComplex :: Vector Double -> Vector (Complex Double) asComplex v = V { dim = dim v `div` 2, fptr = castForeignPtr (fptr v), ptr = castPtr (ptr v) } -constantG n x = fromList (replicate n x) +constantG x n = fromList (replicate n x) -constantR :: Int -> Double -> Vector Double +constantR :: Double -> Int -> Vector Double constantR = constantAux cconstantR -constantC :: Int -> Complex Double -> Vector (Complex Double) +constantC :: Complex Double -> Int -> Vector (Complex Double) constantC = constantAux cconstantC -constantAux fun n x = unsafePerformIO $ do +constantAux fun x n = unsafePerformIO $ do v <- createVector n px <- newArray [x] fun px // vec v // check "constantAux" [] @@ -124,8 +124,8 @@ foreign import ccall safe "aux.h constantR" foreign import ccall safe "aux.h constantC" cconstantC :: Ptr (Complex Double) -> TCV -- Complex Double :> IO Int -constant :: Field a => Int -> a -> Vector a -constant n x | isReal id x = scast $ constantR n (scast x) - | isComp id x = scast $ constantC n (scast x) - | otherwise = constantG n x +constant :: Field a => a -> Int -> Vector a +constant x n | isReal id x = scast $ constantR (scast x) n + | isComp id x = scast $ constantC (scast x) n + | otherwise = constantG x n diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index ec5744d..c7d5cfa 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -16,12 +16,13 @@ module Data.Packed.Matrix ( Matrix(rows,cols), Field, toLists, (><), (>|<), (@@>), trans, - reshape, + reshape, flatten, fromRows, toRows, fromColumns, toColumns, joinVert, joinHoriz, flipud, fliprl, liftMatrix, liftMatrix2, multiply, + outer, subMatrix, takeRows, dropRows, takeColumns, dropColumns, diag, takeDiag, diagRect, ident @@ -54,11 +55,11 @@ diagRect s r c | r == c = diag s | r < c = trans $ diagRect s c r | r > c = joinVert [diag s , zeros (r-c,c)] - where zeros (r,c) = reshape c $ constant (r*c) 0 + where zeros (r,c) = reshape c $ constant 0 (r*c) takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] -ident n = diag (constant n 1) +ident n = diag (constant 1 n) r >< c = f where f l | dim v == r*c = matrixFromVector RowMajor c v @@ -88,3 +89,5 @@ dropColumns :: Field t => Int -> Matrix t -> Matrix t dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat ---------------------------------------------------------------- + +flatten = cdat diff --git a/lib/Data/Packed/Plot.hs b/lib/Data/Packed/Plot.hs new file mode 100644 index 0000000..9eddc9f --- /dev/null +++ b/lib/Data/Packed/Plot.hs @@ -0,0 +1,167 @@ +----------------------------------------------------------------------------- +-- | +-- Module : Data.Packed.Plot +-- Copyright : (c) Alberto Ruiz 2005 +-- License : GPL-style +-- +-- Maintainer : Alberto Ruiz (aruiz at um dot es) +-- Stability : provisional +-- Portability : uses gnuplot and ImageMagick +-- +-- Very basic (and provisional) drawing tools. +-- +----------------------------------------------------------------------------- + +module Data.Packed.Plot( + + gnuplotX, mplot, + + plot, parametricPlot, + + splot, mesh, mesh', meshdom, + + matrixToPGM, imshow, + +) where + +import Data.Packed.Vector +import Data.Packed.Matrix +import GSL.Vector(FunCodeS(Max,Min),toScalarR) +import Data.List(intersperse) +import System +import Data.IORef +import System.Exit +import Foreign hiding (rotate) + + +size = dim + +-- | Loads a real matrix from a formatted ASCII text file +--fromFile :: FilePath -> IO Matrix +--fromFile filename = readFile filename >>= return . readMatrix read + +-- | Saves a real matrix to a formatted ascii text file +toFile :: FilePath -> Matrix Double -> IO () +toFile filename matrix = writeFile filename (unlines . map unwords. map (map show) . toLists $ matrix) + +------------------------------------------------------------------------ + + +-- | From vectors x and y, it generates a pair of matrices to be used as x and y arguments for matrix functions. +meshdom :: Vector Double -> Vector Double -> (Matrix Double , Matrix Double) +meshdom r1 r2 = (outer r1 (constant 1 (size r2)), outer (constant 1 (size r1)) r2) + + +gnuplotX command = do {system cmdstr; return()} where + cmdstr = "echo \""++command++"\" | gnuplot -persist" + +datafollows = "\\\"-\\\"" + +prep = (++"e\n\n") . unlines . map (unwords . (map show)) + + +{- | Draws a 3D surface representation of a real matrix. + +> > mesh (hilb 20) + +In certain versions you can interactively rotate the graphic using the mouse. + +-} +mesh :: Matrix Double -> IO () +mesh m = gnuplotX (command++dat) where + command = "splot "++datafollows++" matrix with lines\n" + dat = prep $ toLists $ m + +mesh' m = do + writeFile "splot-gnu-command" "splot \"splot-tmp.txt\" matrix with lines; pause -1"; + toFile "splot-tmp.txt" m + putStr "Press [Return] to close the graphic and continue... " + system "gnuplot -persist splot-gnu-command" + system "rm splot-tmp.txt splot-gnu-command" + return () + +{- | Draws the surface represented by the function f in the desired ranges and number of points, internally using 'mesh'. + +> > let f x y = cos (x + y) +> > splot f (0,pi) (0,2*pi) 50 + +-} +splot :: (Matrix Double->Matrix Double->Matrix Double) -> (Double,Double) -> (Double,Double) -> Int -> IO () +splot f rx ry n = mesh' z where + (x,y) = meshdom (linspace n rx) (linspace n ry) + z = f x y + +{- | plots several vectors against the first one -} +mplot :: [Vector Double] -> IO () +mplot m = gnuplotX (commands++dats) where + commands = if length m == 1 then command1 else commandmore + command1 = "plot "++datafollows++" with lines\n" ++ dat + commandmore = "plot " ++ plots ++ "\n" + plots = concat $ intersperse ", " (map cmd [2 .. length m]) + cmd k = datafollows++" using 1:"++show k++" with lines" + dat = prep $ toLists $ fromColumns m + dats = concat (replicate (length m-1) dat) + + + + + + +mplot' m = do + writeFile "plot-gnu-command" (commands++endcmd) + toFile "plot-tmp.txt" (fromColumns m) + putStr "Press [Return] to close the graphic and continue... " + system "gnuplot plot-gnu-command" + system "rm plot-tmp.txt plot-gnu-command" + return () + where + commands = if length m == 1 then command1 else commandmore + command1 = "plot \"plot-tmp.txt\" with lines\n" + commandmore = "plot " ++ plots ++ "\n" + plots = concat $ intersperse ", " (map cmd [2 .. length m]) + cmd k = "\"plot-tmp.txt\" using 1:"++show k++" with lines" + endcmd = "pause -1" + +-- apply several functions to one object +mapf fs x = map ($ x) fs + +{- | Draws a list of functions over a desired range and with a desired number of points + +> > plot [sin, cos, sin.(3*)] (0,2*pi) 1000 + +-} +plot :: [Vector Double->Vector Double] -> (Double,Double) -> Int -> IO () +plot fs rx n = mplot (x: mapf fs x) + where x = linspace n rx + +{- | Draws a parametric curve. For instance, to draw a spiral we can do something like: + +> > parametricPlot (\t->(t * sin t, t * cos t)) (0,10*pi) 1000 + +-} +parametricPlot :: (Vector Double->(Vector Double,Vector Double)) -> (Double, Double) -> Int -> IO () +parametricPlot f rt n = mplot [fx, fy] + where t = linspace n rt + (fx,fy) = f t + + +-- | writes a matrix to pgm image file +matrixToPGM :: Matrix Double -> String +matrixToPGM m = header ++ unlines (map unwords ll) where + c = cols m + r = rows m + header = "P2 "++show c++" "++show r++" "++show (round maxgray :: Int)++"\n" + maxgray = 255.0 + maxval = toScalarR Max $ flatten $ m + minval = toScalarR Min $ flatten $ m + scale = if (maxval == minval) + then 0.0 + else maxgray / (maxval - minval) + f x = show ( round ( scale *(x - minval) ) :: Int ) + ll = map (map f) (toLists m) + +-- | imshow shows a representation of a matrix as a gray level image using ImageMagick's display. +imshow :: Matrix Double -> IO () +imshow m = do + system $ "echo \""++ matrixToPGM m ++"\"| display -antialias -resize 300 - &" + return () diff --git a/lib/Data/Packed/Tensor.hs b/lib/Data/Packed/Tensor.hs index 8d1c8b6..75a9288 100644 --- a/lib/Data/Packed/Tensor.hs +++ b/lib/Data/Packed/Tensor.hs @@ -1 +1,20 @@ - +----------------------------------------------------------------------------- +-- | +-- Module : Data.Packed.Tensor +-- Copyright : (c) Alberto Ruiz 2007 +-- License : GPL-style +-- +-- Maintainer : Alberto Ruiz +-- Stability : provisional +-- Portability : portable +-- +-- Tensors +-- +----------------------------------------------------------------------------- + +module Data.Packed.Tensor ( + +) where + +import Data.Packed.Internal +import Complex diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs index 992301a..aa1b489 100644 --- a/lib/Data/Packed/Vector.hs +++ b/lib/Data/Packed/Vector.hs @@ -20,7 +20,8 @@ module Data.Packed.Vector ( constant, toComplex, comp, conj, - dot + dot, + linspace ) where import Data.Packed.Internal @@ -35,6 +36,14 @@ conj :: Vector (Complex Double) -> Vector (Complex Double) conj v = asComplex $ cdat $ reshape 2 (asReal v) `mulC` diag (fromList [1,-1]) where mulC = multiply RowMajor -comp v = toComplex (v,constant (dim v) 0) +comp v = toComplex (v,constant 0 (dim v)) +{- | Creates a real vector containing a range of values: +> > linspace 10 (-2,2) +>-2. -1.556 -1.111 -0.667 -0.222 0.222 0.667 1.111 1.556 2. + +-} +linspace :: Int -> (Double, Double) -> Vector Double +linspace n (a,b) = fromList [a::Double,a+delta .. b] + where delta = (b-a)/(fromIntegral n -1) -- cgit v1.2.3