diff options
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 254 |
1 files changed, 254 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs new file mode 100644 index 0000000..1116b96 --- /dev/null +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -0,0 +1,254 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE FlexibleContexts #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE Rank2Types #-} | ||
9 | {-# LANGUAGE FlexibleInstances #-} | ||
10 | {-# LANGUAGE GADTs #-} | ||
11 | {-# LANGUAGE TypeFamilies #-} | ||
12 | |||
13 | |||
14 | {- | | ||
15 | Module : Internal.Modular | ||
16 | Copyright : (c) Alberto Ruiz 2015 | ||
17 | License : BSD3 | ||
18 | Stability : experimental | ||
19 | |||
20 | Proof of concept of statically checked modular arithmetic. | ||
21 | |||
22 | -} | ||
23 | |||
24 | module Internal.Modular( | ||
25 | Mod, F | ||
26 | ) where | ||
27 | |||
28 | import Internal.Vector | ||
29 | import Internal.Matrix hiding (mat,size) | ||
30 | import Internal.Numeric | ||
31 | import Internal.Element | ||
32 | import Internal.Tools | ||
33 | import Internal.Container | ||
34 | import Internal.Util(Indexable(..),gaussElim) | ||
35 | import GHC.TypeLits | ||
36 | import Data.Proxy(Proxy) | ||
37 | import Foreign.ForeignPtr(castForeignPtr) | ||
38 | import Data.Vector.Storable(fromList,unsafeToForeignPtr, unsafeFromForeignPtr) | ||
39 | import Foreign.Storable | ||
40 | import Data.Ratio | ||
41 | |||
42 | |||
43 | |||
44 | -- | Wrapper with a phantom integer for statically checked modular arithmetic. | ||
45 | newtype Mod (n :: Nat) t = Mod {unMod:: t} | ||
46 | deriving (Storable) | ||
47 | |||
48 | instance KnownNat m => Enum (F m) | ||
49 | where | ||
50 | toEnum = l0 (\m x -> fromIntegral $ x `mod` (fromIntegral m)) | ||
51 | fromEnum = fromIntegral . unMod | ||
52 | |||
53 | instance KnownNat m => Eq (F m) | ||
54 | where | ||
55 | a == b = (unMod a) == (unMod b) | ||
56 | |||
57 | instance KnownNat m => Ord (F m) | ||
58 | where | ||
59 | compare a b = compare (unMod a) (unMod b) | ||
60 | |||
61 | instance KnownNat m => Real (F m) | ||
62 | where | ||
63 | toRational x = toInteger x % 1 | ||
64 | |||
65 | instance KnownNat m => Integral (F m) | ||
66 | where | ||
67 | toInteger = toInteger . unMod | ||
68 | quotRem a b = (Mod q, Mod r) | ||
69 | where | ||
70 | (q,r) = quotRem (unMod a) (unMod b) | ||
71 | |||
72 | -- | this instance is only valid for prime m | ||
73 | instance KnownNat m => Fractional (F m) | ||
74 | where | ||
75 | recip x | ||
76 | | x*r == 1 = r | ||
77 | | otherwise = error $ show x ++" does not have a multiplicative inverse mod "++show m' | ||
78 | where | ||
79 | r = x^(m'-2) | ||
80 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
81 | fromRational x = fromInteger (numerator x) / fromInteger (denominator x) | ||
82 | |||
83 | l2 :: forall m a b c. (KnownNat m) => (Int -> a -> b -> c) -> Mod m a -> Mod m b -> Mod m c | ||
84 | l2 f (Mod u) (Mod v) = Mod (f m' u v) | ||
85 | where | ||
86 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
87 | |||
88 | l1 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> Mod m a -> Mod m b | ||
89 | l1 f (Mod u) = Mod (f m' u) | ||
90 | where | ||
91 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
92 | |||
93 | l0 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> a -> Mod m b | ||
94 | l0 f u = Mod (f m' u) | ||
95 | where | ||
96 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
97 | |||
98 | |||
99 | instance Show (F n) | ||
100 | where | ||
101 | show = show . unMod | ||
102 | |||
103 | instance forall n . KnownNat n => Num (F n) | ||
104 | where | ||
105 | (+) = l2 (\m a b -> (a + b) `mod` (fromIntegral m)) | ||
106 | (*) = l2 (\m a b -> (a * b) `mod` (fromIntegral m)) | ||
107 | (-) = l2 (\m a b -> (a - b) `mod` (fromIntegral m)) | ||
108 | abs = l1 (const abs) | ||
109 | signum = l1 (const signum) | ||
110 | fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) | ||
111 | |||
112 | |||
113 | -- | Integer modulo n | ||
114 | type F n = Mod n I | ||
115 | |||
116 | type V n = Vector (F n) | ||
117 | type M n = Matrix (F n) | ||
118 | |||
119 | |||
120 | instance Element (F n) | ||
121 | where | ||
122 | transdata n v m = i2f (transdata n (f2i v) m) | ||
123 | constantD x n = i2f (constantD (unMod x) n) | ||
124 | extractR m mi is mj js = i2fM (extractR (f2iM m) mi is mj js) | ||
125 | sortI = sortI . f2i | ||
126 | sortV = i2f . sortV . f2i | ||
127 | compareV u v = compareV (f2i u) (f2i v) | ||
128 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | ||
129 | remapM i j m = i2fM (remap i j (f2iM m)) | ||
130 | |||
131 | instance forall m . KnownNat m => Container Vector (F m) | ||
132 | where | ||
133 | conj' = id | ||
134 | size' = dim | ||
135 | scale' s x = fromInt (scale (unMod s) (toInt x)) | ||
136 | addConstant c x = fromInt (addConstant (unMod c) (toInt x)) | ||
137 | add a b = fromInt (add (toInt a) (toInt b)) | ||
138 | sub a b = fromInt (sub (toInt a) (toInt b)) | ||
139 | mul a b = fromInt (mul (toInt a) (toInt b)) | ||
140 | equal u v = equal (toInt u) (toInt v) | ||
141 | scalar' x = fromList [x] | ||
142 | konst' x = i2f . konst (unMod x) | ||
143 | build' n f = build n (fromIntegral . f) | ||
144 | cmap' = cmap | ||
145 | atIndex' x k = fromIntegral (atIndex (toInt x) k) | ||
146 | minIndex' = minIndex . toInt | ||
147 | maxIndex' = maxIndex . toInt | ||
148 | minElement' = Mod . minElement . toInt | ||
149 | maxElement' = Mod . maxElement . toInt | ||
150 | sumElements' = fromIntegral . sumElements . toInt | ||
151 | prodElements' = fromIntegral . sumElements . toInt | ||
152 | step' = i2f . step . toInt | ||
153 | find' = findV | ||
154 | assoc' = assocV | ||
155 | accum' = accumV | ||
156 | ccompare' a b = ccompare (toInt a) (toInt b) | ||
157 | cselect' c l e g = i2f $ cselect c (toInt l) (toInt e) (toInt g) | ||
158 | scaleRecip s x = scale' s (cmap recip x) | ||
159 | divide x y = mul x (cmap recip y) | ||
160 | arctan2' = undefined | ||
161 | cmod' m = fromInt' . cmod' (unMod m) . toInt' | ||
162 | fromInt' v = i2f $ cmod' (fromIntegral m') (fromInt' v) | ||
163 | where | ||
164 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
165 | toInt' = f2i | ||
166 | |||
167 | |||
168 | instance Indexable (Vector (F m)) (F m) | ||
169 | where | ||
170 | (!) = (@>) | ||
171 | |||
172 | |||
173 | type instance RealOf (F n) = I | ||
174 | |||
175 | |||
176 | instance KnownNat m => Product (F m) where | ||
177 | norm2 = undefined | ||
178 | absSum = undefined | ||
179 | norm1 = undefined | ||
180 | normInf = undefined | ||
181 | multiply = lift2 multiply | ||
182 | |||
183 | |||
184 | instance KnownNat m => Numeric (F m) | ||
185 | |||
186 | i2f :: Vector I -> Vector (F n) | ||
187 | i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | ||
188 | where (fp,i,n) = unsafeToForeignPtr v | ||
189 | |||
190 | f2i :: Vector (F n) -> Vector I | ||
191 | f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) | ||
192 | where (fp,i,n) = unsafeToForeignPtr v | ||
193 | |||
194 | f2iM :: Matrix (F n) -> Matrix I | ||
195 | f2iM = liftMatrix f2i | ||
196 | |||
197 | i2fM :: Matrix I -> Matrix (F n) | ||
198 | i2fM = liftMatrix i2f | ||
199 | |||
200 | |||
201 | lift1 f a = fromInt (f (toInt a)) | ||
202 | lift2 f a b = fromInt (f (toInt a) (toInt b)) | ||
203 | |||
204 | instance forall m . KnownNat m => Num (V m) | ||
205 | where | ||
206 | (+) = lift2 (+) | ||
207 | (*) = lift2 (*) | ||
208 | (-) = lift2 (-) | ||
209 | abs = lift1 abs | ||
210 | signum = lift1 signum | ||
211 | negate = lift1 negate | ||
212 | fromInteger x = fromInt (fromInteger x) | ||
213 | |||
214 | |||
215 | -------------------------------------------------------------------------------- | ||
216 | |||
217 | instance (KnownNat m) => Testable (M m) | ||
218 | where | ||
219 | checkT _ = test | ||
220 | |||
221 | test = (ok, info) | ||
222 | where | ||
223 | v = fromList [3,-5,75] :: V 11 | ||
224 | m = (3><3) [1..] :: M 11 | ||
225 | |||
226 | a = (3><3) [1,2 , 3 | ||
227 | ,4,5 , 6 | ||
228 | ,0,10,-3] :: Matrix I | ||
229 | |||
230 | b = (3><2) [0..] :: Matrix I | ||
231 | |||
232 | am = fromInt a :: Matrix (F 13) | ||
233 | bm = fromInt b :: Matrix (F 13) | ||
234 | ad = fromInt a :: Matrix Double | ||
235 | bd = fromInt b :: Matrix Double | ||
236 | |||
237 | info = do | ||
238 | print v | ||
239 | print m | ||
240 | print (tr m) | ||
241 | print $ v+v | ||
242 | print $ m+m | ||
243 | print $ m <> m | ||
244 | print $ m #> v | ||
245 | |||
246 | print $ am <> gaussElim am bm - bm | ||
247 | print $ ad <> gaussElim ad bd - bd | ||
248 | |||
249 | ok = and | ||
250 | [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) | ||
251 | , am <> gaussElim am bm == bm | ||
252 | ] | ||
253 | |||
254 | |||