summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/LAPACK.hs
diff options
context:
space:
mode:
authorexfalso <0slemi0@gmail.com>2016-10-07 16:49:57 +0100
committerexfalso <0slemi0@gmail.com>2016-10-07 17:03:35 +0100
commit59cb364ebd7bff09a19f5f83104752a14f6a5177 (patch)
treeb95f05bc88eb6b811d1e77fbde9ae8ddb1ac9aa0 /packages/base/src/Internal/LAPACK.hs
parent2f773c0148a1a50b84226f69852997d53b0653fb (diff)
Redefine (#)
Diffstat (limited to 'packages/base/src/Internal/LAPACK.hs')
-rw-r--r--packages/base/src/Internal/LAPACK.hs54
1 files changed, 28 insertions, 26 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index c2c140b..231109a 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -18,7 +18,7 @@ module Internal.LAPACK where
18 18
19import Internal.Devel 19import Internal.Devel
20import Internal.Vector 20import Internal.Vector
21import Internal.Matrix hiding ((#)) 21import Internal.Matrix hiding ((#), (#!))
22import Internal.Conversion 22import Internal.Conversion
23import Internal.Element 23import Internal.Element
24import Foreign.Ptr(nullPtr) 24import Foreign.Ptr(nullPtr)
@@ -28,10 +28,13 @@ import System.IO.Unsafe(unsafePerformIO)
28 28
29----------------------------------------------------------------------------------- 29-----------------------------------------------------------------------------------
30 30
31infixl 1 # 31infixr 1 #
32a # b = apply a b 32a # b = apply a b
33{-# INLINE (#) #-} 33{-# INLINE (#) #-}
34 34
35a #! b = a # b # id
36{-# INLINE (#!) #-}
37
35----------------------------------------------------------------------------------- 38-----------------------------------------------------------------------------------
36 39
37type TMMM t = t ::> t ::> t ::> Ok 40type TMMM t = t ::> t ::> t ::> Ok
@@ -56,7 +59,7 @@ multiplyAux f st a b = unsafePerformIO $ do
56 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ 59 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++
57 show (rows a,cols a) ++ " x " ++ show (rows b, cols b) 60 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
58 s <- createMatrix ColumnMajor (rows a) (cols b) 61 s <- createMatrix ColumnMajor (rows a) (cols b)
59 f (isT a) (isT b) # (tt a) # (tt b) # s #| st 62 ((tt a) # (tt b) #! s) (f (isT a) (isT b)) #| st
60 return s 63 return s
61 64
62-- | Matrix product based on BLAS's /dgemm/. 65-- | Matrix product based on BLAS's /dgemm/.
@@ -80,7 +83,7 @@ multiplyI m a b = unsafePerformIO $ do
80 when (cols a /= rows b) $ error $ 83 when (cols a /= rows b) $ error $
81 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 84 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
82 s <- createMatrix ColumnMajor (rows a) (cols b) 85 s <- createMatrix ColumnMajor (rows a) (cols b)
83 c_multiplyI m # a # b # s #|"c_multiplyI" 86 (a # b #! s) (c_multiplyI m) #|"c_multiplyI"
84 return s 87 return s
85 88
86multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z 89multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z
@@ -88,7 +91,7 @@ multiplyL m a b = unsafePerformIO $ do
88 when (cols a /= rows b) $ error $ 91 when (cols a /= rows b) $ error $
89 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 92 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
90 s <- createMatrix ColumnMajor (rows a) (cols b) 93 s <- createMatrix ColumnMajor (rows a) (cols b)
91 c_multiplyL m # a # b # s #|"c_multiplyL" 94 (a # b #! s) (c_multiplyL m) #|"c_multiplyL"
92 return s 95 return s
93 96
94----------------------------------------------------------------------------- 97-----------------------------------------------------------------------------
@@ -121,7 +124,7 @@ svdAux f st x = unsafePerformIO $ do
121 u <- createMatrix ColumnMajor r r 124 u <- createMatrix ColumnMajor r r
122 s <- createVector (min r c) 125 s <- createVector (min r c)
123 v <- createMatrix ColumnMajor c c 126 v <- createMatrix ColumnMajor c c
124 f # a # u # s # v #| st 127 (a # u # s #! v) f #| st
125 return (u,s,v) 128 return (u,s,v)
126 where 129 where
127 r = rows x 130 r = rows x
@@ -149,7 +152,7 @@ thinSVDAux f st x = unsafePerformIO $ do
149 u <- createMatrix ColumnMajor r q 152 u <- createMatrix ColumnMajor r q
150 s <- createVector q 153 s <- createVector q
151 v <- createMatrix ColumnMajor q c 154 v <- createMatrix ColumnMajor q c
152 f # a # u # s # v #| st 155 (a # u # s #! v) f #| st
153 return (u,s,v) 156 return (u,s,v)
154 where 157 where
155 r = rows x 158 r = rows x
@@ -176,7 +179,7 @@ svCd = svAux zgesdd "svCd"
176svAux f st x = unsafePerformIO $ do 179svAux f st x = unsafePerformIO $ do
177 a <- copy ColumnMajor x 180 a <- copy ColumnMajor x
178 s <- createVector q 181 s <- createVector q
179 g # a # s #| st 182 (a #! s) g #| st
180 return s 183 return s
181 where 184 where
182 r = rows x 185 r = rows x
@@ -197,7 +200,7 @@ rightSVAux f st x = unsafePerformIO $ do
197 a <- copy ColumnMajor x 200 a <- copy ColumnMajor x
198 s <- createVector q 201 s <- createVector q
199 v <- createMatrix ColumnMajor c c 202 v <- createMatrix ColumnMajor c c
200 g # a # s # v #| st 203 (a # s #! v) g #| st
201 return (s,v) 204 return (s,v)
202 where 205 where
203 r = rows x 206 r = rows x
@@ -218,7 +221,7 @@ leftSVAux f st x = unsafePerformIO $ do
218 a <- copy ColumnMajor x 221 a <- copy ColumnMajor x
219 u <- createMatrix ColumnMajor r r 222 u <- createMatrix ColumnMajor r r
220 s <- createVector q 223 s <- createVector q
221 g # a # u # s #| st 224 (a # u #! s) g #| st
222 return (u,s) 225 return (u,s)
223 where 226 where
224 r = rows x 227 r = rows x
@@ -237,7 +240,7 @@ eigAux f st m = unsafePerformIO $ do
237 a <- copy ColumnMajor m 240 a <- copy ColumnMajor m
238 l <- createVector r 241 l <- createVector r
239 v <- createMatrix ColumnMajor r r 242 v <- createMatrix ColumnMajor r r
240 g # a # l # v #| st 243 (a # l #! v) g #| st
241 return (l,v) 244 return (l,v)
242 where 245 where
243 r = rows m 246 r = rows m
@@ -252,7 +255,7 @@ eigC = eigAux zgeev "eigC"
252eigOnlyAux f st m = unsafePerformIO $ do 255eigOnlyAux f st m = unsafePerformIO $ do
253 a <- copy ColumnMajor m 256 a <- copy ColumnMajor m
254 l <- createVector r 257 l <- createVector r
255 g # a # l #| st 258 (a #! l) g #| st
256 return l 259 return l
257 where 260 where
258 r = rows m 261 r = rows m
@@ -277,7 +280,7 @@ eigRaux m = unsafePerformIO $ do
277 a <- copy ColumnMajor m 280 a <- copy ColumnMajor m
278 l <- createVector r 281 l <- createVector r
279 v <- createMatrix ColumnMajor r r 282 v <- createMatrix ColumnMajor r r
280 g # a # l # v #| "eigR" 283 (a # l #! v) g #| "eigR"
281 return (l,v) 284 return (l,v)
282 where 285 where
283 r = rows m 286 r = rows m
@@ -305,7 +308,7 @@ eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR"
305eigSHAux f st m = unsafePerformIO $ do 308eigSHAux f st m = unsafePerformIO $ do
306 l <- createVector r 309 l <- createVector r
307 v <- copy ColumnMajor m 310 v <- copy ColumnMajor m
308 f # l # v #| st 311 (l #! v) f #| st
309 return (l,v) 312 return (l,v)
310 where 313 where
311 r = rows m 314 r = rows m
@@ -356,7 +359,7 @@ linearSolveSQAux g f st a b
356 | n1==n2 && n1==r = unsafePerformIO . g $ do 359 | n1==n2 && n1==r = unsafePerformIO . g $ do
357 a' <- copy ColumnMajor a 360 a' <- copy ColumnMajor a
358 s <- copy ColumnMajor b 361 s <- copy ColumnMajor b
359 f # a' # s #| st 362 (a' #! s) f #| st
360 return s 363 return s
361 | otherwise = error $ st ++ " of nonsquare matrix" 364 | otherwise = error $ st ++ " of nonsquare matrix"
362 where 365 where
@@ -387,7 +390,7 @@ foreign import ccall unsafe "cholSolveC_l" zpotrs :: C ::> C ::> Ok
387linearSolveSQAux2 g f st a b 390linearSolveSQAux2 g f st a b
388 | n1==n2 && n1==r = unsafePerformIO . g $ do 391 | n1==n2 && n1==r = unsafePerformIO . g $ do
389 s <- copy ColumnMajor b 392 s <- copy ColumnMajor b
390 f # a # s #| st 393 (a #! s) f #| st
391 return s 394 return s
392 | otherwise = error $ st ++ " of nonsquare matrix" 395 | otherwise = error $ st ++ " of nonsquare matrix"
393 where 396 where
@@ -415,7 +418,7 @@ linearSolveAux f st a b
415 a' <- copy ColumnMajor a 418 a' <- copy ColumnMajor a
416 r <- createMatrix ColumnMajor (max m n) nrhs 419 r <- createMatrix ColumnMajor (max m n) nrhs
417 setRect 0 0 b r 420 setRect 0 0 b r
418 f # a' # r #| st 421 (a' #! r) f #| st
419 return r 422 return r
420 | otherwise = error $ "different number of rows in linearSolve ("++st++")" 423 | otherwise = error $ "different number of rows in linearSolve ("++st++")"
421 where 424 where
@@ -458,7 +461,7 @@ foreign import ccall unsafe "chol_l_S" dpotrf :: R ::> Ok
458 461
459cholAux f st a = do 462cholAux f st a = do
460 r <- copy ColumnMajor a 463 r <- copy ColumnMajor a
461 f # r #| st 464 (r # id) f #| st
462 return r 465 return r
463 466
464-- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/. 467-- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/.
@@ -495,7 +498,7 @@ qrC = qrAux zgeqr2 "qrC"
495qrAux f st a = unsafePerformIO $ do 498qrAux f st a = unsafePerformIO $ do
496 r <- copy ColumnMajor a 499 r <- copy ColumnMajor a
497 tau <- createVector mn 500 tau <- createVector mn
498 f # tau # r #| st 501 (tau #! r) f #| st
499 return (r,tau) 502 return (r,tau)
500 where 503 where
501 m = rows a 504 m = rows a
@@ -514,7 +517,7 @@ qrgrC = qrgrAux zungqr "qrgrC"
514 517
515qrgrAux f st n (a, tau) = unsafePerformIO $ do 518qrgrAux f st n (a, tau) = unsafePerformIO $ do
516 res <- copy ColumnMajor (subMatrix (0,0) (rows a,n) a) 519 res <- copy ColumnMajor (subMatrix (0,0) (rows a,n) a)
517 f # (subVector 0 n tau') # res #| st 520 ((subVector 0 n tau') #! res) f #| st
518 return res 521 return res
519 where 522 where
520 tau' = vjoin [tau, constantD 0 n] 523 tau' = vjoin [tau, constantD 0 n]
@@ -534,7 +537,7 @@ hessC = hessAux zgehrd "hessC"
534hessAux f st a = unsafePerformIO $ do 537hessAux f st a = unsafePerformIO $ do
535 r <- copy ColumnMajor a 538 r <- copy ColumnMajor a
536 tau <- createVector (mn-1) 539 tau <- createVector (mn-1)
537 f # tau # r #| st 540 (tau #! r) f #| st
538 return (r,tau) 541 return (r,tau)
539 where 542 where
540 m = rows a 543 m = rows a
@@ -556,7 +559,7 @@ schurC = schurAux zgees "schurC"
556schurAux f st a = unsafePerformIO $ do 559schurAux f st a = unsafePerformIO $ do
557 u <- createMatrix ColumnMajor n n 560 u <- createMatrix ColumnMajor n n
558 s <- copy ColumnMajor a 561 s <- copy ColumnMajor a
559 f # u # s #| st 562 (u #! s) f #| st
560 return (u,s) 563 return (u,s)
561 where 564 where
562 n = rows a 565 n = rows a
@@ -576,7 +579,7 @@ luC = luAux zgetrf "luC"
576luAux f st a = unsafePerformIO $ do 579luAux f st a = unsafePerformIO $ do
577 lu <- copy ColumnMajor a 580 lu <- copy ColumnMajor a
578 piv <- createVector (min n m) 581 piv <- createVector (min n m)
579 f # piv # lu #| st 582 (piv #! lu) f #| st
580 return (lu, map (pred.round) (toList piv)) 583 return (lu, map (pred.round) (toList piv))
581 where 584 where
582 n = rows a 585 n = rows a
@@ -598,7 +601,7 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv b
598lusAux f st a piv b 601lusAux f st a piv b
599 | n1==n2 && n2==n =unsafePerformIO $ do 602 | n1==n2 && n2==n =unsafePerformIO $ do
600 x <- copy ColumnMajor b 603 x <- copy ColumnMajor b
601 f # a # piv' # x #| st 604 (a # piv' #! x) f #| st
602 return x 605 return x
603 | otherwise = error st 606 | otherwise = error st
604 where 607 where
@@ -622,7 +625,7 @@ ldlC = ldlAux zhetrf "ldlC"
622ldlAux f st a = unsafePerformIO $ do 625ldlAux f st a = unsafePerformIO $ do
623 ldl <- copy ColumnMajor a 626 ldl <- copy ColumnMajor a
624 piv <- createVector (rows a) 627 piv <- createVector (rows a)
625 f # piv # ldl #| st 628 (piv #! ldl) f #| st
626 return (ldl, map (pred.round) (toList piv)) 629 return (ldl, map (pred.round) (toList piv))
627 630
628----------------------------------------------------------------------------------- 631-----------------------------------------------------------------------------------
@@ -637,4 +640,3 @@ ldlsR a piv b = lusAux dsytrs "ldlsR" (fmat a) piv b
637-- | Solve a complex linear system from a precomputed LDL decomposition ('ldlC'), using LAPACK's /zsytrs/. 640-- | Solve a complex linear system from a precomputed LDL decomposition ('ldlC'), using LAPACK's /zsytrs/.
638ldlsC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) 641ldlsC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double)
639ldlsC a piv b = lusAux zsytrs "ldlsC" (fmat a) piv b 642ldlsC a piv b = lusAux zsytrs "ldlsC" (fmat a) piv b
640