summaryrefslogtreecommitdiff
path: root/packages/base/src/Data
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Data')
-rw-r--r--packages/base/src/Data/Packed.hs25
-rw-r--r--packages/base/src/Data/Packed/Development.hs31
-rw-r--r--packages/base/src/Data/Packed/Foreign.hs99
-rw-r--r--packages/base/src/Data/Packed/Internal.hs26
-rw-r--r--packages/base/src/Data/Packed/Internal/Common.hs160
-rw-r--r--packages/base/src/Data/Packed/Internal/Matrix.hs422
-rw-r--r--packages/base/src/Data/Packed/Internal/Signatures.hs70
-rw-r--r--packages/base/src/Data/Packed/Internal/Vector.hs471
-rw-r--r--packages/base/src/Data/Packed/Matrix.hs490
-rw-r--r--packages/base/src/Data/Packed/ST.hs178
-rw-r--r--packages/base/src/Data/Packed/Vector.hs96
11 files changed, 2068 insertions, 0 deletions
diff --git a/packages/base/src/Data/Packed.hs b/packages/base/src/Data/Packed.hs
new file mode 100644
index 0000000..c66718a
--- /dev/null
+++ b/packages/base/src/Data/Packed.hs
@@ -0,0 +1,25 @@
1-----------------------------------------------------------------------------
2{- |
3Module : Data.Packed
4Copyright : (c) Alberto Ruiz 2006-2014
5License : BSD3
6Maintainer : Alberto Ruiz
7Stability : provisional
8
9Types for dense 'Vector' and 'Matrix' of 'Storable' elements.
10
11-}
12-----------------------------------------------------------------------------
13
14module Data.Packed (
15 -- * Vector
16 --
17 -- | Vectors are @Data.Vector.Storable.Vector@ from the \"vector\" package.
18 module Data.Packed.Vector,
19 -- * Matrix
20 module Data.Packed.Matrix,
21) where
22
23import Data.Packed.Vector
24import Data.Packed.Matrix
25
diff --git a/packages/base/src/Data/Packed/Development.hs b/packages/base/src/Data/Packed/Development.hs
new file mode 100644
index 0000000..777b6c5
--- /dev/null
+++ b/packages/base/src/Data/Packed/Development.hs
@@ -0,0 +1,31 @@
1
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Development
5-- Copyright : (c) Alberto Ruiz 2009
6-- License : BSD3
7-- Maintainer : Alberto Ruiz
8-- Stability : provisional
9-- Portability : portable
10--
11-- The library can be easily extended with additional foreign functions
12-- using the tools in this module. Illustrative usage examples can be found
13-- in the @examples\/devel@ folder included in the package.
14--
15-----------------------------------------------------------------------------
16
17module Data.Packed.Development (
18 createVector, createMatrix,
19 vec, mat,
20 app1, app2, app3, app4,
21 app5, app6, app7, app8, app9, app10,
22 MatrixOrder(..), orderOf, cmat, fmat,
23 matrixFromVector,
24 unsafeFromForeignPtr,
25 unsafeToForeignPtr,
26 check, (//),
27 at', atM'
28) where
29
30import Data.Packed.Internal
31
diff --git a/packages/base/src/Data/Packed/Foreign.hs b/packages/base/src/Data/Packed/Foreign.hs
new file mode 100644
index 0000000..efa51ca
--- /dev/null
+++ b/packages/base/src/Data/Packed/Foreign.hs
@@ -0,0 +1,99 @@
1{-# LANGUAGE MagicHash, UnboxedTuples #-}
2-- | FFI and hmatrix helpers.
3--
4-- Sample usage, to upload a perspective matrix to a shader.
5--
6-- @ glUniformMatrix4fv 0 1 (fromIntegral gl_TRUE) \`appMatrix\` perspective 0.01 100 (pi\/2) (4\/3)
7-- @
8--
9module Data.Packed.Foreign
10 ( app
11 , appVector, appVectorLen
12 , appMatrix, appMatrixLen, appMatrixRaw, appMatrixRawLen
13 , unsafeMatrixToVector, unsafeMatrixToForeignPtr
14 ) where
15import Data.Packed.Internal
16import qualified Data.Vector.Storable as S
17import Foreign (Ptr, ForeignPtr, Storable)
18import Foreign.C.Types (CInt)
19import GHC.Base (IO(..), realWorld#)
20
21{-# INLINE unsafeInlinePerformIO #-}
22-- | If we use unsafePerformIO, it may not get inlined, so in a function that returns IO (which are all safe uses of app* in this module), there would be
23-- unecessary calls to unsafePerformIO or its internals.
24unsafeInlinePerformIO :: IO a -> a
25unsafeInlinePerformIO (IO f) = case f realWorld# of
26 (# _, x #) -> x
27
28{-# INLINE app #-}
29-- | Only useful since it is left associated with a precedence of 1, unlike 'Prelude.$', which is right associative.
30-- e.g.
31--
32-- @
33-- someFunction
34-- \`appMatrixLen\` m
35-- \`appVectorLen\` v
36-- \`app\` other
37-- \`app\` arguments
38-- \`app\` go here
39-- @
40--
41-- One could also write:
42--
43-- @
44-- (someFunction
45-- \`appMatrixLen\` m
46-- \`appVectorLen\` v)
47-- other
48-- arguments
49-- (go here)
50-- @
51--
52app :: (a -> b) -> a -> b
53app f = f
54
55{-# INLINE appVector #-}
56appVector :: Storable a => (Ptr a -> b) -> Vector a -> b
57appVector f x = unsafeInlinePerformIO (S.unsafeWith x (return . f))
58
59{-# INLINE appVectorLen #-}
60appVectorLen :: Storable a => (CInt -> Ptr a -> b) -> Vector a -> b
61appVectorLen f x = unsafeInlinePerformIO (S.unsafeWith x (return . f (fromIntegral (S.length x))))
62
63{-# INLINE appMatrix #-}
64appMatrix :: Element a => (Ptr a -> b) -> Matrix a -> b
65appMatrix f x = unsafeInlinePerformIO (S.unsafeWith (flatten x) (return . f))
66
67{-# INLINE appMatrixLen #-}
68appMatrixLen :: Element a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
69appMatrixLen f x = unsafeInlinePerformIO (S.unsafeWith (flatten x) (return . f r c))
70 where
71 r = fromIntegral (rows x)
72 c = fromIntegral (cols x)
73
74{-# INLINE appMatrixRaw #-}
75appMatrixRaw :: Storable a => (Ptr a -> b) -> Matrix a -> b
76appMatrixRaw f x = unsafeInlinePerformIO (S.unsafeWith (xdat x) (return . f))
77
78{-# INLINE appMatrixRawLen #-}
79appMatrixRawLen :: Element a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
80appMatrixRawLen f x = unsafeInlinePerformIO (S.unsafeWith (xdat x) (return . f r c))
81 where
82 r = fromIntegral (rows x)
83 c = fromIntegral (cols x)
84
85infixl 1 `app`
86infixl 1 `appVector`
87infixl 1 `appMatrix`
88infixl 1 `appMatrixRaw`
89
90{-# INLINE unsafeMatrixToVector #-}
91-- | This will disregard the order of the matrix, and simply return it as-is.
92-- If the order of the matrix is RowMajor, this function is identical to 'flatten'.
93unsafeMatrixToVector :: Matrix a -> Vector a
94unsafeMatrixToVector = xdat
95
96{-# INLINE unsafeMatrixToForeignPtr #-}
97unsafeMatrixToForeignPtr :: Storable a => Matrix a -> (ForeignPtr a, Int)
98unsafeMatrixToForeignPtr m = S.unsafeToForeignPtr0 (xdat m)
99
diff --git a/packages/base/src/Data/Packed/Internal.hs b/packages/base/src/Data/Packed/Internal.hs
new file mode 100644
index 0000000..537e51e
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal.hs
@@ -0,0 +1,26 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Data.Packed.Internal
4-- Copyright : (c) Alberto Ruiz 2007
5-- License : GPL-style
6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es>
8-- Stability : provisional
9-- Portability : portable
10--
11-- Reexports all internal modules
12--
13-----------------------------------------------------------------------------
14-- #hide
15
16module Data.Packed.Internal (
17 module Data.Packed.Internal.Common,
18 module Data.Packed.Internal.Signatures,
19 module Data.Packed.Internal.Vector,
20 module Data.Packed.Internal.Matrix,
21) where
22
23import Data.Packed.Internal.Common
24import Data.Packed.Internal.Signatures
25import Data.Packed.Internal.Vector
26import Data.Packed.Internal.Matrix
diff --git a/packages/base/src/Data/Packed/Internal/Common.hs b/packages/base/src/Data/Packed/Internal/Common.hs
new file mode 100644
index 0000000..615bbdf
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Common.hs
@@ -0,0 +1,160 @@
1{-# LANGUAGE CPP #-}
2-- |
3-- Module : Data.Packed.Internal.Common
4-- Copyright : (c) Alberto Ruiz 2007
5-- License : BSD3
6-- Maintainer : Alberto Ruiz
7-- Stability : provisional
8--
9--
10-- Development utilities.
11--
12
13
14module Data.Packed.Internal.Common(
15 Adapt,
16 app1, app2, app3, app4,
17 app5, app6, app7, app8, app9, app10,
18 (//), check, mbCatch,
19 splitEvery, common, compatdim,
20 fi,
21 table,
22 finit
23) where
24
25import Control.Monad(when)
26import Foreign.C.Types
27import Foreign.Storable.Complex()
28import Data.List(transpose,intersperse)
29import Control.Exception as E
30
31-- | @splitEvery 3 [1..9] == [[1,2,3],[4,5,6],[7,8,9]]@
32splitEvery :: Int -> [a] -> [[a]]
33splitEvery _ [] = []
34splitEvery k l = take k l : splitEvery k (drop k l)
35
36-- | obtains the common value of a property of a list
37common :: (Eq a) => (b->a) -> [b] -> Maybe a
38common f = commonval . map f where
39 commonval :: (Eq a) => [a] -> Maybe a
40 commonval [] = Nothing
41 commonval [a] = Just a
42 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
43
44-- | common value with \"adaptable\" 1
45compatdim :: [Int] -> Maybe Int
46compatdim [] = Nothing
47compatdim [a] = Just a
48compatdim (a:b:xs)
49 | a==b = compatdim (b:xs)
50 | a==1 = compatdim (b:xs)
51 | b==1 = compatdim (a:xs)
52 | otherwise = Nothing
53
54-- | Formatting tool
55table :: String -> [[String]] -> String
56table sep as = unlines . map unwords' $ transpose mtp where
57 mt = transpose as
58 longs = map (maximum . map length) mt
59 mtp = zipWith (\a b -> map (pad a) b) longs mt
60 pad n str = replicate (n - length str) ' ' ++ str
61 unwords' = concat . intersperse sep
62
63-- | postfix function application (@flip ($)@)
64(//) :: x -> (x -> y) -> y
65infixl 0 //
66(//) = flip ($)
67
68-- | specialized fromIntegral
69fi :: Int -> CInt
70fi = fromIntegral
71
72-- hmm..
73ww2 w1 o1 w2 o2 f = w1 o1 $ w2 o2 . f
74ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ ww2 w2 o2 w3 o3 . f
75ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ ww3 w2 o2 w3 o3 w4 o4 . f
76ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 f = w1 o1 $ ww4 w2 o2 w3 o3 w4 o4 w5 o5 . f
77ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 f = w1 o1 $ ww5 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 . f
78ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 f = w1 o1 $ ww6 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 . f
79ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 f = w1 o1 $ ww7 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 . f
80ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 f = w1 o1 $ ww8 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 . f
81ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 f = w1 o1 $ ww9 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 . f
82
83type Adapt f t r = t -> ((f -> r) -> IO()) -> IO()
84
85type Adapt1 f t1 = Adapt f t1 (IO CInt) -> t1 -> String -> IO()
86type Adapt2 f t1 r1 t2 = Adapt f t1 r1 -> t1 -> Adapt1 r1 t2
87type Adapt3 f t1 r1 t2 r2 t3 = Adapt f t1 r1 -> t1 -> Adapt2 r1 t2 r2 t3
88type Adapt4 f t1 r1 t2 r2 t3 r3 t4 = Adapt f t1 r1 -> t1 -> Adapt3 r1 t2 r2 t3 r3 t4
89type Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 = Adapt f t1 r1 -> t1 -> Adapt4 r1 t2 r2 t3 r3 t4 r4 t5
90type Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 = Adapt f t1 r1 -> t1 -> Adapt5 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6
91type Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 = Adapt f t1 r1 -> t1 -> Adapt6 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7
92type Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 = Adapt f t1 r1 -> t1 -> Adapt7 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8
93type Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 = Adapt f t1 r1 -> t1 -> Adapt8 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9
94type Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 = Adapt f t1 r1 -> t1 -> Adapt9 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10
95
96app1 :: f -> Adapt1 f t1
97app2 :: f -> Adapt2 f t1 r1 t2
98app3 :: f -> Adapt3 f t1 r1 t2 r2 t3
99app4 :: f -> Adapt4 f t1 r1 t2 r2 t3 r3 t4
100app5 :: f -> Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5
101app6 :: f -> Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6
102app7 :: f -> Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7
103app8 :: f -> Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8
104app9 :: f -> Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9
105app10 :: f -> Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10
106
107app1 f w1 o1 s = w1 o1 $ \a1 -> f // a1 // check s
108app2 f w1 o1 w2 o2 s = ww2 w1 o1 w2 o2 $ \a1 a2 -> f // a1 // a2 // check s
109app3 f w1 o1 w2 o2 w3 o3 s = ww3 w1 o1 w2 o2 w3 o3 $
110 \a1 a2 a3 -> f // a1 // a2 // a3 // check s
111app4 f w1 o1 w2 o2 w3 o3 w4 o4 s = ww4 w1 o1 w2 o2 w3 o3 w4 o4 $
112 \a1 a2 a3 a4 -> f // a1 // a2 // a3 // a4 // check s
113app5 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 s = ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 $
114 \a1 a2 a3 a4 a5 -> f // a1 // a2 // a3 // a4 // a5 // check s
115app6 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 s = ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 $
116 \a1 a2 a3 a4 a5 a6 -> f // a1 // a2 // a3 // a4 // a5 // a6 // check s
117app7 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 s = ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 $
118 \a1 a2 a3 a4 a5 a6 a7 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // check s
119app8 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 s = ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 $
120 \a1 a2 a3 a4 a5 a6 a7 a8 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // check s
121app9 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 s = ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 $
122 \a1 a2 a3 a4 a5 a6 a7 a8 a9 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // check s
123app10 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 s = ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 $
124 \a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // a10 // check s
125
126
127
128-- GSL error codes are <= 1024
129-- | error codes for the auxiliary functions required by the wrappers
130errorCode :: CInt -> String
131errorCode 2000 = "bad size"
132errorCode 2001 = "bad function code"
133errorCode 2002 = "memory problem"
134errorCode 2003 = "bad file"
135errorCode 2004 = "singular"
136errorCode 2005 = "didn't converge"
137errorCode 2006 = "the input matrix is not positive definite"
138errorCode 2007 = "not yet supported in this OS"
139errorCode n = "code "++show n
140
141
142-- | clear the fpu
143foreign import ccall unsafe "asm_finit" finit :: IO ()
144
145-- | check the error code
146check :: String -> IO CInt -> IO ()
147check msg f = do
148#if FINIT
149 finit
150#endif
151 err <- f
152 when (err/=0) $ error (msg++": "++errorCode err)
153 return ()
154
155-- | Error capture and conversion to Maybe
156mbCatch :: IO x -> IO (Maybe x)
157mbCatch act = E.catch (Just `fmap` act) f
158 where f :: SomeException -> IO (Maybe x)
159 f _ = return Nothing
160
diff --git a/packages/base/src/Data/Packed/Internal/Matrix.hs b/packages/base/src/Data/Packed/Internal/Matrix.hs
new file mode 100644
index 0000000..9b831cc
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Matrix.hs
@@ -0,0 +1,422 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-}
5
6-- |
7-- Module : Data.Packed.Internal.Matrix
8-- Copyright : (c) Alberto Ruiz 2007
9-- License : BSD3
10-- Maintainer : Alberto Ruiz
11-- Stability : provisional
12--
13-- Internal matrix representation
14--
15
16module Data.Packed.Internal.Matrix(
17 Matrix(..), rows, cols, cdat, fdat,
18 MatrixOrder(..), orderOf,
19 createMatrix, mat,
20 cmat, fmat,
21 toLists, flatten, reshape,
22 Element(..),
23 trans,
24 fromRows, toRows, fromColumns, toColumns,
25 matrixFromVector,
26 subMatrix,
27 liftMatrix, liftMatrix2,
28 (@@>), atM',
29 singleton,
30 emptyM,
31 size, shSize, conformVs, conformMs, conformVTo, conformMTo
32) where
33
34import Data.Packed.Internal.Common
35import Data.Packed.Internal.Signatures
36import Data.Packed.Internal.Vector
37
38import Foreign.Marshal.Alloc(alloca, free)
39import Foreign.Marshal.Array(newArray)
40import Foreign.Ptr(Ptr, castPtr)
41import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf)
42import Data.Complex(Complex)
43import Foreign.C.Types
44import System.IO.Unsafe(unsafePerformIO)
45import Control.DeepSeq
46
47-----------------------------------------------------------------
48
49{- Design considerations for the Matrix Type
50 -----------------------------------------
51
52- we must easily handle both row major and column major order,
53 for bindings to LAPACK and GSL/C
54
55- we'd like to simplify redundant matrix transposes:
56 - Some of them arise from the order requirements of some functions
57 - some functions (matrix product) admit transposed arguments
58
59- maybe we don't really need this kind of simplification:
60 - more complex code
61 - some computational overhead
62 - only appreciable gain in code with a lot of redundant transpositions
63 and cheap matrix computations
64
65- we could carry both the matrix and its (lazily computed) transpose.
66 This may save some transpositions, but it is necessary to keep track of the
67 data which is actually computed to be used by functions like the matrix product
68 which admit both orders.
69
70- but if we need the transposed data and it is not in the structure, we must make
71 sure that we touch the same foreignptr that is used in the computation.
72
73- a reasonable solution is using two constructors for a matrix. Transposition just
74 "flips" the constructor. Actual data transposition is not done if followed by a
75 matrix product or another transpose.
76
77-}
78
79data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
80
81transOrder RowMajor = ColumnMajor
82transOrder ColumnMajor = RowMajor
83{- | Matrix representation suitable for GSL and LAPACK computations.
84
85The elements are stored in a continuous memory array.
86
87-}
88
89data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int
90 , icols :: {-# UNPACK #-} !Int
91 , xdat :: {-# UNPACK #-} !(Vector t)
92 , order :: !MatrixOrder }
93-- RowMajor: preferred by C, fdat may require a transposition
94-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
95
96cdat = xdat
97fdat = xdat
98
99rows :: Matrix t -> Int
100rows = irows
101
102cols :: Matrix t -> Int
103cols = icols
104
105orderOf :: Matrix t -> MatrixOrder
106orderOf = order
107
108
109-- | Matrix transpose.
110trans :: Matrix t -> Matrix t
111trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o}
112
113cmat :: (Element t) => Matrix t -> Matrix t
114cmat m@Matrix{order = RowMajor} = m
115cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor}
116
117fmat :: (Element t) => Matrix t -> Matrix t
118fmat m@Matrix{order = ColumnMajor} = m
119fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor}
120
121-- C-Haskell matrix adapter
122-- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r
123
124mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
125mat a f =
126 unsafeWith (xdat a) $ \p -> do
127 let m g = do
128 g (fi (rows a)) (fi (cols a)) p
129 f m
130
131{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
132
133>>> flatten (ident 3)
134fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
135
136-}
137flatten :: Element t => Matrix t -> Vector t
138flatten = xdat . cmat
139
140{-
141type Mt t s = Int -> Int -> Ptr t -> s
142
143infixr 6 ::>
144type t ::> s = Mt t s
145-}
146
147-- | the inverse of 'Data.Packed.Matrix.fromLists'
148toLists :: (Element t) => Matrix t -> [[t]]
149toLists m = splitEvery (cols m) . toList . flatten $ m
150
151-- | Create a matrix from a list of vectors.
152-- All vectors must have the same dimension,
153-- or dimension 1, which is are automatically expanded.
154fromRows :: Element t => [Vector t] -> Matrix t
155fromRows [] = emptyM 0 0
156fromRows vs = case compatdim (map dim vs) of
157 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
158 Just 0 -> emptyM r 0
159 Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
160 where
161 r = length vs
162 adapt c v
163 | c == 0 = fromList[]
164 | dim v == c = v
165 | otherwise = constantD (v@>0) c
166
167-- | extracts the rows of a matrix as a list of vectors
168toRows :: Element t => Matrix t -> [Vector t]
169toRows m
170 | c == 0 = replicate r (fromList[])
171 | otherwise = toRows' 0
172 where
173 v = flatten m
174 r = rows m
175 c = cols m
176 toRows' k | k == r*c = []
177 | otherwise = subVector k c v : toRows' (k+c)
178
179-- | Creates a matrix from a list of vectors, as columns
180fromColumns :: Element t => [Vector t] -> Matrix t
181fromColumns m = trans . fromRows $ m
182
183-- | Creates a list of vectors from the columns of a matrix
184toColumns :: Element t => Matrix t -> [Vector t]
185toColumns m = toRows . trans $ m
186
187-- | Reads a matrix position.
188(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
189infixl 9 @@>
190m@Matrix {irows = r, icols = c} @@> (i,j)
191 | safe = if i<0 || i>=r || j<0 || j>=c
192 then error "matrix indexing out of range"
193 else atM' m i j
194 | otherwise = atM' m i j
195{-# INLINE (@@>) #-}
196
197-- Unsafe matrix access without range checking
198atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j)
199atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
200{-# INLINE atM' #-}
201
202------------------------------------------------------------------
203
204matrixFromVector o r c v
205 | r * c == dim v = m
206 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
207 where
208 m = Matrix { irows = r, icols = c, xdat = v, order = o }
209
210-- allocates memory for a new matrix
211createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
212createMatrix ord r c = do
213 p <- createVector (r*c)
214 return (matrixFromVector ord r c p)
215
216{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@
217where r is the desired number of rows.)
218
219>>> reshape 4 (fromList [1..12])
220(3><4)
221 [ 1.0, 2.0, 3.0, 4.0
222 , 5.0, 6.0, 7.0, 8.0
223 , 9.0, 10.0, 11.0, 12.0 ]
224
225-}
226reshape :: Storable t => Int -> Vector t -> Matrix t
227reshape 0 v = matrixFromVector RowMajor 0 0 v
228reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
229
230singleton x = reshape 1 (fromList [x])
231
232-- | application of a vector function on the flattened matrix elements
233liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
234liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d)
235
236-- | application of a vector function on the flattened matrices elements
237liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
238liftMatrix2 f m1 m2
239 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
240 | otherwise = case orderOf m1 of
241 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2))
242 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
243
244
245compat :: Matrix a -> Matrix b -> Bool
246compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
247
248------------------------------------------------------------------
249
250{- | Supported matrix elements.
251
252 This class provides optimized internal
253 operations for selected element types.
254 It provides unoptimised defaults for any 'Storable' type,
255 so you can create instances simply as:
256 @instance Element Foo@.
257-}
258class (Storable a) => Element a where
259 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
260 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
261 -> Matrix a -> Matrix a
262 subMatrixD = subMatrix'
263 transdata :: Int -> Vector a -> Int -> Vector a
264 transdata = transdataP -- transdata'
265 constantD :: a -> Int -> Vector a
266 constantD = constantP -- constant'
267
268
269instance Element Float where
270 transdata = transdataAux ctransF
271 constantD = constantAux cconstantF
272
273instance Element Double where
274 transdata = transdataAux ctransR
275 constantD = constantAux cconstantR
276
277instance Element (Complex Float) where
278 transdata = transdataAux ctransQ
279 constantD = constantAux cconstantQ
280
281instance Element (Complex Double) where
282 transdata = transdataAux ctransC
283 constantD = constantAux cconstantC
284
285-------------------------------------------------------------------
286
287transdataAux fun c1 d c2 =
288 if noneed
289 then d
290 else unsafePerformIO $ do
291 v <- createVector (dim d)
292 unsafeWith d $ \pd ->
293 unsafeWith v $ \pv ->
294 fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
295 return v
296 where r1 = dim d `div` c1
297 r2 = dim d `div` c2
298 noneed = dim d == 0 || r1 == 1 || c1 == 1
299
300transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
301transdataP c1 d c2 =
302 if noneed
303 then d
304 else unsafePerformIO $ do
305 v <- createVector (dim d)
306 unsafeWith d $ \pd ->
307 unsafeWith v $ \pv ->
308 ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP"
309 return v
310 where r1 = dim d `div` c1
311 r2 = dim d `div` c2
312 sz = sizeOf (d @> 0)
313 noneed = dim d == 0 || r1 == 1 || c1 == 1
314
315foreign import ccall unsafe "transF" ctransF :: TFMFM
316foreign import ccall unsafe "transR" ctransR :: TMM
317foreign import ccall unsafe "transQ" ctransQ :: TQMQM
318foreign import ccall unsafe "transC" ctransC :: TCMCM
319foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
320
321----------------------------------------------------------------------
322
323constantAux fun x n = unsafePerformIO $ do
324 v <- createVector n
325 px <- newArray [x]
326 app1 (fun px) vec v "constantAux"
327 free px
328 return v
329
330foreign import ccall unsafe "constantF" cconstantF :: Ptr Float -> TF
331
332foreign import ccall unsafe "constantR" cconstantR :: Ptr Double -> TV
333
334foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
335
336foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV
337
338constantP :: Storable a => a -> Int -> Vector a
339constantP a n = unsafePerformIO $ do
340 let sz = sizeOf a
341 v <- createVector n
342 unsafeWith v $ \p -> do
343 alloca $ \k -> do
344 poke k a
345 cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP"
346 return v
347foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt
348
349----------------------------------------------------------------------
350
351-- | Extracts a submatrix from a matrix.
352subMatrix :: Element a
353 => (Int,Int) -- ^ (r0,c0) starting position
354 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
355 -> Matrix a -- ^ input matrix
356 -> Matrix a -- ^ result
357subMatrix (r0,c0) (rt,ct) m
358 | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) &&
359 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
360 | otherwise = error $ "wrong subMatrix "++
361 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
362
363subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do
364 w <- createVector (rt*ct)
365 unsafeWith v $ \p ->
366 unsafeWith w $ \q -> do
367 let go (-1) _ = return ()
368 go !i (-1) = go (i-1) (ct-1)
369 go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0)
370 pokeElemOff q (i*ct+j) x
371 go i (j-1)
372 go (rt-1) (ct-1)
373 return w
374
375subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor
376subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m)
377
378--------------------------------------------------------------------------
379
380maxZ xs = if minimum xs == 0 then 0 else maximum xs
381
382conformMs ms = map (conformMTo (r,c)) ms
383 where
384 r = maxZ (map rows ms)
385 c = maxZ (map cols ms)
386
387
388conformVs vs = map (conformVTo n) vs
389 where
390 n = maxZ (map dim vs)
391
392conformMTo (r,c) m
393 | size m == (r,c) = m
394 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
395 | size m == (r,1) = repCols c m
396 | size m == (1,c) = repRows r m
397 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
398
399conformVTo n v
400 | dim v == n = v
401 | dim v == 1 = constantD (v@>0) n
402 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
403
404repRows n x = fromRows (replicate n (flatten x))
405repCols n x = fromColumns (replicate n (flatten x))
406
407size m = (rows m, cols m)
408
409shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
410
411emptyM r c = matrixFromVector RowMajor r c (fromList[])
412
413----------------------------------------------------------------------
414
415instance (Storable t, NFData t) => NFData (Matrix t)
416 where
417 rnf m | d > 0 = rnf (v @> 0)
418 | otherwise = ()
419 where
420 d = dim v
421 v = xdat m
422
diff --git a/packages/base/src/Data/Packed/Internal/Signatures.hs b/packages/base/src/Data/Packed/Internal/Signatures.hs
new file mode 100644
index 0000000..acc3070
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Signatures.hs
@@ -0,0 +1,70 @@
1-- |
2-- Module : Data.Packed.Internal.Signatures
3-- Copyright : (c) Alberto Ruiz 2009
4-- License : BSD3
5-- Maintainer : Alberto Ruiz
6-- Stability : provisional
7--
8-- Signatures of the C functions.
9--
10
11
12module Data.Packed.Internal.Signatures where
13
14import Foreign.Ptr(Ptr)
15import Data.Complex(Complex)
16import Foreign.C.Types(CInt)
17
18type PF = Ptr Float --
19type PD = Ptr Double --
20type PQ = Ptr (Complex Float) --
21type PC = Ptr (Complex Double) --
22type TF = CInt -> PF -> IO CInt --
23type TFF = CInt -> PF -> TF --
24type TFV = CInt -> PF -> TV --
25type TVF = CInt -> PD -> TF --
26type TFFF = CInt -> PF -> TFF --
27type TV = CInt -> PD -> IO CInt --
28type TVV = CInt -> PD -> TV --
29type TVVV = CInt -> PD -> TVV --
30type TFM = CInt -> CInt -> PF -> IO CInt --
31type TFMFM = CInt -> CInt -> PF -> TFM --
32type TFMFMFM = CInt -> CInt -> PF -> TFMFM --
33type TM = CInt -> CInt -> PD -> IO CInt --
34type TMM = CInt -> CInt -> PD -> TM --
35type TVMM = CInt -> PD -> TMM --
36type TMVMM = CInt -> CInt -> PD -> TVMM --
37type TMMM = CInt -> CInt -> PD -> TMM --
38type TVM = CInt -> PD -> TM --
39type TVVM = CInt -> PD -> TVM --
40type TMV = CInt -> CInt -> PD -> TV --
41type TMMV = CInt -> CInt -> PD -> TMV --
42type TMVM = CInt -> CInt -> PD -> TVM --
43type TMMVM = CInt -> CInt -> PD -> TMVM --
44type TCM = CInt -> CInt -> PC -> IO CInt --
45type TCVCM = CInt -> PC -> TCM --
46type TCMCVCM = CInt -> CInt -> PC -> TCVCM --
47type TMCMCVCM = CInt -> CInt -> PD -> TCMCVCM --
48type TCMCMCVCM = CInt -> CInt -> PC -> TCMCVCM --
49type TCMCM = CInt -> CInt -> PC -> TCM --
50type TVCM = CInt -> PD -> TCM --
51type TCMVCM = CInt -> CInt -> PC -> TVCM --
52type TCMCMVCM = CInt -> CInt -> PC -> TCMVCM --
53type TCMCMCM = CInt -> CInt -> PC -> TCMCM --
54type TCV = CInt -> PC -> IO CInt --
55type TCVCV = CInt -> PC -> TCV --
56type TCVCVCV = CInt -> PC -> TCVCV --
57type TCVV = CInt -> PC -> TV --
58type TQV = CInt -> PQ -> IO CInt --
59type TQVQV = CInt -> PQ -> TQV --
60type TQVQVQV = CInt -> PQ -> TQVQV --
61type TQVF = CInt -> PQ -> TF --
62type TQM = CInt -> CInt -> PQ -> IO CInt --
63type TQMQM = CInt -> CInt -> PQ -> TQM --
64type TQMQMQM = CInt -> CInt -> PQ -> TQMQM --
65type TCMCV = CInt -> CInt -> PC -> TCV --
66type TVCV = CInt -> PD -> TCV --
67type TCVM = CInt -> PC -> TM --
68type TMCVM = CInt -> CInt -> PD -> TCVM --
69type TMMCVM = CInt -> CInt -> PD -> TMCVM --
70
diff --git a/packages/base/src/Data/Packed/Internal/Vector.hs b/packages/base/src/Data/Packed/Internal/Vector.hs
new file mode 100644
index 0000000..d0bc143
--- /dev/null
+++ b/packages/base/src/Data/Packed/Internal/Vector.hs
@@ -0,0 +1,471 @@
1{-# LANGUAGE MagicHash, CPP, UnboxedTuples, BangPatterns, FlexibleContexts #-}
2-- |
3-- Module : Data.Packed.Internal.Vector
4-- Copyright : (c) Alberto Ruiz 2007
5-- License : BSD3
6-- Maintainer : Alberto Ruiz
7-- Stability : provisional
8--
9-- Vector implementation
10--
11--------------------------------------------------------------------------------
12
13module Data.Packed.Internal.Vector (
14 Vector, dim,
15 fromList, toList, (|>),
16 vjoin, (@>), safe, at, at', subVector, takesV,
17 mapVector, mapVectorWithIndex, zipVectorWith, unzipVectorWith,
18 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_,
19 foldVector, foldVectorG, foldLoop, foldVectorWithIndex,
20 createVector, vec,
21 asComplex, asReal, float2DoubleV, double2FloatV,
22 stepF, stepD, condF, condD,
23 conjugateQ, conjugateC,
24 cloneVector,
25 unsafeToForeignPtr,
26 unsafeFromForeignPtr,
27 unsafeWith
28) where
29
30import Data.Packed.Internal.Common
31import Data.Packed.Internal.Signatures
32import Foreign.Marshal.Array(peekArray, copyArray, advancePtr)
33import Foreign.ForeignPtr(ForeignPtr, castForeignPtr)
34import Foreign.Ptr(Ptr)
35import Foreign.Storable(Storable, peekElemOff, pokeElemOff, sizeOf)
36import Foreign.C.Types
37import Data.Complex
38import Control.Monad(when)
39import System.IO.Unsafe(unsafePerformIO)
40
41#if __GLASGOW_HASKELL__ >= 605
42import GHC.ForeignPtr (mallocPlainForeignPtrBytes)
43#else
44import Foreign.ForeignPtr (mallocForeignPtrBytes)
45#endif
46
47import GHC.Base
48#if __GLASGOW_HASKELL__ < 612
49import GHC.IOBase hiding (liftIO)
50#endif
51
52import qualified Data.Vector.Storable as Vector
53import Data.Vector.Storable(Vector,
54 fromList,
55 unsafeToForeignPtr,
56 unsafeFromForeignPtr,
57 unsafeWith)
58
59
60-- | Number of elements
61dim :: (Storable t) => Vector t -> Int
62dim = Vector.length
63
64
65-- C-Haskell vector adapter
66-- vec :: Adapt (CInt -> Ptr t -> r) (Vector t) r
67vec :: (Storable t) => Vector t -> (((CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
68vec x f = unsafeWith x $ \p -> do
69 let v g = do
70 g (fi $ dim x) p
71 f v
72{-# INLINE vec #-}
73
74
75-- allocates memory for a new vector
76createVector :: Storable a => Int -> IO (Vector a)
77createVector n = do
78 when (n < 0) $ error ("trying to createVector of negative dim: "++show n)
79 fp <- doMalloc undefined
80 return $ unsafeFromForeignPtr fp 0 n
81 where
82 --
83 -- Use the much cheaper Haskell heap allocated storage
84 -- for foreign pointer space we control
85 --
86 doMalloc :: Storable b => b -> IO (ForeignPtr b)
87 doMalloc dummy = do
88#if __GLASGOW_HASKELL__ >= 605
89 mallocPlainForeignPtrBytes (n * sizeOf dummy)
90#else
91 mallocForeignPtrBytes (n * sizeOf dummy)
92#endif
93
94{- | creates a Vector from a list:
95
96@> fromList [2,3,5,7]
974 |> [2.0,3.0,5.0,7.0]@
98
99-}
100
101safeRead v = inlinePerformIO . unsafeWith v
102{-# INLINE safeRead #-}
103
104inlinePerformIO :: IO a -> a
105inlinePerformIO (IO m) = case m realWorld# of (# _, r #) -> r
106{-# INLINE inlinePerformIO #-}
107
108{- | extracts the Vector elements to a list
109
110>>> toList (linspace 5 (1,10))
111[1.0,3.25,5.5,7.75,10.0]
112
113-}
114toList :: Storable a => Vector a -> [a]
115toList v = safeRead v $ peekArray (dim v)
116
117{- | Create a vector from a list of elements and explicit dimension. The input
118 list is explicitly truncated if it is too long, so it may safely
119 be used, for instance, with infinite lists.
120
121>>> 5 |> [1..]
122fromList [1.0,2.0,3.0,4.0,5.0]
123
124-}
125(|>) :: (Storable a) => Int -> [a] -> Vector a
126infixl 9 |>
127n |> l = if length l' == n
128 then fromList l'
129 else error "list too short for |>"
130 where l' = take n l
131
132
133-- | access to Vector elements without range checking
134at' :: Storable a => Vector a -> Int -> a
135at' v n = safeRead v $ flip peekElemOff n
136{-# INLINE at' #-}
137
138--
139-- turn off bounds checking with -funsafe at configure time.
140-- ghc will optimise away the salways true case at compile time.
141--
142#if defined(UNSAFE)
143safe :: Bool
144safe = False
145#else
146safe = True
147#endif
148
149-- | access to Vector elements with range checking.
150at :: Storable a => Vector a -> Int -> a
151at v n
152 | safe = if n >= 0 && n < dim v
153 then at' v n
154 else error "vector index out of range"
155 | otherwise = at' v n
156{-# INLINE at #-}
157
158{- | takes a number of consecutive elements from a Vector
159
160>>> subVector 2 3 (fromList [1..10])
161fromList [3.0,4.0,5.0]
162
163-}
164subVector :: Storable t => Int -- ^ index of the starting element
165 -> Int -- ^ number of elements to extract
166 -> Vector t -- ^ source
167 -> Vector t -- ^ result
168subVector = Vector.slice
169
170
171{- | Reads a vector position:
172
173>>> fromList [0..9] @> 7
1747.0
175
176-}
177(@>) :: Storable t => Vector t -> Int -> t
178infixl 9 @>
179(@>) = at
180
181
182{- | concatenate a list of vectors
183
184>>> vjoin [fromList [1..5::Double], konst 1 3]
185fromList [1.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0]
186
187-}
188vjoin :: Storable t => [Vector t] -> Vector t
189vjoin [] = fromList []
190vjoin [v] = v
191vjoin as = unsafePerformIO $ do
192 let tot = sum (map dim as)
193 r <- createVector tot
194 unsafeWith r $ \ptr ->
195 joiner as tot ptr
196 return r
197 where joiner [] _ _ = return ()
198 joiner (v:cs) _ p = do
199 let n = dim v
200 unsafeWith v $ \pb -> copyArray p pb n
201 joiner cs 0 (advancePtr p n)
202
203
204{- | Extract consecutive subvectors of the given sizes.
205
206>>> takesV [3,4] (linspace 10 (1,10::Double))
207[fromList [1.0,2.0,3.0],fromList [4.0,5.0,6.0,7.0]]
208
209-}
210takesV :: Storable t => [Int] -> Vector t -> [Vector t]
211takesV ms w | sum ms > dim w = error $ "takesV " ++ show ms ++ " on dim = " ++ (show $ dim w)
212 | otherwise = go ms w
213 where go [] _ = []
214 go (n:ns) v = subVector 0 n v
215 : go ns (subVector n (dim v - n) v)
216
217---------------------------------------------------------------
218
219-- | transforms a complex vector into a real vector with alternating real and imaginary parts
220asReal :: (RealFloat a, Storable a) => Vector (Complex a) -> Vector a
221asReal v = unsafeFromForeignPtr (castForeignPtr fp) (2*i) (2*n)
222 where (fp,i,n) = unsafeToForeignPtr v
223
224-- | transforms a real vector into a complex vector with alternating real and imaginary parts
225asComplex :: (RealFloat a, Storable a) => Vector a -> Vector (Complex a)
226asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2)
227 where (fp,i,n) = unsafeToForeignPtr v
228
229---------------------------------------------------------------
230
231float2DoubleV :: Vector Float -> Vector Double
232float2DoubleV v = unsafePerformIO $ do
233 r <- createVector (dim v)
234 app2 c_float2double vec v vec r "float2double"
235 return r
236
237double2FloatV :: Vector Double -> Vector Float
238double2FloatV v = unsafePerformIO $ do
239 r <- createVector (dim v)
240 app2 c_double2float vec v vec r "double2float2"
241 return r
242
243
244foreign import ccall unsafe "float2double" c_float2double:: TFV
245foreign import ccall unsafe "double2float" c_double2float:: TVF
246
247---------------------------------------------------------------
248
249stepF :: Vector Float -> Vector Float
250stepF v = unsafePerformIO $ do
251 r <- createVector (dim v)
252 app2 c_stepF vec v vec r "stepF"
253 return r
254
255stepD :: Vector Double -> Vector Double
256stepD v = unsafePerformIO $ do
257 r <- createVector (dim v)
258 app2 c_stepD vec v vec r "stepD"
259 return r
260
261foreign import ccall unsafe "stepF" c_stepF :: TFF
262foreign import ccall unsafe "stepD" c_stepD :: TVV
263
264---------------------------------------------------------------
265
266condF :: Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float
267condF x y l e g = unsafePerformIO $ do
268 r <- createVector (dim x)
269 app6 c_condF vec x vec y vec l vec e vec g vec r "condF"
270 return r
271
272condD :: Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double
273condD x y l e g = unsafePerformIO $ do
274 r <- createVector (dim x)
275 app6 c_condD vec x vec y vec l vec e vec g vec r "condD"
276 return r
277
278foreign import ccall unsafe "condF" c_condF :: CInt -> PF -> CInt -> PF -> CInt -> PF -> TFFF
279foreign import ccall unsafe "condD" c_condD :: CInt -> PD -> CInt -> PD -> CInt -> PD -> TVVV
280
281--------------------------------------------------------------------------------
282
283conjugateAux fun x = unsafePerformIO $ do
284 v <- createVector (dim x)
285 app2 fun vec x vec v "conjugateAux"
286 return v
287
288conjugateQ :: Vector (Complex Float) -> Vector (Complex Float)
289conjugateQ = conjugateAux c_conjugateQ
290foreign import ccall unsafe "conjugateQ" c_conjugateQ :: TQVQV
291
292conjugateC :: Vector (Complex Double) -> Vector (Complex Double)
293conjugateC = conjugateAux c_conjugateC
294foreign import ccall unsafe "conjugateC" c_conjugateC :: TCVCV
295
296--------------------------------------------------------------------------------
297
298cloneVector :: Storable t => Vector t -> IO (Vector t)
299cloneVector v = do
300 let n = dim v
301 r <- createVector n
302 let f _ s _ d = copyArray d s n >> return 0
303 app2 f vec v vec r "cloneVector"
304 return r
305
306------------------------------------------------------------------
307
308-- | map on Vectors
309mapVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b
310mapVector f v = unsafePerformIO $ do
311 w <- createVector (dim v)
312 unsafeWith v $ \p ->
313 unsafeWith w $ \q -> do
314 let go (-1) = return ()
315 go !k = do x <- peekElemOff p k
316 pokeElemOff q k (f x)
317 go (k-1)
318 go (dim v -1)
319 return w
320{-# INLINE mapVector #-}
321
322-- | zipWith for Vectors
323zipVectorWith :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c
324zipVectorWith f u v = unsafePerformIO $ do
325 let n = min (dim u) (dim v)
326 w <- createVector n
327 unsafeWith u $ \pu ->
328 unsafeWith v $ \pv ->
329 unsafeWith w $ \pw -> do
330 let go (-1) = return ()
331 go !k = do x <- peekElemOff pu k
332 y <- peekElemOff pv k
333 pokeElemOff pw k (f x y)
334 go (k-1)
335 go (n -1)
336 return w
337{-# INLINE zipVectorWith #-}
338
339-- | unzipWith for Vectors
340unzipVectorWith :: (Storable (a,b), Storable c, Storable d)
341 => ((a,b) -> (c,d)) -> Vector (a,b) -> (Vector c,Vector d)
342unzipVectorWith f u = unsafePerformIO $ do
343 let n = dim u
344 v <- createVector n
345 w <- createVector n
346 unsafeWith u $ \pu ->
347 unsafeWith v $ \pv ->
348 unsafeWith w $ \pw -> do
349 let go (-1) = return ()
350 go !k = do z <- peekElemOff pu k
351 let (x,y) = f z
352 pokeElemOff pv k x
353 pokeElemOff pw k y
354 go (k-1)
355 go (n-1)
356 return (v,w)
357{-# INLINE unzipVectorWith #-}
358
359foldVector :: Storable a => (a -> b -> b) -> b -> Vector a -> b
360foldVector f x v = unsafePerformIO $
361 unsafeWith v $ \p -> do
362 let go (-1) s = return s
363 go !k !s = do y <- peekElemOff p k
364 go (k-1::Int) (f y s)
365 go (dim v -1) x
366{-# INLINE foldVector #-}
367
368-- the zero-indexed index is passed to the folding function
369foldVectorWithIndex :: Storable a => (Int -> a -> b -> b) -> b -> Vector a -> b
370foldVectorWithIndex f x v = unsafePerformIO $
371 unsafeWith v $ \p -> do
372 let go (-1) s = return s
373 go !k !s = do y <- peekElemOff p k
374 go (k-1::Int) (f k y s)
375 go (dim v -1) x
376{-# INLINE foldVectorWithIndex #-}
377
378foldLoop f s0 d = go (d - 1) s0
379 where
380 go 0 s = f (0::Int) s
381 go !j !s = go (j - 1) (f j s)
382
383foldVectorG f s0 v = foldLoop g s0 (dim v)
384 where g !k !s = f k (at' v) s
385 {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479)
386{-# INLINE foldVectorG #-}
387
388-------------------------------------------------------------------
389
390-- | monadic map over Vectors
391-- the monad @m@ must be strict
392mapVectorM :: (Storable a, Storable b, Monad m) => (a -> m b) -> Vector a -> m (Vector b)
393mapVectorM f v = do
394 w <- return $! unsafePerformIO $! createVector (dim v)
395 mapVectorM' w 0 (dim v -1)
396 return w
397 where mapVectorM' w' !k !t
398 | k == t = do
399 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
400 y <- f x
401 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
402 | otherwise = do
403 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
404 y <- f x
405 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
406 mapVectorM' w' (k+1) t
407{-# INLINE mapVectorM #-}
408
409-- | monadic map over Vectors
410mapVectorM_ :: (Storable a, Monad m) => (a -> m ()) -> Vector a -> m ()
411mapVectorM_ f v = do
412 mapVectorM' 0 (dim v -1)
413 where mapVectorM' !k !t
414 | k == t = do
415 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
416 f x
417 | otherwise = do
418 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
419 _ <- f x
420 mapVectorM' (k+1) t
421{-# INLINE mapVectorM_ #-}
422
423-- | monadic map over Vectors with the zero-indexed index passed to the mapping function
424-- the monad @m@ must be strict
425mapVectorWithIndexM :: (Storable a, Storable b, Monad m) => (Int -> a -> m b) -> Vector a -> m (Vector b)
426mapVectorWithIndexM f v = do
427 w <- return $! unsafePerformIO $! createVector (dim v)
428 mapVectorM' w 0 (dim v -1)
429 return w
430 where mapVectorM' w' !k !t
431 | k == t = do
432 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
433 y <- f k x
434 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
435 | otherwise = do
436 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
437 y <- f k x
438 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
439 mapVectorM' w' (k+1) t
440{-# INLINE mapVectorWithIndexM #-}
441
442-- | monadic map over Vectors with the zero-indexed index passed to the mapping function
443mapVectorWithIndexM_ :: (Storable a, Monad m) => (Int -> a -> m ()) -> Vector a -> m ()
444mapVectorWithIndexM_ f v = do
445 mapVectorM' 0 (dim v -1)
446 where mapVectorM' !k !t
447 | k == t = do
448 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
449 f k x
450 | otherwise = do
451 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
452 _ <- f k x
453 mapVectorM' (k+1) t
454{-# INLINE mapVectorWithIndexM_ #-}
455
456
457mapVectorWithIndex :: (Storable a, Storable b) => (Int -> a -> b) -> Vector a -> Vector b
458--mapVectorWithIndex g = head . mapVectorWithIndexM (\a b -> [g a b])
459mapVectorWithIndex f v = unsafePerformIO $ do
460 w <- createVector (dim v)
461 unsafeWith v $ \p ->
462 unsafeWith w $ \q -> do
463 let go (-1) = return ()
464 go !k = do x <- peekElemOff p k
465 pokeElemOff q k (f k x)
466 go (k-1)
467 go (dim v -1)
468 return w
469{-# INLINE mapVectorWithIndex #-}
470
471
diff --git a/packages/base/src/Data/Packed/Matrix.hs b/packages/base/src/Data/Packed/Matrix.hs
new file mode 100644
index 0000000..d94d167
--- /dev/null
+++ b/packages/base/src/Data/Packed/Matrix.hs
@@ -0,0 +1,490 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE CPP #-}
6
7-----------------------------------------------------------------------------
8-- |
9-- Module : Data.Packed.Matrix
10-- Copyright : (c) Alberto Ruiz 2007-10
11-- License : GPL
12--
13-- Maintainer : Alberto Ruiz <aruiz@um.es>
14-- Stability : provisional
15--
16-- A Matrix representation suitable for numerical computations using LAPACK and GSL.
17--
18-- This module provides basic functions for manipulation of structure.
19
20-----------------------------------------------------------------------------
21{-# OPTIONS_HADDOCK hide #-}
22
23module Data.Packed.Matrix (
24 Matrix,
25 Element,
26 rows,cols,
27 (><),
28 trans,
29 reshape, flatten,
30 fromLists, toLists, buildMatrix,
31 (@@>),
32 asRow, asColumn,
33 fromRows, toRows, fromColumns, toColumns,
34 fromBlocks, diagBlock, toBlocks, toBlocksEvery,
35 repmat,
36 flipud, fliprl,
37 subMatrix, takeRows, dropRows, takeColumns, dropColumns,
38 extractRows, extractColumns,
39 diagRect, takeDiag,
40 mapMatrix, mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_,
41 liftMatrix, liftMatrix2, liftMatrix2Auto,fromArray2D
42) where
43
44import Data.Packed.Internal
45import qualified Data.Packed.ST as ST
46import Data.Array
47
48import Data.List(transpose,intersperse)
49import Foreign.Storable(Storable)
50import Control.Monad(liftM)
51
52-------------------------------------------------------------------
53
54#ifdef BINARY
55
56import Data.Binary
57import Control.Monad(replicateM)
58
59instance (Binary a, Element a, Storable a) => Binary (Matrix a) where
60 put m = do
61 let r = rows m
62 let c = cols m
63 put r
64 put c
65 mapM_ (\i -> mapM_ (\j -> put $ m @@> (i,j)) [0..(c-1)]) [0..(r-1)]
66 get = do
67 r <- get
68 c <- get
69 xs <- replicateM r $ replicateM c get
70 return $ fromLists xs
71
72#endif
73
74-------------------------------------------------------------------
75
76instance (Show a, Element a) => (Show (Matrix a)) where
77 show m | rows m == 0 || cols m == 0 = sizes m ++" []"
78 show m = (sizes m++) . dsp . map (map show) . toLists $ m
79
80sizes m = "("++show (rows m)++"><"++show (cols m)++")\n"
81
82dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
83 where
84 mt = transpose as
85 longs = map (maximum . map length) mt
86 mtp = zipWith (\a b -> map (pad a) b) longs mt
87 pad n str = replicate (n - length str) ' ' ++ str
88 unwords' = concat . intersperse ", "
89
90------------------------------------------------------------------
91
92instance (Element a, Read a) => Read (Matrix a) where
93 readsPrec _ s = [((rs><cs) . read $ listnums, rest)]
94 where (thing,rest) = breakAt ']' s
95 (dims,listnums) = breakAt ')' thing
96 cs = read . init . fst. breakAt ')' . snd . breakAt '<' $ dims
97 rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims
98
99
100breakAt c l = (a++[c],tail b) where
101 (a,b) = break (==c) l
102
103------------------------------------------------------------------
104
105-- | creates a matrix from a vertical list of matrices
106joinVert :: Element t => [Matrix t] -> Matrix t
107joinVert [] = emptyM 0 0
108joinVert ms = case common cols ms of
109 Nothing -> error "(impossible) joinVert on matrices with different number of columns"
110 Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms)
111
112-- | creates a matrix from a horizontal list of matrices
113joinHoriz :: Element t => [Matrix t] -> Matrix t
114joinHoriz ms = trans. joinVert . map trans $ ms
115
116{- | Create a matrix from blocks given as a list of lists of matrices.
117
118Single row-column components are automatically expanded to match the
119corresponding common row and column:
120
121@
122disp = putStr . dispf 2
123@
124
125>>> disp $ fromBlocks [[ident 5, 7, row[10,20]], [3, diagl[1,2,3], 0]]
1268x10
1271 0 0 0 0 7 7 7 10 20
1280 1 0 0 0 7 7 7 10 20
1290 0 1 0 0 7 7 7 10 20
1300 0 0 1 0 7 7 7 10 20
1310 0 0 0 1 7 7 7 10 20
1323 3 3 3 3 1 0 0 0 0
1333 3 3 3 3 0 2 0 0 0
1343 3 3 3 3 0 0 3 0 0
135
136-}
137fromBlocks :: Element t => [[Matrix t]] -> Matrix t
138fromBlocks = fromBlocksRaw . adaptBlocks
139
140fromBlocksRaw mms = joinVert . map joinHoriz $ mms
141
142adaptBlocks ms = ms' where
143 bc = case common length ms of
144 Just c -> c
145 Nothing -> error "fromBlocks requires rectangular [[Matrix]]"
146 rs = map (compatdim . map rows) ms
147 cs = map (compatdim . map cols) (transpose ms)
148 szs = sequence [rs,cs]
149 ms' = splitEvery bc $ zipWith g szs (concat ms)
150
151 g [Just nr,Just nc] m
152 | nr == r && nc == c = m
153 | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc))
154 | r == 1 = fromRows (replicate nr (flatten m))
155 | otherwise = fromColumns (replicate nc (flatten m))
156 where
157 r = rows m
158 c = cols m
159 x = m@@>(0,0)
160 g _ _ = error "inconsistent dimensions in fromBlocks"
161
162
163--------------------------------------------------------------------------------
164
165{- | create a block diagonal matrix
166
167>>> disp 2 $ diagBlock [konst 1 (2,2), konst 2 (3,5), col [5,7]]
1687x8
1691 1 0 0 0 0 0 0
1701 1 0 0 0 0 0 0
1710 0 2 2 2 2 2 0
1720 0 2 2 2 2 2 0
1730 0 2 2 2 2 2 0
1740 0 0 0 0 0 0 5
1750 0 0 0 0 0 0 7
176
177>>> diagBlock [(0><4)[], konst 2 (2,3)] :: Matrix Double
178(2><7)
179 [ 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0
180 , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ]
181
182-}
183diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t
184diagBlock ms = fromBlocks $ zipWith f ms [0..]
185 where
186 f m k = take n $ replicate k z ++ m : repeat z
187 n = length ms
188 z = (1><1) [0]
189
190--------------------------------------------------------------------------------
191
192
193-- | Reverse rows
194flipud :: Element t => Matrix t -> Matrix t
195flipud m = extractRows [r-1,r-2 .. 0] $ m
196 where
197 r = rows m
198
199-- | Reverse columns
200fliprl :: Element t => Matrix t -> Matrix t
201fliprl m = extractColumns [c-1,c-2 .. 0] $ m
202 where
203 c = cols m
204
205------------------------------------------------------------
206
207{- | creates a rectangular diagonal matrix:
208
209>>> diagRect 7 (fromList [10,20,30]) 4 5 :: Matrix Double
210(4><5)
211 [ 10.0, 7.0, 7.0, 7.0, 7.0
212 , 7.0, 20.0, 7.0, 7.0, 7.0
213 , 7.0, 7.0, 30.0, 7.0, 7.0
214 , 7.0, 7.0, 7.0, 7.0, 7.0 ]
215
216-}
217diagRect :: (Storable t) => t -> Vector t -> Int -> Int -> Matrix t
218diagRect z v r c = ST.runSTMatrix $ do
219 m <- ST.newMatrix z r c
220 let d = min r c `min` (dim v)
221 mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1]
222 return m
223
224-- | extracts the diagonal from a rectangular matrix
225takeDiag :: (Element t) => Matrix t -> Vector t
226takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]]
227
228------------------------------------------------------------
229
230{- | An easy way to create a matrix:
231
232>>> (2><3)[2,4,7,-3,11,0]
233(2><3)
234 [ 2.0, 4.0, 7.0
235 , -3.0, 11.0, 0.0 ]
236
237This is the format produced by the instances of Show (Matrix a), which
238can also be used for input.
239
240The input list is explicitly truncated, so that it can
241safely be used with lists that are too long (like infinite lists).
242
243>>> (2><3)[1..]
244(2><3)
245 [ 1.0, 2.0, 3.0
246 , 4.0, 5.0, 6.0 ]
247
248
249-}
250(><) :: (Storable a) => Int -> Int -> [a] -> Matrix a
251r >< c = f where
252 f l | dim v == r*c = matrixFromVector RowMajor r c v
253 | otherwise = error $ "inconsistent list size = "
254 ++show (dim v) ++" in ("++show r++"><"++show c++")"
255 where v = fromList $ take (r*c) l
256
257----------------------------------------------------------------
258
259-- | Creates a matrix with the first n rows of another matrix
260takeRows :: Element t => Int -> Matrix t -> Matrix t
261takeRows n mt = subMatrix (0,0) (n, cols mt) mt
262-- | Creates a copy of a matrix without the first n rows
263dropRows :: Element t => Int -> Matrix t -> Matrix t
264dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt
265-- |Creates a matrix with the first n columns of another matrix
266takeColumns :: Element t => Int -> Matrix t -> Matrix t
267takeColumns n mt = subMatrix (0,0) (rows mt, n) mt
268-- | Creates a copy of a matrix without the first n columns
269dropColumns :: Element t => Int -> Matrix t -> Matrix t
270dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt
271
272----------------------------------------------------------------
273
274{- | Creates a 'Matrix' from a list of lists (considered as rows).
275
276>>> fromLists [[1,2],[3,4],[5,6]]
277(3><2)
278 [ 1.0, 2.0
279 , 3.0, 4.0
280 , 5.0, 6.0 ]
281
282-}
283fromLists :: Element t => [[t]] -> Matrix t
284fromLists = fromRows . map fromList
285
286-- | creates a 1-row matrix from a vector
287--
288-- >>> asRow (fromList [1..5])
289-- (1><5)
290-- [ 1.0, 2.0, 3.0, 4.0, 5.0 ]
291--
292asRow :: Storable a => Vector a -> Matrix a
293asRow v = reshape (dim v) v
294
295-- | creates a 1-column matrix from a vector
296--
297-- >>> asColumn (fromList [1..5])
298-- (5><1)
299-- [ 1.0
300-- , 2.0
301-- , 3.0
302-- , 4.0
303-- , 5.0 ]
304--
305asColumn :: Storable a => Vector a -> Matrix a
306asColumn = trans . asRow
307
308
309
310{- | creates a Matrix of the specified size using the supplied function to
311 to map the row\/column position to the value at that row\/column position.
312
313@> buildMatrix 3 4 (\\(r,c) -> fromIntegral r * fromIntegral c)
314(3><4)
315 [ 0.0, 0.0, 0.0, 0.0, 0.0
316 , 0.0, 1.0, 2.0, 3.0, 4.0
317 , 0.0, 2.0, 4.0, 6.0, 8.0]@
318
319Hilbert matrix of order N:
320
321@hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@
322
323-}
324buildMatrix :: Element a => Int -> Int -> ((Int, Int) -> a) -> Matrix a
325buildMatrix rc cc f =
326 fromLists $ map (map f)
327 $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)]
328
329-----------------------------------------------------
330
331fromArray2D :: (Storable e) => Array (Int, Int) e -> Matrix e
332fromArray2D m = (r><c) (elems m)
333 where ((r0,c0),(r1,c1)) = bounds m
334 r = r1-r0+1
335 c = c1-c0+1
336
337
338-- | rearranges the rows of a matrix according to the order given in a list of integers.
339extractRows :: Element t => [Int] -> Matrix t -> Matrix t
340extractRows [] m = emptyM 0 (cols m)
341extractRows l m = fromRows $ extract (toRows m) l
342 where
343 extract l' is = [l'!!i | i<- map verify is]
344 verify k
345 | k >= 0 && k < rows m = k
346 | otherwise = error $ "can't extract row "
347 ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m
348
349-- | rearranges the rows of a matrix according to the order given in a list of integers.
350extractColumns :: Element t => [Int] -> Matrix t -> Matrix t
351extractColumns l m = trans . extractRows (map verify l) . trans $ m
352 where
353 verify k
354 | k >= 0 && k < cols m = k
355 | otherwise = error $ "can't extract column "
356 ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m
357
358
359
360{- | creates matrix by repetition of a matrix a given number of rows and columns
361
362>>> repmat (ident 2) 2 3
363(4><6)
364 [ 1.0, 0.0, 1.0, 0.0, 1.0, 0.0
365 , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
366 , 1.0, 0.0, 1.0, 0.0, 1.0, 0.0
367 , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ]
368
369-}
370repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t
371repmat m r c
372 | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m)
373 | otherwise = fromBlocks $ replicate r $ replicate c $ m
374
375-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix.
376liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
377liftMatrix2Auto f m1 m2
378 | compat' m1 m2 = lM f m1 m2
379 | ok = lM f m1' m2'
380 | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2
381 where
382 (r1,c1) = size m1
383 (r2,c2) = size m2
384 r = max r1 r2
385 c = max c1 c2
386 r0 = min r1 r2
387 c0 = min c1 c2
388 ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2
389 m1' = conformMTo (r,c) m1
390 m2' = conformMTo (r,c) m2
391
392-- FIXME do not flatten if equal order
393lM f m1 m2 = matrixFromVector
394 RowMajor
395 (max (rows m1) (rows m2))
396 (max (cols m1) (cols m2))
397 (f (flatten m1) (flatten m2))
398
399compat' :: Matrix a -> Matrix b -> Bool
400compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
401 where
402 s1 = size m1
403 s2 = size m2
404
405------------------------------------------------------------
406
407toBlockRows [r] m | r == rows m = [m]
408toBlockRows rs m = map (reshape (cols m)) (takesV szs (flatten m))
409 where szs = map (* cols m) rs
410
411toBlockCols [c] m | c == cols m = [m]
412toBlockCols cs m = map trans . toBlockRows cs . trans $ m
413
414-- | Partition a matrix into blocks with the given numbers of rows and columns.
415-- The remaining rows and columns are discarded.
416toBlocks :: (Element t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]]
417toBlocks rs cs m = map (toBlockCols cs) . toBlockRows rs $ m
418
419-- | Fully partition a matrix into blocks of the same size. If the dimensions are not
420-- a multiple of the given size the last blocks will be smaller.
421toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]]
422toBlocksEvery r c m
423 | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c
424 | otherwise = toBlocks rs cs m
425 where
426 (qr,rr) = rows m `divMod` r
427 (qc,rc) = cols m `divMod` c
428 rs = replicate qr r ++ if rr > 0 then [rr] else []
429 cs = replicate qc c ++ if rc > 0 then [rc] else []
430
431-------------------------------------------------------------------
432
433-- Given a column number and a function taking matrix indexes, returns
434-- a function which takes vector indexes (that can be used on the
435-- flattened matrix).
436mk :: Int -> ((Int, Int) -> t) -> (Int -> t)
437mk c g = \k -> g (divMod k c)
438
439{- |
440
441>>> mapMatrixWithIndexM_ (\(i,j) v -> printf "m[%d,%d] = %.f\n" i j v :: IO()) ((2><3)[1 :: Double ..])
442m[0,0] = 1
443m[0,1] = 2
444m[0,2] = 3
445m[1,0] = 4
446m[1,1] = 5
447m[1,2] = 6
448
449-}
450mapMatrixWithIndexM_
451 :: (Element a, Num a, Monad m) =>
452 ((Int, Int) -> a -> m ()) -> Matrix a -> m ()
453mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m
454 where
455 c = cols m
456
457{- |
458
459>>> mapMatrixWithIndexM (\(i,j) v -> Just $ 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double)
460Just (3><3)
461 [ 100.0, 1.0, 2.0
462 , 10.0, 111.0, 12.0
463 , 20.0, 21.0, 122.0 ]
464
465-}
466mapMatrixWithIndexM
467 :: (Element a, Storable b, Monad m) =>
468 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b)
469mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m
470 where
471 c = cols m
472
473{- |
474
475>>> mapMatrixWithIndex (\\(i,j) v -> 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double)
476(3><3)
477 [ 100.0, 1.0, 2.0
478 , 10.0, 111.0, 12.0
479 , 20.0, 21.0, 122.0 ]
480
481 -}
482mapMatrixWithIndex
483 :: (Element a, Storable b) =>
484 ((Int, Int) -> a -> b) -> Matrix a -> Matrix b
485mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m
486 where
487 c = cols m
488
489mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b
490mapMatrix f = liftMatrix (mapVector f)
diff --git a/packages/base/src/Data/Packed/ST.hs b/packages/base/src/Data/Packed/ST.hs
new file mode 100644
index 0000000..dae457c
--- /dev/null
+++ b/packages/base/src/Data/Packed/ST.hs
@@ -0,0 +1,178 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeOperators #-}
3{-# LANGUAGE Rank2Types #-}
4{-# LANGUAGE BangPatterns #-}
5-----------------------------------------------------------------------------
6-- |
7-- Module : Data.Packed.ST
8-- Copyright : (c) Alberto Ruiz 2008
9-- License : BSD3
10-- Maintainer : Alberto Ruiz
11-- Stability : provisional
12-- Portability : portable
13--
14-- In-place manipulation inside the ST monad.
15-- See examples/inplace.hs in the distribution.
16--
17-----------------------------------------------------------------------------
18
19module Data.Packed.ST (
20 -- * Mutable Vectors
21 STVector, newVector, thawVector, freezeVector, runSTVector,
22 readVector, writeVector, modifyVector, liftSTVector,
23 -- * Mutable Matrices
24 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
25 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
26 -- * Unsafe functions
27 newUndefinedVector,
28 unsafeReadVector, unsafeWriteVector,
29 unsafeThawVector, unsafeFreezeVector,
30 newUndefinedMatrix,
31 unsafeReadMatrix, unsafeWriteMatrix,
32 unsafeThawMatrix, unsafeFreezeMatrix
33) where
34
35import Data.Packed.Internal
36
37import Control.Monad.ST(ST, runST)
38import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
39
40#if MIN_VERSION_base(4,4,0)
41import Control.Monad.ST.Unsafe(unsafeIOToST)
42#else
43import Control.Monad.ST(unsafeIOToST)
44#endif
45
46{-# INLINE ioReadV #-}
47ioReadV :: Storable t => Vector t -> Int -> IO t
48ioReadV v k = unsafeWith v $ \s -> peekElemOff s k
49
50{-# INLINE ioWriteV #-}
51ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
52ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x
53
54newtype STVector s t = STVector (Vector t)
55
56thawVector :: Storable t => Vector t -> ST s (STVector s t)
57thawVector = unsafeIOToST . fmap STVector . cloneVector
58
59unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
60unsafeThawVector = unsafeIOToST . return . STVector
61
62runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
63runSTVector st = runST (st >>= unsafeFreezeVector)
64
65{-# INLINE unsafeReadVector #-}
66unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
67unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x
68
69{-# INLINE unsafeWriteVector #-}
70unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
71unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
72
73{-# INLINE modifyVector #-}
74modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
75modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
76
77liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a
78liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
79
80freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
81freezeVector v = liftSTVector id v
82
83unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
84unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
85
86{-# INLINE safeIndexV #-}
87safeIndexV f (STVector v) k
88 | k < 0 || k>= dim v = error $ "out of range error in vector (dim="
89 ++show (dim v)++", pos="++show k++")"
90 | otherwise = f (STVector v) k
91
92{-# INLINE readVector #-}
93readVector :: Storable t => STVector s t -> Int -> ST s t
94readVector = safeIndexV unsafeReadVector
95
96{-# INLINE writeVector #-}
97writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
98writeVector = safeIndexV unsafeWriteVector
99
100newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
101newUndefinedVector = unsafeIOToST . fmap STVector . createVector
102
103{-# INLINE newVector #-}
104newVector :: Storable t => t -> Int -> ST s (STVector s t)
105newVector x n = do
106 v <- newUndefinedVector n
107 let go (-1) = return v
108 go !k = unsafeWriteVector v k x >> go (k-1 :: Int)
109 go (n-1)
110
111-------------------------------------------------------------------------
112
113{-# INLINE ioReadM #-}
114ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
115ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c)
116ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r)
117
118{-# INLINE ioWriteM #-}
119ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
120ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val
121ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val
122
123newtype STMatrix s t = STMatrix (Matrix t)
124
125thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
126thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
127
128unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
129unsafeThawMatrix = unsafeIOToST . return . STMatrix
130
131runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
132runSTMatrix st = runST (st >>= unsafeFreezeMatrix)
133
134{-# INLINE unsafeReadMatrix #-}
135unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
136unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r
137
138{-# INLINE unsafeWriteMatrix #-}
139unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
140unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
141
142{-# INLINE modifyMatrix #-}
143modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
144modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
145
146liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a
147liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
148
149unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
150unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
151
152freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
153freezeMatrix m = liftSTMatrix id m
154
155cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
156
157{-# INLINE safeIndexM #-}
158safeIndexM f (STMatrix m) r c
159 | r<0 || r>=rows m ||
160 c<0 || c>=cols m = error $ "out of range error in matrix (size="
161 ++show (rows m,cols m)++", pos="++show (r,c)++")"
162 | otherwise = f (STMatrix m) r c
163
164{-# INLINE readMatrix #-}
165readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
166readMatrix = safeIndexM unsafeReadMatrix
167
168{-# INLINE writeMatrix #-}
169writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
170writeMatrix = safeIndexM unsafeWriteMatrix
171
172newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
173newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
174
175{-# NOINLINE newMatrix #-}
176newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
177newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
178
diff --git a/packages/base/src/Data/Packed/Vector.hs b/packages/base/src/Data/Packed/Vector.hs
new file mode 100644
index 0000000..b5a4318
--- /dev/null
+++ b/packages/base/src/Data/Packed/Vector.hs
@@ -0,0 +1,96 @@
1{-# LANGUAGE FlexibleContexts #-}
2{-# LANGUAGE CPP #-}
3-----------------------------------------------------------------------------
4-- |
5-- Module : Data.Packed.Vector
6-- Copyright : (c) Alberto Ruiz 2007-10
7-- License : GPL
8--
9-- Maintainer : Alberto Ruiz <aruiz@um.es>
10-- Stability : provisional
11--
12-- 1D arrays suitable for numeric computations using external libraries.
13--
14-- This module provides basic functions for manipulation of structure.
15--
16-----------------------------------------------------------------------------
17{-# OPTIONS_HADDOCK hide #-}
18
19module Data.Packed.Vector (
20 Vector,
21 fromList, (|>), toList, buildVector,
22 dim, (@>),
23 subVector, takesV, vjoin, join,
24 mapVector, mapVectorWithIndex, zipVector, zipVectorWith, unzipVector, unzipVectorWith,
25 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_,
26 foldLoop, foldVector, foldVectorG, foldVectorWithIndex
27) where
28
29import Data.Packed.Internal.Vector
30import Foreign.Storable
31
32-------------------------------------------------------------------
33
34#ifdef BINARY
35
36import Data.Binary
37import Control.Monad(replicateM)
38
39-- a 64K cache, with a Double taking 13 bytes in Bytestring,
40-- implies a chunk size of 5041
41chunk :: Int
42chunk = 5000
43
44chunks :: Int -> [Int]
45chunks d = let c = d `div` chunk
46 m = d `mod` chunk
47 in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk)
48
49putVector v = do
50 let d = dim v
51 mapM_ (\i -> put $ v @> i) [0..(d-1)]
52
53getVector d = do
54 xs <- replicateM d get
55 return $! fromList xs
56
57instance (Binary a, Storable a) => Binary (Vector a) where
58 put v = do
59 let d = dim v
60 put d
61 mapM_ putVector $! takesV (chunks d) v
62 get = do
63 d <- get
64 vs <- mapM getVector $ chunks d
65 return $! vjoin vs
66
67#endif
68
69-------------------------------------------------------------------
70
71{- | creates a Vector of the specified length using the supplied function to
72 to map the index to the value at that index.
73
74@> buildVector 4 fromIntegral
754 |> [0.0,1.0,2.0,3.0]@
76
77-}
78buildVector :: Storable a => Int -> (Int -> a) -> Vector a
79buildVector len f =
80 fromList $ map f [0 .. (len - 1)]
81
82
83-- | zip for Vectors
84zipVector :: (Storable a, Storable b, Storable (a,b)) => Vector a -> Vector b -> Vector (a,b)
85zipVector = zipVectorWith (,)
86
87-- | unzip for Vectors
88unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vector a,Vector b)
89unzipVector = unzipVectorWith id
90
91-------------------------------------------------------------------
92
93{-# DEPRECATED join "use vjoin or Data.Vector.concat" #-}
94join :: Storable t => [Vector t] -> Vector t
95join = vjoin
96