From de0219353ca9631135a3f750cef05b9636bef232 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Thu, 24 Apr 2014 13:17:55 +0200 Subject: konst with bidirectional type inference --- lib/Data/Packed/Random.hs | 4 +-- lib/Numeric/Container.hs | 18 ++++++++++-- lib/Numeric/ContainerBoot.hs | 40 ++++++++++++++------------- lib/Numeric/LinearAlgebra/Algorithms.hs | 6 ++-- lib/Numeric/LinearAlgebra/Util/Convolution.hs | 10 +++---- 5 files changed, 47 insertions(+), 31 deletions(-) diff --git a/lib/Data/Packed/Random.hs b/lib/Data/Packed/Random.hs index dabb17d..e8b0268 100644 --- a/lib/Data/Packed/Random.hs +++ b/lib/Data/Packed/Random.hs @@ -36,7 +36,7 @@ gaussianSample :: Seed -> Matrix Double -- ^ result gaussianSample seed n med cov = m where c = dim med - meds = konst 1 n `outer` med + meds = konst' 1 n `outer` med rs = reshape c $ randomVector seed Gaussian (c * n) m = rs `mXm` cholSH cov `add` meds @@ -52,6 +52,6 @@ uniformSample seed n rgs = m where cs = zipWith subtract as bs d = dim a dat = toRows $ reshape n $ randomVector seed Uniform (n*d) - am = konst 1 n `outer` a + am = konst' 1 n `outer` a m = fromColumns (zipWith scale cs dat) `add` am diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs index ed6714f..a31acfe 100644 --- a/lib/Numeric/Container.hs +++ b/lib/Numeric/Container.hs @@ -28,6 +28,7 @@ module Numeric.Container ( -- * Basic functions module Data.Packed, + konst, -- build, constant, linspace, diag, ident, ctrans, @@ -59,8 +60,6 @@ module Numeric.Container ( loadMatrix, saveMatrix, fromFile, fileDimensions, readMatrix, fscanfVector, fprintfVector, freadVector, fwriteVector, - -- * Experimental - build', konst' ) where import Data.Packed @@ -174,3 +173,18 @@ instance Container Matrix t => Contraction t (Matrix t) (Matrix t) where instance Container Matrix t => Contraction (Matrix t) t (Matrix t) where (×) = flip scale +-------------------------------------------------------------------------------- + +-- bidirectional type inference +class Konst e d c | d -> c, c -> d + where + konst :: e -> d -> c e + +instance Container Vector e => Konst e Int Vector + where + konst = konst' + +instance Container Vector e => Konst e (Int,Int) Matrix + where + konst = konst' + diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index 4c5bbd0..8707473 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs @@ -36,9 +36,7 @@ module Numeric.ContainerBoot ( RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf, - module Data.Complex, - -- * Experimental - build', konst' + module Data.Complex ) where import Data.Packed @@ -91,15 +89,13 @@ class (Complexable c, Fractional e, Element e) => Container c e where -- | cannot implement instance Functor because of Element class constraint cmap :: (Element b) => (e -> b) -> c e -> c b -- | constant structure of given size - konst :: e -> IndexOf c -> c e + konst' :: e -> IndexOf c -> c e -- | create a structure using a function -- -- Hilbert matrix of order N: -- - -- @hilb n = build (n,n) (\\i j -> 1/(i+j+1))@ - build :: IndexOf c -> (ArgOf c e) -> c e - --build :: BoundsOf f -> f -> (ContainerOf f) e - -- + -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@ + build' :: IndexOf c -> (ArgOf c e) -> c e -- | indexing function atIndex :: c e -> IndexOf c -> e -- | index of min element @@ -186,8 +182,8 @@ instance Container Vector Float where equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 arctan2 = vectorZipF ATan2 scalar x = fromList [x] - konst = constantD - build = buildV + konst' = constantD + build' = buildV conj = id cmap = mapVector atIndex = (@>) @@ -214,8 +210,8 @@ instance Container Vector Double where equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 arctan2 = vectorZipR ATan2 scalar x = fromList [x] - konst = constantD - build = buildV + konst' = constantD + build' = buildV conj = id cmap = mapVector atIndex = (@>) @@ -242,8 +238,8 @@ instance Container Vector (Complex Double) where equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 arctan2 = vectorZipC ATan2 scalar x = fromList [x] - konst = constantD - build = buildV + konst' = constantD + build' = buildV conj = conjugateC cmap = mapVector atIndex = (@>) @@ -270,8 +266,8 @@ instance Container Vector (Complex Float) where equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 arctan2 = vectorZipQ ATan2 scalar x = fromList [x] - konst = constantD - build = buildV + konst' = constantD + build' = buildV conj = conjugateQ cmap = mapVector atIndex = (@>) @@ -300,8 +296,8 @@ instance (Container Vector a) => Container Matrix a where equal a b = cols a == cols b && flatten a `equal` flatten b arctan2 = liftMatrix2 arctan2 scalar x = (1><1) [x] - konst v (r,c) = reshape c (konst v (r*c)) - build = buildM + konst' v (r,c) = reshape c (konst' v (r*c)) + build' = buildM conj = liftMatrix conj cmap f = liftMatrix (mapVector f) atIndex = (@@>) @@ -506,7 +502,7 @@ type instance ElementOf (Vector a) = a type instance ElementOf (Matrix a) = a ------------------------------------------------------------ - +{- class Build f where build' :: BoundsOf f -> f -> ContainerOf f @@ -546,6 +542,8 @@ instance (Element a, => Build (a->a->a) where build' = buildM +-} + buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ] where rs = map fromIntegral [0 .. (rc-1)] cs = map fromIntegral [0 .. (cc-1)] @@ -553,6 +551,8 @@ buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ] buildV n f = fromList [f k | k <- ks] where ks = map fromIntegral [0 .. (n-1)] +{- + ---------------------------------------------------- -- experimental @@ -570,6 +570,8 @@ instance Konst Int where instance Konst (Int,Int) where konst' k (r,c) = reshape c $ konst' k (r*c) +-} + -------------------------------------------------------- -- | conjugate transpose ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index a3f541b..7223cd9 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs @@ -484,7 +484,7 @@ zh k v = fromList $ replicate (k-1) 0 ++ (1:drop k xs) where xs = toList v zt 0 v = v -zt k v = vjoin [subVector 0 (dim v - k) v, konst 0 k] +zt k v = vjoin [subVector 0 (dim v - k) v, konst' 0 k] unpackQR :: (Field t) => (Matrix t, Vector t) -> (Matrix t, Matrix t) @@ -640,10 +640,10 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s) c = cols l_u tu = triang r c 0 1 tl = triang r c 0 0 - l = takeColumns r (l_u |*| tl) |+| diagRect 0 (konst 1 r) r r + l = takeColumns r (l_u |*| tl) |+| diagRect 0 (konst' 1 r) r r u = l_u |*| tu (p,s) = fixPerm r perm - l' = (l_u |*| tl) |+| diagRect 0 (konst 1 c) r c + l' = (l_u |*| tl) |+| diagRect 0 (konst' 1 c) r c u' = takeRows c (l_u |*| tu) (|+|) = add (|*|) = mul diff --git a/lib/Numeric/LinearAlgebra/Util/Convolution.hs b/lib/Numeric/LinearAlgebra/Util/Convolution.hs index 1043614..82de476 100644 --- a/lib/Numeric/LinearAlgebra/Util/Convolution.hs +++ b/lib/Numeric/LinearAlgebra/Util/Convolution.hs @@ -59,7 +59,7 @@ corrMin ker v = minEvery ss (asRow ker) <> ones where minEvery a b = cond a b a a b ss = vectSS (dim ker) v - ones = konst' 1 (dim ker) + ones = konst 1 (dim ker) @@ -87,7 +87,7 @@ corr2 ker mat = dims | otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")" sz m = show (rows m)++"x"++show (cols m) -conv2 :: (Num a, Product a) => Matrix a -> Matrix a -> Matrix a +conv2 :: (Num a, Product a, Container Vector a) => Matrix a -> Matrix a -> Matrix a -- ^ 2D convolution conv2 k m = corr2 (fliprl . flipud $ k) pm where @@ -101,9 +101,9 @@ conv2 k m = corr2 (fliprl . flipud $ k) pm c = cols k - 1 h = rows m w = cols m - z1 = konst' 0 (r,c) - z2 = konst' 0 (r,w) - z3 = konst' 0 (h,c) + z1 = konst 0 (r,c) + z2 = konst 0 (r,w) + z3 = konst 0 (h,c) -- TODO: could be simplified using future empty arrays -- cgit v1.2.3