diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-14 19:49:10 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-14 19:49:10 +0200 |
commit | 57487d828065ea219cdb33c9dc177b67c60b34c7 (patch) | |
tree | f6cc1e11ba41165e3a65930c66954a5220a4a8cb /packages/base/src/Internal | |
parent | 517dfdbf884ef2b3f3f3d365294a6a714ba7ff9d (diff) |
minor changes
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 10 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 59 | ||||
-rw-r--r-- | packages/base/src/Internal/Static.hs | 3 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 44 |
4 files changed, 77 insertions, 39 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index e0f5ed2..e4b1226 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -262,15 +262,7 @@ compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2 | |||
262 | 262 | ||
263 | ------------------------------------------------------------------ | 263 | ------------------------------------------------------------------ |
264 | 264 | ||
265 | {- | Supported matrix elements. | 265 | -- | Supported matrix elements. |
266 | |||
267 | This class provides optimized internal | ||
268 | operations for selected element types. | ||
269 | It provides unoptimised defaults for any 'Storable' type, | ||
270 | so you can create instances simply as: | ||
271 | |||
272 | >instance Element Foo | ||
273 | -} | ||
274 | class (Storable a) => Element a where | 266 | class (Storable a) => Element a where |
275 | transdata :: Int -> Vector a -> Int -> Vector a | 267 | transdata :: Int -> Vector a -> Int -> Vector a |
276 | constantD :: a -> Int -> Vector a | 268 | constantD :: a -> Int -> Vector a |
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index a84ca25..434fe63 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -10,7 +10,7 @@ | |||
10 | -- Stability : provisional | 10 | -- Stability : provisional |
11 | -- | 11 | -- |
12 | -- In-place manipulation inside the ST monad. | 12 | -- In-place manipulation inside the ST monad. |
13 | -- See examples/inplace.hs in the distribution. | 13 | -- See @examples/inplace.hs@ in the repository. |
14 | -- | 14 | -- |
15 | ----------------------------------------------------------------------------- | 15 | ----------------------------------------------------------------------------- |
16 | 16 | ||
@@ -21,8 +21,8 @@ module Internal.ST ( | |||
21 | -- * Mutable Matrices | 21 | -- * Mutable Matrices |
22 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 22 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
23 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 23 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
24 | axpy, scal, swap, extractMatrix, setMatrix, rowOpST, | 24 | -- axpy, scal, swap, rowOp, |
25 | mutable, | 25 | mutable, extractMatrix, setMatrix, rowOper, RowOper(..), RowRange(..), ColRange(..), |
26 | -- * Unsafe functions | 26 | -- * Unsafe functions |
27 | newUndefinedVector, | 27 | newUndefinedVector, |
28 | unsafeReadVector, unsafeWriteVector, | 28 | unsafeReadVector, unsafeWriteVector, |
@@ -178,16 +178,55 @@ newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c) | |||
178 | 178 | ||
179 | -------------------------------------------------------------------------------- | 179 | -------------------------------------------------------------------------------- |
180 | 180 | ||
181 | rowOpST :: Element t => Int -> t -> Int -> Int -> Int -> Int -> STMatrix s t -> ST s () | 181 | data ColRange = AllCols |
182 | rowOpST c x i1 i2 j1 j2 (STMatrix m) = unsafeIOToST (rowOp c x i1 i2 j1 j2 m) | 182 | | ColRange Int Int |
183 | | Col Int | ||
184 | | FromCol Int | ||
183 | 185 | ||
184 | axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) | 186 | getColRange c AllCols = (0,c-1) |
185 | scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) | 187 | getColRange c (ColRange a b) = (a `mod` c, b `mod` c) |
186 | swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) | 188 | getColRange c (Col a) = (a `mod` c, a `mod` c) |
189 | getColRange c (FromCol a) = (a `mod` c, c-1) | ||
187 | 190 | ||
188 | extractMatrix (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | 191 | data RowRange = AllRows |
192 | | RowRange Int Int | ||
193 | | Row Int | ||
194 | | FromRow Int | ||
195 | |||
196 | getRowRange r AllRows = (0,r-1) | ||
197 | getRowRange r (RowRange a b) = (a `mod` r, b `mod` r) | ||
198 | getRowRange r (Row a) = (a `mod` r, a `mod` r) | ||
199 | getRowRange r (FromRow a) = (a `mod` r, r-1) | ||
200 | |||
201 | data RowOper t = AXPY t Int Int ColRange | ||
202 | | SCAL t RowRange ColRange | ||
203 | | SWAP Int Int ColRange | ||
204 | |||
205 | rowOper :: (Num t, Element t) => RowOper t -> STMatrix s t -> ST s () | ||
206 | |||
207 | rowOper (AXPY x i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 0 x i1' i2' j1 j2 m | ||
208 | where | ||
209 | (j1,j2) = getColRange (cols m) r | ||
210 | i1' = i1 `mod` (rows m) | ||
211 | i2' = i2 `mod` (rows m) | ||
212 | |||
213 | rowOper (SCAL x rr rc) (STMatrix m) = unsafeIOToST $ rowOp 1 x i1 i2 j1 j2 m | ||
214 | where | ||
215 | (i1,i2) = getRowRange (rows m) rr | ||
216 | (j1,j2) = getColRange (cols m) rc | ||
217 | |||
218 | rowOper (SWAP i1 i2 r) (STMatrix m) = unsafeIOToST $ rowOp 2 0 i1' i2' j1 j2 m | ||
219 | where | ||
220 | (j1,j2) = getColRange (cols m) r | ||
221 | i1' = i1 `mod` (rows m) | ||
222 | i2' = i2 `mod` (rows m) | ||
223 | |||
224 | |||
225 | extractMatrix (STMatrix m) rr rc = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | ||
226 | where | ||
227 | (i1,i2) = getRowRange (rows m) rr | ||
228 | (j1,j2) = getColRange (cols m) rc | ||
189 | 229 | ||
190 | -------------------------------------------------------------------------------- | ||
191 | 230 | ||
192 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 231 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
193 | mutable f a = runST $ do | 232 | mutable f a = runST $ do |
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs index 01c2205..0068313 100644 --- a/packages/base/src/Internal/Static.hs +++ b/packages/base/src/Internal/Static.hs | |||
@@ -34,6 +34,9 @@ import Text.Printf | |||
34 | 34 | ||
35 | -------------------------------------------------------------------------------- | 35 | -------------------------------------------------------------------------------- |
36 | 36 | ||
37 | type ℝ = Double | ||
38 | type ℂ = Complex Double | ||
39 | |||
37 | newtype Dim (n :: Nat) t = Dim t | 40 | newtype Dim (n :: Nat) t = Dim t |
38 | deriving Show | 41 | deriving Show |
39 | 42 | ||
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 2650ac8..09ba21c 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -65,7 +65,7 @@ import Internal.Element | |||
65 | import Internal.Container | 65 | import Internal.Container |
66 | import Internal.Vectorized | 66 | import Internal.Vectorized |
67 | import Internal.IO | 67 | import Internal.IO |
68 | import Internal.Algorithms hiding (i,Normed,swap) | 68 | import Internal.Algorithms hiding (i,Normed,swap,linearSolve') |
69 | import Numeric.Matrix() | 69 | import Numeric.Matrix() |
70 | import Numeric.Vector() | 70 | import Numeric.Vector() |
71 | import Internal.Random | 71 | import Internal.Random |
@@ -155,7 +155,7 @@ infixl 3 & | |||
155 | (&) :: Vector Double -> Vector Double -> Vector Double | 155 | (&) :: Vector Double -> Vector Double -> Vector Double |
156 | a & b = vjoin [a,b] | 156 | a & b = vjoin [a,b] |
157 | 157 | ||
158 | {- | horizontal concatenation of real matrices | 158 | {- | horizontal concatenation |
159 | 159 | ||
160 | >>> ident 3 ||| konst 7 (3,4) | 160 | >>> ident 3 ||| konst 7 (3,4) |
161 | (3><7) | 161 | (3><7) |
@@ -165,7 +165,7 @@ a & b = vjoin [a,b] | |||
165 | 165 | ||
166 | -} | 166 | -} |
167 | infixl 3 ||| | 167 | infixl 3 ||| |
168 | (|||) :: Matrix Double -> Matrix Double -> Matrix Double | 168 | (|||) :: Element t => Matrix t -> Matrix t -> Matrix t |
169 | a ||| b = fromBlocks [[a,b]] | 169 | a ||| b = fromBlocks [[a,b]] |
170 | 170 | ||
171 | -- | a synonym for ('|||') (unicode 0x00a6, broken bar) | 171 | -- | a synonym for ('|||') (unicode 0x00a6, broken bar) |
@@ -174,9 +174,9 @@ infixl 3 ¦ | |||
174 | (¦) = (|||) | 174 | (¦) = (|||) |
175 | 175 | ||
176 | 176 | ||
177 | -- | vertical concatenation of real matrices | 177 | -- | vertical concatenation |
178 | -- | 178 | -- |
179 | (===) :: Matrix Double -> Matrix Double -> Matrix Double | 179 | (===) :: Element t => Matrix t -> Matrix t -> Matrix t |
180 | infixl 2 === | 180 | infixl 2 === |
181 | a === b = fromBlocks [[a],[b]] | 181 | a === b = fromBlocks [[a],[b]] |
182 | 182 | ||
@@ -588,7 +588,7 @@ gaussElim_2 a b = flipudrl r | |||
588 | where | 588 | where |
589 | flipudrl = flipud . fliprl | 589 | flipudrl = flipud . fliprl |
590 | splitColsAt n = (takeColumns n &&& dropColumns n) | 590 | splitColsAt n = (takeColumns n &&& dropColumns n) |
591 | go f x y = splitColsAt (cols a) (down f $ fromBlocks [[x,y]]) | 591 | go f x y = splitColsAt (cols a) (down f $ x ||| y) |
592 | (a1,b1) = go (snd . swapMax 0) a b | 592 | (a1,b1) = go (snd . swapMax 0) a b |
593 | ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1) | 593 | ( _, r) = go id (flipudrl $ a1) (flipudrl $ b1) |
594 | 594 | ||
@@ -600,7 +600,7 @@ gaussElim_1 | |||
600 | 600 | ||
601 | gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2) | 601 | gaussElim_1 x y = dropColumns (rows x) (flipud $ fromRows s2) |
602 | where | 602 | where |
603 | rs = toRows $ fromBlocks [[x , y]] | 603 | rs = toRows $ x ||| y |
604 | s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting | 604 | s1 = fromRows $ pivotDown (rows x) 0 rs -- interesting |
605 | s2 = pivotUp (rows x-1) (toRows $ flipud s1) | 605 | s2 = pivotUp (rows x-1) (toRows $ flipud s1) |
606 | 606 | ||
@@ -637,12 +637,15 @@ pivotUp n xs | |||
637 | 637 | ||
638 | -------------------------------------------------------------------------------- | 638 | -------------------------------------------------------------------------------- |
639 | 639 | ||
640 | gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]]) | 640 | gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (a ||| b) |
641 | 641 | ||
642 | gaussST (r,_) x = do | 642 | gaussST (r,_) x = do |
643 | let n = r-1 | 643 | let n = r-1 |
644 | axpy m a i j = rowOper (AXPY a i j AllCols) m | ||
645 | swap m i j = rowOper (SWAP i j AllCols) m | ||
646 | scal m a i = rowOper (SCAL a (Row i) AllCols) m | ||
644 | forM_ [0..n] $ \i -> do | 647 | forM_ [0..n] $ \i -> do |
645 | c <- maxIndex . abs . flatten <$> extractMatrix x i n i i | 648 | c <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i) |
646 | swap x i (i+c) | 649 | swap x i (i+c) |
647 | a <- readMatrix x i i | 650 | a <- readMatrix x i i |
648 | when (a == 0) $ error "singular!" | 651 | when (a == 0) $ error "singular!" |
@@ -656,22 +659,23 @@ gaussST (r,_) x = do | |||
656 | axpy x (-b) i j | 659 | axpy x (-b) i j |
657 | 660 | ||
658 | 661 | ||
659 | luST ok (r,c) x = do | 662 | |
660 | let n = r-1 | 663 | luST ok (r,_) x = do |
661 | axpy' m a i j = rowOpST 0 a i j (i+1) (c-1) m | 664 | let axpy m a i j = rowOper (AXPY a i j (FromCol (i+1))) m |
662 | p <- thawMatrix . asColumn . range $ r | 665 | swap m i j = rowOper (SWAP i j AllCols) m |
663 | forM_ [0..n] $ \i -> do | 666 | p <- newUndefinedVector r |
664 | k <- maxIndex . abs . flatten <$> extractMatrix x i n i i | 667 | forM_ [0..r-1] $ \i -> do |
665 | writeMatrix p i 0 (fi (k+i)) | 668 | k <- maxIndex . abs . flatten <$> extractMatrix x (FromRow i) (Col i) |
669 | writeVector p i (k+i) | ||
666 | swap x i (i+k) | 670 | swap x i (i+k) |
667 | a <- readMatrix x i i | 671 | a <- readMatrix x i i |
668 | when (ok a) $ do | 672 | when (ok a) $ do |
669 | forM_ [i+1..n] $ \j -> do | 673 | forM_ [i+1..r-1] $ \j -> do |
670 | b <- (/a) <$> readMatrix x j i | 674 | b <- (/a) <$> readMatrix x j i |
671 | axpy' x (-b) i j | 675 | axpy x (-b) i j |
672 | writeMatrix x j i b | 676 | writeMatrix x j i b |
673 | v <- unsafeFreezeMatrix p | 677 | v <- unsafeFreezeVector p |
674 | return (map ti $ toList $ flatten v) | 678 | return (toList v) |
675 | 679 | ||
676 | 680 | ||
677 | -------------------------------------------------------------------------------- | 681 | -------------------------------------------------------------------------------- |