summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
authorJoe Crayne <joe@jerkface.net>2019-08-08 02:22:30 -0400
committerJoe Crayne <joe@jerkface.net>2019-08-08 22:47:46 -0400
commitbadcbdfddc4be31fc79a6df4553795af18069efe (patch)
tree90c38bd8793b53a5e6f00049eb78acaa8d88d711 /packages/base/src/Internal/Matrix.hs
parentd844a145f2e8808c9f75cd99c673d5f5c8960bf2 (diff)
Removed the Element class.tower
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs307
1 files changed, 123 insertions, 184 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 5436e59..04092f9 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -2,6 +2,7 @@
2{-# LANGUAGE FlexibleContexts #-} 2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-} 4{-# LANGUAGE BangPatterns #-}
5{-# LANGUAGE CPP #-}
5{-# LANGUAGE TypeOperators #-} 6{-# LANGUAGE TypeOperators #-}
6{-# LANGUAGE TypeFamilies #-} 7{-# LANGUAGE TypeFamilies #-}
7{-# LANGUAGE ViewPatterns #-} 8{-# LANGUAGE ViewPatterns #-}
@@ -22,12 +23,14 @@ module Internal.Matrix where
22 23
23import Internal.Vector 24import Internal.Vector
24import Internal.Devel 25import Internal.Devel
26import Internal.Extract
25import Internal.Vectorized hiding ((#), (#!)) 27import Internal.Vectorized hiding ((#), (#!))
26import Foreign.Marshal.Alloc ( free ) 28import Foreign.Marshal.Alloc ( free )
27import Foreign.Marshal.Array(newArray) 29import Foreign.Marshal.Array(newArray)
28import Foreign.Ptr ( Ptr ) 30import Foreign.Ptr ( Ptr )
29import Foreign.Storable ( Storable ) 31import Foreign.Storable ( Storable )
30import Data.Complex ( Complex ) 32import Data.Complex ( Complex )
33import Data.Int
31import Foreign.C.Types ( CInt(..) ) 34import Foreign.C.Types ( CInt(..) )
32import Foreign.C.String ( CString, newCString ) 35import Foreign.C.String ( CString, newCString )
33import System.IO.Unsafe ( unsafePerformIO ) 36import System.IO.Unsafe ( unsafePerformIO )
@@ -61,19 +64,23 @@ size :: Matrix t -> (Int, Int)
61size m = (irows m, icols m) 64size m = (irows m, icols m)
62{-# INLINE size #-} 65{-# INLINE size #-}
63 66
67-- | True if the matrix is in RowMajor form.
64rowOrder :: Matrix t -> Bool 68rowOrder :: Matrix t -> Bool
65rowOrder m = xCol m == 1 || cols m == 1 69rowOrder m = xCol m == 1 || cols m == 1
66{-# INLINE rowOrder #-} 70{-# INLINE rowOrder #-}
67 71
72-- | True if the matrix is in ColMajor form or if their is only one row.
68colOrder :: Matrix t -> Bool 73colOrder :: Matrix t -> Bool
69colOrder m = xRow m == 1 || rows m == 1 74colOrder m = xRow m == 1 || rows m == 1
70{-# INLINE colOrder #-} 75{-# INLINE colOrder #-}
71 76
77-- | True if the matrix is a single row or column vector.
72is1d :: Matrix t -> Bool 78is1d :: Matrix t -> Bool
73is1d (size->(r,c)) = r==1 || c==1 79is1d (size->(r,c)) = r==1 || c==1
74{-# INLINE is1d #-} 80{-# INLINE is1d #-}
75 81
76-- data is not contiguous 82-- | True if the matrix is not contiguous. This usually
83-- means it is a slice of some larger matrix.
77isSlice :: Storable t => Matrix t -> Bool 84isSlice :: Storable t => Matrix t -> Bool
78isSlice m@(size->(r,c)) = r*c < dim (xdat m) 85isSlice m@(size->(r,c)) = r*c < dim (xdat m)
79{-# INLINE isSlice #-} 86{-# INLINE isSlice #-}
@@ -95,19 +102,23 @@ showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv
95 102
96-------------------------------------------------------------------------------- 103--------------------------------------------------------------------------------
97 104
98-- | Matrix transpose. 105-- | O(1) Matrix transpose. This is only a logical transposition that does not
106-- re-order the element storage. If the storage order is important, use 'cmat'
107-- or 'fmat'.
99trans :: Matrix t -> Matrix t 108trans :: Matrix t -> Matrix t
100trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = 109trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } =
101 m { irows = c, icols = r, xRow = xc, xCol = xr } 110 m { irows = c, icols = r, xRow = xc, xCol = xr }
102 111
103 112
104cmat :: (Element t) => Matrix t -> Matrix t 113-- | Obtain the RowMajor equivalent of a given Matrix.
114cmat :: (Storable t) => Matrix t -> Matrix t
105cmat m 115cmat m
106 | rowOrder m = m 116 | rowOrder m = m
107 | otherwise = extractAll RowMajor m 117 | otherwise = extractAll RowMajor m
108 118
109 119
110fmat :: (Element t) => Matrix t -> Matrix t 120-- | Obtain the ColumnMajor equivalent of a given Matrix.
121fmat :: (Storable t) => Matrix t -> Matrix t
111fmat m 122fmat m
112 | colOrder m = m 123 | colOrder m = m
113 | otherwise = extractAll ColumnMajor m 124 | otherwise = extractAll ColumnMajor m
@@ -115,14 +126,14 @@ fmat m
115 126
116-- C-Haskell matrix adapters 127-- C-Haskell matrix adapters
117{-# INLINE amatr #-} 128{-# INLINE amatr #-}
118amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r 129amatr :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Ptr a -> f) -> IO r
119amatr x f g = unsafeWith (xdat x) (f . g r c) 130amatr x f g = unsafeWith (xdat x) (f . g r c)
120 where 131 where
121 r = fi (rows x) 132 r = fi (rows x)
122 c = fi (cols x) 133 c = fi (cols x)
123 134
124{-# INLINE amat #-} 135{-# INLINE amat #-}
125amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r 136amat :: Storable a => Matrix a -> (f -> IO r) -> (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> f) -> IO r
126amat x f g = unsafeWith (xdat x) (f . g r c sr sc) 137amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
127 where 138 where
128 r = fi (rows x) 139 r = fi (rows x)
@@ -133,8 +144,8 @@ amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
133 144
134instance Storable t => TransArray (Matrix t) 145instance Storable t => TransArray (Matrix t)
135 where 146 where
136 type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b 147 type TransRaw (Matrix t) b = Int32 -> Int32 -> Ptr t -> b
137 type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b 148 type Trans (Matrix t) b = Int32 -> Int32 -> Int32 -> Int32 -> Ptr t -> b
138 apply = amat 149 apply = amat
139 {-# INLINE apply #-} 150 {-# INLINE apply #-}
140 applyRaw = amatr 151 applyRaw = amatr
@@ -151,10 +162,10 @@ a #! b = a # b # id
151 162
152-------------------------------------------------------------------------------- 163--------------------------------------------------------------------------------
153 164
154copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) 165copy :: Storable t => MatrixOrder -> Matrix t -> IO (Matrix t)
155copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) 166copy ord m = extractAux ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1])
156 167
157extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t 168extractAll :: Storable t => MatrixOrder -> Matrix t -> Matrix t
158extractAll ord m = unsafePerformIO (copy ord m) 169extractAll ord m = unsafePerformIO (copy ord m)
159 170
160{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 171{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
@@ -164,14 +175,14 @@ extractAll ord m = unsafePerformIO (copy ord m)
164it :: (Num t, Element t) => Vector t 175it :: (Num t, Element t) => Vector t
165 176
166-} 177-}
167flatten :: Element t => Matrix t -> Vector t 178flatten :: Storable t => Matrix t -> Vector t
168flatten m 179flatten m
169 | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m) 180 | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m)
170 | otherwise = xdat m 181 | otherwise = xdat m
171 182
172 183
173-- | the inverse of 'Data.Packed.Matrix.fromLists' 184-- | the inverse of 'Data.Packed.Matrix.fromLists'
174toLists :: (Element t) => Matrix t -> [[t]] 185toLists :: (Storable t) => Matrix t -> [[t]]
175toLists = map toList . toRows 186toLists = map toList . toRows
176 187
177 188
@@ -192,7 +203,7 @@ compatdim (a:b:xs)
192-- | Create a matrix from a list of vectors. 203-- | Create a matrix from a list of vectors.
193-- All vectors must have the same dimension, 204-- All vectors must have the same dimension,
194-- or dimension 1, which is are automatically expanded. 205-- or dimension 1, which is are automatically expanded.
195fromRows :: Element t => [Vector t] -> Matrix t 206fromRows :: Storable t => [Vector t] -> Matrix t
196fromRows [] = emptyM 0 0 207fromRows [] = emptyM 0 0
197fromRows vs = case compatdim (map dim vs) of 208fromRows vs = case compatdim (map dim vs) of
198 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs) 209 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
@@ -203,25 +214,25 @@ fromRows vs = case compatdim (map dim vs) of
203 adapt c v 214 adapt c v
204 | c == 0 = fromList[] 215 | c == 0 = fromList[]
205 | dim v == c = v 216 | dim v == c = v
206 | otherwise = constantD (v@>0) c 217 | otherwise = constantAux (v@>0) c
207 218
208-- | extracts the rows of a matrix as a list of vectors 219-- | extracts the rows of a matrix as a list of vectors
209toRows :: Element t => Matrix t -> [Vector t] 220toRows :: Storable t => Matrix t -> [Vector t]
210toRows m 221toRows m
211 | rowOrder m = map sub rowRange 222 | rowOrder m = map sub rowRange
212 | otherwise = map ext rowRange 223 | otherwise = map ext rowRange
213 where 224 where
214 rowRange = [0..rows m-1] 225 rowRange = [0..rows m-1]
215 sub k = subVector (k*xRow m) (cols m) (xdat m) 226 sub k = subVector (k*xRow m) (cols m) (xdat m)
216 ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) 227 ext k = xdat $ unsafePerformIO $ extractAux RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1])
217 228
218 229
219-- | Creates a matrix from a list of vectors, as columns 230-- | Creates a matrix from a list of vectors, as columns
220fromColumns :: Element t => [Vector t] -> Matrix t 231fromColumns :: Storable t => [Vector t] -> Matrix t
221fromColumns m = trans . fromRows $ m 232fromColumns m = trans . fromRows $ m
222 233
223-- | Creates a list of vectors from the columns of a matrix 234-- | Creates a list of vectors from the columns of a matrix
224toColumns :: Element t => Matrix t -> [Vector t] 235toColumns :: Storable t => Matrix t -> [Vector t]
225toColumns m = toRows . trans $ m 236toColumns m = toRows . trans $ m
226 237
227-- | Reads a matrix position. 238-- | Reads a matrix position.
@@ -271,13 +282,13 @@ reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
271 282
272 283
273-- | application of a vector function on the flattened matrix elements 284-- | application of a vector function on the flattened matrix elements
274liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b 285liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
275liftMatrix f m@Matrix { irows = r, icols = c, xdat = d} 286liftMatrix f m@Matrix { irows = r, icols = c, xdat = d}
276 | isSlice m = matrixFromVector RowMajor r c (f (flatten m)) 287 | isSlice m = matrixFromVector RowMajor r c (f (flatten m))
277 | otherwise = matrixFromVector (orderOf m) r c (f d) 288 | otherwise = matrixFromVector (orderOf m) r c (f d)
278 289
279-- | application of a vector function on the flattened matrices elements 290-- | application of a vector function on the flattened matrices elements
280liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 291liftMatrix2 :: (Storable t, Storable a, Storable b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
281liftMatrix2 f m1@(size->(r,c)) m2 292liftMatrix2 f m1@(size->(r,c)) m2
282 | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2" 293 | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2"
283 | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2)) 294 | rowOrder m1 = matrixFromVector RowMajor r c (f (flatten m1) (flatten m2))
@@ -285,103 +296,8 @@ liftMatrix2 f m1@(size->(r,c)) m2
285 296
286------------------------------------------------------------------ 297------------------------------------------------------------------
287 298
288-- | Supported matrix elements.
289class (Storable a) => Element a where
290 constantD :: a -> Int -> Vector a
291 extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
292 setRect :: Int -> Int -> Matrix a -> Matrix a -> IO ()
293 sortI :: Ord a => Vector a -> Vector CInt
294 sortV :: Ord a => Vector a -> Vector a
295 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
296 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
297 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
298 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
299 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
300 reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
301
302
303instance Element Float where
304 constantD = constantAux cconstantF
305 extractR = extractAux c_extractF
306 setRect = setRectAux c_setRectF
307 sortI = sortIdxF
308 sortV = sortValF
309 compareV = compareF
310 selectV = selectF
311 remapM = remapF
312 rowOp = rowOpAux c_rowOpF
313 gemm = gemmg c_gemmF
314 reorderV = reorderAux c_reorderF
315
316instance Element Double where
317 constantD = constantAux cconstantR
318 extractR = extractAux c_extractD
319 setRect = setRectAux c_setRectD
320 sortI = sortIdxD
321 sortV = sortValD
322 compareV = compareD
323 selectV = selectD
324 remapM = remapD
325 rowOp = rowOpAux c_rowOpD
326 gemm = gemmg c_gemmD
327 reorderV = reorderAux c_reorderD
328
329instance Element (Complex Float) where
330 constantD = constantAux cconstantQ
331 extractR = extractAux c_extractQ
332 setRect = setRectAux c_setRectQ
333 sortI = undefined
334 sortV = undefined
335 compareV = undefined
336 selectV = selectQ
337 remapM = remapQ
338 rowOp = rowOpAux c_rowOpQ
339 gemm = gemmg c_gemmQ
340 reorderV = reorderAux c_reorderQ
341
342instance Element (Complex Double) where
343 constantD = constantAux cconstantC
344 extractR = extractAux c_extractC
345 setRect = setRectAux c_setRectC
346 sortI = undefined
347 sortV = undefined
348 compareV = undefined
349 selectV = selectC
350 remapM = remapC
351 rowOp = rowOpAux c_rowOpC
352 gemm = gemmg c_gemmC
353 reorderV = reorderAux c_reorderC
354
355instance Element (CInt) where
356 constantD = constantAux cconstantI
357 extractR = extractAux c_extractI
358 setRect = setRectAux c_setRectI
359 sortI = sortIdxI
360 sortV = sortValI
361 compareV = compareI
362 selectV = selectI
363 remapM = remapI
364 rowOp = rowOpAux c_rowOpI
365 gemm = gemmg c_gemmI
366 reorderV = reorderAux c_reorderI
367
368instance Element Z where
369 constantD = constantAux cconstantL
370 extractR = extractAux c_extractL
371 setRect = setRectAux c_setRectL
372 sortI = sortIdxL
373 sortV = sortValL
374 compareV = compareL
375 selectV = selectL
376 remapM = remapL
377 rowOp = rowOpAux c_rowOpL
378 gemm = gemmg c_gemmL
379 reorderV = reorderAux c_reorderL
380
381-------------------------------------------------------------------
382
383-- | reference to a rectangular slice of a matrix (no data copy) 299-- | reference to a rectangular slice of a matrix (no data copy)
384subMatrix :: Element a 300subMatrix :: Storable a
385 => (Int,Int) -- ^ (r0,c0) starting position 301 => (Int,Int) -- ^ (r0,c0) starting position
386 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix 302 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
387 -> Matrix a -- ^ input matrix 303 -> Matrix a -- ^ input matrix
@@ -402,34 +318,34 @@ subMatrix (r0,c0) (rt,ct) m
402maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 318maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1
403maxZ xs = if minimum xs == 0 then 0 else maximum xs 319maxZ xs = if minimum xs == 0 then 0 else maximum xs
404 320
405conformMs :: Element t => [Matrix t] -> [Matrix t] 321conformMs :: Storable t => [Matrix t] -> [Matrix t]
406conformMs ms = map (conformMTo (r,c)) ms 322conformMs ms = map (conformMTo (r,c)) ms
407 where 323 where
408 r = maxZ (map rows ms) 324 r = maxZ (map rows ms)
409 c = maxZ (map cols ms) 325 c = maxZ (map cols ms)
410 326
411conformVs :: Element t => [Vector t] -> [Vector t] 327conformVs :: Storable t => [Vector t] -> [Vector t]
412conformVs vs = map (conformVTo n) vs 328conformVs vs = map (conformVTo n) vs
413 where 329 where
414 n = maxZ (map dim vs) 330 n = maxZ (map dim vs)
415 331
416conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t 332conformMTo :: Storable t => (Int, Int) -> Matrix t -> Matrix t
417conformMTo (r,c) m 333conformMTo (r,c) m
418 | size m == (r,c) = m 334 | size m == (r,c) = m
419 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) 335 | size m == (1,1) = matrixFromVector RowMajor r c (constantAux (m@@>(0,0)) (r*c))
420 | size m == (r,1) = repCols c m 336 | size m == (r,1) = repCols c m
421 | size m == (1,c) = repRows r m 337 | size m == (1,c) = repRows r m
422 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) 338 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c)
423 339
424conformVTo :: Element t => Int -> Vector t -> Vector t 340conformVTo :: Storable t => Int -> Vector t -> Vector t
425conformVTo n v 341conformVTo n v
426 | dim v == n = v 342 | dim v == n = v
427 | dim v == 1 = constantD (v@>0) n 343 | dim v == 1 = constantAux (v@>0) n
428 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n 344 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
429 345
430repRows :: Element t => Int -> Matrix t -> Matrix t 346repRows :: Storable t => Int -> Matrix t -> Matrix t
431repRows n x = fromRows (replicate n (flatten x)) 347repRows n x = fromRows (replicate n (flatten x))
432repCols :: Element t => Int -> Matrix t -> Matrix t 348repCols :: Storable t => Int -> Matrix t -> Matrix t
433repCols n x = fromColumns (replicate n (flatten x)) 349repCols n x = fromColumns (replicate n (flatten x))
434 350
435shSize :: Matrix t -> [Char] 351shSize :: Matrix t -> [Char]
@@ -453,32 +369,50 @@ instance (Storable t, NFData t) => NFData (Matrix t)
453 369
454--------------------------------------------------------------- 370---------------------------------------------------------------
455 371
372{-
456extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, 373extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1,
457 Storable t, Num t3, Num t2, Integral t1, Integral t) 374 Storable t, Num t3, Num t2, Integral t1, Integral t)
458 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t 375 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) -- f
459 -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) 376 -> MatrixOrder -- ord
460 -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) 377 -> c -- m
461extractAux f ord m moder vr modec vc = do 378 -> t3 -- moder
379 -> Vector t1 -- vr
380 -> t2 -- modec
381 -> Vector t -- vc
382 -> IO (Matrix a)
383-}
384
385extractAux :: Storable a =>
386 MatrixOrder
387 -> Matrix a
388 -> Int32
389 -> Vector Int32
390 -> Int32
391 -> Vector Int32
392 -> IO (Matrix a)
393extractAux ord m moder vr modec vc = do
462 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 394 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
463 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 395 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
464 r <- createMatrix ord nr nc 396 r <- createMatrix ord nr nc
465 (vr # vc # m #! r) (f moder modec) #|"extract" 397 (vr # vc # m #! r) (extractStorable moder modec) #|"extract"
466 398
467 return r 399 return r
468 400
469type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) 401{-
402type Extr x = Int32 -> Int32 -> CIdxs (CIdxs (OM x (OM x (IO Int32))))
470 403
471foreign import ccall unsafe "extractD" c_extractD :: Extr Double 404foreign import ccall unsafe "extractD" c_extractD :: Extr Double
472foreign import ccall unsafe "extractF" c_extractF :: Extr Float 405foreign import ccall unsafe "extractF" c_extractF :: Extr Float
473foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) 406foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double)
474foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) 407foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float)
475foreign import ccall unsafe "extractI" c_extractI :: Extr CInt 408foreign import ccall unsafe "extractI" c_extractI :: Extr Int32
476foreign import ccall unsafe "extractL" c_extractL :: Extr Z 409foreign import ccall unsafe "extractL" c_extractL :: Extr Z
410-}
477 411
478--------------------------------------------------------------- 412---------------------------------------------------------------
479 413
480setRectAux :: (TransArray c1, TransArray c) 414setRectAux :: (TransArray c1, TransArray c)
481 => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) 415 => (Int32 -> Int32 -> Trans c1 (Trans c (IO Int32)))
482 -> Int -> Int -> c1 -> c -> IO () 416 -> Int -> Int -> c1 -> c -> IO ()
483setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" 417setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
484 418
@@ -494,17 +428,17 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
494-------------------------------------------------------------------------------- 428--------------------------------------------------------------------------------
495 429
496sortG :: (Storable t, Storable a) 430sortG :: (Storable t, Storable a)
497 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a 431 => (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32) -> Vector t -> Vector a
498sortG f v = unsafePerformIO $ do 432sortG f v = unsafePerformIO $ do
499 r <- createVector (dim v) 433 r <- createVector (dim v)
500 (v #! r) f #|"sortG" 434 (v #! r) f #|"sortG"
501 return r 435 return r
502 436
503sortIdxD :: Vector Double -> Vector CInt 437sortIdxD :: Vector Double -> Vector Int32
504sortIdxD = sortG c_sort_indexD 438sortIdxD = sortG c_sort_indexD
505sortIdxF :: Vector Float -> Vector CInt 439sortIdxF :: Vector Float -> Vector Int32
506sortIdxF = sortG c_sort_indexF 440sortIdxF = sortG c_sort_indexF
507sortIdxI :: Vector CInt -> Vector CInt 441sortIdxI :: Vector Int32 -> Vector Int32
508sortIdxI = sortG c_sort_indexI 442sortIdxI = sortG c_sort_indexI
509sortIdxL :: Vector Z -> Vector I 443sortIdxL :: Vector Z -> Vector I
510sortIdxL = sortG c_sort_indexL 444sortIdxL = sortG c_sort_indexL
@@ -513,81 +447,81 @@ sortValD :: Vector Double -> Vector Double
513sortValD = sortG c_sort_valD 447sortValD = sortG c_sort_valD
514sortValF :: Vector Float -> Vector Float 448sortValF :: Vector Float -> Vector Float
515sortValF = sortG c_sort_valF 449sortValF = sortG c_sort_valF
516sortValI :: Vector CInt -> Vector CInt 450sortValI :: Vector Int32 -> Vector Int32
517sortValI = sortG c_sort_valI 451sortValI = sortG c_sort_valI
518sortValL :: Vector Z -> Vector Z 452sortValL :: Vector Z -> Vector Z
519sortValL = sortG c_sort_valL 453sortValL = sortG c_sort_valL
520 454
521foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) 455foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV Int32 (IO Int32))
522foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt)) 456foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV Int32 (IO Int32))
523foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt)) 457foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV Int32 (CV Int32 (IO Int32))
524foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok 458foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok
525 459
526foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt)) 460foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO Int32))
527foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt)) 461foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO Int32))
528foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt)) 462foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV Int32 (CV Int32 (IO Int32))
529foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok 463foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
530 464
531-------------------------------------------------------------------------------- 465--------------------------------------------------------------------------------
532 466
533compareG :: (TransArray c, Storable t, Storable a) 467compareG :: (TransArray c, Storable t, Storable a)
534 => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) 468 => Trans c (Int32 -> Ptr t -> Int32 -> Ptr a -> IO Int32)
535 -> c -> Vector t -> Vector a 469 -> c -> Vector t -> Vector a
536compareG f u v = unsafePerformIO $ do 470compareG f u v = unsafePerformIO $ do
537 r <- createVector (dim v) 471 r <- createVector (dim v)
538 (u # v #! r) f #|"compareG" 472 (u # v #! r) f #|"compareG"
539 return r 473 return r
540 474
541compareD :: Vector Double -> Vector Double -> Vector CInt 475compareD :: Vector Double -> Vector Double -> Vector Int32
542compareD = compareG c_compareD 476compareD = compareG c_compareD
543compareF :: Vector Float -> Vector Float -> Vector CInt 477compareF :: Vector Float -> Vector Float -> Vector Int32
544compareF = compareG c_compareF 478compareF = compareG c_compareF
545compareI :: Vector CInt -> Vector CInt -> Vector CInt 479compareI :: Vector Int32 -> Vector Int32 -> Vector Int32
546compareI = compareG c_compareI 480compareI = compareG c_compareI
547compareL :: Vector Z -> Vector Z -> Vector CInt 481compareL :: Vector Z -> Vector Z -> Vector Int32
548compareL = compareG c_compareL 482compareL = compareG c_compareL
549 483
550foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) 484foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV Int32 (IO Int32)))
551foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt))) 485foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV Int32 (IO Int32)))
552foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt))) 486foreign import ccall unsafe "compareI" c_compareI :: CV Int32 (CV Int32 (CV Int32 (IO Int32)))
553foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok 487foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
554 488
555-------------------------------------------------------------------------------- 489--------------------------------------------------------------------------------
556 490
557selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) 491selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a)
558 => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) 492 => Trans c2 (Trans c1 (Int32 -> Ptr t -> Trans c (Int32 -> Ptr a -> IO Int32)))
559 -> c2 -> c1 -> Vector t -> c -> Vector a 493 -> c2 -> c1 -> Vector t -> c -> Vector a
560selectG f c u v w = unsafePerformIO $ do 494selectG f c u v w = unsafePerformIO $ do
561 r <- createVector (dim v) 495 r <- createVector (dim v)
562 (c # u # v # w #! r) f #|"selectG" 496 (c # u # v # w #! r) f #|"selectG"
563 return r 497 return r
564 498
565selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double 499selectD :: Vector Int32 -> Vector Double -> Vector Double -> Vector Double -> Vector Double
566selectD = selectG c_selectD 500selectD = selectG c_selectD
567selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float 501selectF :: Vector Int32 -> Vector Float -> Vector Float -> Vector Float -> Vector Float
568selectF = selectG c_selectF 502selectF = selectG c_selectF
569selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt 503selectI :: Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32 -> Vector Int32
570selectI = selectG c_selectI 504selectI = selectG c_selectI
571selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z 505selectL :: Vector Int32 -> Vector Z -> Vector Z -> Vector Z -> Vector Z
572selectL = selectG c_selectL 506selectL = selectG c_selectL
573selectC :: Vector CInt 507selectC :: Vector Int32
574 -> Vector (Complex Double) 508 -> Vector (Complex Double)
575 -> Vector (Complex Double) 509 -> Vector (Complex Double)
576 -> Vector (Complex Double) 510 -> Vector (Complex Double)
577 -> Vector (Complex Double) 511 -> Vector (Complex Double)
578selectC = selectG c_selectC 512selectC = selectG c_selectC
579selectQ :: Vector CInt 513selectQ :: Vector Int32
580 -> Vector (Complex Float) 514 -> Vector (Complex Float)
581 -> Vector (Complex Float) 515 -> Vector (Complex Float)
582 -> Vector (Complex Float) 516 -> Vector (Complex Float)
583 -> Vector (Complex Float) 517 -> Vector (Complex Float)
584selectQ = selectG c_selectQ 518selectQ = selectG c_selectQ
585 519
586type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) 520type Sel x = CV Int32 (CV x (CV x (CV x (CV x (IO Int32)))))
587 521
588foreign import ccall unsafe "chooseD" c_selectD :: Sel Double 522foreign import ccall unsafe "chooseD" c_selectD :: Sel Double
589foreign import ccall unsafe "chooseF" c_selectF :: Sel Float 523foreign import ccall unsafe "chooseF" c_selectF :: Sel Float
590foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt 524foreign import ccall unsafe "chooseI" c_selectI :: Sel Int32
591foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double) 525foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
592foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float) 526foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
593foreign import ccall unsafe "chooseL" c_selectL :: Sel Z 527foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
@@ -595,35 +529,35 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
595--------------------------------------------------------------------------- 529---------------------------------------------------------------------------
596 530
597remapG :: (TransArray c, TransArray c1, Storable t, Storable a) 531remapG :: (TransArray c, TransArray c1, Storable t, Storable a)
598 => (CInt -> CInt -> CInt -> CInt -> Ptr t 532 => (Int32 -> Int32 -> Int32 -> Int32 -> Ptr t
599 -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) 533 -> Trans c1 (Trans c (Int32 -> Int32 -> Int32 -> Int32 -> Ptr a -> IO Int32)))
600 -> Matrix t -> c1 -> c -> Matrix a 534 -> Matrix t -> c1 -> c -> Matrix a
601remapG f i j m = unsafePerformIO $ do 535remapG f i j m = unsafePerformIO $ do
602 r <- createMatrix RowMajor (rows i) (cols i) 536 r <- createMatrix RowMajor (rows i) (cols i)
603 (i # j # m #! r) f #|"remapG" 537 (i # j # m #! r) f #|"remapG"
604 return r 538 return r
605 539
606remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double 540remapD :: Matrix Int32 -> Matrix Int32 -> Matrix Double -> Matrix Double
607remapD = remapG c_remapD 541remapD = remapG c_remapD
608remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float 542remapF :: Matrix Int32 -> Matrix Int32 -> Matrix Float -> Matrix Float
609remapF = remapG c_remapF 543remapF = remapG c_remapF
610remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt 544remapI :: Matrix Int32 -> Matrix Int32 -> Matrix Int32 -> Matrix Int32
611remapI = remapG c_remapI 545remapI = remapG c_remapI
612remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z 546remapL :: Matrix Int32 -> Matrix Int32 -> Matrix Z -> Matrix Z
613remapL = remapG c_remapL 547remapL = remapG c_remapL
614remapC :: Matrix CInt 548remapC :: Matrix Int32
615 -> Matrix CInt 549 -> Matrix Int32
616 -> Matrix (Complex Double) 550 -> Matrix (Complex Double)
617 -> Matrix (Complex Double) 551 -> Matrix (Complex Double)
618remapC = remapG c_remapC 552remapC = remapG c_remapC
619remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) 553remapQ :: Matrix Int32 -> Matrix Int32 -> Matrix (Complex Float) -> Matrix (Complex Float)
620remapQ = remapG c_remapQ 554remapQ = remapG c_remapQ
621 555
622type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) 556type Rem x = OM Int32 (OM Int32 (OM x (OM x (IO Int32))))
623 557
624foreign import ccall unsafe "remapD" c_remapD :: Rem Double 558foreign import ccall unsafe "remapD" c_remapD :: Rem Double
625foreign import ccall unsafe "remapF" c_remapF :: Rem Float 559foreign import ccall unsafe "remapF" c_remapF :: Rem Float
626foreign import ccall unsafe "remapI" c_remapI :: Rem CInt 560foreign import ccall unsafe "remapI" c_remapI :: Rem Int32
627foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double) 561foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double)
628foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float) 562foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float)
629foreign import ccall unsafe "remapL" c_remapL :: Rem Z 563foreign import ccall unsafe "remapL" c_remapL :: Rem Z
@@ -631,14 +565,14 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
631-------------------------------------------------------------------------------- 565--------------------------------------------------------------------------------
632 566
633rowOpAux :: (TransArray c, Storable a) => 567rowOpAux :: (TransArray c, Storable a) =>
634 (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) 568 (Int32 -> Ptr a -> Int32 -> Int32 -> Int32 -> Int32 -> Trans c (IO Int32))
635 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () 569 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO ()
636rowOpAux f c x i1 i2 j1 j2 m = do 570rowOpAux f c x i1 i2 j1 j2 m = do
637 px <- newArray [x] 571 px <- newArray [x]
638 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" 572 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
639 free px 573 free px
640 574
641type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok 575type RowOp x = Int32 -> Ptr x -> Int32 -> Int32 -> Int32 -> Int32 -> x ::> Ok
642 576
643foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R 577foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R
644foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float 578foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float
@@ -652,7 +586,7 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
652-------------------------------------------------------------------------------- 586--------------------------------------------------------------------------------
653 587
654gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) 588gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3)
655 => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) 589 => Trans c3 (Trans c2 (Trans c1 (Trans c (IO Int32))))
656 -> c3 -> c2 -> c1 -> c -> IO () 590 -> c3 -> c2 -> c1 -> c -> IO ()
657gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" 591gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
658 592
@@ -669,21 +603,26 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
669 603
670-------------------------------------------------------------------------------- 604--------------------------------------------------------------------------------
671 605
606{-
672reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => 607reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) =>
673 (CInt -> Ptr a -> CInt -> Ptr t1 608 (Int32 -> Ptr a -> Int32 -> Ptr t1
674 -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) 609 -> Trans c (Int32 -> Ptr t -> Int32 -> Ptr a1 -> IO Int32))
675 -> Vector t1 -> c -> Vector t -> Vector a1 610 -> Vector t1 -> c -> Vector t -> Vector a1
611-}
612reorderAux :: (TransArray c, Storable a,
613 Trans c (Int32 -> Ptr a -> Int32 -> Ptr a -> IO Int32) ~ (Int32 -> ConstPtr Int32 -> Int32 -> ConstPtr a -> Int32 -> Ptr a -> IO Int32)) =>
614 p -> Vector Int32 -> c -> Vector a -> Vector a
676reorderAux f s d v = unsafePerformIO $ do 615reorderAux f s d v = unsafePerformIO $ do
677 k <- createVector (dim s) 616 k <- createVector (dim s)
678 r <- createVector (dim v) 617 r <- createVector (dim v)
679 (k # s # d # v #! r) f #| "reorderV" 618 (k # s # d # v #! r) reorderStorable #| "reorderV"
680 return r 619 return r
681 620
682type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt))))) 621type Reorder x = CV Int32 (CV Int32 (CV Int32 (CV x (CV x (IO Int32)))))
683 622
684foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double 623foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double
685foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float 624foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float
686foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt 625foreign import ccall unsafe "reorderI" c_reorderI :: Reorder Int32
687foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double) 626foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double)
688foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float) 627foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float)
689foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z 628foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
@@ -691,12 +630,12 @@ foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
691-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, 630-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices,
692-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ 631-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@
693-- This function is intended to be used internally by tensor libraries. 632-- This function is intended to be used internally by tensor libraries.
694reorderVector :: Element a 633reorderVector :: Storable a
695 => Vector CInt -- ^ @strides@: array strides 634 => Vector Int32 -- ^ @strides@: array strides
696 -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ 635 -> Vector Int32 -- ^ @dims@: array dimensions of new array @v@
697 -> Vector a -- ^ @v@: flattened input array 636 -> Vector a -- ^ @v@: flattened input array
698 -> Vector a -- ^ @v'@: flattened output array 637 -> Vector a -- ^ @v'@: flattened output array
699reorderVector = reorderV 638reorderVector = reorderAux ()
700 639
701-------------------------------------------------------------------------------- 640--------------------------------------------------------------------------------
702 641