From 57487d828065ea219cdb33c9dc177b67c60b34c7 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Sun, 14 Jun 2015 19:49:10 +0200 Subject: minor changes --- packages/base/src/Internal/Matrix.hs | 10 +--- packages/base/src/Internal/ST.hs | 59 ++++++++++++++++++---- packages/base/src/Internal/Static.hs | 3 ++ packages/base/src/Internal/Util.hs | 44 ++++++++-------- packages/base/src/Numeric/LinearAlgebra.hs | 10 ++-- packages/base/src/Numeric/LinearAlgebra/Data.hs | 3 +- packages/base/src/Numeric/LinearAlgebra/Devel.hs | 7 ++- packages/base/src/Numeric/LinearAlgebra/HMatrix.hs | 4 +- packages/base/src/Numeric/LinearAlgebra/Static.hs | 6 +-- 9 files changed, 91 insertions(+), 55 deletions(-) (limited to 'packages/base/src') diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index e0f5ed2..e4b1226 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs @@ -262,15 +262,7 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 ------------------------------------------------------------------ -{- | Supported matrix elements. - - This class provides optimized internal - operations for selected element types. - It provides unoptimised defaults for any 'Storable' type, - so you can create instances simply as: - - >instance Element Foo --} +-- | Supported matrix elements. class (Storable a) => Element a where transdata :: Int -> Vector a -> Int -> Vector a constantD :: a -> Int -> Vector a diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index a84ca25..434fe63 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs @@ -10,7 +10,7 @@ -- Stability : provisional -- -- In-place manipulation inside the ST monad. --- See examples/inplace.hs in the distribution. +-- See @examples/inplace.hs@ in the repository. -- ----------------------------------------------------------------------------- @@ -21,8 +21,8 @@ module Internal.ST ( -- * Mutable Matrices STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, - axpy, scal, swap, extractMatrix, setMatrix, rowOpST, - mutable, +-- axpy, scal, swap, rowOp, + mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), -- * Unsafe functions newUndefinedVector, unsafeReadVector, unsafeWriteVector, @@ -178,16 +178,55 @@ newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) -------------------------------------------------------------------------------- -rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s () -rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m) +data ColRange = AllCols + | ColRange Int Int + | Col Int + | FromCol Int -axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) -scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) -swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) +getColRange c AllCols = (0,c-1) +getColRange c (ColRange a b) = (a `mod` c, b `mod` c) +getColRange c (Col a) = (a `mod` c, a `mod` c) +getColRange c (FromCol a) = (a `mod` c, c-1) -extractMatrix (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) +data RowRange = AllRows + | RowRange Int Int + | Row Int + | FromRow Int + +getRowRange r AllRows = (0,r-1) +getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) +getRowRange r (Row a) = (a `mod` r, a `mod` r) +getRowRange r (FromRow a) = (a `mod` r, r-1) + +data RowOper t = AXPY t Int Int ColRange + | SCAL t RowRange ColRange + | SWAP Int Int ColRange + +rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () + +rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m + where + (j1,j2) = getColRange (cols m) r + i1' = i1 `mod` (rows m) + i2' = i2 `mod` (rows m) + +rowOper (SCAL x rr rc) (STMatrix m) = unsafeIOToST $ rowOp 1 x i1 i2 j1 j2 m + where + (i1,i2) = getRowRange (rows m) rr + (j1,j2) = getColRange (cols m) rc + +rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m + where + (j1,j2) = getColRange (cols m) r + i1' = i1 `mod` (rows m) + i2' = i2 `mod` (rows m) + + +extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) + where + (i1,i2) = getRowRange (rows m) rr + (j1,j2) = getColRange (cols m) rc --------------------------------------------------------------------------------- mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) mutable f a = runST $ do diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs index 01c2205..0068313 100644 --- a/packages/base/src/Internal/Static.hs +++ b/packages/base/src/Internal/Static.hs @@ -34,6 +34,9 @@ import Text.Printf -------------------------------------------------------------------------------- +type ℝ = Double +type ℂ = Complex Double + newtype Dim (n :: Nat) t = Dim t deriving Show diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 2650ac8..09ba21c 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs @@ -65,7 +65,7 @@ import Internal.Element import Internal.Container import Internal.Vectorized import Internal.IO -import Internal.Algorithms hiding (i,Normed,swap) +import Internal.Algorithms hiding (i,Normed,swap,linearSolve') import Numeric.Matrix() import Numeric.Vector() import Internal.Random @@ -155,7 +155,7 @@ infixl 3 & (&) :: Vector Double -> Vector Double -> Vector Double a & b = vjoin [a,b] -{- | horizontal concatenation of real matrices +{- | horizontal concatenation >>> ident 3 ||| konst 7 (3,4) (3><7) @@ -165,7 +165,7 @@ a & b = vjoin [a,b] -} infixl 3 ||| -(|||) :: Matrix Double -> Matrix Double -> Matrix Double +(|||) :: Element t => Matrix t -> Matrix t -> Matrix t a ||| b = fromBlocks [[a,b]] -- | a synonym for ('|||') (unicode 0x00a6, broken bar) @@ -174,9 +174,9 @@ infixl 3 ¦ (¦) = (|||) --- | vertical concatenation of real matrices +-- | vertical concatenation -- -(===) :: Matrix Double -> Matrix Double -> Matrix Double +(===) :: Element t => Matrix t -> Matrix t -> Matrix t infixl 2 === a === b = fromBlocks [[a],[b]] @@ -588,7 +588,7 @@ gaussElim_2 a b = flipudrl r where flipudrl = flipud . fliprl splitColsAt n = (takeColumns n &&& dropColumns n) - go f x y = splitColsAt (cols a) (down f $ fromBlocks [[x,y]]) + go f x y = splitColsAt (cols a) (down f $ x ||| y) (a1,b1) = go (snd . swapMax 0) a b ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1) @@ -600,7 +600,7 @@ gaussElim_1 gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2) where - rs = toRows $ fromBlocks [[x , y]] + rs = toRows $ x ||| y s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting s2 = pivotUp (rows x-1) (toRows $ flipud s1) @@ -637,12 +637,15 @@ pivotUp n xs -------------------------------------------------------------------------------- -gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]]) +gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (a ||| b) gaussST (r,_) x = do let n = r-1 + axpy m a i j = rowOper (AXPY a i j AllCols) m + swap m i j = rowOper (SWAP i j AllCols) m + scal m a i = rowOper (SCAL a (Row i) AllCols) m forM_ [0..n] $ \i -> do - c <- maxIndex . abs . flatten <$> extractMatrix x i n i i + c <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i) swap x i (i+c) a <- readMatrix x i i when (a == 0) $ error "singular!" @@ -656,22 +659,23 @@ gaussST (r,_) x = do axpy x (-b) i j -luST ok (r,c) x = do - let n = r-1 - axpy' m a i j = rowOpST 0 a i j (i+1) (c-1) m - p <- thawMatrix . asColumn . range $ r - forM_ [0..n] $ \i -> do - k <- maxIndex . abs . flatten <$> extractMatrix x i n i i - writeMatrix p i 0 (fi (k+i)) + +luST ok (r,_) x = do + let axpy m a i j = rowOper (AXPY a i j (FromCol (i+1))) m + swap m i j = rowOper (SWAP i j AllCols) m + p <- newUndefinedVector r + forM_ [0..r-1] $ \i -> do + k <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i) + writeVector p i (k+i) swap x i (i+k) a <- readMatrix x i i when (ok a) $ do - forM_ [i+1..n] $ \j -> do + forM_ [i+1..r-1] $ \j -> do b <- (/a) <$> readMatrix x j i - axpy' x (-b) i j + axpy x (-b) i j writeMatrix x j i b - v <- unsafeFreezeMatrix p - return (map ti $ toList $ flatten v) + v <- unsafeFreezeVector p + return (toList v) -------------------------------------------------------------------------------- diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 0f8efa4..fe524cc 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs @@ -80,6 +80,7 @@ module Numeric.LinearAlgebra ( cholSolve, cgSolve, cgSolve', + linearSolve', -- * Inverse and pseudoinverse inv, pinv, pinvTol, @@ -136,8 +137,9 @@ module Numeric.LinearAlgebra ( Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, -- * Misc - meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, luST, magnit, - ℝ,ℂ,iC, + meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, magnit, + haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, + iC, -- * Auxiliary classes Element, Container, Product, Numeric, LSDiv, Complexable, RealElement, @@ -156,7 +158,7 @@ import Numeric.Vector() import Internal.Matrix import Internal.Container hiding ((<>)) import Internal.Numeric hiding (mul) -import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked') +import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked',linearSolve') import qualified Internal.Algorithms as A import Internal.Util import Internal.Random @@ -240,3 +242,5 @@ orth m = orthSVD (Left (1*eps)) m (leftSV m) luPacked' x = mutable (luST (magnit 0)) x +linearSolve' x y = gaussElim x y + diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 1c9bb68..fffc2bd 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs @@ -53,8 +53,7 @@ module Numeric.LinearAlgebra.Data( -- * Matrix extraction Extractor(..), (??), - takeRows, takeLastRows, dropRows, dropLastRows, - takeColumns, takeLastColumns, dropColumns, dropLastColumns, + takeRows, dropRows, takeColumns, dropColumns, subMatrix, (?), (¿), fliprl, flipud, remap, -- * Block matrix diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index f572656..36c5f03 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs @@ -20,8 +20,7 @@ module Numeric.LinearAlgebra.Devel( module Internal.Foreign, -- * FFI tools - -- | Illustrative usage examples can be found - -- in the @examples\/devel@ folder included in the package. + -- | See @examples/devel@ in the repository. createVector, createMatrix, vec, mat, omat, @@ -36,7 +35,7 @@ module Numeric.LinearAlgebra.Devel( -- * ST -- | In-place manipulation inside the ST monad. - -- See examples\/inplace.hs in the distribution. + -- See @examples/inplace.hs@ in the repository. -- ** Mutable Vectors STVector, newVector, thawVector, freezeVector, runSTVector, @@ -44,7 +43,7 @@ module Numeric.LinearAlgebra.Devel( -- ** Mutable Matrices STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, - axpy,scal,swap, extractMatrix, setMatrix, mutable, rowOpST, + mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), -- ** Unsafe functions newUndefinedVector, unsafeReadVector, unsafeWriteVector, diff --git a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs index 327f284..11c2487 100644 --- a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs +++ b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs @@ -13,10 +13,10 @@ compatibility with previous version, to be removed module Numeric.LinearAlgebra.HMatrix ( module Numeric.LinearAlgebra, - (¦),(——) + (¦),(——),ℝ,ℂ, ) where import Numeric.LinearAlgebra import Internal.Util - + diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index dee5b2c..a657bd0 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs @@ -28,7 +28,7 @@ This module is under active development and the interface is subject to changes. module Numeric.LinearAlgebra.Static( -- * Vector - ℝ, R, + ℝ, R, vec2, vec3, vec4, (&), (#), split, headTail, vector, linspace, range, dim, @@ -71,10 +71,6 @@ import Data.Proxy(Proxy) import Internal.Static import Control.Arrow((***)) - - - - ud1 :: R n -> Vector ℝ ud1 (R (Dim v)) = v -- cgit v1.2.3