summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2017-03-22 13:53:01 +0100
committerGitHub <noreply@github.com>2017-03-22 13:53:01 +0100
commitdb8efa9f0d46ee21f0dacdfe35c0d966d91d751d (patch)
treecfda55f02230554e9e6ecd78039d1b6ab3857672 /packages/base/src
parentddae74c9f73a1d7fcb8ad00bb74ee77ac3d01086 (diff)
parent49d718705d205d62aea2762445f95735a671f305 (diff)
Merge pull request #224 from idontgetoutmuch/master
Add tridiagonal solver and tests for it and triagonal solver.
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Internal/Algorithms.hs76
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c72
-rw-r--r--packages/base/src/Internal/LAPACK.hs22
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs2
4 files changed, 170 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs
index 23a5e13..99c9e34 100644
--- a/packages/base/src/Internal/Algorithms.hs
+++ b/packages/base/src/Internal/Algorithms.hs
@@ -62,6 +62,7 @@ class (Numeric t,
62 linearSolve' :: Matrix t -> Matrix t -> Matrix t 62 linearSolve' :: Matrix t -> Matrix t -> Matrix t
63 cholSolve' :: Matrix t -> Matrix t -> Matrix t 63 cholSolve' :: Matrix t -> Matrix t -> Matrix t
64 triSolve' :: UpLo -> Matrix t -> Matrix t -> Matrix t 64 triSolve' :: UpLo -> Matrix t -> Matrix t -> Matrix t
65 triDiagSolve' :: Vector t -> Vector t -> Vector t -> Matrix t -> Matrix t
65 ldlPacked' :: Matrix t -> (Matrix t, [Int]) 66 ldlPacked' :: Matrix t -> (Matrix t, [Int])
66 ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t 67 ldlSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t
67 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t 68 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t
@@ -88,6 +89,7 @@ instance Field Double where
88 mbLinearSolve' = mbLinearSolveR 89 mbLinearSolve' = mbLinearSolveR
89 cholSolve' = cholSolveR 90 cholSolve' = cholSolveR
90 triSolve' = triSolveR 91 triSolve' = triSolveR
92 triDiagSolve' = triDiagSolveR
91 linearSolveLS' = linearSolveLSR 93 linearSolveLS' = linearSolveLSR
92 linearSolveSVD' = linearSolveSVDR Nothing 94 linearSolveSVD' = linearSolveSVDR Nothing
93 eig' = eigR 95 eig' = eigR
@@ -118,6 +120,7 @@ instance Field (Complex Double) where
118 mbLinearSolve' = mbLinearSolveC 120 mbLinearSolve' = mbLinearSolveC
119 cholSolve' = cholSolveC 121 cholSolve' = cholSolveC
120 triSolve' = triSolveC 122 triSolve' = triSolveC
123 triDiagSolve' = triDiagSolveC
121 linearSolveLS' = linearSolveLSC 124 linearSolveLS' = linearSolveLSC
122 linearSolveSVD' = linearSolveSVDC Nothing 125 linearSolveSVD' = linearSolveSVDC Nothing
123 eig' = eigC 126 eig' = eigC
@@ -356,10 +359,79 @@ cholSolve
356 -> Matrix t -- ^ solution 359 -> Matrix t -- ^ solution
357cholSolve = {-# SCC "cholSolve" #-} cholSolve' 360cholSolve = {-# SCC "cholSolve" #-} cholSolve'
358 361
359-- | Solve a triangular linear system. 362-- | Solve a triangular linear system. If `Upper` is specified then
360triSolve :: Field t => UpLo -> Matrix t -> Matrix t -> Matrix t 363-- all elements below the diagonal are ignored; if `Lower` is
364-- specified then all elements above the diagonal are ignored.
365triSolve
366 :: Field t
367 => UpLo -- ^ `Lower` or `Upper`
368 -> Matrix t -- ^ coefficient matrix
369 -> Matrix t -- ^ right hand sides
370 -> Matrix t -- ^ solution
361triSolve = {-# SCC "triSolve" #-} triSolve' 371triSolve = {-# SCC "triSolve" #-} triSolve'
362 372
373-- | Solve a tridiagonal linear system. Suppose you wish to solve \(Ax = b\) where
374--
375-- \[
376-- A =
377-- \begin{bmatrix}
378-- 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0
379-- \\ 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0
380-- \\ 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0
381-- \\ 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0 & 0.0
382-- \\ 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0 & 0.0
383-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0 & 0.0
384-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0 & 0.0
385-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0 & 4.0
386-- \\ 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 3.0 & 1.0
387-- \end{bmatrix}
388-- \quad
389-- b =
390-- \begin{bmatrix}
391-- 1.0 & 1.0 & 1.0
392-- \\ 1.0 & -1.0 & 2.0
393-- \\ 1.0 & 1.0 & 3.0
394-- \\ 1.0 & -1.0 & 4.0
395-- \\ 1.0 & 1.0 & 5.0
396-- \\ 1.0 & -1.0 & 6.0
397-- \\ 1.0 & 1.0 & 7.0
398-- \\ 1.0 & -1.0 & 8.0
399-- \\ 1.0 & 1.0 & 9.0
400-- \end{bmatrix}
401-- \]
402--
403-- then
404--
405-- @
406-- dL = fromList [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0]
407-- d = fromList [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
408-- dU = fromList [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0]
409--
410-- b = (9><3)
411-- [
412-- 1.0, 1.0, 1.0,
413-- 1.0, -1.0, 2.0,
414-- 1.0, 1.0, 3.0,
415-- 1.0, -1.0, 4.0,
416-- 1.0, 1.0, 5.0,
417-- 1.0, -1.0, 6.0,
418-- 1.0, 1.0, 7.0,
419-- 1.0, -1.0, 8.0,
420-- 1.0, 1.0, 9.0
421-- ]
422--
423-- x = triDiagSolve dL d dU b
424-- @
425--
426triDiagSolve
427 :: Field t
428 => Vector t -- ^ lower diagonal: \(n - 1\) elements
429 -> Vector t -- ^ diagonal: \(n\) elements
430 -> Vector t -- ^ upper diagonal: \(n - 1\) elements
431 -> Matrix t -- ^ right hand sides
432 -> Matrix t -- ^ solution
433triDiagSolve = {-# SCC "triDiagSolve" #-} triDiagSolve'
434
363-- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value. 435-- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value.
364linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t 436linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t
365linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD' 437linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD'
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index 4a8129c..5018a45 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -668,6 +668,78 @@ int triSolveC_l_l(KOCMAT(a),OCMAT(b)) {
668 OK 668 OK
669} 669}
670 670
671//////// tridiagonal real linear system ////////////
672
673int dgttrf_(integer *n,
674 doublereal *dl, doublereal *d, doublereal *du, doublereal *du2,
675 integer *ipiv,
676 integer *info);
677
678int dgttrs_(char *trans, integer *n, integer *nrhs,
679 doublereal *dl, doublereal *d, doublereal *du, doublereal *du2,
680 integer *ipiv, doublereal *b, integer *ldb,
681 integer *info);
682
683int triDiagSolveR_l(DVEC(dl), DVEC(d), DVEC(du), ODMAT(b)) {
684 integer n = dn;
685 integer nhrs = bc;
686 REQUIRES(n >= 1 && dln == dn - 1 && dun == dn - 1 && br == n, BAD_SIZE);
687 DEBUGMSG("triDiagSolveR_l");
688 integer res;
689 integer* ipiv = (integer*)malloc(n*sizeof(integer));
690 double* du2 = (double*)malloc((n - 2)*sizeof(double));
691 dgttrf_ (&n,
692 dlp, dp, dup, du2,
693 ipiv,
694 &res);
695 CHECK(res,res);
696 dgttrs_ ("N",
697 &n,&nhrs,
698 dlp, dp, dup, du2,
699 ipiv, bp, &n,
700 &res);
701 CHECK(res,res);
702 free(ipiv);
703 free(du2);
704 OK
705}
706
707//////// tridiagonal complex linear system ////////////
708
709int zgttrf_(integer *n,
710 doublecomplex *dl, doublecomplex *d, doublecomplex *du, doublecomplex *du2,
711 integer *ipiv,
712 integer *info);
713
714int zgttrs_(char *trans, integer *n, integer *nrhs,
715 doublecomplex *dl, doublecomplex *d, doublecomplex *du, doublecomplex *du2,
716 integer *ipiv, doublecomplex *b, integer *ldb,
717 integer *info);
718
719int triDiagSolveC_l(CVEC(dl), CVEC(d), CVEC(du), OCMAT(b)) {
720 integer n = dn;
721 integer nhrs = bc;
722 REQUIRES(n >= 1 && dln == dn - 1 && dun == dn - 1 && br == n, BAD_SIZE);
723 DEBUGMSG("triDiagSolveC_l");
724 integer res;
725 integer* ipiv = (integer*)malloc(n*sizeof(integer));
726 doublecomplex* du2 = (doublecomplex*)malloc((n - 2)*sizeof(doublecomplex));
727 zgttrf_ (&n,
728 dlp, dp, dup, du2,
729 ipiv,
730 &res);
731 CHECK(res,res);
732 zgttrs_ ("N",
733 &n,&nhrs,
734 dlp, dp, dup, du2,
735 ipiv, bp, &n,
736 &res);
737 CHECK(res,res);
738 free(ipiv);
739 free(du2);
740 OK
741}
742
671//////////////////// least squares real linear system //////////// 743//////////////////// least squares real linear system ////////////
672 744
673int dgels_(char *trans, integer *m, integer *n, integer * 745int dgels_(char *trans, integer *m, integer *n, integer *
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index b4dd5cf..e306454 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -436,6 +436,28 @@ triSolveC :: UpLo -> Matrix (Complex Double) -> Matrix (Complex Double) -> Matri
436triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b 436triSolveC Lower a b = linearSolveTRAux2 id ztrtrs_l "triSolveC" (fmat a) b
437triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b 437triSolveC Upper a b = linearSolveTRAux2 id ztrtrs_u "triSolveC" (fmat a) b
438 438
439--------------------------------------------------------------------------------
440foreign import ccall unsafe "triDiagSolveR_l" dgttrs :: R :> R :> R :> R ::> Ok
441foreign import ccall unsafe "triDiagSolveC_l" zgttrs :: C :> C :> C :> C ::> Ok
442
443linearSolveGTAux2 g f st dl d du b
444 | ndl == nd - 1 &&
445 ndu == nd - 1 &&
446 nd == r = unsafePerformIO . g $ do
447 s <- copy ColumnMajor b
448 (dl # d # du #! s) f #| st
449 return s
450 | otherwise = error $ st ++ " of nonsquare matrix"
451 where
452 ndl = dim dl
453 nd = dim d
454 ndu = dim du
455 r = rows b
456
457-- | Solves a tridiagonal system of linear equations.
458triDiagSolveR dl d du b = linearSolveGTAux2 id dgttrs "triDiagSolveR" dl d du b
459triDiagSolveC dl d du b = linearSolveGTAux2 id zgttrs "triDiagSolveC" dl d du b
460
439----------------------------------------------------------------------------------- 461-----------------------------------------------------------------------------------
440 462
441foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok 463foreign import ccall unsafe "linearSolveLSR_l" dgels :: R ::> R ::> Ok
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index 869330c..fd100e0 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -97,6 +97,8 @@ module Numeric.LinearAlgebra (
97 -- ** Triangular 97 -- ** Triangular
98 UpLo(..), 98 UpLo(..),
99 triSolve, 99 triSolve,
100 -- ** Tridiagonal
101 triDiagSolve,
100 -- ** Sparse 102 -- ** Sparse
101 cgSolve, 103 cgSolve,
102 cgSolve', 104 cgSolve',