summaryrefslogtreecommitdiff
path: root/packages
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-12 20:58:13 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-12 20:58:13 +0200
commit4b3e29097aa272d429f8005fe17b459cf0c049c8 (patch)
treedf01591ec7bdffe61f68062cc09e95f69e745a90 /packages
parent0396adb9f10f5b337e54d64fec365c9cb01e9745 (diff)
row ops in ST
Diffstat (limited to 'packages')
-rw-r--r--packages/base/hmatrix.cabal1
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c87
-rw-r--r--packages/base/src/Internal/C/lapack-aux.h1
-rw-r--r--packages/base/src/Internal/C/vector-aux.c7
-rw-r--r--packages/base/src/Internal/Element.hs3
-rw-r--r--packages/base/src/Internal/Matrix.hs32
-rw-r--r--packages/base/src/Internal/Modular.hs43
-rw-r--r--packages/base/src/Internal/ST.hs26
-rw-r--r--packages/base/src/Internal/Util.hs80
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs2
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Devel.hs1
11 files changed, 263 insertions, 20 deletions
diff --git a/packages/base/hmatrix.cabal b/packages/base/hmatrix.cabal
index 0ab4821..f725341 100644
--- a/packages/base/hmatrix.cabal
+++ b/packages/base/hmatrix.cabal
@@ -81,6 +81,7 @@ library
81 ghc-options: -Wall 81 ghc-options: -Wall
82 -fno-warn-missing-signatures 82 -fno-warn-missing-signatures
83 -fno-warn-orphans 83 -fno-warn-orphans
84 -fprof-auto
84 85
85 cc-options: -O4 -Wall 86 cc-options: -O4 -Wall
86 87
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index dcce1c5..e42889d 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -4,6 +4,12 @@
4#include <math.h> 4#include <math.h>
5#include <time.h> 5#include <time.h>
6#include <inttypes.h> 6#include <inttypes.h>
7#include <complex.h>
8
9typedef double complex TCD;
10typedef float complex TCF;
11
12#undef complex
7 13
8#include "lapack-aux.h" 14#include "lapack-aux.h"
9 15
@@ -46,6 +52,10 @@
46#define NODEFPOS 2006 52#define NODEFPOS 2006
47#define NOSPRTD 2007 53#define NOSPRTD 2007
48 54
55inline int mod (int a, int b);
56
57inline int64_t mod_l (int64_t a, int64_t b);
58
49//--------------------------------------- 59//---------------------------------------
50void asm_finit() { 60void asm_finit() {
51#ifdef i386 61#ifdef i386
@@ -1310,6 +1320,83 @@ int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) {
1310int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP 1320int multiplyI(int m, KOIMAT(a), KOIMAT(b), OIMAT(r)) MULT_IMP
1311int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP 1321int multiplyL(int64_t m, KOLMAT(a), KOLMAT(b), OLMAT(r)) MULT_IMP
1312 1322
1323/////////////////////////////// inplace row ops ////////////////////////////////
1324
1325#define AXPY_IMP { \
1326 int j; \
1327 for(j=j1; j<=j2; j++) { \
1328 AT(r,i2,j) += a*AT(r,i1,j); \
1329 } OK }
1330
1331#define AXPY_MOD_IMP(M) { \
1332 int j; \
1333 for(j=j1; j<=j2; j++) { \
1334 AT(r,i2,j) = M(AT(r,i2,j) + M(a*AT(r,i1,j), m) , m); \
1335 } OK }
1336
1337
1338#define SCAL_IMP { \
1339 int i,j; \
1340 for(i=i1; i<=i2; i++) { \
1341 for(j=j1; j<=j2; j++) { \
1342 AT(r,i,j) = a*AT(r,i,j); \
1343 } \
1344 } OK }
1345
1346#define SCAL_MOD_IMP(M) { \
1347 int i,j; \
1348 for(i=i1; i<=i2; i++) { \
1349 for(j=j1; j<=j2; j++) { \
1350 AT(r,i,j) = M(a*AT(r,i,j) , m); \
1351 } \
1352 } OK }
1353
1354
1355#define SWAP_IMP(T) { \
1356 T aux; \
1357 int k; \
1358 if (i1 != i2) { \
1359 for (k=j1; k<=j2; k++) { \
1360 aux = AT(r,i1,k); \
1361 AT(r,i1,k) = AT(r,i2,k); \
1362 AT(r,i2,k) = aux; \
1363 } \
1364 } OK }
1365
1366
1367#define ROWOP_IMP(T) { \
1368 T a = *pa; \
1369 switch(code) { \
1370 case 0: AXPY_IMP \
1371 case 1: SCAL_IMP \
1372 case 2: SWAP_IMP(T) \
1373 default: ERROR(BAD_CODE); \
1374 } \
1375}
1376
1377#define ROWOP_MOD_IMP(T,M) { \
1378 T a = *pa; \
1379 switch(code) { \
1380 case 0: AXPY_MOD_IMP(M) \
1381 case 1: SCAL_MOD_IMP(M) \
1382 case 2: SWAP_IMP(T) \
1383 default: ERROR(BAD_CODE); \
1384 } \
1385}
1386
1387
1388#define ROWOP(T) int rowop_##T(int code, T* pa, int i1, int i2, int j1, int j2, MATG(T,r)) ROWOP_IMP(T)
1389
1390#define ROWOP_MOD(T,M) int rowop_mod_##T(T m, int code, T* pa, int i1, int i2, int j1, int j2, MATG(T,r)) ROWOP_MOD_IMP(T,M)
1391
1392ROWOP(double)
1393ROWOP(float)
1394ROWOP(TCD)
1395ROWOP(TCF)
1396ROWOP(int32_t)
1397ROWOP(int64_t)
1398ROWOP_MOD(int32_t,mod)
1399ROWOP_MOD(int64_t,mod_l)
1313 1400
1314////////////////// sparse matrix-product /////////////////////////////////////// 1401////////////////// sparse matrix-product ///////////////////////////////////////
1315 1402
diff --git a/packages/base/src/Internal/C/lapack-aux.h b/packages/base/src/Internal/C/lapack-aux.h
index 1549bb5..e4d95bc 100644
--- a/packages/base/src/Internal/C/lapack-aux.h
+++ b/packages/base/src/Internal/C/lapack-aux.h
@@ -59,6 +59,7 @@ typedef short ftnlen;
59#define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p 59#define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p
60#define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p 60#define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p
61 61
62#define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p
62 63
63#define KIVEC(A) int A##n, const int*A##p 64#define KIVEC(A) int A##n, const int*A##p
64#define KLVEC(A) int A##n, const int64_t*A##p 65#define KLVEC(A) int A##n, const int64_t*A##p
diff --git a/packages/base/src/Internal/C/vector-aux.c b/packages/base/src/Internal/C/vector-aux.c
index c161556..5528a9d 100644
--- a/packages/base/src/Internal/C/vector-aux.c
+++ b/packages/base/src/Internal/C/vector-aux.c
@@ -716,6 +716,7 @@ int mapValF(int code, float* pval, KFVEC(x), FVEC(r)) {
716 } 716 }
717} 717}
718 718
719inline
719int mod (int a, int b) { 720int mod (int a, int b) {
720 int m = a % b; 721 int m = a % b;
721 if (b>0) { 722 if (b>0) {
@@ -741,7 +742,7 @@ int mapValI(int code, int* pval, KIVEC(x), IVEC(r)) {
741 } 742 }
742} 743}
743 744
744 745inline
745int64_t mod_l (int64_t a, int64_t b) { 746int64_t mod_l (int64_t a, int64_t b) {
746 int64_t m = a % b; 747 int64_t m = a % b;
747 if (b>0) { 748 if (b>0) {
@@ -1230,7 +1231,7 @@ int round_vector_i(KDVEC(v),IVEC(r)) {
1230int mod_vector(int m, KIVEC(v), IVEC(r)) { 1231int mod_vector(int m, KIVEC(v), IVEC(r)) {
1231 int k; 1232 int k;
1232 for(k=0; k<vn; k++) { 1233 for(k=0; k<vn; k++) {
1233 rp[k] = vp[k] % m; 1234 rp[k] = mod(vp[k],m);
1234 } 1235 }
1235 OK 1236 OK
1236} 1237}
@@ -1266,7 +1267,7 @@ int round_vector_l(KDVEC(v),LVEC(r)) {
1266int mod_vector_l(int64_t m, KLVEC(v), LVEC(r)) { 1267int mod_vector_l(int64_t m, KLVEC(v), LVEC(r)) {
1267 int k; 1268 int k;
1268 for(k=0; k<vn; k++) { 1269 for(k=0; k<vn; k++) {
1269 rp[k] = vp[k] % m; 1270 rp[k] = mod_l(vp[k],m);
1270 } 1271 }
1271 OK 1272 OK
1272} 1273}
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs
index 55bff67..4007491 100644
--- a/packages/base/src/Internal/Element.hs
+++ b/packages/base/src/Internal/Element.hs
@@ -30,6 +30,7 @@ import Text.Printf
30import Data.List(transpose,intersperse) 30import Data.List(transpose,intersperse)
31import Data.List.Split(chunksOf) 31import Data.List.Split(chunksOf)
32import Foreign.Storable(Storable) 32import Foreign.Storable(Storable)
33import System.IO.Unsafe(unsafePerformIO)
33import Control.Monad(liftM) 34import Control.Monad(liftM)
34 35
35------------------------------------------------------------------- 36-------------------------------------------------------------------
@@ -147,7 +148,7 @@ m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n))
147m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) 148m ?? (DropLast n, e) = m ?? (Take (rows m - n), e)
148m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) 149m ?? (e, DropLast n) = m ?? (e, Take (cols m - n))
149 150
150m ?? (er,ec) = extractR m moder rs modec cs 151m ?? (er,ec) = unsafePerformIO $ extractR m moder rs modec cs
151 where 152 where
152 (moder,rs) = mkExt (rows m) er 153 (moder,rs) = mkExt (rows m) er
153 (modec,cs) = mkExt (cols m) ec 154 (modec,cs) = mkExt (cols m) ec
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 8de06ce..fa1aad6 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -20,6 +20,7 @@ import Internal.Vector
20import Internal.Devel 20import Internal.Devel
21import Internal.Vectorized 21import Internal.Vectorized
22import Foreign.Marshal.Alloc ( free ) 22import Foreign.Marshal.Alloc ( free )
23import Foreign.Marshal.Array(newArray)
23import Foreign.Ptr ( Ptr ) 24import Foreign.Ptr ( Ptr )
24import Foreign.Storable ( Storable ) 25import Foreign.Storable ( Storable )
25import Data.Complex ( Complex ) 26import Data.Complex ( Complex )
@@ -273,12 +274,13 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
273class (Storable a) => Element a where 274class (Storable a) => Element a where
274 transdata :: Int -> Vector a -> Int -> Vector a 275 transdata :: Int -> Vector a -> Int -> Vector a
275 constantD :: a -> Int -> Vector a 276 constantD :: a -> Int -> Vector a
276 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> Matrix a 277 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
277 sortI :: Ord a => Vector a -> Vector CInt 278 sortI :: Ord a => Vector a -> Vector CInt
278 sortV :: Ord a => Vector a -> Vector a 279 sortV :: Ord a => Vector a -> Vector a
279 compareV :: Ord a => Vector a -> Vector a -> Vector CInt 280 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
280 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a 281 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
281 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 282 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
283 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
282 284
283 285
284instance Element Float where 286instance Element Float where
@@ -290,6 +292,7 @@ instance Element Float where
290 compareV = compareF 292 compareV = compareF
291 selectV = selectF 293 selectV = selectF
292 remapM = remapF 294 remapM = remapF
295 rowOp = rowOpAux c_rowOpF
293 296
294instance Element Double where 297instance Element Double where
295 transdata = transdataAux ctransR 298 transdata = transdataAux ctransR
@@ -300,6 +303,7 @@ instance Element Double where
300 compareV = compareD 303 compareV = compareD
301 selectV = selectD 304 selectV = selectD
302 remapM = remapD 305 remapM = remapD
306 rowOp = rowOpAux c_rowOpD
303 307
304 308
305instance Element (Complex Float) where 309instance Element (Complex Float) where
@@ -311,6 +315,7 @@ instance Element (Complex Float) where
311 compareV = undefined 315 compareV = undefined
312 selectV = selectQ 316 selectV = selectQ
313 remapM = remapQ 317 remapM = remapQ
318 rowOp = rowOpAux c_rowOpQ
314 319
315 320
316instance Element (Complex Double) where 321instance Element (Complex Double) where
@@ -322,6 +327,7 @@ instance Element (Complex Double) where
322 compareV = undefined 327 compareV = undefined
323 selectV = selectC 328 selectV = selectC
324 remapM = remapC 329 remapM = remapC
330 rowOp = rowOpAux c_rowOpC
325 331
326instance Element (CInt) where 332instance Element (CInt) where
327 transdata = transdataAux ctransI 333 transdata = transdataAux ctransI
@@ -332,6 +338,7 @@ instance Element (CInt) where
332 compareV = compareI 338 compareV = compareI
333 selectV = selectI 339 selectV = selectI
334 remapM = remapI 340 remapM = remapI
341 rowOp = rowOpAux c_rowOpI
335 342
336instance Element Z where 343instance Element Z where
337 transdata = transdataAux ctransL 344 transdata = transdataAux ctransL
@@ -342,6 +349,7 @@ instance Element Z where
342 compareV = compareL 349 compareV = compareL
343 selectV = selectL 350 selectV = selectL
344 remapM = remapL 351 remapM = remapL
352 rowOp = rowOpAux c_rowOpL
345 353
346------------------------------------------------------------------- 354-------------------------------------------------------------------
347 355
@@ -379,7 +387,7 @@ subMatrix :: Element a
379 -> Matrix a -- ^ result 387 -> Matrix a -- ^ result
380subMatrix (r0,c0) (rt,ct) m 388subMatrix (r0,c0) (rt,ct) m
381 | 0 <= r0 && 0 <= rt && r0+rt <= rows m && 389 | 0 <= r0 && 0 <= rt && r0+rt <= rows m &&
382 0 <= c0 && 0 <= ct && c0+ct <= cols m = extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1]) 390 0 <= c0 && 0 <= ct && c0+ct <= cols m = unsafePerformIO $ extractR m 0 (idxs[r0,r0+rt-1]) 0 (idxs[c0,c0+ct-1])
383 | otherwise = error $ "wrong subMatrix "++ 391 | otherwise = error $ "wrong subMatrix "++
384 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m) 392 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
385 393
@@ -430,7 +438,7 @@ instance (Storable t, NFData t) => NFData (Matrix t)
430 438
431--------------------------------------------------------------- 439---------------------------------------------------------------
432 440
433extractAux f m moder vr modec vc = unsafePerformIO $ do 441extractAux f m moder vr modec vc = do
434 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 442 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
435 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 443 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
436 r <- createMatrix RowMajor nr nc 444 r <- createMatrix RowMajor nr nc
@@ -538,6 +546,24 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
538 546
539-------------------------------------------------------------------------------- 547--------------------------------------------------------------------------------
540 548
549rowOpAux f c x i1 i2 j1 j2 m = do
550 px <- newArray [x]
551 app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp"
552 free px
553
554type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
555
556foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R
557foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float
558foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C
559foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float)
560foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I
561foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z
562foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I
563foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
564
565--------------------------------------------------------------------------------
566
541foreign import ccall unsafe "saveMatrix" c_saveMatrix 567foreign import ccall unsafe "saveMatrix" c_saveMatrix
542 :: CString -> CString -> Double ..> Ok 568 :: CString -> CString -> Double ..> Ok
543 569
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 1289a21..824fc57 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -111,18 +111,46 @@ instance forall n t . (Integral t, KnownNat n) => Num (Mod n t)
111 fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) 111 fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m))
112 112
113 113
114instance KnownNat m => Element (Mod m I)
115 where
116 transdata n v m = i2f (transdata n (f2i v) m)
117 constantD x n = i2f (constantD (unMod x) n)
118 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js
119 sortI = sortI . f2i
120 sortV = i2f . sortV . f2i
121 compareV u v = compareV (f2i u) (f2i v)
122 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
123 remapM i j m = i2fM (remap i j (f2iM m))
124 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x)
125 where
126 m' = fromIntegral . natVal $ (undefined :: Proxy m)
114 127
115instance (Ord t, Element t) => Element (Mod n t) 128instance KnownNat m => Element (Mod m Z)
116 where 129 where
117 transdata n v m = i2f (transdata n (f2i v) m) 130 transdata n v m = i2f (transdata n (f2i v) m)
118 constantD x n = i2f (constantD (unMod x) n) 131 constantD x n = i2f (constantD (unMod x) n)
119 extractR m mi is mj js = i2fM (extractR (f2iM m) mi is mj js) 132 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js
120 sortI = sortI . f2i 133 sortI = sortI . f2i
121 sortV = i2f . sortV . f2i 134 sortV = i2f . sortV . f2i
122 compareV u v = compareV (f2i u) (f2i v) 135 compareV u v = compareV (f2i u) (f2i v)
123 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) 136 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
124 remapM i j m = i2fM (remap i j (f2iM m)) 137 remapM i j m = i2fM (remap i j (f2iM m))
138 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x)
139 where
140 m' = fromIntegral . natVal $ (undefined :: Proxy m)
125 141
142{-
143instance (Ord t, Element t) => Element (Mod m t)
144 where
145 transdata n v m = i2f (transdata n (f2i v) m)
146 constantD x n = i2f (constantD (unMod x) n)
147 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js
148 sortI = sortI . f2i
149 sortV = i2f . sortV . f2i
150 compareV u v = compareV (f2i u) (f2i v)
151 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
152 remapM i j m = i2fM (remap i j (f2iM m))
153-}
126 154
127instance forall m . KnownNat m => Container Vector (Mod m I) 155instance forall m . KnownNat m => Container Vector (Mod m I)
128 where 156 where
@@ -205,12 +233,10 @@ instance forall m . KnownNat m => Container Vector (Mod m Z)
205 toZ' = f2i 233 toZ' = f2i
206 234
207 235
208
209instance (Storable t, Indexable (Vector t) t) => Indexable (Vector (Mod m t)) (Mod m t) 236instance (Storable t, Indexable (Vector t) t) => Indexable (Vector (Mod m t)) (Mod m t)
210 where 237 where
211 (!) = (@>) 238 (!) = (@>)
212 239
213
214type instance RealOf (Mod n I) = I 240type instance RealOf (Mod n I) = I
215type instance RealOf (Mod n Z) = Z 241type instance RealOf (Mod n Z) = Z
216 242
@@ -270,6 +296,15 @@ instance forall m . KnownNat m => Num (Vector (Mod m I))
270 negate = lift1 negate 296 negate = lift1 negate
271 fromInteger x = fromInt (fromInteger x) 297 fromInteger x = fromInt (fromInteger x)
272 298
299instance forall m . KnownNat m => Num (Vector (Mod m Z))
300 where
301 (+) = lift2 (+)
302 (*) = lift2 (*)
303 (-) = lift2 (-)
304 abs = lift1 abs
305 signum = lift1 signum
306 negate = lift1 negate
307 fromInteger x = fromZ (fromInteger x)
273 308
274-------------------------------------------------------------------------------- 309--------------------------------------------------------------------------------
275 310
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index ae75a1b..107d3c3 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -1,5 +1,6 @@
1{-# LANGUAGE Rank2Types #-} 1{-# LANGUAGE Rank2Types #-}
2{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE BangPatterns #-}
3
3----------------------------------------------------------------------------- 4-----------------------------------------------------------------------------
4-- | 5-- |
5-- Module : Internal.ST 6-- Module : Internal.ST
@@ -20,6 +21,8 @@ module Internal.ST (
20 -- * Mutable Matrices 21 -- * Mutable Matrices
21 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
22 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24 axpy, scal, swap, extractRect,
25 mutable,
23 -- * Unsafe functions 26 -- * Unsafe functions
24 newUndefinedVector, 27 newUndefinedVector,
25 unsafeReadVector, unsafeWriteVector, 28 unsafeReadVector, unsafeWriteVector,
@@ -34,8 +37,6 @@ import Internal.Matrix
34import Internal.Vectorized 37import Internal.Vectorized
35import Control.Monad.ST(ST, runST) 38import Control.Monad.ST(ST, runST)
36import Foreign.Storable(Storable, peekElemOff, pokeElemOff) 39import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
37
38
39import Control.Monad.ST.Unsafe(unsafeIOToST) 40import Control.Monad.ST.Unsafe(unsafeIOToST)
40 41
41{-# INLINE ioReadV #-} 42{-# INLINE ioReadV #-}
@@ -144,6 +145,7 @@ liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
144unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 145unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
145unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 146unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
146 147
148
147freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 149freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
148freezeMatrix m = liftSTMatrix id m 150freezeMatrix m = liftSTMatrix id m
149 151
@@ -171,3 +173,23 @@ newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
171newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t) 173newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
172newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) 174newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
173 175
176--------------------------------------------------------------------------------
177
178rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s ()
179rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m)
180
181axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m)
182scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m)
183swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m)
184
185extractRect (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
186
187--------------------------------------------------------------------------------
188
189mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
190mutable f a = runST $ do
191 x <- thawMatrix a
192 info <- f (rows a, cols a) x
193 r <- unsafeFreezeMatrix x
194 return (r,info)
195
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index b1fb800..7a556e9 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -54,7 +54,7 @@ module Internal.Util(
54 -- ** 2D 54 -- ** 2D
55 corr2, conv2, separable, 55 corr2, conv2, separable,
56 block2x2,block3x3,view1,unView1,foldMatrix, 56 block2x2,block3x3,view1,unView1,foldMatrix,
57 gaussElim 57 gaussElim_1, gaussElim_2, gaussElim
58) where 58) where
59 59
60import Internal.Vector 60import Internal.Vector
@@ -64,17 +64,19 @@ import Internal.Element
64import Internal.Container 64import Internal.Container
65import Internal.Vectorized 65import Internal.Vectorized
66import Internal.IO 66import Internal.IO
67import Internal.Algorithms hiding (i,Normed) 67import Internal.Algorithms hiding (i,Normed,swap)
68import Numeric.Matrix() 68import Numeric.Matrix()
69import Numeric.Vector() 69import Numeric.Vector()
70import Internal.Random 70import Internal.Random
71import Internal.Convolution 71import Internal.Convolution
72import Control.Monad(when) 72import Control.Monad(when,forM_)
73import Text.Printf 73import Text.Printf
74import Data.List.Split(splitOn) 74import Data.List.Split(splitOn)
75import Data.List(intercalate,) 75import Data.List(intercalate,sortBy)
76import Control.Arrow((&&&)) 76import Control.Arrow((&&&))
77import Data.Complex 77import Data.Complex
78import Data.Function(on)
79import Internal.ST
78 80
79type ℝ = Double 81type ℝ = Double
80type ℕ = Int 82type ℕ = Int
@@ -359,6 +361,10 @@ instance Indexable (Vector I) I
359 where 361 where
360 (!) = (@>) 362 (!) = (@>)
361 363
364instance Indexable (Vector Z) Z
365 where
366 (!) = (@>)
367
362instance Indexable (Vector (Complex Double)) (Complex Double) 368instance Indexable (Vector (Complex Double)) (Complex Double)
363 where 369 where
364 (!) = (@>) 370 (!) = (@>)
@@ -550,11 +556,11 @@ down g a = foldMatrix g f a
550-- 556--
551-- @a <> gaussElim a b = b@ 557-- @a <> gaussElim a b = b@
552-- 558--
553gaussElim 559gaussElim_2
554 :: (Eq t, Fractional t, Num (Vector t), Numeric t) 560 :: (Eq t, Fractional t, Num (Vector t), Numeric t)
555 => Matrix t -> Matrix t -> Matrix t 561 => Matrix t -> Matrix t -> Matrix t
556 562
557gaussElim a b = flipudrl r 563gaussElim_2 a b = flipudrl r
558 where 564 where
559 flipudrl = flipud . fliprl 565 flipudrl = flipud . fliprl
560 splitColsAt n = (takeColumns n &&& dropColumns n) 566 splitColsAt n = (takeColumns n &&& dropColumns n)
@@ -564,6 +570,68 @@ gaussElim a b = flipudrl r
564 570
565-------------------------------------------------------------------------------- 571--------------------------------------------------------------------------------
566 572
573gaussElim_1
574 :: (Fractional t, Num (Vector t), Ord t, Indexable (Vector t) t, Numeric t)
575 => Matrix t -> Matrix t -> Matrix t
576
577gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2)
578 where
579 rs = toRows $ fromBlocks [[x , y]]
580 s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting
581 s2 = pivotUp (rows x-1) (toRows $ flipud s1)
582
583pivotDown t n xs
584 | t == n = []
585 | otherwise = y : pivotDown t (n+1) ys
586 where
587 y:ys = redu (pivot n xs)
588
589 pivot k = (const k &&& id)
590 . sortBy (flip compare `on` (abs. (!k)))
591
592 redu (k,x:zs)
593 | p == 0 = error "gauss: singular!" -- FIXME
594 | otherwise = u : map f zs
595 where
596 p = x!k
597 u = scale (recip (x!k)) x
598 f z = z - scale (z!k) u
599 redu (_,[]) = []
600
601
602pivotUp n xs
603 | n == -1 = []
604 | otherwise = y : pivotUp (n-1) ys
605 where
606 y:ys = redu' (n,xs)
607
608 redu' (k,x:zs) = u : map f zs
609 where
610 u = x
611 f z = z - scale (z!k) u
612 redu' (_,[]) = []
613
614--------------------------------------------------------------------------------
615
616gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]])
617
618gaussST (r,_) x = do
619 let n = r-1
620 forM_ [0..n] $ \i -> do
621 c <- maxIndex . abs . flatten <$> extractRect x i n i i
622 swap x i (i+c)
623 a <- readMatrix x i i
624 scal x (recip a) i
625 forM_ [i+1..n] $ \j -> do
626 b <- readMatrix x j i
627 axpy x (-b) i j
628 forM_ [n,n-1..1] $ \i -> do
629 forM_ [i-1,i-2..0] $ \j -> do
630 b <- readMatrix x j i
631 axpy x (-b) i j
632
633--------------------------------------------------------------------------------
634
567instance Testable (Matrix I) where 635instance Testable (Matrix I) where
568 checkT _ = test 636 checkT _ = test
569 637
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index 56e5053..c97f415 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -134,7 +134,7 @@ module Numeric.LinearAlgebra (
134 Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, 134 Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample,
135 135
136 -- * Misc 136 -- * Misc
137 meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, 137 meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, gaussElim_1, gaussElim_2,
138 ℝ,ℂ,iC, 138 ℝ,ℂ,iC,
139 -- * Auxiliary classes 139 -- * Auxiliary classes
140 Element, Container, Product, Numeric, LSDiv, 140 Element, Container, Product, Numeric, LSDiv,
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
index 1a70663..84763fe 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
@@ -44,6 +44,7 @@ module Numeric.LinearAlgebra.Devel(
44 -- ** Mutable Matrices 44 -- ** Mutable Matrices
45 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 45 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
46 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 46 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
47 axpy,scal,swap, extractRect, mutable,
47 -- ** Unsafe functions 48 -- ** Unsafe functions
48 newUndefinedVector, 49 newUndefinedVector,
49 unsafeReadVector, unsafeWriteVector, 50 unsafeReadVector, unsafeWriteVector,