summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-19 13:55:39 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-19 13:55:39 +0200
commitdb50bc11dafa6834a4367427156306674063ed6b (patch)
tree721e9d0235168be1d0ebb2bd1dd254a66251f274 /packages/base/src/Internal/Matrix.hs
parent7f9c7b5adf8f05653d15f19358f41c1916e8db70 (diff)
removed the annoying appN adapter for the foreign functions.
replaced by several overloaded app variants in the style of the module Internal.Foreign contributed by Mike Ledger.
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs72
1 files changed, 54 insertions, 18 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 8f8c219..db0a609 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -3,6 +3,8 @@
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-} 4{-# LANGUAGE BangPatterns #-}
5{-# LANGUAGE TypeOperators #-} 5{-# LANGUAGE TypeOperators #-}
6{-# LANGUAGE TypeFamilies #-}
7
6 8
7-- | 9-- |
8-- Module : Internal.Matrix 10-- Module : Internal.Matrix
@@ -18,7 +20,7 @@ module Internal.Matrix where
18 20
19import Internal.Vector 21import Internal.Vector
20import Internal.Devel 22import Internal.Devel
21import Internal.Vectorized 23import Internal.Vectorized hiding ((#))
22import Foreign.Marshal.Alloc ( free ) 24import Foreign.Marshal.Alloc ( free )
23import Foreign.Marshal.Array(newArray) 25import Foreign.Marshal.Array(newArray)
24import Foreign.Ptr ( Ptr ) 26import Foreign.Ptr ( Ptr )
@@ -79,8 +81,6 @@ data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int
79-- RowMajor: preferred by C, fdat may require a transposition 81-- RowMajor: preferred by C, fdat may require a transposition
80-- ColumnMajor: preferred by LAPACK, cdat may require a transposition 82-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
81 83
82--cdat = xdat
83--fdat = xdat
84 84
85rows :: Matrix t -> Int 85rows :: Matrix t -> Int
86rows = irows 86rows = irows
@@ -129,6 +129,48 @@ omat a f =
129 g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p 129 g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p
130 f m 130 f m
131 131
132--------------------------------------------------------------------------------
133
134{-# INLINE amatr #-}
135amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
136amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c))
137 where
138 r = fromIntegral (rows x)
139 c = fromIntegral (cols x)
140
141{-# INLINE amat #-}
142amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
143amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc))
144 where
145 r = fromIntegral (rows x)
146 c = fromIntegral (cols x)
147 sr = stepRow x
148 sc = stepCol x
149
150{-# INLINE arrmat #-}
151arrmat :: Storable a => (Ptr CInt -> Ptr a -> b) -> Matrix a -> b
152arrmat f x = inlinePerformIO (unsafeWith s (\p -> unsafeWith (xdat x) (return . f p)))
153 where
154 s = fromList [fi (rows x), fi (cols x), stepRow x, stepCol x]
155
156
157instance Storable t => TransArray (Matrix t)
158 where
159 type Elem (Matrix t) = t
160 type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b
161 type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b
162 apply = amat
163 {-# INLINE apply #-}
164 applyRaw = amatr
165 {-# INLINE applyRaw #-}
166 applyArray = arrmat
167 {-# INLINE applyArray #-}
168
169infixl 1 #
170a # b = apply a b
171{-# INLINE (#) #-}
172
173--------------------------------------------------------------------------------
132 174
133{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 175{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
134 176
@@ -139,12 +181,6 @@ fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
139flatten :: Element t => Matrix t -> Vector t 181flatten :: Element t => Matrix t -> Vector t
140flatten = xdat . cmat 182flatten = xdat . cmat
141 183
142{-
143type Mt t s = Int -> Int -> Ptr t -> s
144
145infixr 6 ::>
146type t ::> s = Mt t s
147-}
148 184
149-- | the inverse of 'Data.Packed.Matrix.fromLists' 185-- | the inverse of 'Data.Packed.Matrix.fromLists'
150toLists :: (Element t) => Matrix t -> [[t]] 186toLists :: (Element t) => Matrix t -> [[t]]
@@ -445,7 +481,7 @@ extractAux f m moder vr modec vc = do
445 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 481 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
446 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 482 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
447 r <- createMatrix RowMajor nr nc 483 r <- createMatrix RowMajor nr nc
448 app4 (f moder modec) vec vr vec vc omat m omat r "extractAux" 484 f moder modec # vr # vc # m # r #|"extract"
449 return r 485 return r
450 486
451type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) 487type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
@@ -459,7 +495,7 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z
459 495
460--------------------------------------------------------------- 496---------------------------------------------------------------
461 497
462setRectAux f i j m r = app2 (f (fi i) (fi j)) omat m omat r "setRect" 498setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect"
463 499
464type SetRect x = I -> I -> x ::> x::> Ok 500type SetRect x = I -> I -> x ::> x::> Ok
465 501
@@ -474,7 +510,7 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
474 510
475sortG f v = unsafePerformIO $ do 511sortG f v = unsafePerformIO $ do
476 r <- createVector (dim v) 512 r <- createVector (dim v)
477 app2 f vec v vec r "sortG" 513 f # v # r #|"sortG"
478 return r 514 return r
479 515
480sortIdxD = sortG c_sort_indexD 516sortIdxD = sortG c_sort_indexD
@@ -501,7 +537,7 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
501 537
502compareG f u v = unsafePerformIO $ do 538compareG f u v = unsafePerformIO $ do
503 r <- createVector (dim v) 539 r <- createVector (dim v)
504 app3 f vec u vec v vec r "compareG" 540 f # u # v # r #|"compareG"
505 return r 541 return r
506 542
507compareD = compareG c_compareD 543compareD = compareG c_compareD
@@ -518,7 +554,7 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
518 554
519selectG f c u v w = unsafePerformIO $ do 555selectG f c u v w = unsafePerformIO $ do
520 r <- createVector (dim v) 556 r <- createVector (dim v)
521 app5 f vec c vec u vec v vec w vec r "selectG" 557 f # c # u # v # w # r #|"selectG"
522 return r 558 return r
523 559
524selectD = selectG c_selectD 560selectD = selectG c_selectD
@@ -541,7 +577,7 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
541 577
542remapG f i j m = unsafePerformIO $ do 578remapG f i j m = unsafePerformIO $ do
543 r <- createMatrix RowMajor (rows i) (cols i) 579 r <- createMatrix RowMajor (rows i) (cols i)
544 app4 f omat i omat j omat m omat r "remapG" 580 f # i # j # m # r #|"remapG"
545 return r 581 return r
546 582
547remapD = remapG c_remapD 583remapD = remapG c_remapD
@@ -564,7 +600,7 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
564 600
565rowOpAux f c x i1 i2 j1 j2 m = do 601rowOpAux f c x i1 i2 j1 j2 m = do
566 px <- newArray [x] 602 px <- newArray [x]
567 app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp" 603 f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp"
568 free px 604 free px
569 605
570type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok 606type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
@@ -580,7 +616,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
580 616
581-------------------------------------------------------------------------------- 617--------------------------------------------------------------------------------
582 618
583gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg" 619gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg"
584 620
585type Tgemm x = x :> I :> x ::> x ::> x ::> Ok 621type Tgemm x = x :> I :> x ::> x ::> x ::> Ok
586 622
@@ -608,7 +644,7 @@ saveMatrix
608saveMatrix name format m = do 644saveMatrix name format m = do
609 cname <- newCString name 645 cname <- newCString name
610 cformat <- newCString format 646 cformat <- newCString format
611 app1 (c_saveMatrix cname cformat) mat m "saveMatrix" 647 c_saveMatrix cname cformat `applyRaw` m #|"saveMatrix"
612 free cname 648 free cname
613 free cformat 649 free cformat
614 return () 650 return ()