diff options
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 87 |
1 files changed, 86 insertions, 1 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 2856ec2..5436e59 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -57,19 +57,24 @@ cols :: Matrix t -> Int | |||
57 | cols = icols | 57 | cols = icols |
58 | {-# INLINE cols #-} | 58 | {-# INLINE cols #-} |
59 | 59 | ||
60 | size :: Matrix t -> (Int, Int) | ||
60 | size m = (irows m, icols m) | 61 | size m = (irows m, icols m) |
61 | {-# INLINE size #-} | 62 | {-# INLINE size #-} |
62 | 63 | ||
64 | rowOrder :: Matrix t -> Bool | ||
63 | rowOrder m = xCol m == 1 || cols m == 1 | 65 | rowOrder m = xCol m == 1 || cols m == 1 |
64 | {-# INLINE rowOrder #-} | 66 | {-# INLINE rowOrder #-} |
65 | 67 | ||
68 | colOrder :: Matrix t -> Bool | ||
66 | colOrder m = xRow m == 1 || rows m == 1 | 69 | colOrder m = xRow m == 1 || rows m == 1 |
67 | {-# INLINE colOrder #-} | 70 | {-# INLINE colOrder #-} |
68 | 71 | ||
72 | is1d :: Matrix t -> Bool | ||
69 | is1d (size->(r,c)) = r==1 || c==1 | 73 | is1d (size->(r,c)) = r==1 || c==1 |
70 | {-# INLINE is1d #-} | 74 | {-# INLINE is1d #-} |
71 | 75 | ||
72 | -- data is not contiguous | 76 | -- data is not contiguous |
77 | isSlice :: Storable t => Matrix t -> Bool | ||
73 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) | 78 | isSlice m@(size->(r,c)) = r*c < dim (xdat m) |
74 | {-# INLINE isSlice #-} | 79 | {-# INLINE isSlice #-} |
75 | 80 | ||
@@ -136,16 +141,20 @@ instance Storable t => TransArray (Matrix t) | |||
136 | {-# INLINE applyRaw #-} | 141 | {-# INLINE applyRaw #-} |
137 | 142 | ||
138 | infixr 1 # | 143 | infixr 1 # |
144 | (#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r | ||
139 | a # b = apply a b | 145 | a # b = apply a b |
140 | {-# INLINE (#) #-} | 146 | {-# INLINE (#) #-} |
141 | 147 | ||
148 | (#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r | ||
142 | a #! b = a # b # id | 149 | a #! b = a # b # id |
143 | {-# INLINE (#!) #-} | 150 | {-# INLINE (#!) #-} |
144 | 151 | ||
145 | -------------------------------------------------------------------------------- | 152 | -------------------------------------------------------------------------------- |
146 | 153 | ||
154 | copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) | ||
147 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) | 155 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) |
148 | 156 | ||
157 | extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t | ||
149 | extractAll ord m = unsafePerformIO (copy ord m) | 158 | extractAll ord m = unsafePerformIO (copy ord m) |
150 | 159 | ||
151 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 160 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
@@ -224,11 +233,13 @@ m@Matrix {irows = r, icols = c} @@> (i,j) | |||
224 | {-# INLINE (@@>) #-} | 233 | {-# INLINE (@@>) #-} |
225 | 234 | ||
226 | -- Unsafe matrix access without range checking | 235 | -- Unsafe matrix access without range checking |
236 | atM' :: Storable t => Matrix t -> Int -> Int -> t | ||
227 | atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) | 237 | atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) |
228 | {-# INLINE atM' #-} | 238 | {-# INLINE atM' #-} |
229 | 239 | ||
230 | ------------------------------------------------------------------ | 240 | ------------------------------------------------------------------ |
231 | 241 | ||
242 | matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t | ||
232 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } | 243 | matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } |
233 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } | 244 | matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } |
234 | matrixFromVector o r c v | 245 | matrixFromVector o r c v |
@@ -388,18 +399,21 @@ subMatrix (r0,c0) (rt,ct) m | |||
388 | 399 | ||
389 | -------------------------------------------------------------------------- | 400 | -------------------------------------------------------------------------- |
390 | 401 | ||
402 | maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 | ||
391 | maxZ xs = if minimum xs == 0 then 0 else maximum xs | 403 | maxZ xs = if minimum xs == 0 then 0 else maximum xs |
392 | 404 | ||
405 | conformMs :: Element t => [Matrix t] -> [Matrix t] | ||
393 | conformMs ms = map (conformMTo (r,c)) ms | 406 | conformMs ms = map (conformMTo (r,c)) ms |
394 | where | 407 | where |
395 | r = maxZ (map rows ms) | 408 | r = maxZ (map rows ms) |
396 | c = maxZ (map cols ms) | 409 | c = maxZ (map cols ms) |
397 | 410 | ||
398 | 411 | conformVs :: Element t => [Vector t] -> [Vector t] | |
399 | conformVs vs = map (conformVTo n) vs | 412 | conformVs vs = map (conformVTo n) vs |
400 | where | 413 | where |
401 | n = maxZ (map dim vs) | 414 | n = maxZ (map dim vs) |
402 | 415 | ||
416 | conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t | ||
403 | conformMTo (r,c) m | 417 | conformMTo (r,c) m |
404 | | size m == (r,c) = m | 418 | | size m == (r,c) = m |
405 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) | 419 | | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) |
@@ -407,18 +421,24 @@ conformMTo (r,c) m | |||
407 | | size m == (1,c) = repRows r m | 421 | | size m == (1,c) = repRows r m |
408 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) | 422 | | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) |
409 | 423 | ||
424 | conformVTo :: Element t => Int -> Vector t -> Vector t | ||
410 | conformVTo n v | 425 | conformVTo n v |
411 | | dim v == n = v | 426 | | dim v == n = v |
412 | | dim v == 1 = constantD (v@>0) n | 427 | | dim v == 1 = constantD (v@>0) n |
413 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n | 428 | | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n |
414 | 429 | ||
430 | repRows :: Element t => Int -> Matrix t -> Matrix t | ||
415 | repRows n x = fromRows (replicate n (flatten x)) | 431 | repRows n x = fromRows (replicate n (flatten x)) |
432 | repCols :: Element t => Int -> Matrix t -> Matrix t | ||
416 | repCols n x = fromColumns (replicate n (flatten x)) | 433 | repCols n x = fromColumns (replicate n (flatten x)) |
417 | 434 | ||
435 | shSize :: Matrix t -> [Char] | ||
418 | shSize = shDim . size | 436 | shSize = shDim . size |
419 | 437 | ||
438 | shDim :: (Show a, Show a1) => (a1, a) -> [Char] | ||
420 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" | 439 | shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" |
421 | 440 | ||
441 | emptyM :: Storable t => Int -> Int -> Matrix t | ||
422 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) | 442 | emptyM r c = matrixFromVector RowMajor r c (fromList[]) |
423 | 443 | ||
424 | ---------------------------------------------------------------------- | 444 | ---------------------------------------------------------------------- |
@@ -433,6 +453,11 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
433 | 453 | ||
434 | --------------------------------------------------------------- | 454 | --------------------------------------------------------------- |
435 | 455 | ||
456 | extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, | ||
457 | Storable t, Num t3, Num t2, Integral t1, Integral t) | ||
458 | => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t | ||
459 | -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) | ||
460 | -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) | ||
436 | extractAux f ord m moder vr modec vc = do | 461 | extractAux f ord m moder vr modec vc = do |
437 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr | 462 | let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr |
438 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc | 463 | nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc |
@@ -452,6 +477,9 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z | |||
452 | 477 | ||
453 | --------------------------------------------------------------- | 478 | --------------------------------------------------------------- |
454 | 479 | ||
480 | setRectAux :: (TransArray c1, TransArray c) | ||
481 | => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) | ||
482 | -> Int -> Int -> c1 -> c -> IO () | ||
455 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" | 483 | setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" |
456 | 484 | ||
457 | type SetRect x = I -> I -> x ::> x::> Ok | 485 | type SetRect x = I -> I -> x ::> x::> Ok |
@@ -465,19 +493,29 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | |||
465 | 493 | ||
466 | -------------------------------------------------------------------------------- | 494 | -------------------------------------------------------------------------------- |
467 | 495 | ||
496 | sortG :: (Storable t, Storable a) | ||
497 | => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a | ||
468 | sortG f v = unsafePerformIO $ do | 498 | sortG f v = unsafePerformIO $ do |
469 | r <- createVector (dim v) | 499 | r <- createVector (dim v) |
470 | (v #! r) f #|"sortG" | 500 | (v #! r) f #|"sortG" |
471 | return r | 501 | return r |
472 | 502 | ||
503 | sortIdxD :: Vector Double -> Vector CInt | ||
473 | sortIdxD = sortG c_sort_indexD | 504 | sortIdxD = sortG c_sort_indexD |
505 | sortIdxF :: Vector Float -> Vector CInt | ||
474 | sortIdxF = sortG c_sort_indexF | 506 | sortIdxF = sortG c_sort_indexF |
507 | sortIdxI :: Vector CInt -> Vector CInt | ||
475 | sortIdxI = sortG c_sort_indexI | 508 | sortIdxI = sortG c_sort_indexI |
509 | sortIdxL :: Vector Z -> Vector I | ||
476 | sortIdxL = sortG c_sort_indexL | 510 | sortIdxL = sortG c_sort_indexL |
477 | 511 | ||
512 | sortValD :: Vector Double -> Vector Double | ||
478 | sortValD = sortG c_sort_valD | 513 | sortValD = sortG c_sort_valD |
514 | sortValF :: Vector Float -> Vector Float | ||
479 | sortValF = sortG c_sort_valF | 515 | sortValF = sortG c_sort_valF |
516 | sortValI :: Vector CInt -> Vector CInt | ||
480 | sortValI = sortG c_sort_valI | 517 | sortValI = sortG c_sort_valI |
518 | sortValL :: Vector Z -> Vector Z | ||
481 | sortValL = sortG c_sort_valL | 519 | sortValL = sortG c_sort_valL |
482 | 520 | ||
483 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) | 521 | foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) |
@@ -492,14 +530,21 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok | |||
492 | 530 | ||
493 | -------------------------------------------------------------------------------- | 531 | -------------------------------------------------------------------------------- |
494 | 532 | ||
533 | compareG :: (TransArray c, Storable t, Storable a) | ||
534 | => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) | ||
535 | -> c -> Vector t -> Vector a | ||
495 | compareG f u v = unsafePerformIO $ do | 536 | compareG f u v = unsafePerformIO $ do |
496 | r <- createVector (dim v) | 537 | r <- createVector (dim v) |
497 | (u # v #! r) f #|"compareG" | 538 | (u # v #! r) f #|"compareG" |
498 | return r | 539 | return r |
499 | 540 | ||
541 | compareD :: Vector Double -> Vector Double -> Vector CInt | ||
500 | compareD = compareG c_compareD | 542 | compareD = compareG c_compareD |
543 | compareF :: Vector Float -> Vector Float -> Vector CInt | ||
501 | compareF = compareG c_compareF | 544 | compareF = compareG c_compareF |
545 | compareI :: Vector CInt -> Vector CInt -> Vector CInt | ||
502 | compareI = compareG c_compareI | 546 | compareI = compareG c_compareI |
547 | compareL :: Vector Z -> Vector Z -> Vector CInt | ||
503 | compareL = compareG c_compareL | 548 | compareL = compareG c_compareL |
504 | 549 | ||
505 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) | 550 | foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) |
@@ -509,16 +554,33 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok | |||
509 | 554 | ||
510 | -------------------------------------------------------------------------------- | 555 | -------------------------------------------------------------------------------- |
511 | 556 | ||
557 | selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) | ||
558 | => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) | ||
559 | -> c2 -> c1 -> Vector t -> c -> Vector a | ||
512 | selectG f c u v w = unsafePerformIO $ do | 560 | selectG f c u v w = unsafePerformIO $ do |
513 | r <- createVector (dim v) | 561 | r <- createVector (dim v) |
514 | (c # u # v # w #! r) f #|"selectG" | 562 | (c # u # v # w #! r) f #|"selectG" |
515 | return r | 563 | return r |
516 | 564 | ||
565 | selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double | ||
517 | selectD = selectG c_selectD | 566 | selectD = selectG c_selectD |
567 | selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float | ||
518 | selectF = selectG c_selectF | 568 | selectF = selectG c_selectF |
569 | selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt | ||
519 | selectI = selectG c_selectI | 570 | selectI = selectG c_selectI |
571 | selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z | ||
520 | selectL = selectG c_selectL | 572 | selectL = selectG c_selectL |
573 | selectC :: Vector CInt | ||
574 | -> Vector (Complex Double) | ||
575 | -> Vector (Complex Double) | ||
576 | -> Vector (Complex Double) | ||
577 | -> Vector (Complex Double) | ||
521 | selectC = selectG c_selectC | 578 | selectC = selectG c_selectC |
579 | selectQ :: Vector CInt | ||
580 | -> Vector (Complex Float) | ||
581 | -> Vector (Complex Float) | ||
582 | -> Vector (Complex Float) | ||
583 | -> Vector (Complex Float) | ||
522 | selectQ = selectG c_selectQ | 584 | selectQ = selectG c_selectQ |
523 | 585 | ||
524 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) | 586 | type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) |
@@ -532,16 +594,29 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z | |||
532 | 594 | ||
533 | --------------------------------------------------------------------------- | 595 | --------------------------------------------------------------------------- |
534 | 596 | ||
597 | remapG :: (TransArray c, TransArray c1, Storable t, Storable a) | ||
598 | => (CInt -> CInt -> CInt -> CInt -> Ptr t | ||
599 | -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) | ||
600 | -> Matrix t -> c1 -> c -> Matrix a | ||
535 | remapG f i j m = unsafePerformIO $ do | 601 | remapG f i j m = unsafePerformIO $ do |
536 | r <- createMatrix RowMajor (rows i) (cols i) | 602 | r <- createMatrix RowMajor (rows i) (cols i) |
537 | (i # j # m #! r) f #|"remapG" | 603 | (i # j # m #! r) f #|"remapG" |
538 | return r | 604 | return r |
539 | 605 | ||
606 | remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double | ||
540 | remapD = remapG c_remapD | 607 | remapD = remapG c_remapD |
608 | remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float | ||
541 | remapF = remapG c_remapF | 609 | remapF = remapG c_remapF |
610 | remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt | ||
542 | remapI = remapG c_remapI | 611 | remapI = remapG c_remapI |
612 | remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z | ||
543 | remapL = remapG c_remapL | 613 | remapL = remapG c_remapL |
614 | remapC :: Matrix CInt | ||
615 | -> Matrix CInt | ||
616 | -> Matrix (Complex Double) | ||
617 | -> Matrix (Complex Double) | ||
544 | remapC = remapG c_remapC | 618 | remapC = remapG c_remapC |
619 | remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) | ||
545 | remapQ = remapG c_remapQ | 620 | remapQ = remapG c_remapQ |
546 | 621 | ||
547 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) | 622 | type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) |
@@ -555,6 +630,9 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z | |||
555 | 630 | ||
556 | -------------------------------------------------------------------------------- | 631 | -------------------------------------------------------------------------------- |
557 | 632 | ||
633 | rowOpAux :: (TransArray c, Storable a) => | ||
634 | (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) | ||
635 | -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () | ||
558 | rowOpAux f c x i1 i2 j1 j2 m = do | 636 | rowOpAux f c x i1 i2 j1 j2 m = do |
559 | px <- newArray [x] | 637 | px <- newArray [x] |
560 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" | 638 | (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" |
@@ -573,6 +651,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
573 | 651 | ||
574 | -------------------------------------------------------------------------------- | 652 | -------------------------------------------------------------------------------- |
575 | 653 | ||
654 | gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) | ||
655 | => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) | ||
656 | -> c3 -> c2 -> c1 -> c -> IO () | ||
576 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" | 657 | gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" |
577 | 658 | ||
578 | type Tgemm x = x :> x ::> x ::> x ::> Ok | 659 | type Tgemm x = x :> x ::> x ::> x ::> Ok |
@@ -588,6 +669,10 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | |||
588 | 669 | ||
589 | -------------------------------------------------------------------------------- | 670 | -------------------------------------------------------------------------------- |
590 | 671 | ||
672 | reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => | ||
673 | (CInt -> Ptr a -> CInt -> Ptr t1 | ||
674 | -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) | ||
675 | -> Vector t1 -> c -> Vector t -> Vector a1 | ||
591 | reorderAux f s d v = unsafePerformIO $ do | 676 | reorderAux f s d v = unsafePerformIO $ do |
592 | k <- createVector (dim s) | 677 | k <- createVector (dim s) |
593 | r <- createVector (dim v) | 678 | r <- createVector (dim v) |