summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal')
-rw-r--r--packages/base/src/Internal/Modular.hs254
1 files changed, 254 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
new file mode 100644
index 0000000..1116b96
--- /dev/null
+++ b/packages/base/src/Internal/Modular.hs
@@ -0,0 +1,254 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE FlexibleContexts #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE Rank2Types #-}
9{-# LANGUAGE FlexibleInstances #-}
10{-# LANGUAGE GADTs #-}
11{-# LANGUAGE TypeFamilies #-}
12
13
14{- |
15Module : Internal.Modular
16Copyright : (c) Alberto Ruiz 2015
17License : BSD3
18Stability : experimental
19
20Proof of concept of statically checked modular arithmetic.
21
22-}
23
24module Internal.Modular(
25 Mod, F
26) where
27
28import Internal.Vector
29import Internal.Matrix hiding (mat,size)
30import Internal.Numeric
31import Internal.Element
32import Internal.Tools
33import Internal.Container
34import Internal.Util(Indexable(..),gaussElim)
35import GHC.TypeLits
36import Data.Proxy(Proxy)
37import Foreign.ForeignPtr(castForeignPtr)
38import Data.Vector.Storable(fromList,unsafeToForeignPtr, unsafeFromForeignPtr)
39import Foreign.Storable
40import Data.Ratio
41
42
43
44-- | Wrapper with a phantom integer for statically checked modular arithmetic.
45newtype Mod (n :: Nat) t = Mod {unMod:: t}
46 deriving (Storable)
47
48instance KnownNat m => Enum (F m)
49 where
50 toEnum = l0 (\m x -> fromIntegral $ x `mod` (fromIntegral m))
51 fromEnum = fromIntegral . unMod
52
53instance KnownNat m => Eq (F m)
54 where
55 a == b = (unMod a) == (unMod b)
56
57instance KnownNat m => Ord (F m)
58 where
59 compare a b = compare (unMod a) (unMod b)
60
61instance KnownNat m => Real (F m)
62 where
63 toRational x = toInteger x % 1
64
65instance KnownNat m => Integral (F m)
66 where
67 toInteger = toInteger . unMod
68 quotRem a b = (Mod q, Mod r)
69 where
70 (q,r) = quotRem (unMod a) (unMod b)
71
72-- | this instance is only valid for prime m
73instance KnownNat m => Fractional (F m)
74 where
75 recip x
76 | x*r == 1 = r
77 | otherwise = error $ show x ++" does not have a multiplicative inverse mod "++show m'
78 where
79 r = x^(m'-2)
80 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
81 fromRational x = fromInteger (numerator x) / fromInteger (denominator x)
82
83l2 :: forall m a b c. (KnownNat m) => (Int -> a -> b -> c) -> Mod m a -> Mod m b -> Mod m c
84l2 f (Mod u) (Mod v) = Mod (f m' u v)
85 where
86 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
87
88l1 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> Mod m a -> Mod m b
89l1 f (Mod u) = Mod (f m' u)
90 where
91 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
92
93l0 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> a -> Mod m b
94l0 f u = Mod (f m' u)
95 where
96 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
97
98
99instance Show (F n)
100 where
101 show = show . unMod
102
103instance forall n . KnownNat n => Num (F n)
104 where
105 (+) = l2 (\m a b -> (a + b) `mod` (fromIntegral m))
106 (*) = l2 (\m a b -> (a * b) `mod` (fromIntegral m))
107 (-) = l2 (\m a b -> (a - b) `mod` (fromIntegral m))
108 abs = l1 (const abs)
109 signum = l1 (const signum)
110 fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m))
111
112
113-- | Integer modulo n
114type F n = Mod n I
115
116type V n = Vector (F n)
117type M n = Matrix (F n)
118
119
120instance Element (F n)
121 where
122 transdata n v m = i2f (transdata n (f2i v) m)
123 constantD x n = i2f (constantD (unMod x) n)
124 extractR m mi is mj js = i2fM (extractR (f2iM m) mi is mj js)
125 sortI = sortI . f2i
126 sortV = i2f . sortV . f2i
127 compareV u v = compareV (f2i u) (f2i v)
128 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
129 remapM i j m = i2fM (remap i j (f2iM m))
130
131instance forall m . KnownNat m => Container Vector (F m)
132 where
133 conj' = id
134 size' = dim
135 scale' s x = fromInt (scale (unMod s) (toInt x))
136 addConstant c x = fromInt (addConstant (unMod c) (toInt x))
137 add a b = fromInt (add (toInt a) (toInt b))
138 sub a b = fromInt (sub (toInt a) (toInt b))
139 mul a b = fromInt (mul (toInt a) (toInt b))
140 equal u v = equal (toInt u) (toInt v)
141 scalar' x = fromList [x]
142 konst' x = i2f . konst (unMod x)
143 build' n f = build n (fromIntegral . f)
144 cmap' = cmap
145 atIndex' x k = fromIntegral (atIndex (toInt x) k)
146 minIndex' = minIndex . toInt
147 maxIndex' = maxIndex . toInt
148 minElement' = Mod . minElement . toInt
149 maxElement' = Mod . maxElement . toInt
150 sumElements' = fromIntegral . sumElements . toInt
151 prodElements' = fromIntegral . sumElements . toInt
152 step' = i2f . step . toInt
153 find' = findV
154 assoc' = assocV
155 accum' = accumV
156 ccompare' a b = ccompare (toInt a) (toInt b)
157 cselect' c l e g = i2f $ cselect c (toInt l) (toInt e) (toInt g)
158 scaleRecip s x = scale' s (cmap recip x)
159 divide x y = mul x (cmap recip y)
160 arctan2' = undefined
161 cmod' m = fromInt' . cmod' (unMod m) . toInt'
162 fromInt' v = i2f $ cmod' (fromIntegral m') (fromInt' v)
163 where
164 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
165 toInt' = f2i
166
167
168instance Indexable (Vector (F m)) (F m)
169 where
170 (!) = (@>)
171
172
173type instance RealOf (F n) = I
174
175
176instance KnownNat m => Product (F m) where
177 norm2 = undefined
178 absSum = undefined
179 norm1 = undefined
180 normInf = undefined
181 multiply = lift2 multiply
182
183
184instance KnownNat m => Numeric (F m)
185
186i2f :: Vector I -> Vector (F n)
187i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
188 where (fp,i,n) = unsafeToForeignPtr v
189
190f2i :: Vector (F n) -> Vector I
191f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
192 where (fp,i,n) = unsafeToForeignPtr v
193
194f2iM :: Matrix (F n) -> Matrix I
195f2iM = liftMatrix f2i
196
197i2fM :: Matrix I -> Matrix (F n)
198i2fM = liftMatrix i2f
199
200
201lift1 f a = fromInt (f (toInt a))
202lift2 f a b = fromInt (f (toInt a) (toInt b))
203
204instance forall m . KnownNat m => Num (V m)
205 where
206 (+) = lift2 (+)
207 (*) = lift2 (*)
208 (-) = lift2 (-)
209 abs = lift1 abs
210 signum = lift1 signum
211 negate = lift1 negate
212 fromInteger x = fromInt (fromInteger x)
213
214
215--------------------------------------------------------------------------------
216
217instance (KnownNat m) => Testable (M m)
218 where
219 checkT _ = test
220
221test = (ok, info)
222 where
223 v = fromList [3,-5,75] :: V 11
224 m = (3><3) [1..] :: M 11
225
226 a = (3><3) [1,2 , 3
227 ,4,5 , 6
228 ,0,10,-3] :: Matrix I
229
230 b = (3><2) [0..] :: Matrix I
231
232 am = fromInt a :: Matrix (F 13)
233 bm = fromInt b :: Matrix (F 13)
234 ad = fromInt a :: Matrix Double
235 bd = fromInt b :: Matrix Double
236
237 info = do
238 print v
239 print m
240 print (tr m)
241 print $ v+v
242 print $ m+m
243 print $ m <> m
244 print $ m #> v
245
246 print $ am <> gaussElim am bm - bm
247 print $ ad <> gaussElim ad bd - bd
248
249 ok = and
250 [ toInt (m #> v) == cmod 11 (toInt m #> toInt v )
251 , am <> gaussElim am bm == bm
252 ]
253
254