From 9a17969ad0ea9f940db6201a37b9aed19ad605df Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 3 Jun 2014 21:06:17 +0200 Subject: fix linspace, expose udot, complex static, wip --- packages/base/src/Numeric/LinearAlgebra/Real.hs | 395 +++++++++++++++--------- 1 file changed, 241 insertions(+), 154 deletions(-) (limited to 'packages/base/src/Numeric/LinearAlgebra/Real.hs') diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index 5634031..424e766 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs @@ -11,13 +11,15 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE OverlappingInstances #-} +{-# LANGUAGE TypeFamilies #-} {- | Module : Numeric.LinearAlgebra.Real Copyright : (c) Alberto Ruiz 2006-14 License : BSD3 -Stability : provisional +Stability : experimental Experimental interface for real arrays with statically checked dimensions. @@ -26,165 +28,173 @@ Experimental interface for real arrays with statically checked dimensions. module Numeric.LinearAlgebra.Real( -- * Vector R, - vec2, vec3, vec4, ๐•ง, (&), + vec2, vec3, vec4, (&), (#), + vect, + linspace, range, dim, -- * Matrix L, Sq, row, col, (ยฆ),(โ€”โ€”), - Konst(..), + unrow, uncol, + Sized(..), eye, - diagR, diag, + diagR, diag, Diag(..), blockAt, + mat, -- * Products (<>),(#>),(<ยท>), + -- * Linear Systems + linSolve, -- (<\>), -- * Pretty printing Disp(..), -- * Misc - Dim, unDim, + C, + withVector, withMatrix, module Numeric.HMatrix ) where import GHC.TypeLits -import Numeric.HMatrix hiding ((<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โ€”โ€”),row,col) +import Numeric.HMatrix hiding ( + (<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โ€”โ€”),row,col,vect,mat,linspace,(<\>),fromList,takeDiag) import qualified Numeric.HMatrix as LA -import Data.Packed.ST import Data.Proxy(Proxy) +import Numeric.LinearAlgebra.Static +import Text.Printf + +instance forall n . KnownNat n => Show (R n) + where + show (ud1 -> v) + | singleV v = "("++show (v!0)++" :: R "++show d++")" + | otherwise = "(vect"++ drop 8 (show v)++" :: R "++show d++")" + where + d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int -newtype Dim (n :: Nat) t = Dim t - deriving Show -unDim :: Dim n t -> t -unDim (Dim x) = x +ud1 :: R n -> Vector โ„ +ud1 (R (Dim v)) = v --- data Proxy :: Nat -> * +mkR :: Vector โ„ -> R n +mkR = R . Dim -lift1F - :: (c t -> c t) - -> Dim n (c t) -> Dim n (c t) -lift1F f (Dim v) = Dim (f v) -lift2F - :: (c t -> c t -> c t) - -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) -lift2F f (Dim u) (Dim v) = Dim (f u v) +infixl 4 & +(&) :: forall n . KnownNat n + => R n -> โ„ -> R (n+1) +u & x = u # (konst x :: R 1) + +infixl 4 # +(#) :: forall n m . (KnownNat n, KnownNat m) + => R n -> R m -> R (n+m) +(R u) # (R v) = R (vconcat u v) -type R n = Dim n (Vector โ„) +vec2 :: โ„ -> โ„ -> R 2 +vec2 a b = R (gvec2 a b) -type L m n = Dim m (Dim n (Matrix โ„)) +vec3 :: โ„ -> โ„ -> โ„ -> R 3 +vec3 a b c = R (gvec3 a b c) -infixl 4 & -(&) :: forall n . KnownNat n - => R n -> โ„ -> R (n+1) -Dim v & x = Dim (vjoin [v', scalar x]) +vec4 :: โ„ -> โ„ -> โ„ -> โ„ -> R 4 +vec4 a b c d = R (gvec4 a b c d) + +vect :: forall n . KnownNat n => [โ„] -> R n +vect xs = R (gvect "R" xs) + +linspace :: forall n . KnownNat n => (โ„,โ„) -> R n +linspace (a,b) = mkR (LA.linspace d (a,b)) where d = fromIntegral . natVal $ (undefined :: Proxy n) - v' | d > 1 && size v == 1 = LA.konst (v!0) d - | otherwise = v +range :: forall n . KnownNat n => R n +range = mkR (LA.linspace d (1,fromIntegral d)) + where + d = fromIntegral . natVal $ (undefined :: Proxy n) --- vect0 :: R 0 --- vect0 = Dim (fromList[]) +dim :: forall n . KnownNat n => R n +dim = mkR (scalar d) + where + d = fromIntegral . natVal $ (undefined :: Proxy n) -๐•ง :: โ„ -> R 1 -๐•ง = Dim . scalar +-------------------------------------------------------------------------------- -vec2 :: โ„ -> โ„ -> R 2 -vec2 a b = Dim $ runSTVector $ do - v <- newUndefinedVector 2 - writeVector v 0 a - writeVector v 1 b - return v +newtype L m n = L (Dim m (Dim n (Matrix โ„))) + deriving (Num,Fractional) -vec3 :: โ„ -> โ„ -> โ„ -> R 3 -vec3 a b c = Dim $ runSTVector $ do - v <- newUndefinedVector 3 - writeVector v 0 a - writeVector v 1 b - writeVector v 2 c - return v +ud2 :: L m n -> Matrix โ„ +ud2 (L (Dim (Dim x))) = x -vec4 :: โ„ -> โ„ -> โ„ -> โ„ -> R 4 -vec4 a b c d = Dim $ runSTVector $ do - v <- newUndefinedVector 4 - writeVector v 0 a - writeVector v 1 b - writeVector v 2 c - writeVector v 3 d - return v +mkL :: Matrix โ„ -> L m n +mkL x = L (Dim (Dim x)) -instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) - where - (+) = lift2F (+) - (*) = lift2F (*) - (-) = lift2F (-) - abs = lift1F abs - signum = lift1F signum - negate = lift1F negate - fromInteger x = Dim (fromInteger x) - -instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) - where - (+) = (lift2F . lift2F) (+) - (*) = (lift2F . lift2F) (*) - (-) = (lift2F . lift2F) (-) - abs = (lift1F . lift1F) abs - signum = (lift1F . lift1F) signum - negate = (lift1F . lift1F) negate - fromInteger x = Dim (Dim (fromInteger x)) - -instance Fractional (Dim n (Vector Double)) - where - fromRational x = Dim (fromRational x) - (/) = lift2F (/) -instance Fractional (Dim m (Dim n (Matrix Double))) +instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) where - fromRational x = Dim (Dim (fromRational x)) - (/) = (lift2F.lift2F) (/) + show (ud2 -> x) + | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' + | isDiag = printf "(diag %s %s :: L %d %d)" (show z) shy m' n' + | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + isDiag = rows x == 1 && m' > 1 + v = flatten x + z = v!0 + y = subVector 1 (size v-1) v + shy = drop 9 (show y) -------------------------------------------------------------------------------- -class Konst t +instance forall n. KnownNat n => Sized โ„ (R n) (Vector โ„) where - konst :: โ„ -> t + konst x = mkR (LA.scalar x) + extract = ud1 + fromList = vect + expand (extract -> v) + | singleV v = LA.konst (v!0) d + | otherwise = v + where + d = fromIntegral . natVal $ (undefined :: Proxy n) -instance forall n. KnownNat n => Konst (R n) - where - konst x = Dim (LA.konst x d) - where - d = fromIntegral . natVal $ (undefined :: Proxy n) -instance forall m n . (KnownNat m, KnownNat n) => Konst (L m n) +instance forall m n . (KnownNat m, KnownNat n) => Sized โ„ (L m n) (Matrix โ„) where - konst x = Dim (Dim (LA.konst x (m',n'))) + konst x = mkL (LA.scalar x) + extract = ud2 + fromList = mat + expand (extract -> a) + | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') + | rows a == 1 && m'>1 = diagRect x y m' n' + | otherwise = a where m' = fromIntegral . natVal $ (undefined :: Proxy m) n' = fromIntegral . natVal $ (undefined :: Proxy n) + v = flatten a + x = v!0 + y = subVector 1 (size v -1) v -------------------------------------------------------------------------------- -diagR :: forall m n k . (KnownNat m, KnownNat n) => โ„ -> R k -> L m n -diagR x v = Dim (Dim (diagRect x (unDim v) m' n')) - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - n' = fromIntegral . natVal $ (undefined :: Proxy n) +diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => โ„ -> R k -> L m n +diagR x v = mkL (asRow (vjoin [scalar x, expand v])) diag :: KnownNat n => R n -> Sq n diag = diagR 0 +eye :: KnownNat n => Sq n +eye = diag 1 + -------------------------------------------------------------------------------- blockAt :: forall m n . (KnownNat m, KnownNat n) => โ„ -> Int -> Int -> Matrix Double -> L m n -blockAt x r c a = Dim (Dim res) +blockAt x r c a = mkL res where z = scalar x z1 = LA.konst x (r,c) @@ -196,117 +206,189 @@ blockAt x r c a = Dim (Dim res) n' = fromIntegral . natVal $ (undefined :: Proxy n) res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] -{- -matrix :: (KnownNat m, KnownNat n) => Matrix Double -> L n m -matrix = blockAt 0 0 0 --} + +mat :: forall m n . (KnownNat m, KnownNat n) => [โ„] -> L m n +mat xs = L (gmat "L" xs) + -------------------------------------------------------------------------------- class Disp t where disp :: Int -> t -> IO () -instance Disp (L n m) + +instance (KnownNat m, KnownNat n) => Disp (L m n) where - disp n (d2 -> a) = do + disp n x = do + let a = expand x + let su = LA.dispf n a + printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) + +{- + disp n (ud2 -> a) = do if rows a == 1 && cols a == 1 then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a) else putStr "Dim " >> LA.disp n a +-} -instance Disp (R n) +instance KnownNat n => Disp (R n) where - disp n (unDim -> v) = do - let su = LA.dispf n (asRow v) - if LA.size v == 1 - then putStrLn $ "Const " ++ (last . words $ su ) - else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su) + disp n v = do + let su = LA.dispf n (asRow $ expand v) + putStr "R " >> putStr (tail . dropWhile (/='x') $ su) -------------------------------------------------------------------------------- -{- -infixl 3 # -(#) :: L r c -> R c -> L (r+1) c -Dim (Dim m) # Dim v = Dim (Dim (m LA.โ€”โ€” asRow v)) -๐•ž :: forall n . KnownNat n => L 0 n -๐•ž = Dim (Dim (LA.konst 0 (0,d))) - where - d = fromIntegral . natVal $ (undefined :: Proxy n) --} - row :: R n -> L 1 n -row (Dim v) = Dim (Dim (asRow v)) +row = mkL . asRow . ud1 col :: R n -> L n 1 col = tr . row -infixl 3 ยฆ -(ยฆ) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) -a ยฆ b = rjoin (expk a) (expk b) - where - Dim (Dim a') `rjoin` Dim (Dim b') = Dim (Dim (a' LA.ยฆ b')) +unrow :: L 1 n -> R n +unrow = mkR . head . toRows . ud2 + +uncol :: L n 1 -> R n +uncol = unrow . tr + infixl 2 โ€”โ€” (โ€”โ€”) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c -a โ€”โ€” b = cjoin (expk a) (expk b) - where - Dim (Dim a') `cjoin` Dim (Dim b') = Dim (Dim (a' LA.โ€”โ€” b')) - -expk :: (KnownNat n, KnownNat m) => L m n -> L m n -expk x | singleton x = konst (d2 x `atIndex` (0,0)) - | otherwise = x - where - singleton (d2 -> m) = rows m == 1 && cols m == 1 +a โ€”โ€” b = mkL (expand a LA.โ€”โ€” expand b) -{- +infixl 3 ยฆ +(ยฆ) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) +a ยฆ b = tr (tr a โ€”โ€” tr b) --} type Sq n = L n n type GL = (KnownNat n, KnownNat m) => L m n type GSq = KnownNat n => Sq n +isDiag0 :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (Vector โ„) +isDiag0 (extract -> x) + | rows x == 1 && m' > 1 && z == 0 = Just y + | otherwise = Nothing + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + v = flatten x + z = v!0 + y = subVector 1 (size v-1) v + + infixr 8 <> -(<>) :: L m k -> L k n -> L m n -(d2 -> a) <> (d2 -> b) = Dim (Dim (a LA.<> b)) +(<>) :: (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n +a <> b = mkL (expand a LA.<> expand b) infixr 8 #> -(#>) :: L m n -> R n -> R m -(d2 -> m) #> (unDim -> v) = Dim (m LA.#> v) +(#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m +(isDiag0 -> Just w) #> v = mkR (w' * v') + where + v' = expand v + w' = subVector 0 (max 0 (size w - size v')) (vjoin [w , z]) + z = LA.konst 0 (max 0 (size v' - size w)) + +m #> v = mkR (expand m LA.#> expand v) infixr 8 <ยท> (<ยท>) :: R n -> R n -> โ„ -(unDim -> u) <ยท> (unDim -> v) = udot u v +(ud1 -> u) <ยท> (ud1 -> v) + | singleV u || singleV v = sumElements (u * v) + | otherwise = udot u v -d2 :: forall c (n :: Nat) (n1 :: Nat). Dim n1 (Dim n c) -> c -d2 = unDim . unDim +instance Transposable (L m n) (L n m) + where + tr (ud2 -> a) = mkL (tr a) +-------------------------------------------------------------------------------- +{- +class Minim (n :: Nat) (m :: Nat) + where + type Mini n m :: Nat -instance Transposable (L m n) (L n m) +instance forall (n :: Nat) . Minim n n where - tr (Dim (Dim a)) = Dim (Dim (tr a)) + type Mini n n = n -eye :: forall n . KnownNat n => Sq n -eye = Dim (Dim (ident d)) +instance forall (n :: Nat) (m :: Nat) . (n <= m+1) => Minim n m where - d = fromIntegral . natVal $ (undefined :: Proxy n) + type Mini n m = n + +instance forall (n :: Nat) (m :: Nat) . (m <= n+1) => Minim n m + where + type Mini n m = m +-} + +class Diag m d | m -> d + where + takeDiag :: m -> d + + + +instance forall n . (KnownNat n) => Diag (L n n) (R n) + where + takeDiag m = mkR (LA.takeDiag (expand m)) + + +instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m) + where + takeDiag m = mkR (LA.takeDiag (expand m)) + + +instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n) + where + takeDiag m = mkR (LA.takeDiag (expand m)) -------------------------------------------------------------------------------- +linSolve :: L m m -> L m n -> L m n +linSolve (ud2 -> a) (ud2 -> b) = mkL (LA.linearSolve a b) + +-------------------------------------------------------------------------------- + +withVector + :: forall z + . Vector โ„ + -> (forall n . (KnownNat n) => R n -> z) + -> z +withVector v f = + case someNatVal $ fromIntegral $ size v of + Nothing -> error "static/dynamic mismatch" + Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) + + +withMatrix + :: forall z + . Matrix โ„ + -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z) + -> z +withMatrix a f = + case someNatVal $ fromIntegral $ rows a of + Nothing -> error "static/dynamic mismatch" + Just (SomeNat (_ :: Proxy m)) -> + case someNatVal $ fromIntegral $ cols a of + Nothing -> error "static/dynamic mismatch" + Just (SomeNat (_ :: Proxy n)) -> + f (mkL a :: L n m) + +-------------------------------------------------------------------------------- + test :: (Bool, IO ()) test = (ok,info) where - ok = d2 (eye :: Sq 5) == ident 5 - && d2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] - && d2 (tm :: L 3 5) == mat 5 [1..15] + ok = expand (eye :: Sq 5) == ident 5 + && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] + && ud2 (tm :: L 3 5) == LA.mat 5 [1..15] && thingS == thingD && precS == precD + && withVector (LA.vect [1..15]) sumV == sumElements (LA.fromList [1..15]) info = do print $ u @@ -319,19 +401,24 @@ test = (ok,info) print thingD print precS print precD + print $ withVector (LA.vect [1..15]) sumV + + sumV w = w <ยท> konst 1 u = vec2 3 5 + ๐•ง x = vect [x] :: R 1 + v = ๐•ง 2 & 4 & 7 - mTm :: L n m -> Sq m +-- mTm :: L n m -> Sq m mTm a = tr a <> a tm :: GL tm = lmat 0 [1..] lmat :: forall m n . (KnownNat m, KnownNat n) => โ„ -> [โ„] -> L m n - lmat z xs = Dim . Dim . reshape n' . fromList . take (m'*n') $ xs ++ repeat z + lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z where m' = fromIntegral . natVal $ (undefined :: Proxy m) n' = fromIntegral . natVal $ (undefined :: Proxy n) @@ -343,12 +430,12 @@ test = (ok,info) where q = tm :: L 10 3 - thingD = vjoin [unDim u, 1] LA.<ยท> tr m LA.#> m LA.#> unDim v + thingD = vjoin [ud1 u, 1] LA.<ยท> tr m LA.#> m LA.#> ud1 v where - m = mat 3 [1..30] + m = LA.mat 3 [1..30] precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <ยท> konst 2 #> v - precD = 1 + 2 * vjoin[unDim u, 6] LA.<ยท> LA.konst 2 (size (unDim u) +1, size (unDim v)) LA.#> unDim v + precD = 1 + 2 * vjoin[ud1 u, 6] LA.<ยท> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v instance (KnownNat n', KnownNat m') => Testable (L n' m') -- cgit v1.2.3