summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs87
1 files changed, 86 insertions, 1 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 2856ec2..5436e59 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -57,19 +57,24 @@ cols :: Matrix t -> Int
57cols = icols 57cols = icols
58{-# INLINE cols #-} 58{-# INLINE cols #-}
59 59
60size :: Matrix t -> (Int, Int)
60size m = (irows m, icols m) 61size m = (irows m, icols m)
61{-# INLINE size #-} 62{-# INLINE size #-}
62 63
64rowOrder :: Matrix t -> Bool
63rowOrder m = xCol m == 1 || cols m == 1 65rowOrder m = xCol m == 1 || cols m == 1
64{-# INLINE rowOrder #-} 66{-# INLINE rowOrder #-}
65 67
68colOrder :: Matrix t -> Bool
66colOrder m = xRow m == 1 || rows m == 1 69colOrder m = xRow m == 1 || rows m == 1
67{-# INLINE colOrder #-} 70{-# INLINE colOrder #-}
68 71
72is1d :: Matrix t -> Bool
69is1d (size->(r,c)) = r==1 || c==1 73is1d (size->(r,c)) = r==1 || c==1
70{-# INLINE is1d #-} 74{-# INLINE is1d #-}
71 75
72-- data is not contiguous 76-- data is not contiguous
77isSlice :: Storable t => Matrix t -> Bool
73isSlice m@(size->(r,c)) = r*c < dim (xdat m) 78isSlice m@(size->(r,c)) = r*c < dim (xdat m)
74{-# INLINE isSlice #-} 79{-# INLINE isSlice #-}
75 80
@@ -136,16 +141,20 @@ instance Storable t => TransArray (Matrix t)
136 {-# INLINE applyRaw #-} 141 {-# INLINE applyRaw #-}
137 142
138infixr 1 # 143infixr 1 #
144(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
139a # b = apply a b 145a # b = apply a b
140{-# INLINE (#) #-} 146{-# INLINE (#) #-}
141 147
148(#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
142a #! b = a # b # id 149a #! b = a # b # id
143{-# INLINE (#!) #-} 150{-# INLINE (#!) #-}
144 151
145-------------------------------------------------------------------------------- 152--------------------------------------------------------------------------------
146 153
154copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t)
147copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) 155copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1])
148 156
157extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t
149extractAll ord m = unsafePerformIO (copy ord m) 158extractAll ord m = unsafePerformIO (copy ord m)
150 159
151{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 160{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
@@ -224,11 +233,13 @@ m@Matrix {irows = r, icols = c} @@> (i,j)
224{-# INLINE (@@>) #-} 233{-# INLINE (@@>) #-}
225 234
226-- Unsafe matrix access without range checking 235-- Unsafe matrix access without range checking
236atM' :: Storable t => Matrix t -> Int -> Int -> t
227atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) 237atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
228{-# INLINE atM' #-} 238{-# INLINE atM' #-}
229 239
230------------------------------------------------------------------ 240------------------------------------------------------------------
231 241
242matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t
232matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } 243matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
233matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } 244matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
234matrixFromVector o r c v 245matrixFromVector o r c v
@@ -388,18 +399,21 @@ subMatrix (r0,c0) (rt,ct) m
388 399
389-------------------------------------------------------------------------- 400--------------------------------------------------------------------------
390 401
402maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1
391maxZ xs = if minimum xs == 0 then 0 else maximum xs 403maxZ xs = if minimum xs == 0 then 0 else maximum xs
392 404
405conformMs :: Element t => [Matrix t] -> [Matrix t]
393conformMs ms = map (conformMTo (r,c)) ms 406conformMs ms = map (conformMTo (r,c)) ms
394 where 407 where
395 r = maxZ (map rows ms) 408 r = maxZ (map rows ms)
396 c = maxZ (map cols ms) 409 c = maxZ (map cols ms)
397 410
398 411conformVs :: Element t => [Vector t] -> [Vector t]
399conformVs vs = map (conformVTo n) vs 412conformVs vs = map (conformVTo n) vs
400 where 413 where
401 n = maxZ (map dim vs) 414 n = maxZ (map dim vs)
402 415
416conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t
403conformMTo (r,c) m 417conformMTo (r,c) m
404 | size m == (r,c) = m 418 | size m == (r,c) = m
405 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) 419 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
@@ -407,18 +421,24 @@ conformMTo (r,c) m
407 | size m == (1,c) = repRows r m 421 | size m == (1,c) = repRows r m
408 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) 422 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c)
409 423
424conformVTo :: Element t => Int -> Vector t -> Vector t
410conformVTo n v 425conformVTo n v
411 | dim v == n = v 426 | dim v == n = v
412 | dim v == 1 = constantD (v@>0) n 427 | dim v == 1 = constantD (v@>0) n
413 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n 428 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
414 429
430repRows :: Element t => Int -> Matrix t -> Matrix t
415repRows n x = fromRows (replicate n (flatten x)) 431repRows n x = fromRows (replicate n (flatten x))
432repCols :: Element t => Int -> Matrix t -> Matrix t
416repCols n x = fromColumns (replicate n (flatten x)) 433repCols n x = fromColumns (replicate n (flatten x))
417 434
435shSize :: Matrix t -> [Char]
418shSize = shDim . size 436shSize = shDim . size
419 437
438shDim :: (Show a, Show a1) => (a1, a) -> [Char]
420shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" 439shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
421 440
441emptyM :: Storable t => Int -> Int -> Matrix t
422emptyM r c = matrixFromVector RowMajor r c (fromList[]) 442emptyM r c = matrixFromVector RowMajor r c (fromList[])
423 443
424---------------------------------------------------------------------- 444----------------------------------------------------------------------
@@ -433,6 +453,11 @@ instance (Storable t, NFData t) => NFData (Matrix t)
433 453
434--------------------------------------------------------------- 454---------------------------------------------------------------
435 455
456extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1,
457 Storable t, Num t3, Num t2, Integral t1, Integral t)
458 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t
459 -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))
460 -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a)
436extractAux f ord m moder vr modec vc = do 461extractAux f ord m moder vr modec vc = do
437 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr 462 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
438 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc 463 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
@@ -452,6 +477,9 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z
452 477
453--------------------------------------------------------------- 478---------------------------------------------------------------
454 479
480setRectAux :: (TransArray c1, TransArray c)
481 => (CInt -> CInt -> Trans c1 (Trans c (IO CInt)))
482 -> Int -> Int -> c1 -> c -> IO ()
455setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" 483setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
456 484
457type SetRect x = I -> I -> x ::> x::> Ok 485type SetRect x = I -> I -> x ::> x::> Ok
@@ -465,19 +493,29 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
465 493
466-------------------------------------------------------------------------------- 494--------------------------------------------------------------------------------
467 495
496sortG :: (Storable t, Storable a)
497 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
468sortG f v = unsafePerformIO $ do 498sortG f v = unsafePerformIO $ do
469 r <- createVector (dim v) 499 r <- createVector (dim v)
470 (v #! r) f #|"sortG" 500 (v #! r) f #|"sortG"
471 return r 501 return r
472 502
503sortIdxD :: Vector Double -> Vector CInt
473sortIdxD = sortG c_sort_indexD 504sortIdxD = sortG c_sort_indexD
505sortIdxF :: Vector Float -> Vector CInt
474sortIdxF = sortG c_sort_indexF 506sortIdxF = sortG c_sort_indexF
507sortIdxI :: Vector CInt -> Vector CInt
475sortIdxI = sortG c_sort_indexI 508sortIdxI = sortG c_sort_indexI
509sortIdxL :: Vector Z -> Vector I
476sortIdxL = sortG c_sort_indexL 510sortIdxL = sortG c_sort_indexL
477 511
512sortValD :: Vector Double -> Vector Double
478sortValD = sortG c_sort_valD 513sortValD = sortG c_sort_valD
514sortValF :: Vector Float -> Vector Float
479sortValF = sortG c_sort_valF 515sortValF = sortG c_sort_valF
516sortValI :: Vector CInt -> Vector CInt
480sortValI = sortG c_sort_valI 517sortValI = sortG c_sort_valI
518sortValL :: Vector Z -> Vector Z
481sortValL = sortG c_sort_valL 519sortValL = sortG c_sort_valL
482 520
483foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) 521foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
@@ -492,14 +530,21 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
492 530
493-------------------------------------------------------------------------------- 531--------------------------------------------------------------------------------
494 532
533compareG :: (TransArray c, Storable t, Storable a)
534 => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt)
535 -> c -> Vector t -> Vector a
495compareG f u v = unsafePerformIO $ do 536compareG f u v = unsafePerformIO $ do
496 r <- createVector (dim v) 537 r <- createVector (dim v)
497 (u # v #! r) f #|"compareG" 538 (u # v #! r) f #|"compareG"
498 return r 539 return r
499 540
541compareD :: Vector Double -> Vector Double -> Vector CInt
500compareD = compareG c_compareD 542compareD = compareG c_compareD
543compareF :: Vector Float -> Vector Float -> Vector CInt
501compareF = compareG c_compareF 544compareF = compareG c_compareF
545compareI :: Vector CInt -> Vector CInt -> Vector CInt
502compareI = compareG c_compareI 546compareI = compareG c_compareI
547compareL :: Vector Z -> Vector Z -> Vector CInt
503compareL = compareG c_compareL 548compareL = compareG c_compareL
504 549
505foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) 550foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
@@ -509,16 +554,33 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
509 554
510-------------------------------------------------------------------------------- 555--------------------------------------------------------------------------------
511 556
557selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a)
558 => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt)))
559 -> c2 -> c1 -> Vector t -> c -> Vector a
512selectG f c u v w = unsafePerformIO $ do 560selectG f c u v w = unsafePerformIO $ do
513 r <- createVector (dim v) 561 r <- createVector (dim v)
514 (c # u # v # w #! r) f #|"selectG" 562 (c # u # v # w #! r) f #|"selectG"
515 return r 563 return r
516 564
565selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double
517selectD = selectG c_selectD 566selectD = selectG c_selectD
567selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float
518selectF = selectG c_selectF 568selectF = selectG c_selectF
569selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt
519selectI = selectG c_selectI 570selectI = selectG c_selectI
571selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z
520selectL = selectG c_selectL 572selectL = selectG c_selectL
573selectC :: Vector CInt
574 -> Vector (Complex Double)
575 -> Vector (Complex Double)
576 -> Vector (Complex Double)
577 -> Vector (Complex Double)
521selectC = selectG c_selectC 578selectC = selectG c_selectC
579selectQ :: Vector CInt
580 -> Vector (Complex Float)
581 -> Vector (Complex Float)
582 -> Vector (Complex Float)
583 -> Vector (Complex Float)
522selectQ = selectG c_selectQ 584selectQ = selectG c_selectQ
523 585
524type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) 586type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
@@ -532,16 +594,29 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
532 594
533--------------------------------------------------------------------------- 595---------------------------------------------------------------------------
534 596
597remapG :: (TransArray c, TransArray c1, Storable t, Storable a)
598 => (CInt -> CInt -> CInt -> CInt -> Ptr t
599 -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)))
600 -> Matrix t -> c1 -> c -> Matrix a
535remapG f i j m = unsafePerformIO $ do 601remapG f i j m = unsafePerformIO $ do
536 r <- createMatrix RowMajor (rows i) (cols i) 602 r <- createMatrix RowMajor (rows i) (cols i)
537 (i # j # m #! r) f #|"remapG" 603 (i # j # m #! r) f #|"remapG"
538 return r 604 return r
539 605
606remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double
540remapD = remapG c_remapD 607remapD = remapG c_remapD
608remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float
541remapF = remapG c_remapF 609remapF = remapG c_remapF
610remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt
542remapI = remapG c_remapI 611remapI = remapG c_remapI
612remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z
543remapL = remapG c_remapL 613remapL = remapG c_remapL
614remapC :: Matrix CInt
615 -> Matrix CInt
616 -> Matrix (Complex Double)
617 -> Matrix (Complex Double)
544remapC = remapG c_remapC 618remapC = remapG c_remapC
619remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float)
545remapQ = remapG c_remapQ 620remapQ = remapG c_remapQ
546 621
547type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) 622type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
@@ -555,6 +630,9 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z
555 630
556-------------------------------------------------------------------------------- 631--------------------------------------------------------------------------------
557 632
633rowOpAux :: (TransArray c, Storable a) =>
634 (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt))
635 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO ()
558rowOpAux f c x i1 i2 j1 j2 m = do 636rowOpAux f c x i1 i2 j1 j2 m = do
559 px <- newArray [x] 637 px <- newArray [x]
560 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" 638 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
@@ -573,6 +651,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
573 651
574-------------------------------------------------------------------------------- 652--------------------------------------------------------------------------------
575 653
654gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3)
655 => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt))))
656 -> c3 -> c2 -> c1 -> c -> IO ()
576gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" 657gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
577 658
578type Tgemm x = x :> x ::> x ::> x ::> Ok 659type Tgemm x = x :> x ::> x ::> x ::> Ok
@@ -588,6 +669,10 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
588 669
589-------------------------------------------------------------------------------- 670--------------------------------------------------------------------------------
590 671
672reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) =>
673 (CInt -> Ptr a -> CInt -> Ptr t1
674 -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt))
675 -> Vector t1 -> c -> Vector t -> Vector a1
591reorderAux f s d v = unsafePerformIO $ do 676reorderAux f s d v = unsafePerformIO $ do
592 k <- createVector (dim s) 677 k <- createVector (dim s)
593 r <- createVector (dim v) 678 r <- createVector (dim v)