summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra/LAPACK.hs
blob: 8bc249261c9c7bb6b9fdd4d26a4b30d85166e303 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
{-# OPTIONS_GHC #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.LinearAlgebra.LAPACK
-- Copyright   :  (c) Alberto Ruiz 2006-7
-- License     :  GPL-style
-- 
-- Maintainer  :  Alberto Ruiz (aruiz at um dot es)
-- Stability   :  provisional
-- Portability :  portable (uses FFI)
--
-- Wrappers for a few LAPACK functions (<http://www.netlib.org/lapack>).
--
-----------------------------------------------------------------------------

module Numeric.LinearAlgebra.LAPACK (
    svdR, svdRdd, svdC,
    eigC, eigR, eigS, eigH, eigS', eigH',
    linearSolveR, linearSolveC,
    linearSolveLSR, linearSolveLSC,
    linearSolveSVDR, linearSolveSVDC,
    luR, luC, lusR, lusC,
    cholS, cholH,
    qrR, qrC,
    hessR, hessC,
    schurR, schurC
) where

import Data.Packed.Internal
import Data.Packed.Internal.Vector
import Data.Packed.Internal.Matrix
import Data.Packed.Vector
import Data.Packed.Matrix
import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale))
import Complex
import Foreign
import Foreign.C.Types (CInt)

-----------------------------------------------------------------------------
foreign import ccall "LAPACK/lapack-aux.h svd_l_R" dgesvd :: TMMVM
foreign import ccall "LAPACK/lapack-aux.h svd_l_C" zgesvd :: TCMCMVCM
foreign import ccall "LAPACK/lapack-aux.h svd_l_Rdd" dgesdd :: TMMVM

-- | Wrapper for LAPACK's /dgesvd/, which computes the full svd decomposition of a real matrix.
--
-- @(u,s,v)=full svdR m@ so that @m=u \<\> s \<\> 'trans' v@.
svdR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
svdR = svdAux dgesvd "svdR" . fmat

-- | Wrapper for LAPACK's /dgesvd/, which computes the full svd decomposition of a real matrix.
--
-- @(u,s,v)=full svdRdd m@ so that @m=u \<\> s \<\> 'trans' v@.
svdRdd :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
svdRdd = svdAux dgesdd "svdRdd" . fmat

-- | Wrapper for LAPACK's /zgesvd/, which computes the full svd decomposition of a complex matrix.
--
-- @(u,s,v)=full svdC m@ so that @m=u \<\> comp s \<\> 'trans' v@.
svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double))
svdC = svdAux zgesvd "svdC" . fmat

svdAux f st x = unsafePerformIO $ do
    u <- createMatrix ColumnMajor r r
    s <- createVector (min r c)
    v <- createMatrix ColumnMajor c c
    app4 f mat x mat u vec s mat v st
    return (u,s,trans v)
  where r = rows x
        c = cols x

-----------------------------------------------------------------------------
eigAux f st m
    | r == 1 = (fromList [flatten m `at` 0], singleton 1)
    | otherwise = unsafePerformIO $ do
        l <- createVector r
        v <- createMatrix ColumnMajor r r
        dummy <- createMatrix ColumnMajor 1 1
        app4 f mat m mat dummy vec l mat v st
        return (l,v)
  where r = rows m


foreign import ccall "LAPACK/lapack-aux.h eig_l_C" zgeev :: TCMCMCVCM
foreign import ccall "LAPACK/lapack-aux.h eig_l_R" dgeev :: TMMCVM
foreign import ccall "LAPACK/lapack-aux.h eig_l_S" dsyev :: TMVM
foreign import ccall "LAPACK/lapack-aux.h eig_l_H" zheev :: TCMVCM

-- | Wrapper for LAPACK's /zgeev/, which computes the eigenvalues and right eigenvectors of a general complex matrix:
--
-- if @(l,v)=eigC m@ then @m \<\> v = v \<\> diag l@.
--
-- The eigenvectors are the columns of v.
-- The eigenvalues are not sorted.
eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Double))
eigC = eigAux zgeev "eigC" . fmat

-----------------------------------------------------------------------------

