From 9a17969ad0ea9f940db6201a37b9aed19ad605df Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 3 Jun 2014 21:06:17 +0200 Subject: fix linspace, expose udot, complex static, wip --- packages/base/src/Numeric/Container.hs | 2 +- packages/base/src/Numeric/LinearAlgebra/Complex.hs | 80 +++++ packages/base/src/Numeric/LinearAlgebra/Real.hs | 395 +++++++++++++-------- packages/base/src/Numeric/LinearAlgebra/Static.hs | 193 ++++++++++ 4 files changed, 515 insertions(+), 155 deletions(-) create mode 100644 packages/base/src/Numeric/LinearAlgebra/Complex.hs create mode 100644 packages/base/src/Numeric/LinearAlgebra/Static.hs (limited to 'packages/base/src/Numeric') diff --git a/packages/base/src/Numeric/Container.hs b/packages/base/src/Numeric/Container.hs index 6a841aa..f78bfb9 100644 --- a/packages/base/src/Numeric/Container.hs +++ b/packages/base/src/Numeric/Container.hs @@ -20,7 +20,7 @@ module Numeric.Container( sumElements, prodElements, step, cond, find, assoc, accum, Element(..), - Product(..), + Product(..), dot, udot, optimiseMult, mXm, mXv, vXm, (<.>), Mul(..), diff --git a/packages/base/src/Numeric/LinearAlgebra/Complex.hs b/packages/base/src/Numeric/LinearAlgebra/Complex.hs new file mode 100644 index 0000000..17bc397 --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Complex.hs @@ -0,0 +1,80 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE EmptyDataDecls #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE GADTs #-} + + +{- | +Module : Numeric.LinearAlgebra.Complex +Copyright : (c) Alberto Ruiz 2006-14 +License : BSD3 +Stability : experimental + +-} + +module Numeric.LinearAlgebra.Complex( + C, + vec2, vec3, vec4, (&), (#), + vect, + R +) where + +import GHC.TypeLits +import Numeric.HMatrix hiding ( + (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace) +import qualified Numeric.HMatrix as LA +import Data.Proxy(Proxy) +import Numeric.LinearAlgebra.Static + + + +instance forall n . KnownNat n => Show (C n) + where + show (ud1 -> v) + | size v == 1 = "("++show (v!0)++" :: C "++show d++")" + | otherwise = "(vect"++ drop 8 (show v)++" :: C "++show d++")" + where + d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + + +ud1 :: C n -> Vector ℂ +ud1 (C (Dim v)) = v + +mkC :: Vector ℂ -> C n +mkC = C . Dim + + +infixl 4 & +(&) :: forall n . KnownNat n + => C n -> ℂ -> C (n+1) +u & x = u # (mkC (LA.scalar x) :: C 1) + +infixl 4 # +(#) :: forall n m . (KnownNat n, KnownNat m) + => C n -> C m -> C (n+m) +(C u) # (C v) = C (vconcat u v) + + + +vec2 :: ℂ -> ℂ -> C 2 +vec2 a b = C (gvec2 a b) + +vec3 :: ℂ -> ℂ -> ℂ -> C 3 +vec3 a b c = C (gvec3 a b c) + + +vec4 :: ℂ -> ℂ -> ℂ -> ℂ -> C 4 +vec4 a b c d = C (gvec4 a b c d) + +vect :: forall n . KnownNat n => [ℂ] -> C n +vect xs = C (gvect "C" xs) + diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index 5634031..424e766 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs @@ -11,13 +11,15 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE OverlappingInstances #-} +{-# LANGUAGE TypeFamilies #-} {- | Module : Numeric.LinearAlgebra.Real Copyright : (c) Alberto Ruiz 2006-14 License : BSD3 -Stability : provisional +Stability : experimental Experimental interface for real arrays with statically checked dimensions. @@ -26,165 +28,173 @@ Experimental interface for real arrays with statically checked dimensions. module Numeric.LinearAlgebra.Real( -- * Vector R, - vec2, vec3, vec4, 𝕧, (&), + vec2, vec3, vec4, (&), (#), + vect, + linspace, range, dim, -- * Matrix L, Sq, row, col, (¦),(——), - Konst(..), + unrow, uncol, + Sized(..), eye, - diagR, diag, + diagR, diag, Diag(..), blockAt, + mat, -- * Products (<>),(#>),(<·>), + -- * Linear Systems + linSolve, -- (<\>), -- * Pretty printing Disp(..), -- * Misc - Dim, unDim, + C, + withVector, withMatrix, module Numeric.HMatrix ) where import GHC.TypeLits -import Numeric.HMatrix hiding ((<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col) +import Numeric.HMatrix hiding ( + (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace,(<\>),fromList,takeDiag) import qualified Numeric.HMatrix as LA -import Data.Packed.ST import Data.Proxy(Proxy) +import Numeric.LinearAlgebra.Static +import Text.Printf + +instance forall n . KnownNat n => Show (R n) + where + show (ud1 -> v) + | singleV v = "("++show (v!0)++" :: R "++show d++")" + | otherwise = "(vect"++ drop 8 (show v)++" :: R "++show d++")" + where + d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int -newtype Dim (n :: Nat) t = Dim t - deriving Show -unDim :: Dim n t -> t -unDim (Dim x) = x +ud1 :: R n -> Vector ℝ +ud1 (R (Dim v)) = v --- data Proxy :: Nat -> * +mkR :: Vector ℝ -> R n +mkR = R . Dim -lift1F - :: (c t -> c t) - -> Dim n (c t) -> Dim n (c t) -lift1F f (Dim v) = Dim (f v) -lift2F - :: (c t -> c t -> c t) - -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) -lift2F f (Dim u) (Dim v) = Dim (f u v) +infixl 4 & +(&) :: forall n . KnownNat n + => R n -> ℝ -> R (n+1) +u & x = u # (konst x :: R 1) + +infixl 4 # +(#) :: forall n m . (KnownNat n, KnownNat m) + => R n -> R m -> R (n+m) +(R u) # (R v) = R (vconcat u v) -type R n = Dim n (Vector ℝ) +vec2 :: ℝ -> ℝ -> R 2 +vec2 a b = R (gvec2 a b) -type L m n = Dim m (Dim n (Matrix ℝ)) +vec3 :: ℝ -> ℝ -> ℝ -> R 3 +vec3 a b c = R (gvec3 a b c) -infixl 4 & -(&) :: forall n . KnownNat n - => R n -> ℝ -> R (n+1) -Dim v & x = Dim (vjoin [v', scalar x]) +vec4 :: ℝ -> ℝ -> ℝ -> ℝ -> R 4 +vec4 a b c d = R (gvec4 a b c d) + +vect :: forall n . KnownNat n => [ℝ] -> R n +vect xs = R (gvect "R" xs) + +linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n +linspace (a,b) = mkR (LA.linspace d (a,b)) where d = fromIntegral . natVal $ (undefined :: Proxy n) - v' | d > 1 && size v == 1 = LA.konst (v!0) d - | otherwise = v +range :: forall n . KnownNat n => R n +range = mkR (LA.linspace d (1,fromIntegral d)) + where + d = fromIntegral . natVal $ (undefined :: Proxy n) --- vect0 :: R 0 --- vect0 = Dim (fromList[]) +dim :: forall n . KnownNat n => R n +dim = mkR (scalar d) + where + d = fromIntegral . natVal $ (undefined :: Proxy n) -𝕧 :: ℝ -> R 1 -𝕧 = Dim . scalar +-------------------------------------------------------------------------------- -vec2 :: ℝ -> ℝ -> R 2 -vec2 a b = Dim $ runSTVector $ do - v <- newUndefinedVector 2 - writeVector v 0 a - writeVector v 1 b - return v +newtype L m n = L (Dim m (Dim n (Matrix ℝ))) + deriving (Num,Fractional) -vec3 :: ℝ -> ℝ -> ℝ -> R 3 -vec3 a b c = Dim $ runSTVector $ do - v <- newUndefinedVector 3 - writeVector v 0 a - writeVector v 1 b - writeVector v 2 c - return v +ud2 :: L m n -> Matrix ℝ +ud2 (L (Dim (Dim x))) = x -vec4 :: ℝ -> ℝ -> ℝ -> ℝ -> R 4 -vec4 a b c d = Dim $ runSTVector $ do - v <- newUndefinedVector 4 - writeVector v 0 a - writeVector v 1 b - writeVector v 2 c - writeVector v 3 d - return v +mkL :: Matrix ℝ -> L m n +mkL x = L (Dim (Dim x)) -instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) - where - (+) = lift2F (+) - (*) = lift2F (*) - (-) = lift2F (-) - abs = lift1F abs - signum = lift1F signum - negate = lift1F negate - fromInteger x = Dim (fromInteger x) - -instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) - where - (+) = (lift2F . lift2F) (+) - (*) = (lift2F . lift2F) (*) - (-) = (lift2F . lift2F) (-) - abs = (lift1F . lift1F) abs - signum = (lift1F . lift1F) signum - negate = (lift1F . lift1F) negate - fromInteger x = Dim (Dim (fromInteger x)) - -instance Fractional (Dim n (Vector Double)) - where - fromRational x = Dim (fromRational x) - (/) = lift2F (/) -instance Fractional (Dim m (Dim n (Matrix Double))) +instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) where - fromRational x = Dim (Dim (fromRational x)) - (/) = (lift2F.lift2F) (/) + show (ud2 -> x) + | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' + | isDiag = printf "(diag %s %s :: L %d %d)" (show z) shy m' n' + | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + isDiag = rows x == 1 && m' > 1 + v = flatten x + z = v!0 + y = subVector 1 (size v-1) v + shy = drop 9 (show y) -------------------------------------------------------------------------------- -class Konst t +instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) where - konst :: ℝ -> t + konst x = mkR (LA.scalar x) + extract = ud1 + fromList = vect + expand (extract -> v) + | singleV v = LA.konst (v!0) d + | otherwise = v + where + d = fromIntegral . natVal $ (undefined :: Proxy n) -instance forall n. KnownNat n => Konst (R n) - where - konst x = Dim (LA.konst x d) - where - d = fromIntegral . natVal $ (undefined :: Proxy n) -instance forall m n . (KnownNat m, KnownNat n) => Konst (L m n) +instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) where - konst x = Dim (Dim (LA.konst x (m',n'))) + konst x = mkL (LA.scalar x) + extract = ud2 + fromList = mat + expand (extract -> a) + | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') + | rows a == 1 && m'>1 = diagRect x y m' n' + | otherwise = a where m' = fromIntegral . natVal $ (undefined :: Proxy m) n' = fromIntegral . natVal $ (undefined :: Proxy n) + v = flatten a + x = v!0 + y = subVector 1 (size v -1) v -------------------------------------------------------------------------------- -diagR :: forall m n k . (KnownNat m, KnownNat n) => ℝ -> R k -> L m n -diagR x v = Dim (Dim (diagRect x (unDim v) m' n')) - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - n' = fromIntegral . natVal $ (undefined :: Proxy n) +diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n +diagR x v = mkL (asRow (vjoin [scalar x, expand v])) diag :: KnownNat n => R n -> Sq n diag = diagR 0 +eye :: KnownNat n => Sq n +eye = diag 1 + -------------------------------------------------------------------------------- blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n -blockAt x r c a = Dim (Dim res) +blockAt x r c a = mkL res where z = scalar x z1 = LA.konst x (r,c) @@ -196,117 +206,189 @@ blockAt x r c a = Dim (Dim res) n' = fromIntegral . natVal $ (undefined :: Proxy n) res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] -{- -matrix :: (KnownNat m, KnownNat n) => Matrix Double -> L n m -matrix = blockAt 0 0 0 --} + +mat :: forall m n . (KnownNat m, KnownNat n) => [ℝ] -> L m n +mat xs = L (gmat "L" xs) + -------------------------------------------------------------------------------- class Disp t where disp :: Int -> t -> IO () -instance Disp (L n m) + +instance (KnownNat m, KnownNat n) => Disp (L m n) where - disp n (d2 -> a) = do + disp n x = do + let a = expand x + let su = LA.dispf n a + printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) + +{- + disp n (ud2 -> a) = do if rows a == 1 && cols a == 1 then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a) else putStr "Dim " >> LA.disp n a +-} -instance Disp (R n) +instance KnownNat n => Disp (R n) where - disp n (unDim -> v) = do - let su = LA.dispf n (asRow v) - if LA.size v == 1 - then putStrLn $ "Const " ++ (last . words $ su ) - else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su) + disp n v = do + let su = LA.dispf n (asRow $ expand v) + putStr "R " >> putStr (tail . dropWhile (/='x') $ su) -------------------------------------------------------------------------------- -{- -infixl 3 # -(#) :: L r c -> R c -> L (r+1) c -Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v)) -𝕞 :: forall n . KnownNat n => L 0 n -𝕞 = Dim (Dim (LA.konst 0 (0,d))) - where - d = fromIntegral . natVal $ (undefined :: Proxy n) --} - row :: R n -> L 1 n -row (Dim v) = Dim (Dim (asRow v)) +row = mkL . asRow . ud1 col :: R n -> L n 1 col = tr . row -infixl 3 ¦ -(¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) -a ¦ b = rjoin (expk a) (expk b) - where - Dim (Dim a') `rjoin` Dim (Dim b') = Dim (Dim (a' LA.¦ b')) +unrow :: L 1 n -> R n +unrow = mkR . head . toRows . ud2 + +uncol :: L n 1 -> R n +uncol = unrow . tr + infixl 2 —— (——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c -a —— b = cjoin (expk a) (expk b) - where - Dim (Dim a') `cjoin` Dim (Dim b') = Dim (Dim (a' LA.—— b')) - -expk :: (KnownNat n, KnownNat m) => L m n -> L m n -expk x | singleton x = konst (d2 x `atIndex` (0,0)) - | otherwise = x - where - singleton (d2 -> m) = rows m == 1 && cols m == 1 +a —— b = mkL (expand a LA.—— expand b) -{- +infixl 3 ¦ +(¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) +a ¦ b = tr (tr a —— tr b) --} type Sq n = L n n type GL = (KnownNat n, KnownNat m) => L m n type GSq = KnownNat n => Sq n +isDiag0 :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (Vector ℝ) +isDiag0 (extract -> x) + | rows x == 1 && m' > 1 && z == 0 = Just y + | otherwise = Nothing + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + v = flatten x + z = v!0 + y = subVector 1 (size v-1) v + + infixr 8 <> -(<>) :: L m k -> L k n -> L m n -(d2 -> a) <> (d2 -> b) = Dim (Dim (a LA.<> b)) +(<>) :: (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n +a <> b = mkL (expand a LA.<> expand b) infixr 8 #> -(#>) :: L m n -> R n -> R m -(d2 -> m) #> (unDim -> v) = Dim (m LA.#> v) +(#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m +(isDiag0 -> Just w) #> v = mkR (w' * v') + where + v' = expand v + w' = subVector 0 (max 0 (size w - size v')) (vjoin [w , z]) + z = LA.konst 0 (max 0 (size v' - size w)) + +m #> v = mkR (expand m LA.#> expand v) infixr 8 <·> (<·>) :: R n -> R n -> ℝ -(unDim -> u) <·> (unDim -> v) = udot u v +(ud1 -> u) <·> (ud1 -> v) + | singleV u || singleV v = sumElements (u * v) + | otherwise = udot u v -d2 :: forall c (n :: Nat) (n1 :: Nat). Dim n1 (Dim n c) -> c -d2 = unDim . unDim +instance Transposable (L m n) (L n m) + where + tr (ud2 -> a) = mkL (tr a) +-------------------------------------------------------------------------------- +{- +class Minim (n :: Nat) (m :: Nat) + where + type Mini n m :: Nat -instance Transposable (L m n) (L n m) +instance forall (n :: Nat) . Minim n n where - tr (Dim (Dim a)) = Dim (Dim (tr a)) + type Mini n n = n -eye :: forall n . KnownNat n => Sq n -eye = Dim (Dim (ident d)) +instance forall (n :: Nat) (m :: Nat) . (n <= m+1) => Minim n m where - d = fromIntegral . natVal $ (undefined :: Proxy n) + type Mini n m = n + +instance forall (n :: Nat) (m :: Nat) . (m <= n+1) => Minim n m + where + type Mini n m = m +-} + +class Diag m d | m -> d + where + takeDiag :: m -> d + + + +instance forall n . (KnownNat n) => Diag (L n n) (R n) + where + takeDiag m = mkR (LA.takeDiag (expand m)) + + +instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m) + where + takeDiag m = mkR (LA.takeDiag (expand m)) + + +instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n) + where + takeDiag m = mkR (LA.takeDiag (expand m)) -------------------------------------------------------------------------------- +linSolve :: L m m -> L m n -> L m n +linSolve (ud2 -> a) (ud2 -> b) = mkL (LA.linearSolve a b) + +-------------------------------------------------------------------------------- + +withVector + :: forall z + . Vector ℝ + -> (forall n . (KnownNat n) => R n -> z) + -> z +withVector v f = + case someNatVal $ fromIntegral $ size v of + Nothing -> error "static/dynamic mismatch" + Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) + + +withMatrix + :: forall z + . Matrix ℝ + -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z) + -> z +withMatrix a f = + case someNatVal $ fromIntegral $ rows a of + Nothing -> error "static/dynamic mismatch" + Just (SomeNat (_ :: Proxy m)) -> + case someNatVal $ fromIntegral $ cols a of + Nothing -> error "static/dynamic mismatch" + Just (SomeNat (_ :: Proxy n)) -> + f (mkL a :: L n m) + +-------------------------------------------------------------------------------- + test :: (Bool, IO ()) test = (ok,info) where - ok = d2 (eye :: Sq 5) == ident 5 - && d2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] - && d2 (tm :: L 3 5) == mat 5 [1..15] + ok = expand (eye :: Sq 5) == ident 5 + && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] + && ud2 (tm :: L 3 5) == LA.mat 5 [1..15] && thingS == thingD && precS == precD + && withVector (LA.vect [1..15]) sumV == sumElements (LA.fromList [1..15]) info = do print $ u @@ -319,19 +401,24 @@ test = (ok,info) print thingD print precS print precD + print $ withVector (LA.vect [1..15]) sumV + + sumV w = w <·> konst 1 u = vec2 3 5 + 𝕧 x = vect [x] :: R 1 + v = 𝕧 2 & 4 & 7 - mTm :: L n m -> Sq m +-- mTm :: L n m -> Sq m mTm a = tr a <> a tm :: GL tm = lmat 0 [1..] lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n - lmat z xs = Dim . Dim . reshape n' . fromList . take (m'*n') $ xs ++ repeat z + lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z where m' = fromIntegral . natVal $ (undefined :: Proxy m) n' = fromIntegral . natVal $ (undefined :: Proxy n) @@ -343,12 +430,12 @@ test = (ok,info) where q = tm :: L 10 3 - thingD = vjoin [unDim u, 1] LA.<·> tr m LA.#> m LA.#> unDim v + thingD = vjoin [ud1 u, 1] LA.<·> tr m LA.#> m LA.#> ud1 v where - m = mat 3 [1..30] + m = LA.mat 3 [1..30] precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v - precD = 1 + 2 * vjoin[unDim u, 6] LA.<·> LA.konst 2 (size (unDim u) +1, size (unDim v)) LA.#> unDim v + precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v instance (KnownNat n', KnownNat m') => Testable (L n' m') diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs new file mode 100644 index 0000000..f9e935d --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs @@ -0,0 +1,193 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE EmptyDataDecls #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE GADTs #-} + + +{- | +Module : Numeric.LinearAlgebra.Static +Copyright : (c) Alberto Ruiz 2006-14 +License : BSD3 +Stability : provisional + +-} + +module Numeric.LinearAlgebra.Static( + Dim(..), + R(..), C(..), + lift1F, lift2F, + vconcat, gvec2, gvec3, gvec4, gvect, gmat, + Sized(..), + singleV, singleM +) where + + +import GHC.TypeLits +import Numeric.HMatrix as LA +import Data.Packed as D +import Data.Packed.ST +import Data.Proxy(Proxy) +import Foreign.Storable(Storable) + + + +newtype R n = R (Dim n (Vector ℝ)) + deriving (Num,Fractional) + + +newtype C n = C (Dim n (Vector ℂ)) + deriving (Num,Fractional) + + + +newtype Dim (n :: Nat) t = Dim t + deriving Show + +lift1F + :: (c t -> c t) + -> Dim n (c t) -> Dim n (c t) +lift1F f (Dim v) = Dim (f v) + +lift2F + :: (c t -> c t -> c t) + -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) +lift2F f (Dim u) (Dim v) = Dim (f u v) + +-------------------------------------------------------------------------------- + +instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) + where + (+) = lift2F (+) + (*) = lift2F (*) + (-) = lift2F (-) + abs = lift1F abs + signum = lift1F signum + negate = lift1F negate + fromInteger x = Dim (fromInteger x) + +instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) + where + (+) = (lift2F . lift2F) (+) + (*) = (lift2F . lift2F) (*) + (-) = (lift2F . lift2F) (-) + abs = (lift1F . lift1F) abs + signum = (lift1F . lift1F) signum + negate = (lift1F . lift1F) negate + fromInteger x = Dim (Dim (fromInteger x)) + +instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t)) + where + fromRational x = Dim (fromRational x) + (/) = lift2F (/) + +instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t))) + where + fromRational x = Dim (Dim (fromRational x)) + (/) = (lift2F.lift2F) (/) + +-------------------------------------------------------------------------------- + +type V n t = Dim n (Vector t) + +ud :: Dim n (Vector t) -> Vector t +ud (Dim v) = v + +mkV :: forall (n :: Nat) t . t -> Dim n t +mkV = Dim + +type M m n t = Dim m (Dim n (Matrix t)) + +ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t +ud2 (Dim (Dim m)) = m + +mkM :: forall (m :: Nat) (n :: Nat) t . t -> Dim m (Dim n t) +mkM = Dim . Dim + + +vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) + => V n t -> V m t -> V (n+m) t +(ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v']) + where + du = fromIntegral . natVal $ (undefined :: Proxy n) + dv = fromIntegral . natVal $ (undefined :: Proxy m) + u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du + | otherwise = u + v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv + | otherwise = v + + +gvec2 :: Storable t => t -> t -> V 2 t +gvec2 a b = mkV $ runSTVector $ do + v <- newUndefinedVector 2 + writeVector v 0 a + writeVector v 1 b + return v + +gvec3 :: Storable t => t -> t -> t -> V 3 t +gvec3 a b c = mkV $ runSTVector $ do + v <- newUndefinedVector 3 + writeVector v 0 a + writeVector v 1 b + writeVector v 2 c + return v + + +gvec4 :: Storable t => t -> t -> t -> t -> V 4 t +gvec4 a b c d = mkV $ runSTVector $ do + v <- newUndefinedVector 4 + writeVector v 0 a + writeVector v 1 b + writeVector v 2 c + writeVector v 3 d + return v + + +gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t +gvect st xs' + | ok = mkV v + | not (null rest) && null (tail rest) = abort (show xs') + | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") + | otherwise = abort (show xs) + where + (xs,rest) = splitAt d xs' + ok = size v == d && null rest + v = LA.fromList xs + d = fromIntegral . natVal $ (undefined :: Proxy n) + abort info = error $ st++" "++show d++" can't be created from elements "++info + + +gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> M m n t +gmat st xs' + | ok = mkM x + | not (null rest) && null (tail rest) = abort (show xs') + | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") + | otherwise = abort (show xs) + where + (xs,rest) = splitAt (m'*n') xs' + v = LA.fromList xs + x = reshape n' v + ok = rem (size v) n' == 0 && size x == (m',n') && null rest + m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info + + +class Num t => Sized t s d | s -> t, s -> d + where + konst :: t -> s + extract :: s -> d + fromList :: [t] -> s + expand :: s -> d + +singleV v = size v == 1 +singleM m = rows m == 1 && cols m == 1 + -- cgit v1.2.3