From 57487d828065ea219cdb33c9dc177b67c60b34c7 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sun, 14 Jun 2015 19:49:10 +0200 Subject: minor changes --- packages/base/src/Internal/Matrix.hs | 10 +----- packages/base/src/Internal/ST.hs | 59 ++++++++++++++++++++++++++++++------ packages/base/src/Internal/Static.hs | 3 ++ packages/base/src/Internal/Util.hs | 44 +++++++++++++++------------ 4 files changed, 77 insertions(+), 39 deletions(-) (limited to 'packages/base/src/Internal') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index e0f5ed2..e4b1226 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -262,15 +262,7 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 ------------------------------------------------------------------ -{- | Supported matrix elements. - - This class provides optimized internal - operations for selected element types. - It provides unoptimised defaults for any 'Storable' type, - so you can create instances simply as: - - >instance Element Foo --} +-- | Supported matrix elements. class (Storable a) => Element a where transdata :: Int -> Vector a -> Int -> Vector a constantD :: a -> Int -> Vector a diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index a84ca25..434fe63 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -10,7 +10,7 @@ -- Stability : provisional -- -- In-place manipulation inside the ST monad. --- See examples/inplace.hs in the distribution. +-- See @examples/inplace.hs@ in the repository. -- ----------------------------------------------------------------------------- @@ -21,8 +21,8 @@ module Internal.ST ( -- * Mutable Matrices STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, - axpy, scal, swap, extractMatrix, setMatrix, rowOpST, - mutable, +-- axpy, scal, swap, rowOp, + mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), -- * Unsafe functions newUndefinedVector, unsafeReadVector, unsafeWriteVector, @@ -178,16 +178,55 @@ newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) -------------------------------------------------------------------------------- -rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s () -rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m) +data ColRange = AllCols + | ColRange Int Int + | Col Int + | FromCol Int -axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) -scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) -swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) +getColRange c AllCols = (0,c-1) +getColRange c (ColRange a b) = (a `mod` c, b `mod` c) +getColRange c (Col a) = (a `mod` c, a `mod` c) +getColRange c (FromCol a) = (a `mod` c, c-1) -extractMatrix (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) +data RowRange = AllRows + | RowRange Int Int + | Row Int + | FromRow Int + +getRowRange r AllRows = (0,r-1) +getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) +getRowRange r (Row a) = (a `mod` r, a `mod` r) +getRowRange r (FromRow a) = (a `mod` r, r-1) + +data RowOper t = AXPY t Int Int ColRange + | SCAL t RowRange ColRange + | SWAP Int Int ColRange + +rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () + +rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m + where + (j1,j2) = getColRange (cols m) r + i1' = i1 `mod` (rows m) + i2' = i2 `mod` (rows m) + +rowOper (SCAL x rr rc) (STMatrix m) = unsafeIOToST $ rowOp 1 x i1 i2 j1 j2 m + where + (i1,i2) = getRowRange (rows m) rr + (j1,j2) = getColRange (cols m) rc + +rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m + where + (j1,j2) = getColRange (cols m) r + i1' = i1 `mod` (rows m) + i2' = i2 `mod` (rows m) + + +extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) + where + (i1,i2) = getRowRange (rows m) rr + (j1,j2) = getColRange (cols m) rc --------------------------------------------------------------------------------- mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) mutable f a = runST $ do diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs index 01c2205..0068313 100644 --- a/packages/base/src/Internal/Static.hs +++ b/packages/base/src/Internal/Static.hs @@ -34,6 +34,9 @@ import Text.Printf -------------------------------------------------------------------------------- +type ℝ = Double +type ℂ = Complex Double + newtype Dim (n :: Nat) t = Dim t deriving Show diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 2650ac8..09ba21c 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs @@ -65,7 +65,7 @@ import Internal.Element import Internal.Container import Internal.Vectorized import Internal.IO -import Internal.Algorithms hiding (i,Normed,swap) +import Internal.Algorithms hiding (i,Normed,swap,linearSolve') import Numeric.Matrix() import Numeric.Vector() import Internal.Random @@ -155,7 +155,7 @@ infixl 3 & (&) :: Vector Double -> Vector Double -> Vector Double a & b = vjoin [a,b] -{- | horizontal concatenation of real matrices +{- | horizontal concatenation >>> ident 3 ||| konst 7 (3,4) (3><7) @@ -165,7 +165,7 @@ a & b = vjoin [a,b] -} infixl 3 ||| -(|||) :: Matrix Double -> Matrix Double -> Matrix Double +(|||) :: Element t => Matrix t -> Matrix t -> Matrix t a ||| b = fromBlocks [[a,b]] -- | a synonym for ('|||') (unicode 0x00a6, broken bar) @@ -174,9 +174,9 @@ infixl 3 ¦ (¦) = (|||) --- | vertical concatenation of real matrices +-- | vertical concatenation -- -(===) :: Matrix Double -> Matrix Double -> Matrix Double +(===) :: Element t => Matrix t -> Matrix t -> Matrix t infixl 2 === a === b = fromBlocks [[a],[b]] @@ -588,7 +588,7 @@ gaussElim_2 a b = flipudrl r where flipudrl = flipud . fliprl splitColsAt n = (takeColumns n &&& dropColumns n) - go f x y = splitColsAt (cols a) (down f $ fromBlocks [[x,y]]) + go f x y = splitColsAt (cols a) (down f $ x ||| y) (a1,b1) = go (snd . swapMax 0) a b ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1) @@ -600,7 +600,7 @@ gaussElim_1 gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2) where - rs = toRows $ fromBlocks [[x , y]] + rs = toRows $ x ||| y s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting s2 = pivotUp (rows x-1) (toRows $ flipud s1) @@ -637,12 +637,15 @@ pivotUp n xs -------------------------------------------------------------------------------- -gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]]) +gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (a ||| b) gaussST (r,_) x = do let n = r-1 + axpy m a i j = rowOper (AXPY a i j AllCols) m + swap m i j = rowOper (SWAP i j AllCols) m + scal m a i = rowOper (SCAL a (Row i) AllCols) m forM_ [0..n] $ \i -> do - c <- maxIndex . abs . flatten <$> extractMatrix x i n i i + c <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i) swap x i (i+c) a <- readMatrix x i i when (a == 0) $ error "singular!" @@ -656,22 +659,23 @@ gaussST (r,_) x = do axpy x (-b) i j -luST ok (r,c) x = do - let n = r-1 - axpy' m a i j = rowOpST 0 a i j (i+1) (c-1) m - p <- thawMatrix . asColumn . range $ r - forM_ [0..n] $ \i -> do - k <- maxIndex . abs . flatten <$> extractMatrix x i n i i - writeMatrix p i 0 (fi (k+i)) + +luST ok (r,_) x = do + let axpy m a i j = rowOper (AXPY a i j (FromCol (i+1))) m + swap m i j = rowOper (SWAP i j AllCols) m + p <- newUndefinedVector r + forM_ [0..r-1] $ \i -> do + k <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i) + writeVector p i (k+i) swap x i (i+k) a <- readMatrix x i i when (ok a) $ do - forM_ [i+1..n] $ \j -> do + forM_ [i+1..r-1] $ \j -> do b <- (/a) <$> readMatrix x j i - axpy' x (-b) i j + axpy x (-b) i j writeMatrix x j i b - v <- unsafeFreezeMatrix p - return (map ti $ toList $ flatten v) + v <- unsafeFreezeVector p + return (toList v) -------------------------------------------------------------------------------- -- cgit v1.2.3