From eb28c0981f4da42c15ac267f7f6ba28d6f8bffbc Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Mon, 11 Jun 2007 12:34:06 +0000 Subject: ok linearSolve --- examples/tests.hs | 26 ++++++++++++++++++- lib/Data/Packed/Internal/Vector.hs | 10 ++++++- lib/LAPACK.hs | 5 ++-- lib/LAPACK/Internal.hs | 53 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 4 deletions(-) diff --git a/examples/tests.hs b/examples/tests.hs index 53436c8..50f0a03 100644 --- a/examples/tests.hs +++ b/examples/tests.hs @@ -138,6 +138,18 @@ instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where return $ Her (m `addM` (liftMatrix conj) (trans m)) coarbitrary = undefined +data PairSM a = PairSM (Matrix a) (Matrix a) deriving Show +instance (Num a, Field a, Arbitrary a) => Arbitrary (PairSM a) where + arbitrary = do + a <- choose (1,10) + c <- choose (1,10) + l1 <- vector (a*a) + l2 <- vector (a*c) + return $ PairSM ((a> v) |~~| (v <> diag (comp s)) where (s,v) = eigH m (<>) = prod +linearSolveSQTest fun eqfun singu prod (PairSM a b) = singu a || (a <> fun a b) ==== b + where (<>) = prod + (====) = eqfun + +prec = 1E-15 + +singular fun m = s1 < prec || s2/s1 < prec + where (_,ss,v) = fun m + s = toList ss + s1 = maximum s + s2 = minimum s main = do quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType]) @@ -204,7 +227,8 @@ main = do quickCheck (eigTestS mulF) quickCheck (eigTestH mulC) quickCheck (eigTestH mulF) - + quickCheck (linearSolveSQTest linearSolveR (|~|) (singular svdR') mulC) + quickCheck (linearSolveSQTest linearSolveC (|~~|) (singular svdC') mulC) kk = (2><2) [ 1.0, 0.0 diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs index 8f4e6a4..4836bdb 100644 --- a/lib/Data/Packed/Internal/Vector.hs +++ b/lib/Data/Packed/Internal/Vector.hs @@ -41,9 +41,17 @@ on f g = \x y -> f (g x) (g y) infixl 0 // (//) = flip ($) +errorCode 1000 = "bad size" +errorCode 1001 = "bad function code" +errorCode 1002 = "memory problem" +errorCode 1003 = "bad file" +errorCode 1004 = "singular" +errorCode 1005 = "didn't converge" +errorCode n = "code "++show n + check msg ls f = do err <- f - when (err/=0) (error msg) + when (err/=0) (error (msg++": "++errorCode err)) mapM_ (touchForeignPtr . fptr) ls return () diff --git a/lib/LAPACK.hs b/lib/LAPACK.hs index 0f1a178..0019fbe 100644 --- a/lib/LAPACK.hs +++ b/lib/LAPACK.hs @@ -13,10 +13,11 @@ ----------------------------------------------------------------------------- module LAPACK ( - --module LAPACK.Internal svdR, svdR', svdC, svdC', eigC, eigR, eigS, eigH, - linearSolveLSR + linearSolveR, linearSolveC, + linearSolveLSR, linearSolveLSC, + linearSolveSVDR, linearSolveSVDC, ) where import LAPACK.Internal diff --git a/lib/LAPACK/Internal.hs b/lib/LAPACK/Internal.hs index ba50e6b..ec46b66 100644 --- a/lib/LAPACK/Internal.hs +++ b/lib/LAPACK/Internal.hs @@ -174,11 +174,29 @@ eigH' (m@M {rows = r}) foreign import ccall "lapack-aux.h linearSolveR_l" dgesv :: Double ::> Double ::> Double ::> IO Int +-- | Wrapper for LAPACK's /dgesv/, which solves a general real linear system (for several right-hand sides) internally using the lu decomposition. +linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double +linearSolveR a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) + | n1==n2 && n1==r = unsafePerformIO $ do + s <- createMatrix ColumnMajor r c + dgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveR" [fdat a, fdat b] + return s + | otherwise = error "linearSolveR of nonsquare matrix" + ----------------------------------------------------------------------------- -- zgesv foreign import ccall "lapack-aux.h linearSolveC_l" zgesv :: (Complex Double) ::> (Complex Double) ::> (Complex Double) ::> IO Int +-- | Wrapper for LAPACK's /zgesv/, which solves a general complex linear system (for several right-hand sides) internally using the lu decomposition. +linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) +linearSolveC a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) + | n1==n2 && n1==r = unsafePerformIO $ do + s <- createMatrix ColumnMajor r c + zgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveC" [fdat a, fdat b] + return s + | otherwise = error "linearSolveC of nonsquare matrix" + ----------------------------------------------------------------------------------- -- dgels foreign import ccall "lapack-aux.h linearSolveLSR_l" @@ -198,12 +216,47 @@ linearSolveLSR_l a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformI foreign import ccall "lapack-aux.h linearSolveLSC_l" zgels :: (Complex Double) ::> (Complex Double) ::> (Complex Double) ::> IO Int +-- | Wrapper for LAPACK's /zgels/, which obtains the least squared error solution of an overconstrained complex linear system or the minimum norm solution of an underdetermined system, for several right-hand sides. For rank deficient systems use 'linearSolveSVDC'. +linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) +linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSC_l a b + +linearSolveLSC_l a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do + r <- createMatrix ColumnMajor (max m n) nrhs + zgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSC" [fdat a, fdat b] + return r + ----------------------------------------------------------------------------------- -- dgelss foreign import ccall "lapack-aux.h linearSolveSVDR_l" dgelss :: Double -> Double ::> Double ::> Double ::> IO Int +-- | Wrapper for LAPACK's /dgelss/, which obtains the minimum norm solution to a real linear least squares problem Ax=B using the svd, for several right-hand sides. Admits rank deficient systems but it is slower than 'linearSolveLSR'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used. +linearSolveSVDR :: Maybe Double -- ^ rcond + -> Matrix Double -- ^ coefficient matrix + -> Matrix Double -- ^ right hand sides (as columns) + -> Matrix Double -- ^ solution vectors (as columns) +linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDR_l rcond a b +linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) a b + +linearSolveSVDR_l rcond a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do + r <- createMatrix ColumnMajor (max m n) nrhs + dgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDR" [fdat a, fdat b] + return r + ----------------------------------------------------------------------------------- -- zgelss foreign import ccall "lapack-aux.h linearSolveSVDC_l" zgelss :: Double -> (Complex Double) ::> (Complex Double) ::> (Complex Double) ::> IO Int + +-- | Wrapper for LAPACK's /zgelss/, which obtains the minimum norm solution to a complex linear least squares problem Ax=B using the svd, for several right-hand sides. Admits rank deficient systems but it is slower than 'linearSolveLSC'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used. +linearSolveSVDC :: Maybe Double -- ^ rcond + -> Matrix (Complex Double) -- ^ coefficient matrix + -> Matrix (Complex Double) -- ^ right hand sides (as columns) + -> Matrix (Complex Double) -- ^ solution vectors (as columns) +linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDC_l rcond a b +linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) a b + +linearSolveSVDC_l rcond a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do + r <- createMatrix ColumnMajor (max m n) nrhs + zgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDC" [fdat a, fdat b] + return r -- cgit v1.2.3