summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed/Internal')
-rw-r--r--lib/Data/Packed/Internal/Common.hs84
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs109
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs12
-rw-r--r--lib/Data/Packed/Internal/Vector.hs62
4 files changed, 124 insertions, 143 deletions
diff --git a/lib/Data/Packed/Internal/Common.hs b/lib/Data/Packed/Internal/Common.hs
new file mode 100644
index 0000000..dddd269
--- /dev/null
+++ b/lib/Data/Packed/Internal/Common.hs
@@ -0,0 +1,84 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Internal.Common
5-- Copyright : (c) Alberto Ruiz 2007
6-- License : GPL-style
7--
8-- Maintainer : Alberto Ruiz <aruiz@um.es>
9-- Stability : provisional
10-- Portability : portable (uses FFI)
11--
12-- Common tools
13--
14-----------------------------------------------------------------------------
15
16module Data.Packed.Internal.Common where
17
18import Foreign
19import Complex
20import Control.Monad(when)
21import Debug.Trace
22import Data.List(transpose,intersperse)
23import Data.Typeable
24import Data.Maybe(fromJust)
25
26debug x = trace (show x) x
27
28data Vector t = V { dim :: Int
29 , fptr :: ForeignPtr t
30 , ptr :: Ptr t
31 } deriving Typeable
32
33----------------------------------------------------------------------
34instance (Storable a, RealFloat a) => Storable (Complex a) where --
35 alignment x = alignment (realPart x) --
36 sizeOf x = 2 * sizeOf (realPart x) --
37 peek p = do --
38 [re,im] <- peekArray 2 (castPtr p) --
39 return (re :+ im) --
40 poke p (a :+ b) = pokeArray (castPtr p) [a,b] --
41----------------------------------------------------------------------
42
43on f g = \x y -> f (g x) (g y)
44
45partit :: Int -> [a] -> [[a]]
46partit _ [] = []
47partit n l = take n l : partit n (drop n l)
48
49-- | obtains the common value of a property of a list
50common :: (Eq a) => (b->a) -> [b] -> Maybe a
51common f = commonval . map f where
52 commonval :: (Eq a) => [a] -> Maybe a
53 commonval [] = Nothing
54 commonval [a] = Just a
55 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
56
57xor a b = a && not b || b && not a
58
59(//) :: x -> (x -> y) -> y
60infixl 0 //
61(//) = flip ($)
62
63errorCode 1000 = "bad size"
64errorCode 1001 = "bad function code"
65errorCode 1002 = "memory problem"
66errorCode 1003 = "bad file"
67errorCode 1004 = "singular"
68errorCode 1005 = "didn't converge"
69errorCode n = "code "++show n
70
71check msg ls f = do
72 err <- f
73 when (err/=0) (error (msg++": "++errorCode err))
74 mapM_ (touchForeignPtr . fptr) ls
75 return ()
76
77class (Storable a, Typeable a) => Field a where
78instance (Storable a, Typeable a) => Field a where
79
80isReal w x = typeOf (undefined :: Double) == typeOf (w x)
81isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x)
82
83scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b
84scast = fromJust . cast
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index bae56f1..2c0acdf 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -1,4 +1,4 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} 1{-# OPTIONS_GHC -fglasgow-exts #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3-- | 3-- |
4-- Module : Data.Packed.Internal.Matrix 4-- Module : Data.Packed.Internal.Matrix
@@ -15,18 +15,16 @@
15 15
16module Data.Packed.Internal.Matrix where 16module Data.Packed.Internal.Matrix where
17 17
18import Data.Packed.Internal.Common
18import Data.Packed.Internal.Vector 19import Data.Packed.Internal.Vector
19 20
20import Foreign hiding (xor) 21import Foreign hiding (xor)
21import Complex 22import Complex
22import Control.Monad(when) 23import Control.Monad(when)
23import Debug.Trace
24import Data.List(transpose,intersperse) 24import Data.List(transpose,intersperse)
25import Data.Typeable 25import Data.Typeable
26import Data.Maybe(fromJust) 26import Data.Maybe(fromJust)
27 27
28debug x = trace (show x) x
29
30 28
31data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) 29data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
32 30
@@ -39,7 +37,7 @@ data Matrix t = M { rows :: Int
39 , order :: MatrixOrder 37 , order :: MatrixOrder
40 } deriving Typeable 38 } deriving Typeable
41 39
42xor a b = a && not b || b && not a 40
43 41
44fortran m = order m == ColumnMajor 42fortran m = order m == ColumnMajor
45 43
@@ -57,25 +55,12 @@ type t ::> s = Mt t s
57 55
58mat d m f = f (rows m) (cols m) (ptr (d m)) 56mat d m f = f (rows m) (cols m) (ptr (d m))
59 57
58toLists m = partit (cols m) . toList . cdat $ m
59
60instance (Show a, Storable a) => (Show (Matrix a)) where 60instance (Show a, Storable a) => (Show (Matrix a)) where
61 show m = (sizes++) . dsp . map (map show) . toLists $ m 61 show m = (sizes++) . dsp . map (map show) . toLists $ m
62 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" 62 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n"
63 63
64partit :: Int -> [a] -> [[a]]
65partit _ [] = []
66partit n l = take n l : partit n (drop n l)
67
68-- | obtains the common value of a property of a list
69common :: (Eq a) => (b->a) -> [b] -> Maybe a
70common f = commonval . map f where
71 commonval :: (Eq a) => [a] -> Maybe a
72 commonval [] = Nothing
73 commonval [a] = Just a
74 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
75
76
77toLists m = partit (cols m) . toList . cdat $ m
78
79dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp 64dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
80 where 65 where
81 mt = transpose as 66 mt = transpose as
@@ -146,62 +131,6 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
146--{-# RULES "transdataR" transdata=transdataR #-} 131--{-# RULES "transdataR" transdata=transdataR #-}
147--{-# RULES "transdataC" transdata=transdataC #-} 132--{-# RULES "transdataC" transdata=transdataC #-}
148 133
149-----------------------------------------------------------------------------
150
151-- | creates a Matrix from a list of vectors
152fromRows :: Field t => [Vector t] -> Matrix t
153fromRows vs = case common dim vs of
154 Nothing -> error "fromRows applied to [] or to vectors with different sizes"
155 Just c -> reshape c (join vs)
156
157-- | extracts the rows of a matrix as a list of vectors
158toRows :: Storable t => Matrix t -> [Vector t]
159toRows m = toRows' 0 where
160 v = cdat m
161 r = rows m
162 c = cols m
163 toRows' k | k == r*c = []
164 | otherwise = subVector k c v : toRows' (k+c)
165
166-- | Creates a matrix from a list of vectors, as columns
167fromColumns :: Field t => [Vector t] -> Matrix t
168fromColumns m = trans . fromRows $ m
169
170-- | Creates a list of vectors from the columns of a matrix
171toColumns :: Field t => Matrix t -> [Vector t]
172toColumns m = toRows . trans $ m
173
174-- | creates a matrix from a vertical list of matrices
175joinVert :: Field t => [Matrix t] -> Matrix t
176joinVert ms = case common cols ms of
177 Nothing -> error "joinVert on matrices with different number of columns"
178 Just c -> reshape c $ join (map cdat ms)
179
180-- | creates a matrix from a horizontal list of matrices
181joinHoriz :: Field t => [Matrix t] -> Matrix t
182joinHoriz ms = trans. joinVert . map trans $ ms
183
184-- | creates a complex vector from vectors with real and imaginary parts
185toComplex :: (Vector Double, Vector Double) -> Vector (Complex Double)
186toComplex (r,i) = asComplex $ cdat $ fromColumns [r,i]
187
188-- | obtains the complex conjugate of a complex vector
189conj :: Vector (Complex Double) -> Vector (Complex Double)
190conj v = asComplex $ cdat $ reshape 2 (asReal v) `mulC` diag (fromList [1,-1])
191 where mulC = multiply RowMajor
192
193comp v = toComplex (v,constant (dim v) 0)
194
195------------------------------------------------------------------------------
196
197-- | Reverse rows
198flipud :: Field t => Matrix t -> Matrix t
199flipud m = fromRows . reverse . toRows $ m
200
201-- | Reverse columns
202fliprl :: Field t => Matrix t -> Matrix t
203fliprl m = fromColumns . reverse . toColumns $ m
204
205----------------------------------------------------------------- 134-----------------------------------------------------------------
206 135
207liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes 136liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes
@@ -330,13 +259,25 @@ diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..
330 delta i j | i==j = 1 259 delta i j | i==j = 1
331 | otherwise = 0 260 | otherwise = 0
332 261
333diagRect s r c 262-- | creates a Matrix from a list of vectors
334 | dim s < min r c = error "diagRect" 263fromRows :: Field t => [Vector t] -> Matrix t
335 | r == c = diag s 264fromRows vs = case common dim vs of
336 | r < c = trans $ diagRect s c r 265 Nothing -> error "fromRows applied to [] or to vectors with different sizes"
337 | r > c = joinVert [diag s , zeros (r-c,c)] 266 Just c -> reshape c (join vs)
338 where zeros (r,c) = reshape c $ constant (r*c) 0
339 267
340takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] 268-- | extracts the rows of a matrix as a list of vectors
269toRows :: Storable t => Matrix t -> [Vector t]
270toRows m = toRows' 0 where
271 v = cdat m
272 r = rows m
273 c = cols m
274 toRows' k | k == r*c = []
275 | otherwise = subVector k c v : toRows' (k+c)
341 276
342ident n = diag (constant n 1) 277-- | Creates a matrix from a list of vectors, as columns
278fromColumns :: Field t => [Vector t] -> Matrix t
279fromColumns m = trans . fromRows $ m
280
281-- | Creates a list of vectors from the columns of a matrix
282toColumns :: Field t => Matrix t -> [Vector t]
283toColumns m = toRows . trans $ m
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 67dcb09..123270d 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -1,4 +1,3 @@
1--{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2----------------------------------------------------------------------------- 1-----------------------------------------------------------------------------
3-- | 2-- |
4-- Module : Data.Packed.Internal.Tensor 3-- Module : Data.Packed.Internal.Tensor
@@ -9,15 +8,17 @@
9-- Stability : provisional 8-- Stability : provisional
10-- Portability : portable (uses FFI) 9-- Portability : portable (uses FFI)
11-- 10--
12-- Fundamental types 11-- basic tensor operations
13-- 12--
14----------------------------------------------------------------------------- 13-----------------------------------------------------------------------------
15 14
16module Data.Packed.Internal.Tensor where 15module Data.Packed.Internal.Tensor where
17 16
17import Data.Packed.Internal
18import Data.Packed.Internal.Vector 18import Data.Packed.Internal.Vector
19import Data.Packed.Internal.Matrix 19import Data.Packed.Internal.Matrix
20import Foreign.Storable 20import Foreign.Storable
21import Data.List(sort)
21 22
22data IdxTp = Covariant | Contravariant deriving (Show,Eq) 23data IdxTp = Covariant | Contravariant deriving (Show,Eq)
23 24
@@ -99,3 +100,10 @@ compatIdxAux (n1,(t1,_)) (n2, (t2,_)) = t1 /= t2 && n1 == n2
99compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where 100compatIdx t1 n1 t2 n2 = compatIdxAux d1 d2 where
100 d1 = head $ snd $ fst $ findIdx n1 t1 101 d1 = head $ snd $ fst $ findIdx n1 t1
101 d2 = head $ snd $ fst $ findIdx n2 t2 102 d2 = head $ snd $ fst $ findIdx n2 t2
103
104names t = sort $ map (snd.snd) (dims t)
105
106normal t = tridx (names t) t
107
108contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ]
109
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index 4836bdb..125df1e 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -1,4 +1,4 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} 1{-# OPTIONS_GHC -fglasgow-exts #-}
2----------------------------------------------------------------------------- 2-----------------------------------------------------------------------------
3-- | 3-- |
4-- Module : Data.Packed.Internal.Vector 4-- Module : Data.Packed.Internal.Vector
@@ -9,70 +9,16 @@
9-- Stability : provisional 9-- Stability : provisional
10-- Portability : portable (uses FFI) 10-- Portability : portable (uses FFI)
11-- 11--
12-- Fundamental types 12-- Vector implementation
13-- 13--
14----------------------------------------------------------------------------- 14-----------------------------------------------------------------------------
15 15
16module Data.Packed.Internal.Vector where 16module Data.Packed.Internal.Vector where
17 17
18import Data.Packed.Internal.Common
18import Foreign 19import Foreign
19import Complex 20import Complex
20import Control.Monad(when) 21import Control.Monad(when)
21import Debug.Trace
22import Data.List(transpose,intersperse)
23import Data.Typeable
24import Data.Maybe(fromJust)
25
26debug x = trace (show x) x
27
28----------------------------------------------------------------------
29instance (Storable a, RealFloat a) => Storable (Complex a) where --
30 alignment x = alignment (realPart x) --
31 sizeOf x = 2 * sizeOf (realPart x) --
32 peek p = do --
33 [re,im] <- peekArray 2 (castPtr p) --
34 return (re :+ im) --
35 poke p (a :+ b) = pokeArray (castPtr p) [a,b] --
36----------------------------------------------------------------------
37
38on f g = \x y -> f (g x) (g y)
39
40(//) :: x -> (x -> y) -> y
41infixl 0 //
42(//) = flip ($)
43
44errorCode 1000 = "bad size"
45errorCode 1001 = "bad function code"
46errorCode 1002 = "memory problem"
47errorCode 1003 = "bad file"
48errorCode 1004 = "singular"
49errorCode 1005 = "didn't converge"
50errorCode n = "code "++show n
51
52check msg ls f = do
53 err <- f
54 when (err/=0) (error (msg++": "++errorCode err))
55 mapM_ (touchForeignPtr . fptr) ls
56 return ()
57
58class (Storable a, Typeable a) => Field a where
59instance (Storable a, Typeable a) => Field a where
60
61isReal w x = typeOf (undefined :: Double) == typeOf (w x)
62isComp w x = typeOf (undefined :: Complex Double) == typeOf (w x)
63baseOf v = (v `at` 0)
64
65scast :: forall a . forall b . (Typeable a, Typeable b) => a -> b
66scast = fromJust . cast
67
68
69
70----------------------------------------------------------------------
71
72data Vector t = V { dim :: Int
73 , fptr :: ForeignPtr t
74 , ptr :: Ptr t
75 } deriving Typeable
76 22
77type Vc t s = Int -> Ptr t -> s 23type Vc t s = Int -> Ptr t -> s
78infixr 5 :> 24infixr 5 :>
@@ -81,6 +27,8 @@ type t :> s = Vc t s
81vec :: Vector t -> (Vc t s) -> s 27vec :: Vector t -> (Vc t s) -> s
82vec v f = f (dim v) (ptr v) 28vec v f = f (dim v) (ptr v)
83 29
30baseOf v = (v `at` 0)
31
84createVector :: Storable a => Int -> IO (Vector a) 32createVector :: Storable a => Int -> IO (Vector a)
85createVector n = do 33createVector n = do
86 when (n <= 0) $ error ("trying to createVector of dim "++show n) 34 when (n <= 0) $ error ("trying to createVector of dim "++show n)