diff options
Diffstat (limited to 'packages/base/src/Internal/Modular.hs')
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 136 |
1 files changed, 99 insertions, 37 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 0274607..6b34010 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -7,6 +7,7 @@ | |||
7 | {-# LANGUAGE ScopedTypeVariables #-} | 7 | {-# LANGUAGE ScopedTypeVariables #-} |
8 | {-# LANGUAGE Rank2Types #-} | 8 | {-# LANGUAGE Rank2Types #-} |
9 | {-# LANGUAGE FlexibleInstances #-} | 9 | {-# LANGUAGE FlexibleInstances #-} |
10 | {-# LANGUAGE UndecidableInstances #-} | ||
10 | {-# LANGUAGE GADTs #-} | 11 | {-# LANGUAGE GADTs #-} |
11 | {-# LANGUAGE TypeFamilies #-} | 12 | {-# LANGUAGE TypeFamilies #-} |
12 | 13 | ||
@@ -22,7 +23,7 @@ Proof of concept of statically checked modular arithmetic. | |||
22 | -} | 23 | -} |
23 | 24 | ||
24 | module Internal.Modular( | 25 | module Internal.Modular( |
25 | Mod, F | 26 | Mod |
26 | ) where | 27 | ) where |
27 | 28 | ||
28 | import Internal.Vector | 29 | import Internal.Vector |
@@ -30,8 +31,8 @@ import Internal.Matrix hiding (mat,size) | |||
30 | import Internal.Numeric | 31 | import Internal.Numeric |
31 | import Internal.Element | 32 | import Internal.Element |
32 | import Internal.Container | 33 | import Internal.Container |
33 | import Internal.Vectorized (prodI,sumI) | 34 | import Internal.Vectorized (prodI,sumI,prodL,sumL) |
34 | import Internal.LAPACK (multiplyI) | 35 | import Internal.LAPACK (multiplyI, multiplyL) |
35 | import Internal.Util(Indexable(..),gaussElim) | 36 | import Internal.Util(Indexable(..),gaussElim) |
36 | import GHC.TypeLits | 37 | import GHC.TypeLits |
37 | import Data.Proxy(Proxy) | 38 | import Data.Proxy(Proxy) |
@@ -45,24 +46,24 @@ import Data.Ratio | |||
45 | newtype Mod (n :: Nat) t = Mod {unMod:: t} | 46 | newtype Mod (n :: Nat) t = Mod {unMod:: t} |
46 | deriving (Storable) | 47 | deriving (Storable) |
47 | 48 | ||
48 | instance KnownNat m => Enum (F m) | 49 | instance (Integral t, Enum t, KnownNat m) => Enum (Mod m t) |
49 | where | 50 | where |
50 | toEnum = l0 (\m x -> fromIntegral $ x `mod` (fromIntegral m)) | 51 | toEnum = l0 (\m x -> fromIntegral $ x `mod` (fromIntegral m)) |
51 | fromEnum = fromIntegral . unMod | 52 | fromEnum = fromIntegral . unMod |
52 | 53 | ||
53 | instance KnownNat m => Eq (F m) | 54 | instance (Eq t, KnownNat m) => Eq (Mod m t) |
54 | where | 55 | where |
55 | a == b = (unMod a) == (unMod b) | 56 | a == b = (unMod a) == (unMod b) |
56 | 57 | ||
57 | instance KnownNat m => Ord (F m) | 58 | instance (Ord t, KnownNat m) => Ord (Mod m t) |
58 | where | 59 | where |
59 | compare a b = compare (unMod a) (unMod b) | 60 | compare a b = compare (unMod a) (unMod b) |
60 | 61 | ||
61 | instance KnownNat m => Real (F m) | 62 | instance (Real t, KnownNat m, Integral (Mod m t)) => Real (Mod m t) |
62 | where | 63 | where |
63 | toRational x = toInteger x % 1 | 64 | toRational x = toInteger x % 1 |
64 | 65 | ||
65 | instance KnownNat m => Integral (F m) | 66 | instance (Integral t, KnownNat m, Num (Mod m t)) => Integral (Mod m t) |
66 | where | 67 | where |
67 | toInteger = toInteger . unMod | 68 | toInteger = toInteger . unMod |
68 | quotRem a b = (Mod q, Mod r) | 69 | quotRem a b = (Mod q, Mod r) |
@@ -70,7 +71,7 @@ instance KnownNat m => Integral (F m) | |||
70 | (q,r) = quotRem (unMod a) (unMod b) | 71 | (q,r) = quotRem (unMod a) (unMod b) |
71 | 72 | ||
72 | -- | this instance is only valid for prime m | 73 | -- | this instance is only valid for prime m |
73 | instance KnownNat m => Fractional (F m) | 74 | instance (Show (Mod m t), Num (Mod m t), Eq t, KnownNat m) => Fractional (Mod m t) |
74 | where | 75 | where |
75 | recip x | 76 | recip x |
76 | | x*r == 1 = r | 77 | | x*r == 1 = r |
@@ -80,27 +81,27 @@ instance KnownNat m => Fractional (F m) | |||
80 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 81 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
81 | fromRational x = fromInteger (numerator x) / fromInteger (denominator x) | 82 | fromRational x = fromInteger (numerator x) / fromInteger (denominator x) |
82 | 83 | ||
83 | l2 :: forall m a b c. (KnownNat m) => (Int -> a -> b -> c) -> Mod m a -> Mod m b -> Mod m c | 84 | l2 :: forall m a b c. (Num c, KnownNat m) => (c -> a -> b -> c) -> Mod m a -> Mod m b -> Mod m c |
84 | l2 f (Mod u) (Mod v) = Mod (f m' u v) | 85 | l2 f (Mod u) (Mod v) = Mod (f m' u v) |
85 | where | 86 | where |
86 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 87 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
87 | 88 | ||
88 | l1 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> Mod m a -> Mod m b | 89 | l1 :: forall m a b . (Num b, KnownNat m) => (b -> a -> b) -> Mod m a -> Mod m b |
89 | l1 f (Mod u) = Mod (f m' u) | 90 | l1 f (Mod u) = Mod (f m' u) |
90 | where | 91 | where |
91 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 92 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
92 | 93 | ||
93 | l0 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> a -> Mod m b | 94 | l0 :: forall m a b . (Num b, KnownNat m) => (b -> a -> b) -> a -> Mod m b |
94 | l0 f u = Mod (f m' u) | 95 | l0 f u = Mod (f m' u) |
95 | where | 96 | where |
96 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 97 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
97 | 98 | ||
98 | 99 | ||
99 | instance Show (F n) | 100 | instance Show t => Show (Mod n t) |
100 | where | 101 | where |
101 | show = show . unMod | 102 | show = show . unMod |
102 | 103 | ||
103 | instance forall n . KnownNat n => Num (F n) | 104 | instance forall n t . (Integral t, KnownNat n) => Num (Mod n t) |
104 | where | 105 | where |
105 | (+) = l2 (\m a b -> (a + b) `mod` (fromIntegral m)) | 106 | (+) = l2 (\m a b -> (a + b) `mod` (fromIntegral m)) |
106 | (*) = l2 (\m a b -> (a * b) `mod` (fromIntegral m)) | 107 | (*) = l2 (\m a b -> (a * b) `mod` (fromIntegral m)) |
@@ -110,14 +111,8 @@ instance forall n . KnownNat n => Num (F n) | |||
110 | fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) | 111 | fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) |
111 | 112 | ||
112 | 113 | ||
113 | -- | Integer modulo n | ||
114 | type F n = Mod n I | ||
115 | 114 | ||
116 | type V n = Vector (F n) | 115 | instance (Ord t, Element t) => Element (Mod n t) |
117 | type M n = Matrix (F n) | ||
118 | |||
119 | |||
120 | instance Element (F n) | ||
121 | where | 116 | where |
122 | transdata n v m = i2f (transdata n (f2i v) m) | 117 | transdata n v m = i2f (transdata n (f2i v) m) |
123 | constantD x n = i2f (constantD (unMod x) n) | 118 | constantD x n = i2f (constantD (unMod x) n) |
@@ -128,7 +123,8 @@ instance Element (F n) | |||
128 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | 123 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) |
129 | remapM i j m = i2fM (remap i j (f2iM m)) | 124 | remapM i j m = i2fM (remap i j (f2iM m)) |
130 | 125 | ||
131 | instance forall m . KnownNat m => Container Vector (F m) | 126 | |
127 | instance forall m . KnownNat m => Container Vector (Mod m I) | ||
132 | where | 128 | where |
133 | conj' = id | 129 | conj' = id |
134 | size' = dim | 130 | size' = dim |
@@ -168,24 +164,77 @@ instance forall m . KnownNat m => Container Vector (F m) | |||
168 | fromZ' = vmod . fromZ' | 164 | fromZ' = vmod . fromZ' |
169 | toZ' = toZ' . f2i | 165 | toZ' = toZ' . f2i |
170 | 166 | ||
167 | instance forall m . KnownNat m => Container Vector (Mod m Z) | ||
168 | where | ||
169 | conj' = id | ||
170 | size' = dim | ||
171 | scale' s x = vmod (scale (unMod s) (f2i x)) | ||
172 | addConstant c x = vmod (addConstant (unMod c) (f2i x)) | ||
173 | add a b = vmod (add (f2i a) (f2i b)) | ||
174 | sub a b = vmod (sub (f2i a) (f2i b)) | ||
175 | mul a b = vmod (mul (f2i a) (f2i b)) | ||
176 | equal u v = equal (f2i u) (f2i v) | ||
177 | scalar' x = fromList [x] | ||
178 | konst' x = i2f . konst (unMod x) | ||
179 | build' n f = build n (fromIntegral . f) | ||
180 | cmap' = cmap | ||
181 | atIndex' x k = fromIntegral (atIndex (f2i x) k) | ||
182 | minIndex' = minIndex . f2i | ||
183 | maxIndex' = maxIndex . f2i | ||
184 | minElement' = Mod . minElement . f2i | ||
185 | maxElement' = Mod . maxElement . f2i | ||
186 | sumElements' = fromIntegral . sumL m' . f2i | ||
187 | where | ||
188 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
189 | prodElements' = fromIntegral . prodL m' . f2i | ||
190 | where | ||
191 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
192 | step' = i2f . step . f2i | ||
193 | find' = findV | ||
194 | assoc' = assocV | ||
195 | accum' = accumV | ||
196 | ccompare' a b = ccompare (f2i a) (f2i b) | ||
197 | cselect' c l e g = i2f $ cselect c (f2i l) (f2i e) (f2i g) | ||
198 | scaleRecip s x = scale' s (cmap recip x) | ||
199 | divide x y = mul x (cmap recip y) | ||
200 | arctan2' = undefined | ||
201 | cmod' m = vmod . cmod' (unMod m) . f2i | ||
202 | fromInt' = vmod . fromInt' | ||
203 | toInt' = toInt . f2i | ||
204 | fromZ' = vmod | ||
205 | toZ' = f2i | ||
206 | |||
207 | |||
171 | 208 | ||
172 | instance Indexable (Vector (F m)) (F m) | 209 | instance (Storable t, Indexable (Vector t) t) => Indexable (Vector (Mod m t)) (Mod m t) |
173 | where | 210 | where |
174 | (!) = (@>) | 211 | (!) = (@>) |
175 | 212 | ||
176 | 213 | ||
177 | type instance RealOf (F n) = I | 214 | type instance RealOf (Mod n I) = I |
215 | type instance RealOf (Mod n Z) = Z | ||
178 | 216 | ||
179 | instance KnownNat m => Product (F m) where | 217 | instance KnownNat m => Product (Mod m I) where |
180 | norm2 = undefined | 218 | norm2 = undefined |
181 | absSum = undefined | 219 | absSum = undefined |
182 | norm1 = undefined | 220 | norm1 = undefined |
183 | normInf = undefined | 221 | normInf = undefined |
184 | multiply = lift2 (multiplyI m') | 222 | multiply = lift2m (multiplyI m') |
223 | where | ||
224 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
225 | |||
226 | instance KnownNat m => Product (Mod m Z) where | ||
227 | norm2 = undefined | ||
228 | absSum = undefined | ||
229 | norm1 = undefined | ||
230 | normInf = undefined | ||
231 | multiply = lift2m (multiplyL m') | ||
185 | where | 232 | where |
186 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 233 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
187 | 234 | ||
188 | instance KnownNat m => Numeric (F m) | 235 | |
236 | instance KnownNat m => Numeric (Mod m I) | ||
237 | instance KnownNat m => Numeric (Mod m Z) | ||
189 | 238 | ||
190 | i2f :: Storable t => Vector t -> Vector (Mod n t) | 239 | i2f :: Storable t => Vector t -> Vector (Mod n t) |
191 | i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | 240 | i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) |
@@ -206,10 +255,12 @@ vmod = i2f . cmod' m' | |||
206 | where | 255 | where |
207 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 256 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
208 | 257 | ||
209 | lift1 f a = fromInt (f (toInt a)) | 258 | lift1 f a = vmod (f (f2i a)) |
210 | lift2 f a b = fromInt (f (toInt a) (toInt b)) | 259 | lift2 f a b = vmod (f (f2i a) (f2i b)) |
260 | |||
261 | lift2m f a b = liftMatrix vmod (f (f2iM a) (f2iM b)) | ||
211 | 262 | ||
212 | instance forall m . KnownNat m => Num (V m) | 263 | instance forall m . KnownNat m => Num (Vector (Mod m I)) |
213 | where | 264 | where |
214 | (+) = lift2 (+) | 265 | (+) = lift2 (+) |
215 | (*) = lift2 (*) | 266 | (*) = lift2 (*) |
@@ -222,14 +273,14 @@ instance forall m . KnownNat m => Num (V m) | |||
222 | 273 | ||
223 | -------------------------------------------------------------------------------- | 274 | -------------------------------------------------------------------------------- |
224 | 275 | ||
225 | instance (KnownNat m) => Testable (M m) | 276 | instance (KnownNat m) => Testable (Matrix (Mod m I)) |
226 | where | 277 | where |
227 | checkT _ = test | 278 | checkT _ = test |
228 | 279 | ||
229 | test = (ok, info) | 280 | test = (ok, info) |
230 | where | 281 | where |
231 | v = fromList [3,-5,75] :: V 11 | 282 | v = fromList [3,-5,75] :: Vector (Mod 11 I) |
232 | m = (3><3) [1..] :: M 11 | 283 | m = (3><3) [1..] :: Matrix (Mod 11 I) |
233 | 284 | ||
234 | a = (3><3) [1,2 , 3 | 285 | a = (3><3) [1,2 , 3 |
235 | ,4,5 , 6 | 286 | ,4,5 , 6 |
@@ -237,13 +288,17 @@ test = (ok, info) | |||
237 | 288 | ||
238 | b = (3><2) [0..] :: Matrix I | 289 | b = (3><2) [0..] :: Matrix I |
239 | 290 | ||
240 | am = fromInt a :: Matrix (F 13) | 291 | am = fromInt a :: Matrix (Mod 13 I) |
241 | bm = fromInt b :: Matrix (F 13) | 292 | bm = fromInt b :: Matrix (Mod 13 I) |
242 | ad = fromInt a :: Matrix Double | 293 | ad = fromInt a :: Matrix Double |
243 | bd = fromInt b :: Matrix Double | 294 | bd = fromInt b :: Matrix Double |
244 | 295 | ||
245 | g = (3><3) (repeat (40000)) :: Matrix I | 296 | g = (3><3) (repeat (40000)) :: Matrix I |
246 | gm = fromInt g :: Matrix (F 100000) | 297 | gm = fromInt g :: Matrix (Mod 100000 I) |
298 | |||
299 | lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z | ||
300 | lgm = fromZ lg :: Matrix (Mod 10000000000 Z) | ||
301 | |||
247 | 302 | ||
248 | info = do | 303 | info = do |
249 | print v | 304 | print v |
@@ -262,11 +317,18 @@ test = (ok, info) | |||
262 | print gm | 317 | print gm |
263 | print $ gm <> gm | 318 | print $ gm <> gm |
264 | 319 | ||
320 | print lg | ||
321 | print $ lg <> lg | ||
322 | print lgm | ||
323 | print $ lgm <> lgm | ||
324 | |||
325 | |||
265 | ok = and | 326 | ok = and |
266 | [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) | 327 | [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) |
267 | , am <> gaussElim am bm == bm | 328 | , am <> gaussElim am bm == bm |
268 | , prodElements (konst (9:: F 10) (12::Int)) == product (replicate 12 (9:: F 10)) | 329 | , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) |
269 | , gm <> gm == konst 0 (3,3) | 330 | , gm <> gm == konst 0 (3,3) |
331 | , lgm <> lgm == konst 0 (3,3) | ||
270 | ] | 332 | ] |
271 | 333 | ||
272 | 334 | ||