summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs131
-rw-r--r--lib/Data/Packed/Internal/Tensor.hs2
-rw-r--r--lib/Data/Packed/Matrix.hs4
-rw-r--r--lib/GSL.hs2
-rw-r--r--lib/GSL/Compat.hs3
-rw-r--r--lib/GSL/Matrix.hs42
-rw-r--r--lib/LAPACK.hs73
7 files changed, 188 insertions, 69 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index 48652f3..6ba2d06 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -27,6 +27,10 @@ import Data.Maybe(fromJust)
27 27
28---------------------------------------------------------------- 28----------------------------------------------------------------
29 29
30-- the condition Storable a => Field a means that we can only put
31-- in Field types that are in Storable, and therefore Storable a
32-- is not required in signatures if we have a Field a.
33
30class Storable a => Field a where 34class Storable a => Field a where
31 constant :: a -> Int -> Vector a 35 constant :: a -> Int -> Vector a
32 transdata :: Int -> Vector a -> Int -> Vector a 36 transdata :: Int -> Vector a -> Int -> Vector a
@@ -36,7 +40,6 @@ class Storable a => Field a where
36 -> Matrix a -> Matrix a 40 -> Matrix a -> Matrix a
37 diag :: Vector a -> Matrix a 41 diag :: Vector a -> Matrix a
38 42
39
40instance Field Double where 43instance Field Double where
41 constant = constantR 44 constant = constantR
42 transdata = transdataR 45 transdata = transdataR
@@ -78,12 +81,40 @@ foreign import ccall safe "aux.h transC"
78 81
79transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d 82transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d
80 83
84{- Design considerations for the Matrix Type
85 -----------------------------------------
86
87- we must easily handle both row major and column major order,
88 for bindings to LAPACK and GSL/C
89
90- we'd like to simplify redundant matrix transposes:
91 - Some of them arise from the order requirements of some functions
92 - some functions (matrix product) admit transposed arguments
81 93
94- maybe we don't really need this kind of simplification:
95 - more complex code
96 - some computational overhead
97 - only appreciable gain in code with a lot of redundant transpositions
98 and cheap matrix computations
82 99
100- we could carry both the matrix and its (lazily computed) transpose.
101 This may save some transpositions, but it is necessary to keep track of the
102 data which is actually computed to be used by functions like the matrix product
103 which admit both orders. Therefore, maybe it is better to have something like
104 viewC and viewF, which may actually perform a transpose if required.
83 105
106- but if we need the transposed data and it is not in the structure, we must make
107 sure that we touch the same foreignptr that is used in the computation. Access
108 to such pointer cannot be made by creating a new vector.
109
110-}
84 111
85data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) 112data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
86 113
114{-
115
116
117
87data Matrix t = M { rows :: Int 118data Matrix t = M { rows :: Int
88 , cols :: Int 119 , cols :: Int
89 , dat :: Vector t 120 , dat :: Vector t
@@ -91,28 +122,26 @@ data Matrix t = M { rows :: Int
91 , isTrans :: Bool 122 , isTrans :: Bool
92 , order :: MatrixOrder 123 , order :: MatrixOrder
93 } -- deriving Typeable 124 } -- deriving Typeable
125-}
94 126
127data Matrix t = MC { rows :: Int, cols :: Int, dat :: Vector t } -- row major order
128 | MF { rows :: Int, cols :: Int, dat :: Vector t } -- column major order
95 129
96data NMat t = MC { rws, cls :: Int, dtc :: Vector t} 130-- transposition just changes the data order
97 | MF { rws, cls :: Int, dtf :: Vector t} 131trans :: Matrix t -> Matrix t
98 | Tr (NMat t) 132trans MC {rows = r, cols = c, dat = d} = MF {rows = c, cols = r, dat = d}
99 133trans MF {rows = r, cols = c, dat = d} = MC {rows = c, cols = r, dat = d}
100ntrans (Tr m) = m
101ntrans m = Tr m
102 134
103viewC m@MC{} = m 135viewC m@MC{} = m
104viewF m@MF{} = m 136viewC MF {rows = r, cols = c, dat = d} = MC {rows = r, cols = c, dat = transdata r d c}
105 137
106fortran m = order m == ColumnMajor 138viewF m@MF{} = m
139viewF MC {rows = r, cols = c, dat = d} = MF {rows = r, cols = c, dat = transdata c d r}
107 140
108cdat m = if fortran m `xor` isTrans m then tdat m else dat m 141--fortran m = order m == ColumnMajor
109fdat m = if fortran m `xor` isTrans m then dat m else tdat m
110 142
111trans :: Matrix t -> Matrix t 143cdat m = dat (viewC m)
112trans m = m { rows = cols m 144fdat m = dat (viewF m)
113 , cols = rows m
114 , isTrans = not (isTrans m)
115 }
116 145
117type Mt t s = Int -> Int -> Ptr t -> s 146type Mt t s = Int -> Int -> Ptr t -> s
118-- not yet admitted by my haddock version 147-- not yet admitted by my haddock version
@@ -120,11 +149,14 @@ type Mt t s = Int -> Int -> Ptr t -> s
120-- type t ::> s = Mt t s 149-- type t ::> s = Mt t s
121 150
122mat d m f = f (rows m) (cols m) (ptr (d m)) 151mat d m f = f (rows m) (cols m) (ptr (d m))
152--mat m f = f (rows m) (cols m) (ptr (dat m))
153--matC m f = f (rows m) (cols m) (ptr (cdat m))
154
123 155
124toLists :: (Storable t) => Matrix t -> [[t]] 156--toLists :: (Storable t) => Matrix t -> [[t]]
125toLists m = partit (cols m) . toList . cdat $ m 157toLists m = partit (cols m) . toList . cdat $ m
126 158
127instance (Show a, Storable a) => (Show (Matrix a)) where 159instance (Show a, Field a) => (Show (Matrix a)) where
128 show m = (sizes++) . dsp . map (map show) . toLists $ m 160 show m = (sizes++) . dsp . map (map show) . toLists $ m
129 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" 161 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n"
130 162
@@ -136,6 +168,7 @@ dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unw
136 pad n str = replicate (n - length str) ' ' ++ str 168 pad n str = replicate (n - length str) ' ' ++ str
137 unwords' = concat . intersperse ", " 169 unwords' = concat . intersperse ", "
138 170
171{-
139matrixFromVector RowMajor c v = 172matrixFromVector RowMajor c v =
140 M { rows = r 173 M { rows = r
141 , cols = c 174 , cols = c
@@ -147,8 +180,6 @@ matrixFromVector RowMajor c v =
147 r | m==0 = d 180 r | m==0 = d
148 | otherwise = error "matrixFromVector" 181 | otherwise = error "matrixFromVector"
149 182
150-- r = dim v `div` c -- TODO check mod=0
151
152matrixFromVector ColumnMajor c v = 183matrixFromVector ColumnMajor c v =
153 M { rows = r 184 M { rows = r
154 , cols = c 185 , cols = c
@@ -160,6 +191,23 @@ matrixFromVector ColumnMajor c v =
160 r | m==0 = d 191 r | m==0 = d
161 | otherwise = error "matrixFromVector" 192 | otherwise = error "matrixFromVector"
162 193
194-}
195
196matrixFromVector RowMajor c v = MC { rows = r, cols = c, dat = v}
197 where (d,m) = dim v `divMod` c
198 r | m==0 = d
199 | otherwise = error "matrixFromVector"
200
201matrixFromVector ColumnMajor c v = MF { rows = r, cols = c, dat = v}
202 where (d,m) = dim v `divMod` c
203 r | m==0 = d
204 | otherwise = error "matrixFromVector"
205
206
207
208
209
210
163createMatrix order r c = do 211createMatrix order r c = do
164 p <- createVector (r*c) 212 p <- createVector (r*c)
165 return (matrixFromVector order c p) 213 return (matrixFromVector order c p)
@@ -178,10 +226,10 @@ reshape c v = matrixFromVector RowMajor c v
178 226
179singleton x = reshape 1 (fromList [x]) 227singleton x = reshape 1 (fromList [x])
180 228
181liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 229--liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
182liftMatrix f m = reshape (cols m) (f (cdat m)) 230liftMatrix f m = reshape (cols m) (f (cdat m))
183 231
184liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 232--liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
185liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) 233liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2))
186 | otherwise = error "nonconformant matrices in liftMatrix2" 234 | otherwise = error "nonconformant matrices in liftMatrix2"
187------------------------------------------------------------------ 235------------------------------------------------------------------
@@ -203,6 +251,7 @@ multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multipl
203 251
204------------------------------------------------------------------ 252------------------------------------------------------------------
205 253
254{-
206gmatC m f | fortran m = 255gmatC m f | fortran m =
207 if (isTrans m) 256 if (isTrans m)
208 then f 0 (rows m) (cols m) (ptr (dat m)) 257 then f 0 (rows m) (cols m) (ptr (dat m))
@@ -211,7 +260,11 @@ gmatC m f | fortran m =
211 if isTrans m 260 if isTrans m
212 then f 1 (cols m) (rows m) (ptr (dat m)) 261 then f 1 (cols m) (rows m) (ptr (dat m))
213 else f 0 (rows m) (cols m) (ptr (dat m)) 262 else f 0 (rows m) (cols m) (ptr (dat m))
263-}
214 264
265gmatC MF {rows = r, cols = c, dat = d} f = f 1 c r (ptr d)
266gmatC MC {rows = r, cols = c, dat = d} f = f 0 r c (ptr d)
267{-# INLINE gmatC #-}
215 268
216multiplyAux fun order a b = unsafePerformIO $ do 269multiplyAux fun order a b = unsafePerformIO $ do
217 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ 270 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
@@ -219,6 +272,7 @@ multiplyAux fun order a b = unsafePerformIO $ do
219 r <- createMatrix order (rows a) (cols b) 272 r <- createMatrix order (rows a) (cols b)
220 fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] 273 fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b]
221 return r 274 return r
275{-# INLINE multiplyAux #-}
222 276
223foreign import ccall safe "aux.h multiplyR" 277foreign import ccall safe "aux.h multiplyR"
224 cmultiplyR :: Int -> Int -> Int -> Ptr Double 278 cmultiplyR :: Int -> Int -> Int -> Ptr Double
@@ -234,13 +288,15 @@ foreign import ccall safe "aux.h multiplyC"
234 288
235multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a 289multiply :: (Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
236multiply RowMajor a b = multiplyD RowMajor a b 290multiply RowMajor a b = multiplyD RowMajor a b
237multiply ColumnMajor a b = m {rows = cols m, cols = rows m, order = ColumnMajor} 291multiply ColumnMajor a b = MF {rows = c, cols = r, dat = d}
238 where m = multiplyD RowMajor (trans b) (trans a) 292 where MC {rows = r, cols = c, dat = d } = multiplyD RowMajor (trans b) (trans a)
239 293
240 294
241multiplyR = multiplyAux cmultiplyR 295multiplyR = multiplyAux cmultiplyR'
242multiplyC = multiplyAux cmultiplyC 296multiplyC = multiplyAux cmultiplyC
243 297
298cmultiplyR' p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3 = {-# SCC "mulR" #-} cmultiplyR p1 p2 p3 p4 q1 q2 q3 q4 r1 r2 r3
299
244---------------------------------------------------------------------- 300----------------------------------------------------------------------
245 301
246-- | extraction of a submatrix of a real matrix 302-- | extraction of a submatrix of a real matrix
@@ -249,7 +305,7 @@ subMatrixR :: (Int,Int) -- ^ (r0,c0) starting position
249 -> Matrix Double -> Matrix Double 305 -> Matrix Double -> Matrix Double
250subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do 306subMatrixR (r0,c0) (rt,ct) x = unsafePerformIO $ do
251 r <- createMatrix RowMajor rt ct 307 r <- createMatrix RowMajor rt ct
252 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat cdat r // check "subMatrixR" [dat r] 308 c_submatrixR r0 (r0+rt-1) c0 (c0+ct-1) // mat cdat x // mat dat r // check "subMatrixR" [dat r]
253 return r 309 return r
254foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM 310foreign import ccall "aux.h submatrixR" c_submatrixR :: Int -> Int -> Int -> Int -> TMM
255 311
@@ -278,8 +334,8 @@ subMatrixG (r0,c0) (rt,ct) x = reshape ct $ fromList $ concat $ map (subList c0
278 334
279diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do 335diagAux fun msg (v@V {dim = n}) = unsafePerformIO $ do
280 m <- createMatrix RowMajor n n 336 m <- createMatrix RowMajor n n
281 fun // vec v // mat dat m // check msg [dat m] 337 fun // vec v // mat cdat m // check msg [dat m]
282 return m {tdat = dat m} 338 return m -- {tdat = dat m}
283 339
284-- | diagonal matrix from a real vector 340-- | diagonal matrix from a real vector
285diagR :: Vector Double -> Matrix Double 341diagR :: Vector Double -> Matrix Double
@@ -305,13 +361,13 @@ diagG v = reshape c $ fromList $ [ l!!(i-1) * delta k i | k <- [1..c], i <- [1..
305 | otherwise = 0 361 | otherwise = 0
306 362
307-- | creates a Matrix from a list of vectors 363-- | creates a Matrix from a list of vectors
308fromRows :: Field t => [Vector t] -> Matrix t 364--fromRows :: Field t => [Vector t] -> Matrix t
309fromRows vs = case common dim vs of 365fromRows vs = case common dim vs of
310 Nothing -> error "fromRows applied to [] or to vectors with different sizes" 366 Nothing -> error "fromRows applied to [] or to vectors with different sizes"
311 Just c -> reshape c (join vs) 367 Just c -> reshape c (join vs)
312 368
313-- | extracts the rows of a matrix as a list of vectors 369-- | extracts the rows of a matrix as a list of vectors
314toRows :: Storable t => Matrix t -> [Vector t] 370--toRows :: Storable t => Matrix t -> [Vector t]
315toRows m = toRows' 0 where 371toRows m = toRows' 0 where
316 v = cdat m 372 v = cdat m
317 r = rows m 373 r = rows m
@@ -324,16 +380,25 @@ fromColumns :: Field t => [Vector t] -> Matrix t
324fromColumns m = trans . fromRows $ m 380fromColumns m = trans . fromRows $ m
325 381
326-- | Creates a list of vectors from the columns of a matrix 382-- | Creates a list of vectors from the columns of a matrix
327toColumns :: Storable t => Matrix t -> [Vector t] 383toColumns :: Field t => Matrix t -> [Vector t]
328toColumns m = toRows . trans $ m 384toColumns m = toRows . trans $ m
329 385
330 386
331-- | Reads a matrix position. 387-- | Reads a matrix position.
332(@@>) :: Storable t => Matrix t -> (Int,Int) -> t 388(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
333infixl 9 @@> 389infixl 9 @@>
334m@M {rows = r, cols = c} @@> (i,j) 390--m@M {rows = r, cols = c} @@> (i,j)
391-- | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
392-- | otherwise = cdat m `at` (i*c+j)
393
394MC {rows = r, cols = c, dat = v} @@> (i,j)
395 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
396 | otherwise = v `at` (i*c+j)
397
398MF {rows = r, cols = c, dat = v} @@> (i,j)
335 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range" 399 | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
336 | otherwise = cdat m `at` (i*c+j) 400 | otherwise = v `at` (j*r+i)
401
337 402
338------------------------------------------------------------------ 403------------------------------------------------------------------
339 404
diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs
index 6876685..dea1636 100644
--- a/lib/Data/Packed/Internal/Tensor.hs
+++ b/lib/Data/Packed/Internal/Tensor.hs
@@ -92,7 +92,7 @@ tensor dssig vec = T d v `withIdx` seqind where
92tensorFromVector :: IdxType -> Vector t -> Tensor t 92tensorFromVector :: IdxType -> Vector t -> Tensor t
93tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v} 93tensorFromVector tp v = T {dims = [IdxDesc (dim v) tp "1"], ten = v}
94 94
95tensorFromMatrix :: IdxType -> IdxType -> Matrix t -> Tensor t 95tensorFromMatrix :: Field t => IdxType -> IdxType -> Matrix t -> Tensor t
96tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"] 96tensorFromMatrix tpr tpc m = T {dims = [IdxDesc (rows m) tpr "1",IdxDesc (cols m) tpc "2"]
97 , ten = cdat m} 97 , ten = cdat m}
98 98
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs
index 2e8cb3d..45aaaba 100644
--- a/lib/Data/Packed/Matrix.hs
+++ b/lib/Data/Packed/Matrix.hs
@@ -77,7 +77,7 @@ diagRect s r c
77 | r > c = joinVert [diag s , zeros (r-c,c)] 77 | r > c = joinVert [diag s , zeros (r-c,c)]
78 where zeros (r,c) = reshape c $ constant 0 (r*c) 78 where zeros (r,c) = reshape c $ constant 0 (r*c)
79 79
80takeDiag :: (Storable t) => Matrix t -> Vector t 80takeDiag :: (Field t) => Matrix t -> Vector t
81takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] 81takeDiag m = fromList [cdat m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]]
82 82
83ident :: (Num t, Field t) => Int -> Matrix t 83ident :: (Num t, Field t) => Int -> Matrix t
@@ -119,7 +119,7 @@ dropColumns n mat = subMatrix (0,n) (rows mat, cols mat - n) mat
119@\> flatten ('ident' 3) 119@\> flatten ('ident' 3)
1209 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@ 1209 # [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]@
121-} 121-}
122flatten :: Matrix t -> Vector t 122flatten :: Field t => Matrix t -> Vector t
123flatten = cdat 123flatten = cdat
124 124
125-- | Creates a 'Matrix' from a list of lists (considered as rows). 125-- | Creates a 'Matrix' from a list of lists (considered as rows).
diff --git a/lib/GSL.hs b/lib/GSL.hs
index 23865b0..e6ff6df 100644
--- a/lib/GSL.hs
+++ b/lib/GSL.hs
@@ -21,7 +21,7 @@ module LinearAlgebra.Algorithms,
21module LAPACK, 21module LAPACK,
22module GSL.Integration, 22module GSL.Integration,
23module GSL.Differentiation, 23module GSL.Differentiation,
24module GSL.Special, 24--module GSL.Special,
25module GSL.Fourier, 25module GSL.Fourier,
26module GSL.Polynomials, 26module GSL.Polynomials,
27module GSL.Minimization, 27module GSL.Minimization,
diff --git a/lib/GSL/Compat.hs b/lib/GSL/Compat.hs
index 809a1f5..1d6f7b9 100644
--- a/lib/GSL/Compat.hs
+++ b/lib/GSL/Compat.hs
@@ -38,7 +38,7 @@ adaptScalar f1 f2 f3 x y
38 | dim y == 1 = f3 x (y@>0) 38 | dim y == 1 = f3 x (y@>0)
39 | otherwise = f2 x y 39 | otherwise = f2 x y
40 40
41liftMatrix2' :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 41liftMatrix2' :: (Field t, Field a, Field b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
42liftMatrix2' f m1 m2 | compat' m1 m2 = reshape (max (cols m1) (cols m2)) (f (cdat m1) (cdat m2)) 42liftMatrix2' f m1 m2 | compat' m1 m2 = reshape (max (cols m1) (cols m2)) (f (cdat m1) (cdat m2))
43 | otherwise = error "nonconformant matrices in liftMatrix2'" 43 | otherwise = error "nonconformant matrices in liftMatrix2'"
44 44
@@ -63,6 +63,7 @@ instance (Eq a, Field a) => Eq (Matrix a) where
63 63
64instance (Field a, Linear Vector a) => Num (Matrix a) where 64instance (Field a, Linear Vector a) => Num (Matrix a) where
65 (+) = liftMatrix2' (+) 65 (+) = liftMatrix2' (+)
66 (-) = liftMatrix2' (-)
66 negate = liftMatrix negate 67 negate = liftMatrix negate
67 (*) = liftMatrix2' (*) 68 (*) = liftMatrix2' (*)
68 signum = liftMatrix signum 69 signum = liftMatrix signum
diff --git a/lib/GSL/Matrix.hs b/lib/GSL/Matrix.hs
index 26c5e2a..15710df 100644
--- a/lib/GSL/Matrix.hs
+++ b/lib/GSL/Matrix.hs
@@ -46,13 +46,14 @@ import Foreign.C.String
46 46
47-} 47-}
48eigSg :: Matrix Double -> (Vector Double, Matrix Double) 48eigSg :: Matrix Double -> (Vector Double, Matrix Double)
49eigSg (m@M {rows = r}) 49eigSg m
50 | r == 1 = (fromList [cdat m `at` 0], singleton 1) 50 | r == 1 = (fromList [cdat m `at` 0], singleton 1)
51 | otherwise = unsafePerformIO $ do 51 | otherwise = unsafePerformIO $ do
52 l <- createVector r 52 l <- createVector r
53 v <- createMatrix RowMajor r r 53 v <- createMatrix RowMajor r r
54 c_eigS // mat cdat m // vec l // mat dat v // check "eigSg" [cdat m] 54 c_eigS // mat cdat m // vec l // mat dat v // check "eigSg" [cdat m]
55 return (l,v) 55 return (l,v)
56 where r = rows m
56foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM 57foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM
57 58
58------------------------------------------------------------------ 59------------------------------------------------------------------
@@ -76,13 +77,14 @@ foreign import ccall "gsl-aux.h eigensystemR" c_eigS :: TMVM
76 77
77-} 78-}
78eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double)) 79eigHg :: Matrix (Complex Double)-> (Vector Double, Matrix (Complex Double))
79eigHg (m@M {rows = r}) 80eigHg m
80 | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1) 81 | r == 1 = (fromList [realPart $ cdat m `at` 0], singleton 1)
81 | otherwise = unsafePerformIO $ do 82 | otherwise = unsafePerformIO $ do
82 l <- createVector r 83 l <- createVector r
83 v <- createMatrix RowMajor r r 84 v <- createMatrix RowMajor r r
84 c_eigH // mat cdat m // vec l // mat dat v // check "eigHg" [cdat m] 85 c_eigH // mat cdat m // vec l // mat dat v // check "eigHg" [cdat m]
85 return (l,v) 86 return (l,v)
87 where r = rows m
86foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM 88foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM
87 89
88 90
@@ -108,16 +110,18 @@ foreign import ccall "gsl-aux.h eigensystemC" c_eigH :: TCMVCM
108 110
109-} 111-}
110svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) 112svdg :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
111svdg x@M {rows = r, cols = c} = if r>=c 113svdg x = if rows x >= cols x
112 then svd' x 114 then svd' x
113 else (v, s, u) where (u,s,v) = svd' (trans x) 115 else (v, s, u) where (u,s,v) = svd' (trans x)
114 116
115svd' x@M {rows = r, cols = c} = unsafePerformIO $ do 117svd' x = unsafePerformIO $ do
116 u <- createMatrix RowMajor r c 118 u <- createMatrix RowMajor r c
117 s <- createVector c 119 s <- createVector c
118 v <- createMatrix RowMajor c c 120 v <- createMatrix RowMajor c c
119 c_svd // mat cdat x // mat dat u // vec s // mat dat v // check "svdg" [cdat x] 121 c_svd // mat cdat x // mat dat u // vec s // mat dat v // check "svdg" [cdat x]
120 return (u,s,v) 122 return (u,s,v)
123 where r = rows x
124 c = cols x
121foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM 125foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM
122 126
123{- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/. 127{- | QR decomposition of a real matrix using /gsl_linalg_QR_decomp/ and /gsl_linalg_QR_unpack/.
@@ -138,11 +142,13 @@ foreign import ccall "gsl-aux.h svd" c_svd :: TMMVM
138 142
139-} 143-}
140qr :: Matrix Double -> (Matrix Double, Matrix Double) 144qr :: Matrix Double -> (Matrix Double, Matrix Double)
141qr x@M {rows = r, cols = c} = unsafePerformIO $ do 145qr x = unsafePerformIO $ do
142 q <- createMatrix RowMajor r r 146 q <- createMatrix RowMajor r r
143 rot <- createMatrix RowMajor r c 147 rot <- createMatrix RowMajor r c
144 c_qr // mat cdat x // mat dat q // mat dat rot // check "qr" [cdat x] 148 c_qr // mat cdat x // mat dat q // mat dat rot // check "qr" [cdat x]
145 return (q,rot) 149 return (q,rot)
150 where r = rows x
151 c = cols x
146foreign import ccall "gsl-aux.h QR" c_qr :: TMMM 152foreign import ccall "gsl-aux.h QR" c_qr :: TMMM
147 153
148{- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/. 154{- | Cholesky decomposition of a symmetric positive definite real matrix using /gsl_linalg_cholesky_decomp/.
@@ -159,11 +165,11 @@ foreign import ccall "gsl-aux.h QR" c_qr :: TMMM
159 165
160-} 166-}
161chol :: Matrix Double -> Matrix Double 167chol :: Matrix Double -> Matrix Double
162--chol x@(M r _ p) = createM [p] "chol" r r $ m c_chol x 168chol x = unsafePerformIO $ do
163chol x@M {rows = r} = unsafePerformIO $ do
164 res <- createMatrix RowMajor r r 169 res <- createMatrix RowMajor r r
165 c_chol // mat cdat x // mat dat res // check "chol" [cdat x] 170 c_chol // mat cdat x // mat dat res // check "chol" [cdat x]
166 return res 171 return res
172 where r = rows x
167foreign import ccall "gsl-aux.h chol" c_chol :: TMM 173foreign import ccall "gsl-aux.h chol" c_chol :: TMM
168 174
169-------------------------------------------------------- 175--------------------------------------------------------
@@ -171,43 +177,53 @@ foreign import ccall "gsl-aux.h chol" c_chol :: TMM
171{- -| efficient multiplication by the inverse of a matrix (for real matrices) 177{- -| efficient multiplication by the inverse of a matrix (for real matrices)
172-} 178-}
173luSolveR :: Matrix Double -> Matrix Double -> Matrix Double 179luSolveR :: Matrix Double -> Matrix Double -> Matrix Double
174luSolveR a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) 180luSolveR a b
175 | n1==n2 && n1==r = unsafePerformIO $ do 181 | n1==n2 && n1==r = unsafePerformIO $ do
176 s <- createMatrix RowMajor r c 182 s <- createMatrix RowMajor r c
177 c_luSolveR // mat cdat a // mat cdat b // mat dat s // check "luSolveR" [cdat a, cdat b] 183 c_luSolveR // mat cdat a // mat cdat b // mat dat s // check "luSolveR" [cdat a, cdat b]
178 return s 184 return s
179 | otherwise = error "luSolveR of nonsquare matrix" 185 | otherwise = error "luSolveR of nonsquare matrix"
180 186 where n1 = rows a
187 n2 = cols a
188 r = rows b
189 c = cols b
181foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM 190foreign import ccall "gsl-aux.h luSolveR" c_luSolveR :: TMMM
182 191
183{- -| efficient multiplication by the inverse of a matrix (for complex matrices). 192{- -| efficient multiplication by the inverse of a matrix (for complex matrices).
184-} 193-}
185luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 194luSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
186luSolveC a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) 195luSolveC a b
187 | n1==n2 && n1==r = unsafePerformIO $ do 196 | n1==n2 && n1==r = unsafePerformIO $ do
188 s <- createMatrix RowMajor r c 197 s <- createMatrix RowMajor r c
189 c_luSolveC // mat cdat a // mat cdat b // mat dat s // check "luSolveC" [cdat a, cdat b] 198 c_luSolveC // mat cdat a // mat cdat b // mat dat s // check "luSolveC" [cdat a, cdat b]
190 return s 199 return s
191 | otherwise = error "luSolveC of nonsquare matrix" 200 | otherwise = error "luSolveC of nonsquare matrix"
192 201 where n1 = rows a
202 n2 = cols a
203 r = rows b
204 c = cols b
193foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM 205foreign import ccall "gsl-aux.h luSolveC" c_luSolveC :: TCMCMCM
194 206
195{- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign) 207{- | lu decomposition of real matrix (packed as a vector including l, u, the permutation and sign)
196-} 208-}
197luRaux :: Matrix Double -> Vector Double 209luRaux :: Matrix Double -> Vector Double
198luRaux x@M {rows = r, cols = c} = unsafePerformIO $ do 210luRaux x = unsafePerformIO $ do
199 res <- createVector (r*r+r+1) 211 res <- createVector (r*r+r+1)
200 c_luRaux // mat cdat x // vec res // check "luRaux" [cdat x] 212 c_luRaux // mat cdat x // vec res // check "luRaux" [cdat x]
201 return res 213 return res
214 where r = rows x
215 c = cols x
202foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV 216foreign import ccall "gsl-aux.h luRaux" c_luRaux :: TMV
203 217
204{- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign) 218{- | lu decomposition of complex matrix (packed as a vector including l, u, the permutation and sign)
205-} 219-}
206luCaux :: Matrix (Complex Double) -> Vector (Complex Double) 220luCaux :: Matrix (Complex Double) -> Vector (Complex Double)
207luCaux x@M {rows = r, cols = c} = unsafePerformIO $ do 221luCaux x = unsafePerformIO $ do
208 res <- createVector (r*r+r+1) 222 res <- createVector (r*r+r+1)
209 c_luCaux // mat cdat x // vec res // check "luCaux" [cdat x] 223 c_luCaux // mat cdat x // vec res // check "luCaux" [cdat x]
210 return res 224 return res
225 where r = rows x
226 c = cols x
211foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV 227foreign import ccall "gsl-aux.h luCaux" c_luCaux :: TCMCV
212 228
213{- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in <http://www.gnu.org/software/gsl/manual/gsl-ref_13.html#SEC223>. 229{- | The LU decomposition of a square matrix. Is based on /gsl_linalg_LU_decomp/ and /gsl_linalg_complex_LU_decomp/ as described in <http://www.gnu.org/software/gsl/manual/gsl-ref_13.html#SEC223>.
diff --git a/lib/LAPACK.hs b/lib/LAPACK.hs
index b0008b1..ba72681 100644
--- a/lib/LAPACK.hs
+++ b/lib/LAPACK.hs
@@ -37,16 +37,19 @@ foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM
37-- 37--
38-- @(u,s,v)=svdR m@ so that @m=u \<\> s \<\> 'trans' v@. 38-- @(u,s,v)=svdR m@ so that @m=u \<\> s \<\> 'trans' v@.
39svdR :: Matrix Double -> (Matrix Double, Matrix Double, Matrix Double) 39svdR :: Matrix Double -> (Matrix Double, Matrix Double, Matrix Double)
40svdR x@M {rows = r, cols = c} = (u, diagRect s r c, v) where (u,s,v) = svdR' x 40svdR x = (u, diagRect s r c, v) where (u,s,v) = svdR' x
41 r = rows x
42 c = cols x
41 43
42svdR' :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) 44svdR' :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
43svdR' x@M {rows = r, cols = c} = unsafePerformIO $ do 45svdR' x = unsafePerformIO $ do
44 u <- createMatrix ColumnMajor r r 46 u <- createMatrix ColumnMajor r r
45 s <- createVector (min r c) 47 s <- createVector (min r c)
46 v <- createMatrix ColumnMajor c c 48 v <- createMatrix ColumnMajor c c
47 dgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdR" [fdat x] 49 dgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdR" [fdat x]
48 return (u,s,trans v) 50 return (u,s,trans v)
49 51 where r = rows x
52 c = cols x
50----------------------------------------------------------------------------- 53-----------------------------------------------------------------------------
51foreign import ccall "LAPACK/lapack-aux.h svd_l_Rdd" dgesdd :: TMMVM 54foreign import ccall "LAPACK/lapack-aux.h svd_l_Rdd" dgesdd :: TMMVM
52 55
@@ -54,15 +57,19 @@ foreign import ccall "LAPACK/lapack-aux.h svd_l_Rdd" dgesdd :: TMMVM
54-- 57--
55-- @(u,s,v)=svdRdd m@ so that @m=u \<\> s \<\> 'trans' v@. 58-- @(u,s,v)=svdRdd m@ so that @m=u \<\> s \<\> 'trans' v@.
56svdRdd :: Matrix Double -> (Matrix Double, Matrix Double , Matrix Double) 59svdRdd :: Matrix Double -> (Matrix Double, Matrix Double , Matrix Double)
57svdRdd x@M {rows = r, cols = c} = (u, diagRect s r c, v) where (u,s,v) = svdRdd' x 60svdRdd x = (u, diagRect s r c, v) where (u,s,v) = svdRdd' x
61 r = rows x
62 c = cols x
58 63
59svdRdd' :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double) 64svdRdd' :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
60svdRdd' x@M {rows = r, cols = c} = unsafePerformIO $ do 65svdRdd' x = unsafePerformIO $ do
61 u <- createMatrix ColumnMajor r r 66 u <- createMatrix ColumnMajor r r
62 s <- createVector (min r c) 67 s <- createVector (min r c)
63 v <- createMatrix ColumnMajor c c 68 v <- createMatrix ColumnMajor c c
64 dgesdd // mat fdat x // mat dat u // vec s // mat dat v // check "svdRdd" [fdat x] 69 dgesdd // mat fdat x // mat dat u // vec s // mat dat v // check "svdRdd" [fdat x]
65 return (u,s,trans v) 70 return (u,s,trans v)
71 where r = rows x
72 c = cols x
66 73
67----------------------------------------------------------------------------- 74-----------------------------------------------------------------------------
68foreign import ccall "LAPACK/lapack-aux.h svd_l_C" zgesvd :: TCMCMVCM 75foreign import ccall "LAPACK/lapack-aux.h svd_l_C" zgesvd :: TCMCMVCM
@@ -71,15 +78,20 @@ foreign import ccall "LAPACK/lapack-aux.h svd_l_C" zgesvd :: TCMCMVCM
71-- 78--
72-- @(u,s,v)=svdC m@ so that @m=u \<\> s \<\> 'trans' v@. 79-- @(u,s,v)=svdC m@ so that @m=u \<\> s \<\> 'trans' v@.
73svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix Double, Matrix (Complex Double)) 80svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix Double, Matrix (Complex Double))
74svdC x@M {rows = r, cols = c} = (u, diagRect s r c, v) where (u,s,v) = svdC' x 81svdC x = (u, diagRect s r c, v) where (u,s,v) = svdC' x
82 r = rows x
83 c = cols x
75 84
76svdC' :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double)) 85svdC' :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double))
77svdC' x@M {rows = r, cols = c} = unsafePerformIO $ do 86svdC' x = unsafePerformIO $ do
78 u <- createMatrix ColumnMajor r r 87 u <- createMatrix ColumnMajor r r
79 s <- createVector (min r c) 88 s <- createVector (min r c)
80 v <- createMatrix ColumnMajor c c 89 v <- createMatrix ColumnMajor c c
81 zgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdC" [fdat x] 90 zgesvd // mat fdat x // mat dat u // vec s // mat dat v // check "svdC" [fdat x]
82 return (u,s,trans v) 91 return (u,s,trans v)
92 where r = rows x
93 c = cols x
94
83 95
84----------------------------------------------------------------------------- 96-----------------------------------------------------------------------------
85foreign import ccall "LAPACK/lapack-aux.h eig_l_C" zgeev :: TCMCMCVCM 97foreign import ccall "LAPACK/lapack-aux.h eig_l_C" zgeev :: TCMCMCVCM
@@ -91,7 +103,7 @@ foreign import ccall "LAPACK/lapack-aux.h eig_l_C" zgeev :: TCMCMCVCM
91-- The eigenvectors are the columns of v. 103-- The eigenvectors are the columns of v.
92-- The eigenvalues are not sorted. 104-- The eigenvalues are not sorted.
93eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Double)) 105eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Double))
94eigC (m@M {rows = r}) 106eigC m
95 | r == 1 = (fromList [cdat m `at` 0], singleton 1) 107 | r == 1 = (fromList [cdat m `at` 0], singleton 1)
96 | otherwise = unsafePerformIO $ do 108 | otherwise = unsafePerformIO $ do
97 l <- createVector r 109 l <- createVector r
@@ -99,6 +111,7 @@ eigC (m@M {rows = r})
99 dummy <- createMatrix ColumnMajor 1 1 111 dummy <- createMatrix ColumnMajor 1 1
100 zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m] 112 zgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigC" [fdat m]
101 return (l,v) 113 return (l,v)
114 where r = rows m
102 115
103----------------------------------------------------------------------------- 116-----------------------------------------------------------------------------
104foreign import ccall "LAPACK/lapack-aux.h eig_l_R" dgeev :: TMMCVM 117foreign import ccall "LAPACK/lapack-aux.h eig_l_R" dgeev :: TMMCVM
@@ -110,14 +123,15 @@ foreign import ccall "LAPACK/lapack-aux.h eig_l_R" dgeev :: TMMCVM
110-- The eigenvectors are the columns of v. 123-- The eigenvectors are the columns of v.
111-- The eigenvalues are not sorted. 124-- The eigenvalues are not sorted.
112eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double)) 125eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double))
113eigR (m@M {rows = r}) = (s', v'') 126eigR m = (s', v'')
114 where (s,v) = eigRaux m 127 where (s,v) = eigRaux m
115 s' = toComplex (subVector 0 r (asReal s), subVector r r (asReal s)) 128 s' = toComplex (subVector 0 r (asReal s), subVector r r (asReal s))
116 v' = toRows $ trans v 129 v' = toRows $ trans v
117 v'' = fromColumns $ fixeig (toList s') v' 130 v'' = fromColumns $ fixeig (toList s') v'
131 r = rows m
118 132
119eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double) 133eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double)
120eigRaux (m@M {rows = r}) 134eigRaux m
121 | r == 1 = (fromList [(cdat m `at` 0):+0], singleton 1) 135 | r == 1 = (fromList [(cdat m `at` 0):+0], singleton 1)
122 | otherwise = unsafePerformIO $ do 136 | otherwise = unsafePerformIO $ do
123 l <- createVector r 137 l <- createVector r
@@ -125,6 +139,7 @@ eigRaux (m@M {rows = r})
125 dummy <- createMatrix ColumnMajor 1 1 139 dummy <- createMatrix ColumnMajor 1 1
126 dgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigR" [fdat m] 140 dgeev // mat fdat m // mat dat dummy // vec l // mat dat v // check "eigR" [fdat m]
127 return (l,v) 141 return (l,v)
142 where r = rows m
128 143
129fixeig [] _ = [] 144fixeig [] _ = []
130fixeig [r] [v] = [comp v] 145fixeig [r] [v] = [comp v]
@@ -148,13 +163,14 @@ eigS m = (s', fliprl v)
148 where (s,v) = eigS' m 163 where (s,v) = eigS' m
149 s' = fromList . reverse . toList $ s 164 s' = fromList . reverse . toList $ s
150 165
151eigS' (m@M {rows = r}) 166eigS' m
152 | r == 1 = (fromList [cdat m `at` 0], singleton 1) 167 | r == 1 = (fromList [cdat m `at` 0], singleton 1)
153 | otherwise = unsafePerformIO $ do 168 | otherwise = unsafePerformIO $ do
154 l <- createVector r 169 l <- createVector r
155 v <- createMatrix ColumnMajor r r 170 v <- createMatrix ColumnMajor r r
156 dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m] 171 dsyev // mat fdat m // vec l // mat dat v // check "eigS" [fdat m]
157 return (l,v) 172 return (l,v)
173 where r = rows m
158 174
159----------------------------------------------------------------------------- 175-----------------------------------------------------------------------------
160foreign import ccall "LAPACK/lapack-aux.h eig_l_H" zheev :: TCMVCM 176foreign import ccall "LAPACK/lapack-aux.h eig_l_H" zheev :: TCMVCM
@@ -170,37 +186,46 @@ eigH m = (s', fliprl v)
170 where (s,v) = eigH' m 186 where (s,v) = eigH' m
171 s' = fromList . reverse . toList $ s 187 s' = fromList . reverse . toList $ s
172 188
173eigH' (m@M {rows = r}) 189eigH' m
174 | r == 1 = (fromList [realPart (cdat m `at` 0)], singleton 1) 190 | r == 1 = (fromList [realPart (cdat m `at` 0)], singleton 1)
175 | otherwise = unsafePerformIO $ do 191 | otherwise = unsafePerformIO $ do
176 l <- createVector r 192 l <- createVector r
177 v <- createMatrix ColumnMajor r r 193 v <- createMatrix ColumnMajor r r
178 zheev // mat fdat m // vec l // mat dat v // check "eigH" [fdat m] 194 zheev // mat fdat m // vec l // mat dat v // check "eigH" [fdat m]
179 return (l,v) 195 return (l,v)
196 where r = rows m
180 197
181----------------------------------------------------------------------------- 198-----------------------------------------------------------------------------
182foreign import ccall "LAPACK/lapack-aux.h linearSolveR_l" dgesv :: TMMM 199foreign import ccall "LAPACK/lapack-aux.h linearSolveR_l" dgesv :: TMMM
183 200
184-- | Wrapper for LAPACK's /dgesv/, which solves a general real linear system (for several right-hand sides) internally using the lu decomposition. 201-- | Wrapper for LAPACK's /dgesv/, which solves a general real linear system (for several right-hand sides) internally using the lu decomposition.
185linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double 202linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double
186linearSolveR a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) 203linearSolveR a b
187 | n1==n2 && n1==r = unsafePerformIO $ do 204 | n1==n2 && n1==r = unsafePerformIO $ do
188 s <- createMatrix ColumnMajor r c 205 s <- createMatrix ColumnMajor r c
189 dgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveR" [fdat a, fdat b] 206 dgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveR" [fdat a, fdat b]
190 return s 207 return s
191 | otherwise = error "linearSolveR of nonsquare matrix" 208 | otherwise = error "linearSolveR of nonsquare matrix"
209 where n1 = rows a
210 n2 = cols a
211 r = rows b
212 c = cols b
192 213
193----------------------------------------------------------------------------- 214-----------------------------------------------------------------------------
194foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM 215foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM
195 216
196-- | Wrapper for LAPACK's /zgesv/, which solves a general complex linear system (for several right-hand sides) internally using the lu decomposition. 217-- | Wrapper for LAPACK's /zgesv/, which solves a general complex linear system (for several right-hand sides) internally using the lu decomposition.
197linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 218linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
198linearSolveC a@(M {rows = n1, cols = n2}) b@(M {rows = r, cols = c}) 219linearSolveC a b
199 | n1==n2 && n1==r = unsafePerformIO $ do 220 | n1==n2 && n1==r = unsafePerformIO $ do
200 s <- createMatrix ColumnMajor r c 221 s <- createMatrix ColumnMajor r c
201 zgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveC" [fdat a, fdat b] 222 zgesv // mat fdat a // mat fdat b // mat dat s // check "linearSolveC" [fdat a, fdat b]
202 return s 223 return s
203 | otherwise = error "linearSolveC of nonsquare matrix" 224 | otherwise = error "linearSolveC of nonsquare matrix"
225 where n1 = rows a
226 n2 = cols a
227 r = rows b
228 c = cols b
204 229
205----------------------------------------------------------------------------------- 230-----------------------------------------------------------------------------------
206foreign import ccall "LAPACK/lapack-aux.h linearSolveLSR_l" dgels :: TMMM 231foreign import ccall "LAPACK/lapack-aux.h linearSolveLSR_l" dgels :: TMMM
@@ -209,10 +234,13 @@ foreign import ccall "LAPACK/lapack-aux.h linearSolveLSR_l" dgels :: TMMM
209linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double 234linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double
210linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSR_l a b 235linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSR_l a b
211 236
212linearSolveLSR_l a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do 237linearSolveLSR_l a b = unsafePerformIO $ do
213 r <- createMatrix ColumnMajor (max m n) nrhs 238 r <- createMatrix ColumnMajor (max m n) nrhs
214 dgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSR" [fdat a, fdat b] 239 dgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSR" [fdat a, fdat b]
215 return r 240 return r
241 where m = rows a
242 n = cols a
243 nrhs = cols b
216 244
217----------------------------------------------------------------------------------- 245-----------------------------------------------------------------------------------
218foreign import ccall "LAPACK/lapack-aux.h linearSolveLSC_l" zgels :: TCMCMCM 246foreign import ccall "LAPACK/lapack-aux.h linearSolveLSC_l" zgels :: TCMCMCM
@@ -221,10 +249,13 @@ foreign import ccall "LAPACK/lapack-aux.h linearSolveLSC_l" zgels :: TCMCMCM
221linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double) 249linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
222linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSC_l a b 250linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $ linearSolveLSC_l a b
223 251
224linearSolveLSC_l a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do 252linearSolveLSC_l a b = unsafePerformIO $ do
225 r <- createMatrix ColumnMajor (max m n) nrhs 253 r <- createMatrix ColumnMajor (max m n) nrhs
226 zgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSC" [fdat a, fdat b] 254 zgels // mat fdat a // mat fdat b // mat dat r // check "linearSolveLSC" [fdat a, fdat b]
227 return r 255 return r
256 where m = rows a
257 n = cols a
258 nrhs = cols b
228 259
229----------------------------------------------------------------------------------- 260-----------------------------------------------------------------------------------
230foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDR_l" dgelss :: Double -> TMMM 261foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDR_l" dgelss :: Double -> TMMM
@@ -237,10 +268,13 @@ linearSolveSVDR :: Maybe Double -- ^ rcond
237linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDR_l rcond a b 268linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDR_l rcond a b
238linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) a b 269linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) a b
239 270
240linearSolveSVDR_l rcond a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do 271linearSolveSVDR_l rcond a b = unsafePerformIO $ do
241 r <- createMatrix ColumnMajor (max m n) nrhs 272 r <- createMatrix ColumnMajor (max m n) nrhs
242 dgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDR" [fdat a, fdat b] 273 dgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDR" [fdat a, fdat b]
243 return r 274 return r
275 where m = rows a
276 n = cols a
277 nrhs = cols b
244 278
245----------------------------------------------------------------------------------- 279-----------------------------------------------------------------------------------
246foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDC_l" zgelss :: Double -> TCMCMCM 280foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDC_l" zgelss :: Double -> TCMCMCM
@@ -253,8 +287,11 @@ linearSolveSVDC :: Maybe Double -- ^ rcond
253linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDC_l rcond a b 287linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $ linearSolveSVDC_l rcond a b
254linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) a b 288linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) a b
255 289
256linearSolveSVDC_l rcond a@(M {rows = m, cols = n}) b@(M {cols = nrhs}) = unsafePerformIO $ do 290linearSolveSVDC_l rcond a b = unsafePerformIO $ do
257 r <- createMatrix ColumnMajor (max m n) nrhs 291 r <- createMatrix ColumnMajor (max m n) nrhs
258 zgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDC" [fdat a, fdat b] 292 zgelss rcond // mat fdat a // mat fdat b // mat dat r // check "linearSolveSVDC" [fdat a, fdat b]
259 return r 293 return r
294 where m = rows a
295 n = cols a
296 nrhs = cols b
260 297