diff options
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 7 | ||||
-rw-r--r-- | lib/LAPACK.hs | 2 | ||||
-rw-r--r-- | lib/LAPACK/Internal.hs | 21 |
3 files changed, 23 insertions, 7 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index bd333d4..4383e79 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -315,3 +315,10 @@ diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1.. | |||
315 | l = toList v | 315 | l = toList v |
316 | delta i j | i==j = 1 | 316 | delta i j | i==j = 1 |
317 | | otherwise = 0 | 317 | | otherwise = 0 |
318 | |||
319 | diagRect s r c | ||
320 | | dim s < min r c = error "diagRect" | ||
321 | | r == c = diag s | ||
322 | | r < c = joinHoriz [diag s , zeros (r,c-r)] | ||
323 | | otherwise = joinVert [diag s , zeros (r-c,c)] | ||
324 | where zeros (r,c) = reshape c $ constant (r*c) 0 | ||
diff --git a/lib/LAPACK.hs b/lib/LAPACK.hs index 67e49af..088b6d7 100644 --- a/lib/LAPACK.hs +++ b/lib/LAPACK.hs | |||
@@ -14,7 +14,7 @@ | |||
14 | 14 | ||
15 | module LAPACK ( | 15 | module LAPACK ( |
16 | --module LAPACK.Internal | 16 | --module LAPACK.Internal |
17 | svdR, svdR', | 17 | svdR, svdR', svdC, svdC', |
18 | eigC, | 18 | eigC, |
19 | linearSolveLSR | 19 | linearSolveLSR |
20 | ) where | 20 | ) where |
diff --git a/lib/LAPACK/Internal.hs b/lib/LAPACK/Internal.hs index 2569215..e39dd10 100644 --- a/lib/LAPACK/Internal.hs +++ b/lib/LAPACK/Internal.hs | |||
@@ -31,12 +31,7 @@ foreign import ccall "lapack-aux.h svd_l_R" | |||
31 | -- | 31 | -- |
32 | -- @(u,s,v)=svdR m@ so that @m=u \<\> s \<\> 'trans' v@. | 32 | -- @(u,s,v)=svdR m@ so that @m=u \<\> s \<\> 'trans' v@. |
33 | svdR :: Matrix Double -> (Matrix Double, Matrix Double , Matrix Double) | 33 | svdR :: Matrix Double -> (Matrix Double, Matrix Double , Matrix Double) |
34 | svdR x@M {rows = r, cols = c} = (u, s, v) | 34 | svdR x@M {rows = r, cols = c} = (u, diagRect s r c, v) where (u,s,v) = svdR' x |
35 | where (u,s',v) = svdR' x | ||
36 | s | r == c = diag s' | ||
37 | | r < c = joinHoriz [diag s' , zeros (r,c-r)] | ||
38 | | otherwise = joinVert [diag s' , zeros (r-c,c)] | ||
39 | zeros (r,c) = reshape c $ constant (r*c) 0 | ||
40 | 35 | ||
41 | svdR' x@M {rows = r, cols = c} = unsafePerformIO $ do | 36 | svdR' x@M {rows = r, cols = c} = unsafePerformIO $ do |
42 | u <- createMatrix ColumnMajor r r | 37 | u <- createMatrix ColumnMajor r r |
@@ -55,6 +50,20 @@ foreign import ccall "lapack-aux.h svd_l_Rdd" | |||
55 | foreign import ccall "lapack-aux.h svd_l_C" | 50 | foreign import ccall "lapack-aux.h svd_l_C" |
56 | zgesvd :: (Complex Double) ::> (Complex Double) ::> (Double :> (Complex Double) ::> IO Int) | 51 | zgesvd :: (Complex Double) ::> (Complex Double) ::> (Double :> (Complex Double) ::> IO Int) |
57 | 52 | ||
53 | -- | Wrapper for LAPACK's /zgesvd/, which computes the full svd decomposition of a complex matrix. | ||
54 | -- | ||
55 | -- @(u,s,v)=svdC m@ so that @m=u \<\> s \<\> 'trans' v@. | ||
56 | svdC :: Matrix (Complex Double) | ||
57 | -> (Matrix (Complex Double), Matrix Double, Matrix (Complex Double)) | ||
58 | svdC x@M {rows = r, cols = c} = (u, diagRect s r c, v) where (u,s,v) = svdC' x | ||
59 | |||
60 | svdC' x@M {rows = r, cols = c} = unsafePerformIO $ do | ||
61 | u <- createMatrix ColumnMajor r r | ||
62 | s <- createVector (min r c) | ||
63 | v <- createMatrix ColumnMajor c c | ||
64 | zgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdC" [fdat x] | ||
65 | return (u,s,trans v) | ||
66 | |||
58 | ----------------------------------------------------------------------------- | 67 | ----------------------------------------------------------------------------- |
59 | -- zgeev | 68 | -- zgeev |
60 | foreign import ccall "lapack-aux.h eig_l_C" | 69 | foreign import ccall "lapack-aux.h eig_l_C" |