diff options
author | Alberto Ruiz <aruiz@um.es> | 2014-04-24 13:17:55 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2014-04-24 13:17:55 +0200 |
commit | de0219353ca9631135a3f750cef05b9636bef232 (patch) | |
tree | 2943867ceca43bcf5037f60077a6269f589deff8 /lib | |
parent | 3c1bbdd450304945c035a1e49cdb67871ea50451 (diff) |
konst with bidirectional type inference
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Data/Packed/Random.hs | 4 | ||||
-rw-r--r-- | lib/Numeric/Container.hs | 18 | ||||
-rw-r--r-- | lib/Numeric/ContainerBoot.hs | 40 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Algorithms.hs | 6 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Util/Convolution.hs | 10 |
5 files changed, 47 insertions, 31 deletions
diff --git a/lib/Data/Packed/Random.hs b/lib/Data/Packed/Random.hs index dabb17d..e8b0268 100644 --- a/lib/Data/Packed/Random.hs +++ b/lib/Data/Packed/Random.hs | |||
@@ -36,7 +36,7 @@ gaussianSample :: Seed | |||
36 | -> Matrix Double -- ^ result | 36 | -> Matrix Double -- ^ result |
37 | gaussianSample seed n med cov = m where | 37 | gaussianSample seed n med cov = m where |
38 | c = dim med | 38 | c = dim med |
39 | meds = konst 1 n `outer` med | 39 | meds = konst' 1 n `outer` med |
40 | rs = reshape c $ randomVector seed Gaussian (c * n) | 40 | rs = reshape c $ randomVector seed Gaussian (c * n) |
41 | m = rs `mXm` cholSH cov `add` meds | 41 | m = rs `mXm` cholSH cov `add` meds |
42 | 42 | ||
@@ -52,6 +52,6 @@ uniformSample seed n rgs = m where | |||
52 | cs = zipWith subtract as bs | 52 | cs = zipWith subtract as bs |
53 | d = dim a | 53 | d = dim a |
54 | dat = toRows $ reshape n $ randomVector seed Uniform (n*d) | 54 | dat = toRows $ reshape n $ randomVector seed Uniform (n*d) |
55 | am = konst 1 n `outer` a | 55 | am = konst' 1 n `outer` a |
56 | m = fromColumns (zipWith scale cs dat) `add` am | 56 | m = fromColumns (zipWith scale cs dat) `add` am |
57 | 57 | ||
diff --git a/lib/Numeric/Container.hs b/lib/Numeric/Container.hs index ed6714f..a31acfe 100644 --- a/lib/Numeric/Container.hs +++ b/lib/Numeric/Container.hs | |||
@@ -28,6 +28,7 @@ | |||
28 | module Numeric.Container ( | 28 | module Numeric.Container ( |
29 | -- * Basic functions | 29 | -- * Basic functions |
30 | module Data.Packed, | 30 | module Data.Packed, |
31 | konst, -- build, | ||
31 | constant, linspace, | 32 | constant, linspace, |
32 | diag, ident, | 33 | diag, ident, |
33 | ctrans, | 34 | ctrans, |
@@ -59,8 +60,6 @@ module Numeric.Container ( | |||
59 | loadMatrix, saveMatrix, fromFile, fileDimensions, | 60 | loadMatrix, saveMatrix, fromFile, fileDimensions, |
60 | readMatrix, | 61 | readMatrix, |
61 | fscanfVector, fprintfVector, freadVector, fwriteVector, | 62 | fscanfVector, fprintfVector, freadVector, fwriteVector, |
62 | -- * Experimental | ||
63 | build', konst' | ||
64 | ) where | 63 | ) where |
65 | 64 | ||
66 | import Data.Packed | 65 | import Data.Packed |
@@ -174,3 +173,18 @@ instance Container Matrix t => Contraction t (Matrix t) (Matrix t) where | |||
174 | instance Container Matrix t => Contraction (Matrix t) t (Matrix t) where | 173 | instance Container Matrix t => Contraction (Matrix t) t (Matrix t) where |
175 | (×) = flip scale | 174 | (×) = flip scale |
176 | 175 | ||
176 | -------------------------------------------------------------------------------- | ||
177 | |||
178 | -- bidirectional type inference | ||
179 | class Konst e d c | d -> c, c -> d | ||
180 | where | ||
181 | konst :: e -> d -> c e | ||
182 | |||
183 | instance Container Vector e => Konst e Int Vector | ||
184 | where | ||
185 | konst = konst' | ||
186 | |||
187 | instance Container Vector e => Konst e (Int,Int) Matrix | ||
188 | where | ||
189 | konst = konst' | ||
190 | |||
diff --git a/lib/Numeric/ContainerBoot.hs b/lib/Numeric/ContainerBoot.hs index 4c5bbd0..8707473 100644 --- a/lib/Numeric/ContainerBoot.hs +++ b/lib/Numeric/ContainerBoot.hs | |||
@@ -36,9 +36,7 @@ module Numeric.ContainerBoot ( | |||
36 | RealOf, ComplexOf, SingleOf, DoubleOf, | 36 | RealOf, ComplexOf, SingleOf, DoubleOf, |
37 | 37 | ||
38 | IndexOf, | 38 | IndexOf, |
39 | module Data.Complex, | 39 | module Data.Complex |
40 | -- * Experimental | ||
41 | build', konst' | ||
42 | ) where | 40 | ) where |
43 | 41 | ||
44 | import Data.Packed | 42 | import Data.Packed |
@@ -91,15 +89,13 @@ class (Complexable c, Fractional e, Element e) => Container c e where | |||
91 | -- | cannot implement instance Functor because of Element class constraint | 89 | -- | cannot implement instance Functor because of Element class constraint |
92 | cmap :: (Element b) => (e -> b) -> c e -> c b | 90 | cmap :: (Element b) => (e -> b) -> c e -> c b |
93 | -- | constant structure of given size | 91 | -- | constant structure of given size |
94 | konst :: e -> IndexOf c -> c e | 92 | konst' :: e -> IndexOf c -> c e |
95 | -- | create a structure using a function | 93 | -- | create a structure using a function |
96 | -- | 94 | -- |
97 | -- Hilbert matrix of order N: | 95 | -- Hilbert matrix of order N: |
98 | -- | 96 | -- |
99 | -- @hilb n = build (n,n) (\\i j -> 1/(i+j+1))@ | 97 | -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@ |
100 | build :: IndexOf c -> (ArgOf c e) -> c e | 98 | build' :: IndexOf c -> (ArgOf c e) -> c e |
101 | --build :: BoundsOf f -> f -> (ContainerOf f) e | ||
102 | -- | ||
103 | -- | indexing function | 99 | -- | indexing function |
104 | atIndex :: c e -> IndexOf c -> e | 100 | atIndex :: c e -> IndexOf c -> e |
105 | -- | index of min element | 101 | -- | index of min element |
@@ -186,8 +182,8 @@ instance Container Vector Float where | |||
186 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 | 182 | equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0 |
187 | arctan2 = vectorZipF ATan2 | 183 | arctan2 = vectorZipF ATan2 |
188 | scalar x = fromList [x] | 184 | scalar x = fromList [x] |
189 | konst = constantD | 185 | konst' = constantD |
190 | build = buildV | 186 | build' = buildV |
191 | conj = id | 187 | conj = id |
192 | cmap = mapVector | 188 | cmap = mapVector |
193 | atIndex = (@>) | 189 | atIndex = (@>) |
@@ -214,8 +210,8 @@ instance Container Vector Double where | |||
214 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 | 210 | equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0 |
215 | arctan2 = vectorZipR ATan2 | 211 | arctan2 = vectorZipR ATan2 |
216 | scalar x = fromList [x] | 212 | scalar x = fromList [x] |
217 | konst = constantD | 213 | konst' = constantD |
218 | build = buildV | 214 | build' = buildV |
219 | conj = id | 215 | conj = id |
220 | cmap = mapVector | 216 | cmap = mapVector |
221 | atIndex = (@>) | 217 | atIndex = (@>) |
@@ -242,8 +238,8 @@ instance Container Vector (Complex Double) where | |||
242 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | 238 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 |
243 | arctan2 = vectorZipC ATan2 | 239 | arctan2 = vectorZipC ATan2 |
244 | scalar x = fromList [x] | 240 | scalar x = fromList [x] |
245 | konst = constantD | 241 | konst' = constantD |
246 | build = buildV | 242 | build' = buildV |
247 | conj = conjugateC | 243 | conj = conjugateC |
248 | cmap = mapVector | 244 | cmap = mapVector |
249 | atIndex = (@>) | 245 | atIndex = (@>) |
@@ -270,8 +266,8 @@ instance Container Vector (Complex Float) where | |||
270 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 | 266 | equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0 |
271 | arctan2 = vectorZipQ ATan2 | 267 | arctan2 = vectorZipQ ATan2 |
272 | scalar x = fromList [x] | 268 | scalar x = fromList [x] |
273 | konst = constantD | 269 | konst' = constantD |
274 | build = buildV | 270 | build' = buildV |
275 | conj = conjugateQ | 271 | conj = conjugateQ |
276 | cmap = mapVector | 272 | cmap = mapVector |
277 | atIndex = (@>) | 273 | atIndex = (@>) |
@@ -300,8 +296,8 @@ instance (Container Vector a) => Container Matrix a where | |||
300 | equal a b = cols a == cols b && flatten a `equal` flatten b | 296 | equal a b = cols a == cols b && flatten a `equal` flatten b |
301 | arctan2 = liftMatrix2 arctan2 | 297 | arctan2 = liftMatrix2 arctan2 |
302 | scalar x = (1><1) [x] | 298 | scalar x = (1><1) [x] |
303 | konst v (r,c) = reshape c (konst v (r*c)) | 299 | konst' v (r,c) = reshape c (konst' v (r*c)) |
304 | build = buildM | 300 | build' = buildM |
305 | conj = liftMatrix conj | 301 | conj = liftMatrix conj |
306 | cmap f = liftMatrix (mapVector f) | 302 | cmap f = liftMatrix (mapVector f) |
307 | atIndex = (@@>) | 303 | atIndex = (@@>) |
@@ -506,7 +502,7 @@ type instance ElementOf (Vector a) = a | |||
506 | type instance ElementOf (Matrix a) = a | 502 | type instance ElementOf (Matrix a) = a |
507 | 503 | ||
508 | ------------------------------------------------------------ | 504 | ------------------------------------------------------------ |
509 | 505 | {- | |
510 | class Build f where | 506 | class Build f where |
511 | build' :: BoundsOf f -> f -> ContainerOf f | 507 | build' :: BoundsOf f -> f -> ContainerOf f |
512 | 508 | ||
@@ -546,6 +542,8 @@ instance (Element a, | |||
546 | => Build (a->a->a) where | 542 | => Build (a->a->a) where |
547 | build' = buildM | 543 | build' = buildM |
548 | 544 | ||
545 | -} | ||
546 | |||
549 | buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ] | 547 | buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ] |
550 | where rs = map fromIntegral [0 .. (rc-1)] | 548 | where rs = map fromIntegral [0 .. (rc-1)] |
551 | cs = map fromIntegral [0 .. (cc-1)] | 549 | cs = map fromIntegral [0 .. (cc-1)] |
@@ -553,6 +551,8 @@ buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ] | |||
553 | buildV n f = fromList [f k | k <- ks] | 551 | buildV n f = fromList [f k | k <- ks] |
554 | where ks = map fromIntegral [0 .. (n-1)] | 552 | where ks = map fromIntegral [0 .. (n-1)] |
555 | 553 | ||
554 | {- | ||
555 | |||
556 | ---------------------------------------------------- | 556 | ---------------------------------------------------- |
557 | -- experimental | 557 | -- experimental |
558 | 558 | ||
@@ -570,6 +570,8 @@ instance Konst Int where | |||
570 | instance Konst (Int,Int) where | 570 | instance Konst (Int,Int) where |
571 | konst' k (r,c) = reshape c $ konst' k (r*c) | 571 | konst' k (r,c) = reshape c $ konst' k (r*c) |
572 | 572 | ||
573 | -} | ||
574 | |||
573 | -------------------------------------------------------- | 575 | -------------------------------------------------------- |
574 | -- | conjugate transpose | 576 | -- | conjugate transpose |
575 | ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e | 577 | ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e |
diff --git a/lib/Numeric/LinearAlgebra/Algorithms.hs b/lib/Numeric/LinearAlgebra/Algorithms.hs index a3f541b..7223cd9 100644 --- a/lib/Numeric/LinearAlgebra/Algorithms.hs +++ b/lib/Numeric/LinearAlgebra/Algorithms.hs | |||
@@ -484,7 +484,7 @@ zh k v = fromList $ replicate (k-1) 0 ++ (1:drop k xs) | |||
484 | where xs = toList v | 484 | where xs = toList v |
485 | 485 | ||
486 | zt 0 v = v | 486 | zt 0 v = v |
487 | zt k v = vjoin [subVector 0 (dim v - k) v, konst 0 k] | 487 | zt k v = vjoin [subVector 0 (dim v - k) v, konst' 0 k] |
488 | 488 | ||
489 | 489 | ||
490 | unpackQR :: (Field t) => (Matrix t, Vector t) -> (Matrix t, Matrix t) | 490 | unpackQR :: (Field t) => (Matrix t, Vector t) -> (Matrix t, Matrix t) |
@@ -640,10 +640,10 @@ luFact (l_u,perm) | r <= c = (l ,u ,p, s) | |||
640 | c = cols l_u | 640 | c = cols l_u |
641 | tu = triang r c 0 1 | 641 | tu = triang r c 0 1 |
642 | tl = triang r c 0 0 | 642 | tl = triang r c 0 0 |
643 | l = takeColumns r (l_u |*| tl) |+| diagRect 0 (konst 1 r) r r | 643 | l = takeColumns r (l_u |*| tl) |+| diagRect 0 (konst' 1 r) r r |
644 | u = l_u |*| tu | 644 | u = l_u |*| tu |
645 | (p,s) = fixPerm r perm | 645 | (p,s) = fixPerm r perm |
646 | l' = (l_u |*| tl) |+| diagRect 0 (konst 1 c) r c | 646 | l' = (l_u |*| tl) |+| diagRect 0 (konst' 1 c) r c |
647 | u' = takeRows c (l_u |*| tu) | 647 | u' = takeRows c (l_u |*| tu) |
648 | (|+|) = add | 648 | (|+|) = add |
649 | (|*|) = mul | 649 | (|*|) = mul |
diff --git a/lib/Numeric/LinearAlgebra/Util/Convolution.hs b/lib/Numeric/LinearAlgebra/Util/Convolution.hs index 1043614..82de476 100644 --- a/lib/Numeric/LinearAlgebra/Util/Convolution.hs +++ b/lib/Numeric/LinearAlgebra/Util/Convolution.hs | |||
@@ -59,7 +59,7 @@ corrMin ker v = minEvery ss (asRow ker) <> ones | |||
59 | where | 59 | where |
60 | minEvery a b = cond a b a a b | 60 | minEvery a b = cond a b a a b |
61 | ss = vectSS (dim ker) v | 61 | ss = vectSS (dim ker) v |
62 | ones = konst' 1 (dim ker) | 62 | ones = konst 1 (dim ker) |
63 | 63 | ||
64 | 64 | ||
65 | 65 | ||
@@ -87,7 +87,7 @@ corr2 ker mat = dims | |||
87 | | otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")" | 87 | | otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")" |
88 | sz m = show (rows m)++"x"++show (cols m) | 88 | sz m = show (rows m)++"x"++show (cols m) |
89 | 89 | ||
90 | conv2 :: (Num a, Product a) => Matrix a -> Matrix a -> Matrix a | 90 | conv2 :: (Num a, Product a, Container Vector a) => Matrix a -> Matrix a -> Matrix a |
91 | -- ^ 2D convolution | 91 | -- ^ 2D convolution |
92 | conv2 k m = corr2 (fliprl . flipud $ k) pm | 92 | conv2 k m = corr2 (fliprl . flipud $ k) pm |
93 | where | 93 | where |
@@ -101,9 +101,9 @@ conv2 k m = corr2 (fliprl . flipud $ k) pm | |||
101 | c = cols k - 1 | 101 | c = cols k - 1 |
102 | h = rows m | 102 | h = rows m |
103 | w = cols m | 103 | w = cols m |
104 | z1 = konst' 0 (r,c) | 104 | z1 = konst 0 (r,c) |
105 | z2 = konst' 0 (r,w) | 105 | z2 = konst 0 (r,w) |
106 | z3 = konst' 0 (h,c) | 106 | z3 = konst 0 (h,c) |
107 | 107 | ||
108 | -- TODO: could be simplified using future empty arrays | 108 | -- TODO: could be simplified using future empty arrays |
109 | 109 | ||