From c41d21fefa04c66039a0b218daaa53c2577ef838 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 12 Nov 2007 10:01:39 +0000 Subject: data structures simplification --- lib/Numeric/LinearAlgebra/LAPACK.hs | 240 +++++++++++++----------------------- 1 file changed, 86 insertions(+), 154 deletions(-) (limited to 'lib/Numeric/LinearAlgebra/LAPACK.hs') diff --git a/lib/Numeric/LinearAlgebra/LAPACK.hs b/lib/Numeric/LinearAlgebra/LAPACK.hs index 628d4f8..315be17 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK.hs +++ b/lib/Numeric/LinearAlgebra/LAPACK.hs @@ -1,4 +1,4 @@ -{-# OPTIONS_GHC -fglasgow-exts #-} +{-# OPTIONS_GHC #-} ----------------------------------------------------------------------------- -- | -- Module : Numeric.LinearAlgebra.LAPACK @@ -36,54 +36,52 @@ import Foreign ----------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM +foreign import ccall "LAPACK/lapack-aux.h svd_l_C" zgesvd :: TCMCMVCM +foreign import ccall "LAPACK/lapack-aux.h svd_l_Rdd" dgesdd :: TMMVM -- | Wrapper for LAPACK's /dgesvd/, which computes the full svd decomposition of a real matrix. -- -- @(u,s,v)=full svdR m@ so that @m=u \<\> s \<\> 'trans' v@. svdR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) -svdR x = unsafePerformIO $ do - u <- createMatrix ColumnMajor r r - s <- createVector (min r c) - v <- createMatrix ColumnMajor c c - dgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdR" [fdat x] - return (u,s,trans v) - where r = rows x - c = cols x ------------------------------------------------------------------------------ -foreign import ccall "LAPACK/lapack-aux.h svd_l_Rdd" dgesdd :: TMMVM +svdR = svdAux dgesvd "svdR" . fmat -- | Wrapper for LAPACK's /dgesvd/, which computes the full svd decomposition of a real matrix. -- -- @(u,s,v)=full svdRdd m@ so that @m=u \<\> s \<\> 'trans' v@. svdRdd :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) -svdRdd x = unsafePerformIO $ do - u <- createMatrix ColumnMajor r r - s <- createVector (min r c) - v <- createMatrix ColumnMajor c c - dgesdd // mat fdat x // mat dat u // vec s // mat dat v // check "svdRdd" [fdat x] - return (u,s,trans v) - where r = rows x - c = cols x - ------------------------------------------------------------------------------ -foreign import ccall "LAPACK/lapack-aux.h svd_l_C" zgesvd :: TCMCMVCM +svdRdd = svdAux dgesdd "svdRdd" . fmat -- | Wrapper for LAPACK's /zgesvd/, which computes the full svd decomposition of a complex matrix. -- -- @(u,s,v)=full svdC m@ so that @m=u \<\> comp s \<\> 'trans' v@. svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) -svdC x = unsafePerformIO $ do +svdC = svdAux zgesvd "svdC" . fmat + +svdAux f st x = unsafePerformIO $ do u <- createMatrix ColumnMajor r r s <- createVector (min r c) v <- createMatrix ColumnMajor c c - zgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdC" [fdat x] + f // matf x // matf u // vec s // matf v // check st [fdat x] return (u,s,trans v) where r = rows x c = cols x - ----------------------------------------------------------------------------- +eigAux f st m + | r == 1 = (fromList [flatten m `at` 0], singleton 1) + | otherwise = unsafePerformIO $ do + l <- createVector r + v <- createMatrix ColumnMajor r r + dummy <- createMatrix ColumnMajor 1 1 + f // matf m // matf dummy // vec l // matf v // check st [fdat m] + return (l,v) + where r = rows m + + foreign import ccall "LAPACK/lapack-aux.h eig_l_C" zgeev :: TCMCMCVCM +foreign import ccall "LAPACK/lapack-aux.h eig_l_R" dgeev :: TMMCVM +foreign import ccall "LAPACK/lapack-aux.h eig_l_S" dsyev :: TMVM +foreign import ccall "LAPACK/lapack-aux.h eig_l_H" zheev :: TCMVCM -- | Wrapper for LAPACK's /zgeev/, which computes the eigenvalues and right eigenvectors of a general complex matrix: -- @@ -92,18 +90,9 @@ foreign import ccall "LAPACK/lapack-aux.h eig_l_C" zgeev :: TCMCMCVCM -- The eigenvectors are the columns of v. -- The eigenvalues are not sorted. eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Double)) -eigC m - | r == 1 = (fromList [cdat m `at` 0], singleton 1) - | otherwise = unsafePerformIO $ do - l <- createVector r - v <- createMatrix ColumnMajor r r - dummy <- createMatrix ColumnMajor 1 1 - zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m] - return (l,v) - where r = rows m +eigC = eigAux zgeev "eigC" . fmat ----------------------------------------------------------------------------- -foreign import ccall "LAPACK/lapack-aux.h eig_l_R" dgeev :: TMMCVM -- | Wrapper for LAPACK's /dgeev/, which computes the eigenvalues and right eigenvectors of a general real matrix: -- @@ -113,7 +102,7 @@ foreign import ccall "LAPACK/lapack-aux.h eig_l_R" dgeev :: TMMCVM -- The eigenvalues are not sorted. eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double)) eigR m = (s', v'') - where (s,v) = eigRaux m + where (s,v) = eigRaux (fmat m) s' = toComplex (subVector 0 r (asReal s), subVector r r (asReal s)) v' = toRows $ trans v v'' = fromColumns $ fixeig (toList s') v' @@ -121,12 +110,12 @@ eigR m = (s', v'') eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) eigRaux m - | r == 1 = (fromList [(cdat m `at` 0):+0], singleton 1) + | r == 1 = (fromList [(flatten m `at` 0):+0], singleton 1) | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix ColumnMajor r r dummy <- createMatrix ColumnMajor 1 1 - dgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigR" [fdat m] + dgeev // matf m // matf dummy // vec l // matf v // check "eigR" [fdat m] return (l,v) where r = rows m @@ -138,7 +127,6 @@ fixeig ((r1:+i1):(r2:+i2):r) (v1:v2:vs) where scale = vectorMapValR Scale ----------------------------------------------------------------------------- -foreign import ccall "LAPACK/lapack-aux.h eig_l_S" dsyev :: TMVM -- | Wrapper for LAPACK's /dsyev/, which computes the eigenvalues and right eigenvectors of a symmetric real matrix: -- @@ -148,20 +136,19 @@ foreign import ccall "LAPACK/lapack-aux.h eig_l_S" dsyev :: TMVM -- The eigenvalues are sorted in descending order (use eigS' for ascending order). eigS :: Matrix Double -> (Vector Double, Matrix Double) eigS m = (s', fliprl v) - where (s,v) = eigS' m + where (s,v) = eigS' (fmat m) s' = fromList . reverse . toList $ s eigS' m - | r == 1 = (fromList [cdat m `at` 0], singleton 1) + | r == 1 = (fromList [flatten m `at` 0], singleton 1) | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix ColumnMajor r r - dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m] + dsyev // matf m // vec l // matf v // check "eigS" [fdat m] return (l,v) where r = rows m ----------------------------------------------------------------------------- -foreign import ccall "LAPACK/lapack-aux.h eig_l_H" zheev :: TCMVCM -- | Wrapper for LAPACK's /zheev/, which computes the eigenvalues and right eigenvectors of a hermitian complex matrix: -- @@ -171,165 +158,120 @@ foreign import ccall "LAPACK/lapack-aux.h eig_l_H" zheev :: TCMVCM -- The eigenvalues are sorted in descending order (use eigH' for ascending order). eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double)) eigH m = (s', fliprl v) - where (s,v) = eigH' m + where (s,v) = eigH' (fmat m) s' = fromList . reverse . toList $ s eigH' m - | r == 1 = (fromList [realPart (cdat m `at` 0)], singleton 1) + | r == 1 = (fromList [realPart (flatten m `at` 0)], singleton 1) | otherwise = unsafePerformIO $ do l <- createVector r v <- createMatrix ColumnMajor r r - zheev // mat fdat m // vec l // mat dat v // check "eigH" [fdat m] + zheev // matf m // vec l // matf v // check "eigH" [fdat m] return (l,v) where r = rows m ----------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h linearSolveR_l" dgesv :: TMMM +foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM --- | Wrapper for LAPACK's /dgesv/, which solves a general real linear system (for several right-hand sides) internally using the lu decomposition. -linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double -linearSolveR a b +linearSolveSQAux f st a b | n1==n2 && n1==r = unsafePerformIO $ do s <- createMatrix ColumnMajor r c - dgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveR" [fdat a, fdat b] + f // matf a // matf b // matf s // check st [fdat a, fdat b] return s - | otherwise = error "linearSolveR of nonsquare matrix" + | otherwise = error $ st ++ " of nonsquare matrix" where n1 = rows a n2 = cols a r = rows b c = cols b ------------------------------------------------------------------------------ -foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM +-- | Wrapper for LAPACK's /dgesv/, which solves a general real linear system (for several right-hand sides) internally using the lu decomposition. +linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double +linearSolveR a b = linearSolveSQAux dgesv "linearSolveR" (fmat a) (fmat b) -- | Wrapper for LAPACK's /zgesv/, which solves a general complex linear system (for several right-hand sides) internally using the lu decomposition. linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) -linearSolveC a b - | n1==n2 && n1==r = unsafePerformIO $ do - s <- createMatrix ColumnMajor r c - zgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveC" [fdat a, fdat b] - return s - | otherwise = error "linearSolveC of nonsquare matrix" - where n1 = rows a - n2 = cols a - r = rows b - c = cols b +linearSolveC a b = linearSolveSQAux zgesv "linearSolveC" (fmat a) (fmat b) ----------------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h linearSolveLSR_l" dgels :: TMMM +foreign import ccall "LAPACK/lapack-aux.h linearSolveLSC_l" zgels :: TCMCMCM +foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDR_l" dgelss :: Double -> TMMM +foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDC_l" zgelss :: Double -> TCMCMCM --- | Wrapper for LAPACK's /dgels/, which obtains the least squared error solution of an overconstrained real linear system or the minimum norm solution of an underdetermined system, for several right-hand sides. For rank deficient systems use 'linearSolveSVDR'. -linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double -linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSR_l a b - -linearSolveLSR_l a b = unsafePerformIO $ do +linearSolveAux f st a b = unsafePerformIO $ do r <- createMatrix ColumnMajor (max m n) nrhs - dgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSR" [fdat a, fdat b] + f // matf a // matf b // matf r // check st [fdat a, fdat b] return r where m = rows a n = cols a nrhs = cols b ------------------------------------------------------------------------------------ -foreign import ccall "LAPACK/lapack-aux.h linearSolveLSC_l" zgels :: TCMCMCM +-- | Wrapper for LAPACK's /dgels/, which obtains the least squared error solution of an overconstrained real linear system or the minimum norm solution of an underdetermined system, for several right-hand sides. For rank deficient systems use 'linearSolveSVDR'. +linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double +linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $ + linearSolveAux dgels "linearSolverLSR" (fmat a) (fmat b) -- | Wrapper for LAPACK's /zgels/, which obtains the least squared error solution of an overconstrained complex linear system or the minimum norm solution of an underdetermined system, for several right-hand sides. For rank deficient systems use 'linearSolveSVDC'. linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) -linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSC_l a b - -linearSolveLSC_l a b = unsafePerformIO $ do - r <- createMatrix ColumnMajor (max m n) nrhs - zgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSC" [fdat a, fdat b] - return r - where m = rows a - n = cols a - nrhs = cols b - ------------------------------------------------------------------------------------ -foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDR_l" dgelss :: Double -> TMMM +linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ + linearSolveAux zgels "linearSolveLSC" (fmat a) (fmat b) -- | Wrapper for LAPACK's /dgelss/, which obtains the minimum norm solution to a real linear least squares problem Ax=B using the svd, for several right-hand sides. Admits rank deficient systems but it is slower than 'linearSolveLSR'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used. linearSolveSVDR :: Maybe Double -- ^ rcond -> Matrix Double -- ^ coefficient matrix -> Matrix Double -- ^ right hand sides (as columns) -> Matrix Double -- ^ solution vectors (as columns) -linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDR_l rcond a b -linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) a b - -linearSolveSVDR_l rcond a b = unsafePerformIO $ do - r <- createMatrix ColumnMajor (max m n) nrhs - dgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDR" [fdat a, fdat b] - return r - where m = rows a - n = cols a - nrhs = cols b - ------------------------------------------------------------------------------------ -foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDC_l" zgelss :: Double -> TCMCMCM +linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ + linearSolveAux (dgelss rcond) "linearSolveSVDR" (fmat a) (fmat b) +linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) (fmat a) (fmat b) -- | Wrapper for LAPACK's /zgelss/, which obtains the minimum norm solution to a complex linear least squares problem Ax=B using the svd, for several right-hand sides. Admits rank deficient systems but it is slower than 'linearSolveLSC'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used. linearSolveSVDC :: Maybe Double -- ^ rcond -> Matrix (Complex Double) -- ^ coefficient matrix -> Matrix (Complex Double) -- ^ right hand sides (as columns) -> Matrix (Complex Double) -- ^ solution vectors (as columns) -linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDC_l rcond a b -linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) a b - -linearSolveSVDC_l rcond a b = unsafePerformIO $ do - r <- createMatrix ColumnMajor (max m n) nrhs - zgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDC" [fdat a, fdat b] - return r - where m = rows a - n = cols a - nrhs = cols b +linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ + linearSolveAux (zgelss rcond) "linearSolveSVDC" (fmat a) (fmat b) +linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b) ----------------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h chol_l_H" zpotrf :: TCMCM +foreign import ccall "LAPACK/lapack-aux.h chol_l_S" dpotrf :: TMM -- | Wrapper for LAPACK's /zpotrf/, which computes the Cholesky factorization of a -- complex Hermitian positive definite matrix. cholH :: Matrix (Complex Double) -> Matrix (Complex Double) -cholH a = unsafePerformIO $ do - r <- createMatrix ColumnMajor n n - zpotrf // mat fdat a // mat dat r // check "cholH" [fdat a] - return r - where n = rows a - ------------------------------------------------------------------------------------ -foreign import ccall "LAPACK/lapack-aux.h chol_l_S" dpotrf :: TMM +cholH = cholAux zpotrf "cholH" . fmat -- | Wrapper for LAPACK's /dpotrf/, which computes the Cholesky factorization of a -- real symmetric positive definite matrix. cholS :: Matrix Double -> Matrix Double -cholS a = unsafePerformIO $ do +cholS = cholAux dpotrf "cholS" . fmat + +cholAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor n n - dpotrf // mat fdat a // mat dat r // check "cholS" [fdat a] + f // matf a // matf r // check st [fdat a] return r where n = rows a ----------------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h qr_l_R" dgeqr2 :: TMVM +foreign import ccall "LAPACK/lapack-aux.h qr_l_C" zgeqr2 :: TCMCVCM -- | Wrapper for LAPACK's /dgeqr2/, which computes a QR factorization of a real matrix. qrR :: Matrix Double -> (Matrix Double, Vector Double) -qrR a = unsafePerformIO $ do - r <- createMatrix ColumnMajor m n - tau <- createVector mn - dgeqr2 // mat fdat a // vec tau // mat dat r // check "qrR" [fdat a] - return (r,tau) - where m = rows a - n = cols a - mn = min m n - ------------------------------------------------------------------------------------ -foreign import ccall "LAPACK/lapack-aux.h qr_l_C" zgeqr2 :: TCMCVCM +qrR = qrAux dgeqr2 "qrR" . fmat -- | Wrapper for LAPACK's /zgeqr2/, which computes a QR factorization of a complex matrix. qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) -qrC a = unsafePerformIO $ do +qrC = qrAux zgeqr2 "qrC" . fmat + +qrAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor m n tau <- createVector mn - zgeqr2 // mat fdat a // vec tau // mat dat r // check "qrC" [fdat a] + withForeignPtr (fptr $ fdat $ a) $ \p -> + f m n p // vec tau // matf r // check st [fdat a] return (r,tau) where m = rows a n = cols a @@ -337,52 +279,42 @@ qrC a = unsafePerformIO $ do ----------------------------------------------------------------------------------- foreign import ccall "LAPACK/lapack-aux.h hess_l_R" dgehrd :: TMVM +foreign import ccall "LAPACK/lapack-aux.h hess_l_C" zgehrd :: TCMCVCM -- | Wrapper for LAPACK's /dgehrd/, which computes a Hessenberg factorization of a square real matrix. hessR :: Matrix Double -> (Matrix Double, Vector Double) -hessR a = unsafePerformIO $ do - r <- createMatrix ColumnMajor m n - tau <- createVector (mn-1) - dgehrd // mat fdat a // vec tau // mat dat r // check "hessR" [fdat a] - return (r,tau) - where m = rows a - n = cols a - mn = min m n - ------------------------------------------------------------------------------------ -foreign import ccall "LAPACK/lapack-aux.h hess_l_C" zgehrd :: TCMCVCM +hessR = hessAux dgehrd "hessR" . fmat -- | Wrapper for LAPACK's /zgehrd/, which computes a Hessenberg factorization of a square complex matrix. hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double)) -hessC a = unsafePerformIO $ do +hessC = hessAux zgehrd "hessC" . fmat + +hessAux f st a = unsafePerformIO $ do r <- createMatrix ColumnMajor m n tau <- createVector (mn-1) - zgehrd // mat fdat a // vec tau // mat dat r // check "hessC" [fdat a] + f // matf a // vec tau // matf r // check st [fdat a] return (r,tau) where m = rows a n = cols a mn = min m n ----------------------------------------------------------------------------------- -foreign import ccall "LAPACK/lapack-aux.h schur_l_R" dgees :: TMMM +foreign import ccall safe "LAPACK/lapack-aux.h schur_l_R" dgees :: TMMM +foreign import ccall "LAPACK/lapack-aux.h schur_l_C" zgees :: TCMCMCM -- | Wrapper for LAPACK's /dgees/, which computes a Schur factorization of a square real matrix. schurR :: Matrix Double -> (Matrix Double, Matrix Double) -schurR a = unsafePerformIO $ do - u <- createMatrix ColumnMajor n n - s <- createMatrix ColumnMajor n n - dgees // mat fdat a // mat dat u // mat dat s // check "schurR" [fdat a] - return (u,s) - where n = rows a - ------------------------------------------------------------------------------------ -foreign import ccall "LAPACK/lapack-aux.h schur_l_C" zgees :: TCMCMCM +schurR = schurAux dgees "schurR" . fmat -- | Wrapper for LAPACK's /zgees/, which computes a Schur factorization of a square complex matrix. schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double)) -schurC a = unsafePerformIO $ do +schurC = schurAux zgees "schurC" . fmat + +schurAux f st a = unsafePerformIO $ do u <- createMatrix ColumnMajor n n s <- createMatrix ColumnMajor n n - zgees // mat fdat a // mat dat u // mat dat s // check "schurC" [fdat a] + f // matf a // matf u // matf s // check st [fdat a] return (u,s) where n = rows a + +----------------------------------------------------------------------------------- -- cgit v1.2.3