summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs87
1 files changed, 86 insertions, 1 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 4905f61..4bfa13d 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -57,19 +57,24 @@ cols :: Matrix t -> Int
57cols = icols 57cols = icols
58{-# INLINE cols #-} 58{-# INLINE cols #-}
59 59
60size :: Matrix t -> (Int, Int)
60size m = (irows m, icols m) 61size m = (irows m, icols m)
61{-# INLINE size #-} 62{-# INLINE size #-}
62 63
64rowOrder :: Matrix t -> Bool
63rowOrder m = xCol m == 1 || cols m == 1 65rowOrder m = xCol m == 1 || cols m == 1
64{-# INLINE rowOrder #-} 66{-# INLINE rowOrder #-}
65 67
68colOrder :: Matrix t -> Bool
66colOrder m = xRow m == 1 || rows m == 1 69colOrder m = xRow m == 1 || rows m == 1
67{-# INLINE colOrder #-} 70{-# INLINE colOrder #-}
68 71
72is1d :: Matrix t -> Bool
69is1d (size->(r,c)) = r==1 || c==1 73is1d (size->(r,c)) = r==1 || c==1
70{-# INLINE is1d #-} 74{-# INLINE is1d #-}
71 75
72-- data is not contiguous 76-- data is not contiguous
77isSlice :: Storable t => Matrix t -> Bool
73isSlice m@(size->(r,c)) = r*c < dim (xdat m) 78isSlice m@(size->(r,c)) = r*c < dim (xdat m)
74{-# INLINE isSlice #-} 79{-# INLINE isSlice #-}
75 80
@@ -136,16 +141,20 @@ instance Storable t => TransArray (Matrix t)
136 {-# INLINE applyRaw #-} 141 {-# INLINE applyRaw #-}
137 142
138infixr 1 # 143infixr 1 #
144(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
139a # b = apply a b 145a # b = apply a b
140{-# INLINE (#) #-} 146{-# INLINE (#) #-}
141 147
148(#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
142a #! b = a # b # id 149a #! b = a # b # id
143{-# INLINE (#!) #-} 150{-# INLINE (#!) #-}
144 151
145-------------------------------------------------------------------------------- 152--------------------------------------------------------------------------------
146 153
154copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t)
147copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) 155copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1])
148 156
157extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t
149extractAll ord m = unsafePerformIO (copy ord m) 158extractAll ord m = unsafePerformIO (copy ord m)
150 159
151{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 160{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
@@ -223,11 +232,13 @@ m@Matrix {irows = r, icols = c} @@> (i,j)
223{-# INLINE (@@>) #-} 232{-# INLINE (@@>) #-}
224 233
225-- Unsafe matrix access without range checking 234-- Unsafe matrix access without range checking
235atM' :: Storable t => Matrix t -> Int -> Int -> t
226atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) 236atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
227{-# INLINE atM' #-} 237{-# INLINE atM' #-}
228 238
229------------------------------------------------------------------ 239------------------------------------------------------------------
230 240
241matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t
231matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } 242matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
232matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } 243matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
233matrixFromVector o r c v 244matrixFromVector o r c v
@@ -387,18 +398,21 @@ subMatrix (r0,c0) (rt,ct) m
387 398
388-------------------------------------------------------------------------- 399--------------------------------------------------------------------------
389 400
401maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1
390maxZ xs = if minimum xs == 0 then 0 else maximum xs 402maxZ xs = if minimum xs == 0 then 0 else maximum xs
391 403
404conformMs :: Element t => [Matrix t] -> [Matrix t]
392conformMs ms = map (conformMTo (r,c)) ms 405conformMs ms = map (conformMTo (r,c)) ms
393 where 406 where
394 r = maxZ (map rows ms) 407 r = maxZ (map rows ms)
395 c = maxZ (map cols ms) 408 c = maxZ (map cols ms)
396 409
397 410conformVs :: Element t => [Vector t] -> [Vector t]
398conformVs vs = map (conformVTo n) vs 411conformVs vs = map (conformVTo n) vs
399 where 412 where
400 n = maxZ (map dim vs) 413 n = maxZ (map dim vs)
401 414
415conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t
402conformMTo (r,c) m 416conformMTo (r,c) m
403 | size m == (r,c) = m 417 | size m == (r,c) = m
404 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) 418 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
@@ -406,18 +420,24 @@ conformMTo (r,c) m
406 | size m == (1,c) = repRows r m 420 | size m == (1,c) = repRows r m
407 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) 421 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c)
408 422
423conformVTo :: Element t => Int -> Vector t -> Vector t
409conformVTo n v 424conformVTo n v
410 | dim v == n = v 425 | dim v == n = v
411 | dim v == 1 = constantD (v@>0) n 426 | dim v == 1 = constantD (v@>0) n
412 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n 427 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
413 428
429repRows :: Element t => Int -> Matrix t -> Matrix t
414repRows n x = fromRows (replicate n (flatten x)) 430repRows n x = fromRows (replicate n (flatten x))
431repCols :: Element t => Int -> Matrix t -> Matrix t
415repCols n x = fromColumns (replicate n (flatten x)) 432repCols n x = fromColumns (replicate n (flatten x))
416 433
434shSize :: Matrix t -> [Char]
417shSize = shDim . size 435shSize = shDim . size
418 436
437shDim :: (Show a, Show a1) => (a1, a) -> [Char]
419shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" 438shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
420 439
440emptyM :: Storable t => Int -> Int -> Matrix t
421emptyM r c = matrixFromVector RowMajor r c (fromList[]) 441emptyM r c = matrixFromVector RowMajor r c (fromList[])
422 442
423---------------------------------------------------------------------- 443----------------------------------------------------------------------
@@ -432,6 +452,11 @@ instance (Storable t, NFData t) => NFData (Matrix t)
432 452
433--------------------------------------------------------------- 453---------------------------------------------------------------
434 454
455extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1,
456 Storable t, Num t3, Num t2, Integral t1, Integral t)
457 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t
458 -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))
459 -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a)
435extractAux f ord m moder vr modec vc = do 460extractAux f ord m moder vr modec vc = do
436 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 461 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
437 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 462 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
@@ -451,6 +476,9 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z
451 476
452--------------------------------------------------------------- 477---------------------------------------------------------------
453 478
479setRectAux :: (TransArray c1, TransArray c)
480 => (CInt -> CInt -> Trans c1 (Trans c (IO CInt)))
481 -> Int -> Int -> c1 -> c -> IO ()
454setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" 482setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
455 483
456type SetRect x = I -> I -> x ::> x::> Ok 484type SetRect x = I -> I -> x ::> x::> Ok
@@ -464,19 +492,29 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
464 492
465-------------------------------------------------------------------------------- 493--------------------------------------------------------------------------------
466 494
495sortG :: (Storable t, Storable a)
496 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
467sortG f v = unsafePerformIO $ do 497sortG f v = unsafePerformIO $ do
468 r <- createVector (dim v) 498 r <- createVector (dim v)
469 (v #! r) f #|"sortG" 499 (v #! r) f #|"sortG"
470 return r 500 return r
471 501
502sortIdxD :: Vector Double -> Vector CInt
472sortIdxD = sortG c_sort_indexD 503sortIdxD = sortG c_sort_indexD
504sortIdxF :: Vector Float -> Vector CInt
473sortIdxF = sortG c_sort_indexF 505sortIdxF = sortG c_sort_indexF
506sortIdxI :: Vector CInt -> Vector CInt
474sortIdxI = sortG c_sort_indexI 507sortIdxI = sortG c_sort_indexI
508sortIdxL :: Vector Z -> Vector I
475sortIdxL = sortG c_sort_indexL 509sortIdxL = sortG c_sort_indexL
476 510
511sortValD :: Vector Double -> Vector Double
477sortValD = sortG c_sort_valD 512sortValD = sortG c_sort_valD
513sortValF :: Vector Float -> Vector Float
478sortValF = sortG c_sort_valF 514sortValF = sortG c_sort_valF
515sortValI :: Vector CInt -> Vector CInt
479sortValI = sortG c_sort_valI 516sortValI = sortG c_sort_valI
517sortValL :: Vector Z -> Vector Z
480sortValL = sortG c_sort_valL 518sortValL = sortG c_sort_valL
481 519
482foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) 520foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
@@ -491,14 +529,21 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
491 529
492-------------------------------------------------------------------------------- 530--------------------------------------------------------------------------------
493 531
532compareG :: (TransArray c, Storable t, Storable a)
533 => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt)
534 -> c -> Vector t -> Vector a
494compareG f u v = unsafePerformIO $ do 535compareG f u v = unsafePerformIO $ do
495 r <- createVector (dim v) 536 r <- createVector (dim v)
496 (u # v #! r) f #|"compareG" 537 (u # v #! r) f #|"compareG"
497 return r 538 return r
498 539
540compareD :: Vector Double -> Vector Double -> Vector CInt
499compareD = compareG c_compareD 541compareD = compareG c_compareD
542compareF :: Vector Float -> Vector Float -> Vector CInt
500compareF = compareG c_compareF 543compareF = compareG c_compareF
544compareI :: Vector CInt -> Vector CInt -> Vector CInt
501compareI = compareG c_compareI 545compareI = compareG c_compareI
546compareL :: Vector Z -> Vector Z -> Vector CInt
502compareL = compareG c_compareL 547compareL = compareG c_compareL
503 548
504foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) 549foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
@@ -508,16 +553,33 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
508 553
509-------------------------------------------------------------------------------- 554--------------------------------------------------------------------------------
510 555
556selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a)
557 => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt)))
558 -> c2 -> c1 -> Vector t -> c -> Vector a
511selectG f c u v w = unsafePerformIO $ do 559selectG f c u v w = unsafePerformIO $ do
512 r <- createVector (dim v) 560 r <- createVector (dim v)
513 (c # u # v # w #! r) f #|"selectG" 561 (c # u # v # w #! r) f #|"selectG"
514 return r 562 return r
515 563
564selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double
516selectD = selectG c_selectD 565selectD = selectG c_selectD
566selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float
517selectF = selectG c_selectF 567selectF = selectG c_selectF
568selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt
518selectI = selectG c_selectI 569selectI = selectG c_selectI
570selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z
519selectL = selectG c_selectL 571selectL = selectG c_selectL
572selectC :: Vector CInt
573 -> Vector (Complex Double)
574 -> Vector (Complex Double)
575 -> Vector (Complex Double)
576 -> Vector (Complex Double)
520selectC = selectG c_selectC 577selectC = selectG c_selectC
578selectQ :: Vector CInt
579 -> Vector (Complex Float)
580 -> Vector (Complex Float)
581 -> Vector (Complex Float)
582 -> Vector (Complex Float)
521selectQ = selectG c_selectQ 583selectQ = selectG c_selectQ
522 584
523type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) 585type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
@@ -531,16 +593,29 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
531 593
532--------------------------------------------------------------------------- 594---------------------------------------------------------------------------
533 595
596remapG :: (TransArray c, TransArray c1, Storable t, Storable a)
597 => (CInt -> CInt -> CInt -> CInt -> Ptr t
598 -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)))
599 -> Matrix t -> c1 -> c -> Matrix a
534remapG f i j m = unsafePerformIO $ do 600remapG f i j m = unsafePerformIO $ do
535 r <- createMatrix RowMajor (rows i) (cols i) 601 r <- createMatrix RowMajor (rows i) (cols i)
536 (i # j # m #! r) f #|"remapG" 602 (i # j # m #! r) f #|"remapG"
537 return r 603 return r
538 604
605remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double
539remapD = remapG c_remapD 606remapD = remapG c_remapD
607remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float
540remapF = remapG c_remapF 608remapF = remapG c_remapF
609remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt
541remapI = remapG c_remapI 610remapI = remapG c_remapI
611remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z
542remapL = remapG c_remapL 612remapL = remapG c_remapL
613remapC :: Matrix CInt
614 -> Matrix CInt
615 -> Matrix (Complex Double)
616 -> Matrix (Complex Double)
543remapC = remapG c_remapC 617remapC = remapG c_remapC
618remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float)
544remapQ = remapG c_remapQ 619remapQ = remapG c_remapQ
545 620
546type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) 621type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
@@ -554,6 +629,9 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
554 629
555-------------------------------------------------------------------------------- 630--------------------------------------------------------------------------------
556 631
632rowOpAux :: (TransArray c, Storable a) =>
633 (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt))
634 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO ()
557rowOpAux f c x i1 i2 j1 j2 m = do 635rowOpAux f c x i1 i2 j1 j2 m = do
558 px <- newArray [x] 636 px <- newArray [x]
559 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" 637 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
@@ -572,6 +650,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
572 650
573-------------------------------------------------------------------------------- 651--------------------------------------------------------------------------------
574 652
653gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3)
654 => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt))))
655 -> c3 -> c2 -> c1 -> c -> IO ()
575gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" 656gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
576 657
577type Tgemm x = x :> x ::> x ::> x ::> Ok 658type Tgemm x = x :> x ::> x ::> x ::> Ok
@@ -587,6 +668,10 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
587 668
588-------------------------------------------------------------------------------- 669--------------------------------------------------------------------------------
589 670
671reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) =>
672 (CInt -> Ptr a -> CInt -> Ptr t1
673 -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt))
674 -> Vector t1 -> c -> Vector t -> Vector a1
590reorderAux f s d v = unsafePerformIO $ do 675reorderAux f s d v = unsafePerformIO $ do
591 k <- createVector (dim s) 676 k <- createVector (dim s)
592 r <- createVector (dim v) 677 r <- createVector (dim v)