summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Internal/Devel.hs5
-rw-r--r--packages/base/src/Internal/LAPACK.hs54
-rw-r--r--packages/base/src/Internal/Matrix.hs35
-rw-r--r--packages/base/src/Internal/Sparse.hs5
-rw-r--r--packages/base/src/Internal/Vector.hs24
-rw-r--r--packages/base/src/Internal/Vectorized.hs34
6 files changed, 80 insertions, 77 deletions
diff --git a/packages/base/src/Internal/Devel.hs b/packages/base/src/Internal/Devel.hs
index 92b5604..3887663 100644
--- a/packages/base/src/Internal/Devel.hs
+++ b/packages/base/src/Internal/Devel.hs
@@ -80,8 +80,8 @@ class TransArray c
80 where 80 where
81 type Trans c b 81 type Trans c b
82 type TransRaw c b 82 type TransRaw c b
83 apply :: (Trans c b) -> c -> b 83 apply :: c -> (b -> IO r) -> (Trans c b) -> IO r
84 applyRaw :: (TransRaw c b) -> c -> b 84 applyRaw :: c -> (b -> IO r) -> (TransRaw c b) -> IO r
85 infixl 1 `apply`, `applyRaw` 85 infixl 1 `apply`, `applyRaw`
86 86
87instance Storable t => TransArray (Vector t) 87instance Storable t => TransArray (Vector t)
@@ -92,4 +92,3 @@ instance Storable t => TransArray (Vector t)
92 {-# INLINE apply #-} 92 {-# INLINE apply #-}
93 applyRaw = avec 93 applyRaw = avec
94 {-# INLINE applyRaw #-} 94 {-# INLINE applyRaw #-}
95
diff --git a/packages/base/src/Internal/LAPACK.hs b/packages/base/src/Internal/LAPACK.hs
index c2c140b..231109a 100644
--- a/packages/base/src/Internal/LAPACK.hs
+++ b/packages/base/src/Internal/LAPACK.hs
@@ -18,7 +18,7 @@ module Internal.LAPACK where
18 18
19import Internal.Devel 19import Internal.Devel
20import Internal.Vector 20import Internal.Vector
21import Internal.Matrix hiding ((#)) 21import Internal.Matrix hiding ((#), (#!))
22import Internal.Conversion 22import Internal.Conversion
23import Internal.Element 23import Internal.Element
24import Foreign.Ptr(nullPtr) 24import Foreign.Ptr(nullPtr)
@@ -28,10 +28,13 @@ import System.IO.Unsafe(unsafePerformIO)
28 28
29----------------------------------------------------------------------------------- 29-----------------------------------------------------------------------------------
30 30
31infixl 1 # 31infixr 1 #
32a # b = apply a b 32a # b = apply a b
33{-# INLINE (#) #-} 33{-# INLINE (#) #-}
34 34
35a #! b = a # b # id
36{-# INLINE (#!) #-}
37
35----------------------------------------------------------------------------------- 38-----------------------------------------------------------------------------------
36 39
37type TMMM t = t ::> t ::> t ::> Ok 40type TMMM t = t ::> t ::> t ::> Ok
@@ -56,7 +59,7 @@ multiplyAux f st a b = unsafePerformIO $ do
56 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++ 59 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++
57 show (rows a,cols a) ++ " x " ++ show (rows b, cols b) 60 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
58 s <- createMatrix ColumnMajor (rows a) (cols b) 61 s <- createMatrix ColumnMajor (rows a) (cols b)
59 f (isT a) (isT b) # (tt a) # (tt b) # s #| st 62 ((tt a) # (tt b) #! s) (f (isT a) (isT b)) #| st
60 return s 63 return s
61 64
62-- | Matrix product based on BLAS's /dgemm/. 65-- | Matrix product based on BLAS's /dgemm/.
@@ -80,7 +83,7 @@ multiplyI m a b = unsafePerformIO $ do
80 when (cols a /= rows b) $ error $ 83 when (cols a /= rows b) $ error $
81 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 84 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
82 s <- createMatrix ColumnMajor (rows a) (cols b) 85 s <- createMatrix ColumnMajor (rows a) (cols b)
83 c_multiplyI m # a # b # s #|"c_multiplyI" 86 (a # b #! s) (c_multiplyI m) #|"c_multiplyI"
84 return s 87 return s
85 88
86multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z 89multiplyL :: Z -> Matrix Z -> Matrix Z -> Matrix Z
@@ -88,7 +91,7 @@ multiplyL m a b = unsafePerformIO $ do
88 when (cols a /= rows b) $ error $ 91 when (cols a /= rows b) $ error $
89 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b 92 "inconsistent dimensions in matrix product "++ shSize a ++ " x " ++ shSize b
90 s <- createMatrix ColumnMajor (rows a) (cols b) 93 s <- createMatrix ColumnMajor (rows a) (cols b)
91 c_multiplyL m # a # b # s #|"c_multiplyL" 94 (a # b #! s) (c_multiplyL m) #|"c_multiplyL"
92 return s 95 return s
93 96
94----------------------------------------------------------------------------- 97-----------------------------------------------------------------------------
@@ -121,7 +124,7 @@ svdAux f st x = unsafePerformIO $ do
121 u <- createMatrix ColumnMajor r r 124 u <- createMatrix ColumnMajor r r
122 s <- createVector (min r c) 125 s <- createVector (min r c)
123 v <- createMatrix ColumnMajor c c 126 v <- createMatrix ColumnMajor c c
124 f # a # u # s # v #| st 127 (a # u # s #! v) f #| st
125 return (u,s,v) 128 return (u,s,v)
126 where 129 where
127 r = rows x 130 r = rows x
@@ -149,7 +152,7 @@ thinSVDAux f st x = unsafePerformIO $ do
149 u <- createMatrix ColumnMajor r q 152 u <- createMatrix ColumnMajor r q
150 s <- createVector q 153 s <- createVector q
151 v <- createMatrix ColumnMajor q c 154 v <- createMatrix ColumnMajor q c
152 f # a # u # s # v #| st 155 (a # u # s #! v) f #| st
153 return (u,s,v) 156 return (u,s,v)
154 where 157 where
155 r = rows x 158 r = rows x
@@ -176,7 +179,7 @@ svCd = svAux zgesdd "svCd"
176svAux f st x = unsafePerformIO $ do 179svAux f st x = unsafePerformIO $ do
177 a <- copy ColumnMajor x 180 a <- copy ColumnMajor x
178 s <- createVector q 181 s <- createVector q
179 g # a # s #| st 182 (a #! s) g #| st
180 return s 183 return s
181 where 184 where
182 r = rows x 185 r = rows x
@@ -197,7 +200,7 @@ rightSVAux f st x = unsafePerformIO $ do
197 a <- copy ColumnMajor x 200 a <- copy ColumnMajor x
198 s <- createVector q 201 s <- createVector q
199 v <- createMatrix ColumnMajor c c 202 v <- createMatrix ColumnMajor c c
200 g # a # s # v #| st 203 (a # s #! v) g #| st
201 return (s,v) 204 return (s,v)
202 where 205 where
203 r = rows x 206 r = rows x
@@ -218,7 +221,7 @@ leftSVAux f st x = unsafePerformIO $ do
218 a <- copy ColumnMajor x 221 a <- copy ColumnMajor x
219 u <- createMatrix ColumnMajor r r 222 u <- createMatrix ColumnMajor r r
220 s <- createVector q 223 s <- createVector q
221 g # a # u # s #| st 224 (a # u #! s) g #| st
222 return (u,s) 225 return (u,s)
223 where 226 where
224 r = rows x 227 r = rows x
@@ -237,7 +240,7 @@ eigAux f st m = unsafePerformIO $ do
237 a <- copy ColumnMajor m 240 a <- copy ColumnMajor m
238 l <- createVector r 241 l <- createVector r
239 v <- createMatrix ColumnMajor r r 242 v <- createMatrix ColumnMajor r r
240 g # a # l # v #| st 243 (a # l #! v) g #| st
241 return (l,v) 244 return (l,v)
242 where 245 where
243 r = rows m 246 r = rows m
@@ -252,7 +255,7 @@ eigC = eigAux zgeev "eigC"
252eigOnlyAux f st m = unsafePerformIO $ do 255eigOnlyAux f st m = unsafePerformIO $ do
253 a <- copy ColumnMajor m 256 a <- copy ColumnMajor m
254 l <- createVector r 257 l <- createVector r
255 g # a # l #| st 258 (a #! l) g #| st
256 return l 259 return l
257 where 260 where
258 r = rows m 261 r = rows m
@@ -277,7 +280,7 @@ eigRaux m = unsafePerformIO $ do
277 a <- copy ColumnMajor m 280 a <- copy ColumnMajor m
278 l <- createVector r 281 l <- createVector r
279 v <- createMatrix ColumnMajor r r 282 v <- createMatrix ColumnMajor r r
280 g # a # l # v #| "eigR" 283 (a # l #! v) g #| "eigR"
281 return (l,v) 284 return (l,v)
282 where 285 where
283 r = rows m 286 r = rows m
@@ -305,7 +308,7 @@ eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR"
305eigSHAux f st m = unsafePerformIO $ do 308eigSHAux f st m = unsafePerformIO $ do
306 l <- createVector r 309 l <- createVector r
307 v <- copy ColumnMajor m 310 v <- copy ColumnMajor m
308 f # l # v #| st 311 (l #! v) f #| st
309 return (l,v) 312 return (l,v)
310 where 313 where
311 r = rows m 314 r = rows m
@@ -356,7 +359,7 @@ linearSolveSQAux g f st a b
356 | n1==n2 && n1==r = unsafePerformIO . g $ do 359 | n1==n2 && n1==r = unsafePerformIO . g $ do
357 a' <- copy ColumnMajor a 360 a' <- copy ColumnMajor a
358 s <- copy ColumnMajor b 361 s <- copy ColumnMajor b
359 f # a' # s #| st 362 (a' #! s) f #| st
360 return s 363 return s
361 | otherwise = error $ st ++ " of nonsquare matrix" 364 | otherwise = error $ st ++ " of nonsquare matrix"
362 where 365 where
@@ -387,7 +390,7 @@ foreign import ccall unsafe "cholSolveC_l" zpotrs :: C ::> C ::> Ok
387linearSolveSQAux2 g f st a b 390linearSolveSQAux2 g f st a b
388 | n1==n2 && n1==r = unsafePerformIO . g $ do 391 | n1==n2 && n1==r = unsafePerformIO . g $ do
389 s <- copy ColumnMajor b 392 s <- copy ColumnMajor b
390 f # a # s #| st 393 (a #! s) f #| st
391 return s 394 return s
392 | otherwise = error $ st ++ " of nonsquare matrix" 395 | otherwise = error $ st ++ " of nonsquare matrix"
393 where 396 where
@@ -415,7 +418,7 @@ linearSolveAux f st a b
415 a' <- copy ColumnMajor a 418 a' <- copy ColumnMajor a
416 r <- createMatrix ColumnMajor (max m n) nrhs 419 r <- createMatrix ColumnMajor (max m n) nrhs
417 setRect 0 0 b r 420 setRect 0 0 b r
418 f # a' # r #| st 421 (a' #! r) f #| st
419 return r 422 return r
420 | otherwise = error $ "different number of rows in linearSolve ("++st++")" 423 | otherwise = error $ "different number of rows in linearSolve ("++st++")"
421 where 424 where
@@ -458,7 +461,7 @@ foreign import ccall unsafe "chol_l_S" dpotrf :: R ::> Ok
458 461
459cholAux f st a = do 462cholAux f st a = do
460 r <- copy ColumnMajor a 463 r <- copy ColumnMajor a
461 f # r #| st 464 (r # id) f #| st
462 return r 465 return r
463 466
464-- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/. 467-- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/.
@@ -495,7 +498,7 @@ qrC = qrAux zgeqr2 "qrC"
495qrAux f st a = unsafePerformIO $ do 498qrAux f st a = unsafePerformIO $ do
496 r <- copy ColumnMajor a 499 r <- copy ColumnMajor a
497 tau <- createVector mn 500 tau <- createVector mn
498 f # tau # r #| st 501 (tau #! r) f #| st
499 return (r,tau) 502 return (r,tau)
500 where 503 where
501 m = rows a 504 m = rows a
@@ -514,7 +517,7 @@ qrgrC = qrgrAux zungqr "qrgrC"
514 517
515qrgrAux f st n (a, tau) = unsafePerformIO $ do 518qrgrAux f st n (a, tau) = unsafePerformIO $ do
516 res <- copy ColumnMajor (subMatrix (0,0) (rows a,n) a) 519 res <- copy ColumnMajor (subMatrix (0,0) (rows a,n) a)
517 f # (subVector 0 n tau') # res #| st 520 ((subVector 0 n tau') #! res) f #| st
518 return res 521 return res
519 where 522 where
520 tau' = vjoin [tau, constantD 0 n] 523 tau' = vjoin [tau, constantD 0 n]
@@ -534,7 +537,7 @@ hessC = hessAux zgehrd "hessC"
534hessAux f st a = unsafePerformIO $ do 537hessAux f st a = unsafePerformIO $ do
535 r <- copy ColumnMajor a 538 r <- copy ColumnMajor a
536 tau <- createVector (mn-1) 539 tau <- createVector (mn-1)
537 f # tau # r #| st 540 (tau #! r) f #| st
538 return (r,tau) 541 return (r,tau)
539 where 542 where
540 m = rows a 543 m = rows a
@@ -556,7 +559,7 @@ schurC = schurAux zgees "schurC"
556schurAux f st a = unsafePerformIO $ do 559schurAux f st a = unsafePerformIO $ do
557 u <- createMatrix ColumnMajor n n 560 u <- createMatrix ColumnMajor n n
558 s <- copy ColumnMajor a 561 s <- copy ColumnMajor a
559 f # u # s #| st 562 (u #! s) f #| st
560 return (u,s) 563 return (u,s)
561 where 564 where
562 n = rows a 565 n = rows a
@@ -576,7 +579,7 @@ luC = luAux zgetrf "luC"
576luAux f st a = unsafePerformIO $ do 579luAux f st a = unsafePerformIO $ do
577 lu <- copy ColumnMajor a 580 lu <- copy ColumnMajor a
578 piv <- createVector (min n m) 581 piv <- createVector (min n m)
579 f # piv # lu #| st 582 (piv #! lu) f #| st
580 return (lu, map (pred.round) (toList piv)) 583 return (lu, map (pred.round) (toList piv))
581 where 584 where
582 n = rows a 585 n = rows a
@@ -598,7 +601,7 @@ lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv b
598lusAux f st a piv b 601lusAux f st a piv b
599 | n1==n2 && n2==n =unsafePerformIO $ do 602 | n1==n2 && n2==n =unsafePerformIO $ do
600 x <- copy ColumnMajor b 603 x <- copy ColumnMajor b
601 f # a # piv' # x #| st 604 (a # piv' #! x) f #| st
602 return x 605 return x
603 | otherwise = error st 606 | otherwise = error st
604 where 607 where
@@ -622,7 +625,7 @@ ldlC = ldlAux zhetrf "ldlC"
622ldlAux f st a = unsafePerformIO $ do 625ldlAux f st a = unsafePerformIO $ do
623 ldl <- copy ColumnMajor a 626 ldl <- copy ColumnMajor a
624 piv <- createVector (rows a) 627 piv <- createVector (rows a)
625 f # piv # ldl #| st 628 (piv #! ldl) f #| st
626 return (ldl, map (pred.round) (toList piv)) 629 return (ldl, map (pred.round) (toList piv))
627 630
628----------------------------------------------------------------------------------- 631-----------------------------------------------------------------------------------
@@ -637,4 +640,3 @@ ldlsR a piv b = lusAux dsytrs "ldlsR" (fmat a) piv b
637-- | Solve a complex linear system from a precomputed LDL decomposition ('ldlC'), using LAPACK's /zsytrs/. 640-- | Solve a complex linear system from a precomputed LDL decomposition ('ldlC'), using LAPACK's /zsytrs/.
638ldlsC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double) 641ldlsC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double)
639ldlsC a piv b = lusAux zsytrs "ldlsC" (fmat a) piv b 642ldlsC a piv b = lusAux zsytrs "ldlsC" (fmat a) piv b
640
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index c47c625..0135288 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -22,7 +22,7 @@ module Internal.Matrix where
22 22
23import Internal.Vector 23import Internal.Vector
24import Internal.Devel 24import Internal.Devel
25import Internal.Vectorized hiding ((#)) 25import Internal.Vectorized hiding ((#), (#!))
26import Foreign.Marshal.Alloc ( free ) 26import Foreign.Marshal.Alloc ( free )
27import Foreign.Marshal.Array(newArray) 27import Foreign.Marshal.Array(newArray)
28import Foreign.Ptr ( Ptr ) 28import Foreign.Ptr ( Ptr )
@@ -110,15 +110,15 @@ fmat m
110 110
111-- C-Haskell matrix adapters 111-- C-Haskell matrix adapters
112{-# INLINE amatr #-} 112{-# INLINE amatr #-}
113amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b 113amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r
114amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) 114amatr x f g = unsafeWith (xdat x) (f . g r c)
115 where 115 where
116 r = fi (rows x) 116 r = fi (rows x)
117 c = fi (cols x) 117 c = fi (cols x)
118 118
119{-# INLINE amat #-} 119{-# INLINE amat #-}
120amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b 120amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r
121amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) 121amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
122 where 122 where
123 r = fi (rows x) 123 r = fi (rows x)
124 c = fi (cols x) 124 c = fi (cols x)
@@ -135,10 +135,13 @@ instance Storable t => TransArray (Matrix t)
135 applyRaw = amatr 135 applyRaw = amatr
136 {-# INLINE applyRaw #-} 136 {-# INLINE applyRaw #-}
137 137
138infixl 1 # 138infixr 1 #
139a # b = apply a b 139a # b = apply a b
140{-# INLINE (#) #-} 140{-# INLINE (#) #-}
141 141
142a #! b = a # b # id
143{-# INLINE (#!) #-}
144
142-------------------------------------------------------------------------------- 145--------------------------------------------------------------------------------
143 146
144copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) 147copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1])
@@ -426,7 +429,8 @@ extractAux f ord m moder vr modec vc = do
426 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 429 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
427 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 430 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
428 r <- createMatrix ord nr nc 431 r <- createMatrix ord nr nc
429 f moder modec # vr # vc # m # r #|"extract" 432 (vr # vc # m #! r) (f moder modec) #|"extract"
433
430 return r 434 return r
431 435
432type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) 436type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
@@ -440,7 +444,7 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z
440 444
441--------------------------------------------------------------- 445---------------------------------------------------------------
442 446
443setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect" 447setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
444 448
445type SetRect x = I -> I -> x ::> x::> Ok 449type SetRect x = I -> I -> x ::> x::> Ok
446 450
@@ -455,7 +459,7 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
455 459
456sortG f v = unsafePerformIO $ do 460sortG f v = unsafePerformIO $ do
457 r <- createVector (dim v) 461 r <- createVector (dim v)
458 f # v # r #|"sortG" 462 (v #! r) f #|"sortG"
459 return r 463 return r
460 464
461sortIdxD = sortG c_sort_indexD 465sortIdxD = sortG c_sort_indexD
@@ -482,7 +486,7 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
482 486
483compareG f u v = unsafePerformIO $ do 487compareG f u v = unsafePerformIO $ do
484 r <- createVector (dim v) 488 r <- createVector (dim v)
485 f # u # v # r #|"compareG" 489 (u # v #! r) f #|"compareG"
486 return r 490 return r
487 491
488compareD = compareG c_compareD 492compareD = compareG c_compareD
@@ -499,7 +503,7 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
499 503
500selectG f c u v w = unsafePerformIO $ do 504selectG f c u v w = unsafePerformIO $ do
501 r <- createVector (dim v) 505 r <- createVector (dim v)
502 f # c # u # v # w # r #|"selectG" 506 (c # u # v # w #! r) f #|"selectG"
503 return r 507 return r
504 508
505selectD = selectG c_selectD 509selectD = selectG c_selectD
@@ -522,7 +526,7 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
522 526
523remapG f i j m = unsafePerformIO $ do 527remapG f i j m = unsafePerformIO $ do
524 r <- createMatrix RowMajor (rows i) (cols i) 528 r <- createMatrix RowMajor (rows i) (cols i)
525 f # i # j # m # r #|"remapG" 529 (i # j # m #! r) f #|"remapG"
526 return r 530 return r
527 531
528remapD = remapG c_remapD 532remapD = remapG c_remapD
@@ -545,7 +549,7 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
545 549
546rowOpAux f c x i1 i2 j1 j2 m = do 550rowOpAux f c x i1 i2 j1 j2 m = do
547 px <- newArray [x] 551 px <- newArray [x]
548 f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp" 552 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
549 free px 553 free px
550 554
551type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok 555type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
@@ -561,7 +565,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
561 565
562-------------------------------------------------------------------------------- 566--------------------------------------------------------------------------------
563 567
564gemmg f v m1 m2 m3 = f # v # m1 # m2 # m3 #|"gemmg" 568gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
565 569
566type Tgemm x = x :> x ::> x ::> x ::> Ok 570type Tgemm x = x :> x ::> x ::> x ::> Ok
567 571
@@ -589,10 +593,9 @@ saveMatrix
589saveMatrix name format m = do 593saveMatrix name format m = do
590 cname <- newCString name 594 cname <- newCString name
591 cformat <- newCString format 595 cformat <- newCString format
592 c_saveMatrix cname cformat # m #|"saveMatrix" 596 (m # id) (c_saveMatrix cname cformat) #|"saveMatrix"
593 free cname 597 free cname
594 free cformat 598 free cformat
595 return () 599 return ()
596 600
597-------------------------------------------------------------------------------- 601--------------------------------------------------------------------------------
598
diff --git a/packages/base/src/Internal/Sparse.hs b/packages/base/src/Internal/Sparse.hs
index 1604e7e..1ff3f57 100644
--- a/packages/base/src/Internal/Sparse.hs
+++ b/packages/base/src/Internal/Sparse.hs
@@ -144,13 +144,13 @@ gmXv :: GMatrix -> Vector Double -> Vector Double
144gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do 144gmXv SparseR { gmCSR = CSR{..}, .. } v = unsafePerformIO $ do
145 dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) 145 dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
146 r <- createVector nRows 146 r <- createVector nRows
147 c_smXv # csrVals # csrCols # csrRows # v # r #|"CSRXv" 147 (csrVals # csrCols # csrRows # v #! r) c_smXv #|"CSRXv"
148 return r 148 return r
149 149
150gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do 150gmXv SparseC { gmCSC = CSC{..}, .. } v = unsafePerformIO $ do
151 dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) 151 dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
152 r <- createVector nRows 152 r <- createVector nRows
153 c_smTXv # cscVals # cscRows # cscCols # v # r #|"CSCXv" 153 (cscVals # cscRows # cscCols # v #! r) c_smTXv #|"CSCXv"
154 return r 154 return r
155 155
156gmXv Diag{..} v 156gmXv Diag{..} v
@@ -211,4 +211,3 @@ instance Transposable GMatrix GMatrix
211 tr (Diag v n m) = Diag v m n 211 tr (Diag v n m) = Diag v m n
212 tr (Dense a n m) = Dense (tr a) m n 212 tr (Dense a n m) = Dense (tr a) m n
213 tr' = tr 213 tr' = tr
214
diff --git a/packages/base/src/Internal/Vector.hs b/packages/base/src/Internal/Vector.hs
index b4e235c..c4a310d 100644
--- a/packages/base/src/Internal/Vector.hs
+++ b/packages/base/src/Internal/Vector.hs
@@ -66,9 +66,8 @@ dim = Vector.length
66 66
67-- C-Haskell vector adapter 67-- C-Haskell vector adapter
68{-# INLINE avec #-} 68{-# INLINE avec #-}
69avec :: Storable a => (CInt -> Ptr a -> b) -> Vector a -> b 69avec :: Storable a => Vector a -> (f -> IO r) -> ((CInt -> Ptr a -> f) -> IO r)
70avec f v = inlinePerformIO (unsafeWith v (return . f (fromIntegral (Vector.length v)))) 70avec v f g = unsafeWith v $ \ptr -> f (g (fromIntegral (Vector.length v)) ptr)
71infixl 1 `avec`
72 71
73-- allocates memory for a new vector 72-- allocates memory for a new vector
74createVector :: Storable a => Int -> IO (Vector a) 73createVector :: Storable a => Int -> IO (Vector a)
@@ -199,7 +198,7 @@ takesV ms w | sum ms > dim w = error $ "takesV " ++ show ms ++ " on dim = " ++ (
199 198
200--------------------------------------------------------------- 199---------------------------------------------------------------
201 200
202-- | transforms a complex vector into a real vector with alternating real and imaginary parts 201-- | transforms a complex vector into a real vector with alternating real and imaginary parts
203asReal :: (RealFloat a, Storable a) => Vector (Complex a) -> Vector a 202asReal :: (RealFloat a, Storable a) => Vector (Complex a) -> Vector a
204asReal v = unsafeFromForeignPtr (castForeignPtr fp) (2*i) (2*n) 203asReal v = unsafeFromForeignPtr (castForeignPtr fp) (2*i) (2*n)
205 where (fp,i,n) = unsafeToForeignPtr v 204 where (fp,i,n) = unsafeToForeignPtr v
@@ -244,7 +243,7 @@ zipVectorWith f u v = unsafePerformIO $ do
244{-# INLINE zipVectorWith #-} 243{-# INLINE zipVectorWith #-}
245 244
246-- | unzipWith for Vectors 245-- | unzipWith for Vectors
247unzipVectorWith :: (Storable (a,b), Storable c, Storable d) 246unzipVectorWith :: (Storable (a,b), Storable c, Storable d)
248 => ((a,b) -> (c,d)) -> Vector (a,b) -> (Vector c,Vector d) 247 => ((a,b) -> (c,d)) -> Vector (a,b) -> (Vector c,Vector d)
249unzipVectorWith f u = unsafePerformIO $ do 248unzipVectorWith f u = unsafePerformIO $ do
250 let n = dim u 249 let n = dim u
@@ -255,7 +254,7 @@ unzipVectorWith f u = unsafePerformIO $ do
255 unsafeWith w $ \pw -> do 254 unsafeWith w $ \pw -> do
256 let go (-1) = return () 255 let go (-1) = return ()
257 go !k = do z <- peekElemOff pu k 256 go !k = do z <- peekElemOff pu k
258 let (x,y) = f z 257 let (x,y) = f z
259 pokeElemOff pv k x 258 pokeElemOff pv k x
260 pokeElemOff pw k y 259 pokeElemOff pw k y
261 go (k-1) 260 go (k-1)
@@ -303,11 +302,11 @@ mapVectorM f v = do
303 return w 302 return w
304 where mapVectorM' w' !k !t 303 where mapVectorM' w' !k !t
305 | k == t = do 304 | k == t = do
306 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k 305 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
307 y <- f x 306 y <- f x
308 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y 307 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
309 | otherwise = do 308 | otherwise = do
310 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k 309 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
311 y <- f x 310 y <- f x
312 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y 311 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
313 mapVectorM' w' (k+1) t 312 mapVectorM' w' (k+1) t
@@ -322,7 +321,7 @@ mapVectorM_ f v = do
322 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k 321 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
323 f x 322 f x
324 | otherwise = do 323 | otherwise = do
325 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k 324 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
326 _ <- f x 325 _ <- f x
327 mapVectorM' (k+1) t 326 mapVectorM' (k+1) t
328{-# INLINE mapVectorM_ #-} 327{-# INLINE mapVectorM_ #-}
@@ -336,11 +335,11 @@ mapVectorWithIndexM f v = do
336 return w 335 return w
337 where mapVectorM' w' !k !t 336 where mapVectorM' w' !k !t
338 | k == t = do 337 | k == t = do
339 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k 338 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
340 y <- f k x 339 y <- f k x
341 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y 340 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
342 | otherwise = do 341 | otherwise = do
343 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k 342 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
344 y <- f k x 343 y <- f k x
345 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y 344 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
346 mapVectorM' w' (k+1) t 345 mapVectorM' w' (k+1) t
@@ -355,7 +354,7 @@ mapVectorWithIndexM_ f v = do
355 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k 354 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
356 f k x 355 f k x
357 | otherwise = do 356 | otherwise = do
358 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k 357 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
359 _ <- f k x 358 _ <- f k x
360 mapVectorM' (k+1) t 359 mapVectorM' (k+1) t
361{-# INLINE mapVectorWithIndexM_ #-} 360{-# INLINE mapVectorWithIndexM_ #-}
@@ -454,4 +453,3 @@ unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vect
454unzipVector = unzipVectorWith id 453unzipVector = unzipVectorWith id
455 454
456------------------------------------------------------------------- 455-------------------------------------------------------------------
457
diff --git a/packages/base/src/Internal/Vectorized.hs b/packages/base/src/Internal/Vectorized.hs
index 03bcf90..a410bb2 100644
--- a/packages/base/src/Internal/Vectorized.hs
+++ b/packages/base/src/Internal/Vectorized.hs
@@ -27,10 +27,13 @@ import Foreign.C.String
27import System.IO.Unsafe(unsafePerformIO) 27import System.IO.Unsafe(unsafePerformIO)
28import Control.Monad(when) 28import Control.Monad(when)
29 29
30infixl 1 # 30infixr 1 #
31a # b = applyRaw a b 31a # b = applyRaw a b
32{-# INLINE (#) #-} 32{-# INLINE (#) #-}
33 33
34a #! b = a # b # id
35{-# INLINE (#!) #-}
36
34fromei x = fromIntegral (fromEnum x) :: CInt 37fromei x = fromIntegral (fromEnum x) :: CInt
35 38
36data FunCodeV = Sin 39data FunCodeV = Sin
@@ -103,7 +106,7 @@ sumL m = sumg (c_sumL m)
103 106
104sumg f x = unsafePerformIO $ do 107sumg f x = unsafePerformIO $ do
105 r <- createVector 1 108 r <- createVector 1
106 f # x # r #| "sum" 109 (x #! r) f #| "sum"
107 return $ r @> 0 110 return $ r @> 0
108 111
109type TVV t = t :> t :> Ok 112type TVV t = t :> t :> Ok
@@ -139,7 +142,7 @@ prodL = prodg . c_prodL
139 142
140prodg f x = unsafePerformIO $ do 143prodg f x = unsafePerformIO $ do
141 r <- createVector 1 144 r <- createVector 1
142 f # x # r #| "prod" 145 (x #! r) f #| "prod"
143 return $ r @> 0 146 return $ r @> 0
144 147
145 148
@@ -154,24 +157,24 @@ foreign import ccall unsafe "prodL" c_prodL :: Z -> TVV Z
154 157
155toScalarAux fun code v = unsafePerformIO $ do 158toScalarAux fun code v = unsafePerformIO $ do
156 r <- createVector 1 159 r <- createVector 1
157 fun (fromei code) # v # r #|"toScalarAux" 160 (v #! r) (fun (fromei code)) #|"toScalarAux"
158 return (r @> 0) 161 return (r @> 0)
159 162
160vectorMapAux fun code v = unsafePerformIO $ do 163vectorMapAux fun code v = unsafePerformIO $ do
161 r <- createVector (dim v) 164 r <- createVector (dim v)
162 fun (fromei code) # v # r #|"vectorMapAux" 165 (v #! r) (fun (fromei code)) #|"vectorMapAux"
163 return r 166 return r
164 167
165vectorMapValAux fun code val v = unsafePerformIO $ do 168vectorMapValAux fun code val v = unsafePerformIO $ do
166 r <- createVector (dim v) 169 r <- createVector (dim v)
167 pval <- newArray [val] 170 pval <- newArray [val]
168 fun (fromei code) pval # v # r #|"vectorMapValAux" 171 (v #! r) (fun (fromei code) pval) #|"vectorMapValAux"
169 free pval 172 free pval
170 return r 173 return r
171 174
172vectorZipAux fun code u v = unsafePerformIO $ do 175vectorZipAux fun code u v = unsafePerformIO $ do
173 r <- createVector (dim u) 176 r <- createVector (dim u)
174 fun (fromei code) # u # v # r #|"vectorZipAux" 177 (u # v #! r) (fun (fromei code)) #|"vectorZipAux"
175 return r 178 return r
176 179
177--------------------------------------------------------------------- 180---------------------------------------------------------------------
@@ -368,7 +371,7 @@ randomVector :: Seed
368 -> Vector Double 371 -> Vector Double
369randomVector seed dist n = unsafePerformIO $ do 372randomVector seed dist n = unsafePerformIO $ do
370 r <- createVector n 373 r <- createVector n
371 c_random_vector (fi seed) ((fi.fromEnum) dist) # r #|"randomVector" 374 (r # id) (c_random_vector (fi seed) ((fi.fromEnum) dist)) #|"randomVector"
372 return r 375 return r
373 376
374foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok 377foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> Double :> Ok
@@ -377,7 +380,7 @@ foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> D
377 380
378roundVector v = unsafePerformIO $ do 381roundVector v = unsafePerformIO $ do
379 r <- createVector (dim v) 382 r <- createVector (dim v)
380 c_round_vector # v # r #|"roundVector" 383 (v #! r) c_round_vector #|"roundVector"
381 return r 384 return r
382 385
383foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double 386foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double
@@ -391,7 +394,7 @@ foreign import ccall unsafe "round_vector" c_round_vector :: TVV Double
391range :: Int -> Vector I 394range :: Int -> Vector I
392range n = unsafePerformIO $ do 395range n = unsafePerformIO $ do
393 r <- createVector n 396 r <- createVector n
394 c_range_vector # r #|"range" 397 (r # id) c_range_vector #|"range"
395 return r 398 return r
396 399
397foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok 400foreign import ccall unsafe "range_vector" c_range_vector :: CInt :> Ok
@@ -431,7 +434,7 @@ long2intV = tog c_long2int
431 434
432tog f v = unsafePerformIO $ do 435tog f v = unsafePerformIO $ do
433 r <- createVector (dim v) 436 r <- createVector (dim v)
434 f # v # r #|"tog" 437 (v #! r) f #|"tog"
435 return r 438 return r
436 439
437foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok 440foreign import ccall unsafe "float2double" c_float2double :: Float :> Double :> Ok
@@ -450,7 +453,7 @@ foreign import ccall unsafe "long2int" c_long2int :: Z :> I :> Ok
450 453
451stepg f v = unsafePerformIO $ do 454stepg f v = unsafePerformIO $ do
452 r <- createVector (dim v) 455 r <- createVector (dim v)
453 f # v # r #|"step" 456 (v #! r) f #|"step"
454 return r 457 return r
455 458
456stepD :: Vector Double -> Vector Double 459stepD :: Vector Double -> Vector Double
@@ -475,7 +478,7 @@ foreign import ccall unsafe "stepL" c_stepL :: TVV Z
475 478
476conjugateAux fun x = unsafePerformIO $ do 479conjugateAux fun x = unsafePerformIO $ do
477 v <- createVector (dim x) 480 v <- createVector (dim x)
478 fun # x # v #|"conjugateAux" 481 (x #! v) fun #|"conjugateAux"
479 return v 482 return v
480 483
481conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) 484conjugateQ :: Vector (Complex Float) -> Vector (Complex Float)
@@ -493,7 +496,7 @@ cloneVector v = do
493 let n = dim v 496 let n = dim v
494 r <- createVector n 497 r <- createVector n
495 let f _ s _ d = copyArray d s n >> return 0 498 let f _ s _ d = copyArray d s n >> return 0
496 f # v # r #|"cloneVector" 499 (v #! r) f #|"cloneVector"
497 return r 500 return r
498 501
499-------------------------------------------------------------------------------- 502--------------------------------------------------------------------------------
@@ -501,7 +504,7 @@ cloneVector v = do
501constantAux fun x n = unsafePerformIO $ do 504constantAux fun x n = unsafePerformIO $ do
502 v <- createVector n 505 v <- createVector n
503 px <- newArray [x] 506 px <- newArray [x]
504 fun px # v #|"constantAux" 507 (v # id) (fun px) #|"constantAux"
505 free px 508 free px
506 return v 509 return v
507 510
@@ -515,4 +518,3 @@ foreign import ccall unsafe "constantI" cconstantI :: TConst CInt
515foreign import ccall unsafe "constantL" cconstantL :: TConst Z 518foreign import ccall unsafe "constantL" cconstantL :: TConst Z
516 519
517---------------------------------------------------------------------- 520----------------------------------------------------------------------
518