summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-04-24 13:17:55 +0200
committerAlberto Ruiz <aruiz@um.es>2014-04-24 13:17:55 +0200
commitde0219353ca9631135a3f750cef05b9636bef232 (patch)
tree2943867ceca43bcf5037f60077a6269f589deff8 /lib
parent3c1bbdd450304945c035a1e49cdb67871ea50451 (diff)
konst with bidirectional type inference
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Random.hs4
-rw-r--r--lib/Numeric/Container.hs18
-rw-r--r--lib/Numeric/ContainerBoot.hs40
-rw-r--r--lib/Numeric/LinearAlgebra/Algorithms.hs6
-rw-r--r--lib/Numeric/LinearAlgebra/Util/Convolution.hs10
5 files changed, 47 insertions, 31 deletions
diff --git a/lib/Data/Packed/Random.hs b/lib/Data/Packed/Random.hs
index dabb17d..e8b0268 100644
--- a/lib/Data/Packed/Random.hs
+++ b/lib/Data/Packed/Random.hs
@@ -36,7 +36,7 @@ gaussianSample :: Seed
36 -> Matrix Double -- ^ result 36 -> Matrix Double -- ^ result
37gaussianSample seed n med cov = m where 37gaussianSample seed n med cov = m where
38 c = dim med 38 c = dim med
39 meds = konst 1 n `outer` med 39 meds = konst' 1 n `outer` med
40 rs = reshape c $ randomVector seed Gaussian (c * n) 40 rs = reshape c $ randomVector seed Gaussian (c * n)
41 m = rs `mXm` cholSH cov `add` meds 41 m = rs `mXm` cholSH cov `add` meds
42 42
@@ -52,6 +52,6 @@ uniformSample seed n rgs = m where
52 cs = zipWith subtract as bs 52 cs = zipWith subtract as bs
53 d = dim a 53 d = dim a
54 dat = toRows $ reshape n $ randomVector seed Uniform (n*d) 54 dat = toRows $ reshape n $ randomVector seed Uniform (n*d)
55 am = konst 1 n `outer` a 55 am = konst' 1 n `outer` a
56 m = fromColumns (zipWith scale cs dat) `add` am 56 m = fromColumns (zipWith scale cs dat) `add` am
57 57
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs
index ed6714f..a31acfe 100644
--- a/lib/Numeric/Container.hs
+++ b/lib/Numeric/Container.hs
@@ -28,6 +28,7 @@
28module Numeric.Container ( 28module Numeric.Container (
29 -- * Basic functions 29 -- * Basic functions
30 module Data.Packed, 30 module Data.Packed,
31 konst, -- build,
31 constant, linspace, 32 constant, linspace,
32 diag, ident, 33 diag, ident,
33 ctrans, 34 ctrans,
@@ -59,8 +60,6 @@ module Numeric.Container (
59 loadMatrix, saveMatrix, fromFile, fileDimensions, 60 loadMatrix, saveMatrix, fromFile, fileDimensions,
60 readMatrix, 61 readMatrix,
61 fscanfVector, fprintfVector, freadVector, fwriteVector, 62 fscanfVector, fprintfVector, freadVector, fwriteVector,
62 -- * Experimental
63 build', konst'
64) where 63) where
65 64
66import Data.Packed 65import Data.Packed
@@ -174,3 +173,18 @@ instance Container Matrix t => Contraction t (Matrix t) (Matrix t) where
174instance Container Matrix t => Contraction (Matrix t) t (Matrix t) where 173instance Container Matrix t => Contraction (Matrix t) t (Matrix t) where
175 (×) = flip scale 174 (×) = flip scale
176 175
176--------------------------------------------------------------------------------
177
178-- bidirectional type inference
179class Konst e d c | d -> c, c -> d
180 where
181 konst :: e -> d -> c e
182
183instance Container Vector e => Konst e Int Vector
184 where
185 konst = konst'
186
187instance Container Vector e => Konst e (Int,Int) Matrix
188 where
189 konst = konst'
190
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs
index 4c5bbd0..8707473 100644
--- a/lib/Numeric/ContainerBoot.hs
+++ b/lib/Numeric/ContainerBoot.hs
@@ -36,9 +36,7 @@ module Numeric.ContainerBoot (
36 RealOf, ComplexOf, SingleOf, DoubleOf, 36 RealOf, ComplexOf, SingleOf, DoubleOf,
37 37
38 IndexOf, 38 IndexOf,
39 module Data.Complex, 39 module Data.Complex
40 -- * Experimental
41 build', konst'
42) where 40) where
43 41
44import Data.Packed 42import Data.Packed
@@ -91,15 +89,13 @@ class (Complexable c, Fractional e, Element e) => Container c e where
91 -- | cannot implement instance Functor because of Element class constraint 89 -- | cannot implement instance Functor because of Element class constraint
92 cmap :: (Element b) => (e -> b) -> c e -> c b 90 cmap :: (Element b) => (e -> b) -> c e -> c b
93 -- | constant structure of given size 91 -- | constant structure of given size
94 konst :: e -> IndexOf c -> c e 92 konst' :: e -> IndexOf c -> c e
95 -- | create a structure using a function 93 -- | create a structure using a function
96 -- 94 --
97 -- Hilbert matrix of order N: 95 -- Hilbert matrix of order N:
98 -- 96 --
99 -- @hilb n = build (n,n) (\\i j -> 1/(i+j+1))@ 97 -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@
100 build :: IndexOf c -> (ArgOf c e) -> c e 98 build' :: IndexOf c -> (ArgOf c e) -> c e
101 --build :: BoundsOf f -> f -> (ContainerOf f) e
102 --
103 -- | indexing function 99 -- | indexing function
104 atIndex :: c e -> IndexOf c -> e 100 atIndex :: c e -> IndexOf c -> e
105 -- | index of min element 101 -- | index of min element
@@ -186,8 +182,8 @@ instance Container Vector Float where
186 equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 182 equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0
187 arctan2 = vectorZipF ATan2 183 arctan2 = vectorZipF ATan2
188 scalar x = fromList [x] 184 scalar x = fromList [x]
189 konst = constantD 185 konst' = constantD
190 build = buildV 186 build' = buildV
191 conj = id 187 conj = id
192 cmap = mapVector 188 cmap = mapVector
193 atIndex = (@>) 189 atIndex = (@>)
@@ -214,8 +210,8 @@ instance Container Vector Double where
214 equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 210 equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0
215 arctan2 = vectorZipR ATan2 211 arctan2 = vectorZipR ATan2
216 scalar x = fromList [x] 212 scalar x = fromList [x]
217 konst = constantD 213 konst' = constantD
218 build = buildV 214 build' = buildV
219 conj = id 215 conj = id
220 cmap = mapVector 216 cmap = mapVector
221 atIndex = (@>) 217 atIndex = (@>)
@@ -242,8 +238,8 @@ instance Container Vector (Complex Double) where
242 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 238 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
243 arctan2 = vectorZipC ATan2 239 arctan2 = vectorZipC ATan2
244 scalar x = fromList [x] 240 scalar x = fromList [x]
245 konst = constantD 241 konst' = constantD
246 build = buildV 242 build' = buildV
247 conj = conjugateC 243 conj = conjugateC
248 cmap = mapVector 244 cmap = mapVector
249 atIndex = (@>) 245 atIndex = (@>)
@@ -270,8 +266,8 @@ instance Container Vector (Complex Float) where
270 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 266 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
271 arctan2 = vectorZipQ ATan2 267 arctan2 = vectorZipQ ATan2
272 scalar x = fromList [x] 268 scalar x = fromList [x]
273 konst = constantD 269 konst' = constantD
274 build = buildV 270 build' = buildV
275 conj = conjugateQ 271 conj = conjugateQ
276 cmap = mapVector 272 cmap = mapVector
277 atIndex = (@>) 273 atIndex = (@>)
@@ -300,8 +296,8 @@ instance (Container Vector a) => Container Matrix a where
300 equal a b = cols a == cols b && flatten a `equal` flatten b 296 equal a b = cols a == cols b && flatten a `equal` flatten b
301 arctan2 = liftMatrix2 arctan2 297 arctan2 = liftMatrix2 arctan2
302 scalar x = (1><1) [x] 298 scalar x = (1><1) [x]
303 konst v (r,c) = reshape c (konst v (r*c)) 299 konst' v (r,c) = reshape c (konst' v (r*c))
304 build = buildM 300 build' = buildM
305 conj = liftMatrix conj 301 conj = liftMatrix conj
306 cmap f = liftMatrix (mapVector f) 302 cmap f = liftMatrix (mapVector f)
307 atIndex = (@@>) 303 atIndex = (@@>)
@@ -506,7 +502,7 @@ type instance ElementOf (Vector a) = a
506type instance ElementOf (Matrix a) = a 502type instance ElementOf (Matrix a) = a
507 503
508------------------------------------------------------------ 504------------------------------------------------------------
509 505{-
510class Build f where 506class Build f where
511 build' :: BoundsOf f -> f -> ContainerOf f 507 build' :: BoundsOf f -> f -> ContainerOf f
512 508
@@ -546,6 +542,8 @@ instance (Element a,
546 => Build (a->a->a) where 542 => Build (a->a->a) where
547 build' = buildM 543 build' = buildM
548 544
545-}
546
549buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ] 547buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ]
550 where rs = map fromIntegral [0 .. (rc-1)] 548 where rs = map fromIntegral [0 .. (rc-1)]
551 cs = map fromIntegral [0 .. (cc-1)] 549 cs = map fromIntegral [0 .. (cc-1)]
@@ -553,6 +551,8 @@ buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ]
553buildV n f = fromList [f k | k <- ks] 551buildV n f = fromList [f k | k <- ks]
554 where ks = map fromIntegral [0 .. (n-1)] 552 where ks = map fromIntegral [0 .. (n-1)]
555 553
554{-
555
556---------------------------------------------------- 556----------------------------------------------------
557-- experimental 557-- experimental
558 558
@@ -570,6 +570,8 @@ instance Konst Int where
570instance Konst (Int,Int) where 570instance Konst (Int,Int) where
571 konst' k (r,c) = reshape c $ konst' k (r*c) 571 konst' k (r,c) = reshape c $ konst' k (r*c)
572 572
573-}
574
573-------------------------------------------------------- 575--------------------------------------------------------
574-- | conjugate transpose 576-- | conjugate transpose
575ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e 577ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs
index a3f541b..7223cd9 100644
--- a/lib/Numeric/LinearAlgebra/Algorithms.hs
+++ b/lib/Numeric/LinearAlgebra/Algorithms.hs
@@ -484,7 +484,7 @@ zh k v = fromList $ replicate (k-1) 0 ++ (1:drop k xs)
484 where xs = toList v 484 where xs = toList v
485 485
486zt 0 v = v 486zt 0 v = v
487zt k v = vjoin [subVector 0 (dim v - k) v, konst 0 k] 487zt k v = vjoin [subVector 0 (dim v - k) v, konst' 0 k]
488 488
489 489
490unpackQR :: (Field t) => (Matrix t, Vector t) -> (Matrix t, Matrix t) 490unpackQR :: (Field t) => (Matrix t, Vector t) -> (Matrix t, Matrix t)
@@ -640,10 +640,10 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s)
640 c = cols l_u 640 c = cols l_u
641 tu = triang r c 0 1 641 tu = triang r c 0 1
642 tl = triang r c 0 0 642 tl = triang r c 0 0
643 l = takeColumns r (l_u |*| tl) |+| diagRect 0 (konst 1 r) r r 643 l = takeColumns r (l_u |*| tl) |+| diagRect 0 (konst' 1 r) r r
644 u = l_u |*| tu 644 u = l_u |*| tu
645 (p,s) = fixPerm r perm 645 (p,s) = fixPerm r perm
646 l' = (l_u |*| tl) |+| diagRect 0 (konst 1 c) r c 646 l' = (l_u |*| tl) |+| diagRect 0 (konst' 1 c) r c
647 u' = takeRows c (l_u |*| tu) 647 u' = takeRows c (l_u |*| tu)
648 (|+|) = add 648 (|+|) = add
649 (|*|) = mul 649 (|*|) = mul
diff --git a/lib/Numeric/LinearAlgebra/Util/Convolution.hs b/lib/Numeric/LinearAlgebra/Util/Convolution.hs
index 1043614..82de476 100644
--- a/lib/Numeric/LinearAlgebra/Util/Convolution.hs
+++ b/lib/Numeric/LinearAlgebra/Util/Convolution.hs
@@ -59,7 +59,7 @@ corrMin ker v = minEvery ss (asRow ker) <> ones
59 where 59 where
60 minEvery a b = cond a b a a b 60 minEvery a b = cond a b a a b
61 ss = vectSS (dim ker) v 61 ss = vectSS (dim ker) v
62 ones = konst' 1 (dim ker) 62 ones = konst 1 (dim ker)
63 63
64 64
65 65
@@ -87,7 +87,7 @@ corr2 ker mat = dims
87 | otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")" 87 | otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")"
88 sz m = show (rows m)++"x"++show (cols m) 88 sz m = show (rows m)++"x"++show (cols m)
89 89
90conv2 :: (Num a, Product a) => Matrix a -> Matrix a -> Matrix a 90conv2 :: (Num a, Product a, Container Vector a) => Matrix a -> Matrix a -> Matrix a
91-- ^ 2D convolution 91-- ^ 2D convolution
92conv2 k m = corr2 (fliprl . flipud $ k) pm 92conv2 k m = corr2 (fliprl . flipud $ k) pm
93 where 93 where
@@ -101,9 +101,9 @@ conv2 k m = corr2 (fliprl . flipud $ k) pm
101 c = cols k - 1 101 c = cols k - 1
102 h = rows m 102 h = rows m
103 w = cols m 103 w = cols m
104 z1 = konst' 0 (r,c) 104 z1 = konst 0 (r,c)
105 z2 = konst' 0 (r,w) 105 z2 = konst 0 (r,w)
106 z3 = konst' 0 (h,c) 106 z3 = konst 0 (h,c)
107 107
108-- TODO: could be simplified using future empty arrays 108-- TODO: could be simplified using future empty arrays
109 109