summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Modular.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Modular.hs')
-rw-r--r--packages/base/src/Internal/Modular.hs22
1 files changed, 13 insertions, 9 deletions
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 37f6e9b..c4f95d8 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -36,7 +36,7 @@ import Internal.LAPACK (multiplyI, multiplyL)
36import Internal.Algorithms(luFact) 36import Internal.Algorithms(luFact)
37import Internal.Util(Normed(..),Indexable(..), 37import Internal.Util(Normed(..),Indexable(..),
38 gaussElim, gaussElim_1, gaussElim_2, 38 gaussElim, gaussElim_1, gaussElim_2,
39 luST, luSolve', luPacked', magnit) 39 luST, luSolve', luPacked', magnit, invershur)
40import Internal.ST(mutable) 40import Internal.ST(mutable)
41import GHC.TypeLits 41import GHC.TypeLits
42import Data.Proxy(Proxy) 42import Data.Proxy(Proxy)
@@ -126,9 +126,8 @@ instance forall n t . (Integral t, KnownNat n) => Num (Mod n t)
126 126
127instance KnownNat m => Element (Mod m I) 127instance KnownNat m => Element (Mod m I)
128 where 128 where
129 transdata n v m = i2f (transdata n (f2i v) m)
130 constantD x n = i2f (constantD (unMod x) n) 129 constantD x n = i2f (constantD (unMod x) n)
131 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js 130 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js
132 setRect i j m x = setRect i j (f2iM m) (f2iM x) 131 setRect i j m x = setRect i j (f2iM m) (f2iM x)
133 sortI = sortI . f2i 132 sortI = sortI . f2i
134 sortV = i2f . sortV . f2i 133 sortV = i2f . sortV . f2i
@@ -144,9 +143,8 @@ instance KnownNat m => Element (Mod m I)
144 143
145instance KnownNat m => Element (Mod m Z) 144instance KnownNat m => Element (Mod m Z)
146 where 145 where
147 transdata n v m = i2f (transdata n (f2i v) m)
148 constantD x n = i2f (constantD (unMod x) n) 146 constantD x n = i2f (constantD (unMod x) n)
149 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js 147 extractR ord m mi is mj js = i2fM <$> extractR ord (f2iM m) mi is mj js
150 setRect i j m x = setRect i j (f2iM m) (f2iM x) 148 setRect i j m x = setRect i j (f2iM m) (f2iM x)
151 sortI = sortI . f2i 149 sortI = sortI . f2i
152 sortV = i2f . sortV . f2i 150 sortV = i2f . sortV . f2i
@@ -293,11 +291,11 @@ f2i :: Storable t => Vector (Mod n t) -> Vector t
293f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n) 291f2i v = unsafeFromForeignPtr (castForeignPtr fp) (i) (n)
294 where (fp,i,n) = unsafeToForeignPtr v 292 where (fp,i,n) = unsafeToForeignPtr v
295 293
296f2iM :: Storable t => Matrix (Mod n t) -> Matrix t 294f2iM :: (Element t, Element (Mod n t)) => Matrix (Mod n t) -> Matrix t
297f2iM = liftMatrix f2i 295f2iM m = m { xdat = f2i (xdat m) }
298 296
299i2fM :: Storable t => Matrix t -> Matrix (Mod n t) 297i2fM :: (Element t, Element (Mod n t)) => Matrix t -> Matrix (Mod n t)
300i2fM = liftMatrix i2f 298i2fM m = m { xdat = i2f (xdat m) }
301 299
302vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t) 300vmod :: forall m t. (KnownNat m, Storable t, Integral t, Numeric t) => Vector t -> Vector (Mod m t)
303vmod = i2f . cmod' m' 301vmod = i2f . cmod' m'
@@ -376,6 +374,8 @@ test = (ok, info)
376 bb = flipud aa 374 bb = flipud aa
377 x = luSolve' (luPacked' aa) bb 375 x = luSolve' (luPacked' aa) bb
378 376
377 tmm = diagRect 1 (fromList [2..6]) 5 5 :: Matrix (Mod 19 I)
378
379 info = do 379 info = do
380 print v 380 print v
381 print m 381 print m
@@ -421,6 +421,9 @@ test = (ok, info)
421 print $ checkSolve (sgen 5 :: Matrix (Complex Float)) 421 print $ checkSolve (sgen 5 :: Matrix (Complex Float))
422 print $ checkSolve (gen 5 :: Matrix (Mod 7 I)) 422 print $ checkSolve (gen 5 :: Matrix (Mod 7 I))
423 print $ checkSolve (gen 5 :: Matrix (Mod 7 Z)) 423 print $ checkSolve (gen 5 :: Matrix (Mod 7 Z))
424
425 print $ luSolve' (luPacked' tmm) (ident (rows tmm))
426 print $ invershur tmm
424 427
425 428
426 ok = and 429 ok = and
@@ -449,6 +452,7 @@ test = (ok, info)
449 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) 452 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I))
450 , gm <> gm == konst 0 (3,3) 453 , gm <> gm == konst 0 (3,3)
451 , lgm <> lgm == konst 0 (3,3) 454 , lgm <> lgm == konst 0 (3,3)
455 , invershur tmm == luSolve' (luPacked' tmm) (ident (rows tmm))
452 ] 456 ]
453 457
454 458