summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Modular.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Modular.hs')
-rw-r--r--packages/base/src/Internal/Modular.hs74
1 files changed, 39 insertions, 35 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index cf50a05..36ffb57 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -74,24 +74,24 @@ instance KnownNat m => Fractional (F m)
74 | x*r == 1 = r 74 | x*r == 1 = r
75 | otherwise = error $ show x ++" does not have a multiplicative inverse mod "++show m' 75 | otherwise = error $ show x ++" does not have a multiplicative inverse mod "++show m'
76 where 76 where
77 r = x^(m'-2) 77 r = x^(m'-2 :: Integer)
78 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 78 m' = fromIntegral . natVal $ (undefined :: Proxy m)
79 fromRational x = fromInteger (numerator x) / fromInteger (denominator x) 79 fromRational x = fromInteger (numerator x) / fromInteger (denominator x)
80 80
81l2 :: forall m a b c. (KnownNat m) => (Int -> a -> b -> c) -> Mod m a -> Mod m b -> Mod m c 81l2 :: forall m a b c. (KnownNat m) => (Int -> a -> b -> c) -> Mod m a -> Mod m b -> Mod m c
82l2 f (Mod u) (Mod v) = Mod (f m' u v) 82l2 f (Mod u) (Mod v) = Mod (f m' u v)
83 where 83 where
84 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 84 m' = fromIntegral . natVal $ (undefined :: Proxy m)
85 85
86l1 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> Mod m a -> Mod m b 86l1 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> Mod m a -> Mod m b
87l1 f (Mod u) = Mod (f m' u) 87l1 f (Mod u) = Mod (f m' u)
88 where 88 where
89 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 89 m' = fromIntegral . natVal $ (undefined :: Proxy m)
90 90
91l0 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> a -> Mod m b 91l0 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> a -> Mod m b
92l0 f u = Mod (f m' u) 92l0 f u = Mod (f m' u)
93 where 93 where
94 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int 94 m' = fromIntegral . natVal $ (undefined :: Proxy m)
95 95
96 96
97instance Show (F n) 97instance Show (F n)
@@ -106,7 +106,7 @@ instance forall n . KnownNat n => Num (F n)
106 abs = l1 (const abs) 106 abs = l1 (const abs)
107 signum = l1 (const signum) 107 signum = l1 (const signum)
108 fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m)) 108 fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m))
109 109
110 110
111-- | Integer modulo n 111-- | Integer modulo n
112type F n = Mod n I 112type F n = Mod n I
@@ -114,7 +114,7 @@ type F n = Mod n I
114type V n = Vector (F n) 114type V n = Vector (F n)
115type M n = Matrix (F n) 115type M n = Matrix (F n)
116 116
117 117
118instance Element (F n) 118instance Element (F n)
119 where 119 where
120 transdata n v m = i2f (transdata n (f2i v) m) 120 transdata n v m = i2f (transdata n (f2i v) m)
@@ -130,37 +130,37 @@ instance forall m . KnownNat m => Container Vector (F m)
130 where 130 where
131 conj' = id 131 conj' = id
132 size' = dim 132 size' = dim
133 scale' s x = fromInt (scale (unMod s) (toInt x)) 133 scale' s x = vmod (scale (unMod s) (f2i x))
134 addConstant c x = fromInt (addConstant (unMod c) (toInt x)) 134 addConstant c x = vmod (addConstant (unMod c) (f2i x))
135 add a b = fromInt (add (toInt a) (toInt b)) 135 add a b = vmod (add (f2i a) (f2i b))
136 sub a b = fromInt (sub (toInt a) (toInt b)) 136 sub a b = vmod (sub (f2i a) (f2i b))
137 mul a b = fromInt (mul (toInt a) (toInt b)) 137 mul a b = vmod (mul (f2i a) (f2i b))
138 equal u v = equal (toInt u) (toInt v) 138 equal u v = equal (f2i u) (f2i v)
139 scalar' x = fromList [x] 139 scalar' x = fromList [x]
140 konst' x = i2f . konst (unMod x) 140 konst' x = i2f . konst (unMod x)
141 build' n f = build n (fromIntegral . f) 141 build' n f = build n (fromIntegral . f)
142 cmap' = cmap 142 cmap' = cmap
143 atIndex' x k = fromIntegral (atIndex (toInt x) k) 143 atIndex' x k = fromIntegral (atIndex (f2i x) k)
144 minIndex' = minIndex . toInt 144 minIndex' = minIndex . f2i
145 maxIndex' = maxIndex . toInt 145 maxIndex' = maxIndex . f2i
146 minElement' = Mod . minElement . toInt 146 minElement' = Mod . minElement . f2i
147 maxElement' = Mod . maxElement . toInt 147 maxElement' = Mod . maxElement . f2i
148 sumElements' = fromIntegral . sumElements . toInt 148 sumElements' = fromIntegral . sumElements . f2i -- FIXME
149 prodElements' = fromIntegral . sumElements . toInt 149 prodElements' = fromIntegral . sumElements . f2i -- FIXME
150 step' = i2f . step . toInt 150 step' = i2f . step . f2i
151 find' = findV 151 find' = findV
152 assoc' = assocV 152 assoc' = assocV
153 accum' = accumV 153 accum' = accumV
154 ccompare' a b = ccompare (toInt a) (toInt b) 154 ccompare' a b = ccompare (f2i a) (f2i b)
155 cselect' c l e g = i2f $ cselect c (toInt l) (toInt e) (toInt g) 155 cselect' c l e g = i2f $ cselect c (f2i l) (f2i e) (f2i g)
156 scaleRecip s x = scale' s (cmap recip x) 156 scaleRecip s x = scale' s (cmap recip x)
157 divide x y = mul x (cmap recip y) 157 divide x y = mul x (cmap recip y)
158 arctan2' = undefined 158 arctan2' = undefined
159 cmod' m = fromInt' . cmod' (unMod m) . toInt' 159 cmod' m = vmod . cmod' (unMod m) . f2i
160 fromInt' v = i2f $ cmod' (fromIntegral m') (fromInt' v) 160 fromInt' = vmod
161 where
162 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
163 toInt' = f2i 161 toInt' = f2i
162 fromZ' = vmod . fromZ'
163 toZ' = toZ' . f2i
164 164
165 165
166instance Indexable (Vector (F m)) (F m) 166instance Indexable (Vector (F m)) (F m)
@@ -176,25 +176,29 @@ instance KnownNat m => Product (F m) where
176 absSum = undefined 176 absSum = undefined
177 norm1 = undefined 177 norm1 = undefined
178 normInf = undefined 178 normInf = undefined
179 multiply = lift2 multiply 179 multiply = lift2 multiply -- FIXME
180 180
181 181
182instance KnownNat m => Numeric (F m) 182instance KnownNat m => Numeric (F m)
183 183
184i2f :: Vector I -> Vector (F n) 184i2f :: Storable t => Vector t -> Vector (Mod n t)
185i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) 185i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
186 where (fp,i,n) = unsafeToForeignPtr v 186 where (fp,i,n) = unsafeToForeignPtr v
187 187
188f2i :: Vector (F n) -> Vector I 188f2i :: Storable t => Vector (Mod n t) -> Vector t
189f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) 189f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
190 where (fp,i,n) = unsafeToForeignPtr v 190 where (fp,i,n) = unsafeToForeignPtr v
191 191
192f2iM :: Matrix (F n) -> Matrix I 192f2iM :: Storable t => Matrix (Mod n t) -> Matrix t
193f2iM = liftMatrix f2i 193f2iM = liftMatrix f2i
194 194
195i2fM :: Matrix I -> Matrix (F n) 195i2fM :: Storable t => Matrix t -> Matrix (Mod n t)
196i2fM = liftMatrix i2f 196i2fM = liftMatrix i2f
197 197
198vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t)
199vmod = i2f . cmod' m'
200 where
201 m' = fromIntegral . natVal $ (undefined :: Proxy m)
198 202
199lift1 f a = fromInt (f (toInt a)) 203lift1 f a = fromInt (f (toInt a))
200lift2 f a b = fromInt (f (toInt a) (toInt b)) 204lift2 f a b = fromInt (f (toInt a) (toInt b))
@@ -220,18 +224,18 @@ test = (ok, info)
220 where 224 where
221 v = fromList [3,-5,75] :: V 11 225 v = fromList [3,-5,75] :: V 11
222 m = (3><3) [1..] :: M 11 226 m = (3><3) [1..] :: M 11
223 227
224 a = (3><3) [1,2 , 3 228 a = (3><3) [1,2 , 3
225 ,4,5 , 6 229 ,4,5 , 6
226 ,0,10,-3] :: Matrix I 230 ,0,10,-3] :: Matrix I
227 231
228 b = (3><2) [0..] :: Matrix I 232 b = (3><2) [0..] :: Matrix I
229 233
230 am = fromInt a :: Matrix (F 13) 234 am = fromInt a :: Matrix (F 13)
231 bm = fromInt b :: Matrix (F 13) 235 bm = fromInt b :: Matrix (F 13)
232 ad = fromInt a :: Matrix Double 236 ad = fromInt a :: Matrix Double
233 bd = fromInt b :: Matrix Double 237 bd = fromInt b :: Matrix Double
234 238
235 info = do 239 info = do
236 print v 240 print v
237 print m 241 print m