diff options
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 72 |
1 files changed, 54 insertions, 18 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 8f8c219..db0a609 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -3,6 +3,8 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE BangPatterns #-} | 4 | {-# LANGUAGE BangPatterns #-} |
5 | {-# LANGUAGE TypeOperators #-} | 5 | {-# LANGUAGE TypeOperators #-} |
6 | {-# LANGUAGE TypeFamilies #-} | ||
7 | |||
6 | 8 | ||
7 | -- | | 9 | -- | |
8 | -- Module : Internal.Matrix | 10 | -- Module : Internal.Matrix |
@@ -18,7 +20,7 @@ module Internal.Matrix where | |||
18 | 20 | ||
19 | import Internal.Vector | 21 | import Internal.Vector |
20 | import Internal.Devel | 22 | import Internal.Devel |
21 | import Internal.Vectorized | 23 | import Internal.Vectorized hiding ((#)) |
22 | import Foreign.Marshal.Alloc ( free ) | 24 | import Foreign.Marshal.Alloc ( free ) |
23 | import Foreign.Marshal.Array(newArray) | 25 | import Foreign.Marshal.Array(newArray) |
24 | import Foreign.Ptr ( Ptr ) | 26 | import Foreign.Ptr ( Ptr ) |
@@ -79,8 +81,6 @@ data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int | |||
79 | -- RowMajor: preferred by C, fdat may require a transposition | 81 | -- RowMajor: preferred by C, fdat may require a transposition |
80 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition | 82 | -- ColumnMajor: preferred by LAPACK, cdat may require a transposition |
81 | 83 | ||
82 | --cdat = xdat | ||
83 | --fdat = xdat | ||
84 | 84 | ||
85 | rows :: Matrix t -> Int | 85 | rows :: Matrix t -> Int |
86 | rows = irows | 86 | rows = irows |
@@ -129,6 +129,48 @@ omat a f = | |||
129 | g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p | 129 | g (fi (rows a)) (fi (cols a)) (stepRow a) (stepCol a) p |
130 | f m | 130 | f m |
131 | 131 | ||
132 | -------------------------------------------------------------------------------- | ||
133 | |||
134 | {-# INLINE amatr #-} | ||
135 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | ||
136 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) | ||
137 | where | ||
138 | r = fromIntegral (rows x) | ||
139 | c = fromIntegral (cols x) | ||
140 | |||
141 | {-# INLINE amat #-} | ||
142 | amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | ||
143 | amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) | ||
144 | where | ||
145 | r = fromIntegral (rows x) | ||
146 | c = fromIntegral (cols x) | ||
147 | sr = stepRow x | ||
148 | sc = stepCol x | ||
149 | |||
150 | {-# INLINE arrmat #-} | ||
151 | arrmat :: Storable a => (Ptr CInt -> Ptr a -> b) -> Matrix a -> b | ||
152 | arrmat f x = inlinePerformIO (unsafeWith s (\p -> unsafeWith (xdat x) (return . f p))) | ||
153 | where | ||
154 | s = fromList [fi (rows x), fi (cols x), stepRow x, stepCol x] | ||
155 | |||
156 | |||
157 | instance Storable t => TransArray (Matrix t) | ||
158 | where | ||
159 | type Elem (Matrix t) = t | ||
160 | type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b | ||
161 | type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b | ||
162 | apply = amat | ||
163 | {-# INLINE apply #-} | ||
164 | applyRaw = amatr | ||
165 | {-# INLINE applyRaw #-} | ||
166 | applyArray = arrmat | ||
167 | {-# INLINE applyArray #-} | ||
168 | |||
169 | infixl 1 # | ||
170 | a # b = apply a b | ||
171 | {-# INLINE (#) #-} | ||
172 | |||
173 | -------------------------------------------------------------------------------- | ||
132 | 174 | ||
133 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 175 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
134 | 176 | ||
@@ -139,12 +181,6 @@ fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] | |||
139 | flatten :: Element t => Matrix t -> Vector t | 181 | flatten :: Element t => Matrix t -> Vector t |
140 | flatten = xdat . cmat | 182 | flatten = xdat . cmat |
141 | 183 | ||
142 | {- | ||
143 | type Mt t s = Int -> Int -> Ptr t -> s | ||
144 | |||
145 | infixr 6 ::> | ||
146 | type t ::> s = Mt t s | ||
147 | -} | ||
148 | 184 | ||
149 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | 185 | -- | the inverse of 'Data.Packed.Matrix.fromLists' |
150 | toLists :: (Element t) => Matrix t -> [[t]] | 186 | toLists :: (Element t) => Matrix t -> [[t]] |
@@ -445,7 +481,7 @@ extractAux f m moder vr modec vc = do | |||
445 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 481 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
446 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 482 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
447 | r <- createMatrix RowMajor nr nc | 483 | r <- createMatrix RowMajor nr nc |
448 | app4 (f moder modec) vec vr vec vc omat m omat r "extractAux" | 484 | f moder modec # vr # vc # m # r #|"extract" |
449 | return r | 485 | return r |
450 | 486 | ||
451 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) | 487 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) |
@@ -459,7 +495,7 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z | |||
459 | 495 | ||
460 | --------------------------------------------------------------- | 496 | --------------------------------------------------------------- |
461 | 497 | ||
462 | setRectAux f i j m r = app2 (f (fi i) (fi j)) omat m omat r "setRect" | 498 | setRectAux f i j m r = f (fi i) (fi j) # m # r #|"setRect" |
463 | 499 | ||
464 | type SetRect x = I -> I -> x ::> x::> Ok | 500 | type SetRect x = I -> I -> x ::> x::> Ok |
465 | 501 | ||
@@ -474,7 +510,7 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | |||
474 | 510 | ||
475 | sortG f v = unsafePerformIO $ do | 511 | sortG f v = unsafePerformIO $ do |
476 | r <- createVector (dim v) | 512 | r <- createVector (dim v) |
477 | app2 f vec v vec r "sortG" | 513 | f # v # r #|"sortG" |
478 | return r | 514 | return r |
479 | 515 | ||
480 | sortIdxD = sortG c_sort_indexD | 516 | sortIdxD = sortG c_sort_indexD |
@@ -501,7 +537,7 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | |||
501 | 537 | ||
502 | compareG f u v = unsafePerformIO $ do | 538 | compareG f u v = unsafePerformIO $ do |
503 | r <- createVector (dim v) | 539 | r <- createVector (dim v) |
504 | app3 f vec u vec v vec r "compareG" | 540 | f # u # v # r #|"compareG" |
505 | return r | 541 | return r |
506 | 542 | ||
507 | compareD = compareG c_compareD | 543 | compareD = compareG c_compareD |
@@ -518,7 +554,7 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | |||
518 | 554 | ||
519 | selectG f c u v w = unsafePerformIO $ do | 555 | selectG f c u v w = unsafePerformIO $ do |
520 | r <- createVector (dim v) | 556 | r <- createVector (dim v) |
521 | app5 f vec c vec u vec v vec w vec r "selectG" | 557 | f # c # u # v # w # r #|"selectG" |
522 | return r | 558 | return r |
523 | 559 | ||
524 | selectD = selectG c_selectD | 560 | selectD = selectG c_selectD |
@@ -541,7 +577,7 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | |||
541 | 577 | ||
542 | remapG f i j m = unsafePerformIO $ do | 578 | remapG f i j m = unsafePerformIO $ do |
543 | r <- createMatrix RowMajor (rows i) (cols i) | 579 | r <- createMatrix RowMajor (rows i) (cols i) |
544 | app4 f omat i omat j omat m omat r "remapG" | 580 | f # i # j # m # r #|"remapG" |
545 | return r | 581 | return r |
546 | 582 | ||
547 | remapD = remapG c_remapD | 583 | remapD = remapG c_remapD |
@@ -564,7 +600,7 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
564 | 600 | ||
565 | rowOpAux f c x i1 i2 j1 j2 m = do | 601 | rowOpAux f c x i1 i2 j1 j2 m = do |
566 | px <- newArray [x] | 602 | px <- newArray [x] |
567 | app1 (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) omat m "rowOp" | 603 | f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2) # m #|"rowOp" |
568 | free px | 604 | free px |
569 | 605 | ||
570 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | 606 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok |
@@ -580,7 +616,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
580 | 616 | ||
581 | -------------------------------------------------------------------------------- | 617 | -------------------------------------------------------------------------------- |
582 | 618 | ||
583 | gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg" | 619 | gemmg f u v m1 m2 m3 = f # u # v # m1 # m2 # m3 #|"gemmg" |
584 | 620 | ||
585 | type Tgemm x = x :> I :> x ::> x ::> x ::> Ok | 621 | type Tgemm x = x :> I :> x ::> x ::> x ::> Ok |
586 | 622 | ||
@@ -608,7 +644,7 @@ saveMatrix | |||
608 | saveMatrix name format m = do | 644 | saveMatrix name format m = do |
609 | cname <- newCString name | 645 | cname <- newCString name |
610 | cformat <- newCString format | 646 | cformat <- newCString format |
611 | app1 (c_saveMatrix cname cformat) mat m "saveMatrix" | 647 | c_saveMatrix cname cformat `applyRaw` m #|"saveMatrix" |
612 | free cname | 648 | free cname |
613 | free cformat | 649 | free cformat |
614 | return () | 650 | return () |