From 1a68793247b8845cefad4d157e4f4d25b1731b42 Mon Sep 17 00:00:00 2001 From: Dominic Steinitz Date: Fri, 30 Mar 2018 12:48:20 +0100 Subject: Implement CI --- packages/base/src/Internal/Matrix.hs | 87 +++++++++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 1 deletion(-) (limited to 'packages/base/src/Internal/Matrix.hs') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 4905f61..4bfa13d 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -57,19 +57,24 @@ cols :: Matrix t -> Int cols = icols {-# INLINE cols #-} +size :: Matrix t -> (Int, Int) size m = (irows m, icols m) {-# INLINE size #-} +rowOrder :: Matrix t -> Bool rowOrder m = xCol m == 1 || cols m == 1 {-# INLINE rowOrder #-} +colOrder :: Matrix t -> Bool colOrder m = xRow m == 1 || rows m == 1 {-# INLINE colOrder #-} +is1d :: Matrix t -> Bool is1d (size->(r,c)) = r==1 || c==1 {-# INLINE is1d #-} -- data is not contiguous +isSlice :: Storable t => Matrix t -> Bool isSlice m@(size->(r,c)) = r*c < dim (xdat m) {-# INLINE isSlice #-} @@ -136,16 +141,20 @@ instance Storable t => TransArray (Matrix t) {-# INLINE applyRaw #-} infixr 1 # +(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r a # b = apply a b {-# INLINE (#) #-} +(#!) :: (TransArray c, TransArray c1) => c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r a #! b = a # b # id {-# INLINE (#!) #-} -------------------------------------------------------------------------------- +copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) +extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t extractAll ord m = unsafePerformIO (copy ord m) {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. @@ -223,11 +232,13 @@ m@Matrix {irows = r, icols = c} @@> (i,j) {-# INLINE (@@>) #-} -- Unsafe matrix access without range checking +atM' :: Storable t => Matrix t -> Int -> Int -> t atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m)) {-# INLINE atM' #-} ------------------------------------------------------------------ +matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 } matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d } matrixFromVector o r c v @@ -387,18 +398,21 @@ subMatrix (r0,c0) (rt,ct) m -------------------------------------------------------------------------- +maxZ :: (Num t1, Ord t1, Foldable t) => t t1 -> t1 maxZ xs = if minimum xs == 0 then 0 else maximum xs +conformMs :: Element t => [Matrix t] -> [Matrix t] conformMs ms = map (conformMTo (r,c)) ms where r = maxZ (map rows ms) c = maxZ (map cols ms) - +conformVs :: Element t => [Vector t] -> [Vector t] conformVs vs = map (conformVTo n) vs where n = maxZ (map dim vs) +conformMTo :: Element t => (Int, Int) -> Matrix t -> Matrix t conformMTo (r,c) m | size m == (r,c) = m | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c)) @@ -406,18 +420,24 @@ conformMTo (r,c) m | size m == (1,c) = repRows r m | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c) +conformVTo :: Element t => Int -> Vector t -> Vector t conformVTo n v | dim v == n = v | dim v == 1 = constantD (v@>0) n | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n +repRows :: Element t => Int -> Matrix t -> Matrix t repRows n x = fromRows (replicate n (flatten x)) +repCols :: Element t => Int -> Matrix t -> Matrix t repCols n x = fromColumns (replicate n (flatten x)) +shSize :: Matrix t -> [Char] shSize = shDim . size +shDim :: (Show a, Show a1) => (a1, a) -> [Char] shDim (r,c) = "(" ++ show r ++"x"++ show c ++")" +emptyM :: Storable t => Int -> Int -> Matrix t emptyM r c = matrixFromVector RowMajor r c (fromList[]) ---------------------------------------------------------------------- @@ -432,6 +452,11 @@ instance (Storable t, NFData t) => NFData (Matrix t) --------------------------------------------------------------- +extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1, + Storable t, Num t3, Num t2, Integral t1, Integral t) + => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t + -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)) + -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a) extractAux f ord m moder vr modec vc = do let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc @@ -451,6 +476,9 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z --------------------------------------------------------------- +setRectAux :: (TransArray c1, TransArray c) + => (CInt -> CInt -> Trans c1 (Trans c (IO CInt))) + -> Int -> Int -> c1 -> c -> IO () setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect" type SetRect x = I -> I -> x ::> x::> Ok @@ -464,19 +492,29 @@ foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z -------------------------------------------------------------------------------- +sortG :: (Storable t, Storable a) + => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a sortG f v = unsafePerformIO $ do r <- createVector (dim v) (v #! r) f #|"sortG" return r +sortIdxD :: Vector Double -> Vector CInt sortIdxD = sortG c_sort_indexD +sortIdxF :: Vector Float -> Vector CInt sortIdxF = sortG c_sort_indexF +sortIdxI :: Vector CInt -> Vector CInt sortIdxI = sortG c_sort_indexI +sortIdxL :: Vector Z -> Vector I sortIdxL = sortG c_sort_indexL +sortValD :: Vector Double -> Vector Double sortValD = sortG c_sort_valD +sortValF :: Vector Float -> Vector Float sortValF = sortG c_sort_valF +sortValI :: Vector CInt -> Vector CInt sortValI = sortG c_sort_valI +sortValL :: Vector Z -> Vector Z sortValL = sortG c_sort_valL foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt)) @@ -491,14 +529,21 @@ foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok -------------------------------------------------------------------------------- +compareG :: (TransArray c, Storable t, Storable a) + => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) + -> c -> Vector t -> Vector a compareG f u v = unsafePerformIO $ do r <- createVector (dim v) (u # v #! r) f #|"compareG" return r +compareD :: Vector Double -> Vector Double -> Vector CInt compareD = compareG c_compareD +compareF :: Vector Float -> Vector Float -> Vector CInt compareF = compareG c_compareF +compareI :: Vector CInt -> Vector CInt -> Vector CInt compareI = compareG c_compareI +compareL :: Vector Z -> Vector Z -> Vector CInt compareL = compareG c_compareL foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt))) @@ -508,16 +553,33 @@ foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok -------------------------------------------------------------------------------- +selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a) + => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt))) + -> c2 -> c1 -> Vector t -> c -> Vector a selectG f c u v w = unsafePerformIO $ do r <- createVector (dim v) (c # u # v # w #! r) f #|"selectG" return r +selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double selectD = selectG c_selectD +selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float selectF = selectG c_selectF +selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt selectI = selectG c_selectI +selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z selectL = selectG c_selectL +selectC :: Vector CInt + -> Vector (Complex Double) + -> Vector (Complex Double) + -> Vector (Complex Double) + -> Vector (Complex Double) selectC = selectG c_selectC +selectQ :: Vector CInt + -> Vector (Complex Float) + -> Vector (Complex Float) + -> Vector (Complex Float) + -> Vector (Complex Float) selectQ = selectG c_selectQ type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt))))) @@ -531,16 +593,29 @@ foreign import ccall unsafe "chooseL" c_selectL :: Sel Z --------------------------------------------------------------------------- +remapG :: (TransArray c, TransArray c1, Storable t, Storable a) + => (CInt -> CInt -> CInt -> CInt -> Ptr t + -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))) + -> Matrix t -> c1 -> c -> Matrix a remapG f i j m = unsafePerformIO $ do r <- createMatrix RowMajor (rows i) (cols i) (i # j # m #! r) f #|"remapG" return r +remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double remapD = remapG c_remapD +remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float remapF = remapG c_remapF +remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt remapI = remapG c_remapI +remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z remapL = remapG c_remapL +remapC :: Matrix CInt + -> Matrix CInt + -> Matrix (Complex Double) + -> Matrix (Complex Double) remapC = remapG c_remapC +remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float) remapQ = remapG c_remapQ type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt)))) @@ -554,6 +629,9 @@ foreign import ccall unsafe "remapL" c_remapL :: Rem Z -------------------------------------------------------------------------------- +rowOpAux :: (TransArray c, Storable a) => + (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt)) + -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO () rowOpAux f c x i1 i2 j1 j2 m = do px <- newArray [x] (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp" @@ -572,6 +650,9 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z -------------------------------------------------------------------------------- +gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3) + => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt)))) + -> c3 -> c2 -> c1 -> c -> IO () gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg" type Tgemm x = x :> x ::> x ::> x ::> Ok @@ -587,6 +668,10 @@ foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z -------------------------------------------------------------------------------- +reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) => + (CInt -> Ptr a -> CInt -> Ptr t1 + -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt)) + -> Vector t1 -> c -> Vector t -> Vector a1 reorderAux f s d v = unsafePerformIO $ do k <- createVector (dim s) r <- createVector (dim v) -- cgit v1.2.3