-- | Wrapper for LAPACK's /dgeev/, which computes the eigenvalues and right eigenvectors of a general real matrix:
--
-- if @(l,v)=eigR m@ then @m \<\> v = v \<\> diag l@.
--
-- The eigenvectors are the columns of v.
-- The eigenvalues are not sorted.
eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double))
eigR m = (s', v'')
    where (s,v) = eigRaux (fmat m)
          s' = toComplex (subVector 0 r (asReal s), subVector r r (asReal s))
          v' = toRows $ trans v
          v'' = fromColumns $ fixeig (toList s') v'
          r = rows m

eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double)
eigRaux m
    | r == 1 = (fromList [(flatten m `at` 0):+0], singleton 1)
    | otherwise = unsafePerformIO $ do
        l <- createVector r
        v <- createMatrix ColumnMajor r r
        dummy <- createMatrix ColumnMajor 1 1
        app4 dgeev mat m mat dummy vec l mat v "eigR"
        return (l,v)
  where r = rows m

fixeig  []  _ =  []
fixeig [_] [v] = [comp v]
fixeig ((r1:+i1):(r2:+i2):r) (v1:v2:vs)
    | r1 == r2 && i1 == (-i2) = toComplex (v1,v2) : toComplex (v1,scale (-1) v2) : fixeig r vs
    | otherwise = comp v1 : fixeig ((r2:+i2):r) (v2:vs)
  where scale = vectorMapValR Scale
fixeig _ _ = error "fixeig with impossible inputs"

-----------------------------------------------------------------------------

-- | Wrapper for LAPACK's /dsyev/, which computes the eigenvalues and right eigenvectors of a symmetric real matrix:
--
-- if @(l,v)=eigSl m@ then @m \<\> v = v \<\> diag l@.
--
-- The eigenvectors are the columns of v.
-- The eigenvalues are sorted in descending order (use eigS' for ascending order).
eigS :: Matrix Double -> (Vector Double, Matrix Double)
eigS m = (s', fliprl v)
    where (s,v) = eigS' (fmat m)
          s' = fromList . reverse . toList $  s

eigS' m
    | r == 1 = (fromList [flatten m `at` 0], singleton 1)
    | otherwise = unsafePerformIO $ do
        l <- createVector r
        v <- createMatrix ColumnMajor r r
        app3 dsyev mat m vec l mat v "eigS"
        return (l,v)
  where r = rows m

-----------------------------------------------------------------------------

-- | Wrapper for LAPACK's /zheev/, which computes the eigenvalues and right eigenvectors of a hermitian complex matrix:
--
-- if @(l,v)=eigH m@ then @m \<\> s v = v \<\> diag l@.
--
-- The eigenvectors are the columns of v.
-- The eigenvalues are sorted in descending order (use eigH' for ascending order).
eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double))
eigH m = (s', fliprl v)
    where (s,v) = eigH' (fmat m)
          s' = fromList . reverse . toList $  s

eigH' m
    | r == 1 = (fromList [realPart (flatten m `at` 0)], singleton 1)
    | otherwise = unsafePerformIO $ do
        l <- createVector r
        v <- createMatrix ColumnMajor r r
        app3 zheev mat m vec l mat v "eigH"
        return (l,v)
  where r = rows m

-----------------------------------------------------------------------------
foreign import ccall "LAPACK/lapack-aux.h linearSolveR_l" dgesv :: TMMM
foreign import ccall "LAPACK/lapack-aux.h linearSolveC_l" zgesv :: TCMCMCM

linearSolveSQAux f st a b
    | n1==n2 && n1==r = unsafePerformIO $ do
        s <- createMatrix ColumnMajor r c
        app3 f mat a mat b mat s st
        return s
    | otherwise = error $ st ++ " of nonsquare matrix"
  where n1 = rows a
        n2 = cols a
        r  = rows b
        c  = cols b

-- | Wrapper for LAPACK's /dgesv/, which solves a general real linear system (for several right-hand sides) internally using the lu decomposition.
linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double
linearSolveR a b = linearSolveSQAux dgesv "linearSolveR" (fmat a) (fmat b)

-- | Wrapper for LAPACK's /zgesv/, which solves a general complex linear system (for several right-hand sides) internally using the lu decomposition.
linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
linearSolveC a b = linearSolveSQAux zgesv "linearSolveC" (fmat a) (fmat b)

-----------------------------------------------------------------------------------
foreign import ccall "LAPACK/lapack-aux.h linearSolveLSR_l" dgels :: TMMM
foreign import ccall "LAPACK/lapack-aux.h linearSolveLSC_l" zgels :: TCMCMCM
foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDR_l" dgelss :: Double -> TMMM
foreign import ccall "LAPACK/lapack-aux.h linearSolveSVDC_l" zgelss :: Double -> TCMCMCM

linearSolveAux f st a b = unsafePerformIO $ do
    r <- createMatrix ColumnMajor (max m n) nrhs
    app3 f mat a mat b mat r st
    return r
  where m = rows a
        n = cols a
        nrhs = cols b

-- | Wrapper for LAPACK's /dgels/, which obtains the least squared error solution of an overconstrained real linear system or the minimum norm solution of an underdetermined system, for several right-hand sides. For rank deficient systems use 'linearSolveSVDR'.
linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double
linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $
                     linearSolveAux dgels "linearSolverLSR" (fmat a) (fmat b)

-- | Wrapper for LAPACK's /zgels/, which obtains the least squared error solution of an overconstrained complex linear system or the minimum norm solution of an underdetermined system, for several right-hand sides. For rank deficient systems use 'linearSolveSVDC'.
linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $
                     linearSolveAux zgels "linearSolveLSC" (fmat a) (fmat b)

-- | Wrapper for LAPACK's /dgelss/, which obtains the minimum norm solution to a real linear least squares problem Ax=B using the svd, for several right-hand sides. Admits rank deficient systems but it is slower than 'linearSolveLSR'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used.
linearSolveSVDR :: Maybe Double   -- ^ rcond
                -> Matrix Double  -- ^ coefficient matrix
                -> Matrix Double  -- ^ right hand sides (as columns)
                -> Matrix Double  -- ^ solution vectors (as columns)
linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $
                                   linearSolveAux (dgelss rcond) "linearSolveSVDR" (fmat a) (fmat b)
linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) (fmat a) (fmat b)

-- | Wrapper for LAPACK's /zgelss/, which obtains the minimum norm solution to a complex linear least squares problem Ax=B using the svd, for several right-hand sides. Admits rank deficient systems but it is slower than 'linearSolveLSC'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used.
linearSolveSVDC :: Maybe Double            -- ^ rcond
                -> Matrix (Complex Double) -- ^ coefficient matrix
                -> Matrix (Complex Double) -- ^ right hand sides (as columns)
                -> Matrix (Complex Double) -- ^ solution vectors (as columns)
linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $
                                   linearSolveAux (zgelss rcond) "linearSolveSVDC" (fmat a) (fmat b)
linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b)

-----------------------------------------------------------------------------------
foreign import ccall "LAPACK/lapack-aux.h chol_l_H" zpotrf :: TCMCM
foreign import ccall "LAPACK/lapack-aux.h chol_l_S" dpotrf :: TMM

-- | Wrapper for LAPACK's /zpotrf/, which computes the Cholesky factorization of a
-- complex Hermitian positive definite matrix.
cholH :: Matrix (Complex Double) -> Matrix (Complex Double)
cholH = cholAux zpotrf "cholH" . fmat

-- | Wrapper for LAPACK's /dpotrf/, which computes the Cholesky factorization of a
-- real symmetric positive definite matrix.
cholS :: Matrix Double -> Matrix Double
cholS = cholAux dpotrf "cholS" . fmat

cholAux f st a = unsafePerformIO $ do
    r <- createMatrix ColumnMajor n n
    app2 f mat a mat r st
    return r
  where n = rows a

-----------------------------------------------------------------------------------
foreign import ccall "LAPACK/lapack-aux.h qr_l_R" dgeqr2 :: TMVM
foreign import ccall "LAPACK/lapack-aux.h qr_l_C" zgeqr2 :: TCMCVCM

-- | Wrapper for LAPACK's /dgeqr2/, which computes a QR factorization of a real matrix.
qrR :: Matrix Double -> (Matrix Double, Vector Double)
qrR = qrAux dgeqr2 "qrR" . fmat

-- | Wrapper for LAPACK's /zgeqr2/, which computes a QR factorization of a complex matrix.
qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double))
qrC = qrAux zgeqr2 "qrC" . fmat

