summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--packages/base/CHANGELOG5
-rw-r--r--packages/base/src/Internal/Matrix.hs10
-rw-r--r--packages/base/src/Internal/ST.hs59
-rw-r--r--packages/base/src/Internal/Static.hs3
-rw-r--r--packages/base/src/Internal/Util.hs44
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs10
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Data.hs3
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Devel.hs7
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/HMatrix.hs4
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs6
-rw-r--r--packages/tests/src/Numeric/LinearAlgebra/Tests.hs1
11 files changed, 96 insertions, 56 deletions
diff --git a/packages/base/CHANGELOG b/packages/base/CHANGELOG
index 27c4d31..93b2594 100644
--- a/packages/base/CHANGELOG
+++ b/packages/base/CHANGELOG
@@ -7,7 +7,10 @@
7 7
8 * remap, ccompare, sortIndex 8 * remap, ccompare, sortIndex
9 9
10 * experimental support of type safe modular arithmetic 10 * experimental support of type safe modular arithmetic, including linear
11 systems and lu factorization
12
13 * elementary row operations in ST monad
11 14
12 * old compatibility modules removed 15 * old compatibility modules removed
13 16
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index e0f5ed2..e4b1226 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -262,15 +262,7 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
262 262
263------------------------------------------------------------------ 263------------------------------------------------------------------
264 264
265{- | Supported matrix elements. 265-- | Supported matrix elements.
266
267 This class provides optimized internal
268 operations for selected element types.
269 It provides unoptimised defaults for any 'Storable' type,
270 so you can create instances simply as:
271
272 >instance Element Foo
273-}
274class (Storable a) => Element a where 266class (Storable a) => Element a where
275 transdata :: Int -> Vector a -> Int -> Vector a 267 transdata :: Int -> Vector a -> Int -> Vector a
276 constantD :: a -> Int -> Vector a 268 constantD :: a -> Int -> Vector a
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index a84ca25..434fe63 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -10,7 +10,7 @@
10-- Stability : provisional 10-- Stability : provisional
11-- 11--
12-- In-place manipulation inside the ST monad. 12-- In-place manipulation inside the ST monad.
13-- See examples/inplace.hs in the distribution. 13-- See @examples/inplace.hs@ in the repository.
14-- 14--
15----------------------------------------------------------------------------- 15-----------------------------------------------------------------------------
16 16
@@ -21,8 +21,8 @@ module Internal.ST (
21 -- * Mutable Matrices 21 -- * Mutable Matrices
22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24 axpy, scal, swap, extractMatrix, setMatrix, rowOpST, 24-- axpy, scal, swap, rowOp,
25 mutable, 25 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..),
26 -- * Unsafe functions 26 -- * Unsafe functions
27 newUndefinedVector, 27 newUndefinedVector,
28 unsafeReadVector, unsafeWriteVector, 28 unsafeReadVector, unsafeWriteVector,
@@ -178,16 +178,55 @@ newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
178 178
179-------------------------------------------------------------------------------- 179--------------------------------------------------------------------------------
180 180
181rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s () 181data ColRange = AllCols
182rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m) 182 | ColRange Int Int
183 | Col Int
184 | FromCol Int
183 185
184axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) 186getColRange c AllCols = (0,c-1)
185scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) 187getColRange c (ColRange a b) = (a `mod` c, b `mod` c)
186swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) 188getColRange c (Col a) = (a `mod` c, a `mod` c)
189getColRange c (FromCol a) = (a `mod` c, c-1)
187 190
188extractMatrix (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) 191data RowRange = AllRows
192 | RowRange Int Int
193 | Row Int
194 | FromRow Int
195
196getRowRange r AllRows = (0,r-1)
197getRowRange r (RowRange a b) = (a `mod` r, b `mod` r)
198getRowRange r (Row a) = (a `mod` r, a `mod` r)
199getRowRange r (FromRow a) = (a `mod` r, r-1)
200
201data RowOper t = AXPY t Int Int ColRange
202 | SCAL t RowRange ColRange
203 | SWAP Int Int ColRange
204
205rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s ()
206
207rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m
208 where
209 (j1,j2) = getColRange (cols m) r
210 i1' = i1 `mod` (rows m)
211 i2' = i2 `mod` (rows m)
212
213rowOper (SCAL x rr rc) (STMatrix m) = unsafeIOToST $ rowOp 1 x i1 i2 j1 j2 m
214 where
215 (i1,i2) = getRowRange (rows m) rr
216 (j1,j2) = getColRange (cols m) rc
217
218rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m
219 where
220 (j1,j2) = getColRange (cols m) r
221 i1' = i1 `mod` (rows m)
222 i2' = i2 `mod` (rows m)
223
224
225extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
226 where
227 (i1,i2) = getRowRange (rows m) rr
228 (j1,j2) = getColRange (cols m) rc
189 229
190--------------------------------------------------------------------------------
191 230
192mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 231mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
193mutable f a = runST $ do 232mutable f a = runST $ do
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs
index 01c2205..0068313 100644
--- a/packages/base/src/Internal/Static.hs
+++ b/packages/base/src/Internal/Static.hs
@@ -34,6 +34,9 @@ import Text.Printf
34 34
35-------------------------------------------------------------------------------- 35--------------------------------------------------------------------------------
36 36
37type ℝ = Double
38type ℂ = Complex Double
39
37newtype Dim (n :: Nat) t = Dim t 40newtype Dim (n :: Nat) t = Dim t
38 deriving Show 41 deriving Show
39 42
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index 2650ac8..09ba21c 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -65,7 +65,7 @@ import Internal.Element
65import Internal.Container 65import Internal.Container
66import Internal.Vectorized 66import Internal.Vectorized
67import Internal.IO 67import Internal.IO
68import Internal.Algorithms hiding (i,Normed,swap) 68import Internal.Algorithms hiding (i,Normed,swap,linearSolve')
69import Numeric.Matrix() 69import Numeric.Matrix()
70import Numeric.Vector() 70import Numeric.Vector()
71import Internal.Random 71import Internal.Random
@@ -155,7 +155,7 @@ infixl 3 &
155(&) :: Vector Double -> Vector Double -> Vector Double 155(&) :: Vector Double -> Vector Double -> Vector Double
156a & b = vjoin [a,b] 156a & b = vjoin [a,b]
157 157
158{- | horizontal concatenation of real matrices 158{- | horizontal concatenation
159 159
160>>> ident 3 ||| konst 7 (3,4) 160>>> ident 3 ||| konst 7 (3,4)
161(3><7) 161(3><7)
@@ -165,7 +165,7 @@ a & b = vjoin [a,b]
165 165
166-} 166-}
167infixl 3 ||| 167infixl 3 |||
168(|||) :: Matrix Double -> Matrix Double -> Matrix Double 168(|||) :: Element t => Matrix t -> Matrix t -> Matrix t
169a ||| b = fromBlocks [[a,b]] 169a ||| b = fromBlocks [[a,b]]
170 170
171-- | a synonym for ('|||') (unicode 0x00a6, broken bar) 171-- | a synonym for ('|||') (unicode 0x00a6, broken bar)
@@ -174,9 +174,9 @@ infixl 3 ¦
174(¦) = (|||) 174(¦) = (|||)
175 175
176 176
177-- | vertical concatenation of real matrices 177-- | vertical concatenation
178-- 178--
179(===) :: Matrix Double -> Matrix Double -> Matrix Double 179(===) :: Element t => Matrix t -> Matrix t -> Matrix t
180infixl 2 === 180infixl 2 ===
181a === b = fromBlocks [[a],[b]] 181a === b = fromBlocks [[a],[b]]
182 182
@@ -588,7 +588,7 @@ gaussElim_2 a b = flipudrl r
588 where 588 where
589 flipudrl = flipud . fliprl 589 flipudrl = flipud . fliprl
590 splitColsAt n = (takeColumns n &&& dropColumns n) 590 splitColsAt n = (takeColumns n &&& dropColumns n)
591 go f x y = splitColsAt (cols a) (down f $ fromBlocks [[x,y]]) 591 go f x y = splitColsAt (cols a) (down f $ x ||| y)
592 (a1,b1) = go (snd . swapMax 0) a b 592 (a1,b1) = go (snd . swapMax 0) a b
593 ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1) 593 ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1)
594 594
@@ -600,7 +600,7 @@ gaussElim_1
600 600
601gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2) 601gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2)
602 where 602 where
603 rs = toRows $ fromBlocks [[x , y]] 603 rs = toRows $ x ||| y
604 s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting 604 s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting
605 s2 = pivotUp (rows x-1) (toRows $ flipud s1) 605 s2 = pivotUp (rows x-1) (toRows $ flipud s1)
606 606
@@ -637,12 +637,15 @@ pivotUp n xs
637 637
638-------------------------------------------------------------------------------- 638--------------------------------------------------------------------------------
639 639
640gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]]) 640gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (a ||| b)
641 641
642gaussST (r,_) x = do 642gaussST (r,_) x = do
643 let n = r-1 643 let n = r-1
644 axpy m a i j = rowOper (AXPY a i j AllCols) m
645 swap m i j = rowOper (SWAP i j AllCols) m
646 scal m a i = rowOper (SCAL a (Row i) AllCols) m
644 forM_ [0..n] $ \i -> do 647 forM_ [0..n] $ \i -> do
645 c <- maxIndex . abs . flatten <$> extractMatrix x i n i i 648 c <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i)
646 swap x i (i+c) 649 swap x i (i+c)
647 a <- readMatrix x i i 650 a <- readMatrix x i i
648 when (a == 0) $ error "singular!" 651 when (a == 0) $ error "singular!"
@@ -656,22 +659,23 @@ gaussST (r,_) x = do
656 axpy x (-b) i j 659 axpy x (-b) i j
657 660
658 661
659luST ok (r,c) x = do 662
660 let n = r-1 663luST ok (r,_) x = do
661 axpy' m a i j = rowOpST 0 a i j (i+1) (c-1) m 664 let axpy m a i j = rowOper (AXPY a i j (FromCol (i+1))) m
662 p <- thawMatrix . asColumn . range $ r 665 swap m i j = rowOper (SWAP i j AllCols) m
663 forM_ [0..n] $ \i -> do 666 p <- newUndefinedVector r
664 k <- maxIndex . abs . flatten <$> extractMatrix x i n i i 667 forM_ [0..r-1] $ \i -> do
665 writeMatrix p i 0 (fi (k+i)) 668 k <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i)
669 writeVector p i (k+i)
666 swap x i (i+k) 670 swap x i (i+k)
667 a <- readMatrix x i i 671 a <- readMatrix x i i
668 when (ok a) $ do 672 when (ok a) $ do
669 forM_ [i+1..n] $ \j -> do 673 forM_ [i+1..r-1] $ \j -> do
670 b <- (/a) <$> readMatrix x j i 674 b <- (/a) <$> readMatrix x j i
671 axpy' x (-b) i j 675 axpy x (-b) i j
672 writeMatrix x j i b 676 writeMatrix x j i b
673 v <- unsafeFreezeMatrix p 677 v <- unsafeFreezeVector p
674 return (map ti $ toList $ flatten v) 678 return (toList v)
675 679
676 680
677-------------------------------------------------------------------------------- 681--------------------------------------------------------------------------------
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index 0f8efa4..fe524cc 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -80,6 +80,7 @@ module Numeric.LinearAlgebra (
80 cholSolve, 80 cholSolve,
81 cgSolve, 81 cgSolve,
82 cgSolve', 82 cgSolve',
83 linearSolve',
83 84
84 -- * Inverse and pseudoinverse 85 -- * Inverse and pseudoinverse
85 inv, pinv, pinvTol, 86 inv, pinv, pinvTol,
@@ -136,8 +137,9 @@ module Numeric.LinearAlgebra (
136 Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, 137 Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample,
137 138
138 -- * Misc 139 -- * Misc
139 meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, luST, magnit, 140 meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, magnit,
140 ℝ,ℂ,iC, 141 haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv,
142 iC,
141 -- * Auxiliary classes 143 -- * Auxiliary classes
142 Element, Container, Product, Numeric, LSDiv, 144 Element, Container, Product, Numeric, LSDiv,
143 Complexable, RealElement, 145 Complexable, RealElement,
@@ -156,7 +158,7 @@ import Numeric.Vector()
156import Internal.Matrix 158import Internal.Matrix
157import Internal.Container hiding ((<>)) 159import Internal.Container hiding ((<>))
158import Internal.Numeric hiding (mul) 160import Internal.Numeric hiding (mul)
159import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked') 161import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked',linearSolve')
160import qualified Internal.Algorithms as A 162import qualified Internal.Algorithms as A
161import Internal.Util 163import Internal.Util
162import Internal.Random 164import Internal.Random
@@ -240,3 +242,5 @@ orth m = orthSVD (Left (1*eps)) m (leftSV m)
240 242
241luPacked' x = mutable (luST (magnit 0)) x 243luPacked' x = mutable (luST (magnit 0)) x
242 244
245linearSolve' x y = gaussElim x y
246
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs
index 1c9bb68..fffc2bd 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Data.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs
@@ -53,8 +53,7 @@ module Numeric.LinearAlgebra.Data(
53 53
54 -- * Matrix extraction 54 -- * Matrix extraction
55 Extractor(..), (??), 55 Extractor(..), (??),
56 takeRows, takeLastRows, dropRows, dropLastRows, 56 takeRows, dropRows, takeColumns, dropColumns,
57 takeColumns, takeLastColumns, dropColumns, dropLastColumns,
58 subMatrix, (?), (¿), fliprl, flipud, remap, 57 subMatrix, (?), (¿), fliprl, flipud, remap,
59 58
60 -- * Block matrix 59 -- * Block matrix
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
index f572656..36c5f03 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
@@ -20,8 +20,7 @@ module Numeric.LinearAlgebra.Devel(
20 module Internal.Foreign, 20 module Internal.Foreign,
21 21
22 -- * FFI tools 22 -- * FFI tools
23 -- | Illustrative usage examples can be found 23 -- | See @examples/devel@ in the repository.
24 -- in the @examples\/devel@ folder included in the package.
25 24
26 createVector, createMatrix, 25 createVector, createMatrix,
27 vec, mat, omat, 26 vec, mat, omat,
@@ -36,7 +35,7 @@ module Numeric.LinearAlgebra.Devel(
36 35
37 -- * ST 36 -- * ST
38 -- | In-place manipulation inside the ST monad. 37 -- | In-place manipulation inside the ST monad.
39 -- See examples\/inplace.hs in the distribution. 38 -- See @examples/inplace.hs@ in the repository.
40 39
41 -- ** Mutable Vectors 40 -- ** Mutable Vectors
42 STVector, newVector, thawVector, freezeVector, runSTVector, 41 STVector, newVector, thawVector, freezeVector, runSTVector,
@@ -44,7 +43,7 @@ module Numeric.LinearAlgebra.Devel(
44 -- ** Mutable Matrices 43 -- ** Mutable Matrices
45 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 44 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
46 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 45 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
47 axpy,scal,swap, extractMatrix, setMatrix, mutable, rowOpST, 46 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..),
48 -- ** Unsafe functions 47 -- ** Unsafe functions
49 newUndefinedVector, 48 newUndefinedVector,
50 unsafeReadVector, unsafeWriteVector, 49 unsafeReadVector, unsafeWriteVector,
diff --git a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs
index 327f284..11c2487 100644
--- a/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/HMatrix.hs
@@ -13,10 +13,10 @@ compatibility with previous version, to be removed
13 13
14module Numeric.LinearAlgebra.HMatrix ( 14module Numeric.LinearAlgebra.HMatrix (
15 module Numeric.LinearAlgebra, 15 module Numeric.LinearAlgebra,
16 (¦),(——) 16 (¦),(——),ℝ,ℂ,
17) where 17) where
18 18
19import Numeric.LinearAlgebra 19import Numeric.LinearAlgebra
20import Internal.Util 20import Internal.Util
21 21
22 22
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index dee5b2c..a657bd0 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -28,7 +28,7 @@ This module is under active development and the interface is subject to changes.
28 28
29module Numeric.LinearAlgebra.Static( 29module Numeric.LinearAlgebra.Static(
30 -- * Vector 30 -- * Vector
31 ℝ, R, 31 ℝ, R,
32 vec2, vec3, vec4, (&), (#), split, headTail, 32 vec2, vec3, vec4, (&), (#), split, headTail,
33 vector, 33 vector,
34 linspace, range, dim, 34 linspace, range, dim,
@@ -71,10 +71,6 @@ import Data.Proxy(Proxy)
71import Internal.Static 71import Internal.Static
72import Control.Arrow((***)) 72import Control.Arrow((***))
73 73
74
75
76
77
78ud1 :: R n -> Vector ℝ 74ud1 :: R n -> Vector ℝ
79ud1 (R (Dim v)) = v 75ud1 (R (Dim v)) = v
80 76
diff --git a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
index ffa45e7..148bbb9 100644
--- a/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
+++ b/packages/tests/src/Numeric/LinearAlgebra/Tests.hs
@@ -772,6 +772,7 @@ luBenchN f n x msg = do
772 time (msg ++ " "++ show n) (f m) 772 time (msg ++ " "++ show n) (f m)
773 773
774luBench = do 774luBench = do
775 putStrLn ""
775 luBenchN luPacked 1000 (5::R) "luPacked Double " 776 luBenchN luPacked 1000 (5::R) "luPacked Double "
776 luBenchN luPacked' 1000 (5::R) "luPacked' Double " 777 luBenchN luPacked' 1000 (5::R) "luPacked' Double "
777 luBenchN luPacked' 1000 (5::Mod 9973 I) "luPacked' I mod 9973" 778 luBenchN luPacked' 1000 (5::Mod 9973 I) "luPacked' I mod 9973"