summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Internal/Matrix.hs51
-rw-r--r--packages/base/src/Internal/ST.hs4
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Devel.hs2
3 files changed, 38 insertions, 19 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index c0d1318..f76b9dc 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -79,9 +79,8 @@ The elements are stored in a continuous memory array.
79data Matrix t = Matrix 79data Matrix t = Matrix
80 { irows :: {-# UNPACK #-} !Int 80 { irows :: {-# UNPACK #-} !Int
81 , icols :: {-# UNPACK #-} !Int 81 , icols :: {-# UNPACK #-} !Int
82 , xRow :: {-# UNPACK #-} !CInt 82 , xRow :: {-# UNPACK #-} !Int
83 , xCol :: {-# UNPACK #-} !CInt 83 , xCol :: {-# UNPACK #-} !Int
84-- , rowOrder :: {-# UNPACK #-} !Bool
85 , xdat :: {-# UNPACK #-} !(Vector t) 84 , xdat :: {-# UNPACK #-} !(Vector t)
86 } 85 }
87-- RowMajor: preferred by C, fdat may require a transposition 86-- RowMajor: preferred by C, fdat may require a transposition
@@ -90,13 +89,18 @@ data Matrix t = Matrix
90 89
91rows :: Matrix t -> Int 90rows :: Matrix t -> Int
92rows = irows 91rows = irows
92{-# INLINE rows #-}
93 93
94cols :: Matrix t -> Int 94cols :: Matrix t -> Int
95cols = icols 95cols = icols
96{-# INLINE cols #-}
96 97
97rowOrder m = xRow m > 1 98rowOrder m = xRow m > 1
98{-# INLINE rowOrder #-} 99{-# INLINE rowOrder #-}
99 100
101isSlice m = cols m < xRow m || rows m < xCol m
102{-# INLINE isSlice #-}
103
100orderOf :: Matrix t -> MatrixOrder 104orderOf :: Matrix t -> MatrixOrder
101orderOf m = if rowOrder m then RowMajor else ColumnMajor 105orderOf m = if rowOrder m then RowMajor else ColumnMajor
102 106
@@ -104,19 +108,19 @@ orderOf m = if rowOrder m then RowMajor else ColumnMajor
104-- | Matrix transpose. 108-- | Matrix transpose.
105trans :: Matrix t -> Matrix t 109trans :: Matrix t -> Matrix t
106trans m@Matrix { irows = r, icols = c } | rowOrder m = 110trans m@Matrix { irows = r, icols = c } | rowOrder m =
107 m { irows = c, icols = r, xRow = 1, xCol = fi c } 111 m { irows = c, icols = r, xRow = 1, xCol = c }
108trans m@Matrix { irows = r, icols = c } = 112trans m@Matrix { irows = r, icols = c } =
109 m { irows = c, icols = r, xRow = fi r, xCol = 1 } 113 m { irows = c, icols = r, xRow = r, xCol = 1 }
110 114
111cmat :: (Element t) => Matrix t -> Matrix t 115cmat :: (Element t) => Matrix t -> Matrix t
112cmat m | rowOrder m = m 116cmat m | rowOrder m = m
113cmat m@Matrix { irows = r, icols = c, xdat = d } = 117cmat m@Matrix { irows = r, icols = c, xdat = d } =
114 m { xdat = transdata r d c, xRow = fi c, xCol = 1 } 118 m { xdat = transdata r d c, xRow = c, xCol = 1 }
115 119
116fmat :: (Element t) => Matrix t -> Matrix t 120fmat :: (Element t) => Matrix t -> Matrix t
117fmat m | not (rowOrder m) = m 121fmat m | not (rowOrder m) = m
118fmat m@Matrix { irows = r, icols = c, xdat = d} = 122fmat m@Matrix { irows = r, icols = c, xdat = d} =
119 m { xdat = transdata c d r, xRow = 1, xCol = fi r } 123 m { xdat = transdata c d r, xRow = 1, xCol = r }
120 124
121 125
122-- C-Haskell matrix adapters 126-- C-Haskell matrix adapters
@@ -124,17 +128,17 @@ fmat m@Matrix { irows = r, icols = c, xdat = d} =
124amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b 128amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
125amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) 129amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c))
126 where 130 where
127 r = fromIntegral (rows x) 131 r = fi (rows x)
128 c = fromIntegral (cols x) 132 c = fi (cols x)
129 133
130{-# INLINE amat #-} 134{-# INLINE amat #-}
131amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b 135amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
132amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) 136amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc))
133 where 137 where
134 r = fromIntegral (rows x) 138 r = fi (rows x)
135 c = fromIntegral (cols x) 139 c = fi (cols x)
136 sr = xRow x 140 sr = fi (xRow x)
137 sc = xCol x 141 sc = fi (xCol x)
138 142
139 143
140instance Storable t => TransArray (Matrix t) 144instance Storable t => TransArray (Matrix t)
@@ -227,7 +231,7 @@ m@Matrix {irows = r, icols = c} @@> (i,j)
227{-# INLINE (@@>) #-} 231{-# INLINE (@@>) #-}
228 232
229-- Unsafe matrix access without range checking 233-- Unsafe matrix access without range checking
230atM' m i j = xdat m `at'` (i * (ti $ xRow m) + j * (ti $ xCol m)) 234atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
231{-# INLINE atM' #-} 235{-# INLINE atM' #-}
232 236
233------------------------------------------------------------------ 237------------------------------------------------------------------
@@ -236,8 +240,8 @@ matrixFromVector o r c v
236 | r * c == dim v = m 240 | r * c == dim v = m
237 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m 241 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
238 where 242 where
239 m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = fi c, xCol = 1 } 243 m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 }
240 | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1 , xCol = fi r } 244 | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r }
241 245
242-- allocates memory for a new matrix 246-- allocates memory for a new matrix
243createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a) 247createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
@@ -411,6 +415,21 @@ subMatrix (r0,c0) (rt,ct) m
411 | otherwise = error $ "wrong subMatrix "++ 415 | otherwise = error $ "wrong subMatrix "++
412 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) 416 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
413 417
418
419sliceMatrix :: Element a
420 => (Int,Int) -- ^ (r0,c0) starting position
421 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
422 -> Matrix a -- ^ input matrix
423 -> Matrix a -- ^ result
424sliceMatrix (r0,c0) (rt,ct) m
425 | 0 <= r0 && 0 <= rt && r0+rt <= rows m &&
426 0 <= c0 && 0 <= ct && c0+ct <= cols m = res
427 | otherwise = error $ "wrong sliceMatrix "++
428 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
429 where
430 t = r0 * xRow m + c0 * xCol m
431 res = m { irows = rt, icols = ct, xdat = subVector t (rt*ct) (xdat m) }
432
414-------------------------------------------------------------------------- 433--------------------------------------------------------------------------
415 434
416maxZ xs = if minimum xs == 0 then 0 else maximum xs 435maxZ xs = if minimum xs == 0 then 0 else maximum xs
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index c98ff0e..92654e4 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -109,12 +109,12 @@ newVector x n = do
109 109
110{-# INLINE ioReadM #-} 110{-# INLINE ioReadM #-}
111ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t 111ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
112ioReadM m r c = ioReadV (xdat m) (r * (ti $ xRow m) + c * (ti $ xCol m)) 112ioReadM m r c = ioReadV (xdat m) (r * xRow m + c * xCol m)
113 113
114 114
115{-# INLINE ioWriteM #-} 115{-# INLINE ioWriteM #-}
116ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO () 116ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
117ioWriteM m r c val = ioWriteV (xdat m) (r * (ti $ xRow m) + c * (ti $ xCol m)) val 117ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val
118 118
119 119
120newtype STMatrix s t = STMatrix (Matrix t) 120newtype STMatrix s t = STMatrix (Matrix t)
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
index f18d35b..5ca1a7c 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
@@ -62,7 +62,7 @@ module Numeric.LinearAlgebra.Devel(
62 GMatrix(..), 62 GMatrix(..),
63 63
64 -- * Misc 64 -- * Misc
65 toByteString, fromByteString 65 toByteString, fromByteString, sliceMatrix
66 66
67) where 67) where
68 68