qrAux f st a = unsafePerformIO $ do
    r <- createMatrix ColumnMajor m n
    tau <- createVector mn
    app3 f mat a vec tau mat r st
    return (r,tau)
  where m = rows a
        n = cols a
        mn = min m n

-----------------------------------------------------------------------------------
foreign import ccall "LAPACK/lapack-aux.h hess_l_R" dgehrd :: TMVM
foreign import ccall "LAPACK/lapack-aux.h hess_l_C" zgehrd :: TCMCVCM

-- | Wrapper for LAPACK's /dgehrd/, which computes a Hessenberg factorization of a square real matrix.
hessR :: Matrix Double -> (Matrix Double, Vector Double)
hessR = hessAux dgehrd "hessR" . fmat

-- | Wrapper for LAPACK's /zgehrd/, which computes a Hessenberg factorization of a square complex matrix.
hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double))
hessC = hessAux zgehrd "hessC" . fmat

hessAux f st a = unsafePerformIO $ do
    r <- createMatrix ColumnMajor m n
    tau <- createVector (mn-1)
    app3 f mat a vec tau mat r st
    return (r,tau)
  where m = rows a
        n = cols a
        mn = min m n

-----------------------------------------------------------------------------------
foreign import ccall "LAPACK/lapack-aux.h schur_l_R" dgees :: TMMM
foreign import ccall "LAPACK/lapack-aux.h schur_l_C" zgees :: TCMCMCM

