summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r--packages/base/src/Internal/Matrix.hs10
-rw-r--r--packages/base/src/Internal/ST.hs59
-rw-r--r--packages/base/src/Internal/Static.hs3
-rw-r--r--packages/base/src/Internal/Util.hs44
4 files changed, 77 insertions, 39 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index e0f5ed2..e4b1226 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -262,15 +262,7 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
262 262
263------------------------------------------------------------------ 263------------------------------------------------------------------
264 264
265{- | Supported matrix elements. 265-- | Supported matrix elements.
266
267 This class provides optimized internal
268 operations for selected element types.
269 It provides unoptimised defaults for any 'Storable' type,
270 so you can create instances simply as:
271
272 >instance Element Foo
273-}
274class (Storable a) => Element a where 266class (Storable a) => Element a where
275 transdata :: Int -> Vector a -> Int -> Vector a 267 transdata :: Int -> Vector a -> Int -> Vector a
276 constantD :: a -> Int -> Vector a 268 constantD :: a -> Int -> Vector a
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index a84ca25..434fe63 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -10,7 +10,7 @@
10-- Stability : provisional 10-- Stability : provisional
11-- 11--
12-- In-place manipulation inside the ST monad. 12-- In-place manipulation inside the ST monad.
13-- See examples/inplace.hs in the distribution. 13-- See @examples/inplace.hs@ in the repository.
14-- 14--
15----------------------------------------------------------------------------- 15-----------------------------------------------------------------------------
16 16
@@ -21,8 +21,8 @@ module Internal.ST (
21 -- * Mutable Matrices 21 -- * Mutable Matrices
22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24 axpy, scal, swap, extractMatrix, setMatrix, rowOpST, 24-- axpy, scal, swap, rowOp,
25 mutable, 25 mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..),
26 -- * Unsafe functions 26 -- * Unsafe functions
27 newUndefinedVector, 27 newUndefinedVector,
28 unsafeReadVector, unsafeWriteVector, 28 unsafeReadVector, unsafeWriteVector,
@@ -178,16 +178,55 @@ newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
178 178
179-------------------------------------------------------------------------------- 179--------------------------------------------------------------------------------
180 180
181rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s () 181data ColRange = AllCols
182rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m) 182 | ColRange Int Int
183 | Col Int
184 | FromCol Int
183 185
184axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) 186getColRange c AllCols = (0,c-1)
185scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) 187getColRange c (ColRange a b) = (a `mod` c, b `mod` c)
186swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) 188getColRange c (Col a) = (a `mod` c, a `mod` c)
189getColRange c (FromCol a) = (a `mod` c, c-1)
187 190
188extractMatrix (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) 191data RowRange = AllRows
192 | RowRange Int Int
193 | Row Int
194 | FromRow Int
195
196getRowRange r AllRows = (0,r-1)
197getRowRange r (RowRange a b) = (a `mod` r, b `mod` r)
198getRowRange r (Row a) = (a `mod` r, a `mod` r)
199getRowRange r (FromRow a) = (a `mod` r, r-1)
200
201data RowOper t = AXPY t Int Int ColRange
202 | SCAL t RowRange ColRange
203 | SWAP Int Int ColRange
204
205rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s ()
206
207rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m
208 where
209 (j1,j2) = getColRange (cols m) r
210 i1' = i1 `mod` (rows m)
211 i2' = i2 `mod` (rows m)
212
213rowOper (SCAL x rr rc) (STMatrix m) = unsafeIOToST $ rowOp 1 x i1 i2 j1 j2 m
214 where
215 (i1,i2) = getRowRange (rows m) rr
216 (j1,j2) = getColRange (cols m) rc
217
218rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m
219 where
220 (j1,j2) = getColRange (cols m) r
221 i1' = i1 `mod` (rows m)
222 i2' = i2 `mod` (rows m)
223
224
225extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
226 where
227 (i1,i2) = getRowRange (rows m) rr
228 (j1,j2) = getColRange (cols m) rc
189 229
190--------------------------------------------------------------------------------
191 230
192mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 231mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
193mutable f a = runST $ do 232mutable f a = runST $ do
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs
index 01c2205..0068313 100644
--- a/packages/base/src/Internal/Static.hs
+++ b/packages/base/src/Internal/Static.hs
@@ -34,6 +34,9 @@ import Text.Printf
34 34
35-------------------------------------------------------------------------------- 35--------------------------------------------------------------------------------
36 36
37type ℝ = Double
38type ℂ = Complex Double
39
37newtype Dim (n :: Nat) t = Dim t 40newtype Dim (n :: Nat) t = Dim t
38 deriving Show 41 deriving Show
39 42
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index 2650ac8..09ba21c 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -65,7 +65,7 @@ import Internal.Element
65import Internal.Container 65import Internal.Container
66import Internal.Vectorized 66import Internal.Vectorized
67import Internal.IO 67import Internal.IO
68import Internal.Algorithms hiding (i,Normed,swap) 68import Internal.Algorithms hiding (i,Normed,swap,linearSolve')
69import Numeric.Matrix() 69import Numeric.Matrix()
70import Numeric.Vector() 70import Numeric.Vector()
71import Internal.Random 71import Internal.Random
@@ -155,7 +155,7 @@ infixl 3 &
155(&) :: Vector Double -> Vector Double -> Vector Double 155(&) :: Vector Double -> Vector Double -> Vector Double
156a & b = vjoin [a,b] 156a & b = vjoin [a,b]
157 157
158{- | horizontal concatenation of real matrices 158{- | horizontal concatenation
159 159
160>>> ident 3 ||| konst 7 (3,4) 160>>> ident 3 ||| konst 7 (3,4)
161(3><7) 161(3><7)
@@ -165,7 +165,7 @@ a & b = vjoin [a,b]
165 165
166-} 166-}
167infixl 3 ||| 167infixl 3 |||
168(|||) :: Matrix Double -> Matrix Double -> Matrix Double 168(|||) :: Element t => Matrix t -> Matrix t -> Matrix t
169a ||| b = fromBlocks [[a,b]] 169a ||| b = fromBlocks [[a,b]]
170 170
171-- | a synonym for ('|||') (unicode 0x00a6, broken bar) 171-- | a synonym for ('|||') (unicode 0x00a6, broken bar)
@@ -174,9 +174,9 @@ infixl 3 ¦
174(¦) = (|||) 174(¦) = (|||)
175 175
176 176
177-- | vertical concatenation of real matrices 177-- | vertical concatenation
178-- 178--
179(===) :: Matrix Double -> Matrix Double -> Matrix Double 179(===) :: Element t => Matrix t -> Matrix t -> Matrix t
180infixl 2 === 180infixl 2 ===
181a === b = fromBlocks [[a],[b]] 181a === b = fromBlocks [[a],[b]]
182 182
@@ -588,7 +588,7 @@ gaussElim_2 a b = flipudrl r
588 where 588 where
589 flipudrl = flipud . fliprl 589 flipudrl = flipud . fliprl
590 splitColsAt n = (takeColumns n &&& dropColumns n) 590 splitColsAt n = (takeColumns n &&& dropColumns n)
591 go f x y = splitColsAt (cols a) (down f $ fromBlocks [[x,y]]) 591 go f x y = splitColsAt (cols a) (down f $ x ||| y)
592 (a1,b1) = go (snd . swapMax 0) a b 592 (a1,b1) = go (snd . swapMax 0) a b
593 ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1) 593 ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1)
594 594
@@ -600,7 +600,7 @@ gaussElim_1
600 600
601gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2) 601gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2)
602 where 602 where
603 rs = toRows $ fromBlocks [[x , y]] 603 rs = toRows $ x ||| y
604 s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting 604 s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting
605 s2 = pivotUp (rows x-1) (toRows $ flipud s1) 605 s2 = pivotUp (rows x-1) (toRows $ flipud s1)
606 606
@@ -637,12 +637,15 @@ pivotUp n xs
637 637
638-------------------------------------------------------------------------------- 638--------------------------------------------------------------------------------
639 639
640gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]]) 640gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (a ||| b)
641 641
642gaussST (r,_) x = do 642gaussST (r,_) x = do
643 let n = r-1 643 let n = r-1
644 axpy m a i j = rowOper (AXPY a i j AllCols) m
645 swap m i j = rowOper (SWAP i j AllCols) m
646 scal m a i = rowOper (SCAL a (Row i) AllCols) m
644 forM_ [0..n] $ \i -> do 647 forM_ [0..n] $ \i -> do
645 c <- maxIndex . abs . flatten <$> extractMatrix x i n i i 648 c <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i)
646 swap x i (i+c) 649 swap x i (i+c)
647 a <- readMatrix x i i 650 a <- readMatrix x i i
648 when (a == 0) $ error "singular!" 651 when (a == 0) $ error "singular!"
@@ -656,22 +659,23 @@ gaussST (r,_) x = do
656 axpy x (-b) i j 659 axpy x (-b) i j
657 660
658 661
659luST ok (r,c) x = do 662
660 let n = r-1 663luST ok (r,_) x = do
661 axpy' m a i j = rowOpST 0 a i j (i+1) (c-1) m 664 let axpy m a i j = rowOper (AXPY a i j (FromCol (i+1))) m
662 p <- thawMatrix . asColumn . range $ r 665 swap m i j = rowOper (SWAP i j AllCols) m
663 forM_ [0..n] $ \i -> do 666 p <- newUndefinedVector r
664 k <- maxIndex . abs . flatten <$> extractMatrix x i n i i 667 forM_ [0..r-1] $ \i -> do
665 writeMatrix p i 0 (fi (k+i)) 668 k <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i)
669 writeVector p i (k+i)
666 swap x i (i+k) 670 swap x i (i+k)
667 a <- readMatrix x i i 671 a <- readMatrix x i i
668 when (ok a) $ do 672 when (ok a) $ do
669 forM_ [i+1..n] $ \j -> do 673 forM_ [i+1..r-1] $ \j -> do
670 b <- (/a) <$> readMatrix x j i 674 b <- (/a) <$> readMatrix x j i
671 axpy' x (-b) i j 675 axpy x (-b) i j
672 writeMatrix x j i b 676 writeMatrix x j i b
673 v <- unsafeFreezeMatrix p 677 v <- unsafeFreezeVector p
674 return (map ti $ toList $ flatten v) 678 return (toList v)
675 679
676 680
677-------------------------------------------------------------------------------- 681--------------------------------------------------------------------------------