diff options
author | Dominic Steinitz <dominic@steinitz.org> | 2017-03-17 14:20:07 +0000 |
---|---|---|
committer | Dominic Steinitz <dominic@steinitz.org> | 2017-03-17 14:20:07 +0000 |
commit | fa1642dcf26f1da0a6f4c1324bcd1e8baf9fd478 (patch) | |
tree | 356a1c759bd5f54f20399e57ff1f99afca14733c /packages | |
parent | d2d0066d2ff3d8e66ce902ee1b9d1317f1710a2c (diff) |
Support triangular matrices.
Diffstat (limited to 'packages')
-rw-r--r-- | packages/base/src/Internal/Algorithms.hs | 14 | ||||
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 84 | ||||
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 30 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra.hs | 3 |
4 files changed, 129 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index 70d65d7..23a5e13 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs | |||
@@ -20,13 +20,16 @@ imported from "Numeric.LinearAlgebra.LAPACK". | |||
20 | -} | 20 | -} |
21 | ----------------------------------------------------------------------------- | 21 | ----------------------------------------------------------------------------- |
22 | 22 | ||
23 | module Internal.Algorithms where | 23 | module Internal.Algorithms ( |
24 | module Internal.Algorithms, | ||
25 | UpLo(..) | ||
26 | ) where | ||
24 | 27 | ||
25 | import Internal.Vector | 28 | import Internal.Vector |
26 | import Internal.Matrix | 29 | import Internal.Matrix |
27 | import Internal.Element | 30 | import Internal.Element |
28 | import Internal.Conversion | 31 | import Internal.Conversion |
29 | import Internal.LAPACK as LAPACK | 32 | import Internal.LAPACK |
30 | import Internal.Numeric | 33 | import Internal.Numeric |
31 | import Data.List(foldl1') | 34 | import Data.List(foldl1') |
32 | import qualified Data.Array as A | 35 | import qualified Data.Array as A |
@@ -58,6 +61,7 @@ class (Numeric t, | |||
58 | mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t) | 61 | mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t) |
59 | linearSolve' :: Matrix t -> Matrix t -> Matrix t | 62 | linearSolve' :: Matrix t -> Matrix t -> Matrix t |
60 | cholSolve' :: Matrix t -> Matrix t -> Matrix t | 63 | cholSolve' :: Matrix t -> Matrix t -> Matrix t |
64 | triSolve' :: UpLo -> Matrix t -> Matrix t -> Matrix t | ||
61 | ldlPacked' :: Matrix t -> (Matrix t, [Int]) | 65 | ldlPacked' :: Matrix t -> (Matrix t, [Int]) |
62 | ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t | 66 | ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t |
63 | linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t | 67 | linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t |
@@ -83,6 +87,7 @@ instance Field Double where | |||
83 | linearSolve' = linearSolveR -- (luSolve . luPacked) ?? | 87 | linearSolve' = linearSolveR -- (luSolve . luPacked) ?? |
84 | mbLinearSolve' = mbLinearSolveR | 88 | mbLinearSolve' = mbLinearSolveR |
85 | cholSolve' = cholSolveR | 89 | cholSolve' = cholSolveR |
90 | triSolve' = triSolveR | ||
86 | linearSolveLS' = linearSolveLSR | 91 | linearSolveLS' = linearSolveLSR |
87 | linearSolveSVD' = linearSolveSVDR Nothing | 92 | linearSolveSVD' = linearSolveSVDR Nothing |
88 | eig' = eigR | 93 | eig' = eigR |
@@ -112,6 +117,7 @@ instance Field (Complex Double) where | |||
112 | linearSolve' = linearSolveC | 117 | linearSolve' = linearSolveC |
113 | mbLinearSolve' = mbLinearSolveC | 118 | mbLinearSolve' = mbLinearSolveC |
114 | cholSolve' = cholSolveC | 119 | cholSolve' = cholSolveC |
120 | triSolve' = triSolveC | ||
115 | linearSolveLS' = linearSolveLSC | 121 | linearSolveLS' = linearSolveLSC |
116 | linearSolveSVD' = linearSolveSVDC Nothing | 122 | linearSolveSVD' = linearSolveSVDC Nothing |
117 | eig' = eigC | 123 | eig' = eigC |
@@ -350,6 +356,10 @@ cholSolve | |||
350 | -> Matrix t -- ^ solution | 356 | -> Matrix t -- ^ solution |
351 | cholSolve = {-# SCC "cholSolve" #-} cholSolve' | 357 | cholSolve = {-# SCC "cholSolve" #-} cholSolve' |
352 | 358 | ||
359 | -- | Solve a triangular linear system. | ||
360 | triSolve :: Field t => UpLo -> Matrix t -> Matrix t -> Matrix t | ||
361 | triSolve = {-# SCC "triSolve" #-} triSolve' | ||
362 | |||
353 | -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. | 363 | -- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. |
354 | linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t | 364 | linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t |
355 | linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' | 365 | linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' |
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index ff7ad92..4a8129c 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -584,6 +584,90 @@ int cholSolveC_l(KOCMAT(a),OCMAT(b)) { | |||
584 | OK | 584 | OK |
585 | } | 585 | } |
586 | 586 | ||
587 | //////// triangular real linear system //////////// | ||
588 | |||
589 | int dtrtrs_(char *uplo, char *trans, char *diag, integer *n, integer *nrhs, | ||
590 | doublereal *a, integer *lda, doublereal *b, integer *ldb, integer * | ||
591 | info); | ||
592 | |||
593 | int triSolveR_l_u(KODMAT(a),ODMAT(b)) { | ||
594 | integer n = ar; | ||
595 | integer lda = aXc; | ||
596 | integer nhrs = bc; | ||
597 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | ||
598 | DEBUGMSG("triSolveR_l_u"); | ||
599 | integer res; | ||
600 | dtrtrs_ ("U", | ||
601 | "N", | ||
602 | "N", | ||
603 | &n,&nhrs, | ||
604 | (double*)ap, &lda, | ||
605 | bp, &n, | ||
606 | &res); | ||
607 | CHECK(res,res); | ||
608 | OK | ||
609 | } | ||
610 | |||
611 | int triSolveR_l_l(KODMAT(a),ODMAT(b)) { | ||
612 | integer n = ar; | ||
613 | integer lda = aXc; | ||
614 | integer nhrs = bc; | ||
615 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | ||
616 | DEBUGMSG("triSolveR_l_l"); | ||
617 | integer res; | ||
618 | dtrtrs_ ("L", | ||
619 | "N", | ||
620 | "N", | ||
621 | &n,&nhrs, | ||
622 | (double*)ap, &lda, | ||
623 | bp, &n, | ||
624 | &res); | ||
625 | CHECK(res,res); | ||
626 | OK | ||
627 | } | ||
628 | |||
629 | //////// triangular complex linear system //////////// | ||
630 | |||
631 | int ztrtrs_(char *uplo, char *trans, char *diag, integer *n, integer *nrhs, | ||
632 | doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, | ||
633 | integer *info); | ||
634 | |||
635 | int triSolveC_l_u(KOCMAT(a),OCMAT(b)) { | ||
636 | integer n = ar; | ||
637 | integer lda = aXc; | ||
638 | integer nhrs = bc; | ||
639 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | ||
640 | DEBUGMSG("triSolveC_l_u"); | ||
641 | integer res; | ||
642 | ztrtrs_ ("U", | ||
643 | "N", | ||
644 | "N", | ||
645 | &n,&nhrs, | ||
646 | (doublecomplex*)ap, &lda, | ||
647 | bp, &n, | ||
648 | &res); | ||
649 | CHECK(res,res); | ||
650 | OK | ||
651 | } | ||
652 | |||
653 | int triSolveC_l_l(KOCMAT(a),OCMAT(b)) { | ||
654 | integer n = ar; | ||
655 | integer lda = aXc; | ||
656 | integer nhrs = bc; | ||
657 | REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE); | ||
658 | DEBUGMSG("triSolveC_l_u"); | ||
659 | integer res; | ||
660 | ztrtrs_ ("L", | ||
661 | "N", | ||
662 | "N", | ||
663 | &n,&nhrs, | ||
664 | (doublecomplex*)ap, &lda, | ||
665 | bp, &n, | ||
666 | &res); | ||
667 | CHECK(res,res); | ||
668 | OK | ||
669 | } | ||
670 | |||
587 | //////////////////// least squares real linear system //////////// | 671 | //////////////////// least squares real linear system //////////// |
588 | 672 | ||
589 | int dgels_(char *trans, integer *m, integer *n, integer * | 673 | int dgels_(char *trans, integer *m, integer *n, integer * |
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index 231109a..b4dd5cf 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -406,6 +406,36 @@ cholSolveR a b = linearSolveSQAux2 id dpotrs "cholSolveR" (fmat a) b | |||
406 | cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | 406 | cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) |
407 | cholSolveC a b = linearSolveSQAux2 id zpotrs "cholSolveC" (fmat a) b | 407 | cholSolveC a b = linearSolveSQAux2 id zpotrs "cholSolveC" (fmat a) b |
408 | 408 | ||
409 | -------------------------------------------------------------------------------- | ||
410 | foreign import ccall unsafe "triSolveR_l_u" dtrtrs_u :: R ::> R ::> Ok | ||
411 | foreign import ccall unsafe "triSolveC_l_u" ztrtrs_u :: C ::> C ::> Ok | ||
412 | foreign import ccall unsafe "triSolveR_l_l" dtrtrs_l :: R ::> R ::> Ok | ||
413 | foreign import ccall unsafe "triSolveC_l_l" ztrtrs_l :: C ::> C ::> Ok | ||
414 | |||
415 | |||
416 | linearSolveTRAux2 g f st a b | ||
417 | | n1==n2 && n1==r = unsafePerformIO . g $ do | ||
418 | s <- copy ColumnMajor b | ||
419 | (a #! s) f #| st | ||
420 | return s | ||
421 | | otherwise = error $ st ++ " of nonsquare matrix" | ||
422 | where | ||
423 | n1 = rows a | ||
424 | n2 = cols a | ||
425 | r = rows b | ||
426 | |||
427 | data UpLo = Lower | Upper | ||
428 | |||
429 | -- | Solves a triangular system of linear equations. | ||
430 | triSolveR :: UpLo -> Matrix Double -> Matrix Double -> Matrix Double | ||
431 | triSolveR Lower a b = linearSolveTRAux2 id dtrtrs_l "triSolveR" (fmat a) b | ||
432 | triSolveR Upper a b = linearSolveTRAux2 id dtrtrs_u "triSolveR" (fmat a) b | ||
433 | |||
434 | -- | Solves a triangular system of linear equations. | ||
435 | triSolveC :: UpLo -> Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) | ||
436 | triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b | ||
437 | triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b | ||
438 | |||
409 | ----------------------------------------------------------------------------------- | 439 | ----------------------------------------------------------------------------------- |
410 | 440 | ||
411 | foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok | 441 | foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok |
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index badf8f9..869330c 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -94,6 +94,9 @@ module Numeric.LinearAlgebra ( | |||
94 | ldlSolve, ldlPacked, | 94 | ldlSolve, ldlPacked, |
95 | -- ** Positive definite | 95 | -- ** Positive definite |
96 | cholSolve, | 96 | cholSolve, |
97 | -- ** Triangular | ||
98 | UpLo(..), | ||
99 | triSolve, | ||
97 | -- ** Sparse | 100 | -- ** Sparse |
98 | cgSolve, | 101 | cgSolve, |
99 | cgSolve', | 102 | cgSolve', |