diff options
author | exfalso <0slemi0@gmail.com> | 2016-10-07 16:49:57 +0100 |
---|---|---|
committer | exfalso <0slemi0@gmail.com> | 2016-10-07 17:03:35 +0100 |
commit | 59cb364ebd7bff09a19f5f83104752a14f6a5177 (patch) | |
tree | b95f05bc88eb6b811d1e77fbde9ae8ddb1ac9aa0 /packages/base/src/Internal/LAPACK.hs | |
parent | 2f773c0148a1a50b84226f69852997d53b0653fb (diff) |
Redefine (#)
Diffstat (limited to 'packages/base/src/Internal/LAPACK.hs')
-rw-r--r-- | packages/base/src/Internal/LAPACK.hs | 54 |
1 files changed, 28 insertions, 26 deletions
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs index c2c140b..231109a 100644 --- a/packages/base/src/Internal/LAPACK.hs +++ b/packages/base/src/Internal/LAPACK.hs | |||
@@ -18,7 +18,7 @@ module Internal.LAPACK where | |||
18 | 18 | ||
19 | import Internal.Devel | 19 | import Internal.Devel |
20 | import Internal.Vector | 20 | import Internal.Vector |
21 | import Internal.Matrix hiding ((#)) | 21 | import Internal.Matrix hiding ((#), (#!)) |
22 | import Internal.Conversion | 22 | import Internal.Conversion |
23 | import Internal.Element | 23 | import Internal.Element |
24 | import Foreign.Ptr(nullPtr) | 24 | import Foreign.Ptr(nullPtr) |
@@ -28,10 +28,13 @@ import System.IO.Unsafe(unsafePerformIO) | |||
28 | 28 | ||
29 | ----------------------------------------------------------------------------------- | 29 | ----------------------------------------------------------------------------------- |
30 | 30 | ||
31 | infixl 1 # | 31 | infixr 1 # |
32 | a # b = apply a b | 32 | a # b = apply a b |
33 | {-# INLINE (#) #-} | 33 | {-# INLINE (#) #-} |
34 | 34 | ||
35 | a #! b = a # b # id | ||
36 | {-# INLINE (#!) #-} | ||
37 | |||
35 | ----------------------------------------------------------------------------------- | 38 | ----------------------------------------------------------------------------------- |
36 | 39 | ||
37 | type TMMM t = t ::> t ::> t ::> Ok | 40 | type TMMM t = t ::> t ::> t ::> Ok |
@@ -56,7 +59,7 @@ multiplyAux f st a b = unsafePerformIO $ do | |||
56 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ | 59 | when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ |
57 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | 60 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) |
58 | s <- createMatrix ColumnMajor (rows a) (cols b) | 61 | s <- createMatrix ColumnMajor (rows a) (cols b) |
59 | f (isT a) (isT b) # (tt a) # (tt b) # s #| st | 62 | ((tt a) # (tt b) #! s) (f (isT a) (isT b)) #| st |
60 | return s | 63 | return s |
61 | 64 | ||
62 | -- | Matrix product based on BLAS's /dgemm/. | 65 | -- | Matrix product based on BLAS's /dgemm/. |
@@ -80,7 +83,7 @@ multiplyI m a b = unsafePerformIO $ do | |||
80 | when (cols a /= rows b) $ error $ | 83 | when (cols a /= rows b) $ error $ |
81 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 84 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
82 | s <- createMatrix ColumnMajor (rows a) (cols b) | 85 | s <- createMatrix ColumnMajor (rows a) (cols b) |
83 | c_multiplyI m # a # b # s #|"c_multiplyI" | 86 | (a # b #! s) (c_multiplyI m) #|"c_multiplyI" |
84 | return s | 87 | return s |
85 | 88 | ||
86 | multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z | 89 | multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z |
@@ -88,7 +91,7 @@ multiplyL m a b = unsafePerformIO $ do | |||
88 | when (cols a /= rows b) $ error $ | 91 | when (cols a /= rows b) $ error $ |
89 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b | 92 | "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b |
90 | s <- createMatrix ColumnMajor (rows a) (cols b) | 93 | s <- createMatrix ColumnMajor (rows a) (cols b) |
91 | c_multiplyL m # a # b # s #|"c_multiplyL" | 94 | (a # b #! s) (c_multiplyL m) #|"c_multiplyL" |
92 | return s | 95 | return s |
93 | 96 | ||
94 | ----------------------------------------------------------------------------- | 97 | ----------------------------------------------------------------------------- |
@@ -121,7 +124,7 @@ svdAux f st x = unsafePerformIO $ do | |||
121 | u <- createMatrix ColumnMajor r r | 124 | u <- createMatrix ColumnMajor r r |
122 | s <- createVector (min r c) | 125 | s <- createVector (min r c) |
123 | v <- createMatrix ColumnMajor c c | 126 | v <- createMatrix ColumnMajor c c |
124 | f # a # u # s # v #| st | 127 | (a # u # s #! v) f #| st |
125 | return (u,s,v) | 128 | return (u,s,v) |
126 | where | 129 | where |
127 | r = rows x | 130 | r = rows x |
@@ -149,7 +152,7 @@ thinSVDAux f st x = unsafePerformIO $ do | |||
149 | u <- createMatrix ColumnMajor r q | 152 | u <- createMatrix ColumnMajor r q |
150 | s <- createVector q | 153 | s <- createVector q |
151 | v <- createMatrix ColumnMajor q c | 154 | v <- createMatrix ColumnMajor q c |
152 | f # a # u # s # v #| st | 155 | (a # u # s #! v) f #| st |
153 | return (u,s,v) | 156 | return (u,s,v) |
154 | where | 157 | where |
155 | r = rows x | 158 | r = rows x |
@@ -176,7 +179,7 @@ svCd = svAux zgesdd "svCd" | |||
176 | svAux f st x = unsafePerformIO $ do | 179 | svAux f st x = unsafePerformIO $ do |
177 | a <- copy ColumnMajor x | 180 | a <- copy ColumnMajor x |
178 | s <- createVector q | 181 | s <- createVector q |
179 | g # a # s #| st | 182 | (a #! s) g #| st |
180 | return s | 183 | return s |
181 | where | 184 | where |
182 | r = rows x | 185 | r = rows x |
@@ -197,7 +200,7 @@ rightSVAux f st x = unsafePerformIO $ do | |||
197 | a <- copy ColumnMajor x | 200 | a <- copy ColumnMajor x |
198 | s <- createVector q | 201 | s <- createVector q |
199 | v <- createMatrix ColumnMajor c c | 202 | v <- createMatrix ColumnMajor c c |
200 | g # a # s # v #| st | 203 | (a # s #! v) g #| st |
201 | return (s,v) | 204 | return (s,v) |
202 | where | 205 | where |
203 | r = rows x | 206 | r = rows x |
@@ -218,7 +221,7 @@ leftSVAux f st x = unsafePerformIO $ do | |||
218 | a <- copy ColumnMajor x | 221 | a <- copy ColumnMajor x |
219 | u <- createMatrix ColumnMajor r r | 222 | u <- createMatrix ColumnMajor r r |
220 | s <- createVector q | 223 | s <- createVector q |
221 | g # a # u # s #| st | 224 | (a # u #! s) g #| st |
222 | return (u,s) | 225 | return (u,s) |
223 | where | 226 | where |
224 | r = rows x | 227 | r = rows x |
@@ -237,7 +240,7 @@ eigAux f st m = unsafePerformIO $ do | |||
237 | a <- copy ColumnMajor m | 240 | a <- copy ColumnMajor m |
238 | l <- createVector r | 241 | l <- createVector r |
239 | v <- createMatrix ColumnMajor r r | 242 | v <- createMatrix ColumnMajor r r |
240 | g # a # l # v #| st | 243 | (a # l #! v) g #| st |
241 | return (l,v) | 244 | return (l,v) |
242 | where | 245 | where |
243 | r = rows m | 246 | r = rows m |
@@ -252,7 +255,7 @@ eigC = eigAux zgeev "eigC" | |||
252 | eigOnlyAux f st m = unsafePerformIO $ do | 255 | eigOnlyAux f st m = unsafePerformIO $ do |
253 | a <- copy ColumnMajor m | 256 | a <- copy ColumnMajor m |
254 | l <- createVector r | 257 | l <- createVector r |
255 | g # a # l #| st | 258 | (a #! l) g #| st |
256 | return l | 259 | return l |
257 | where | 260 | where |
258 | r = rows m | 261 | r = rows m |
@@ -277,7 +280,7 @@ eigRaux m = unsafePerformIO $ do | |||
277 | a <- copy ColumnMajor m | 280 | a <- copy ColumnMajor m |
278 | l <- createVector r | 281 | l <- createVector r |
279 | v <- createMatrix ColumnMajor r r | 282 | v <- createMatrix ColumnMajor r r |
280 | g # a # l # v #| "eigR" | 283 | (a # l #! v) g #| "eigR" |
281 | return (l,v) | 284 | return (l,v) |
282 | where | 285 | where |
283 | r = rows m | 286 | r = rows m |
@@ -305,7 +308,7 @@ eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" | |||
305 | eigSHAux f st m = unsafePerformIO $ do | 308 | eigSHAux f st m = unsafePerformIO $ do |
306 | l <- createVector r | 309 | l <- createVector r |
307 | v <- copy ColumnMajor m | 310 | v <- copy ColumnMajor m |
308 | f # l # v #| st | 311 | (l #! v) f #| st |
309 | return (l,v) | 312 | return (l,v) |
310 | where | 313 | where |
311 | r = rows m | 314 | r = rows m |
@@ -356,7 +359,7 @@ linearSolveSQAux g f st a b | |||
356 | | n1==n2 && n1==r = unsafePerformIO . g $ do | 359 | | n1==n2 && n1==r = unsafePerformIO . g $ do |
357 | a' <- copy ColumnMajor a | 360 | a' <- copy ColumnMajor a |
358 | s <- copy ColumnMajor b | 361 | s <- copy ColumnMajor b |
359 | f # a' # s #| st | 362 | (a' #! s) f #| st |
360 | return s | 363 | return s |
361 | | otherwise = error $ st ++ " of nonsquare matrix" | 364 | | otherwise = error $ st ++ " of nonsquare matrix" |
362 | where | 365 | where |
@@ -387,7 +390,7 @@ foreign import ccall unsafe "cholSolveC_l" zpotrs :: C ::> C ::> Ok | |||
387 | linearSolveSQAux2 g f st a b | 390 | linearSolveSQAux2 g f st a b |
388 | | n1==n2 && n1==r = unsafePerformIO . g $ do | 391 | | n1==n2 && n1==r = unsafePerformIO . g $ do |
389 | s <- copy ColumnMajor b | 392 | s <- copy ColumnMajor b |
390 | f # a # s #| st | 393 | (a #! s) f #| st |
391 | return s | 394 | return s |
392 | | otherwise = error $ st ++ " of nonsquare matrix" | 395 | | otherwise = error $ st ++ " of nonsquare matrix" |
393 | where | 396 | where |
@@ -415,7 +418,7 @@ linearSolveAux f st a b | |||
415 | a' <- copy ColumnMajor a | 418 | a' <- copy ColumnMajor a |
416 | r <- createMatrix ColumnMajor (max m n) nrhs | 419 | r <- createMatrix ColumnMajor (max m n) nrhs |
417 | setRect 0 0 b r | 420 | setRect 0 0 b r |
418 | f # a' # r #| st | 421 | (a' #! r) f #| st |
419 | return r | 422 | return r |
420 | | otherwise = error $ "different number of rows in linearSolve ("++st++")" | 423 | | otherwise = error $ "different number of rows in linearSolve ("++st++")" |
421 | where | 424 | where |
@@ -458,7 +461,7 @@ foreign import ccall unsafe "chol_l_S" dpotrf :: R ::> Ok | |||
458 | 461 | ||
459 | cholAux f st a = do | 462 | cholAux f st a = do |
460 | r <- copy ColumnMajor a | 463 | r <- copy ColumnMajor a |
461 | f # r #| st | 464 | (r # id) f #| st |
462 | return r | 465 | return r |
463 | 466 | ||
464 | -- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/. | 467 | -- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/. |
@@ -495,7 +498,7 @@ qrC = qrAux zgeqr2 "qrC" | |||
495 | qrAux f st a = unsafePerformIO $ do | 498 | qrAux f st a = unsafePerformIO $ do |
496 | r <- copy ColumnMajor a | 499 | r <- copy ColumnMajor a |
497 | tau <- createVector mn | 500 | tau <- createVector mn |
498 | f # tau # r #| st | 501 | (tau #! r) f #| st |
499 | return (r,tau) | 502 | return (r,tau) |
500 | where | 503 | where |
501 | m = rows a | 504 | m = rows a |
@@ -514,7 +517,7 @@ qrgrC = qrgrAux zungqr "qrgrC" | |||
514 | 517 | ||
515 | qrgrAux f st n (a, tau) = unsafePerformIO $ do | 518 | qrgrAux f st n (a, tau) = unsafePerformIO $ do |
516 | res <- copy ColumnMajor (subMatrix (0,0) (rows a,n) a) | 519 | res <- copy ColumnMajor (subMatrix (0,0) (rows a,n) a) |
517 | f # (subVector 0 n tau') # res #| st | 520 | ((subVector 0 n tau') #! res) f #| st |
518 | return res | 521 | return res |
519 | where | 522 | where |
520 | tau' = vjoin [tau, constantD 0 n] | 523 | tau' = vjoin [tau, constantD 0 n] |
@@ -534,7 +537,7 @@ hessC = hessAux zgehrd "hessC" | |||
534 | hessAux f st a = unsafePerformIO $ do | 537 | hessAux f st a = unsafePerformIO $ do |
535 | r <- copy ColumnMajor a | 538 | r <- copy ColumnMajor a |
536 | tau <- createVector (mn-1) | 539 | tau <- createVector (mn-1) |
537 | f # tau # r #| st | 540 | (tau #! r) f #| st |
538 | return (r,tau) | 541 | return (r,tau) |
539 | where | 542 | where |
540 | m = rows a | 543 | m = rows a |
@@ -556,7 +559,7 @@ schurC = schurAux zgees "schurC" | |||
556 | schurAux f st a = unsafePerformIO $ do | 559 | schurAux f st a = unsafePerformIO $ do |
557 | u <- createMatrix ColumnMajor n n | 560 | u <- createMatrix ColumnMajor n n |
558 | s <- copy ColumnMajor a | 561 | s <- copy ColumnMajor a |
559 | f # u # s #| st | 562 | (u #! s) f #| st |
560 | return (u,s) | 563 | return (u,s) |
561 | where | 564 | where |
562 | n = rows a | 565 | n = rows a |
@@ -576,7 +579,7 @@ luC = luAux zgetrf "luC" | |||
576 | luAux f st a = unsafePerformIO $ do | 579 | luAux f st a = unsafePerformIO $ do |
577 | lu <- copy ColumnMajor a | 580 | lu <- copy ColumnMajor a |
578 | piv <- createVector (min n m) | 581 | piv <- createVector (min n m) |
579 | f # piv # lu #| st | 582 | (piv #! lu) f #| st |
580 | return (lu, map (pred.round) (toList piv)) | 583 | return (lu, map (pred.round) (toList piv)) |
581 | where | 584 | where |
582 | n = rows a | 585 | n = rows a |
@@ -598,7 +601,7 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv b | |||
598 | lusAux f st a piv b | 601 | lusAux f st a piv b |
599 | | n1==n2 && n2==n =unsafePerformIO $ do | 602 | | n1==n2 && n2==n =unsafePerformIO $ do |
600 | x <- copy ColumnMajor b | 603 | x <- copy ColumnMajor b |
601 | f # a # piv' # x #| st | 604 | (a # piv' #! x) f #| st |
602 | return x | 605 | return x |
603 | | otherwise = error st | 606 | | otherwise = error st |
604 | where | 607 | where |
@@ -622,7 +625,7 @@ ldlC = ldlAux zhetrf "ldlC" | |||
622 | ldlAux f st a = unsafePerformIO $ do | 625 | ldlAux f st a = unsafePerformIO $ do |
623 | ldl <- copy ColumnMajor a | 626 | ldl <- copy ColumnMajor a |
624 | piv <- createVector (rows a) | 627 | piv <- createVector (rows a) |
625 | f # piv # ldl #| st | 628 | (piv #! ldl) f #| st |
626 | return (ldl, map (pred.round) (toList piv)) | 629 | return (ldl, map (pred.round) (toList piv)) |
627 | 630 | ||
628 | ----------------------------------------------------------------------------------- | 631 | ----------------------------------------------------------------------------------- |
@@ -637,4 +640,3 @@ ldlsR a piv b = lusAux dsytrs "ldlsR" (fmat a) piv b | |||
637 | -- | Solve a complex linear system from a precomputed LDL decomposition ('ldlC'), using LAPACK's /zsytrs/. | 640 | -- | Solve a complex linear system from a precomputed LDL decomposition ('ldlC'), using LAPACK's /zsytrs/. |
638 | ldlsC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) | 641 | ldlsC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) |
639 | ldlsC a piv b = lusAux zsytrs "ldlsC" (fmat a) piv b | 642 | ldlsC a piv b = lusAux zsytrs "ldlsC" (fmat a) piv b |
640 | |||