From 0fe05d4f692ed3e199ce2ad3b64df6f756f8807a Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Fri, 5 Jun 2015 16:32:59 +0200 Subject: move internal static --- packages/base/src/Internal/Static.hs | 524 +++++++++++++++++++++ .../src/Numeric/LinearAlgebra/Static/Internal.hs | 524 --------------------- 2 files changed, 524 insertions(+), 524 deletions(-) create mode 100644 packages/base/src/Internal/Static.hs delete mode 100644 packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs (limited to 'packages/base/src') diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs new file mode 100644 index 0000000..48327e5 --- /dev/null +++ b/packages/base/src/Internal/Static.hs @@ -0,0 +1,524 @@ +#if __GLASGOW_HASKELL__ >= 708 + +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE Rank2Types #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} + +{- | +Module : Internal.Static +Copyright : (c) Alberto Ruiz 2006-14 +License : BSD3 +Stability : provisional + +-} + +module Internal.Static where + + +import GHC.TypeLits +import qualified Numeric.LinearAlgebra as LA +import Numeric.LinearAlgebra hiding (konst,size) +import Internal.Vector as D +import Internal.ST +import Data.Proxy(Proxy) +import Foreign.Storable(Storable) +import Text.Printf + +-------------------------------------------------------------------------------- + +newtype Dim (n :: Nat) t = Dim t + deriving Show + +lift1F + :: (c t -> c t) + -> Dim n (c t) -> Dim n (c t) +lift1F f (Dim v) = Dim (f v) + +lift2F + :: (c t -> c t -> c t) + -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) +lift2F f (Dim u) (Dim v) = Dim (f u v) + +-------------------------------------------------------------------------------- + +newtype R n = R (Dim n (Vector ℝ)) + deriving (Num,Fractional,Floating) + +newtype C n = C (Dim n (Vector ℂ)) + deriving (Num,Fractional,Floating) + +newtype L m n = L (Dim m (Dim n (Matrix ℝ))) + +newtype M m n = M (Dim m (Dim n (Matrix ℂ))) + + +mkR :: Vector ℝ -> R n +mkR = R . Dim + +mkC :: Vector ℂ -> C n +mkC = C . Dim + +mkL :: Matrix ℝ -> L m n +mkL x = L (Dim (Dim x)) + +mkM :: Matrix ℂ -> M m n +mkM x = M (Dim (Dim x)) + +-------------------------------------------------------------------------------- + +type V n t = Dim n (Vector t) + +ud :: Dim n (Vector t) -> Vector t +ud (Dim v) = v + +mkV :: forall (n :: Nat) t . t -> Dim n t +mkV = Dim + + +vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) + => V n t -> V m t -> V (n+m) t +(ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v']) + where + du = fromIntegral . natVal $ (undefined :: Proxy n) + dv = fromIntegral . natVal $ (undefined :: Proxy m) + u' | du > 1 && LA.size u == 1 = LA.konst (u D.@> 0) du + | otherwise = u + v' | dv > 1 && LA.size v == 1 = LA.konst (v D.@> 0) dv + | otherwise = v + + +gvec2 :: Storable t => t -> t -> V 2 t +gvec2 a b = mkV $ runSTVector $ do + v <- newUndefinedVector 2 + writeVector v 0 a + writeVector v 1 b + return v + +gvec3 :: Storable t => t -> t -> t -> V 3 t +gvec3 a b c = mkV $ runSTVector $ do + v <- newUndefinedVector 3 + writeVector v 0 a + writeVector v 1 b + writeVector v 2 c + return v + + +gvec4 :: Storable t => t -> t -> t -> t -> V 4 t +gvec4 a b c d = mkV $ runSTVector $ do + v <- newUndefinedVector 4 + writeVector v 0 a + writeVector v 1 b + writeVector v 2 c + writeVector v 3 d + return v + + +gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t +gvect st xs' + | ok = mkV v + | not (null rest) && null (tail rest) = abort (show xs') + | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") + | otherwise = abort (show xs) + where + (xs,rest) = splitAt d xs' + ok = LA.size v == d && null rest + v = LA.fromList xs + d = fromIntegral . natVal $ (undefined :: Proxy n) + abort info = error $ st++" "++show d++" can't be created from elements "++info + + +-------------------------------------------------------------------------------- + +type GM m n t = Dim m (Dim n (Matrix t)) + + +gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t +gmat st xs' + | ok = Dim (Dim x) + | not (null rest) && null (tail rest) = abort (show xs') + | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") + | otherwise = abort (show xs) + where + (xs,rest) = splitAt (m'*n') xs' + v = LA.fromList xs + x = reshape n' v + ok = null rest && ((n' == 0 && dim v == 0) || n'> 0 && (rem (LA.size v) n' == 0) && LA.size x == (m',n')) + m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info + +-------------------------------------------------------------------------------- + +class Num t => Sized t s d | s -> t, s -> d + where + konst :: t -> s + unwrap :: s -> d t + fromList :: [t] -> s + extract :: s -> d t + create :: d t -> Maybe s + size :: s -> IndexOf d + +singleV v = LA.size v == 1 +singleM m = rows m == 1 && cols m == 1 + + +instance forall n. KnownNat n => Sized ℂ (C n) Vector + where + size _ = fromIntegral . natVal $ (undefined :: Proxy n) + konst x = mkC (LA.scalar x) + unwrap (C (Dim v)) = v + fromList xs = C (gvect "C" xs) + extract s@(unwrap -> v) + | singleV v = LA.konst (v!0) (size s) + | otherwise = v + create v + | LA.size v == size r = Just r + | otherwise = Nothing + where + r = mkC v :: C n + + +instance forall n. KnownNat n => Sized ℝ (R n) Vector + where + size _ = fromIntegral . natVal $ (undefined :: Proxy n) + konst x = mkR (LA.scalar x) + unwrap (R (Dim v)) = v + fromList xs = R (gvect "R" xs) + extract s@(unwrap -> v) + | singleV v = LA.konst (v!0) (size s) + | otherwise = v + create v + | LA.size v == size r = Just r + | otherwise = Nothing + where + r = mkR v :: R n + + + +instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) Matrix + where + size _ = ((fromIntegral . natVal) (undefined :: Proxy m) + ,(fromIntegral . natVal) (undefined :: Proxy n)) + konst x = mkL (LA.scalar x) + fromList xs = L (gmat "L" xs) + unwrap (L (Dim (Dim m))) = m + extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' + extract s@(unwrap -> a) + | singleM a = LA.konst (a `atIndex` (0,0)) (size s) + | otherwise = a + create x + | LA.size x == size r = Just r + | otherwise = Nothing + where + r = mkL x :: L m n + + +instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) Matrix + where + size _ = ((fromIntegral . natVal) (undefined :: Proxy m) + ,(fromIntegral . natVal) (undefined :: Proxy n)) + konst x = mkM (LA.scalar x) + fromList xs = M (gmat "M" xs) + unwrap (M (Dim (Dim m))) = m + extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' + extract s@(unwrap -> a) + | singleM a = LA.konst (a `atIndex` (0,0)) (size s) + | otherwise = a + create x + | LA.size x == size r = Just r + | otherwise = Nothing + where + r = mkM x :: M m n + +-------------------------------------------------------------------------------- + +instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m) + where + tr a@(isDiag -> Just _) = mkL (extract a) + tr (extract -> a) = mkL (tr a) + tr' = tr + +instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m) + where + tr a@(isDiagC -> Just _) = mkM (extract a) + tr (extract -> a) = mkM (tr a) + tr' a@(isDiagC -> Just _) = mkM (extract a) + tr' (extract -> a) = mkM (tr' a) + +-------------------------------------------------------------------------------- + +isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) +isDiag (L x) = isDiagg x + +isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int)) +isDiagC (M x) = isDiagg x + + +isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int)) +isDiagg (Dim (Dim x)) + | singleM x = Nothing + | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) + | otherwise = Nothing + where + m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int + n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + v = flatten x + z = v `atIndex` 0 + y = subVector 1 (LA.size v-1) v + ny = LA.size y + zeros = LA.konst 0 (max 0 (min m' n' - ny)) + yz = vjoin [y,zeros] + +-------------------------------------------------------------------------------- + +instance forall n . KnownNat n => Show (R n) + where + show s@(R (Dim v)) + | singleV v = "("++show (v!0)++" :: R "++show d++")" + | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" + where + d = size s + +instance forall n . KnownNat n => Show (C n) + where + show s@(C (Dim v)) + | singleV v = "("++show (v!0)++" :: C "++show d++")" + | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" + where + d = size s + +instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) + where + show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' + show s@(L (Dim (Dim x))) + | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' + | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" + where + (m',n') = size s + +instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) + where + show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' + show s@(M (Dim (Dim x))) + | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' + | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" + where + (m',n') = size s + +-------------------------------------------------------------------------------- + +instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) + where + (+) = lift2F (+) + (*) = lift2F (*) + (-) = lift2F (-) + abs = lift1F abs + signum = lift1F signum + negate = lift1F negate + fromInteger x = Dim (fromInteger x) + +instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim n (Vector t)) + where + fromRational x = Dim (fromRational x) + (/) = lift2F (/) + +instance (Fractional t, Floating (Vector t), Numeric t) => Floating (Dim n (Vector t)) where + sin = lift1F sin + cos = lift1F cos + tan = lift1F tan + asin = lift1F asin + acos = lift1F acos + atan = lift1F atan + sinh = lift1F sinh + cosh = lift1F cosh + tanh = lift1F tanh + asinh = lift1F asinh + acosh = lift1F acosh + atanh = lift1F atanh + exp = lift1F exp + log = lift1F log + sqrt = lift1F sqrt + (**) = lift2F (**) + pi = Dim pi + + +instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) + where + (+) = (lift2F . lift2F) (+) + (*) = (lift2F . lift2F) (*) + (-) = (lift2F . lift2F) (-) + abs = (lift1F . lift1F) abs + signum = (lift1F . lift1F) signum + negate = (lift1F . lift1F) negate + fromInteger x = Dim (Dim (fromInteger x)) + +instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim m (Dim n (Matrix t))) + where + fromRational x = Dim (Dim (fromRational x)) + (/) = (lift2F.lift2F) (/) + +instance (Num (Vector t), Floating (Matrix t), Fractional t, Numeric t) => Floating (Dim m (Dim n (Matrix t))) where + sin = (lift1F . lift1F) sin + cos = (lift1F . lift1F) cos + tan = (lift1F . lift1F) tan + asin = (lift1F . lift1F) asin + acos = (lift1F . lift1F) acos + atan = (lift1F . lift1F) atan + sinh = (lift1F . lift1F) sinh + cosh = (lift1F . lift1F) cosh + tanh = (lift1F . lift1F) tanh + asinh = (lift1F . lift1F) asinh + acosh = (lift1F . lift1F) acosh + atanh = (lift1F . lift1F) atanh + exp = (lift1F . lift1F) exp + log = (lift1F . lift1F) log + sqrt = (lift1F . lift1F) sqrt + (**) = (lift2F . lift2F) (**) + pi = Dim (Dim pi) + +-------------------------------------------------------------------------------- + + +adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b +adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b)) +adaptDiag f a b = f a b + +isFull m = isDiag m == Nothing && not (singleM (unwrap m)) + + +lift1L f (L v) = L (f v) +lift2L f (L a) (L b) = L (f a b) +lift2LD f = adaptDiag (lift2L f) + + +instance (KnownNat n, KnownNat m) => Num (L n m) + where + (+) = lift2LD (+) + (*) = lift2LD (*) + (-) = lift2LD (-) + abs = lift1L abs + signum = lift1L signum + negate = lift1L negate + fromInteger = L . Dim . Dim . fromInteger + +instance (KnownNat n, KnownNat m) => Fractional (L n m) + where + fromRational = L . Dim . Dim . fromRational + (/) = lift2LD (/) + +instance (KnownNat n, KnownNat m) => Floating (L n m) where + sin = lift1L sin + cos = lift1L cos + tan = lift1L tan + asin = lift1L asin + acos = lift1L acos + atan = lift1L atan + sinh = lift1L sinh + cosh = lift1L cosh + tanh = lift1L tanh + asinh = lift1L asinh + acosh = lift1L acosh + atanh = lift1L atanh + exp = lift1L exp + log = lift1L log + sqrt = lift1L sqrt + (**) = lift2LD (**) + pi = konst pi + +-------------------------------------------------------------------------------- + +adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b +adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b)) +adaptDiagC f a b = f a b + +isFullC m = isDiagC m == Nothing && not (singleM (unwrap m)) + +lift1M f (M v) = M (f v) +lift2M f (M a) (M b) = M (f a b) +lift2MD f = adaptDiagC (lift2M f) + +instance (KnownNat n, KnownNat m) => Num (M n m) + where + (+) = lift2MD (+) + (*) = lift2MD (*) + (-) = lift2MD (-) + abs = lift1M abs + signum = lift1M signum + negate = lift1M negate + fromInteger = M . Dim . Dim . fromInteger + +instance (KnownNat n, KnownNat m) => Fractional (M n m) + where + fromRational = M . Dim . Dim . fromRational + (/) = lift2MD (/) + +instance (KnownNat n, KnownNat m) => Floating (M n m) where + sin = lift1M sin + cos = lift1M cos + tan = lift1M tan + asin = lift1M asin + acos = lift1M acos + atan = lift1M atan + sinh = lift1M sinh + cosh = lift1M cosh + tanh = lift1M tanh + asinh = lift1M asinh + acosh = lift1M acosh + atanh = lift1M atanh + exp = lift1M exp + log = lift1M log + sqrt = lift1M sqrt + (**) = lift2MD (**) + pi = M pi + +-------------------------------------------------------------------------------- + + +class Disp t + where + disp :: Int -> t -> IO () + + +instance (KnownNat m, KnownNat n) => Disp (L m n) + where + disp n x = do + let a = extract x + let su = LA.dispf n a + printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) + +instance (KnownNat m, KnownNat n) => Disp (M m n) + where + disp n x = do + let a = extract x + let su = LA.dispcf n a + printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) + + +instance KnownNat n => Disp (R n) + where + disp n v = do + let su = LA.dispf n (asRow $ extract v) + putStr "R " >> putStr (tail . dropWhile (/='x') $ su) + +instance KnownNat n => Disp (C n) + where + disp n v = do + let su = LA.dispcf n (asRow $ extract v) + putStr "C " >> putStr (tail . dropWhile (/='x') $ su) + +-------------------------------------------------------------------------------- + +#else + +module Numeric.LinearAlgebra.Static.Internal where + +#endif + diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs deleted file mode 100644 index 7b770e0..0000000 --- a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs +++ /dev/null @@ -1,524 +0,0 @@ -#if __GLASGOW_HASKELL__ >= 708 - -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE FunctionalDependencies #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE Rank2Types #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} - -{- | -Module : Numeric.LinearAlgebra.Static.Internal -Copyright : (c) Alberto Ruiz 2006-14 -License : BSD3 -Stability : provisional - --} - -module Numeric.LinearAlgebra.Static.Internal where - - -import GHC.TypeLits -import qualified Numeric.LinearAlgebra as LA -import Numeric.LinearAlgebra hiding (konst,size) -import Data.Packed as D -import Data.Packed.ST -import Data.Proxy(Proxy) -import Foreign.Storable(Storable) -import Text.Printf - --------------------------------------------------------------------------------- - -newtype Dim (n :: Nat) t = Dim t - deriving Show - -lift1F - :: (c t -> c t) - -> Dim n (c t) -> Dim n (c t) -lift1F f (Dim v) = Dim (f v) - -lift2F - :: (c t -> c t -> c t) - -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) -lift2F f (Dim u) (Dim v) = Dim (f u v) - --------------------------------------------------------------------------------- - -newtype R n = R (Dim n (Vector ℝ)) - deriving (Num,Fractional,Floating) - -newtype C n = C (Dim n (Vector ℂ)) - deriving (Num,Fractional,Floating) - -newtype L m n = L (Dim m (Dim n (Matrix ℝ))) - -newtype M m n = M (Dim m (Dim n (Matrix ℂ))) - - -mkR :: Vector ℝ -> R n -mkR = R . Dim - -mkC :: Vector ℂ -> C n -mkC = C . Dim - -mkL :: Matrix ℝ -> L m n -mkL x = L (Dim (Dim x)) - -mkM :: Matrix ℂ -> M m n -mkM x = M (Dim (Dim x)) - --------------------------------------------------------------------------------- - -type V n t = Dim n (Vector t) - -ud :: Dim n (Vector t) -> Vector t -ud (Dim v) = v - -mkV :: forall (n :: Nat) t . t -> Dim n t -mkV = Dim - - -vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) - => V n t -> V m t -> V (n+m) t -(ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v']) - where - du = fromIntegral . natVal $ (undefined :: Proxy n) - dv = fromIntegral . natVal $ (undefined :: Proxy m) - u' | du > 1 && LA.size u == 1 = LA.konst (u D.@> 0) du - | otherwise = u - v' | dv > 1 && LA.size v == 1 = LA.konst (v D.@> 0) dv - | otherwise = v - - -gvec2 :: Storable t => t -> t -> V 2 t -gvec2 a b = mkV $ runSTVector $ do - v <- newUndefinedVector 2 - writeVector v 0 a - writeVector v 1 b - return v - -gvec3 :: Storable t => t -> t -> t -> V 3 t -gvec3 a b c = mkV $ runSTVector $ do - v <- newUndefinedVector 3 - writeVector v 0 a - writeVector v 1 b - writeVector v 2 c - return v - - -gvec4 :: Storable t => t -> t -> t -> t -> V 4 t -gvec4 a b c d = mkV $ runSTVector $ do - v <- newUndefinedVector 4 - writeVector v 0 a - writeVector v 1 b - writeVector v 2 c - writeVector v 3 d - return v - - -gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t -gvect st xs' - | ok = mkV v - | not (null rest) && null (tail rest) = abort (show xs') - | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") - | otherwise = abort (show xs) - where - (xs,rest) = splitAt d xs' - ok = LA.size v == d && null rest - v = LA.fromList xs - d = fromIntegral . natVal $ (undefined :: Proxy n) - abort info = error $ st++" "++show d++" can't be created from elements "++info - - --------------------------------------------------------------------------------- - -type GM m n t = Dim m (Dim n (Matrix t)) - - -gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t -gmat st xs' - | ok = Dim (Dim x) - | not (null rest) && null (tail rest) = abort (show xs') - | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") - | otherwise = abort (show xs) - where - (xs,rest) = splitAt (m'*n') xs' - v = LA.fromList xs - x = reshape n' v - ok = null rest && ((n' == 0 && dim v == 0) || n'> 0 && (rem (LA.size v) n' == 0) && LA.size x == (m',n')) - m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int - n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int - abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info - --------------------------------------------------------------------------------- - -class Num t => Sized t s d | s -> t, s -> d - where - konst :: t -> s - unwrap :: s -> d t - fromList :: [t] -> s - extract :: s -> d t - create :: d t -> Maybe s - size :: s -> IndexOf d - -singleV v = LA.size v == 1 -singleM m = rows m == 1 && cols m == 1 - - -instance forall n. KnownNat n => Sized ℂ (C n) Vector - where - size _ = fromIntegral . natVal $ (undefined :: Proxy n) - konst x = mkC (LA.scalar x) - unwrap (C (Dim v)) = v - fromList xs = C (gvect "C" xs) - extract s@(unwrap -> v) - | singleV v = LA.konst (v!0) (size s) - | otherwise = v - create v - | LA.size v == size r = Just r - | otherwise = Nothing - where - r = mkC v :: C n - - -instance forall n. KnownNat n => Sized ℝ (R n) Vector - where - size _ = fromIntegral . natVal $ (undefined :: Proxy n) - konst x = mkR (LA.scalar x) - unwrap (R (Dim v)) = v - fromList xs = R (gvect "R" xs) - extract s@(unwrap -> v) - | singleV v = LA.konst (v!0) (size s) - | otherwise = v - create v - | LA.size v == size r = Just r - | otherwise = Nothing - where - r = mkR v :: R n - - - -instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) Matrix - where - size _ = ((fromIntegral . natVal) (undefined :: Proxy m) - ,(fromIntegral . natVal) (undefined :: Proxy n)) - konst x = mkL (LA.scalar x) - fromList xs = L (gmat "L" xs) - unwrap (L (Dim (Dim m))) = m - extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' - extract s@(unwrap -> a) - | singleM a = LA.konst (a `atIndex` (0,0)) (size s) - | otherwise = a - create x - | LA.size x == size r = Just r - | otherwise = Nothing - where - r = mkL x :: L m n - - -instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) Matrix - where - size _ = ((fromIntegral . natVal) (undefined :: Proxy m) - ,(fromIntegral . natVal) (undefined :: Proxy n)) - konst x = mkM (LA.scalar x) - fromList xs = M (gmat "M" xs) - unwrap (M (Dim (Dim m))) = m - extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' - extract s@(unwrap -> a) - | singleM a = LA.konst (a `atIndex` (0,0)) (size s) - | otherwise = a - create x - | LA.size x == size r = Just r - | otherwise = Nothing - where - r = mkM x :: M m n - --------------------------------------------------------------------------------- - -instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m) - where - tr a@(isDiag -> Just _) = mkL (extract a) - tr (extract -> a) = mkL (tr a) - tr' = tr - -instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m) - where - tr a@(isDiagC -> Just _) = mkM (extract a) - tr (extract -> a) = mkM (tr a) - tr' a@(isDiagC -> Just _) = mkM (extract a) - tr' (extract -> a) = mkM (tr' a) - --------------------------------------------------------------------------------- - -isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) -isDiag (L x) = isDiagg x - -isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int)) -isDiagC (M x) = isDiagg x - - -isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int)) -isDiagg (Dim (Dim x)) - | singleM x = Nothing - | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) - | otherwise = Nothing - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int - n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int - v = flatten x - z = v `atIndex` 0 - y = subVector 1 (LA.size v-1) v - ny = LA.size y - zeros = LA.konst 0 (max 0 (min m' n' - ny)) - yz = vjoin [y,zeros] - --------------------------------------------------------------------------------- - -instance forall n . KnownNat n => Show (R n) - where - show s@(R (Dim v)) - | singleV v = "("++show (v!0)++" :: R "++show d++")" - | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" - where - d = size s - -instance forall n . KnownNat n => Show (C n) - where - show s@(C (Dim v)) - | singleV v = "("++show (v!0)++" :: C "++show d++")" - | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" - where - d = size s - -instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) - where - show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' - show s@(L (Dim (Dim x))) - | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' - | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" - where - (m',n') = size s - -instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) - where - show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' - show s@(M (Dim (Dim x))) - | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' - | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" - where - (m',n') = size s - --------------------------------------------------------------------------------- - -instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) - where - (+) = lift2F (+) - (*) = lift2F (*) - (-) = lift2F (-) - abs = lift1F abs - signum = lift1F signum - negate = lift1F negate - fromInteger x = Dim (fromInteger x) - -instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim n (Vector t)) - where - fromRational x = Dim (fromRational x) - (/) = lift2F (/) - -instance (Fractional t, Floating (Vector t), Numeric t) => Floating (Dim n (Vector t)) where - sin = lift1F sin - cos = lift1F cos - tan = lift1F tan - asin = lift1F asin - acos = lift1F acos - atan = lift1F atan - sinh = lift1F sinh - cosh = lift1F cosh - tanh = lift1F tanh - asinh = lift1F asinh - acosh = lift1F acosh - atanh = lift1F atanh - exp = lift1F exp - log = lift1F log - sqrt = lift1F sqrt - (**) = lift2F (**) - pi = Dim pi - - -instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) - where - (+) = (lift2F . lift2F) (+) - (*) = (lift2F . lift2F) (*) - (-) = (lift2F . lift2F) (-) - abs = (lift1F . lift1F) abs - signum = (lift1F . lift1F) signum - negate = (lift1F . lift1F) negate - fromInteger x = Dim (Dim (fromInteger x)) - -instance (Num (Vector t), Num (Matrix t), Fractional t, Numeric t) => Fractional (Dim m (Dim n (Matrix t))) - where - fromRational x = Dim (Dim (fromRational x)) - (/) = (lift2F.lift2F) (/) - -instance (Num (Vector t), Floating (Matrix t), Fractional t, Numeric t) => Floating (Dim m (Dim n (Matrix t))) where - sin = (lift1F . lift1F) sin - cos = (lift1F . lift1F) cos - tan = (lift1F . lift1F) tan - asin = (lift1F . lift1F) asin - acos = (lift1F . lift1F) acos - atan = (lift1F . lift1F) atan - sinh = (lift1F . lift1F) sinh - cosh = (lift1F . lift1F) cosh - tanh = (lift1F . lift1F) tanh - asinh = (lift1F . lift1F) asinh - acosh = (lift1F . lift1F) acosh - atanh = (lift1F . lift1F) atanh - exp = (lift1F . lift1F) exp - log = (lift1F . lift1F) log - sqrt = (lift1F . lift1F) sqrt - (**) = (lift2F . lift2F) (**) - pi = Dim (Dim pi) - --------------------------------------------------------------------------------- - - -adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b -adaptDiag f a b@(isDiag -> Just _) | isFull a = f a (mkL (extract b)) -adaptDiag f a b = f a b - -isFull m = isDiag m == Nothing && not (singleM (unwrap m)) - - -lift1L f (L v) = L (f v) -lift2L f (L a) (L b) = L (f a b) -lift2LD f = adaptDiag (lift2L f) - - -instance (KnownNat n, KnownNat m) => Num (L n m) - where - (+) = lift2LD (+) - (*) = lift2LD (*) - (-) = lift2LD (-) - abs = lift1L abs - signum = lift1L signum - negate = lift1L negate - fromInteger = L . Dim . Dim . fromInteger - -instance (KnownNat n, KnownNat m) => Fractional (L n m) - where - fromRational = L . Dim . Dim . fromRational - (/) = lift2LD (/) - -instance (KnownNat n, KnownNat m) => Floating (L n m) where - sin = lift1L sin - cos = lift1L cos - tan = lift1L tan - asin = lift1L asin - acos = lift1L acos - atan = lift1L atan - sinh = lift1L sinh - cosh = lift1L cosh - tanh = lift1L tanh - asinh = lift1L asinh - acosh = lift1L acosh - atanh = lift1L atanh - exp = lift1L exp - log = lift1L log - sqrt = lift1L sqrt - (**) = lift2LD (**) - pi = konst pi - --------------------------------------------------------------------------------- - -adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b -adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b)) -adaptDiagC f a b = f a b - -isFullC m = isDiagC m == Nothing && not (singleM (unwrap m)) - -lift1M f (M v) = M (f v) -lift2M f (M a) (M b) = M (f a b) -lift2MD f = adaptDiagC (lift2M f) - -instance (KnownNat n, KnownNat m) => Num (M n m) - where - (+) = lift2MD (+) - (*) = lift2MD (*) - (-) = lift2MD (-) - abs = lift1M abs - signum = lift1M signum - negate = lift1M negate - fromInteger = M . Dim . Dim . fromInteger - -instance (KnownNat n, KnownNat m) => Fractional (M n m) - where - fromRational = M . Dim . Dim . fromRational - (/) = lift2MD (/) - -instance (KnownNat n, KnownNat m) => Floating (M n m) where - sin = lift1M sin - cos = lift1M cos - tan = lift1M tan - asin = lift1M asin - acos = lift1M acos - atan = lift1M atan - sinh = lift1M sinh - cosh = lift1M cosh - tanh = lift1M tanh - asinh = lift1M asinh - acosh = lift1M acosh - atanh = lift1M atanh - exp = lift1M exp - log = lift1M log - sqrt = lift1M sqrt - (**) = lift2MD (**) - pi = M pi - --------------------------------------------------------------------------------- - - -class Disp t - where - disp :: Int -> t -> IO () - - -instance (KnownNat m, KnownNat n) => Disp (L m n) - where - disp n x = do - let a = extract x - let su = LA.dispf n a - printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) - -instance (KnownNat m, KnownNat n) => Disp (M m n) - where - disp n x = do - let a = extract x - let su = LA.dispcf n a - printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) - - -instance KnownNat n => Disp (R n) - where - disp n v = do - let su = LA.dispf n (asRow $ extract v) - putStr "R " >> putStr (tail . dropWhile (/='x') $ su) - -instance KnownNat n => Disp (C n) - where - disp n v = do - let su = LA.dispcf n (asRow $ extract v) - putStr "C " >> putStr (tail . dropWhile (/='x') $ su) - --------------------------------------------------------------------------------- - -#else - -module Numeric.LinearAlgebra.Static.Internal where - -#endif - -- cgit v1.2.3