diff options
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 20 | ||||
-rw-r--r-- | packages/base/src/Internal/Element.hs | 12 | ||||
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 190 | ||||
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 22 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 43 |
7 files changed, 143 insertions, 148 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index 4d48594..cdbaab9 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -1485,26 +1485,6 @@ int smTXv(KDVEC(vals),KIVEC(cols),KIVEC(rows),KDVEC(x),DVEC(r)) { | |||
1485 | } | 1485 | } |
1486 | 1486 | ||
1487 | 1487 | ||
1488 | //////////////////// transpose ///////////////////////// | ||
1489 | |||
1490 | #define TRANS_IMP { \ | ||
1491 | REQUIRES(xr==tc && xc==tr,BAD_SIZE); \ | ||
1492 | DEBUGMSG("trans"); \ | ||
1493 | int i,j; \ | ||
1494 | for (i=0; i<tr; i++) { \ | ||
1495 | for (j=0; j<tc; j++) { \ | ||
1496 | tp[i*tc+j] = xp[j*xc+i]; \ | ||
1497 | } \ | ||
1498 | } \ | ||
1499 | OK } | ||
1500 | |||
1501 | int transF(KFMAT(x),FMAT(t)) TRANS_IMP | ||
1502 | int transR(KDMAT(x),DMAT(t)) TRANS_IMP | ||
1503 | int transQ(KQMAT(x),QMAT(t)) TRANS_IMP | ||
1504 | int transC(KCMAT(x),CMAT(t)) TRANS_IMP | ||
1505 | int transI(KIMAT(x),IMAT(t)) TRANS_IMP | ||
1506 | int transL(KLMAT(x),LMAT(t)) TRANS_IMP | ||
1507 | |||
1508 | //////////////////////// extract ///////////////////////////////// | 1488 | //////////////////////// extract ///////////////////////////////// |
1509 | 1489 | ||
1510 | #define EXTRACT_IMP { \ | 1490 | #define EXTRACT_IMP { \ |
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs index 51d5686..6d86f3d 100644 --- a/packages/base/src/Internal/Element.hs +++ b/packages/base/src/Internal/Element.hs | |||
@@ -173,7 +173,7 @@ m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n)) | |||
173 | m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) | 173 | m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) |
174 | m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) | 174 | m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) |
175 | 175 | ||
176 | m ?? (er,ec) = unsafePerformIO $ extractR m moder rs modec cs | 176 | m ?? (er,ec) = unsafePerformIO $ extractR (orderOf m) m moder rs modec cs |
177 | where | 177 | where |
178 | (moder,rs) = mkExt (rows m) er | 178 | (moder,rs) = mkExt (rows m) er |
179 | (modec,cs) = mkExt (cols m) ec | 179 | (modec,cs) = mkExt (cols m) ec |
@@ -491,9 +491,13 @@ liftMatrix2Auto f m1 m2 | |||
491 | -- FIXME do not flatten if equal order | 491 | -- FIXME do not flatten if equal order |
492 | lM f m1 m2 = matrixFromVector | 492 | lM f m1 m2 = matrixFromVector |
493 | RowMajor | 493 | RowMajor |
494 | (max (rows m1) (rows m2)) | 494 | (max' (rows m1) (rows m2)) |
495 | (max (cols m1) (cols m2)) | 495 | (max' (cols m1) (cols m2)) |
496 | (f (flatten m1) (flatten m2)) | 496 | (f (flatten m1) (flatten m2)) |
497 | where | ||
498 | max' 1 b = b | ||
499 | max' a 1 = a | ||
500 | max' a b = max a b | ||
497 | 501 | ||
498 | compat' :: Matrix a -> Matrix b -> Bool | 502 | compat' :: Matrix a -> Matrix b -> Bool |
499 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | 503 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 |
@@ -595,6 +599,6 @@ mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m | |||
595 | where | 599 | where |
596 | c = cols m | 600 | c = cols m |
597 | 601 | ||
598 | mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b | 602 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b |
599 | mapMatrix f = liftMatrix (mapVector f) | 603 | mapMatrix f = liftMatrix (mapVector f) |
600 | 604 | ||
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index fc9e3ad..5319e95 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -418,6 +418,8 @@ linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b) | |||
418 | 418 | ||
419 | ----------------------------------------------------------------------------------- | 419 | ----------------------------------------------------------------------------------- |
420 | 420 | ||
421 | type TMM t = t ..> t ..> Ok | ||
422 | |||
421 | foreign import ccall unsafe "chol_l_H" zpotrf :: TMM C | 423 | foreign import ccall unsafe "chol_l_H" zpotrf :: TMM C |
422 | foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R | 424 | foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R |
423 | 425 | ||
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index f76b9dc..bdf2785 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -32,49 +32,13 @@ import Foreign.C.Types ( CInt(..) ) | |||
32 | import Foreign.C.String ( CString, newCString ) | 32 | import Foreign.C.String ( CString, newCString ) |
33 | import System.IO.Unsafe ( unsafePerformIO ) | 33 | import System.IO.Unsafe ( unsafePerformIO ) |
34 | import Control.DeepSeq ( NFData(..) ) | 34 | import Control.DeepSeq ( NFData(..) ) |
35 | import Data.List.Split(chunksOf) | 35 | import Text.Printf |
36 | 36 | ||
37 | ----------------------------------------------------------------- | 37 | ----------------------------------------------------------------- |
38 | 38 | ||
39 | {- Design considerations for the Matrix Type | ||
40 | ----------------------------------------- | ||
41 | |||
42 | - we must easily handle both row major and column major order, | ||
43 | for bindings to LAPACK and GSL/C | ||
44 | |||
45 | - we'd like to simplify redundant matrix transposes: | ||
46 | - Some of them arise from the order requirements of some functions | ||
47 | - some functions (matrix product) admit transposed arguments | ||
48 | |||
49 | - maybe we don't really need this kind of simplification: | ||
50 | - more complex code | ||
51 | - some computational overhead | ||
52 | - only appreciable gain in code with a lot of redundant transpositions | ||
53 | and cheap matrix computations | ||
54 | |||
55 | - we could carry both the matrix and its (lazily computed) transpose. | ||
56 | This may save some transpositions, but it is necessary to keep track of the | ||
57 | data which is actually computed to be used by functions like the matrix product | ||
58 | which admit both orders. | ||
59 | |||
60 | - but if we need the transposed data and it is not in the structure, we must make | ||
61 | sure that we touch the same foreignptr that is used in the computation. | ||
62 | |||
63 | - a reasonable solution is using two constructors for a matrix. Transposition just | ||
64 | "flips" the constructor. Actual data transposition is not done if followed by a | ||
65 | matrix product or another transpose. | ||
66 | |||
67 | -} | ||
68 | |||
69 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | 39 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) |
70 | 40 | ||
71 | transOrder RowMajor = ColumnMajor | 41 | -- | Matrix representation suitable for BLAS\/LAPACK computations. |
72 | transOrder ColumnMajor = RowMajor | ||
73 | {- | Matrix representation suitable for BLAS\/LAPACK computations. | ||
74 | |||
75 | The elements are stored in a continuous memory array. | ||
76 | |||
77 | -} | ||
78 | 42 | ||
79 | data Matrix t = Matrix | 43 | data Matrix t = Matrix |
80 | { irows :: {-# UNPACK #-} !Int | 44 | { irows :: {-# UNPACK #-} !Int |
@@ -83,8 +47,6 @@ data Matrix t = Matrix | |||
83 | , xCol :: {-# UNPACK #-} !Int | 47 | , xCol :: {-# UNPACK #-} !Int |
84 | , xdat :: {-# UNPACK #-} !(Vector t) | 48 | , xdat :: {-# UNPACK #-} !(Vector t) |
85 | } | 49 | } |
86 | -- RowMajor: preferred by C, fdat may require a transposition | ||
87 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition | ||
88 | 50 | ||
89 | 51 | ||
90 | rows :: Matrix t -> Int | 52 | rows :: Matrix t -> Int |
@@ -95,32 +57,55 @@ cols :: Matrix t -> Int | |||
95 | cols = icols | 57 | cols = icols |
96 | {-# INLINE cols #-} | 58 | {-# INLINE cols #-} |
97 | 59 | ||
98 | rowOrder m = xRow m > 1 | 60 | size m = (irows m, icols m) |
61 | {-# INLINE size #-} | ||
62 | |||
63 | rowOrder m = xCol m == 1 || cols m == 1 | ||
99 | {-# INLINE rowOrder #-} | 64 | {-# INLINE rowOrder #-} |
100 | 65 | ||
101 | isSlice m = cols m < xRow m || rows m < xCol m | 66 | colOrder m = xRow m == 1 || rows m == 1 |
67 | {-# INLINE colOrder #-} | ||
68 | |||
69 | is1d (size->(r,c)) = r==1 || c==1 | ||
70 | {-# INLINE is1d #-} | ||
71 | |||
72 | -- data is not contiguous | ||
73 | isSlice m@(size->(r,c)) = (c < xRow m || r < xCol m) && min r c > 1 | ||
102 | {-# INLINE isSlice #-} | 74 | {-# INLINE isSlice #-} |
103 | 75 | ||
104 | orderOf :: Matrix t -> MatrixOrder | 76 | orderOf :: Matrix t -> MatrixOrder |
105 | orderOf m = if rowOrder m then RowMajor else ColumnMajor | 77 | orderOf m = if rowOrder m then RowMajor else ColumnMajor |
106 | 78 | ||
107 | 79 | ||
80 | showInternal :: Storable t => Matrix t -> IO () | ||
81 | showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv | ||
82 | where | ||
83 | r = rows m | ||
84 | c = cols m | ||
85 | xr = xRow m | ||
86 | xc = xCol m | ||
87 | slc = if isSlice m then "slice" else "full" | ||
88 | ord = if is1d m then "1d" else if rowOrder m then "rows" else "cols" | ||
89 | dv = dim (xdat m) | ||
90 | |||
91 | -------------------------------------------------------------------------------- | ||
92 | |||
108 | -- | Matrix transpose. | 93 | -- | Matrix transpose. |
109 | trans :: Matrix t -> Matrix t | 94 | trans :: Matrix t -> Matrix t |
110 | trans m@Matrix { irows = r, icols = c } | rowOrder m = | 95 | trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = |
111 | m { irows = c, icols = r, xRow = 1, xCol = c } | 96 | m { irows = c, icols = r, xRow = xc, xCol = xr } |
112 | trans m@Matrix { irows = r, icols = c } = | 97 | |
113 | m { irows = c, icols = r, xRow = r, xCol = 1 } | ||
114 | 98 | ||
115 | cmat :: (Element t) => Matrix t -> Matrix t | 99 | cmat :: (Element t) => Matrix t -> Matrix t |
116 | cmat m | rowOrder m = m | 100 | cmat m |
117 | cmat m@Matrix { irows = r, icols = c, xdat = d } = | 101 | | rowOrder m = m |
118 | m { xdat = transdata r d c, xRow = c, xCol = 1 } | 102 | | otherwise = extractAll RowMajor m |
103 | |||
119 | 104 | ||
120 | fmat :: (Element t) => Matrix t -> Matrix t | 105 | fmat :: (Element t) => Matrix t -> Matrix t |
121 | fmat m | not (rowOrder m) = m | 106 | fmat m |
122 | fmat m@Matrix { irows = r, icols = c, xdat = d} = | 107 | | colOrder m = m |
123 | m { xdat = transdata c d r, xRow = 1, xCol = r } | 108 | | otherwise = extractAll ColumnMajor m |
124 | 109 | ||
125 | 110 | ||
126 | -- C-Haskell matrix adapters | 111 | -- C-Haskell matrix adapters |
@@ -157,6 +142,11 @@ a # b = apply a b | |||
157 | 142 | ||
158 | -------------------------------------------------------------------------------- | 143 | -------------------------------------------------------------------------------- |
159 | 144 | ||
145 | extractAll ord m = unsafePerformIO $ | ||
146 | extractR ord m | ||
147 | 0 (idxs[0,rows m-1]) | ||
148 | 0 (idxs[0,cols m-1]) | ||
149 | |||
160 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 150 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
161 | 151 | ||
162 | >>> flatten (ident 3) | 152 | >>> flatten (ident 3) |
@@ -164,12 +154,14 @@ fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] | |||
164 | 154 | ||
165 | -} | 155 | -} |
166 | flatten :: Element t => Matrix t -> Vector t | 156 | flatten :: Element t => Matrix t -> Vector t |
167 | flatten = xdat . cmat | 157 | flatten m |
158 | | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) | ||
159 | | otherwise = xdat m | ||
168 | 160 | ||
169 | 161 | ||
170 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | 162 | -- | the inverse of 'Data.Packed.Matrix.fromLists' |
171 | toLists :: (Element t) => Matrix t -> [[t]] | 163 | toLists :: (Element t) => Matrix t -> [[t]] |
172 | toLists m = chunksOf (cols m) . toList . flatten $ m | 164 | toLists = map toList . toRows |
173 | 165 | ||
174 | 166 | ||
175 | 167 | ||
@@ -205,6 +197,14 @@ fromRows vs = case compatdim (map dim vs) of | |||
205 | -- | extracts the rows of a matrix as a list of vectors | 197 | -- | extracts the rows of a matrix as a list of vectors |
206 | toRows :: Element t => Matrix t -> [Vector t] | 198 | toRows :: Element t => Matrix t -> [Vector t] |
207 | toRows m | 199 | toRows m |
200 | | rowOrder m = map sub rowRange | ||
201 | | otherwise = map ext rowRange | ||
202 | where | ||
203 | rowRange = [0..rows m-1] | ||
204 | sub k = subVector (k*xRow m) (cols m) (xdat m) | ||
205 | ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) | ||
206 | |||
207 | {- | ||
208 | | c == 0 = replicate r (fromList[]) | 208 | | c == 0 = replicate r (fromList[]) |
209 | | otherwise = toRows' 0 | 209 | | otherwise = toRows' 0 |
210 | where | 210 | where |
@@ -213,6 +213,7 @@ toRows m | |||
213 | c = cols m | 213 | c = cols m |
214 | toRows' k | k == r*c = [] | 214 | toRows' k | k == r*c = [] |
215 | | otherwise = subVector k c v : toRows' (k+c) | 215 | | otherwise = subVector k c v : toRows' (k+c) |
216 | -} | ||
216 | 217 | ||
217 | -- | Creates a matrix from a list of vectors, as columns | 218 | -- | Creates a matrix from a list of vectors, as columns |
218 | fromColumns :: Element t => [Vector t] -> Matrix t | 219 | fromColumns :: Element t => [Vector t] -> Matrix t |
@@ -240,7 +241,7 @@ matrixFromVector o r c v | |||
240 | | r * c == dim v = m | 241 | | r * c == dim v = m |
241 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m | 242 | | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m |
242 | where | 243 | where |
243 | m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } | 244 | m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } |
244 | | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } | 245 | | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } |
245 | 246 | ||
246 | -- allocates memory for a new matrix | 247 | -- allocates memory for a new matrix |
@@ -263,31 +264,26 @@ reshape :: Storable t => Int -> Vector t -> Matrix t | |||
263 | reshape 0 v = matrixFromVector RowMajor 0 0 v | 264 | reshape 0 v = matrixFromVector RowMajor 0 0 v |
264 | reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | 265 | reshape c v = matrixFromVector RowMajor (dim v `div` c) c v |
265 | 266 | ||
266 | --singleton x = reshape 1 (fromList [x]) | ||
267 | 267 | ||
268 | -- | application of a vector function on the flattened matrix elements | 268 | -- | application of a vector function on the flattened matrix elements |
269 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 269 | liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
270 | liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} = matrixFromVector (orderOf m) r c (f d) | 270 | liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} |
271 | | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) | ||
272 | | otherwise = matrixFromVector (orderOf m) r c (f d) | ||
271 | 273 | ||
272 | -- | application of a vector function on the flattened matrices elements | 274 | -- | application of a vector function on the flattened matrices elements |
273 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 275 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
274 | liftMatrix2 f m1 m2 | 276 | liftMatrix2 f m1@(size->(r,c)) m2 |
275 | | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" | 277 | | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" |
276 | | otherwise = case orderOf m1 of | 278 | | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) |
277 | RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2)) | 279 | | otherwise = matrixFromVector ColumnMajor r c (f (flatten (trans m1)) (flatten (trans m2))) |
278 | ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2)) | ||
279 | |||
280 | |||
281 | compat :: Matrix a -> Matrix b -> Bool | ||
282 | compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | ||
283 | 280 | ||
284 | ------------------------------------------------------------------ | 281 | ------------------------------------------------------------------ |
285 | 282 | ||
286 | -- | Supported matrix elements. | 283 | -- | Supported matrix elements. |
287 | class (Storable a) => Element a where | 284 | class (Storable a) => Element a where |
288 | transdata :: Int -> Vector a -> Int -> Vector a | ||
289 | constantD :: a -> Int -> Vector a | 285 | constantD :: a -> Int -> Vector a |
290 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) | 286 | extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) |
291 | setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () | 287 | setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () |
292 | sortI :: Ord a => Vector a -> Vector CInt | 288 | sortI :: Ord a => Vector a -> Vector CInt |
293 | sortV :: Ord a => Vector a -> Vector a | 289 | sortV :: Ord a => Vector a -> Vector a |
@@ -299,7 +295,6 @@ class (Storable a) => Element a where | |||
299 | 295 | ||
300 | 296 | ||
301 | instance Element Float where | 297 | instance Element Float where |
302 | transdata = transdataAux ctransF | ||
303 | constantD = constantAux cconstantF | 298 | constantD = constantAux cconstantF |
304 | extractR = extractAux c_extractF | 299 | extractR = extractAux c_extractF |
305 | setRect = setRectAux c_setRectF | 300 | setRect = setRectAux c_setRectF |
@@ -312,7 +307,6 @@ instance Element Float where | |||
312 | gemm = gemmg c_gemmF | 307 | gemm = gemmg c_gemmF |
313 | 308 | ||
314 | instance Element Double where | 309 | instance Element Double where |
315 | transdata = transdataAux ctransR | ||
316 | constantD = constantAux cconstantR | 310 | constantD = constantAux cconstantR |
317 | extractR = extractAux c_extractD | 311 | extractR = extractAux c_extractD |
318 | setRect = setRectAux c_setRectD | 312 | setRect = setRectAux c_setRectD |
@@ -325,7 +319,6 @@ instance Element Double where | |||
325 | gemm = gemmg c_gemmD | 319 | gemm = gemmg c_gemmD |
326 | 320 | ||
327 | instance Element (Complex Float) where | 321 | instance Element (Complex Float) where |
328 | transdata = transdataAux ctransQ | ||
329 | constantD = constantAux cconstantQ | 322 | constantD = constantAux cconstantQ |
330 | extractR = extractAux c_extractQ | 323 | extractR = extractAux c_extractQ |
331 | setRect = setRectAux c_setRectQ | 324 | setRect = setRectAux c_setRectQ |
@@ -338,7 +331,6 @@ instance Element (Complex Float) where | |||
338 | gemm = gemmg c_gemmQ | 331 | gemm = gemmg c_gemmQ |
339 | 332 | ||
340 | instance Element (Complex Double) where | 333 | instance Element (Complex Double) where |
341 | transdata = transdataAux ctransC | ||
342 | constantD = constantAux cconstantC | 334 | constantD = constantAux cconstantC |
343 | extractR = extractAux c_extractC | 335 | extractR = extractAux c_extractC |
344 | setRect = setRectAux c_setRectC | 336 | setRect = setRectAux c_setRectC |
@@ -351,7 +343,6 @@ instance Element (Complex Double) where | |||
351 | gemm = gemmg c_gemmC | 343 | gemm = gemmg c_gemmC |
352 | 344 | ||
353 | instance Element (CInt) where | 345 | instance Element (CInt) where |
354 | transdata = transdataAux ctransI | ||
355 | constantD = constantAux cconstantI | 346 | constantD = constantAux cconstantI |
356 | extractR = extractAux c_extractI | 347 | extractR = extractAux c_extractI |
357 | setRect = setRectAux c_setRectI | 348 | setRect = setRectAux c_setRectI |
@@ -364,7 +355,6 @@ instance Element (CInt) where | |||
364 | gemm = gemmg c_gemmI | 355 | gemm = gemmg c_gemmI |
365 | 356 | ||
366 | instance Element Z where | 357 | instance Element Z where |
367 | transdata = transdataAux ctransL | ||
368 | constantD = constantAux cconstantL | 358 | constantD = constantAux cconstantL |
369 | extractR = extractAux c_extractL | 359 | extractR = extractAux c_extractL |
370 | setRect = setRectAux c_setRectL | 360 | setRect = setRectAux c_setRectL |
@@ -378,32 +368,6 @@ instance Element Z where | |||
378 | 368 | ||
379 | ------------------------------------------------------------------- | 369 | ------------------------------------------------------------------- |
380 | 370 | ||
381 | transdataAux fun c1 d c2 = | ||
382 | if noneed | ||
383 | then d | ||
384 | else unsafePerformIO $ do | ||
385 | -- putStrLn "T" | ||
386 | v <- createVector (dim d) | ||
387 | unsafeWith d $ \pd -> | ||
388 | unsafeWith v $ \pv -> | ||
389 | fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux" | ||
390 | return v | ||
391 | where r1 = dim d `div` c1 | ||
392 | r2 = dim d `div` c2 | ||
393 | noneed = dim d == 0 || r1 == 1 || c1 == 1 | ||
394 | |||
395 | |||
396 | type TMM t = t ..> t ..> Ok | ||
397 | |||
398 | foreign import ccall unsafe "transF" ctransF :: TMM Float | ||
399 | foreign import ccall unsafe "transR" ctransR :: TMM Double | ||
400 | foreign import ccall unsafe "transQ" ctransQ :: TMM (Complex Float) | ||
401 | foreign import ccall unsafe "transC" ctransC :: TMM (Complex Double) | ||
402 | foreign import ccall unsafe "transI" ctransI :: TMM CInt | ||
403 | foreign import ccall unsafe "transL" ctransL :: TMM Z | ||
404 | |||
405 | ---------------------------------------------------------------------- | ||
406 | |||
407 | subMatrix :: Element a | 371 | subMatrix :: Element a |
408 | => (Int,Int) -- ^ (r0,c0) starting position | 372 | => (Int,Int) -- ^ (r0,c0) starting position |
409 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | 373 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix |
@@ -411,9 +375,8 @@ subMatrix :: Element a | |||
411 | -> Matrix a -- ^ result | 375 | -> Matrix a -- ^ result |
412 | subMatrix (r0,c0) (rt,ct) m | 376 | subMatrix (r0,c0) (rt,ct) m |
413 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && | 377 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && |
414 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) | 378 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR RowMajor m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) |
415 | | otherwise = error $ "wrong subMatrix "++ | 379 | | otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m |
416 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | ||
417 | 380 | ||
418 | 381 | ||
419 | sliceMatrix :: Element a | 382 | sliceMatrix :: Element a |
@@ -424,11 +387,12 @@ sliceMatrix :: Element a | |||
424 | sliceMatrix (r0,c0) (rt,ct) m | 387 | sliceMatrix (r0,c0) (rt,ct) m |
425 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && | 388 | | 0 <= r0 && 0 <= rt && r0+rt <= rows m && |
426 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = res | 389 | 0 <= c0 && 0 <= ct && c0+ct <= cols m = res |
427 | | otherwise = error $ "wrong sliceMatrix "++ | 390 | | otherwise = error $ "wrong sliceMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m |
428 | show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) | ||
429 | where | 391 | where |
430 | t = r0 * xRow m + c0 * xCol m | 392 | p = r0 * xRow m + c0 * xCol m |
431 | res = m { irows = rt, icols = ct, xdat = subVector t (rt*ct) (xdat m) } | 393 | tot | rowOrder m = ct + (rt-1) * xRow m |
394 | | otherwise = rt + (ct-1) * xCol m | ||
395 | res = m { irows = rt, icols = ct, xdat = subVector p tot (xdat m) } | ||
432 | 396 | ||
433 | -------------------------------------------------------------------------- | 397 | -------------------------------------------------------------------------- |
434 | 398 | ||
@@ -449,7 +413,7 @@ conformMTo (r,c) m | |||
449 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | 413 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) |
450 | | size m == (r,1) = repCols c m | 414 | | size m == (r,1) = repCols c m |
451 | | size m == (1,c) = repRows r m | 415 | | size m == (1,c) = repRows r m |
452 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" | 416 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) |
453 | 417 | ||
454 | conformVTo n v | 418 | conformVTo n v |
455 | | dim v == n = v | 419 | | dim v == n = v |
@@ -459,9 +423,9 @@ conformVTo n v | |||
459 | repRows n x = fromRows (replicate n (flatten x)) | 423 | repRows n x = fromRows (replicate n (flatten x)) |
460 | repCols n x = fromColumns (replicate n (flatten x)) | 424 | repCols n x = fromColumns (replicate n (flatten x)) |
461 | 425 | ||
462 | size m = (rows m, cols m) | 426 | shSize = shDim . size |
463 | 427 | ||
464 | shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" | 428 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" |
465 | 429 | ||
466 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | 430 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) |
467 | 431 | ||
@@ -477,10 +441,10 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
477 | 441 | ||
478 | --------------------------------------------------------------- | 442 | --------------------------------------------------------------- |
479 | 443 | ||
480 | extractAux f m moder vr modec vc = do | 444 | extractAux f ord m moder vr modec vc = do |
481 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 445 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
482 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 446 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
483 | r <- createMatrix RowMajor nr nc | 447 | r <- createMatrix ord nr nc |
484 | f moder modec # vr # vc # m # r #|"extract" | 448 | f moder modec # vr # vc # m # r #|"extract" |
485 | return r | 449 | return r |
486 | 450 | ||
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 37f6e9b..c4f95d8 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -36,7 +36,7 @@ import Internal.LAPACK (multiplyI, multiplyL) | |||
36 | import Internal.Algorithms(luFact) | 36 | import Internal.Algorithms(luFact) |
37 | import Internal.Util(Normed(..),Indexable(..), | 37 | import Internal.Util(Normed(..),Indexable(..), |
38 | gaussElim, gaussElim_1, gaussElim_2, | 38 | gaussElim, gaussElim_1, gaussElim_2, |
39 | luST, luSolve', luPacked', magnit) | 39 | luST, luSolve', luPacked', magnit, invershur) |
40 | import Internal.ST(mutable) | 40 | import Internal.ST(mutable) |
41 | import GHC.TypeLits | 41 | import GHC.TypeLits |
42 | import Data.Proxy(Proxy) | 42 | import Data.Proxy(Proxy) |
@@ -126,9 +126,8 @@ instance forall n t . (Integral t, KnownNat n) => Num (Mod n t) | |||
126 | 126 | ||
127 | instance KnownNat m => Element (Mod m I) | 127 | instance KnownNat m => Element (Mod m I) |
128 | where | 128 | where |
129 | transdata n v m = i2f (transdata n (f2i v) m) | ||
130 | constantD x n = i2f (constantD (unMod x) n) | 129 | constantD x n = i2f (constantD (unMod x) n) |
131 | extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js | 130 | extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js |
132 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | 131 | setRect i j m x = setRect i j (f2iM m) (f2iM x) |
133 | sortI = sortI . f2i | 132 | sortI = sortI . f2i |
134 | sortV = i2f . sortV . f2i | 133 | sortV = i2f . sortV . f2i |
@@ -144,9 +143,8 @@ instance KnownNat m => Element (Mod m I) | |||
144 | 143 | ||
145 | instance KnownNat m => Element (Mod m Z) | 144 | instance KnownNat m => Element (Mod m Z) |
146 | where | 145 | where |
147 | transdata n v m = i2f (transdata n (f2i v) m) | ||
148 | constantD x n = i2f (constantD (unMod x) n) | 146 | constantD x n = i2f (constantD (unMod x) n) |
149 | extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js | 147 | extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js |
150 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | 148 | setRect i j m x = setRect i j (f2iM m) (f2iM x) |
151 | sortI = sortI . f2i | 149 | sortI = sortI . f2i |
152 | sortV = i2f . sortV . f2i | 150 | sortV = i2f . sortV . f2i |
@@ -293,11 +291,11 @@ f2i :: Storable t => Vector (Mod n t) -> Vector t | |||
293 | f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | 291 | f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) |
294 | where (fp,i,n) = unsafeToForeignPtr v | 292 | where (fp,i,n) = unsafeToForeignPtr v |
295 | 293 | ||
296 | f2iM :: Storable t => Matrix (Mod n t) -> Matrix t | 294 | f2iM :: (Element t, Element (Mod n t)) => Matrix (Mod n t) -> Matrix t |
297 | f2iM = liftMatrix f2i | 295 | f2iM m = m { xdat = f2i (xdat m) } |
298 | 296 | ||
299 | i2fM :: Storable t => Matrix t -> Matrix (Mod n t) | 297 | i2fM :: (Element t, Element (Mod n t)) => Matrix t -> Matrix (Mod n t) |
300 | i2fM = liftMatrix i2f | 298 | i2fM m = m { xdat = i2f (xdat m) } |
301 | 299 | ||
302 | vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) | 300 | vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) |
303 | vmod = i2f . cmod' m' | 301 | vmod = i2f . cmod' m' |
@@ -376,6 +374,8 @@ test = (ok, info) | |||
376 | bb = flipud aa | 374 | bb = flipud aa |
377 | x = luSolve' (luPacked' aa) bb | 375 | x = luSolve' (luPacked' aa) bb |
378 | 376 | ||
377 | tmm = diagRect 1 (fromList [2..6]) 5 5 :: Matrix (Mod 19 I) | ||
378 | |||
379 | info = do | 379 | info = do |
380 | print v | 380 | print v |
381 | print m | 381 | print m |
@@ -421,6 +421,9 @@ test = (ok, info) | |||
421 | print $ checkSolve (sgen 5 :: Matrix (Complex Float)) | 421 | print $ checkSolve (sgen 5 :: Matrix (Complex Float)) |
422 | print $ checkSolve (gen 5 :: Matrix (Mod 7 I)) | 422 | print $ checkSolve (gen 5 :: Matrix (Mod 7 I)) |
423 | print $ checkSolve (gen 5 :: Matrix (Mod 7 Z)) | 423 | print $ checkSolve (gen 5 :: Matrix (Mod 7 Z)) |
424 | |||
425 | print $ luSolve' (luPacked' tmm) (ident (rows tmm)) | ||
426 | print $ invershur tmm | ||
424 | 427 | ||
425 | 428 | ||
426 | ok = and | 429 | ok = and |
@@ -449,6 +452,7 @@ test = (ok, info) | |||
449 | , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) | 452 | , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) |
450 | , gm <> gm == konst 0 (3,3) | 453 | , gm <> gm == konst 0 (3,3) |
451 | , lgm <> lgm == konst 0 (3,3) | 454 | , lgm <> lgm == konst 0 (3,3) |
455 | , invershur tmm == luSolve' (luPacked' tmm) (ident (rows tmm)) | ||
452 | ] | 456 | ] |
453 | 457 | ||
454 | 458 | ||
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 92654e4..73cdf0c 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -223,7 +223,7 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m | |||
223 | i2' = i2 `mod` (rows m) | 223 | i2' = i2 `mod` (rows m) |
224 | 224 | ||
225 | 225 | ||
226 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | 226 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) |
227 | where | 227 | where |
228 | (i1,i2) = getRowRange (rows m) rr | 228 | (i1,i2) = getRowRange (rows m) rr |
229 | (j1,j2) = getColRange (cols m) rc | 229 | (j1,j2) = getColRange (cols m) rc |
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index bf6c8b6..98eb4ef 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -54,7 +54,9 @@ module Internal.Util( | |||
54 | -- ** 2D | 54 | -- ** 2D |
55 | corr2, conv2, separable, | 55 | corr2, conv2, separable, |
56 | block2x2,block3x3,view1,unView1,foldMatrix, | 56 | block2x2,block3x3,view1,unView1,foldMatrix, |
57 | gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked', luPacked'' | 57 | gaussElim_1, gaussElim_2, gaussElim, |
58 | luST, luSolve', luSolve'', luPacked', luPacked'', | ||
59 | invershur | ||
58 | ) where | 60 | ) where |
59 | 61 | ||
60 | import Internal.Vector | 62 | import Internal.Vector |
@@ -829,6 +831,45 @@ luSolve' (lup,p) b = backSust lup (forwSust lup pb) | |||
829 | where | 831 | where |
830 | pb = b ?? (Pos (fixPerm' p), All) | 832 | pb = b ?? (Pos (fixPerm' p), All) |
831 | 833 | ||
834 | |||
835 | -------------------------------------------------------------------------------- | ||
836 | |||
837 | data MatrixView t b | ||
838 | = Elem t | ||
839 | | Block b b b b | ||
840 | deriving Show | ||
841 | |||
842 | |||
843 | viewBlock' r c m | ||
844 | | (rt,ct) == (1,1) = Elem (atM' m 0 0) | ||
845 | | otherwise = Block m11 m12 m21 m22 | ||
846 | where | ||
847 | (rt,ct) = size m | ||
848 | m11 = sliceMatrix (0,0) (r,c) m | ||
849 | m12 = sliceMatrix (0,c) (r,ct-c) m | ||
850 | m21 = sliceMatrix (r,0) (rt-r,c) m | ||
851 | m22 = sliceMatrix (r,c) (rt-r,ct-c) m | ||
852 | |||
853 | viewBlock m = viewBlock' n n m | ||
854 | where | ||
855 | n = rows m `div` 2 | ||
856 | |||
857 | invershur (viewBlock -> Block a b c d) = fromBlocks [[a',b'],[c',d']] | ||
858 | where | ||
859 | r1 = invershur a | ||
860 | r2 = c <> r1 | ||
861 | r3 = r1 <> b | ||
862 | r4 = c <> r3 | ||
863 | r5 = r4-d | ||
864 | r6 = invershur r5 | ||
865 | b' = r3 <> r6 | ||
866 | c' = r6 <> r2 | ||
867 | r7 = r3 <> c' | ||
868 | a' = r1-r7 | ||
869 | d' = -r6 | ||
870 | |||
871 | invershur x = recip x | ||
872 | |||
832 | -------------------------------------------------------------------------------- | 873 | -------------------------------------------------------------------------------- |
833 | 874 | ||
834 | instance Testable (Matrix I) where | 875 | instance Testable (Matrix I) where |