diff options
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 12ef05a..5163421 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -108,6 +108,14 @@ fmat m | |||
108 | | otherwise = extractAll ColumnMajor m | 108 | | otherwise = extractAll ColumnMajor m |
109 | 109 | ||
110 | 110 | ||
111 | -- C-Haskell matrix adapters | ||
112 | {-# INLINE amatr #-} | ||
113 | amatr :: Storable a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | ||
114 | amatr f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c)) | ||
115 | where | ||
116 | r = fi (rows x) | ||
117 | c = fi (cols x) | ||
118 | |||
111 | {-# INLINE amat #-} | 119 | {-# INLINE amat #-} |
112 | amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b | 120 | amat :: Storable a => (CInt -> CInt -> CInt -> CInt -> Ptr a -> b) -> Matrix a -> b |
113 | amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) | 121 | amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) |
@@ -117,11 +125,16 @@ amat f x = inlinePerformIO (unsafeWith (xdat x) (return . f r c sr sc)) | |||
117 | sr = fi (xRow x) | 125 | sr = fi (xRow x) |
118 | sc = fi (xCol x) | 126 | sc = fi (xCol x) |
119 | 127 | ||
128 | |||
120 | instance Storable t => TransArray (Matrix t) | 129 | instance Storable t => TransArray (Matrix t) |
121 | where | 130 | where |
131 | type Elem (Matrix t) = t | ||
132 | type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b | ||
122 | type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b | 133 | type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b |
123 | apply = amat | 134 | apply = amat |
124 | {-# INLINE apply #-} | 135 | {-# INLINE apply #-} |
136 | applyRaw = amatr | ||
137 | {-# INLINE applyRaw #-} | ||
125 | 138 | ||
126 | infixl 1 # | 139 | infixl 1 # |
127 | a # b = apply a b | 140 | a # b = apply a b |
@@ -564,7 +577,7 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | |||
564 | -------------------------------------------------------------------------------- | 577 | -------------------------------------------------------------------------------- |
565 | 578 | ||
566 | foreign import ccall unsafe "saveMatrix" c_saveMatrix | 579 | foreign import ccall unsafe "saveMatrix" c_saveMatrix |
567 | :: CString -> CString -> Double ::> Ok | 580 | :: CString -> CString -> Double ..> Ok |
568 | 581 | ||
569 | {- | save a matrix as a 2D ASCII table | 582 | {- | save a matrix as a 2D ASCII table |
570 | -} | 583 | -} |
@@ -576,7 +589,7 @@ saveMatrix | |||
576 | saveMatrix name format m = do | 589 | saveMatrix name format m = do |
577 | cname <- newCString name | 590 | cname <- newCString name |
578 | cformat <- newCString format | 591 | cformat <- newCString format |
579 | c_saveMatrix cname cformat `apply` m #|"saveMatrix" | 592 | c_saveMatrix cname cformat `applyRaw` m #|"saveMatrix" |
580 | free cname | 593 | free cname |
581 | free cformat | 594 | free cformat |
582 | return () | 595 | return () |