diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-13 19:18:16 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-13 19:18:16 +0200 |
commit | 717c680a4b65a2226b0dd6fc13f7c63e7bc0431d (patch) | |
tree | 1775c3c363a0b61f5f6a6ec1f22fe9b7d5864dc4 /packages/base/src | |
parent | 4b3e29097aa272d429f8005fe17b459cf0c049c8 (diff) |
setRect, general luPacked' based on luST
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Internal/C/lapack-aux.c | 54 | ||||
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 20 | ||||
-rw-r--r-- | packages/base/src/Internal/Modular.hs | 75 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 7 | ||||
-rw-r--r-- | packages/base/src/Internal/Util.hs | 48 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra.hs | 12 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Devel.hs | 2 |
7 files changed, 169 insertions, 49 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c index e42889d..2843ab5 100644 --- a/packages/base/src/Internal/C/lapack-aux.c +++ b/packages/base/src/Internal/C/lapack-aux.c | |||
@@ -1448,7 +1448,7 @@ int transL(KLMAT(x),LMAT(t)) TRANS_IMP | |||
1448 | 1448 | ||
1449 | //////////////////////// extract ///////////////////////////////// | 1449 | //////////////////////// extract ///////////////////////////////// |
1450 | 1450 | ||
1451 | #define EXTRACT_IMP \ | 1451 | #define EXTRACT_IMP { \ |
1452 | int i,j,si,sj,ni,nj; \ | 1452 | int i,j,si,sj,ni,nj; \ |
1453 | ni = modei ? in : ip[1]-ip[0]+1; \ | 1453 | ni = modei ? in : ip[1]-ip[0]+1; \ |
1454 | nj = modej ? jn : jp[1]-jp[0]+1; \ | 1454 | nj = modej ? jn : jp[1]-jp[0]+1; \ |
@@ -1461,33 +1461,35 @@ int transL(KLMAT(x),LMAT(t)) TRANS_IMP | |||
1461 | \ | 1461 | \ |
1462 | AT(r,i,j) = AT(m,si,sj); \ | 1462 | AT(r,i,j) = AT(m,si,sj); \ |
1463 | } \ | 1463 | } \ |
1464 | } \ | 1464 | } OK } |
1465 | OK | ||
1466 | |||
1467 | int extractD(int modei, int modej, KIVEC(i), KIVEC(j), KODMAT(m), ODMAT(r)) { | ||
1468 | EXTRACT_IMP | ||
1469 | } | ||
1470 | |||
1471 | int extractF(int modei, int modej, KIVEC(i), KIVEC(j), KOFMAT(m), OFMAT(r)) { | ||
1472 | EXTRACT_IMP | ||
1473 | } | ||
1474 | |||
1475 | int extractC(int modei, int modej, KIVEC(i), KIVEC(j), KOCMAT(m), OCMAT(r)) { | ||
1476 | EXTRACT_IMP | ||
1477 | } | ||
1478 | |||
1479 | int extractQ(int modei, int modej, KIVEC(i), KIVEC(j), KOQMAT(m), OQMAT(r)) { | ||
1480 | EXTRACT_IMP | ||
1481 | } | ||
1482 | |||
1483 | int extractI(int modei, int modej, KIVEC(i), KIVEC(j), KOIMAT(m), OIMAT(r)) { | ||
1484 | EXTRACT_IMP | ||
1485 | } | ||
1486 | 1465 | ||
1487 | int extractL(int modei, int modej, KIVEC(i), KIVEC(j), KOLMAT(m), OLMAT(r)) { | 1466 | #define EXTRACT(T) int extract##T(int modei, int modej, KIVEC(i), KIVEC(j), KO##T##MAT(m), O##T##MAT(r)) EXTRACT_IMP |
1488 | EXTRACT_IMP | 1467 | |
1489 | } | 1468 | EXTRACT(D) |
1469 | EXTRACT(F) | ||
1470 | EXTRACT(C) | ||
1471 | EXTRACT(Q) | ||
1472 | EXTRACT(I) | ||
1473 | EXTRACT(L) | ||
1474 | |||
1475 | //////////////////////// setRect ///////////////////////////////// | ||
1476 | |||
1477 | #define SETRECT(T) \ | ||
1478 | int setRect##T(int i, int j, KO##T##MAT(m), O##T##MAT(r)) { \ | ||
1479 | { TRAV(m,a,b) { \ | ||
1480 | int x = a+i, y = b+j; \ | ||
1481 | if(x>=0 && x<rr && y>=0 && y<rc) { \ | ||
1482 | AT(r,x,y) = AT(m,a,b); \ | ||
1483 | } \ | ||
1484 | } \ | ||
1485 | } OK } | ||
1490 | 1486 | ||
1487 | SETRECT(D) | ||
1488 | SETRECT(F) | ||
1489 | SETRECT(C) | ||
1490 | SETRECT(Q) | ||
1491 | SETRECT(I) | ||
1492 | SETRECT(L) | ||
1491 | 1493 | ||
1492 | //////////////////////// remap ///////////////////////////////// | 1494 | //////////////////////// remap ///////////////////////////////// |
1493 | 1495 | ||
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index fa1aad6..e0f5ed2 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -275,6 +275,7 @@ class (Storable a) => Element a where | |||
275 | transdata :: Int -> Vector a -> Int -> Vector a | 275 | transdata :: Int -> Vector a -> Int -> Vector a |
276 | constantD :: a -> Int -> Vector a | 276 | constantD :: a -> Int -> Vector a |
277 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) | 277 | extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) |
278 | setRect :: Int -> Int -> Matrix a -> Matrix a -> IO () | ||
278 | sortI :: Ord a => Vector a -> Vector CInt | 279 | sortI :: Ord a => Vector a -> Vector CInt |
279 | sortV :: Ord a => Vector a -> Vector a | 280 | sortV :: Ord a => Vector a -> Vector a |
280 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt | 281 | compareV :: Ord a => Vector a -> Vector a -> Vector CInt |
@@ -287,6 +288,7 @@ instance Element Float where | |||
287 | transdata = transdataAux ctransF | 288 | transdata = transdataAux ctransF |
288 | constantD = constantAux cconstantF | 289 | constantD = constantAux cconstantF |
289 | extractR = extractAux c_extractF | 290 | extractR = extractAux c_extractF |
291 | setRect = setRectAux c_setRectF | ||
290 | sortI = sortIdxF | 292 | sortI = sortIdxF |
291 | sortV = sortValF | 293 | sortV = sortValF |
292 | compareV = compareF | 294 | compareV = compareF |
@@ -298,6 +300,7 @@ instance Element Double where | |||
298 | transdata = transdataAux ctransR | 300 | transdata = transdataAux ctransR |
299 | constantD = constantAux cconstantR | 301 | constantD = constantAux cconstantR |
300 | extractR = extractAux c_extractD | 302 | extractR = extractAux c_extractD |
303 | setRect = setRectAux c_setRectD | ||
301 | sortI = sortIdxD | 304 | sortI = sortIdxD |
302 | sortV = sortValD | 305 | sortV = sortValD |
303 | compareV = compareD | 306 | compareV = compareD |
@@ -310,6 +313,7 @@ instance Element (Complex Float) where | |||
310 | transdata = transdataAux ctransQ | 313 | transdata = transdataAux ctransQ |
311 | constantD = constantAux cconstantQ | 314 | constantD = constantAux cconstantQ |
312 | extractR = extractAux c_extractQ | 315 | extractR = extractAux c_extractQ |
316 | setRect = setRectAux c_setRectQ | ||
313 | sortI = undefined | 317 | sortI = undefined |
314 | sortV = undefined | 318 | sortV = undefined |
315 | compareV = undefined | 319 | compareV = undefined |
@@ -322,6 +326,7 @@ instance Element (Complex Double) where | |||
322 | transdata = transdataAux ctransC | 326 | transdata = transdataAux ctransC |
323 | constantD = constantAux cconstantC | 327 | constantD = constantAux cconstantC |
324 | extractR = extractAux c_extractC | 328 | extractR = extractAux c_extractC |
329 | setRect = setRectAux c_setRectC | ||
325 | sortI = undefined | 330 | sortI = undefined |
326 | sortV = undefined | 331 | sortV = undefined |
327 | compareV = undefined | 332 | compareV = undefined |
@@ -333,6 +338,7 @@ instance Element (CInt) where | |||
333 | transdata = transdataAux ctransI | 338 | transdata = transdataAux ctransI |
334 | constantD = constantAux cconstantI | 339 | constantD = constantAux cconstantI |
335 | extractR = extractAux c_extractI | 340 | extractR = extractAux c_extractI |
341 | setRect = setRectAux c_setRectI | ||
336 | sortI = sortIdxI | 342 | sortI = sortIdxI |
337 | sortV = sortValI | 343 | sortV = sortValI |
338 | compareV = compareI | 344 | compareV = compareI |
@@ -344,6 +350,7 @@ instance Element Z where | |||
344 | transdata = transdataAux ctransL | 350 | transdata = transdataAux ctransL |
345 | constantD = constantAux cconstantL | 351 | constantD = constantAux cconstantL |
346 | extractR = extractAux c_extractL | 352 | extractR = extractAux c_extractL |
353 | setRect = setRectAux c_setRectL | ||
347 | sortI = sortIdxL | 354 | sortI = sortIdxL |
348 | sortV = sortValL | 355 | sortV = sortValL |
349 | compareV = compareL | 356 | compareV = compareL |
@@ -454,6 +461,19 @@ foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) | |||
454 | foreign import ccall unsafe "extractI" c_extractI :: Extr CInt | 461 | foreign import ccall unsafe "extractI" c_extractI :: Extr CInt |
455 | foreign import ccall unsafe "extractL" c_extractL :: Extr Z | 462 | foreign import ccall unsafe "extractL" c_extractL :: Extr Z |
456 | 463 | ||
464 | --------------------------------------------------------------- | ||
465 | |||
466 | setRectAux f i j m r = app2 (f (fi i) (fi j)) omat m omat r "setRect" | ||
467 | |||
468 | type SetRect x = I -> I -> x ::> x::> Ok | ||
469 | |||
470 | foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double | ||
471 | foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float | ||
472 | foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double) | ||
473 | foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float) | ||
474 | foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I | ||
475 | foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z | ||
476 | |||
457 | -------------------------------------------------------------------------------- | 477 | -------------------------------------------------------------------------------- |
458 | 478 | ||
459 | sortG f v = unsafePerformIO $ do | 479 | sortG f v = unsafePerformIO $ do |
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs index 824fc57..3b27310 100644 --- a/packages/base/src/Internal/Modular.hs +++ b/packages/base/src/Internal/Modular.hs | |||
@@ -33,12 +33,15 @@ import Internal.Element | |||
33 | import Internal.Container | 33 | import Internal.Container |
34 | import Internal.Vectorized (prodI,sumI,prodL,sumL) | 34 | import Internal.Vectorized (prodI,sumI,prodL,sumL) |
35 | import Internal.LAPACK (multiplyI, multiplyL) | 35 | import Internal.LAPACK (multiplyI, multiplyL) |
36 | import Internal.Util(Indexable(..),gaussElim) | 36 | import Internal.Algorithms(luFact) |
37 | import Internal.Util(Normed(..),Indexable(..),gaussElim, gaussElim_1, gaussElim_2,luST, magnit) | ||
38 | import Internal.ST(mutable) | ||
37 | import GHC.TypeLits | 39 | import GHC.TypeLits |
38 | import Data.Proxy(Proxy) | 40 | import Data.Proxy(Proxy) |
39 | import Foreign.ForeignPtr(castForeignPtr) | 41 | import Foreign.ForeignPtr(castForeignPtr) |
40 | import Foreign.Storable | 42 | import Foreign.Storable |
41 | import Data.Ratio | 43 | import Data.Ratio |
44 | import Data.Complex | ||
42 | 45 | ||
43 | 46 | ||
44 | 47 | ||
@@ -116,6 +119,7 @@ instance KnownNat m => Element (Mod m I) | |||
116 | transdata n v m = i2f (transdata n (f2i v) m) | 119 | transdata n v m = i2f (transdata n (f2i v) m) |
117 | constantD x n = i2f (constantD (unMod x) n) | 120 | constantD x n = i2f (constantD (unMod x) n) |
118 | extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js | 121 | extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js |
122 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | ||
119 | sortI = sortI . f2i | 123 | sortI = sortI . f2i |
120 | sortV = i2f . sortV . f2i | 124 | sortV = i2f . sortV . f2i |
121 | compareV u v = compareV (f2i u) (f2i v) | 125 | compareV u v = compareV (f2i u) (f2i v) |
@@ -130,6 +134,7 @@ instance KnownNat m => Element (Mod m Z) | |||
130 | transdata n v m = i2f (transdata n (f2i v) m) | 134 | transdata n v m = i2f (transdata n (f2i v) m) |
131 | constantD x n = i2f (constantD (unMod x) n) | 135 | constantD x n = i2f (constantD (unMod x) n) |
132 | extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js | 136 | extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js |
137 | setRect i j m x = setRect i j (f2iM m) (f2iM x) | ||
133 | sortI = sortI . f2i | 138 | sortI = sortI . f2i |
134 | sortV = i2f . sortV . f2i | 139 | sortV = i2f . sortV . f2i |
135 | compareV u v = compareV (f2i u) (f2i v) | 140 | compareV u v = compareV (f2i u) (f2i v) |
@@ -139,18 +144,6 @@ instance KnownNat m => Element (Mod m Z) | |||
139 | where | 144 | where |
140 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 145 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
141 | 146 | ||
142 | {- | ||
143 | instance (Ord t, Element t) => Element (Mod m t) | ||
144 | where | ||
145 | transdata n v m = i2f (transdata n (f2i v) m) | ||
146 | constantD x n = i2f (constantD (unMod x) n) | ||
147 | extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js | ||
148 | sortI = sortI . f2i | ||
149 | sortV = i2f . sortV . f2i | ||
150 | compareV u v = compareV (f2i u) (f2i v) | ||
151 | selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g)) | ||
152 | remapM i j m = i2fM (remap i j (f2iM m)) | ||
153 | -} | ||
154 | 147 | ||
155 | instance forall m . KnownNat m => Container Vector (Mod m I) | 148 | instance forall m . KnownNat m => Container Vector (Mod m I) |
156 | where | 149 | where |
@@ -258,6 +251,20 @@ instance KnownNat m => Product (Mod m Z) where | |||
258 | where | 251 | where |
259 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 252 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
260 | 253 | ||
254 | instance KnownNat m => Normed (Vector (Mod m I)) | ||
255 | where | ||
256 | norm_0 = norm_0 . toInt | ||
257 | norm_1 = norm_1 . toInt | ||
258 | norm_2 = norm_2 . toInt | ||
259 | norm_Inf = norm_Inf . toInt | ||
260 | |||
261 | instance KnownNat m => Normed (Vector (Mod m Z)) | ||
262 | where | ||
263 | norm_0 = norm_0 . toZ | ||
264 | norm_1 = norm_1 . toZ | ||
265 | norm_2 = norm_2 . toZ | ||
266 | norm_Inf = norm_Inf . toZ | ||
267 | |||
261 | 268 | ||
262 | instance KnownNat m => Numeric (Mod m I) | 269 | instance KnownNat m => Numeric (Mod m I) |
263 | instance KnownNat m => Numeric (Mod m Z) | 270 | instance KnownNat m => Numeric (Mod m Z) |
@@ -334,6 +341,15 @@ test = (ok, info) | |||
334 | lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z | 341 | lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z |
335 | lgm = fromZ lg :: Matrix (Mod 10000000000 Z) | 342 | lgm = fromZ lg :: Matrix (Mod 10000000000 Z) |
336 | 343 | ||
344 | gen n = diagRect 1 (konst 5 n) n n :: Numeric t => Matrix t | ||
345 | |||
346 | checkGen x = norm_Inf $ flatten $ invg x <> x - ident (rows x) | ||
347 | |||
348 | invg t = gaussElim t (ident (rows t)) | ||
349 | |||
350 | checkLU okf t = norm_Inf $ flatten (l <> u <> p - t) | ||
351 | where | ||
352 | (l,u,p,_ :: Int) = luFact $ mutable (luST okf) t | ||
337 | 353 | ||
338 | info = do | 354 | info = do |
339 | print v | 355 | print v |
@@ -356,11 +372,42 @@ test = (ok, info) | |||
356 | print $ lg <> lg | 372 | print $ lg <> lg |
357 | print lgm | 373 | print lgm |
358 | print $ lgm <> lgm | 374 | print $ lgm <> lgm |
375 | |||
376 | print (checkGen (gen 5 :: Matrix R)) | ||
377 | print (checkGen (gen 5 :: Matrix C)) | ||
378 | print (checkGen (gen 5 :: Matrix Float)) | ||
379 | print (checkGen (gen 5 :: Matrix (Complex Float))) | ||
380 | print (invg (gen 5) :: Matrix (Mod 7 I)) | ||
381 | print (invg (gen 5) :: Matrix (Mod 7 Z)) | ||
382 | |||
383 | print $ mutable (luST (const True)) (gen 5 :: Matrix R) | ||
384 | print $ mutable (luST (const True)) (gen 5 :: Matrix (Mod 11 Z)) | ||
385 | |||
386 | print $ checkLU (magnit 0) (gen 5 :: Matrix R) | ||
387 | print $ checkLU (magnit 0) (gen 5 :: Matrix Float) | ||
388 | print $ checkLU (magnit 0) (gen 5 :: Matrix C) | ||
389 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Complex Float)) | ||
390 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I)) | ||
391 | print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z)) | ||
359 | 392 | ||
360 | 393 | ||
361 | ok = and | 394 | ok = and |
362 | [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) | 395 | [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) |
363 | , am <> gaussElim am bm == bm | 396 | , am <> gaussElim_1 am bm == bm |
397 | , am <> gaussElim_2 am bm == bm | ||
398 | , am <> gaussElim am bm == bm | ||
399 | , (checkGen (gen 5 :: Matrix R)) < 1E-15 | ||
400 | , (checkGen (gen 5 :: Matrix Float)) < 1E-7 | ||
401 | , (checkGen (gen 5 :: Matrix C)) < 1E-15 | ||
402 | , (checkGen (gen 5 :: Matrix (Complex Float))) < 1E-7 | ||
403 | , (checkGen (gen 5 :: Matrix (Mod 7 I))) == 0 | ||
404 | , (checkGen (gen 5 :: Matrix (Mod 7 Z))) == 0 | ||
405 | , (checkLU (magnit 1E-10) (gen 5 :: Matrix R)) < 1E-15 | ||
406 | , (checkLU (magnit 1E-5) (gen 5 :: Matrix Float)) < 1E-6 | ||
407 | , (checkLU (magnit 1E-10) (gen 5 :: Matrix C)) < 1E-15 | ||
408 | , (checkLU (magnit 1E-5) (gen 5 :: Matrix (Complex Float))) < 1E-6 | ||
409 | , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))) == 0 | ||
410 | , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))) == 0 | ||
364 | , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) | 411 | , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) |
365 | , gm <> gm == konst 0 (3,3) | 412 | , gm <> gm == konst 0 (3,3) |
366 | , lgm <> lgm == konst 0 (3,3) | 413 | , lgm <> lgm == konst 0 (3,3) |
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 107d3c3..a84ca25 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -21,7 +21,7 @@ module Internal.ST ( | |||
21 | -- * Mutable Matrices | 21 | -- * Mutable Matrices |
22 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 22 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
23 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 23 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
24 | axpy, scal, swap, extractRect, | 24 | axpy, scal, swap, extractMatrix, setMatrix, rowOpST, |
25 | mutable, | 25 | mutable, |
26 | -- * Unsafe functions | 26 | -- * Unsafe functions |
27 | newUndefinedVector, | 27 | newUndefinedVector, |
@@ -166,6 +166,9 @@ readMatrix = safeIndexM unsafeReadMatrix | |||
166 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () | 166 | writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () |
167 | writeMatrix = safeIndexM unsafeWriteMatrix | 167 | writeMatrix = safeIndexM unsafeWriteMatrix |
168 | 168 | ||
169 | setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s () | ||
170 | setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x | ||
171 | |||
169 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) | 172 | newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) |
170 | newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c | 173 | newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c |
171 | 174 | ||
@@ -182,7 +185,7 @@ axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m) | |||
182 | scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) | 185 | scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) |
183 | swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) | 186 | swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) |
184 | 187 | ||
185 | extractRect (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) | 188 | extractMatrix (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) |
186 | 189 | ||
187 | -------------------------------------------------------------------------------- | 190 | -------------------------------------------------------------------------------- |
188 | 191 | ||
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs index 7a556e9..2650ac8 100644 --- a/packages/base/src/Internal/Util.hs +++ b/packages/base/src/Internal/Util.hs | |||
@@ -41,6 +41,7 @@ module Internal.Util( | |||
41 | norm, | 41 | norm, |
42 | ℕ,ℤ,ℝ,ℂ,iC, | 42 | ℕ,ℤ,ℝ,ℂ,iC, |
43 | Normed(..), norm_Frob, norm_nuclear, | 43 | Normed(..), norm_Frob, norm_nuclear, |
44 | magnit, | ||
44 | unitary, | 45 | unitary, |
45 | mt, | 46 | mt, |
46 | (~!~), | 47 | (~!~), |
@@ -54,7 +55,7 @@ module Internal.Util( | |||
54 | -- ** 2D | 55 | -- ** 2D |
55 | corr2, conv2, separable, | 56 | corr2, conv2, separable, |
56 | block2x2,block3x3,view1,unView1,foldMatrix, | 57 | block2x2,block3x3,view1,unView1,foldMatrix, |
57 | gaussElim_1, gaussElim_2, gaussElim | 58 | gaussElim_1, gaussElim_2, gaussElim, luST |
58 | ) where | 59 | ) where |
59 | 60 | ||
60 | import Internal.Vector | 61 | import Internal.Vector |
@@ -300,6 +301,26 @@ instance Normed (Vector I) | |||
300 | norm_2 v = sqrt . fromIntegral $ dot v v | 301 | norm_2 v = sqrt . fromIntegral $ dot v v |
301 | norm_Inf = fromIntegral . normInf | 302 | norm_Inf = fromIntegral . normInf |
302 | 303 | ||
304 | instance Normed (Vector Z) | ||
305 | where | ||
306 | norm_0 = fromIntegral . sumElements . step . abs | ||
307 | norm_1 = fromIntegral . norm1 | ||
308 | norm_2 v = sqrt . fromIntegral $ dot v v | ||
309 | norm_Inf = fromIntegral . normInf | ||
310 | |||
311 | instance Normed (Vector Float) | ||
312 | where | ||
313 | norm_0 = norm_0 . double | ||
314 | norm_1 = norm_1 . double | ||
315 | norm_2 = norm_2 . double | ||
316 | norm_Inf = norm_Inf . double | ||
317 | |||
318 | instance Normed (Vector (Complex Float)) | ||
319 | where | ||
320 | norm_0 = norm_0 . double | ||
321 | norm_1 = norm_1 . double | ||
322 | norm_2 = norm_2 . double | ||
323 | norm_Inf = norm_Inf . double | ||
303 | 324 | ||
304 | 325 | ||
305 | norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> ℝ | 326 | norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> ℝ |
@@ -308,6 +329,9 @@ norm_Frob = norm_2 . flatten | |||
308 | norm_nuclear :: Field t => Matrix t -> ℝ | 329 | norm_nuclear :: Field t => Matrix t -> ℝ |
309 | norm_nuclear = sumElements . singularValues | 330 | norm_nuclear = sumElements . singularValues |
310 | 331 | ||
332 | magnit :: (Element t, Normed (Vector t)) => R -> t -> Bool | ||
333 | magnit e x = norm_1 (fromList [x]) > e | ||
334 | |||
311 | 335 | ||
312 | -- | Obtains a vector in the same direction with 2-norm=1 | 336 | -- | Obtains a vector in the same direction with 2-norm=1 |
313 | unitary :: Vector Double -> Vector Double | 337 | unitary :: Vector Double -> Vector Double |
@@ -618,9 +642,10 @@ gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]] | |||
618 | gaussST (r,_) x = do | 642 | gaussST (r,_) x = do |
619 | let n = r-1 | 643 | let n = r-1 |
620 | forM_ [0..n] $ \i -> do | 644 | forM_ [0..n] $ \i -> do |
621 | c <- maxIndex . abs . flatten <$> extractRect x i n i i | 645 | c <- maxIndex . abs . flatten <$> extractMatrix x i n i i |
622 | swap x i (i+c) | 646 | swap x i (i+c) |
623 | a <- readMatrix x i i | 647 | a <- readMatrix x i i |
648 | when (a == 0) $ error "singular!" | ||
624 | scal x (recip a) i | 649 | scal x (recip a) i |
625 | forM_ [i+1..n] $ \j -> do | 650 | forM_ [i+1..n] $ \j -> do |
626 | b <- readMatrix x j i | 651 | b <- readMatrix x j i |
@@ -630,6 +655,25 @@ gaussST (r,_) x = do | |||
630 | b <- readMatrix x j i | 655 | b <- readMatrix x j i |
631 | axpy x (-b) i j | 656 | axpy x (-b) i j |
632 | 657 | ||
658 | |||
659 | luST ok (r,c) x = do | ||
660 | let n = r-1 | ||
661 | axpy' m a i j = rowOpST 0 a i j (i+1) (c-1) m | ||
662 | p <- thawMatrix . asColumn . range $ r | ||
663 | forM_ [0..n] $ \i -> do | ||
664 | k <- maxIndex . abs . flatten <$> extractMatrix x i n i i | ||
665 | writeMatrix p i 0 (fi (k+i)) | ||
666 | swap x i (i+k) | ||
667 | a <- readMatrix x i i | ||
668 | when (ok a) $ do | ||
669 | forM_ [i+1..n] $ \j -> do | ||
670 | b <- (/a) <$> readMatrix x j i | ||
671 | axpy' x (-b) i j | ||
672 | writeMatrix x j i b | ||
673 | v <- unsafeFreezeMatrix p | ||
674 | return (map ti $ toList $ flatten v) | ||
675 | |||
676 | |||
633 | -------------------------------------------------------------------------------- | 677 | -------------------------------------------------------------------------------- |
634 | 678 | ||
635 | instance Testable (Matrix I) where | 679 | instance Testable (Matrix I) where |
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs index c97f415..0f8efa4 100644 --- a/packages/base/src/Numeric/LinearAlgebra.hs +++ b/packages/base/src/Numeric/LinearAlgebra.hs | |||
@@ -1,3 +1,5 @@ | |||
1 | {-# LANGUAGE FlexibleContexts #-} | ||
2 | |||
1 | ----------------------------------------------------------------------------- | 3 | ----------------------------------------------------------------------------- |
2 | {- | | 4 | {- | |
3 | Module : Numeric.LinearAlgebra | 5 | Module : Numeric.LinearAlgebra |
@@ -119,7 +121,7 @@ module Numeric.LinearAlgebra ( | |||
119 | schur, | 121 | schur, |
120 | 122 | ||
121 | -- * LU | 123 | -- * LU |
122 | lu, luPacked, | 124 | lu, luPacked, luFact, luPacked', |
123 | 125 | ||
124 | -- * Matrix functions | 126 | -- * Matrix functions |
125 | expm, | 127 | expm, |
@@ -134,7 +136,7 @@ module Numeric.LinearAlgebra ( | |||
134 | Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, | 136 | Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, |
135 | 137 | ||
136 | -- * Misc | 138 | -- * Misc |
137 | meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, gaussElim_1, gaussElim_2, | 139 | meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, luST, magnit, |
138 | ℝ,ℂ,iC, | 140 | ℝ,ℂ,iC, |
139 | -- * Auxiliary classes | 141 | -- * Auxiliary classes |
140 | Element, Container, Product, Numeric, LSDiv, | 142 | Element, Container, Product, Numeric, LSDiv, |
@@ -142,7 +144,6 @@ module Numeric.LinearAlgebra ( | |||
142 | RealOf, ComplexOf, SingleOf, DoubleOf, | 144 | RealOf, ComplexOf, SingleOf, DoubleOf, |
143 | IndexOf, | 145 | IndexOf, |
144 | Field, | 146 | Field, |
145 | -- Normed, | ||
146 | Transposable, | 147 | Transposable, |
147 | CGState(..), | 148 | CGState(..), |
148 | Testable(..) | 149 | Testable(..) |
@@ -155,13 +156,14 @@ import Numeric.Vector() | |||
155 | import Internal.Matrix | 156 | import Internal.Matrix |
156 | import Internal.Container hiding ((<>)) | 157 | import Internal.Container hiding ((<>)) |
157 | import Internal.Numeric hiding (mul) | 158 | import Internal.Numeric hiding (mul) |
158 | import Internal.Algorithms hiding (linearSolve,Normed,orth) | 159 | import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked') |
159 | import qualified Internal.Algorithms as A | 160 | import qualified Internal.Algorithms as A |
160 | import Internal.Util | 161 | import Internal.Util |
161 | import Internal.Random | 162 | import Internal.Random |
162 | import Internal.Sparse((!#>)) | 163 | import Internal.Sparse((!#>)) |
163 | import Internal.CG | 164 | import Internal.CG |
164 | import Internal.Conversion | 165 | import Internal.Conversion |
166 | import Internal.ST(mutable) | ||
165 | 167 | ||
166 | {- | infix synonym of 'mul' | 168 | {- | infix synonym of 'mul' |
167 | 169 | ||
@@ -236,3 +238,5 @@ nullspace m = nullspaceSVD (Left (1*eps)) m (rightSV m) | |||
236 | -- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. | 238 | -- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. |
237 | orth m = orthSVD (Left (1*eps)) m (leftSV m) | 239 | orth m = orthSVD (Left (1*eps)) m (leftSV m) |
238 | 240 | ||
241 | luPacked' x = mutable (luST (magnit 0)) x | ||
242 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index 84763fe..f572656 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs | |||
@@ -44,7 +44,7 @@ module Numeric.LinearAlgebra.Devel( | |||
44 | -- ** Mutable Matrices | 44 | -- ** Mutable Matrices |
45 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, | 45 | STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, |
46 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, | 46 | readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, |
47 | axpy,scal,swap, extractRect, mutable, | 47 | axpy,scal,swap, extractMatrix, setMatrix, mutable, rowOpST, |
48 | -- ** Unsafe functions | 48 | -- ** Unsafe functions |
49 | newUndefinedVector, | 49 | newUndefinedVector, |
50 | unsafeReadVector, unsafeWriteVector, | 50 | unsafeReadVector, unsafeWriteVector, |