summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2017-03-19 20:37:07 +0100
committerGitHub <noreply@github.com>2017-03-19 20:37:07 +0100
commitddae74c9f73a1d7fcb8ad00bb74ee77ac3d01086 (patch)
treea48fc977c78fb3886e08dd98375e82e0cd28b946 /packages/base/src
parent76fee2219280b40b994796d1cfb4b2813618e863 (diff)
parentfa1642dcf26f1da0a6f4c1324bcd1e8baf9fd478 (diff)
Merge pull request #223 from idontgetoutmuch/master
Support triangular matrices.
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Internal/Algorithms.hs14
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c84
-rw-r--r--packages/base/src/Internal/LAPACK.hs30
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs3
4 files changed, 129 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs
index 70d65d7..23a5e13 100644
--- a/packages/base/src/Internal/Algorithms.hs
+++ b/packages/base/src/Internal/Algorithms.hs
@@ -20,13 +20,16 @@ imported from "Numeric.LinearAlgebra.LAPACK".
20-} 20-}
21----------------------------------------------------------------------------- 21-----------------------------------------------------------------------------
22 22
23module Internal.Algorithms where 23module Internal.Algorithms (
24 module Internal.Algorithms,
25 UpLo(..)
26) where
24 27
25import Internal.Vector 28import Internal.Vector
26import Internal.Matrix 29import Internal.Matrix
27import Internal.Element 30import Internal.Element
28import Internal.Conversion 31import Internal.Conversion
29import Internal.LAPACK as LAPACK 32import Internal.LAPACK
30import Internal.Numeric 33import Internal.Numeric
31import Data.List(foldl1') 34import Data.List(foldl1')
32import qualified Data.Array as A 35import qualified Data.Array as A
@@ -58,6 +61,7 @@ class (Numeric t,
58 mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t) 61 mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t)
59 linearSolve' :: Matrix t -> Matrix t -> Matrix t 62 linearSolve' :: Matrix t -> Matrix t -> Matrix t
60 cholSolve' :: Matrix t -> Matrix t -> Matrix t 63 cholSolve' :: Matrix t -> Matrix t -> Matrix t
64 triSolve' :: UpLo -> Matrix t -> Matrix t -> Matrix t
61 ldlPacked' :: Matrix t -> (Matrix t, [Int]) 65 ldlPacked' :: Matrix t -> (Matrix t, [Int])
62 ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t 66 ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t
63 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t 67 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t
@@ -83,6 +87,7 @@ instance Field Double where
83 linearSolve' = linearSolveR -- (luSolve . luPacked) ?? 87 linearSolve' = linearSolveR -- (luSolve . luPacked) ??
84 mbLinearSolve' = mbLinearSolveR 88 mbLinearSolve' = mbLinearSolveR
85 cholSolve' = cholSolveR 89 cholSolve' = cholSolveR
90 triSolve' = triSolveR
86 linearSolveLS' = linearSolveLSR 91 linearSolveLS' = linearSolveLSR
87 linearSolveSVD' = linearSolveSVDR Nothing 92 linearSolveSVD' = linearSolveSVDR Nothing
88 eig' = eigR 93 eig' = eigR
@@ -112,6 +117,7 @@ instance Field (Complex Double) where
112 linearSolve' = linearSolveC 117 linearSolve' = linearSolveC
113 mbLinearSolve' = mbLinearSolveC 118 mbLinearSolve' = mbLinearSolveC
114 cholSolve' = cholSolveC 119 cholSolve' = cholSolveC
120 triSolve' = triSolveC
115 linearSolveLS' = linearSolveLSC 121 linearSolveLS' = linearSolveLSC
116 linearSolveSVD' = linearSolveSVDC Nothing 122 linearSolveSVD' = linearSolveSVDC Nothing
117 eig' = eigC 123 eig' = eigC
@@ -350,6 +356,10 @@ cholSolve
350 -> Matrix t -- ^ solution 356 -> Matrix t -- ^ solution
351cholSolve = {-# SCC "cholSolve" #-} cholSolve' 357cholSolve = {-# SCC "cholSolve" #-} cholSolve'
352 358
359-- | Solve a triangular linear system.
360triSolve :: Field t => UpLo -> Matrix t -> Matrix t -> Matrix t
361triSolve = {-# SCC "triSolve" #-} triSolve'
362
353-- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. 363-- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value.
354linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t 364linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t
355linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' 365linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD'
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index ff7ad92..4a8129c 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -584,6 +584,90 @@ int cholSolveC_l(KOCMAT(a),OCMAT(b)) {
584 OK 584 OK
585} 585}
586 586
587//////// triangular real linear system ////////////
588
589int dtrtrs_(char *uplo, char *trans, char *diag, integer *n, integer *nrhs,
590 doublereal *a, integer *lda, doublereal *b, integer *ldb, integer *
591 info);
592
593int triSolveR_l_u(KODMAT(a),ODMAT(b)) {
594 integer n = ar;
595 integer lda = aXc;
596 integer nhrs = bc;
597 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
598 DEBUGMSG("triSolveR_l_u");
599 integer res;
600 dtrtrs_ ("U",
601 "N",
602 "N",
603 &n,&nhrs,
604 (double*)ap, &lda,
605 bp, &n,
606 &res);
607 CHECK(res,res);
608 OK
609}
610
611int triSolveR_l_l(KODMAT(a),ODMAT(b)) {
612 integer n = ar;
613 integer lda = aXc;
614 integer nhrs = bc;
615 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
616 DEBUGMSG("triSolveR_l_l");
617 integer res;
618 dtrtrs_ ("L",
619 "N",
620 "N",
621 &n,&nhrs,
622 (double*)ap, &lda,
623 bp, &n,
624 &res);
625 CHECK(res,res);
626 OK
627}
628
629//////// triangular complex linear system ////////////
630
631int ztrtrs_(char *uplo, char *trans, char *diag, integer *n, integer *nrhs,
632 doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb,
633 integer *info);
634
635int triSolveC_l_u(KOCMAT(a),OCMAT(b)) {
636 integer n = ar;
637 integer lda = aXc;
638 integer nhrs = bc;
639 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
640 DEBUGMSG("triSolveC_l_u");
641 integer res;
642 ztrtrs_ ("U",
643 "N",
644 "N",
645 &n,&nhrs,
646 (doublecomplex*)ap, &lda,
647 bp, &n,
648 &res);
649 CHECK(res,res);
650 OK
651}
652
653int triSolveC_l_l(KOCMAT(a),OCMAT(b)) {
654 integer n = ar;
655 integer lda = aXc;
656 integer nhrs = bc;
657 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
658 DEBUGMSG("triSolveC_l_u");
659 integer res;
660 ztrtrs_ ("L",
661 "N",
662 "N",
663 &n,&nhrs,
664 (doublecomplex*)ap, &lda,
665 bp, &n,
666 &res);
667 CHECK(res,res);
668 OK
669}
670
587//////////////////// least squares real linear system //////////// 671//////////////////// least squares real linear system ////////////
588 672
589int dgels_(char *trans, integer *m, integer *n, integer * 673int dgels_(char *trans, integer *m, integer *n, integer *
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index 231109a..b4dd5cf 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -406,6 +406,36 @@ cholSolveR a b = linearSolveSQAux2 id dpotrs "cholSolveR" (fmat a) b
406cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 406cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
407cholSolveC a b = linearSolveSQAux2 id zpotrs "cholSolveC" (fmat a) b 407cholSolveC a b = linearSolveSQAux2 id zpotrs "cholSolveC" (fmat a) b
408 408
409--------------------------------------------------------------------------------
410foreign import ccall unsafe "triSolveR_l_u" dtrtrs_u :: R ::> R ::> Ok
411foreign import ccall unsafe "triSolveC_l_u" ztrtrs_u :: C ::> C ::> Ok
412foreign import ccall unsafe "triSolveR_l_l" dtrtrs_l :: R ::> R ::> Ok
413foreign import ccall unsafe "triSolveC_l_l" ztrtrs_l :: C ::> C ::> Ok
414
415
416linearSolveTRAux2 g f st a b
417 | n1==n2 && n1==r = unsafePerformIO . g $ do
418 s <- copy ColumnMajor b
419 (a #! s) f #| st
420 return s
421 | otherwise = error $ st ++ " of nonsquare matrix"
422 where
423 n1 = rows a
424 n2 = cols a
425 r = rows b
426
427data UpLo = Lower | Upper
428
429-- | Solves a triangular system of linear equations.
430triSolveR :: UpLo -> Matrix Double -> Matrix Double -> Matrix Double
431triSolveR Lower a b = linearSolveTRAux2 id dtrtrs_l "triSolveR" (fmat a) b
432triSolveR Upper a b = linearSolveTRAux2 id dtrtrs_u "triSolveR" (fmat a) b
433
434-- | Solves a triangular system of linear equations.
435triSolveC :: UpLo -> Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
436triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b
437triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b
438
409----------------------------------------------------------------------------------- 439-----------------------------------------------------------------------------------
410 440
411foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok 441foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index badf8f9..869330c 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -94,6 +94,9 @@ module Numeric.LinearAlgebra (
94 ldlSolve, ldlPacked, 94 ldlSolve, ldlPacked,
95 -- ** Positive definite 95 -- ** Positive definite
96 cholSolve, 96 cholSolve,
97 -- ** Triangular
98 UpLo(..),
99 triSolve,
97 -- ** Sparse 100 -- ** Sparse
98 cgSolve, 101 cgSolve,
99 cgSolve', 102 cgSolve',