From 2addcfb5db6721b9520e8be9942278dfc17b7021 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Wed, 4 Jun 2014 18:49:50 +0200 Subject: complex instances --- packages/base/src/Numeric/LinearAlgebra/Real.hs | 105 +++++++++++++++++----- packages/base/src/Numeric/LinearAlgebra/Static.hs | 6 +- 2 files changed, 84 insertions(+), 27 deletions(-) (limited to 'packages/base/src') diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index 2ff69c7..d03ca6e 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs @@ -45,7 +45,7 @@ module Numeric.LinearAlgebra.Real( -- * Linear Systems linSolve, (<\>), -- * Factorizations - svd, svdTall, svdFlat, eig, + svd, svdTall, svdFlat, Eigen(..), -- * Pretty printing Disp(..), -- * Misc @@ -58,8 +58,9 @@ module Numeric.LinearAlgebra.Real( import GHC.TypeLits import Numeric.HMatrix hiding ( (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace, - (<\>),fromList,takeDiag,svd,eig) + (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH') import qualified Numeric.HMatrix as LA +import Data.Packed.Internal(mbCatch) import Data.Proxy(Proxy) import Numeric.LinearAlgebra.Static import Text.Printf @@ -80,6 +81,8 @@ ud1 (R (Dim v)) = v mkR :: Vector ℝ -> R n mkR = R . Dim +mkC :: Vector ℂ -> C n +mkC = C . Dim infixl 4 & (&) :: forall n . KnownNat n @@ -126,17 +129,17 @@ dim = mkR (scalar d) newtype L m n = L (Dim m (Dim n (Matrix ℝ))) --- newtype CL m n = CL (Dim m (Dim n (Matrix ℂ))) +newtype M m n = M (Dim m (Dim n (Matrix ℂ))) ud2 :: L m n -> Matrix ℝ ud2 (L (Dim (Dim x))) = x - - mkL :: Matrix ℝ -> L m n mkL x = L (Dim (Dim x)) +mkM :: Matrix ℂ -> M m n +mkM x = M (Dim (Dim x)) instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) where @@ -150,6 +153,18 @@ instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) -------------------------------------------------------------------------------- +instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ) + where + konst x = mkC (LA.scalar x) + unwrap (C (Dim v)) = v + fromList xs = C (gvect "C" xs) + extract (unwrap -> v) + | singleV v = LA.konst (v!0) d + | otherwise = v + where + d = fromIntegral . natVal $ (undefined :: Proxy n) + + instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) where konst x = mkR (LA.scalar x) @@ -162,11 +177,12 @@ instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) d = fromIntegral . natVal $ (undefined :: Proxy n) + instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) where konst x = mkL (LA.scalar x) - unwrap = ud2 fromList = mat + unwrap = ud2 extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' extract (unwrap -> a) | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') @@ -175,6 +191,20 @@ instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) m' = fromIntegral . natVal $ (undefined :: Proxy m) n' = fromIntegral . natVal $ (undefined :: Proxy n) + +instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ) + where + konst x = mkM (LA.scalar x) + fromList xs = M (gmat "M" xs) + unwrap (M (Dim (Dim m))) = m + extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' + extract (unwrap -> a) + | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') + | otherwise = a + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) + n' = fromIntegral . natVal $ (undefined :: Proxy n) + -------------------------------------------------------------------------------- diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n @@ -225,26 +255,41 @@ instance (KnownNat m, KnownNat n) => Disp (L m n) let su = LA.dispf n a printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) +instance (KnownNat m, KnownNat n) => Disp (M m n) + where + disp n x = do + let a = extract x + let su = LA.dispcf n a + printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) + + instance KnownNat n => Disp (R n) where disp n v = do let su = LA.dispf n (asRow $ extract v) putStr "R " >> putStr (tail . dropWhile (/='x') $ su) +instance KnownNat n => Disp (C n) + where + disp n v = do + let su = LA.dispcf n (asRow $ extract v) + putStr "C " >> putStr (tail . dropWhile (/='x') $ su) + + -------------------------------------------------------------------------------- row :: R n -> L 1 n row = mkL . asRow . ud1 -col :: R n -> L n 1 -col = tr . row +--col :: R n -> L n 1 +col v = tr . row $ v unrow :: L 1 n -> R n unrow = mkR . head . toRows . ud2 -uncol :: L n 1 -> R n -uncol = unrow . tr +--uncol :: L n 1 -> R n +uncol v = unrow . tr $ v infixl 2 —— @@ -253,7 +298,7 @@ a —— b = mkL (extract a LA.—— extract b) infixl 3 ¦ -(¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) +-- (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) a ¦ b = tr (tr a —— tr b) @@ -274,7 +319,14 @@ isKonst (unwrap -> x) isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) -isDiag (unwrap -> x) +isDiag (L x) = isDiagg x + +isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int)) +isDiagC (M x) = isDiagg x + + +isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int)) +isDiagg (Dim (Dim x)) | singleM x = Nothing | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) | otherwise = Nothing @@ -282,7 +334,7 @@ isDiag (unwrap -> x) m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int v = flatten x - z = v!0 + z = v `atIndex` 0 y = subVector 1 (size v-1) v ny = size y zeros = LA.konst 0 (max 0 (min m' n' - ny)) @@ -320,9 +372,10 @@ infixr 8 <·> | otherwise = udot u v -instance Transposable (L m n) (L n m) +instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m) where - tr (ud2 -> a) = mkL (tr a) + tr a@(isDiag -> Just _) = mkL (extract a) + tr (extract -> a) = mkL (tr a) -------------------------------------------------------------------------------- @@ -424,11 +477,12 @@ svdFlat (extract -> m) = (mkL u, mkR s, mkL v) -------------------------------------------------------------------------------- -class Eig m r | m -> r +class Eigen m l v | m -> l, m -> v where - eig :: m -> r + eigensystem :: m -> (l,v) + eigenvalues :: m -> l -newtype Sym n = Sym (Sq n) +newtype Sym n = Sym (Sq n) deriving Show --newtype Her n = Her (CSq n) @@ -438,16 +492,19 @@ sym m = Sym $ (m + tr m)/2 --her :: KnownNat n => CSq n -> Her n --her = undefined -- Her $ (m + tr m)/2 - -instance KnownNat n => Eig (Sym n) (R n, Sq n) +instance KnownNat n => Eigen (Sym n) (R n) (L n n) where - eig (Sym (extract -> m)) = (mkR l, mkL v) + eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m + eigensystem (Sym (extract -> m)) = (mkR l, mkL v) where - (l,v) = eigSH m + (l,v) = LA.eigSH' m -instance KnownNat n => Eig (Sq n) (C n) +instance KnownNat n => Eigen (Sq n) (C n) (M n n) where - eig (extract -> m) = C . Dim . eigenvalues $ m + eigenvalues (extract -> m) = mkC . LA.eigenvalues $ m + eigensystem (extract -> m) = (mkC l, mkM v) + where + (l,v) = LA.eig m -------------------------------------------------------------------------------- diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 5caf6f8..6acd9a3 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs @@ -27,7 +27,7 @@ module Numeric.LinearAlgebra.Static( lift1F, lift2F, vconcat, gvec2, gvec3, gvec4, gvect, gmat, Sized(..), - singleV, singleM + singleV, singleM,GM ) where @@ -105,7 +105,7 @@ ud (Dim v) = v mkV :: forall (n :: Nat) t . t -> Dim n t mkV = Dim -type M m n t = Dim m (Dim n (Matrix t)) +type GM m n t = Dim m (Dim n (Matrix t)) --ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t --ud2 (Dim (Dim m)) = m @@ -166,7 +166,7 @@ gvect st xs' abort info = error $ st++" "++show d++" can't be created from elements "++info -gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> M m n t +gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t gmat st xs' | ok = mkM x | not (null rest) && null (tail rest) = abort (show xs') -- cgit v1.2.3