-- | Wrapper for LAPACK's /dgees/, which computes a Schur factorization of a square real matrix.
schurR :: Matrix Double -> (Matrix Double, Matrix Double)
schurR = schurAux dgees "schurR" . fmat

-- | Wrapper for LAPACK's /zgees/, which computes a Schur factorization of a square complex matrix.
schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double))
schurC = schurAux zgees "schurC" . fmat

schurAux f st a = unsafePerformIO $ do
    u <- createMatrix ColumnMajor n n
    s <- createMatrix ColumnMajor n n
    app3 f mat a mat u mat s st
    return (u,s)
  where n = rows a

-----------------------------------------------------------------------------------
foreign import ccall "LAPACK/lapack-aux.h lu_l_R" dgetrf :: TMVM
foreign import ccall "LAPACK/lapack-aux.h lu_l_C" zgetrf :: TCMVCM

-- | Wrapper for LAPACK's /dgetrf/, which computes a LU factorization of a general real matrix.
luR :: Matrix Double -> (Matrix Double, [Int])
luR = luAux dgetrf "luR" . fmat

-- | Wrapper for LAPACK's /zgees/, which computes a Schur factorization of a square complex matrix.
luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int])
luC = luAux zgetrf "luC" . fmat

luAux f st a = unsafePerformIO $ do
    lu <- createMatrix ColumnMajor n m
    piv <- createVector (min n m)
    app3 f mat a vec piv mat lu st
    return (lu, map (pred.round) (toList piv))
  where n = rows a
        m = cols a

-----------------------------------------------------------------------------------
type TW a = CInt -> PD -> a
type TQ a = CInt -> CInt -> PC -> a

foreign import ccall "LAPACK/lapack-aux.h luS_l_R" dgetrs :: TMVMM
foreign import ccall "LAPACK/lapack-aux.h luS_l_C" zgetrs :: TQ (TW (TQ (TQ (IO CInt))))

-- | Wrapper for LAPACK's /dgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition.
lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double
lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv (fmat b)

-- | Wrapper for LAPACK's /zgetrs/, which solves a general real linear system (for several right-hand sides) from a precomputed LU decomposition.
lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double)
lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b)

lusAux f st a piv b
    | n1==n2 && n2==n =unsafePerformIO $ do
         x <- createMatrix ColumnMajor n m
         app4 f mat a vec piv' mat b mat x st
         return x
    | otherwise = error $ st ++ " on LU factorization of nonsquare matrix"
  where n1 = rows a
        n2 = cols a
        n = rows b
        m = cols b
        piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double