From 0476c58d0b9da4fdcbbcb05ea055f6d14097e116 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 4 Jun 2014 14:14:31 +0200 Subject: operations with nonexpanded constant and diagonal matrices --- packages/base/src/Numeric/LinearAlgebra/Real.hs | 198 ++++++++++++++++------ packages/base/src/Numeric/LinearAlgebra/Static.hs | 19 ++- 2 files changed, 158 insertions(+), 59 deletions(-) (limited to 'packages/base/src') diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index 424e766..2ff69c7 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs @@ -27,7 +27,7 @@ Experimental interface for real arrays with statically checked dimensions. module Numeric.LinearAlgebra.Real( -- * Vector - R, + R, C, vec2, vec3, vec4, (&), (#), vect, linspace, range, dim, @@ -35,27 +35,30 @@ module Numeric.LinearAlgebra.Real( L, Sq, row, col, (¦),(——), unrow, uncol, - Sized(..), + eye, - diagR, diag, Diag(..), + diagR, diag, blockAt, mat, -- * Products (<>),(#>),(<·>), -- * Linear Systems - linSolve, -- (<\>), + linSolve, (<\>), + -- * Factorizations + svd, svdTall, svdFlat, eig, -- * Pretty printing Disp(..), -- * Misc - C, withVector, withMatrix, + Sized(..), Diag(..), Sym, sym, -- Her, her, module Numeric.HMatrix ) where import GHC.TypeLits import Numeric.HMatrix hiding ( - (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace,(<\>),fromList,takeDiag) + (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace, + (<\>),fromList,takeDiag,svd,eig) import qualified Numeric.HMatrix as LA import Data.Proxy(Proxy) import Numeric.LinearAlgebra.Static @@ -122,8 +125,8 @@ dim = mkR (scalar d) -------------------------------------------------------------------------------- newtype L m n = L (Dim m (Dim n (Matrix ℝ))) - deriving (Num,Fractional) +-- newtype CL m n = CL (Dim m (Dim n (Matrix ℂ))) ud2 :: L m n -> Matrix ℝ ud2 (L (Dim (Dim x))) = x @@ -137,27 +140,22 @@ mkL x = L (Dim (Dim x)) instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) where + show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' show (ud2 -> x) | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' - | isDiag = printf "(diag %s %s :: L %d %d)" (show z) shy m' n' | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" where m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int - isDiag = rows x == 1 && m' > 1 - v = flatten x - z = v!0 - y = subVector 1 (size v-1) v - shy = drop 9 (show y) -------------------------------------------------------------------------------- instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) where konst x = mkR (LA.scalar x) - extract = ud1 + unwrap = ud1 fromList = vect - expand (extract -> v) + extract (unwrap -> v) | singleV v = LA.konst (v!0) d | otherwise = v where @@ -167,23 +165,25 @@ instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) where konst x = mkL (LA.scalar x) - extract = ud2 + unwrap = ud2 fromList = mat - expand (extract -> a) + extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' + extract (unwrap -> a) | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') - | rows a == 1 && m'>1 = diagRect x y m' n' | otherwise = a where m' = fromIntegral . natVal $ (undefined :: Proxy m) n' = fromIntegral . natVal $ (undefined :: Proxy n) - v = flatten a - x = v!0 - y = subVector 1 (size v -1) v -------------------------------------------------------------------------------- diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n -diagR x v = mkL (asRow (vjoin [scalar x, expand v])) +diagR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) + where + ev = extract v + zeros = LA.konst x (max 0 ((min m' n') - size ev)) + m' = fromIntegral . natVal $ (undefined :: Proxy m) + n' = fromIntegral . natVal $ (undefined :: Proxy n) diag :: KnownNat n => R n -> Sq n diag = diagR 0 @@ -221,21 +221,14 @@ class Disp t instance (KnownNat m, KnownNat n) => Disp (L m n) where disp n x = do - let a = expand x + let a = extract x let su = LA.dispf n a printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) -{- - disp n (ud2 -> a) = do - if rows a == 1 && cols a == 1 - then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a) - else putStr "Dim " >> LA.disp n a --} - instance KnownNat n => Disp (R n) where disp n v = do - let su = LA.dispf n (asRow $ expand v) + let su = LA.dispf n (asRow $ extract v) putStr "R " >> putStr (tail . dropWhile (/='x') $ su) -------------------------------------------------------------------------------- @@ -256,7 +249,7 @@ uncol = unrow . tr infixl 2 —— (——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c -a —— b = mkL (expand a LA.—— expand b) +a —— b = mkL (extract a LA.—— extract b) infixl 3 ¦ @@ -264,35 +257,61 @@ infixl 3 ¦ a ¦ b = tr (tr a —— tr b) -type Sq n = L n n +type Sq n = L n n +--type CSq n = CL n n type GL = (KnownNat n, KnownNat m) => L m n type GSq = KnownNat n => Sq n -isDiag0 :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (Vector ℝ) -isDiag0 (extract -> x) - | rows x == 1 && m' > 1 && z == 0 = Just y +isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) +isKonst (unwrap -> x) + | singleM x = Just (x `atIndex` (0,0), (m',n')) + | otherwise = Nothing + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + + + +isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) +isDiag (unwrap -> x) + | singleM x = Nothing + | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) | otherwise = Nothing where m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int v = flatten x z = v!0 y = subVector 1 (size v-1) v + ny = size y + zeros = LA.konst 0 (max 0 (min m' n' - ny)) + yz = vjoin [y,zeros] infixr 8 <> -(<>) :: (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n -a <> b = mkL (expand a LA.<> expand b) +(<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n + +(isKonst -> Just (a,(_,k))) <> (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k) + +(isDiag -> Just (0,a,_)) <> (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) + where + v = a' * b' + n = min (size a) (size b) + a' = subVector 0 n a + b' = subVector 0 n b + +(isDiag -> Just (0,a,_)) <> (extract -> b) = mkL (asColumn a * takeRows (size a) b) + +(extract -> a) <> (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) + +a <> b = mkL (extract a LA.<> extract b) infixr 8 #> (#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m -(isDiag0 -> Just w) #> v = mkR (w' * v') - where - v' = expand v - w' = subVector 0 (max 0 (size w - size v')) (vjoin [w , z]) - z = LA.konst 0 (max 0 (size v' - size w)) +(isDiag -> Just (0, w, _)) #> v = mkR (w * subVector 0 (size w) (extract v)) +m #> v = mkR (extract m LA.#> extract v) -m #> v = mkR (expand m LA.#> expand v) infixr 8 <·> (<·>) :: R n -> R n -> ℝ @@ -306,6 +325,36 @@ instance Transposable (L m n) (L n m) tr (ud2 -> a) = mkL (tr a) -------------------------------------------------------------------------------- + +adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b +adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b)) +adaptDiag f a b = f a b + +isFull m = isDiag m == Nothing && not (singleM (unwrap m)) + + +lift1L f (L v) = L (f v) +lift2L f (L a) (L b) = L (f a b) +lift2LD f = adaptDiag (lift2L f) + + +instance (KnownNat n, KnownNat m) => Num (L n m) + where + (+) = lift2LD (+) + (*) = lift2LD (*) + (-) = lift2LD (-) + abs = lift1L abs + signum = lift1L signum + negate = lift1L negate + fromInteger = L . Dim . Dim . fromInteger + +instance (KnownNat n, KnownNat m) => Fractional (L n m) + where + fromRational = L . Dim . Dim . fromRational + (/) = lift2LD (/) + +-------------------------------------------------------------------------------- + {- class Minim (n :: Nat) (m :: Nat) where @@ -333,24 +382,73 @@ class Diag m d | m -> d instance forall n . (KnownNat n) => Diag (L n n) (R n) where - takeDiag m = mkR (LA.takeDiag (expand m)) + takeDiag m = mkR (LA.takeDiag (extract m)) instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m) where - takeDiag m = mkR (LA.takeDiag (expand m)) + takeDiag m = mkR (LA.takeDiag (extract m)) instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n) where - takeDiag m = mkR (LA.takeDiag (expand m)) + takeDiag m = mkR (LA.takeDiag (extract m)) + + +-------------------------------------------------------------------------------- + +linSolve :: (KnownNat m, KnownNat n) => L m m -> L m n -> L m n +linSolve (extract -> a) (extract -> b) = mkL (LA.linearSolve a b) + +(<\>) :: (KnownNat m, KnownNat n, KnownNat r) => L m n -> L m r -> L n r +(extract -> a) <\> (extract -> b) = mkL (a LA.<\> b) + +svd :: (KnownNat m, KnownNat n) => L m n -> (L m m, R n, L n n) +svd (extract -> m) = (mkL u, mkR s', mkL v) + where + (u,s,v) = LA.svd m + s' = vjoin [s, z] + z = LA.konst 0 (max 0 (cols m - size s)) + + +svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) +svdTall (extract -> m) = (mkL u, mkR s, mkL v) + where + (u,s,v) = LA.thinSVD m +svdFlat :: (KnownNat m, KnownNat n, m <= n) => L m n -> (L m m, R m, L m n) +svdFlat (extract -> m) = (mkL u, mkR s, mkL v) + where + (u,s,v) = LA.thinSVD m + -------------------------------------------------------------------------------- -linSolve :: L m m -> L m n -> L m n -linSolve (ud2 -> a) (ud2 -> b) = mkL (LA.linearSolve a b) +class Eig m r | m -> r + where + eig :: m -> r + +newtype Sym n = Sym (Sq n) + +--newtype Her n = Her (CSq n) + +sym :: KnownNat n => Sq n -> Sym n +sym m = Sym $ (m + tr m)/2 + +--her :: KnownNat n => CSq n -> Her n +--her = undefined -- Her $ (m + tr m)/2 + +instance KnownNat n => Eig (Sym n) (R n, Sq n) + where + eig (Sym (extract -> m)) = (mkR l, mkL v) + where + (l,v) = eigSH m + +instance KnownNat n => Eig (Sq n) (C n) + where + eig (extract -> m) = C . Dim . eigenvalues $ m + -------------------------------------------------------------------------------- withVector @@ -383,7 +481,7 @@ withMatrix a f = test :: (Bool, IO ()) test = (ok,info) where - ok = expand (eye :: Sq 5) == ident 5 + ok = extract (eye :: Sq 5) == ident 5 && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] && ud2 (tm :: L 3 5) == LA.mat 5 [1..15] && thingS == thingD diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index f9e935d..5caf6f8 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs @@ -74,6 +74,12 @@ instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) negate = lift1F negate fromInteger x = Dim (fromInteger x) +instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t)) + where + fromRational x = Dim (fromRational x) + (/) = lift2F (/) + + instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) where (+) = (lift2F . lift2F) (+) @@ -84,11 +90,6 @@ instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) negate = (lift1F . lift1F) negate fromInteger x = Dim (Dim (fromInteger x)) -instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t)) - where - fromRational x = Dim (fromRational x) - (/) = lift2F (/) - instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t))) where fromRational x = Dim (Dim (fromRational x)) @@ -106,8 +107,8 @@ mkV = Dim type M m n t = Dim m (Dim n (Matrix t)) -ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t -ud2 (Dim (Dim m)) = m +--ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t +--ud2 (Dim (Dim m)) = m mkM :: forall (m :: Nat) (n :: Nat) t . t -> Dim m (Dim n t) mkM = Dim . Dim @@ -184,9 +185,9 @@ gmat st xs' class Num t => Sized t s d | s -> t, s -> d where konst :: t -> s - extract :: s -> d + unwrap :: s -> d fromList :: [t] -> s - expand :: s -> d + extract :: s -> d singleV v = size v == 1 singleM m = rows m == 1 && cols m == 1 -- cgit v1.2.3