diff options
Diffstat (limited to 'lib/Numeric')
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 58 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK.hs | 21 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 46 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | 13 |
4 files changed, 120 insertions, 18 deletions
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index 069d9a3..b19c0ec 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -37,6 +37,8 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
37 | hess, | 37 | hess, |
38 | -- ** Schur | 38 | -- ** Schur |
39 | schur, | 39 | schur, |
40 | -- ** LU | ||
41 | lu, | ||
40 | -- * Matrix functions | 42 | -- * Matrix functions |
41 | expm, | 43 | expm, |
42 | sqrtm, | 44 | sqrtm, |
@@ -52,11 +54,11 @@ module Numeric.LinearAlgebra.Algorithms ( | |||
52 | -- * Util | 54 | -- * Util |
53 | haussholder, | 55 | haussholder, |
54 | unpackQR, unpackHess, | 56 | unpackQR, unpackHess, |
55 | Field(linearSolveSVD,lu,eigSH',cholSH) | 57 | Field(linearSolveSVD,eigSH',cholSH) |
56 | ) where | 58 | ) where |
57 | 59 | ||
58 | 60 | ||
59 | import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj) | 61 | import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//)) |
60 | import Data.Packed | 62 | import Data.Packed |
61 | import qualified Numeric.GSL.Matrix as GSL | 63 | import qualified Numeric.GSL.Matrix as GSL |
62 | import Numeric.GSL.Vector | 64 | import Numeric.GSL.Vector |
@@ -64,12 +66,13 @@ import Numeric.LinearAlgebra.LAPACK as LAPACK | |||
64 | import Complex | 66 | import Complex |
65 | import Numeric.LinearAlgebra.Linear | 67 | import Numeric.LinearAlgebra.Linear |
66 | import Data.List(foldl1') | 68 | import Data.List(foldl1') |
69 | import Data.Array | ||
67 | 70 | ||
68 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. | 71 | -- | Auxiliary typeclass used to define generic computations for both real and complex matrices. |
69 | class (Normed (Matrix t), Linear Matrix t) => Field t where | 72 | class (Normed (Matrix t), Linear Matrix t) => Field t where |
70 | -- | Singular value decomposition using lapack's dgesvd or zgesvd. | 73 | -- | Singular value decomposition using lapack's dgesvd or zgesvd. |
71 | svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) | 74 | svd :: Matrix t -> (Matrix t, Vector Double, Matrix t) |
72 | lu :: Matrix t -> (Matrix t, Matrix t, [Int], t) | 75 | luPacked :: Matrix t -> (Matrix t, [Int]) |
73 | -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv. | 76 | -- | Solution of a general linear system (for several right-hand sides) using lapacks' dgesv and zgesv. |
74 | -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". | 77 | -- See also other versions of linearSolve in "Numeric.LinearAlgebra.LAPACK". |
75 | linearSolve :: Matrix t -> Matrix t -> Matrix t | 78 | linearSolve :: Matrix t -> Matrix t -> Matrix t |
@@ -106,7 +109,7 @@ class (Normed (Matrix t), Linear Matrix t) => Field t where | |||
106 | 109 | ||
107 | instance Field Double where | 110 | instance Field Double where |
108 | svd = svdR | 111 | svd = svdR |
109 | lu = GSL.luR | 112 | luPacked = luR |
110 | linearSolve = linearSolveR | 113 | linearSolve = linearSolveR |
111 | linearSolveSVD = linearSolveSVDR Nothing | 114 | linearSolveSVD = linearSolveSVDR Nothing |
112 | ctrans = trans | 115 | ctrans = trans |
@@ -119,7 +122,7 @@ instance Field Double where | |||
119 | 122 | ||
120 | instance Field (Complex Double) where | 123 | instance Field (Complex Double) where |
121 | svd = svdC | 124 | svd = svdC |
122 | lu = GSL.luC | 125 | luPacked = luC |
123 | linearSolve = linearSolveC | 126 | linearSolve = linearSolveC |
124 | linearSolveSVD = linearSolveSVDC Nothing | 127 | linearSolveSVD = linearSolveSVDC Nothing |
125 | ctrans = conj . trans | 128 | ctrans = conj . trans |
@@ -146,10 +149,19 @@ chol m | m `equal` ctrans m = cholSH m | |||
146 | 149 | ||
147 | square m = rows m == cols m | 150 | square m = rows m == cols m |
148 | 151 | ||
152 | -- | determinant of a square matrix, computed from the LU decomposition. | ||
149 | det :: Field t => Matrix t -> t | 153 | det :: Field t => Matrix t -> t |
150 | det m | square m = s * (product $ toList $ takeDiag $ u) | 154 | det m | square m = s * (product $ toList $ takeDiag $ lu) |
151 | | otherwise = error "det of nonsquare matrix" | 155 | | otherwise = error "det of nonsquare matrix" |
152 | where (_,u,_,s) = lu m | 156 | where (lu,perm) = luPacked m |
157 | s = signlp (rows m) perm | ||
158 | |||
159 | -- | LU factorization of a general matrix using lapack's dgetrf or zgetrf. | ||
160 | -- | ||
161 | -- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular, | ||
162 | -- u is upper triangular, p is a permutation matrix and s is the signature of the permutation. | ||
163 | lu :: Field t => Matrix t -> (Matrix t, Matrix t, Matrix t, t) | ||
164 | lu = luFact . luPacked | ||
153 | 165 | ||
154 | -- | Inverse of a square matrix using lapacks' dgesv and zgesv. | 166 | -- | Inverse of a square matrix using lapacks' dgesv and zgesv. |
155 | inv :: Field t => Matrix t -> Matrix t | 167 | inv :: Field t => Matrix t -> Matrix t |
@@ -457,3 +469,35 @@ sqrtmInv x = fst $ fixedPoint $ iterate f (x, ident (rows x)) | |||
457 | (.*) = scale | 469 | (.*) = scale |
458 | (|+|) = add | 470 | (|+|) = add |
459 | (|-|) = sub | 471 | (|-|) = sub |
472 | |||
473 | ------------------------------------------------------------------ | ||
474 | |||
475 | signlp r vals = foldl f 1 (zip [0..r-1] vals) | ||
476 | where f s (a,b) | a /= b = -s | ||
477 | | otherwise = s | ||
478 | |||
479 | swap (arr,s) (a,b) | a /= b = (arr // [(a, arr!b),(b,arr!a)],-s) | ||
480 | | otherwise = (arr,s) | ||
481 | |||
482 | fixPerm r vals = (fromColumns $ elems res, sign) | ||
483 | where v = [0..r-1] | ||
484 | s = toColumns (ident r) | ||
485 | (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) | ||
486 | |||
487 | triang r c h v = reshape c $ fromList [el i j | i<-[0..r-1], j<-[0..c-1]] | ||
488 | where el i j = if j-i>=h then v else 1 - v | ||
489 | |||
490 | luFact (lu,perm) | r <= c = (l ,u ,p, s) | ||
491 | | otherwise = (l',u',p, s) | ||
492 | where | ||
493 | r = rows lu | ||
494 | c = cols lu | ||
495 | tu = triang r c 0 1 | ||
496 | tl = triang r c 0 0 | ||
497 | l = takeColumns r (lu |*| tl) |+| diagRect (constant 1 r) r r | ||
498 | u = lu |*| tu | ||
499 | (p,s) = fixPerm r perm | ||
500 | l' = (lu |*| tl) |+| diagRect (constant 1 c) r c | ||
501 | u' = takeRows c (lu |*| tu) | ||
502 | (|+|) = add | ||
503 | (|*|) = mul | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index cacad87..83db901 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs | |||
@@ -19,6 +19,7 @@ module Numeric.LinearAlgebra.LAPACK ( | |||
19 | linearSolveR, linearSolveC, | 19 | linearSolveR, linearSolveC, |
20 | linearSolveLSR, linearSolveLSC, | 20 | linearSolveLSR, linearSolveLSC, |
21 | linearSolveSVDR, linearSolveSVDC, | 21 | linearSolveSVDR, linearSolveSVDC, |
22 | luR, luC, | ||
22 | cholS, cholH, | 23 | cholS, cholH, |
23 | qrR, qrC, | 24 | qrR, qrC, |
24 | hessR, hessC, | 25 | hessR, hessC, |
@@ -299,7 +300,7 @@ hessAux f st a = unsafePerformIO $ do | |||
299 | mn = min m n | 300 | mn = min m n |
300 | 301 | ||
301 | ----------------------------------------------------------------------------------- | 302 | ----------------------------------------------------------------------------------- |
302 | foreign import ccall safe "LAPACK/lapack-aux.h schur_l_R" dgees :: TMMM | 303 | foreign import ccall "LAPACK/lapack-aux.h schur_l_R" dgees :: TMMM |
303 | foreign import ccall "LAPACK/lapack-aux.h schur_l_C" zgees :: TCMCMCM | 304 | foreign import ccall "LAPACK/lapack-aux.h schur_l_C" zgees :: TCMCMCM |
304 | 305 | ||
305 | -- | Wrapper for LAPACK's /dgees/, which computes a Schur factorization of a square real matrix. | 306 | -- | Wrapper for LAPACK's /dgees/, which computes a Schur factorization of a square real matrix. |
@@ -318,3 +319,21 @@ schurAux f st a = unsafePerformIO $ do | |||
318 | where n = rows a | 319 | where n = rows a |
319 | 320 | ||
320 | ----------------------------------------------------------------------------------- | 321 | ----------------------------------------------------------------------------------- |
322 | foreign import ccall "LAPACK/lapack-aux.h lu_l_R" dgetrf :: TMVM | ||
323 | foreign import ccall "LAPACK/lapack-aux.h lu_l_C" zgetrf :: TCMVCM | ||
324 | |||
325 | -- | Wrapper for LAPACK's /dgetrf/, which computes a LU factorization of a general real matrix. | ||
326 | luR :: Matrix Double -> (Matrix Double, [Int]) | ||
327 | luR = luAux dgetrf "luR" . fmat | ||
328 | |||
329 | -- | Wrapper for LAPACK's /zgees/, which computes a Schur factorization of a square complex matrix. | ||
330 | luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int]) | ||
331 | luC = luAux zgetrf "luC" . fmat | ||
332 | |||
333 | luAux f st a = unsafePerformIO $ do | ||
334 | lu <- createMatrix ColumnMajor n m | ||
335 | piv <- createVector (min n m) | ||
336 | app3 f mat a vec piv mat lu st | ||
337 | return (lu, map (pred.round) (toList piv)) | ||
338 | where n = rows a | ||
339 | m = cols a | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 8392feb..310f6ee 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -768,3 +768,49 @@ int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) { | |||
768 | OK | 768 | OK |
769 | #endif | 769 | #endif |
770 | } | 770 | } |
771 | |||
772 | //////////////////// LU factorization ///////////////////////// | ||
773 | |||
774 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)) { | ||
775 | integer m = ar; | ||
776 | integer n = ac; | ||
777 | integer mn = MIN(m,n); | ||
778 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); | ||
779 | DEBUGMSG("lu_l_R"); | ||
780 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); | ||
781 | memcpy(rp,ap,m*n*sizeof(double)); | ||
782 | integer res; | ||
783 | dgetrf_ (&m,&n,rp,&m,auxipiv,&res); | ||
784 | if(res>0) { | ||
785 | res = 0; // fixme | ||
786 | } | ||
787 | CHECK(res,res); | ||
788 | int k; | ||
789 | for (k=0; k<mn; k++) { | ||
790 | ipivp[k] = auxipiv[k]; | ||
791 | } | ||
792 | free(auxipiv); | ||
793 | OK | ||
794 | } | ||
795 | |||
796 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) { | ||
797 | integer m = ar; | ||
798 | integer n = ac; | ||
799 | integer mn = MIN(m,n); | ||
800 | REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE); | ||
801 | DEBUGMSG("lu_l_C"); | ||
802 | integer* auxipiv = (integer*)malloc(mn*sizeof(integer)); | ||
803 | memcpy(rp,ap,m*n*sizeof(doublecomplex)); | ||
804 | integer res; | ||
805 | zgetrf_ (&m,&n,(doublecomplex*)rp,&m,auxipiv,&res); | ||
806 | if(res>0) { | ||
807 | res = 0; // fixme | ||
808 | } | ||
809 | CHECK(res,res); | ||
810 | int k; | ||
811 | for (k=0; k<mn; k++) { | ||
812 | ipivp[k] = auxipiv[k]; | ||
813 | } | ||
814 | free(auxipiv); | ||
815 | OK | ||
816 | } | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h index e5d74d7..bccd4b8 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.h | |||
@@ -14,41 +14,34 @@ | |||
14 | 14 | ||
15 | int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); | 15 | int svd_l_R(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); |
16 | int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); | 16 | int svd_l_Rdd(KDMAT(x),DMAT(u),DVEC(s),DMAT(v)); |
17 | |||
18 | int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v)); | 17 | int svd_l_C(KCMAT(a),CMAT(u),DVEC(s),CMAT(v)); |
19 | 18 | ||
20 | int eig_l_C(KCMAT(a),CMAT(u),CVEC(s),CMAT(v)); | 19 | int eig_l_C(KCMAT(a),CMAT(u),CVEC(s),CMAT(v)); |
21 | |||
22 | int eig_l_R(KDMAT(a),DMAT(u),CVEC(s),DMAT(v)); | 20 | int eig_l_R(KDMAT(a),DMAT(u),CVEC(s),DMAT(v)); |
23 | 21 | ||
24 | int eig_l_S(KDMAT(a),DVEC(s),DMAT(v)); | 22 | int eig_l_S(KDMAT(a),DVEC(s),DMAT(v)); |
25 | |||
26 | int eig_l_H(KCMAT(a),DVEC(s),CMAT(v)); | 23 | int eig_l_H(KCMAT(a),DVEC(s),CMAT(v)); |
27 | 24 | ||
28 | int linearSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)); | 25 | int linearSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)); |
29 | |||
30 | int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)); | 26 | int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)); |
31 | 27 | ||
32 | int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)); | 28 | int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)); |
33 | |||
34 | int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x)); | 29 | int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x)); |
35 | 30 | ||
36 | int linearSolveSVDR_l(double,KDMAT(a),KDMAT(b),DMAT(x)); | 31 | int linearSolveSVDR_l(double,KDMAT(a),KDMAT(b),DMAT(x)); |
37 | |||
38 | int linearSolveSVDC_l(double,KCMAT(a),KCMAT(b),CMAT(x)); | 32 | int linearSolveSVDC_l(double,KCMAT(a),KCMAT(b),CMAT(x)); |
39 | 33 | ||
40 | int chol_l_H(KCMAT(a),CMAT(r)); | 34 | int chol_l_H(KCMAT(a),CMAT(r)); |
41 | |||
42 | int chol_l_S(KDMAT(a),DMAT(r)); | 35 | int chol_l_S(KDMAT(a),DMAT(r)); |
43 | 36 | ||
44 | int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)); | 37 | int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)); |
45 | |||
46 | int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)); | 38 | int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)); |
47 | 39 | ||
48 | int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)); | 40 | int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)); |
49 | |||
50 | int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)); | 41 | int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)); |
51 | 42 | ||
52 | int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)); | 43 | int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)); |
53 | |||
54 | int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)); | 44 | int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)); |
45 | |||
46 | int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)); | ||
47 | int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)); | ||