summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c20
-rw-r--r--packages/base/src/Internal/Element.hs12
-rw-r--r--packages/base/src/Internal/LAPACK.hs2
-rw-r--r--packages/base/src/Internal/Matrix.hs190
-rw-r--r--packages/base/src/Internal/Modular.hs22
-rw-r--r--packages/base/src/Internal/ST.hs2
-rw-r--r--packages/base/src/Internal/Util.hs43
7 files changed, 143 insertions, 148 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index 4d48594..cdbaab9 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1485,26 +1485,6 @@ int smTXv(KDVEC(vals),KIVEC(cols),KIVEC(rows),KDVEC(x),DVEC(r)) {
1485} 1485}
1486 1486
1487 1487
1488//////////////////// transpose /////////////////////////
1489
1490#define TRANS_IMP { \
1491 REQUIRES(xr==tc && xc==tr,BAD_SIZE); \
1492 DEBUGMSG("trans"); \
1493 int i,j; \
1494 for (i=0; i<tr; i++) { \
1495 for (j=0; j<tc; j++) { \
1496 tp[i*tc+j] = xp[j*xc+i]; \
1497 } \
1498 } \
1499 OK }
1500
1501int transF(KFMAT(x),FMAT(t)) TRANS_IMP
1502int transR(KDMAT(x),DMAT(t)) TRANS_IMP
1503int transQ(KQMAT(x),QMAT(t)) TRANS_IMP
1504int transC(KCMAT(x),CMAT(t)) TRANS_IMP
1505int transI(KIMAT(x),IMAT(t)) TRANS_IMP
1506int transL(KLMAT(x),LMAT(t)) TRANS_IMP
1507
1508//////////////////////// extract ///////////////////////////////// 1488//////////////////////// extract /////////////////////////////////
1509 1489
1510#define EXTRACT_IMP { \ 1490#define EXTRACT_IMP { \
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs
index 51d5686..6d86f3d 100644
--- a/packages/base/src/Internal/Element.hs
+++ b/packages/base/src/Internal/Element.hs
@@ -173,7 +173,7 @@ m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n))
173m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) 173m ?? (DropLast n, e) = m ?? (Take (rows m - n), e)
174m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) 174m ?? (e, DropLast n) = m ?? (e, Take (cols m - n))
175 175
176m ?? (er,ec) = unsafePerformIO $ extractR m moder rs modec cs 176m ?? (er,ec) = unsafePerformIO $ extractR (orderOf m) m moder rs modec cs
177 where 177 where
178 (moder,rs) = mkExt (rows m) er 178 (moder,rs) = mkExt (rows m) er
179 (modec,cs) = mkExt (cols m) ec 179 (modec,cs) = mkExt (cols m) ec
@@ -491,9 +491,13 @@ liftMatrix2Auto f m1 m2
491-- FIXME do not flatten if equal order 491-- FIXME do not flatten if equal order
492lM f m1 m2 = matrixFromVector 492lM f m1 m2 = matrixFromVector
493 RowMajor 493 RowMajor
494 (max (rows m1) (rows m2)) 494 (max' (rows m1) (rows m2))
495 (max (cols m1) (cols m2)) 495 (max' (cols m1) (cols m2))
496 (f (flatten m1) (flatten m2)) 496 (f (flatten m1) (flatten m2))
497 where
498 max' 1 b = b
499 max' a 1 = a
500 max' a b = max a b
497 501
498compat' :: Matrix a -> Matrix b -> Bool 502compat' :: Matrix a -> Matrix b -> Bool
499compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 503compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
@@ -595,6 +599,6 @@ mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m
595 where 599 where
596 c = cols m 600 c = cols m
597 601
598mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b 602mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b
599mapMatrix f = liftMatrix (mapVector f) 603mapMatrix f = liftMatrix (mapVector f)
600 604
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index fc9e3ad..5319e95 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -418,6 +418,8 @@ linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b)
418 418
419----------------------------------------------------------------------------------- 419-----------------------------------------------------------------------------------
420 420
421type TMM t = t ..> t ..> Ok
422
421foreign import ccall unsafe "chol_l_H" zpotrf :: TMM C 423foreign import ccall unsafe "chol_l_H" zpotrf :: TMM C
422foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R 424foreign import ccall unsafe "chol_l_S" dpotrf :: TMM R
423 425
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index f76b9dc..bdf2785 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -32,49 +32,13 @@ import Foreign.C.Types ( CInt(..) )
32import Foreign.C.String ( CString, newCString ) 32import Foreign.C.String ( CString, newCString )
33import System.IO.Unsafe ( unsafePerformIO ) 33import System.IO.Unsafe ( unsafePerformIO )
34import Control.DeepSeq ( NFData(..) ) 34import Control.DeepSeq ( NFData(..) )
35import Data.List.Split(chunksOf) 35import Text.Printf
36 36
37----------------------------------------------------------------- 37-----------------------------------------------------------------
38 38
39{- Design considerations for the Matrix Type
40 -----------------------------------------
41
42- we must easily handle both row major and column major order,
43 for bindings to LAPACK and GSL/C
44
45- we'd like to simplify redundant matrix transposes:
46 - Some of them arise from the order requirements of some functions
47 - some functions (matrix product) admit transposed arguments
48
49- maybe we don't really need this kind of simplification:
50 - more complex code
51 - some computational overhead
52 - only appreciable gain in code with a lot of redundant transpositions
53 and cheap matrix computations
54
55- we could carry both the matrix and its (lazily computed) transpose.
56 This may save some transpositions, but it is necessary to keep track of the
57 data which is actually computed to be used by functions like the matrix product
58 which admit both orders.
59
60- but if we need the transposed data and it is not in the structure, we must make
61 sure that we touch the same foreignptr that is used in the computation.
62
63- a reasonable solution is using two constructors for a matrix. Transposition just
64 "flips" the constructor. Actual data transposition is not done if followed by a
65 matrix product or another transpose.
66
67-}
68
69data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) 39data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
70 40
71transOrder RowMajor = ColumnMajor 41-- | Matrix representation suitable for BLAS\/LAPACK computations.
72transOrder ColumnMajor = RowMajor
73{- | Matrix representation suitable for BLAS\/LAPACK computations.
74
75The elements are stored in a continuous memory array.
76
77-}
78 42
79data Matrix t = Matrix 43data Matrix t = Matrix
80 { irows :: {-# UNPACK #-} !Int 44 { irows :: {-# UNPACK #-} !Int
@@ -83,8 +47,6 @@ data Matrix t = Matrix
83 , xCol :: {-# UNPACK #-} !Int 47 , xCol :: {-# UNPACK #-} !Int
84 , xdat :: {-# UNPACK #-} !(Vector t) 48 , xdat :: {-# UNPACK #-} !(Vector t)
85 } 49 }
86-- RowMajor: preferred by C, fdat may require a transposition
87-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
88 50
89 51
90rows :: Matrix t -> Int 52rows :: Matrix t -> Int
@@ -95,32 +57,55 @@ cols :: Matrix t -> Int
95cols = icols 57cols = icols
96{-# INLINE cols #-} 58{-# INLINE cols #-}
97 59
98rowOrder m = xRow m > 1 60size m = (irows m, icols m)
61{-# INLINE size #-}
62
63rowOrder m = xCol m == 1 || cols m == 1
99{-# INLINE rowOrder #-} 64{-# INLINE rowOrder #-}
100 65
101isSlice m = cols m < xRow m || rows m < xCol m 66colOrder m = xRow m == 1 || rows m == 1
67{-# INLINE colOrder #-}
68
69is1d (size->(r,c)) = r==1 || c==1
70{-# INLINE is1d #-}
71
72-- data is not contiguous
73isSlice m@(size->(r,c)) = (c < xRow m || r < xCol m) && min r c > 1
102{-# INLINE isSlice #-} 74{-# INLINE isSlice #-}
103 75
104orderOf :: Matrix t -> MatrixOrder 76orderOf :: Matrix t -> MatrixOrder
105orderOf m = if rowOrder m then RowMajor else ColumnMajor 77orderOf m = if rowOrder m then RowMajor else ColumnMajor
106 78
107 79
80showInternal :: Storable t => Matrix t -> IO ()
81showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv
82 where
83 r = rows m
84 c = cols m
85 xr = xRow m
86 xc = xCol m
87 slc = if isSlice m then "slice" else "full"
88 ord = if is1d m then "1d" else if rowOrder m then "rows" else "cols"
89 dv = dim (xdat m)
90
91--------------------------------------------------------------------------------
92
108-- | Matrix transpose. 93-- | Matrix transpose.
109trans :: Matrix t -> Matrix t 94trans :: Matrix t -> Matrix t
110trans m@Matrix { irows = r, icols = c } | rowOrder m = 95trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } =
111 m { irows = c, icols = r, xRow = 1, xCol = c } 96 m { irows = c, icols = r, xRow = xc, xCol = xr }
112trans m@Matrix { irows = r, icols = c } = 97
113 m { irows = c, icols = r, xRow = r, xCol = 1 }
114 98
115cmat :: (Element t) => Matrix t -> Matrix t 99cmat :: (Element t) => Matrix t -> Matrix t
116cmat m | rowOrder m = m 100cmat m
117cmat m@Matrix { irows = r, icols = c, xdat = d } = 101 | rowOrder m = m
118 m { xdat = transdata r d c, xRow = c, xCol = 1 } 102 | otherwise = extractAll RowMajor m
103
119 104
120fmat :: (Element t) => Matrix t -> Matrix t 105fmat :: (Element t) => Matrix t -> Matrix t
121fmat m | not (rowOrder m) = m 106fmat m
122fmat m@Matrix { irows = r, icols = c, xdat = d} = 107 | colOrder m = m
123 m { xdat = transdata c d r, xRow = 1, xCol = r } 108 | otherwise = extractAll ColumnMajor m
124 109
125 110
126-- C-Haskell matrix adapters 111-- C-Haskell matrix adapters
@@ -157,6 +142,11 @@ a # b = apply a b
157 142
158-------------------------------------------------------------------------------- 143--------------------------------------------------------------------------------
159 144
145extractAll ord m = unsafePerformIO $
146 extractR ord m
147 0 (idxs[0,rows m-1])
148 0 (idxs[0,cols m-1])
149
160{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 150{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
161 151
162>>> flatten (ident 3) 152>>> flatten (ident 3)
@@ -164,12 +154,14 @@ fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
164 154
165-} 155-}
166flatten :: Element t => Matrix t -> Vector t 156flatten :: Element t => Matrix t -> Vector t
167flatten = xdat . cmat 157flatten m
158 | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m)
159 | otherwise = xdat m
168 160
169 161
170-- | the inverse of 'Data.Packed.Matrix.fromLists' 162-- | the inverse of 'Data.Packed.Matrix.fromLists'
171toLists :: (Element t) => Matrix t -> [[t]] 163toLists :: (Element t) => Matrix t -> [[t]]
172toLists m = chunksOf (cols m) . toList . flatten $ m 164toLists = map toList . toRows
173 165
174 166
175 167
@@ -205,6 +197,14 @@ fromRows vs = case compatdim (map dim vs) of
205-- | extracts the rows of a matrix as a list of vectors 197-- | extracts the rows of a matrix as a list of vectors
206toRows :: Element t => Matrix t -> [Vector t] 198toRows :: Element t => Matrix t -> [Vector t]
207toRows m 199toRows m
200 | rowOrder m = map sub rowRange
201 | otherwise = map ext rowRange
202 where
203 rowRange = [0..rows m-1]
204 sub k = subVector (k*xRow m) (cols m) (xdat m)
205 ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1])
206
207{-
208 | c == 0 = replicate r (fromList[]) 208 | c == 0 = replicate r (fromList[])
209 | otherwise = toRows' 0 209 | otherwise = toRows' 0
210 where 210 where
@@ -213,6 +213,7 @@ toRows m
213 c = cols m 213 c = cols m
214 toRows' k | k == r*c = [] 214 toRows' k | k == r*c = []
215 | otherwise = subVector k c v : toRows' (k+c) 215 | otherwise = subVector k c v : toRows' (k+c)
216-}
216 217
217-- | Creates a matrix from a list of vectors, as columns 218-- | Creates a matrix from a list of vectors, as columns
218fromColumns :: Element t => [Vector t] -> Matrix t 219fromColumns :: Element t => [Vector t] -> Matrix t
@@ -240,7 +241,7 @@ matrixFromVector o r c v
240 | r * c == dim v = m 241 | r * c == dim v = m
241 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m 242 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
242 where 243 where
243 m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 } 244 m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 }
244 | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r } 245 | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r }
245 246
246-- allocates memory for a new matrix 247-- allocates memory for a new matrix
@@ -263,31 +264,26 @@ reshape :: Storable t => Int -> Vector t -> Matrix t
263reshape 0 v = matrixFromVector RowMajor 0 0 v 264reshape 0 v = matrixFromVector RowMajor 0 0 v
264reshape c v = matrixFromVector RowMajor (dim v `div` c) c v 265reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
265 266
266--singleton x = reshape 1 (fromList [x])
267 267
268-- | application of a vector function on the flattened matrix elements 268-- | application of a vector function on the flattened matrix elements
269liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 269liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
270liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} = matrixFromVector (orderOf m) r c (f d) 270liftMatrix f m@Matrix { irows = r, icols = c, xdat = d}
271 | isSlice m = matrixFromVector RowMajor r c (f (flatten m))
272 | otherwise = matrixFromVector (orderOf m) r c (f d)
271 273
272-- | application of a vector function on the flattened matrices elements 274-- | application of a vector function on the flattened matrices elements
273liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 275liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
274liftMatrix2 f m1 m2 276liftMatrix2 f m1@(size->(r,c)) m2
275 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2" 277 | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2"
276 | otherwise = case orderOf m1 of 278 | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2))
277 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2)) 279 | otherwise = matrixFromVector ColumnMajor r c (f (flatten (trans m1)) (flatten (trans m2)))
278 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
279
280
281compat :: Matrix a -> Matrix b -> Bool
282compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
283 280
284------------------------------------------------------------------ 281------------------------------------------------------------------
285 282
286-- | Supported matrix elements. 283-- | Supported matrix elements.
287class (Storable a) => Element a where 284class (Storable a) => Element a where
288 transdata :: Int -> Vector a -> Int -> Vector a
289 constantD :: a -> Int -> Vector a 285 constantD :: a -> Int -> Vector a
290 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) 286 extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
291 setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () 287 setRect :: Int -> Int -> Matrix a -> Matrix a -> IO ()
292 sortI :: Ord a => Vector a -> Vector CInt 288 sortI :: Ord a => Vector a -> Vector CInt
293 sortV :: Ord a => Vector a -> Vector a 289 sortV :: Ord a => Vector a -> Vector a
@@ -299,7 +295,6 @@ class (Storable a) => Element a where
299 295
300 296
301instance Element Float where 297instance Element Float where
302 transdata = transdataAux ctransF
303 constantD = constantAux cconstantF 298 constantD = constantAux cconstantF
304 extractR = extractAux c_extractF 299 extractR = extractAux c_extractF
305 setRect = setRectAux c_setRectF 300 setRect = setRectAux c_setRectF
@@ -312,7 +307,6 @@ instance Element Float where
312 gemm = gemmg c_gemmF 307 gemm = gemmg c_gemmF
313 308
314instance Element Double where 309instance Element Double where
315 transdata = transdataAux ctransR
316 constantD = constantAux cconstantR 310 constantD = constantAux cconstantR
317 extractR = extractAux c_extractD 311 extractR = extractAux c_extractD
318 setRect = setRectAux c_setRectD 312 setRect = setRectAux c_setRectD
@@ -325,7 +319,6 @@ instance Element Double where
325 gemm = gemmg c_gemmD 319 gemm = gemmg c_gemmD
326 320
327instance Element (Complex Float) where 321instance Element (Complex Float) where
328 transdata = transdataAux ctransQ
329 constantD = constantAux cconstantQ 322 constantD = constantAux cconstantQ
330 extractR = extractAux c_extractQ 323 extractR = extractAux c_extractQ
331 setRect = setRectAux c_setRectQ 324 setRect = setRectAux c_setRectQ
@@ -338,7 +331,6 @@ instance Element (Complex Float) where
338 gemm = gemmg c_gemmQ 331 gemm = gemmg c_gemmQ
339 332
340instance Element (Complex Double) where 333instance Element (Complex Double) where
341 transdata = transdataAux ctransC
342 constantD = constantAux cconstantC 334 constantD = constantAux cconstantC
343 extractR = extractAux c_extractC 335 extractR = extractAux c_extractC
344 setRect = setRectAux c_setRectC 336 setRect = setRectAux c_setRectC
@@ -351,7 +343,6 @@ instance Element (Complex Double) where
351 gemm = gemmg c_gemmC 343 gemm = gemmg c_gemmC
352 344
353instance Element (CInt) where 345instance Element (CInt) where
354 transdata = transdataAux ctransI
355 constantD = constantAux cconstantI 346 constantD = constantAux cconstantI
356 extractR = extractAux c_extractI 347 extractR = extractAux c_extractI
357 setRect = setRectAux c_setRectI 348 setRect = setRectAux c_setRectI
@@ -364,7 +355,6 @@ instance Element (CInt) where
364 gemm = gemmg c_gemmI 355 gemm = gemmg c_gemmI
365 356
366instance Element Z where 357instance Element Z where
367 transdata = transdataAux ctransL
368 constantD = constantAux cconstantL 358 constantD = constantAux cconstantL
369 extractR = extractAux c_extractL 359 extractR = extractAux c_extractL
370 setRect = setRectAux c_setRectL 360 setRect = setRectAux c_setRectL
@@ -378,32 +368,6 @@ instance Element Z where
378 368
379------------------------------------------------------------------- 369-------------------------------------------------------------------
380 370
381transdataAux fun c1 d c2 =
382 if noneed
383 then d
384 else unsafePerformIO $ do
385 -- putStrLn "T"
386 v <- createVector (dim d)
387 unsafeWith d $ \pd ->
388 unsafeWith v $ \pv ->
389 fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
390 return v
391 where r1 = dim d `div` c1
392 r2 = dim d `div` c2
393 noneed = dim d == 0 || r1 == 1 || c1 == 1
394
395
396type TMM t = t ..> t ..> Ok
397
398foreign import ccall unsafe "transF" ctransF :: TMM Float
399foreign import ccall unsafe "transR" ctransR :: TMM Double
400foreign import ccall unsafe "transQ" ctransQ :: TMM (Complex Float)
401foreign import ccall unsafe "transC" ctransC :: TMM (Complex Double)
402foreign import ccall unsafe "transI" ctransI :: TMM CInt
403foreign import ccall unsafe "transL" ctransL :: TMM Z
404
405----------------------------------------------------------------------
406
407subMatrix :: Element a 371subMatrix :: Element a
408 => (Int,Int) -- ^ (r0,c0) starting position 372 => (Int,Int) -- ^ (r0,c0) starting position
409 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 373 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
@@ -411,9 +375,8 @@ subMatrix :: Element a
411 -> Matrix a -- ^ result 375 -> Matrix a -- ^ result
412subMatrix (r0,c0) (rt,ct) m 376subMatrix (r0,c0) (rt,ct) m
413 | 0 <= r0 && 0 <= rt && r0+rt <= rows m && 377 | 0 <= r0 && 0 <= rt && r0+rt <= rows m &&
414 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) 378 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR RowMajor m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1])
415 | otherwise = error $ "wrong subMatrix "++ 379 | otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m
416 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
417 380
418 381
419sliceMatrix :: Element a 382sliceMatrix :: Element a
@@ -424,11 +387,12 @@ sliceMatrix :: Element a
424sliceMatrix (r0,c0) (rt,ct) m 387sliceMatrix (r0,c0) (rt,ct) m
425 | 0 <= r0 && 0 <= rt && r0+rt <= rows m && 388 | 0 <= r0 && 0 <= rt && r0+rt <= rows m &&
426 0 <= c0 && 0 <= ct && c0+ct <= cols m = res 389 0 <= c0 && 0 <= ct && c0+ct <= cols m = res
427 | otherwise = error $ "wrong sliceMatrix "++ 390 | otherwise = error $ "wrong sliceMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m
428 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
429 where 391 where
430 t = r0 * xRow m + c0 * xCol m 392 p = r0 * xRow m + c0 * xCol m
431 res = m { irows = rt, icols = ct, xdat = subVector t (rt*ct) (xdat m) } 393 tot | rowOrder m = ct + (rt-1) * xRow m
394 | otherwise = rt + (ct-1) * xCol m
395 res = m { irows = rt, icols = ct, xdat = subVector p tot (xdat m) }
432 396
433-------------------------------------------------------------------------- 397--------------------------------------------------------------------------
434 398
@@ -449,7 +413,7 @@ conformMTo (r,c) m
449 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) 413 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
450 | size m == (r,1) = repCols c m 414 | size m == (r,1) = repCols c m
451 | size m == (1,c) = repRows r m 415 | size m == (1,c) = repRows r m
452 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")" 416 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c)
453 417
454conformVTo n v 418conformVTo n v
455 | dim v == n = v 419 | dim v == n = v
@@ -459,9 +423,9 @@ conformVTo n v
459repRows n x = fromRows (replicate n (flatten x)) 423repRows n x = fromRows (replicate n (flatten x))
460repCols n x = fromColumns (replicate n (flatten x)) 424repCols n x = fromColumns (replicate n (flatten x))
461 425
462size m = (rows m, cols m) 426shSize = shDim . size
463 427
464shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")" 428shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
465 429
466emptyM r c = matrixFromVector RowMajor r c (fromList[]) 430emptyM r c = matrixFromVector RowMajor r c (fromList[])
467 431
@@ -477,10 +441,10 @@ instance (Storable t, NFData t) => NFData (Matrix t)
477 441
478--------------------------------------------------------------- 442---------------------------------------------------------------
479 443
480extractAux f m moder vr modec vc = do 444extractAux f ord m moder vr modec vc = do
481 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 445 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
482 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 446 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
483 r <- createMatrix RowMajor nr nc 447 r <- createMatrix ord nr nc
484 f moder modec # vr # vc # m # r #|"extract" 448 f moder modec # vr # vc # m # r #|"extract"
485 return r 449 return r
486 450
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 37f6e9b..c4f95d8 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -36,7 +36,7 @@ import Internal.LAPACK (multiplyI, multiplyL)
36import Internal.Algorithms(luFact) 36import Internal.Algorithms(luFact)
37import Internal.Util(Normed(..),Indexable(..), 37import Internal.Util(Normed(..),Indexable(..),
38 gaussElim, gaussElim_1, gaussElim_2, 38 gaussElim, gaussElim_1, gaussElim_2,
39 luST, luSolve', luPacked', magnit) 39 luST, luSolve', luPacked', magnit, invershur)
40import Internal.ST(mutable) 40import Internal.ST(mutable)
41import GHC.TypeLits 41import GHC.TypeLits
42import Data.Proxy(Proxy) 42import Data.Proxy(Proxy)
@@ -126,9 +126,8 @@ instance forall n t . (Integral t, KnownNat n) => Num (Mod n t)
126 126
127instance KnownNat m => Element (Mod m I) 127instance KnownNat m => Element (Mod m I)
128 where 128 where
129 transdata n v m = i2f (transdata n (f2i v) m)
130 constantD x n = i2f (constantD (unMod x) n) 129 constantD x n = i2f (constantD (unMod x) n)
131 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js 130 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js
132 setRect i j m x = setRect i j (f2iM m) (f2iM x) 131 setRect i j m x = setRect i j (f2iM m) (f2iM x)
133 sortI = sortI . f2i 132 sortI = sortI . f2i
134 sortV = i2f . sortV . f2i 133 sortV = i2f . sortV . f2i
@@ -144,9 +143,8 @@ instance KnownNat m => Element (Mod m I)
144 143
145instance KnownNat m => Element (Mod m Z) 144instance KnownNat m => Element (Mod m Z)
146 where 145 where
147 transdata n v m = i2f (transdata n (f2i v) m)
148 constantD x n = i2f (constantD (unMod x) n) 146 constantD x n = i2f (constantD (unMod x) n)
149 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js 147 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js
150 setRect i j m x = setRect i j (f2iM m) (f2iM x) 148 setRect i j m x = setRect i j (f2iM m) (f2iM x)
151 sortI = sortI . f2i 149 sortI = sortI . f2i
152 sortV = i2f . sortV . f2i 150 sortV = i2f . sortV . f2i
@@ -293,11 +291,11 @@ f2i :: Storable t => Vector (Mod n t) -> Vector t
293f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) 291f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
294 where (fp,i,n) = unsafeToForeignPtr v 292 where (fp,i,n) = unsafeToForeignPtr v
295 293
296f2iM :: Storable t => Matrix (Mod n t) -> Matrix t 294f2iM :: (Element t, Element (Mod n t)) => Matrix (Mod n t) -> Matrix t
297f2iM = liftMatrix f2i 295f2iM m = m { xdat = f2i (xdat m) }
298 296
299i2fM :: Storable t => Matrix t -> Matrix (Mod n t) 297i2fM :: (Element t, Element (Mod n t)) => Matrix t -> Matrix (Mod n t)
300i2fM = liftMatrix i2f 298i2fM m = m { xdat = i2f (xdat m) }
301 299
302vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) 300vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t)
303vmod = i2f . cmod' m' 301vmod = i2f . cmod' m'
@@ -376,6 +374,8 @@ test = (ok, info)
376 bb = flipud aa 374 bb = flipud aa
377 x = luSolve' (luPacked' aa) bb 375 x = luSolve' (luPacked' aa) bb
378 376
377 tmm = diagRect 1 (fromList [2..6]) 5 5 :: Matrix (Mod 19 I)
378
379 info = do 379 info = do
380 print v 380 print v
381 print m 381 print m
@@ -421,6 +421,9 @@ test = (ok, info)
421 print $ checkSolve (sgen 5 :: Matrix (Complex Float)) 421 print $ checkSolve (sgen 5 :: Matrix (Complex Float))
422 print $ checkSolve (gen 5 :: Matrix (Mod 7 I)) 422 print $ checkSolve (gen 5 :: Matrix (Mod 7 I))
423 print $ checkSolve (gen 5 :: Matrix (Mod 7 Z)) 423 print $ checkSolve (gen 5 :: Matrix (Mod 7 Z))
424
425 print $ luSolve' (luPacked' tmm) (ident (rows tmm))
426 print $ invershur tmm
424 427
425 428
426 ok = and 429 ok = and
@@ -449,6 +452,7 @@ test = (ok, info)
449 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) 452 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I))
450 , gm <> gm == konst 0 (3,3) 453 , gm <> gm == konst 0 (3,3)
451 , lgm <> lgm == konst 0 (3,3) 454 , lgm <> lgm == konst 0 (3,3)
455 , invershur tmm == luSolve' (luPacked' tmm) (ident (rows tmm))
452 ] 456 ]
453 457
454 458
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 92654e4..73cdf0c 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -223,7 +223,7 @@ rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m
223 i2' = i2 `mod` (rows m) 223 i2' = i2 `mod` (rows m)
224 224
225 225
226extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) 226extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR (orderOf m) m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
227 where 227 where
228 (i1,i2) = getRowRange (rows m) rr 228 (i1,i2) = getRowRange (rows m) rr
229 (j1,j2) = getColRange (cols m) rc 229 (j1,j2) = getColRange (cols m) rc
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index bf6c8b6..98eb4ef 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -54,7 +54,9 @@ module Internal.Util(
54 -- ** 2D 54 -- ** 2D
55 corr2, conv2, separable, 55 corr2, conv2, separable,
56 block2x2,block3x3,view1,unView1,foldMatrix, 56 block2x2,block3x3,view1,unView1,foldMatrix,
57 gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked', luPacked'' 57 gaussElim_1, gaussElim_2, gaussElim,
58 luST, luSolve', luSolve'', luPacked', luPacked'',
59 invershur
58) where 60) where
59 61
60import Internal.Vector 62import Internal.Vector
@@ -829,6 +831,45 @@ luSolve' (lup,p) b = backSust lup (forwSust lup pb)
829 where 831 where
830 pb = b ?? (Pos (fixPerm' p), All) 832 pb = b ?? (Pos (fixPerm' p), All)
831 833
834
835--------------------------------------------------------------------------------
836
837data MatrixView t b
838 = Elem t
839 | Block b b b b
840 deriving Show
841
842
843viewBlock' r c m
844 | (rt,ct) == (1,1) = Elem (atM' m 0 0)
845 | otherwise = Block m11 m12 m21 m22
846 where
847 (rt,ct) = size m
848 m11 = sliceMatrix (0,0) (r,c) m
849 m12 = sliceMatrix (0,c) (r,ct-c) m
850 m21 = sliceMatrix (r,0) (rt-r,c) m
851 m22 = sliceMatrix (r,c) (rt-r,ct-c) m
852
853viewBlock m = viewBlock' n n m
854 where
855 n = rows m `div` 2
856
857invershur (viewBlock -> Block a b c d) = fromBlocks [[a',b'],[c',d']]
858 where
859 r1 = invershur a
860 r2 = c <> r1
861 r3 = r1 <> b
862 r4 = c <> r3
863 r5 = r4-d
864 r6 = invershur r5
865 b' = r3 <> r6
866 c' = r6 <> r2
867 r7 = r3 <> c'
868 a' = r1-r7
869 d' = -r6
870
871invershur x = recip x
872
832-------------------------------------------------------------------------------- 873--------------------------------------------------------------------------------
833 874
834instance Testable (Matrix I) where 875instance Testable (Matrix I) where