summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs72
1 files changed, 54 insertions, 18 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 8f8c219..db0a609 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -3,6 +3,8 @@
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-} 4{-# LANGUAGE BangPatterns #-}
5{-# LANGUAGE TypeOperators #-} 5{-# LANGUAGE TypeOperators #-}
6{-# LANGUAGE TypeFamilies #-}
7
6 8
7-- | 9-- |
8-- Module : Internal.Matrix 10-- Module : Internal.Matrix
@@ -18,7 +20,7 @@ module Internal.Matrix where
18 20
19import Internal.Vector 21import Internal.Vector
20import Internal.Devel 22import Internal.Devel
21import Internal.Vectorized 23import Internal.Vectorized hiding ((#))
22import Foreign.Marshal.Alloc ( free ) 24import Foreign.Marshal.Alloc ( free )
23import Foreign.Marshal.Array(newArray) 25import Foreign.Marshal.Array(newArray)
24import Foreign.Ptr ( Ptr ) 26import Foreign.Ptr ( Ptr )
@@ -79,8 +81,6 @@ data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int
79-- RowMajor: preferred by C, fdat may require a transposition 81-- RowMajor: preferred by C, fdat may require a transposition
80-- ColumnMajor: preferred by LAPACK, cdat may require a transposition 82-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
81 83
82--cdat = xdat
83--fdat = xdat
84 84
85rows :: Matrix t -> Int 85rows :: Matrix t -> Int
86rows = irows 86rows = irows
@@ -129,6 +129,48 @@ omat a f =
129 g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p 129 g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p
130 f m 130 f m
131 131
132--------------------------------------------------------------------------------
133
134{-# INLINE amatr #-}
135amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
136amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c))
137 where
138 r = fromIntegral (rows x)
139 c = fromIntegral (cols x)
140
141{-# INLINE amat #-}
142amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
143amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc))
144 where
145 r = fromIntegral (rows x)
146 c = fromIntegral (cols x)
147 sr = stepRow x
148 sc = stepCol x
149
150{-# INLINE arrmat #-}
151arrmat :: Storable a => (Ptr CInt -> Ptr a -> b) -> Matrix a -> b
152arrmat f x = inlinePerformIO (unsafeWith s (\p -> unsafeWith (xdat x) (return . f p)))
153 where
154 s = fromList [fi (rows x), fi (cols x), stepRow x, stepCol x]
155
156
157instance Storable t => TransArray (Matrix t)
158 where
159 type Elem (Matrix t) = t
160 type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b
161 type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b
162 apply = amat
163 {-# INLINE apply #-}
164 applyRaw = amatr
165 {-# INLINE applyRaw #-}
166 applyArray = arrmat
167 {-# INLINE applyArray #-}
168
169infixl 1 #
170a # b = apply a b
171{-# INLINE (#) #-}
172
173--------------------------------------------------------------------------------
132 174
133{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 175{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
134 176
@@ -139,12 +181,6 @@ fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
139flatten :: Element t => Matrix t -> Vector t 181flatten :: Element t => Matrix t -> Vector t
140flatten = xdat . cmat 182flatten = xdat . cmat
141 183
142{-
143type Mt t s = Int -> Int -> Ptr t -> s
144
145infixr 6 ::>
146type t ::> s = Mt t s
147-}
148 184
149-- | the inverse of 'Data.Packed.Matrix.fromLists' 185-- | the inverse of 'Data.Packed.Matrix.fromLists'
150toLists :: (Element t) => Matrix t -> [[t]] 186toLists :: (Element t) => Matrix t -> [[t]]
@@ -445,7 +481,7 @@ extractAux f m moder vr modec vc = do
445 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 481 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
446 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 482 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
447 r <- createMatrix RowMajor nr nc 483 r <- createMatrix RowMajor nr nc
448 app4 (f moder modec) vec vr vec vc omat m omat r "extractAux" 484 f moder modec # vr # vc # m # r #|"extract"
449 return r 485 return r
450 486
451type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) 487type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
@@ -459,7 +495,7 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z
459 495
460--------------------------------------------------------------- 496---------------------------------------------------------------
461 497
462setRectAux f i j m r = app2 (f (fi i) (fi j)) omat m omat r "setRect" 498setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect"
463 499
464type SetRect x = I -> I -> x ::> x::> Ok 500type SetRect x = I -> I -> x ::> x::> Ok
465 501
@@ -474,7 +510,7 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
474 510
475sortG f v = unsafePerformIO $ do 511sortG f v = unsafePerformIO $ do
476 r <- createVector (dim v) 512 r <- createVector (dim v)
477 app2 f vec v vec r "sortG" 513 f # v # r #|"sortG"
478 return r 514 return r
479 515
480sortIdxD = sortG c_sort_indexD 516sortIdxD = sortG c_sort_indexD
@@ -501,7 +537,7 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
501 537
502compareG f u v = unsafePerformIO $ do 538compareG f u v = unsafePerformIO $ do
503 r <- createVector (dim v) 539 r <- createVector (dim v)
504 app3 f vec u vec v vec r "compareG" 540 f # u # v # r #|"compareG"
505 return r 541 return r
506 542
507compareD = compareG c_compareD 543compareD = compareG c_compareD
@@ -518,7 +554,7 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
518 554
519selectG f c u v w = unsafePerformIO $ do 555selectG f c u v w = unsafePerformIO $ do
520 r <- createVector (dim v) 556 r <- createVector (dim v)
521 app5 f vec c vec u vec v vec w vec r "selectG" 557 f # c # u # v # w # r #|"selectG"
522 return r 558 return r
523 559
524selectD = selectG c_selectD 560selectD = selectG c_selectD
@@ -541,7 +577,7 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
541 577
542remapG f i j m = unsafePerformIO $ do 578remapG f i j m = unsafePerformIO $ do
543 r <- createMatrix RowMajor (rows i) (cols i) 579 r <- createMatrix RowMajor (rows i) (cols i)
544 app4 f omat i omat j omat m omat r "remapG" 580 f # i # j # m # r #|"remapG"
545 return r 581 return r
546 582
547remapD = remapG c_remapD 583remapD = remapG c_remapD
@@ -564,7 +600,7 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
564 600
565rowOpAux f c x i1 i2 j1 j2 m = do 601rowOpAux f c x i1 i2 j1 j2 m = do
566 px <- newArray [x] 602 px <- newArray [x]
567 app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp" 603 f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp"
568 free px 604 free px
569 605
570type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok 606type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
@@ -580,7 +616,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
580 616
581-------------------------------------------------------------------------------- 617--------------------------------------------------------------------------------
582 618
583gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg" 619gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg"
584 620
585type Tgemm x = x :> I :> x ::> x ::> x ::> Ok 621type Tgemm x = x :> I :> x ::> x ::> x ::> Ok
586 622
@@ -608,7 +644,7 @@ saveMatrix
608saveMatrix name format m = do 644saveMatrix name format m = do
609 cname <- newCString name 645 cname <- newCString name
610 cformat <- newCString format 646 cformat <- newCString format
611 app1 (c_saveMatrix cname cformat) mat m "saveMatrix" 647 c_saveMatrix cname cformat `applyRaw` m #|"saveMatrix"
612 free cname 648 free cname
613 free cformat 649 free cformat
614 return () 650 return ()