summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/tests.hs26
-rw-r--r--lib/Data/Packed/Internal/Vector.hs10
-rw-r--r--lib/LAPACK.hs5
-rw-r--r--lib/LAPACK/Internal.hs53
4 files changed, 90 insertions, 4 deletions
diff --git a/examples/tests.hs b/examples/tests.hs
index 53436c8..50f0a03 100644
--- a/examples/tests.hs
+++ b/examples/tests.hs
@@ -138,6 +138,18 @@ instance {-(Field a, Arbitrary a, Num a) =>-} Arbitrary Her where
138 return $ Her (m `addM` (liftMatrix conj) (trans m)) 138 return $ Her (m `addM` (liftMatrix conj) (trans m))
139 coarbitrary = undefined 139 coarbitrary = undefined
140 140
141data PairSM a = PairSM (Matrix a) (Matrix a) deriving Show
142instance (Num a, Field a, Arbitrary a) => Arbitrary (PairSM a) where
143 arbitrary = do
144 a <- choose (1,10)
145 c <- choose (1,10)
146 l1 <- vector (a*a)
147 l2 <- vector (a*c)
148 return $ PairSM ((a><a) (map fromIntegral (l1::[Int]))) ((a><c) (map fromIntegral (l2::[Int])))
149 --return $ PairSM ((a><a) l1) ((a><c) l2)
150 coarbitrary = undefined
151
152
141 153
142 154
143addM m1 m2 = liftMatrix2 addV m1 m2 155addM m1 m2 = liftMatrix2 addV m1 m2
@@ -181,7 +193,18 @@ eigTestH prod (Her m) = (m <> v) |~~| (v <> diag (comp s))
181 where (s,v) = eigH m 193 where (s,v) = eigH m
182 (<>) = prod 194 (<>) = prod
183 195
196linearSolveSQTest fun eqfun singu prod (PairSM a b) = singu a || (a <> fun a b) ==== b
197 where (<>) = prod
198 (====) = eqfun
199
184 200
201prec = 1E-15
202
203singular fun m = s1 < prec || s2/s1 < prec
204 where (_,ss,v) = fun m
205 s = toList ss
206 s1 = maximum s
207 s2 = minimum s
185 208
186main = do 209main = do
187 quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType]) 210 quickCheck $ \l -> null l || (toList . fromList) l == (l :: [BaseType])
@@ -204,7 +227,8 @@ main = do
204 quickCheck (eigTestS mulF) 227 quickCheck (eigTestS mulF)
205 quickCheck (eigTestH mulC) 228 quickCheck (eigTestH mulC)
206 quickCheck (eigTestH mulF) 229 quickCheck (eigTestH mulF)
207 230 quickCheck (linearSolveSQTest linearSolveR (|~|) (singular svdR') mulC)
231 quickCheck (linearSolveSQTest linearSolveC (|~~|) (singular svdC') mulC)
208 232
209kk = (2><2) 233kk = (2><2)
210 [ 1.0, 0.0 234 [ 1.0, 0.0
diff --git a/lib/Data/Packed/Internal/Vector.hs b/lib/Data/Packed/Internal/Vector.hs
index 8f4e6a4..4836bdb 100644
--- a/lib/Data/Packed/Internal/Vector.hs
+++ b/lib/Data/Packed/Internal/Vector.hs
@@ -41,9 +41,17 @@ on f g = \x y -> f (g x) (g y)
41infixl 0 // 41infixl 0 //
42(//) = flip ($) 42(//) = flip ($)
43 43
44errorCode 1000 = "bad size"
45errorCode 1001 = "bad function code"
46errorCode 1002 = "memory problem"
47errorCode 1003 = "bad file"
48errorCode 1004 = "singular"
49errorCode 1005 = "didn't converge"
50errorCode n = "code "++show n
51
44check msg ls f = do 52check msg ls f = do
45 err <- f 53 err <- f
46 when (err/=0) (error msg) 54 when (err/=0) (error (msg++": "++errorCode err))
47 mapM_ (touchForeignPtr . fptr) ls 55 mapM_ (touchForeignPtr . fptr) ls
48 return () 56 return ()
49 57
diff --git a/lib/LAPACK.hs b/lib/LAPACK.hs
index 0f1a178..0019fbe 100644
--- a/lib/LAPACK.hs
+++ b/lib/LAPACK.hs
@@ -13,10 +13,11 @@
13----------------------------------------------------------------------------- 13-----------------------------------------------------------------------------
14 14
15module LAPACK ( 15module LAPACK (
16 --module LAPACK.Internal
17 svdR, svdR', svdC, svdC', 16 svdR, svdR', svdC, svdC',
18 eigC, eigR, eigS, eigH, 17 eigC, eigR, eigS, eigH,
19 linearSolveLSR 18 linearSolveR, linearSolveC,
19 linearSolveLSR, linearSolveLSC,
20 linearSolveSVDR, linearSolveSVDC,
20) where 21) where
21 22
22import LAPACK.Internal 23import LAPACK.Internal
diff --git a/lib/LAPACK/Internal.hs b/lib/LAPACK/Internal.hs
index ba50e6b..ec46b66 100644
--- a/lib/LAPACK/Internal.hs
+++ b/lib/LAPACK/Internal.hs
@@ -174,11 +174,29 @@ eigH' (m@M {rows = r})
174foreign import ccall "lapack-aux.h linearSolveR_l" 174foreign import ccall "lapack-aux.h linearSolveR_l"
175 dgesv :: Double ::> Double ::> Double ::> IO Int 175 dgesv :: Double ::> Double ::> Double ::> IO Int
176 176
177-- | Wrapper for LAPACK's /dgesv/, which solves a general real linear system (for several right-hand sides) internally using the lu decomposition.
178linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double
179linearSolveR a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c})
180 | n1==n2 && n1==r = unsafePerformIO $ do
181 s <- createMatrix ColumnMajor r c
182 dgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveR" [fdat a, fdat b]
183 return s
184 | otherwise = error "linearSolveR of nonsquare matrix"
185
177----------------------------------------------------------------------------- 186-----------------------------------------------------------------------------
178-- zgesv 187-- zgesv
179foreign import ccall "lapack-aux.h linearSolveC_l" 188foreign import ccall "lapack-aux.h linearSolveC_l"
180 zgesv :: (Complex Double) ::> (Complex Double) ::> (Complex Double) ::> IO Int 189 zgesv :: (Complex Double) ::> (Complex Double) ::> (Complex Double) ::> IO Int
181 190
191-- | Wrapper for LAPACK's /zgesv/, which solves a general complex linear system (for several right-hand sides) internally using the lu decomposition.
192linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
193linearSolveC a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c})
194 | n1==n2 && n1==r = unsafePerformIO $ do
195 s <- createMatrix ColumnMajor r c
196 zgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveC" [fdat a, fdat b]
197 return s
198 | otherwise = error "linearSolveC of nonsquare matrix"
199
182----------------------------------------------------------------------------------- 200-----------------------------------------------------------------------------------
183-- dgels 201-- dgels
184foreign import ccall "lapack-aux.h linearSolveLSR_l" 202foreign import ccall "lapack-aux.h linearSolveLSR_l"
@@ -198,12 +216,47 @@ linearSolveLSR_l a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformI
198foreign import ccall "lapack-aux.h linearSolveLSC_l" 216foreign import ccall "lapack-aux.h linearSolveLSC_l"
199 zgels :: (Complex Double) ::> (Complex Double) ::> (Complex Double) ::> IO Int 217 zgels :: (Complex Double) ::> (Complex Double) ::> (Complex Double) ::> IO Int
200 218
219-- | Wrapper for LAPACK's /zgels/, which obtains the least squared error solution of an overconstrained complex linear system or the minimum norm solution of an underdetermined system, for several right-hand sides. For rank deficient systems use 'linearSolveSVDC'.
220linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
221linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSC_l a b
222
223linearSolveLSC_l a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do
224 r <- createMatrix ColumnMajor (max m n) nrhs
225 zgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSC" [fdat a, fdat b]
226 return r
227
201----------------------------------------------------------------------------------- 228-----------------------------------------------------------------------------------
202-- dgelss 229-- dgelss
203foreign import ccall "lapack-aux.h linearSolveSVDR_l" 230foreign import ccall "lapack-aux.h linearSolveSVDR_l"
204 dgelss :: Double -> Double ::> Double ::> Double ::> IO Int 231 dgelss :: Double -> Double ::> Double ::> Double ::> IO Int
205 232
233-- | Wrapper for LAPACK's /dgelss/, which obtains the minimum norm solution to a real linear least squares problem Ax=B using the svd, for several right-hand sides. Admits rank deficient systems but it is slower than 'linearSolveLSR'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used.
234linearSolveSVDR :: Maybe Double -- ^ rcond
235 -> Matrix Double -- ^ coefficient matrix
236 -> Matrix Double -- ^ right hand sides (as columns)
237 -> Matrix Double -- ^ solution vectors (as columns)
238linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDR_l rcond a b
239linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) a b
240
241linearSolveSVDR_l rcond a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do
242 r <- createMatrix ColumnMajor (max m n) nrhs
243 dgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDR" [fdat a, fdat b]
244 return r
245
206----------------------------------------------------------------------------------- 246-----------------------------------------------------------------------------------
207-- zgelss 247-- zgelss
208foreign import ccall "lapack-aux.h linearSolveSVDC_l" 248foreign import ccall "lapack-aux.h linearSolveSVDC_l"
209 zgelss :: Double -> (Complex Double) ::> (Complex Double) ::> (Complex Double) ::> IO Int 249 zgelss :: Double -> (Complex Double) ::> (Complex Double) ::> (Complex Double) ::> IO Int
250
251-- | Wrapper for LAPACK's /zgelss/, which obtains the minimum norm solution to a complex linear least squares problem Ax=B using the svd, for several right-hand sides. Admits rank deficient systems but it is slower than 'linearSolveLSC'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used.
252linearSolveSVDC :: Maybe Double -- ^ rcond
253 -> Matrix (Complex Double) -- ^ coefficient matrix
254 -> Matrix (Complex Double) -- ^ right hand sides (as columns)
255 -> Matrix (Complex Double) -- ^ solution vectors (as columns)
256linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDC_l rcond a b
257linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) a b
258
259linearSolveSVDC_l rcond a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do
260 r <- createMatrix ColumnMajor (max m n) nrhs
261 zgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDC" [fdat a, fdat b]
262 return r