From e7d2916f78b5c140738fc4f4f95c9b13c1768293 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 17 Jun 2015 13:02:26 +0200 Subject: luSolve' --- packages/base/src/Internal/Algorithms.hs | 33 ++++++++++++++--------- packages/base/src/Internal/Util.hs | 43 +++++++++++++++++++++++++++--- packages/base/src/Numeric/LinearAlgebra.hs | 3 ++- 3 files changed, 61 insertions(+), 18 deletions(-) diff --git a/packages/base/src/Internal/Algorithms.hs b/packages/base/src/Internal/Algorithms.hs index aaf6fbb..1235da3 100644 --- a/packages/base/src/Internal/Algorithms.hs +++ b/packages/base/src/Internal/Algorithms.hs @@ -29,7 +29,9 @@ import Internal.Conversion import Internal.LAPACK as LAPACK import Internal.Numeric import Data.List(foldl1') -import Data.Array +import qualified Data.Array as A +import Internal.ST +import Internal.Vectorized(range) {- | Generic linear algebra functions for double precision real and complex matrices. @@ -578,11 +580,6 @@ eps = 2.22044604925031e-16 peps :: RealFloat x => x peps = x where x = 2.0 ** fromIntegral (1 - floatDigits x) - --- | The imaginary unit: @i = 0.0 :+ 1.0@ -i :: Complex Double -i = 0:+1 - ----------------------------------------------------------------------- -- | The nullspace of a matrix from its precomputed SVD decomposition. @@ -796,13 +793,23 @@ signlp r vals = foldl f 1 (zip [0..r-1] vals) where f s (a,b) | a /= b = -s | otherwise = s -swap (arr,s) (a,b) | a /= b = (arr // [(a, arr!b),(b,arr!a)],-s) - | otherwise = (arr,s) - -fixPerm r vals = (fromColumns $ elems res, sign) - where v = [0..r-1] - s = toColumns (ident r) - (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals) +fixPerm r vals = (fromColumns $ A.elems res, sign) + where + v = [0..r-1] + t = toColumns (ident r) + (res,sign) = foldl swap (A.listArray (0,r-1) t, 1) (zip v vals) + swap (arr,s) (a,b) + | a /= b = (arr A.// [(a, arr A.! b),(b,arr A.! a)],-s) + | otherwise = (arr,s) + +fixPerm' :: [Int] -> Vector I +fixPerm' s = res $ mutable f s0 + where + s0 = reshape 1 (range (length s)) + res = flatten . fst + swap m i j = rowOper (SWAP i j AllCols) m + f :: (Num t, Element t) => (Int, Int) -> STMatrix s t -> ST s () -- needed because of TypeFamilies + f _ p = sequence_ $ zipWith (swap p) [0..] s triang r c h v = (r>=h then v else 1 - v diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index f08f710..d9777ae 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs @@ -1,6 +1,5 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE ViewPatterns #-} @@ -55,7 +54,7 @@ module Internal.Util( -- ** 2D corr2, conv2, separable, block2x2,block3x3,view1,unView1,foldMatrix, - gaussElim_1, gaussElim_2, gaussElim, luST + gaussElim_1, gaussElim_2, gaussElim, luST, luSolve' ) where import Internal.Vector @@ -65,7 +64,7 @@ import Internal.Element import Internal.Container import Internal.Vectorized import Internal.IO -import Internal.Algorithms hiding (i,Normed,swap,linearSolve') +import Internal.Algorithms hiding (Normed,linearSolve',luSolve') import Numeric.Matrix() import Numeric.Vector() import Internal.Random @@ -73,7 +72,7 @@ import Internal.Convolution import Control.Monad(when,forM_) import Text.Printf import Data.List.Split(splitOn) -import Data.List(intercalate,sortBy) +import Data.List(intercalate,sortBy,foldl') import Control.Arrow((&&&)) import Data.Complex import Data.Function(on) @@ -688,6 +687,42 @@ luST ok (r,_) x = do return (toList v) +-------------------------------------------------------------------------------- + +rowRange m = [0..rows m -1] + +at k = Pos (idxs[k]) + +backSust lup rhs = foldl' f (rhs?[]) (reverse ls) + where + ls = [ (d k , u k , b k) | k <- rowRange lup ] + where + d k = lup ?? (at k, at k) + u k = lup ?? (at k, Drop (k+1)) + b k = rhs ?? (at k, All) + + f x (d,u,b) = (b - u<>x) / d + === + x + + +forwSust lup rhs = foldl' f (rhs?[]) ls + where + ls = [ (l k , b k) | k <- rowRange lup ] + where + l k = lup ?? (at k, Take k) + b k = rhs ?? (at k, All) + + f x (l,b) = x + === + (b - l<>x) + + +luSolve' (lup,p) b = backSust lup (forwSust lup pb) + where + pb = b ?? (Pos (fixPerm' p), All) + + -------------------------------------------------------------------------------- instance Testable (Matrix I) where diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index 2e6b8ca..e899445 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs @@ -77,6 +77,7 @@ module Numeric.LinearAlgebra ( linearSolveLS, linearSolveSVD, luSolve, + luSolve', cholSolve, cgSolve, cgSolve', @@ -158,7 +159,7 @@ import Numeric.Vector() import Internal.Matrix import Internal.Container hiding ((<>)) import Internal.Numeric hiding (mul) -import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked',linearSolve') +import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked',linearSolve',luSolve') import qualified Internal.Algorithms as A import Internal.Util import Internal.Random -- cgit v1.2.3