diff options
author | Joe Crayne <joe@jerkface.net> | 2019-08-08 02:22:30 -0400 |
---|---|---|
committer | Joe Crayne <joe@jerkface.net> | 2019-08-08 22:47:46 -0400 |
commit | badcbdfddc4be31fc79a6df4553795af18069efe (patch) | |
tree | 90c38bd8793b53a5e6f00049eb78acaa8d88d711 /packages/base/src/Internal/Matrix.hs | |
parent | d844a145f2e8808c9f75cd99c673d5f5c8960bf2 (diff) |
Removed the Element class.tower
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 307 |
1 files changed, 123 insertions, 184 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 5436e59..04092f9 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -2,6 +2,7 @@ | |||
2 | {-# LANGUAGE FlexibleContexts #-} | 2 | {-# LANGUAGE FlexibleContexts #-} |
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | {-# LANGUAGE BangPatterns #-} | 4 | {-# LANGUAGE BangPatterns #-} |
5 | {-# LANGUAGE CPP #-} | ||
5 | {-# LANGUAGE TypeOperators #-} | 6 | {-# LANGUAGE TypeOperators #-} |
6 | {-# LANGUAGE TypeFamilies #-} | 7 | {-# LANGUAGE TypeFamilies #-} |
7 | {-# LANGUAGE ViewPatterns #-} | 8 | {-# LANGUAGE ViewPatterns #-} |
@@ -22,12 +23,14 @@ module Internal.Matrix where | |||
22 | 23 | ||
23 | import Internal.Vector | 24 | import Internal.Vector |
24 | import Internal.Devel | 25 | import Internal.Devel |
26 | import Internal.Extract | ||
25 | import Internal.Vectorized hiding ((#), (#!)) | 27 | import Internal.Vectorized hiding ((#), (#!)) |
26 | import Foreign.Marshal.Alloc ( free ) | 28 | import Foreign.Marshal.Alloc ( free ) |
27 | import Foreign.Marshal.Array(newArray) | 29 | import Foreign.Marshal.Array(newArray) |
28 | import Foreign.Ptr ( Ptr ) | 30 | import Foreign.Ptr ( Ptr ) |
29 | import Foreign.Storable ( Storable ) | 31 | import Foreign.Storable ( Storable ) |
30 | import Data.Complex ( Complex ) | 32 | import Data.Complex ( Complex ) |
33 | import Data.Int | ||
31 | import Foreign.C.Types ( CInt(..) ) | 34 | import Foreign.C.Types ( CInt(..) ) |
32 | import Foreign.C.String ( CString, newCString ) | 35 | import Foreign.C.String ( CString, newCString ) |
33 | import System.IO.Unsafe ( unsafePerformIO ) | 36 | import System.IO.Unsafe ( unsafePerformIO ) |
@@ -61,19 +64,23 @@ size :: Matrix t -> (Int, Int) | |||
61 | size m = (irows m, icols m) | 64 | size m = (irows m, icols m) |
62 | {-# INLINE size #-} | 65 | {-# INLINE size #-} |
63 | 66 | ||
67 | -- | True if the matrix is in RowMajor form. | ||
64 | rowOrder :: Matrix t -> Bool | 68 | rowOrder :: Matrix t -> Bool |
65 | rowOrder m = xCol m == 1 || cols m == 1 | 69 | rowOrder m = xCol m == 1 || cols m == 1 |
66 | {-# INLINE rowOrder #-} | 70 | {-# INLINE rowOrder #-} |
67 | 71 | ||
72 | -- | True if the matrix is in ColMajor form or if their is only one row. | ||
68 | colOrder :: Matrix t -> Bool | 73 | colOrder :: Matrix t -> Bool |
69 | colOrder m = xRow m == 1 || rows m == 1 | 74 | colOrder m = xRow m == 1 || rows m == 1 |
70 | {-# INLINE colOrder #-} | 75 | {-# INLINE colOrder #-} |
71 | 76 | ||
77 | -- | True if the matrix is a single row or column vector. | ||
72 | is1d :: Matrix t -> Bool | 78 | is1d :: Matrix t -> Bool |
73 | is1d (size->(r,c)) = r==1 || c==1 | 79 | is1d (size->(r,c)) = r==1 || c==1 |
74 | {-# INLINE is1d #-} | 80 | {-# INLINE is1d #-} |
75 | 81 | ||
76 | -- data is not contiguous | 82 | -- | True if the matrix is not contiguous. This usually |
83 | -- means it is a slice of some larger matrix. | ||
77 | isSlice :: Storable t => Matrix t -> Bool | 84 | isSlice :: Storable t => Matrix t -> Bool |
78 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) | 85 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) |
79 | {-# INLINE isSlice #-} | 86 | {-# INLINE isSlice #-} |
@@ -95,19 +102,23 @@ showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv | |||
95 | 102 | ||
96 | -------------------------------------------------------------------------------- | 103 | -------------------------------------------------------------------------------- |
97 | 104 | ||
98 | -- | Matrix transpose. | 105 | -- | O(1) Matrix transpose. This is only a logical transposition that does not |
106 | -- re-order the element storage. If the storage order is important, use 'cmat' | ||
107 | -- or 'fmat'. | ||
99 | trans :: Matrix t -> Matrix t | 108 | trans :: Matrix t -> Matrix t |
100 | trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = | 109 | trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = |
101 | m { irows = c, icols = r, xRow = xc, xCol = xr } | 110 | m { irows = c, icols = r, xRow = xc, xCol = xr } |
102 | 111 | ||
103 | 112 | ||
104 | cmat :: (Element t) => Matrix t -> Matrix t | 113 | -- | Obtain the RowMajor equivalent of a given Matrix. |
114 | cmat :: (Storable t) => Matrix t -> Matrix t | ||
105 | cmat m | 115 | cmat m |
106 | | rowOrder m = m | 116 | | rowOrder m = m |
107 | | otherwise = extractAll RowMajor m | 117 | | otherwise = extractAll RowMajor m |
108 | 118 | ||
109 | 119 | ||
110 | fmat :: (Element t) => Matrix t -> Matrix t | 120 | -- | Obtain the ColumnMajor equivalent of a given Matrix. |
121 | fmat :: (Storable t) => Matrix t -> Matrix t | ||
111 | fmat m | 122 | fmat m |
112 | | colOrder m = m | 123 | | colOrder m = m |
113 | | otherwise = extractAll ColumnMajor m | 124 | | otherwise = extractAll ColumnMajor m |
@@ -115,14 +126,14 @@ fmat m | |||
115 | 126 | ||
116 | -- C-Haskell matrix adapters | 127 | -- C-Haskell matrix adapters |
117 | {-# INLINE amatr #-} | 128 | {-# INLINE amatr #-} |
118 | amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r | 129 | amatr :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Ptr a -> f) -> IO r |
119 | amatr x f g = unsafeWith (xdat x) (f . g r c) | 130 | amatr x f g = unsafeWith (xdat x) (f . g r c) |
120 | where | 131 | where |
121 | r = fi (rows x) | 132 | r = fi (rows x) |
122 | c = fi (cols x) | 133 | c = fi (cols x) |
123 | 134 | ||
124 | {-# INLINE amat #-} | 135 | {-# INLINE amat #-} |
125 | amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r | 136 | amat :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> f) -> IO r |
126 | amat x f g = unsafeWith (xdat x) (f . g r c sr sc) | 137 | amat x f g = unsafeWith (xdat x) (f . g r c sr sc) |
127 | where | 138 | where |
128 | r = fi (rows x) | 139 | r = fi (rows x) |
@@ -133,8 +144,8 @@ amat x f g = unsafeWith (xdat x) (f . g r c sr sc) | |||
133 | 144 | ||
134 | instance Storable t => TransArray (Matrix t) | 145 | instance Storable t => TransArray (Matrix t) |
135 | where | 146 | where |
136 | type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b | 147 | type TransRaw (Matrix t) b = Int32 -> Int32 -> Ptr t -> b |
137 | type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b | 148 | type Trans (Matrix t) b = Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -> b |
138 | apply = amat | 149 | apply = amat |
139 | {-# INLINE apply #-} | 150 | {-# INLINE apply #-} |
140 | applyRaw = amatr | 151 | applyRaw = amatr |
@@ -151,10 +162,10 @@ a #! b = a # b # id | |||
151 | 162 | ||
152 | -------------------------------------------------------------------------------- | 163 | -------------------------------------------------------------------------------- |
153 | 164 | ||
154 | copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) | 165 | copy :: Storable t => MatrixOrder -> Matrix t -> IO (Matrix t) |
155 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) | 166 | copy ord m = extractAux ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) |
156 | 167 | ||
157 | extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t | 168 | extractAll :: Storable t => MatrixOrder -> Matrix t -> Matrix t |
158 | extractAll ord m = unsafePerformIO (copy ord m) | 169 | extractAll ord m = unsafePerformIO (copy ord m) |
159 | 170 | ||
160 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 171 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
@@ -164,14 +175,14 @@ extractAll ord m = unsafePerformIO (copy ord m) | |||
164 | it :: (Num t, Element t) => Vector t | 175 | it :: (Num t, Element t) => Vector t |
165 | 176 | ||
166 | -} | 177 | -} |
167 | flatten :: Element t => Matrix t -> Vector t | 178 | flatten :: Storable t => Matrix t -> Vector t |
168 | flatten m | 179 | flatten m |
169 | | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) | 180 | | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) |
170 | | otherwise = xdat m | 181 | | otherwise = xdat m |
171 | 182 | ||
172 | 183 | ||
173 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | 184 | -- | the inverse of 'Data.Packed.Matrix.fromLists' |
174 | toLists :: (Element t) => Matrix t -> [[t]] | 185 | toLists :: (Storable t) => Matrix t -> [[t]] |
175 | toLists = map toList . toRows | 186 | toLists = map toList . toRows |
176 | 187 | ||
177 | 188 | ||
@@ -192,7 +203,7 @@ compatdim (a:b:xs) | |||
192 | -- | Create a matrix from a list of vectors. | 203 | -- | Create a matrix from a list of vectors. |
193 | -- All vectors must have the same dimension, | 204 | -- All vectors must have the same dimension, |
194 | -- or dimension 1, which is are automatically expanded. | 205 | -- or dimension 1, which is are automatically expanded. |
195 | fromRows :: Element t => [Vector t] -> Matrix t | 206 | fromRows :: Storable t => [Vector t] -> Matrix t |
196 | fromRows [] = emptyM 0 0 | 207 | fromRows [] = emptyM 0 0 |
197 | fromRows vs = case compatdim (map dim vs) of | 208 | fromRows vs = case compatdim (map dim vs) of |
198 | Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) | 209 | Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) |
@@ -203,25 +214,25 @@ fromRows vs = case compatdim (map dim vs) of | |||
203 | adapt c v | 214 | adapt c v |
204 | | c == 0 = fromList[] | 215 | | c == 0 = fromList[] |
205 | | dim v == c = v | 216 | | dim v == c = v |
206 | | otherwise = constantD (v@>0) c | 217 | | otherwise = constantAux (v@>0) c |
207 | 218 | ||
208 | -- | extracts the rows of a matrix as a list of vectors | 219 | -- | extracts the rows of a matrix as a list of vectors |
209 | toRows :: Element t => Matrix t -> [Vector t] | 220 | toRows :: Storable t => Matrix t -> [Vector t] |
210 | toRows m | 221 | toRows m |
211 | | rowOrder m = map sub rowRange | 222 | | rowOrder m = map sub rowRange |
212 | | otherwise = map ext rowRange | 223 | | otherwise = map ext rowRange |
213 | where | 224 | where |
214 | rowRange = [0..rows m-1] | 225 | rowRange = [0..rows m-1] |
215 | sub k = subVector (k*xRow m) (cols m) (xdat m) | 226 | sub k = subVector (k*xRow m) (cols m) (xdat m) |
216 | ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) | 227 | ext k = xdat $ unsafePerformIO $ extractAux RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) |
217 | 228 | ||
218 | 229 | ||
219 | -- | Creates a matrix from a list of vectors, as columns | 230 | -- | Creates a matrix from a list of vectors, as columns |
220 | fromColumns :: Element t => [Vector t] -> Matrix t | 231 | fromColumns :: Storable t => [Vector t] -> Matrix t |
221 | fromColumns m = trans . fromRows $ m | 232 | fromColumns m = trans . fromRows $ m |
222 | 233 | ||
223 | -- | Creates a list of vectors from the columns of a matrix | 234 | -- | Creates a list of vectors from the columns of a matrix |
224 | toColumns :: Element t => Matrix t -> [Vector t] | 235 | toColumns :: Storable t => Matrix t -> [Vector t] |
225 | toColumns m = toRows . trans $ m | 236 | toColumns m = toRows . trans $ m |
226 | 237 | ||
227 | -- | Reads a matrix position. | 238 | -- | Reads a matrix position. |
@@ -271,13 +282,13 @@ reshape c v = matrixFromVector RowMajor (dim v `div` c) c v | |||
271 | 282 | ||
272 | 283 | ||
273 | -- | application of a vector function on the flattened matrix elements | 284 | -- | application of a vector function on the flattened matrix elements |
274 | liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b | 285 | liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b |
275 | liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} | 286 | liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} |
276 | | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) | 287 | | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) |
277 | | otherwise = matrixFromVector (orderOf m) r c (f d) | 288 | | otherwise = matrixFromVector (orderOf m) r c (f d) |
278 | 289 | ||
279 | -- | application of a vector function on the flattened matrices elements | 290 | -- | application of a vector function on the flattened matrices elements |
280 | liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | 291 | liftMatrix2 :: (Storable t, Storable a, Storable b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t |
281 | liftMatrix2 f m1@(size->(r,c)) m2 | 292 | liftMatrix2 f m1@(size->(r,c)) m2 |
282 | | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" | 293 | | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" |
283 | | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) | 294 | | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) |
@@ -285,103 +296,8 @@ liftMatrix2 f m1@(size->(r,c)) m2 | |||
285 | 296 | ||
286 | ------------------------------------------------------------------ | 297 | ------------------------------------------------------------------ |
287 | 298 | ||
288 | -- | Supported matrix elements. | ||
289 | class (Storable a) => Element a where | ||
290 | constantD :: a -> Int -> Vector a | ||
291 | extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) | ||
292 | setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () | ||
293 | sortI :: Ord a => Vector a -> Vector CInt | ||
294 | sortV :: Ord a => Vector a -> Vector a | ||
295 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt | ||
296 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | ||
297 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | ||
298 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | ||
299 | gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO () | ||
300 | reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation | ||
301 | |||
302 | |||
303 | instance Element Float where | ||
304 | constantD = constantAux cconstantF | ||
305 | extractR = extractAux c_extractF | ||
306 | setRect = setRectAux c_setRectF | ||
307 | sortI = sortIdxF | ||
308 | sortV = sortValF | ||
309 | compareV = compareF | ||
310 | selectV = selectF | ||
311 | remapM = remapF | ||
312 | rowOp = rowOpAux c_rowOpF | ||
313 | gemm = gemmg c_gemmF | ||
314 | reorderV = reorderAux c_reorderF | ||
315 | |||
316 | instance Element Double where | ||
317 | constantD = constantAux cconstantR | ||
318 | extractR = extractAux c_extractD | ||
319 | setRect = setRectAux c_setRectD | ||
320 | sortI = sortIdxD | ||
321 | sortV = sortValD | ||
322 | compareV = compareD | ||
323 | selectV = selectD | ||
324 | remapM = remapD | ||
325 | rowOp = rowOpAux c_rowOpD | ||
326 | gemm = gemmg c_gemmD | ||
327 | reorderV = reorderAux c_reorderD | ||
328 | |||
329 | instance Element (Complex Float) where | ||
330 | constantD = constantAux cconstantQ | ||
331 | extractR = extractAux c_extractQ | ||
332 | setRect = setRectAux c_setRectQ | ||
333 | sortI = undefined | ||
334 | sortV = undefined | ||
335 | compareV = undefined | ||
336 | selectV = selectQ | ||
337 | remapM = remapQ | ||
338 | rowOp = rowOpAux c_rowOpQ | ||
339 | gemm = gemmg c_gemmQ | ||
340 | reorderV = reorderAux c_reorderQ | ||
341 | |||
342 | instance Element (Complex Double) where | ||
343 | constantD = constantAux cconstantC | ||
344 | extractR = extractAux c_extractC | ||
345 | setRect = setRectAux c_setRectC | ||
346 | sortI = undefined | ||
347 | sortV = undefined | ||
348 | compareV = undefined | ||
349 | selectV = selectC | ||
350 | remapM = remapC | ||
351 | rowOp = rowOpAux c_rowOpC | ||
352 | gemm = gemmg c_gemmC | ||
353 | reorderV = reorderAux c_reorderC | ||
354 | |||
355 | instance Element (CInt) where | ||
356 | constantD = constantAux cconstantI | ||
357 | extractR = extractAux c_extractI | ||
358 | setRect = setRectAux c_setRectI | ||
359 | sortI = sortIdxI | ||
360 | sortV = sortValI | ||
361 | compareV = compareI | ||
362 | selectV = selectI | ||
363 | remapM = remapI | ||
364 | rowOp = rowOpAux c_rowOpI | ||
365 | gemm = gemmg c_gemmI | ||
366 | reorderV = reorderAux c_reorderI | ||
367 | |||
368 | instance Element Z where | ||
369 | constantD = constantAux cconstantL | ||
370 | extractR = extractAux c_extractL | ||
371 | setRect = setRectAux c_setRectL | ||
372 | sortI = sortIdxL | ||
373 | sortV = sortValL | ||
374 | compareV = compareL | ||
375 | selectV = selectL | ||
376 | remapM = remapL | ||
377 | rowOp = rowOpAux c_rowOpL | ||
378 | gemm = gemmg c_gemmL | ||
379 | reorderV = reorderAux c_reorderL | ||
380 | |||
381 | ------------------------------------------------------------------- | ||
382 | |||
383 | -- | reference to a rectangular slice of a matrix (no data copy) | 299 | -- | reference to a rectangular slice of a matrix (no data copy) |
384 | subMatrix :: Element a | 300 | subMatrix :: Storable a |
385 | => (Int,Int) -- ^ (r0,c0) starting position | 301 | => (Int,Int) -- ^ (r0,c0) starting position |
386 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix | 302 | -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix |
387 | -> Matrix a -- ^ input matrix | 303 | -> Matrix a -- ^ input matrix |
@@ -402,34 +318,34 @@ subMatrix (r0,c0) (rt,ct) m | |||
402 | maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 | 318 | maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 |
403 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | 319 | maxZ xs = if minimum xs == 0 then 0 else maximum xs |
404 | 320 | ||
405 | conformMs :: Element t => [Matrix t] -> [Matrix t] | 321 | conformMs :: Storable t => [Matrix t] -> [Matrix t] |
406 | conformMs ms = map (conformMTo (r,c)) ms | 322 | conformMs ms = map (conformMTo (r,c)) ms |
407 | where | 323 | where |
408 | r = maxZ (map rows ms) | 324 | r = maxZ (map rows ms) |
409 | c = maxZ (map cols ms) | 325 | c = maxZ (map cols ms) |
410 | 326 | ||
411 | conformVs :: Element t => [Vector t] -> [Vector t] | 327 | conformVs :: Storable t => [Vector t] -> [Vector t] |
412 | conformVs vs = map (conformVTo n) vs | 328 | conformVs vs = map (conformVTo n) vs |
413 | where | 329 | where |
414 | n = maxZ (map dim vs) | 330 | n = maxZ (map dim vs) |
415 | 331 | ||
416 | conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t | 332 | conformMTo :: Storable t => (Int, Int) -> Matrix t -> Matrix t |
417 | conformMTo (r,c) m | 333 | conformMTo (r,c) m |
418 | | size m == (r,c) = m | 334 | | size m == (r,c) = m |
419 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | 335 | | size m == (1,1) = matrixFromVector RowMajor r c (constantAux (m@@>(0,0)) (r*c)) |
420 | | size m == (r,1) = repCols c m | 336 | | size m == (r,1) = repCols c m |
421 | | size m == (1,c) = repRows r m | 337 | | size m == (1,c) = repRows r m |
422 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) | 338 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) |
423 | 339 | ||
424 | conformVTo :: Element t => Int -> Vector t -> Vector t | 340 | conformVTo :: Storable t => Int -> Vector t -> Vector t |
425 | conformVTo n v | 341 | conformVTo n v |
426 | | dim v == n = v | 342 | | dim v == n = v |
427 | | dim v == 1 = constantD (v@>0) n | 343 | | dim v == 1 = constantAux (v@>0) n |
428 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | 344 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n |
429 | 345 | ||
430 | repRows :: Element t => Int -> Matrix t -> Matrix t | 346 | repRows :: Storable t => Int -> Matrix t -> Matrix t |
431 | repRows n x = fromRows (replicate n (flatten x)) | 347 | repRows n x = fromRows (replicate n (flatten x)) |
432 | repCols :: Element t => Int -> Matrix t -> Matrix t | 348 | repCols :: Storable t => Int -> Matrix t -> Matrix t |
433 | repCols n x = fromColumns (replicate n (flatten x)) | 349 | repCols n x = fromColumns (replicate n (flatten x)) |
434 | 350 | ||
435 | shSize :: Matrix t -> [Char] | 351 | shSize :: Matrix t -> [Char] |
@@ -453,32 +369,50 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
453 | 369 | ||
454 | --------------------------------------------------------------- | 370 | --------------------------------------------------------------- |
455 | 371 | ||
372 | {- | ||
456 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, | 373 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, |
457 | Storable t, Num t3, Num t2, Integral t1, Integral t) | 374 | Storable t, Num t3, Num t2, Integral t1, Integral t) |
458 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t | 375 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) -- f |
459 | -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) | 376 | -> MatrixOrder -- ord |
460 | -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) | 377 | -> c -- m |
461 | extractAux f ord m moder vr modec vc = do | 378 | -> t3 -- moder |
379 | -> Vector t1 -- vr | ||
380 | -> t2 -- modec | ||
381 | -> Vector t -- vc | ||
382 | -> IO (Matrix a) | ||
383 | -} | ||
384 | |||
385 | extractAux :: Storable a => | ||
386 | MatrixOrder | ||
387 | -> Matrix a | ||
388 | -> Int32 | ||
389 | -> Vector Int32 | ||
390 | -> Int32 | ||
391 | -> Vector Int32 | ||
392 | -> IO (Matrix a) | ||
393 | extractAux ord m moder vr modec vc = do | ||
462 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 394 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
463 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 395 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
464 | r <- createMatrix ord nr nc | 396 | r <- createMatrix ord nr nc |
465 | (vr # vc # m #! r) (f moder modec) #|"extract" | 397 | (vr # vc # m #! r) (extractStorable moder modec) #|"extract" |
466 | 398 | ||
467 | return r | 399 | return r |
468 | 400 | ||
469 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) | 401 | {- |
402 | type Extr x = Int32 -> Int32 -> CIdxs (CIdxs (OM x (OM x (IO Int32)))) | ||
470 | 403 | ||
471 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double | 404 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double |
472 | foreign import ccall unsafe "extractF" c_extractF :: Extr Float | 405 | foreign import ccall unsafe "extractF" c_extractF :: Extr Float |
473 | foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) | 406 | foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) |
474 | foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) | 407 | foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) |
475 | foreign import ccall unsafe "extractI" c_extractI :: Extr CInt | 408 | foreign import ccall unsafe "extractI" c_extractI :: Extr Int32 |
476 | foreign import ccall unsafe "extractL" c_extractL :: Extr Z | 409 | foreign import ccall unsafe "extractL" c_extractL :: Extr Z |
410 | -} | ||
477 | 411 | ||
478 | --------------------------------------------------------------- | 412 | --------------------------------------------------------------- |
479 | 413 | ||
480 | setRectAux :: (TransArray c1, TransArray c) | 414 | setRectAux :: (TransArray c1, TransArray c) |
481 | => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) | 415 | => (Int32 -> Int32 -> Trans c1 (Trans c (IO Int32))) |
482 | -> Int -> Int -> c1 -> c -> IO () | 416 | -> Int -> Int -> c1 -> c -> IO () |
483 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" | 417 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" |
484 | 418 | ||
@@ -494,17 +428,17 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | |||
494 | -------------------------------------------------------------------------------- | 428 | -------------------------------------------------------------------------------- |
495 | 429 | ||
496 | sortG :: (Storable t, Storable a) | 430 | sortG :: (Storable t, Storable a) |
497 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | 431 | => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a |
498 | sortG f v = unsafePerformIO $ do | 432 | sortG f v = unsafePerformIO $ do |
499 | r <- createVector (dim v) | 433 | r <- createVector (dim v) |
500 | (v #! r) f #|"sortG" | 434 | (v #! r) f #|"sortG" |
501 | return r | 435 | return r |
502 | 436 | ||
503 | sortIdxD :: Vector Double -> Vector CInt | 437 | sortIdxD :: Vector Double -> Vector Int32 |
504 | sortIdxD = sortG c_sort_indexD | 438 | sortIdxD = sortG c_sort_indexD |
505 | sortIdxF :: Vector Float -> Vector CInt | 439 | sortIdxF :: Vector Float -> Vector Int32 |
506 | sortIdxF = sortG c_sort_indexF | 440 | sortIdxF = sortG c_sort_indexF |
507 | sortIdxI :: Vector CInt -> Vector CInt | 441 | sortIdxI :: Vector Int32 -> Vector Int32 |
508 | sortIdxI = sortG c_sort_indexI | 442 | sortIdxI = sortG c_sort_indexI |
509 | sortIdxL :: Vector Z -> Vector I | 443 | sortIdxL :: Vector Z -> Vector I |
510 | sortIdxL = sortG c_sort_indexL | 444 | sortIdxL = sortG c_sort_indexL |
@@ -513,81 +447,81 @@ sortValD :: Vector Double -> Vector Double | |||
513 | sortValD = sortG c_sort_valD | 447 | sortValD = sortG c_sort_valD |
514 | sortValF :: Vector Float -> Vector Float | 448 | sortValF :: Vector Float -> Vector Float |
515 | sortValF = sortG c_sort_valF | 449 | sortValF = sortG c_sort_valF |
516 | sortValI :: Vector CInt -> Vector CInt | 450 | sortValI :: Vector Int32 -> Vector Int32 |
517 | sortValI = sortG c_sort_valI | 451 | sortValI = sortG c_sort_valI |
518 | sortValL :: Vector Z -> Vector Z | 452 | sortValL :: Vector Z -> Vector Z |
519 | sortValL = sortG c_sort_valL | 453 | sortValL = sortG c_sort_valL |
520 | 454 | ||
521 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) | 455 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV Int32 (IO Int32)) |
522 | foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) | 456 | foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV Int32 (IO Int32)) |
523 | foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) | 457 | foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV Int32 (CV Int32 (IO Int32)) |
524 | foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok | 458 | foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok |
525 | 459 | ||
526 | foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) | 460 | foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO Int32)) |
527 | foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) | 461 | foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO Int32)) |
528 | foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) | 462 | foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV Int32 (CV Int32 (IO Int32)) |
529 | foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | 463 | foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok |
530 | 464 | ||
531 | -------------------------------------------------------------------------------- | 465 | -------------------------------------------------------------------------------- |
532 | 466 | ||
533 | compareG :: (TransArray c, Storable t, Storable a) | 467 | compareG :: (TransArray c, Storable t, Storable a) |
534 | => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) | 468 | => Trans c (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) |
535 | -> c -> Vector t -> Vector a | 469 | -> c -> Vector t -> Vector a |
536 | compareG f u v = unsafePerformIO $ do | 470 | compareG f u v = unsafePerformIO $ do |
537 | r <- createVector (dim v) | 471 | r <- createVector (dim v) |
538 | (u # v #! r) f #|"compareG" | 472 | (u # v #! r) f #|"compareG" |
539 | return r | 473 | return r |
540 | 474 | ||
541 | compareD :: Vector Double -> Vector Double -> Vector CInt | 475 | compareD :: Vector Double -> Vector Double -> Vector Int32 |
542 | compareD = compareG c_compareD | 476 | compareD = compareG c_compareD |
543 | compareF :: Vector Float -> Vector Float -> Vector CInt | 477 | compareF :: Vector Float -> Vector Float -> Vector Int32 |
544 | compareF = compareG c_compareF | 478 | compareF = compareG c_compareF |
545 | compareI :: Vector CInt -> Vector CInt -> Vector CInt | 479 | compareI :: Vector Int32 -> Vector Int32 -> Vector Int32 |
546 | compareI = compareG c_compareI | 480 | compareI = compareG c_compareI |
547 | compareL :: Vector Z -> Vector Z -> Vector CInt | 481 | compareL :: Vector Z -> Vector Z -> Vector Int32 |
548 | compareL = compareG c_compareL | 482 | compareL = compareG c_compareL |
549 | 483 | ||
550 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | 484 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV Int32 (IO Int32))) |
551 | foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) | 485 | foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV Int32 (IO Int32))) |
552 | foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) | 486 | foreign import ccall unsafe "compareI" c_compareI :: CV Int32 (CV Int32 (CV Int32 (IO Int32))) |
553 | foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | 487 | foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok |
554 | 488 | ||
555 | -------------------------------------------------------------------------------- | 489 | -------------------------------------------------------------------------------- |
556 | 490 | ||
557 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) | 491 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) |
558 | => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) | 492 | => Trans c2 (Trans c1 (Int32 -> Ptr t -> Trans c (Int32 -> Ptr a -> IO Int32))) |
559 | -> c2 -> c1 -> Vector t -> c -> Vector a | 493 | -> c2 -> c1 -> Vector t -> c -> Vector a |
560 | selectG f c u v w = unsafePerformIO $ do | 494 | selectG f c u v w = unsafePerformIO $ do |
561 | r <- createVector (dim v) | 495 | r <- createVector (dim v) |
562 | (c # u # v # w #! r) f #|"selectG" | 496 | (c # u # v # w #! r) f #|"selectG" |
563 | return r | 497 | return r |
564 | 498 | ||
565 | selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double | 499 | selectD :: Vector Int32 -> Vector Double -> Vector Double -> Vector Double -> Vector Double |
566 | selectD = selectG c_selectD | 500 | selectD = selectG c_selectD |
567 | selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float | 501 | selectF :: Vector Int32 -> Vector Float -> Vector Float -> Vector Float -> Vector Float |
568 | selectF = selectG c_selectF | 502 | selectF = selectG c_selectF |
569 | selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt | 503 | selectI :: Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32 |
570 | selectI = selectG c_selectI | 504 | selectI = selectG c_selectI |
571 | selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z | 505 | selectL :: Vector Int32 -> Vector Z -> Vector Z -> Vector Z -> Vector Z |
572 | selectL = selectG c_selectL | 506 | selectL = selectG c_selectL |
573 | selectC :: Vector CInt | 507 | selectC :: Vector Int32 |
574 | -> Vector (Complex Double) | 508 | -> Vector (Complex Double) |
575 | -> Vector (Complex Double) | 509 | -> Vector (Complex Double) |
576 | -> Vector (Complex Double) | 510 | -> Vector (Complex Double) |
577 | -> Vector (Complex Double) | 511 | -> Vector (Complex Double) |
578 | selectC = selectG c_selectC | 512 | selectC = selectG c_selectC |
579 | selectQ :: Vector CInt | 513 | selectQ :: Vector Int32 |
580 | -> Vector (Complex Float) | 514 | -> Vector (Complex Float) |
581 | -> Vector (Complex Float) | 515 | -> Vector (Complex Float) |
582 | -> Vector (Complex Float) | 516 | -> Vector (Complex Float) |
583 | -> Vector (Complex Float) | 517 | -> Vector (Complex Float) |
584 | selectQ = selectG c_selectQ | 518 | selectQ = selectG c_selectQ |
585 | 519 | ||
586 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | 520 | type Sel x = CV Int32 (CV x (CV x (CV x (CV x (IO Int32))))) |
587 | 521 | ||
588 | foreign import ccall unsafe "chooseD" c_selectD :: Sel Double | 522 | foreign import ccall unsafe "chooseD" c_selectD :: Sel Double |
589 | foreign import ccall unsafe "chooseF" c_selectF :: Sel Float | 523 | foreign import ccall unsafe "chooseF" c_selectF :: Sel Float |
590 | foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt | 524 | foreign import ccall unsafe "chooseI" c_selectI :: Sel Int32 |
591 | foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) | 525 | foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) |
592 | foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) | 526 | foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) |
593 | foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | 527 | foreign import ccall unsafe "chooseL" c_selectL :: Sel Z |
@@ -595,35 +529,35 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | |||
595 | --------------------------------------------------------------------------- | 529 | --------------------------------------------------------------------------- |
596 | 530 | ||
597 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) | 531 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) |
598 | => (CInt -> CInt -> CInt -> CInt -> Ptr t | 532 | => (Int32 -> Int32 -> Int32 -> Int32 -> Ptr t |
599 | -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) | 533 | -> Trans c1 (Trans c (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> IO Int32))) |
600 | -> Matrix t -> c1 -> c -> Matrix a | 534 | -> Matrix t -> c1 -> c -> Matrix a |
601 | remapG f i j m = unsafePerformIO $ do | 535 | remapG f i j m = unsafePerformIO $ do |
602 | r <- createMatrix RowMajor (rows i) (cols i) | 536 | r <- createMatrix RowMajor (rows i) (cols i) |
603 | (i # j # m #! r) f #|"remapG" | 537 | (i # j # m #! r) f #|"remapG" |
604 | return r | 538 | return r |
605 | 539 | ||
606 | remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double | 540 | remapD :: Matrix Int32 -> Matrix Int32 -> Matrix Double -> Matrix Double |
607 | remapD = remapG c_remapD | 541 | remapD = remapG c_remapD |
608 | remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float | 542 | remapF :: Matrix Int32 -> Matrix Int32 -> Matrix Float -> Matrix Float |
609 | remapF = remapG c_remapF | 543 | remapF = remapG c_remapF |
610 | remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt | 544 | remapI :: Matrix Int32 -> Matrix Int32 -> Matrix Int32 -> Matrix Int32 |
611 | remapI = remapG c_remapI | 545 | remapI = remapG c_remapI |
612 | remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z | 546 | remapL :: Matrix Int32 -> Matrix Int32 -> Matrix Z -> Matrix Z |
613 | remapL = remapG c_remapL | 547 | remapL = remapG c_remapL |
614 | remapC :: Matrix CInt | 548 | remapC :: Matrix Int32 |
615 | -> Matrix CInt | 549 | -> Matrix Int32 |
616 | -> Matrix (Complex Double) | 550 | -> Matrix (Complex Double) |
617 | -> Matrix (Complex Double) | 551 | -> Matrix (Complex Double) |
618 | remapC = remapG c_remapC | 552 | remapC = remapG c_remapC |
619 | remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) | 553 | remapQ :: Matrix Int32 -> Matrix Int32 -> Matrix (Complex Float) -> Matrix (Complex Float) |
620 | remapQ = remapG c_remapQ | 554 | remapQ = remapG c_remapQ |
621 | 555 | ||
622 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) | 556 | type Rem x = OM Int32 (OM Int32 (OM x (OM x (IO Int32)))) |
623 | 557 | ||
624 | foreign import ccall unsafe "remapD" c_remapD :: Rem Double | 558 | foreign import ccall unsafe "remapD" c_remapD :: Rem Double |
625 | foreign import ccall unsafe "remapF" c_remapF :: Rem Float | 559 | foreign import ccall unsafe "remapF" c_remapF :: Rem Float |
626 | foreign import ccall unsafe "remapI" c_remapI :: Rem CInt | 560 | foreign import ccall unsafe "remapI" c_remapI :: Rem Int32 |
627 | foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) | 561 | foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) |
628 | foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) | 562 | foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) |
629 | foreign import ccall unsafe "remapL" c_remapL :: Rem Z | 563 | foreign import ccall unsafe "remapL" c_remapL :: Rem Z |
@@ -631,14 +565,14 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
631 | -------------------------------------------------------------------------------- | 565 | -------------------------------------------------------------------------------- |
632 | 566 | ||
633 | rowOpAux :: (TransArray c, Storable a) => | 567 | rowOpAux :: (TransArray c, Storable a) => |
634 | (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) | 568 | (Int32 -> Ptr a -> Int32 -> Int32 -> Int32 -> Int32 -> Trans c (IO Int32)) |
635 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () | 569 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () |
636 | rowOpAux f c x i1 i2 j1 j2 m = do | 570 | rowOpAux f c x i1 i2 j1 j2 m = do |
637 | px <- newArray [x] | 571 | px <- newArray [x] |
638 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" | 572 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" |
639 | free px | 573 | free px |
640 | 574 | ||
641 | type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok | 575 | type RowOp x = Int32 -> Ptr x -> Int32 -> Int32 -> Int32 -> Int32 -> x ::> Ok |
642 | 576 | ||
643 | foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R | 577 | foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R |
644 | foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float | 578 | foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float |
@@ -652,7 +586,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
652 | -------------------------------------------------------------------------------- | 586 | -------------------------------------------------------------------------------- |
653 | 587 | ||
654 | gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) | 588 | gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) |
655 | => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) | 589 | => Trans c3 (Trans c2 (Trans c1 (Trans c (IO Int32)))) |
656 | -> c3 -> c2 -> c1 -> c -> IO () | 590 | -> c3 -> c2 -> c1 -> c -> IO () |
657 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" | 591 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" |
658 | 592 | ||
@@ -669,21 +603,26 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | |||
669 | 603 | ||
670 | -------------------------------------------------------------------------------- | 604 | -------------------------------------------------------------------------------- |
671 | 605 | ||
606 | {- | ||
672 | reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => | 607 | reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => |
673 | (CInt -> Ptr a -> CInt -> Ptr t1 | 608 | (Int32 -> Ptr a -> Int32 -> Ptr t1 |
674 | -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) | 609 | -> Trans c (Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32)) |
675 | -> Vector t1 -> c -> Vector t -> Vector a1 | 610 | -> Vector t1 -> c -> Vector t -> Vector a1 |
611 | -} | ||
612 | reorderAux :: (TransArray c, Storable a, | ||
613 | Trans c (Int32 -> Ptr a -> Int32 -> Ptr a -> IO Int32) ~ (Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr a -> Int32 -> Ptr a -> IO Int32)) => | ||
614 | p -> Vector Int32 -> c -> Vector a -> Vector a | ||
676 | reorderAux f s d v = unsafePerformIO $ do | 615 | reorderAux f s d v = unsafePerformIO $ do |
677 | k <- createVector (dim s) | 616 | k <- createVector (dim s) |
678 | r <- createVector (dim v) | 617 | r <- createVector (dim v) |
679 | (k # s # d # v #! r) f #| "reorderV" | 618 | (k # s # d # v #! r) reorderStorable #| "reorderV" |
680 | return r | 619 | return r |
681 | 620 | ||
682 | type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) | 621 | type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32))))) |
683 | 622 | ||
684 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double | 623 | foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double |
685 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float | 624 | foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float |
686 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt | 625 | foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32 |
687 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) | 626 | foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) |
688 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) | 627 | foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) |
689 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | 628 | foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z |
@@ -691,12 +630,12 @@ foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z | |||
691 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, | 630 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, |
692 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ | 631 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ |
693 | -- This function is intended to be used internally by tensor libraries. | 632 | -- This function is intended to be used internally by tensor libraries. |
694 | reorderVector :: Element a | 633 | reorderVector :: Storable a |
695 | => Vector CInt -- ^ @strides@: array strides | 634 | => Vector Int32 -- ^ @strides@: array strides |
696 | -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ | 635 | -> Vector Int32 -- ^ @dims@: array dimensions of new array @v@ |
697 | -> Vector a -- ^ @v@: flattened input array | 636 | -> Vector a -- ^ @v@: flattened input array |
698 | -> Vector a -- ^ @v'@: flattened output array | 637 | -> Vector a -- ^ @v'@: flattened output array |
699 | reorderVector = reorderV | 638 | reorderVector = reorderAux () |
700 | 639 | ||
701 | -------------------------------------------------------------------------------- | 640 | -------------------------------------------------------------------------------- |
702 | 641 | ||