summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Util/Modular.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Util/Modular.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/Modular.hs233
1 files changed, 233 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/Modular.hs b/packages/base/src/Numeric/LinearAlgebra/Util/Modular.hs
new file mode 100644
index 0000000..18b635d
--- /dev/null
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/Modular.hs
@@ -0,0 +1,233 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE FlexibleContexts #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE Rank2Types #-}
9{-# LANGUAGE FlexibleInstances #-}
10{-# LANGUAGE GADTs #-}
11{-# LANGUAGE TypeFamilies #-}
12
13
14{- |
15Module : Numeric.LinearAlgebra.Util.Modular
16Copyright : (c) Alberto Ruiz 2015
17License : BSD3
18Stability : experimental
19
20Proof of concept of statically checked modular arithmetic.
21
22-}
23
24module Numeric.LinearAlgebra.Util.Modular(
25 Mod, F
26) where
27
28import Data.Packed.Numeric
29import Numeric.LinearAlgebra.Util(Indexable(..),size)
30import GHC.TypeLits
31import Data.Proxy(Proxy)
32import Foreign.ForeignPtr(castForeignPtr)
33import Data.Vector.Storable(unsafeToForeignPtr, unsafeFromForeignPtr)
34import Foreign.Storable
35import Data.Ratio
36import Data.Packed.Internal.Matrix hiding (mat,size)
37import Data.Packed.Internal.Numeric
38
39
40-- | Wrapper with a phantom integer for statically checked modular arithmetic.
41newtype Mod (n :: Nat) t = Mod {unMod:: t}
42 deriving (Storable)
43
44instance KnownNat m => Enum (F m)
45 where
46 toEnum = l0 (\m x -> fromIntegral $ x `mod` (fromIntegral m))
47 fromEnum = fromIntegral . unMod
48
49instance KnownNat m => Eq (F m)
50 where
51 a == b = (unMod a) == (unMod b)
52
53instance KnownNat m => Ord (F m)
54 where
55 compare a b = compare (unMod a) (unMod b)
56
57instance KnownNat m => Real (F m)
58 where
59 toRational x = toInteger x % 1
60
61instance KnownNat m => Integral (F m)
62 where
63 toInteger = toInteger . unMod
64 quotRem a b = (Mod q, Mod r)
65 where
66 (q,r) = quotRem (unMod a) (unMod b)
67
68-- | this instance is only valid for prime m (not checked)
69instance KnownNat m => Fractional (F m)
70 where
71 recip x = x^(m'-2)
72 where
73 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
74 fromRational x = fromInteger (numerator x) / fromInteger (denominator x)
75
76l2 :: forall m a b c. (KnownNat m) => (Int -> a -> b -> c) -> Mod m a -> Mod m b -> Mod m c
77l2 f (Mod u) (Mod v) = Mod (f m' u v)
78 where
79 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
80
81l1 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> Mod m a -> Mod m b
82l1 f (Mod u) = Mod (f m' u)
83 where
84 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
85
86l0 :: forall m a b . (KnownNat m) => (Int -> a -> b) -> a -> Mod m b
87l0 f u = Mod (f m' u)
88 where
89 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
90
91
92instance Show (F n)
93 where
94 show = show . unMod
95
96instance forall n . KnownNat n => Num (F n)
97 where
98 (+) = l2 (\m a b -> (a + b) `mod` (fromIntegral m))
99 (*) = l2 (\m a b -> (a * b) `mod` (fromIntegral m))
100 (-) = l2 (\m a b -> (a - b) `mod` (fromIntegral m))
101 abs = l1 (const abs)
102 signum = l1 (const signum)
103 fromInteger = l0 (\m x -> fromInteger x `mod` (fromIntegral m))
104
105
106-- | Integer modulo n
107type F n = Mod n I
108
109type V n = Vector (F n)
110type M n = Matrix (F n)
111
112
113instance Element (F n)
114 where
115 transdata n v m = i2f (transdata n (f2i v) m)
116 constantD x n = i2f (constantD (unMod x) n)
117 extractR m mi is mj js = i2fM (extractR (f2iM m) mi is mj js)
118 sortI = sortI . f2i
119 sortV = i2f . sortV . f2i
120 compareV u v = compareV (f2i u) (f2i v)
121 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
122 remapM i j m = i2fM (remap i j (f2iM m))
123
124instance forall m . KnownNat m => Container Vector (F m)
125 where
126 conj' = id
127 size' = size
128 scale' s x = fromInt (scale (unMod s) (toInt x))
129 addConstant c x = fromInt (addConstant (unMod c) (toInt x))
130 add a b = fromInt (add (toInt a) (toInt b))
131 sub a b = fromInt (sub (toInt a) (toInt b))
132 mul a b = fromInt (mul (toInt a) (toInt b))
133 equal u v = equal (toInt u) (toInt v)
134 scalar' x = fromList [x]
135 konst' x = i2f . konst (unMod x)
136 build' n f = build n (fromIntegral . f)
137 cmap' = cmap
138 atIndex' x k = fromIntegral (atIndex (toInt x) k)
139 minIndex' = minIndex . toInt
140 maxIndex' = maxIndex . toInt
141 minElement' = Mod . minElement . toInt
142 maxElement' = Mod . maxElement . toInt
143 sumElements' = fromIntegral . sumElements . toInt
144 prodElements' = fromIntegral . sumElements . toInt
145 step' = i2f . step . toInt
146 find' = findV
147 assoc' = assocV
148 accum' = accumV
149 cond' x y l e g = cselect (ccompare x y) l e g
150 ccompare' a b = ccompare (toInt a) (toInt b)
151 cselect' c l e g = i2f $ cselect c (toInt l) (toInt e) (toInt g)
152 scaleRecip s x = scale' s (cmap recip x)
153 divide x y = mul x (cmap recip y)
154 arctan2' = undefined
155 cmod' m = fromInt' . cmod' (unMod m) . toInt'
156 fromInt' v = i2f $ cmod' (fromIntegral m') (fromInt' v)
157 where
158 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
159 toInt' = f2i
160
161
162instance Indexable (Vector (F m)) (F m)
163 where
164 (!) = (@>)
165
166
167type instance RealOf (F n) = I
168
169
170instance KnownNat m => Product (F m) where
171 norm2 = undefined
172 absSum = undefined
173 norm1 = undefined
174 normInf = undefined
175 multiply = lift2 multiply
176
177
178instance KnownNat m => Numeric (F m)
179
180i2f :: Vector I -> Vector (F n)
181i2f v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
182 where (fp,i,n) = unsafeToForeignPtr v
183
184f2i :: Vector (F n) -> Vector I
185f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
186 where (fp,i,n) = unsafeToForeignPtr v
187
188f2iM :: Matrix (F n) -> Matrix I
189f2iM = liftMatrix f2i
190
191i2fM :: Matrix I -> Matrix (F n)
192i2fM = liftMatrix i2f
193
194
195lift1 f a = fromInt (f (toInt a))
196lift2 f a b = fromInt (f (toInt a) (toInt b))
197
198instance forall m . KnownNat m => Num (V m)
199 where
200 (+) = lift2 (+)
201 (*) = lift2 (*)
202 (-) = lift2 (-)
203 abs = lift1 abs
204 signum = lift1 signum
205 negate = lift1 negate
206 fromInteger x = fromInt (fromInteger x)
207
208
209--------------------------------------------------------------------------------
210
211instance (KnownNat m) => Testable (M m)
212 where
213 checkT _ = test
214
215test = (ok, info)
216 where
217 v = fromList [3,-5,75] :: V 11
218 m = (3><3) [1..] :: M 11
219 info = do
220 print v
221 print m
222 print (tr m)
223 print $ v+v
224 print $ m+m
225 print $ m <> m
226 print $ m #> v
227
228
229 ok = and
230 [ toInt (m #> v) == cmod 11 (toInt m #> toInt v )
231 ]
232
233