summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Internal/Algorithms.hs33
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c59
-rw-r--r--packages/base/src/Internal/C/lapack-aux.h1
-rw-r--r--packages/base/src/Internal/Matrix.hs24
-rw-r--r--packages/base/src/Internal/Modular.hs52
-rw-r--r--packages/base/src/Internal/ST.hs29
-rw-r--r--packages/base/src/Internal/Util.hs116
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs56
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Devel.hs2
9 files changed, 280 insertions, 92 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs
index aaf6fbb..1235da3 100644
--- a/packages/base/src/Internal/Algorithms.hs
+++ b/packages/base/src/Internal/Algorithms.hs
@@ -29,7 +29,9 @@ import Internal.Conversion
29import Internal.LAPACK as LAPACK 29import Internal.LAPACK as LAPACK
30import Internal.Numeric 30import Internal.Numeric
31import Data.List(foldl1') 31import Data.List(foldl1')
32import Data.Array 32import qualified Data.Array as A
33import Internal.ST
34import Internal.Vectorized(range)
33 35
34{- | Generic linear algebra functions for double precision real and complex matrices. 36{- | Generic linear algebra functions for double precision real and complex matrices.
35 37
@@ -578,11 +580,6 @@ eps = 2.22044604925031e-16
578peps :: RealFloat x => x 580peps :: RealFloat x => x
579peps = x where x = 2.0 ** fromIntegral (1 - floatDigits x) 581peps = x where x = 2.0 ** fromIntegral (1 - floatDigits x)
580 582
581
582-- | The imaginary unit: @i = 0.0 :+ 1.0@
583i :: Complex Double
584i = 0:+1
585
586----------------------------------------------------------------------- 583-----------------------------------------------------------------------
587 584
588-- | The nullspace of a matrix from its precomputed SVD decomposition. 585-- | The nullspace of a matrix from its precomputed SVD decomposition.
@@ -796,13 +793,23 @@ signlp r vals = foldl f 1 (zip [0..r-1] vals)
796 where f s (a,b) | a /= b = -s 793 where f s (a,b) | a /= b = -s
797 | otherwise = s 794 | otherwise = s
798 795
799swap (arr,s) (a,b) | a /= b = (arr // [(a, arr!b),(b,arr!a)],-s) 796fixPerm r vals = (fromColumns $ A.elems res, sign)
800 | otherwise = (arr,s) 797 where
801 798 v = [0..r-1]
802fixPerm r vals = (fromColumns $ elems res, sign) 799 t = toColumns (ident r)
803 where v = [0..r-1] 800 (res,sign) = foldl swap (A.listArray (0,r-1) t, 1) (zip v vals)
804 s = toColumns (ident r) 801 swap (arr,s) (a,b)
805 (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) 802 | a /= b = (arr A.// [(a, arr A.! b),(b,arr A.! a)],-s)
803 | otherwise = (arr,s)
804
805fixPerm' :: [Int] -> Vector I
806fixPerm' s = res $ mutable f s0
807 where
808 s0 = reshape 1 (range (length s))
809 res = flatten . fst
810 swap m i j = rowOper (SWAP i j AllCols) m
811 f :: (Num t, Element t) => (Int, Int) -> STMatrix s t -> ST s () -- needed because of TypeFamilies
812 f _ p = sequence_ $ zipWith (swap p) [0..] s
806 813
807triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]] 814triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]]
808 where el p q = if q-p>=h then v else 1 - v 815 where el p q = if q-p>=h then v else 1 - v
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index 2843ab5..4d48594 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1398,6 +1398,65 @@ ROWOP(int64_t)
1398ROWOP_MOD(int32_t,mod) 1398ROWOP_MOD(int32_t,mod)
1399ROWOP_MOD(int64_t,mod_l) 1399ROWOP_MOD(int64_t,mod_l)
1400 1400
1401/////////////////////////////// inplace GEMM ////////////////////////////////
1402
1403#define GEMM(T) int gemm_##T(VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1404 T a = cp[0], b = cp[1]; \
1405 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \
1406 int r1b = pp[4], c1b = pp[6] ; \
1407 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \
1408 int dra = r1a - r1r; \
1409 int dcb = c1b-c1r; \
1410 int nk = c2a-c1a+1; \
1411 int i,j,k; \
1412 T t; \
1413 for (i=r1r; i<=r2r; i++) { \
1414 for (j=c1r; j<=c2r; j++) { \
1415 t = 0; \
1416 for(k=0; k<nk; k++) { \
1417 t += AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb); \
1418 } \
1419 AT(r,i,j) = b*AT(r,i,j) + a*t; \
1420 } \
1421 } \
1422 OK \
1423}
1424
1425GEMM(double)
1426GEMM(float)
1427GEMM(TCD)
1428GEMM(TCF)
1429GEMM(int32_t)
1430GEMM(int64_t)
1431
1432#define GEMM_MOD(T,M) int gemm_mod_##T(T m, VECG(T,c),VECG(int,p),MATG(T,a),MATG(T,b),MATG(T,r)) { \
1433 T a = cp[0], b = cp[1]; \
1434 int r1a = pp[0], c1a = pp[2], c2a = pp[3] ; \
1435 int r1b = pp[4], c1b = pp[6] ; \
1436 int r1r = pp[8], r2r = pp[9], c1r = pp[10], c2r = pp[11]; \
1437 int dra = r1a - r1r; \
1438 int dcb = c1b-c1r; \
1439 int nk = c2a-c1a+1; \
1440 int i,j,k; \
1441 T t; \
1442 for (i=r1r; i<=r2r; i++) { \
1443 for (j=c1r; j<=c2r; j++) { \
1444 t = 0; \
1445 for(k=0; k<nk; k++) { \
1446 t = M(t+M(AT(a,i+dra,k+c1a) * AT(b,k+r1b,j+dcb))); \
1447 } \
1448 AT(r,i,j) = M(M(b*AT(r,i,j)) + M(a*t)); \
1449 } \
1450 } \
1451 OK \
1452}
1453
1454#define MOD32(X) mod(X,m)
1455#define MOD64(X) mod_l(X,m)
1456
1457GEMM_MOD(int32_t,MOD32)
1458GEMM_MOD(int64_t,MOD64)
1459
1401////////////////// sparse matrix-product /////////////////////////////////////// 1460////////////////// sparse matrix-product ///////////////////////////////////////
1402 1461
1403 1462
diff --git a/packages/base/src/Internal/C/lapack-aux.h b/packages/base/src/Internal/C/lapack-aux.h
index e4d95bc..bf8c5e9 100644
--- a/packages/base/src/Internal/C/lapack-aux.h
+++ b/packages/base/src/Internal/C/lapack-aux.h
@@ -59,6 +59,7 @@ typedef short ftnlen;
59#define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p 59#define OQMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, complex* A##p
60#define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p 60#define OCMAT(A) int A##r, int A##c, int A##Xr, int A##Xc, doublecomplex* A##p
61 61
62#define VECG(T,A) int A##n, T* A##p
62#define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p 63#define MATG(T,A) int A##r, int A##c, int A##Xr, int A##Xc, T* A##p
63 64
64#define KIVEC(A) int A##n, const int*A##p 65#define KIVEC(A) int A##n, const int*A##p
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 75e92a5..8f8c219 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -274,6 +274,7 @@ class (Storable a) => Element a where
274 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a 274 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
275 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a 275 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
276 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () 276 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
277 gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO ()
277 278
278 279
279instance Element Float where 280instance Element Float where
@@ -287,6 +288,7 @@ instance Element Float where
287 selectV = selectF 288 selectV = selectF
288 remapM = remapF 289 remapM = remapF
289 rowOp = rowOpAux c_rowOpF 290 rowOp = rowOpAux c_rowOpF
291 gemm = gemmg c_gemmF
290 292
291instance Element Double where 293instance Element Double where
292 transdata = transdataAux ctransR 294 transdata = transdataAux ctransR
@@ -299,7 +301,7 @@ instance Element Double where
299 selectV = selectD 301 selectV = selectD
300 remapM = remapD 302 remapM = remapD
301 rowOp = rowOpAux c_rowOpD 303 rowOp = rowOpAux c_rowOpD
302 304 gemm = gemmg c_gemmD
303 305
304instance Element (Complex Float) where 306instance Element (Complex Float) where
305 transdata = transdataAux ctransQ 307 transdata = transdataAux ctransQ
@@ -312,7 +314,7 @@ instance Element (Complex Float) where
312 selectV = selectQ 314 selectV = selectQ
313 remapM = remapQ 315 remapM = remapQ
314 rowOp = rowOpAux c_rowOpQ 316 rowOp = rowOpAux c_rowOpQ
315 317 gemm = gemmg c_gemmQ
316 318
317instance Element (Complex Double) where 319instance Element (Complex Double) where
318 transdata = transdataAux ctransC 320 transdata = transdataAux ctransC
@@ -325,6 +327,7 @@ instance Element (Complex Double) where
325 selectV = selectC 327 selectV = selectC
326 remapM = remapC 328 remapM = remapC
327 rowOp = rowOpAux c_rowOpC 329 rowOp = rowOpAux c_rowOpC
330 gemm = gemmg c_gemmC
328 331
329instance Element (CInt) where 332instance Element (CInt) where
330 transdata = transdataAux ctransI 333 transdata = transdataAux ctransI
@@ -337,6 +340,7 @@ instance Element (CInt) where
337 selectV = selectI 340 selectV = selectI
338 remapM = remapI 341 remapM = remapI
339 rowOp = rowOpAux c_rowOpI 342 rowOp = rowOpAux c_rowOpI
343 gemm = gemmg c_gemmI
340 344
341instance Element Z where 345instance Element Z where
342 transdata = transdataAux ctransL 346 transdata = transdataAux ctransL
@@ -349,6 +353,7 @@ instance Element Z where
349 selectV = selectL 353 selectV = selectL
350 remapM = remapL 354 remapM = remapL
351 rowOp = rowOpAux c_rowOpL 355 rowOp = rowOpAux c_rowOpL
356 gemm = gemmg c_gemmL
352 357
353------------------------------------------------------------------- 358-------------------------------------------------------------------
354 359
@@ -575,6 +580,21 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
575 580
576-------------------------------------------------------------------------------- 581--------------------------------------------------------------------------------
577 582
583gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg"
584
585type Tgemm x = x :> I :> x ::> x ::> x ::> Ok
586
587foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R
588foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float
589foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C
590foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float)
591foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I
592foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z
593foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I
594foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
595
596--------------------------------------------------------------------------------
597
578foreign import ccall unsafe "saveMatrix" c_saveMatrix 598foreign import ccall unsafe "saveMatrix" c_saveMatrix
579 :: CString -> CString -> Double ..> Ok 599 :: CString -> CString -> Double ..> Ok
580 600
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 6c6d5c5..1cae1ac 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -34,7 +34,9 @@ import Internal.Container
34import Internal.Vectorized (prodI,sumI,prodL,sumL) 34import Internal.Vectorized (prodI,sumI,prodL,sumL)
35import Internal.LAPACK (multiplyI, multiplyL) 35import Internal.LAPACK (multiplyI, multiplyL)
36import Internal.Algorithms(luFact) 36import Internal.Algorithms(luFact)
37import Internal.Util(Normed(..),Indexable(..),gaussElim, gaussElim_1, gaussElim_2,luST, magnit) 37import Internal.Util(Normed(..),Indexable(..),
38 gaussElim, gaussElim_1, gaussElim_2,
39 luST, luSolve', luPacked', magnit)
38import Internal.ST(mutable) 40import Internal.ST(mutable)
39import GHC.TypeLits 41import GHC.TypeLits
40import Data.Proxy(Proxy) 42import Data.Proxy(Proxy)
@@ -131,6 +133,9 @@ instance KnownNat m => Element (Mod m I)
131 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x) 133 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpMI m') c (unMod a) i1 i2 j1 j2 (f2iM x)
132 where 134 where
133 m' = fromIntegral . natVal $ (undefined :: Proxy m) 135 m' = fromIntegral . natVal $ (undefined :: Proxy m)
136 gemm u p a b c = gemmg (c_gemmMI m') (f2i u) p (f2iM a) (f2iM b) (f2iM c)
137 where
138 m' = fromIntegral . natVal $ (undefined :: Proxy m)
134 139
135instance KnownNat m => Element (Mod m Z) 140instance KnownNat m => Element (Mod m Z)
136 where 141 where
@@ -146,6 +151,9 @@ instance KnownNat m => Element (Mod m Z)
146 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x) 151 rowOp c a i1 i2 j1 j2 x = rowOpAux (c_rowOpML m') c (unMod a) i1 i2 j1 j2 (f2iM x)
147 where 152 where
148 m' = fromIntegral . natVal $ (undefined :: Proxy m) 153 m' = fromIntegral . natVal $ (undefined :: Proxy m)
154 gemm u p a b c = gemmg (c_gemmML m') (f2i u) p (f2iM a) (f2iM b) (f2iM c)
155 where
156 m' = fromIntegral . natVal $ (undefined :: Proxy m)
149 157
150 158
151instance forall m . KnownNat m => Container Vector (Mod m I) 159instance forall m . KnownNat m => Container Vector (Mod m I)
@@ -344,7 +352,11 @@ test = (ok, info)
344 lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z 352 lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z
345 lgm = fromZ lg :: Matrix (Mod 10000000000 Z) 353 lgm = fromZ lg :: Matrix (Mod 10000000000 Z)
346 354
347 gen n = diagRect 1 (konst 5 n) n n :: Numeric t => Matrix t 355 gen n = diagRect 1 (konst 5 n) n n :: Numeric t => Matrix t
356
357 rgen n = gen n :: Matrix R
358 cgen n = complex (rgen n) + fliprl (complex (rgen n)) * scalar (0:+1) :: Matrix C
359 sgen n = single (cgen n)
348 360
349 checkGen x = norm_Inf $ flatten $ invg x <> x - ident (rows x) 361 checkGen x = norm_Inf $ flatten $ invg x <> x - ident (rows x)
350 362
@@ -354,6 +366,11 @@ test = (ok, info)
354 where 366 where
355 (l,u,p,_ :: Int) = luFact $ mutable (luST okf) t 367 (l,u,p,_ :: Int) = luFact $ mutable (luST okf) t
356 368
369 checkSolve aa = norm_Inf $ flatten (aa <> x - bb)
370 where
371 bb = flipud aa
372 x = luSolve' (luPacked' aa) bb
373
357 info = do 374 info = do
358 print v 375 print v
359 print m 376 print m
@@ -377,9 +394,9 @@ test = (ok, info)
377 print $ lgm <> lgm 394 print $ lgm <> lgm
378 395
379 print (checkGen (gen 5 :: Matrix R)) 396 print (checkGen (gen 5 :: Matrix R))
380 print (checkGen (gen 5 :: Matrix C))
381 print (checkGen (gen 5 :: Matrix Float)) 397 print (checkGen (gen 5 :: Matrix Float))
382 print (checkGen (gen 5 :: Matrix (Complex Float))) 398 print (checkGen (cgen 5 :: Matrix C))
399 print (checkGen (sgen 5 :: Matrix (Complex Float)))
383 print (invg (gen 5) :: Matrix (Mod 7 I)) 400 print (invg (gen 5) :: Matrix (Mod 7 I))
384 print (invg (gen 5) :: Matrix (Mod 7 Z)) 401 print (invg (gen 5) :: Matrix (Mod 7 Z))
385 402
@@ -388,11 +405,18 @@ test = (ok, info)
388 405
389 print $ checkLU (magnit 0) (gen 5 :: Matrix R) 406 print $ checkLU (magnit 0) (gen 5 :: Matrix R)
390 print $ checkLU (magnit 0) (gen 5 :: Matrix Float) 407 print $ checkLU (magnit 0) (gen 5 :: Matrix Float)
391 print $ checkLU (magnit 0) (gen 5 :: Matrix C) 408 print $ checkLU (magnit 0) (cgen 5 :: Matrix C)
392 print $ checkLU (magnit 0) (gen 5 :: Matrix (Complex Float)) 409 print $ checkLU (magnit 0) (sgen 5 :: Matrix (Complex Float))
393 print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I)) 410 print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))
394 print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z)) 411 print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))
395 412
413 print $ checkSolve (gen 5 :: Matrix R)
414 print $ checkSolve (gen 5 :: Matrix Float)
415 print $ checkSolve (cgen 5 :: Matrix C)
416 print $ checkSolve (sgen 5 :: Matrix (Complex Float))
417 print $ checkSolve (gen 5 :: Matrix (Mod 7 I))
418 print $ checkSolve (gen 5 :: Matrix (Mod 7 Z))
419
396 420
397 ok = and 421 ok = and
398 [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) 422 [ toInt (m #> v) == cmod 11 (toInt m #> toInt v )
@@ -401,16 +425,22 @@ test = (ok, info)
401 , am <> gaussElim am bm == bm 425 , am <> gaussElim am bm == bm
402 , (checkGen (gen 5 :: Matrix R)) < 1E-15 426 , (checkGen (gen 5 :: Matrix R)) < 1E-15
403 , (checkGen (gen 5 :: Matrix Float)) < 2E-7 427 , (checkGen (gen 5 :: Matrix Float)) < 2E-7
404 , (checkGen (gen 5 :: Matrix C)) < 1E-15 428 , (checkGen (cgen 5 :: Matrix C)) < 1E-15
405 , (checkGen (gen 5 :: Matrix (Complex Float))) < 2E-7 429 , (checkGen (sgen 5 :: Matrix (Complex Float))) < 2E-7
406 , (checkGen (gen 5 :: Matrix (Mod 7 I))) == 0 430 , (checkGen (gen 5 :: Matrix (Mod 7 I))) == 0
407 , (checkGen (gen 5 :: Matrix (Mod 7 Z))) == 0 431 , (checkGen (gen 5 :: Matrix (Mod 7 Z))) == 0
408 , (checkLU (magnit 1E-10) (gen 5 :: Matrix R)) < 1E-15 432 , (checkLU (magnit 1E-10) (gen 5 :: Matrix R)) < 2E-15
409 , (checkLU (magnit 1E-5) (gen 5 :: Matrix Float)) < 1E-6 433 , (checkLU (magnit 1E-5) (gen 5 :: Matrix Float)) < 1E-6
410 , (checkLU (magnit 1E-10) (gen 5 :: Matrix C)) < 1E-15 434 , (checkLU (magnit 1E-10) (cgen 5 :: Matrix C)) < 5E-15
411 , (checkLU (magnit 1E-5) (gen 5 :: Matrix (Complex Float))) < 1E-6 435 , (checkLU (magnit 1E-5) (sgen 5 :: Matrix (Complex Float))) < 1E-6
412 , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))) == 0 436 , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))) == 0
413 , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))) == 0 437 , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))) == 0
438 , checkSolve (gen 5 :: Matrix R) < 2E-15
439 , checkSolve (gen 5 :: Matrix Float) < 1E-6
440 , checkSolve (cgen 5 :: Matrix C) < 4E-15
441 , checkSolve (sgen 5 :: Matrix (Complex Float)) < 1E-6
442 , checkSolve (gen 5 :: Matrix (Mod 7 I)) == 0
443 , checkSolve (gen 5 :: Matrix (Mod 7 Z)) == 0
414 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) 444 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I))
415 , gm <> gm == konst 0 (3,3) 445 , gm <> gm == konst 0 (3,3)
416 , lgm <> lgm == konst 0 (3,3) 446 , lgm <> lgm == konst 0 (3,3)
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 434fe63..d1defda 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -1,5 +1,6 @@
1{-# LANGUAGE Rank2Types #-} 1{-# LANGUAGE Rank2Types #-}
2{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE BangPatterns #-}
3{-# LANGUAGE ViewPatterns #-}
3 4
4----------------------------------------------------------------------------- 5-----------------------------------------------------------------------------
5-- | 6-- |
@@ -15,14 +16,14 @@
15----------------------------------------------------------------------------- 16-----------------------------------------------------------------------------
16 17
17module Internal.ST ( 18module Internal.ST (
19 ST, runST,
18 -- * Mutable Vectors 20 -- * Mutable Vectors
19 STVector, newVector, thawVector, freezeVector, runSTVector, 21 STVector, newVector, thawVector, freezeVector, runSTVector,
20 readVector, writeVector, modifyVector, liftSTVector, 22 readVector, writeVector, modifyVector, liftSTVector,
21 -- * Mutable Matrices 23 -- * Mutable Matrices
22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 24 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 25 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24-- axpy, scal, swap, rowOp, 26 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..),
25 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..),
26 -- * Unsafe functions 27 -- * Unsafe functions
27 newUndefinedVector, 28 newUndefinedVector,
28 unsafeReadVector, unsafeWriteVector, 29 unsafeReadVector, unsafeWriteVector,
@@ -70,13 +71,13 @@ unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
70modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s () 71modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
71modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k 72modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
72 73
73liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a 74liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s t -> ST s a
74liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x 75liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
75 76
76freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) 77freezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
77freezeVector v = liftSTVector id v 78freezeVector v = liftSTVector id v
78 79
79unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t) 80unsafeFreezeVector :: (Storable t) => STVector s t -> ST s (Vector t)
80unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x 81unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
81 82
82{-# INLINE safeIndexV #-} 83{-# INLINE safeIndexV #-}
@@ -139,14 +140,14 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
139modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () 140modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
140modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c 141modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
141 142
142liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a 143liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a
143liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x 144liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
144 145
145unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 146unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
146unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 147unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
147 148
148 149
149freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t) 150freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
150freezeMatrix m = liftSTMatrix id m 151freezeMatrix m = liftSTMatrix id m
151 152
152cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o) 153cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
@@ -227,6 +228,18 @@ extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (i
227 (i1,i2) = getRowRange (rows m) rr 228 (i1,i2) = getRowRange (rows m) rr
228 (j1,j2) = getColRange (cols m) rc 229 (j1,j2) = getColRange (cols m) rc
229 230
231-- | r0 c0 height width
232data Slice s t = Slice (STMatrix s t) Int Int Int Int
233
234slice (Slice (STMatrix m) r0 c0 nr nc) = (m, idxs[r0,r0+nr-1,c0,c0+nc-1])
235
236gemmm :: Element t => t -> Slice s t -> t -> Slice s t -> Slice s t -> ST s ()
237gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res
238 where
239 res = unsafeIOToST (gemm u v a b r)
240 u = fromList [alpha,beta]
241 v = vjoin[pa,pb,pr]
242
230 243
231mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 244mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
232mutable f a = runST $ do 245mutable f a = runST $ do
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index f08f710..079663d 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -1,6 +1,5 @@
1{-# LANGUAGE FlexibleContexts #-} 1{-# LANGUAGE FlexibleContexts #-}
2{-# LANGUAGE FlexibleInstances #-} 2{-# LANGUAGE FlexibleInstances #-}
3{-# LANGUAGE TypeFamilies #-}
4{-# LANGUAGE MultiParamTypeClasses #-} 3{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-} 4{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE ViewPatterns #-} 5{-# LANGUAGE ViewPatterns #-}
@@ -55,7 +54,7 @@ module Internal.Util(
55 -- ** 2D 54 -- ** 2D
56 corr2, conv2, separable, 55 corr2, conv2, separable,
57 block2x2,block3x3,view1,unView1,foldMatrix, 56 block2x2,block3x3,view1,unView1,foldMatrix,
58 gaussElim_1, gaussElim_2, gaussElim, luST 57 gaussElim_1, gaussElim_2, gaussElim, luST, luSolve', luSolve'', luPacked'
59) where 58) where
60 59
61import Internal.Vector 60import Internal.Vector
@@ -65,7 +64,7 @@ import Internal.Element
65import Internal.Container 64import Internal.Container
66import Internal.Vectorized 65import Internal.Vectorized
67import Internal.IO 66import Internal.IO
68import Internal.Algorithms hiding (i,Normed,swap,linearSolve') 67import Internal.Algorithms hiding (Normed,linearSolve',luSolve', luPacked')
69import Numeric.Matrix() 68import Numeric.Matrix()
70import Numeric.Vector() 69import Numeric.Vector()
71import Internal.Random 70import Internal.Random
@@ -73,7 +72,7 @@ import Internal.Convolution
73import Control.Monad(when,forM_) 72import Control.Monad(when,forM_)
74import Text.Printf 73import Text.Printf
75import Data.List.Split(splitOn) 74import Data.List.Split(splitOn)
76import Data.List(intercalate,sortBy) 75import Data.List(intercalate,sortBy,foldl')
77import Control.Arrow((&&&)) 76import Control.Arrow((&&&))
78import Data.Complex 77import Data.Complex
79import Data.Function(on) 78import Data.Function(on)
@@ -687,6 +686,115 @@ luST ok (r,_) x = do
687 v <- unsafeFreezeVector p 686 v <- unsafeFreezeVector p
688 return (toList v) 687 return (toList v)
689 688
689{- | Experimental implementation of 'luPacked'
690 for any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'.
691
692>>> let m = ident 5 + (5><5) [0..] :: Matrix (Z ./. 17)
693(5><5)
694 [ 1, 1, 2, 3, 4
695 , 5, 7, 7, 8, 9
696 , 10, 11, 13, 13, 14
697 , 15, 16, 0, 2, 2
698 , 3, 4, 5, 6, 8 ]
699
700>>> let (l,u,p,s) = luFact $ luPacked' m
701>>> l
702(5><5)
703 [ 1, 0, 0, 0, 0
704 , 6, 1, 0, 0, 0
705 , 12, 7, 1, 0, 0
706 , 7, 10, 7, 1, 0
707 , 8, 2, 6, 11, 1 ]
708>>> u
709(5><5)
710 [ 15, 16, 0, 2, 2
711 , 0, 13, 7, 13, 14
712 , 0, 0, 15, 0, 11
713 , 0, 0, 0, 15, 15
714 , 0, 0, 0, 0, 1 ]
715
716-}
717luPacked' x = mutable (luST (magnit 0)) x
718
719--------------------------------------------------------------------------------
720
721rowRange m = [0..rows m -1]
722
723at k = Pos (idxs[k])
724
725backSust' lup rhs = foldl' f (rhs?[]) (reverse ls)
726 where
727 ls = [ (d k , u k , b k) | k <- rowRange lup ]
728 where
729 d k = lup ?? (at k, at k)
730 u k = lup ?? (at k, Drop (k+1))
731 b k = rhs ?? (at k, All)
732
733 f x (d,u,b) = (b - u<>x) / d
734 ===
735 x
736
737
738forwSust' lup rhs = foldl' f (rhs?[]) ls
739 where
740 ls = [ (l k , b k) | k <- rowRange lup ]
741 where
742 l k = lup ?? (at k, Take k)
743 b k = rhs ?? (at k, All)
744
745 f x (l,b) = x
746 ===
747 (b - l<>x)
748
749
750luSolve'' (lup,p) b = backSust' lup (forwSust' lup pb)
751 where
752 pb = b ?? (Pos (fixPerm' p), All)
753
754--------------------------------------------------------------------------------
755
756forwSust lup rhs = fst $ mutable f rhs
757 where
758 f (r,c) x = do
759 l <- unsafeThawMatrix lup
760 let go k = gemmm 1 (Slice x k 0 1 c) (-1) (Slice l k 0 1 k) (Slice x 0 0 k c)
761 mapM_ go [0..r-1]
762
763
764backSust lup rhs = fst $ mutable f rhs
765 where
766 f (r,c) m = do
767 l <- unsafeThawMatrix lup
768 let d k = recip (lup `atIndex` (k,k))
769 u k = Slice l k (k+1) 1 (r-1-k)
770 b k = Slice m k 0 1 c
771 x k = Slice m (k+1) 0 (r-1-k) c
772 scal k = rowOper (SCAL (d k) (Row k) AllCols) m
773
774 go k = gemmm 1 (b k) (-1) (u k) (x k) >> scal k
775 mapM_ go [r-1,r-2..0]
776
777
778{- | Experimental implementation of 'luSolve' for any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'.
779
780>>> let a = (2><2) [1,2,3,5] :: Matrix (Z ./. 13)
781(2><2)
782 [ 1, 2
783 , 3, 5 ]
784>>> b
785(2><3)
786 [ 5, 1, 3
787 , 8, 6, 3 ]
788
789>>> luSolve' (luPacked' a) b
790(2><3)
791 [ 4, 7, 4
792 , 7, 10, 6 ]
793
794-}
795luSolve' (lup,p) b = backSust lup (forwSust lup pb)
796 where
797 pb = b ?? (Pos (fixPerm' p), All)
690 798
691-------------------------------------------------------------------------------- 799--------------------------------------------------------------------------------
692 800
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index 2e6b8ca..0b8abbb 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -77,10 +77,10 @@ module Numeric.LinearAlgebra (
77 linearSolveLS, 77 linearSolveLS,
78 linearSolveSVD, 78 linearSolveSVD,
79 luSolve, 79 luSolve,
80 luSolve',
80 cholSolve, 81 cholSolve,
81 cgSolve, 82 cgSolve,
82 cgSolve', 83 cgSolve',
83 linearSolve',
84 84
85 -- * Inverse and pseudoinverse 85 -- * Inverse and pseudoinverse
86 inv, pinv, pinvTol, 86 inv, pinv, pinvTol,
@@ -122,7 +122,7 @@ module Numeric.LinearAlgebra (
122 schur, 122 schur,
123 123
124 -- * LU 124 -- * LU
125 lu, luPacked, luFact, luPacked', 125 lu, luPacked, luPacked', luFact,
126 126
127 -- * Matrix functions 127 -- * Matrix functions
128 expm, 128 expm,
@@ -158,14 +158,13 @@ import Numeric.Vector()
158import Internal.Matrix 158import Internal.Matrix
159import Internal.Container hiding ((<>)) 159import Internal.Container hiding ((<>))
160import Internal.Numeric hiding (mul) 160import Internal.Numeric hiding (mul)
161import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked',linearSolve') 161import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked',linearSolve',luSolve')
162import qualified Internal.Algorithms as A 162import qualified Internal.Algorithms as A
163import Internal.Util 163import Internal.Util
164import Internal.Random 164import Internal.Random
165import Internal.Sparse((!#>)) 165import Internal.Sparse((!#>))
166import Internal.CG 166import Internal.CG
167import Internal.Conversion 167import Internal.Conversion
168import Internal.ST(mutable)
169 168
170{- | infix synonym of 'mul' 169{- | infix synonym of 'mul'
171 170
@@ -240,53 +239,4 @@ nullspace m = nullspaceSVD (Left (1*eps)) m (rightSV m)
240-- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. 239-- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'.
241orth m = orthSVD (Left (1*eps)) m (leftSV m) 240orth m = orthSVD (Left (1*eps)) m (leftSV m)
242 241
243{- | Experimental implementation of LU factorization
244 working on any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'.
245
246>>> let m = ident 5 + (5><5) [0..] :: Matrix (Z ./. 17)
247(5><5)
248 [ 1, 1, 2, 3, 4
249 , 5, 7, 7, 8, 9
250 , 10, 11, 13, 13, 14
251 , 15, 16, 0, 2, 2
252 , 3, 4, 5, 6, 8 ]
253
254>>> let (l,u,p,s) = luFact $ luPacked' m
255>>> l
256(5><5)
257 [ 1, 0, 0, 0, 0
258 , 6, 1, 0, 0, 0
259 , 12, 7, 1, 0, 0
260 , 7, 10, 7, 1, 0
261 , 8, 2, 6, 11, 1 ]
262>>> u
263(5><5)
264 [ 15, 16, 0, 2, 2
265 , 0, 13, 7, 13, 14
266 , 0, 0, 15, 0, 11
267 , 0, 0, 0, 15, 15
268 , 0, 0, 0, 0, 1 ]
269
270-}
271luPacked' x = mutable (luST (magnit 0)) x
272
273{- | Experimental implementation of gaussian elimination
274 working on any Fractional element type, including 'Mod' n 'I' and 'Mod' n 'Z'.
275
276>>> let a = (2><2) [1,2,3,5] :: Matrix (Z ./. 13)
277(2><2)
278 [ 1, 2
279 , 3, 5 ]
280>>> b
281(2><3)
282 [ 5, 1, 3
283 , 8, 6, 3 ]
284
285>>> let x = linearSolve' a b
286(2><3)
287 [ 4, 7, 4
288 , 7, 10, 6 ]
289
290-}
291linearSolve' x y = gaussElim x y
292 242
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
index 36c5f03..db4236b 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
@@ -43,7 +43,7 @@ module Numeric.LinearAlgebra.Devel(
43 -- ** Mutable Matrices 43 -- ** Mutable Matrices
44 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 44 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
45 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 45 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
46 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), 46 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), gemmm, Slice(..),
47 -- ** Unsafe functions 47 -- ** Unsafe functions
48 newUndefinedVector, 48 newUndefinedVector,
49 unsafeReadVector, unsafeWriteVector, 49 unsafeReadVector, unsafeWriteVector,