summaryrefslogtreecommitdiff
path: root/packages
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2015-06-13 19:18:16 +0200
committerAlberto Ruiz <aruiz@um.es>2015-06-13 19:18:16 +0200
commit717c680a4b65a2226b0dd6fc13f7c63e7bc0431d (patch)
tree1775c3c363a0b61f5f6a6ec1f22fe9b7d5864dc4 /packages
parent4b3e29097aa272d429f8005fe17b459cf0c049c8 (diff)
setRect, general luPacked' based on luST
Diffstat (limited to 'packages')
-rw-r--r--packages/base/src/Internal/C/lapack-aux.c54
-rw-r--r--packages/base/src/Internal/Matrix.hs20
-rw-r--r--packages/base/src/Internal/Modular.hs75
-rw-r--r--packages/base/src/Internal/ST.hs7
-rw-r--r--packages/base/src/Internal/Util.hs48
-rw-r--r--packages/base/src/Numeric/LinearAlgebra.hs12
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Devel.hs2
7 files changed, 169 insertions, 49 deletions
diff --git a/packages/base/src/Internal/C/lapack-aux.c b/packages/base/src/Internal/C/lapack-aux.c
index e42889d..2843ab5 100644
--- a/packages/base/src/Internal/C/lapack-aux.c
+++ b/packages/base/src/Internal/C/lapack-aux.c
@@ -1448,7 +1448,7 @@ int transL(KLMAT(x),LMAT(t)) TRANS_IMP
1448 1448
1449//////////////////////// extract ///////////////////////////////// 1449//////////////////////// extract /////////////////////////////////
1450 1450
1451#define EXTRACT_IMP \ 1451#define EXTRACT_IMP { \
1452 int i,j,si,sj,ni,nj; \ 1452 int i,j,si,sj,ni,nj; \
1453 ni = modei ? in : ip[1]-ip[0]+1; \ 1453 ni = modei ? in : ip[1]-ip[0]+1; \
1454 nj = modej ? jn : jp[1]-jp[0]+1; \ 1454 nj = modej ? jn : jp[1]-jp[0]+1; \
@@ -1461,33 +1461,35 @@ int transL(KLMAT(x),LMAT(t)) TRANS_IMP
1461 \ 1461 \
1462 AT(r,i,j) = AT(m,si,sj); \ 1462 AT(r,i,j) = AT(m,si,sj); \
1463 } \ 1463 } \
1464 } \ 1464 } OK }
1465 OK
1466
1467int extractD(int modei, int modej, KIVEC(i), KIVEC(j), KODMAT(m), ODMAT(r)) {
1468 EXTRACT_IMP
1469}
1470
1471int extractF(int modei, int modej, KIVEC(i), KIVEC(j), KOFMAT(m), OFMAT(r)) {
1472 EXTRACT_IMP
1473}
1474
1475int extractC(int modei, int modej, KIVEC(i), KIVEC(j), KOCMAT(m), OCMAT(r)) {
1476 EXTRACT_IMP
1477}
1478
1479int extractQ(int modei, int modej, KIVEC(i), KIVEC(j), KOQMAT(m), OQMAT(r)) {
1480 EXTRACT_IMP
1481}
1482
1483int extractI(int modei, int modej, KIVEC(i), KIVEC(j), KOIMAT(m), OIMAT(r)) {
1484 EXTRACT_IMP
1485}
1486 1465
1487int extractL(int modei, int modej, KIVEC(i), KIVEC(j), KOLMAT(m), OLMAT(r)) { 1466#define EXTRACT(T) int extract##T(int modei, int modej, KIVEC(i), KIVEC(j), KO##T##MAT(m), O##T##MAT(r)) EXTRACT_IMP
1488 EXTRACT_IMP 1467
1489} 1468EXTRACT(D)
1469EXTRACT(F)
1470EXTRACT(C)
1471EXTRACT(Q)
1472EXTRACT(I)
1473EXTRACT(L)
1474
1475//////////////////////// setRect /////////////////////////////////
1476
1477#define SETRECT(T) \
1478int setRect##T(int i, int j, KO##T##MAT(m), O##T##MAT(r)) { \
1479 { TRAV(m,a,b) { \
1480 int x = a+i, y = b+j; \
1481 if(x>=0 && x<rr && y>=0 && y<rc) { \
1482 AT(r,x,y) = AT(m,a,b); \
1483 } \
1484 } \
1485 } OK }
1490 1486
1487SETRECT(D)
1488SETRECT(F)
1489SETRECT(C)
1490SETRECT(Q)
1491SETRECT(I)
1492SETRECT(L)
1491 1493
1492//////////////////////// remap ///////////////////////////////// 1494//////////////////////// remap /////////////////////////////////
1493 1495
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index fa1aad6..e0f5ed2 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -275,6 +275,7 @@ class (Storable a) => Element a where
275 transdata :: Int -> Vector a -> Int -> Vector a 275 transdata :: Int -> Vector a -> Int -> Vector a
276 constantD :: a -> Int -> Vector a 276 constantD :: a -> Int -> Vector a
277 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a) 277 extractR :: Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
278 setRect :: Int -> Int -> Matrix a -> Matrix a -> IO ()
278 sortI :: Ord a => Vector a -> Vector CInt 279 sortI :: Ord a => Vector a -> Vector CInt
279 sortV :: Ord a => Vector a -> Vector a 280 sortV :: Ord a => Vector a -> Vector a
280 compareV :: Ord a => Vector a -> Vector a -> Vector CInt 281 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
@@ -287,6 +288,7 @@ instance Element Float where
287 transdata = transdataAux ctransF 288 transdata = transdataAux ctransF
288 constantD = constantAux cconstantF 289 constantD = constantAux cconstantF
289 extractR = extractAux c_extractF 290 extractR = extractAux c_extractF
291 setRect = setRectAux c_setRectF
290 sortI = sortIdxF 292 sortI = sortIdxF
291 sortV = sortValF 293 sortV = sortValF
292 compareV = compareF 294 compareV = compareF
@@ -298,6 +300,7 @@ instance Element Double where
298 transdata = transdataAux ctransR 300 transdata = transdataAux ctransR
299 constantD = constantAux cconstantR 301 constantD = constantAux cconstantR
300 extractR = extractAux c_extractD 302 extractR = extractAux c_extractD
303 setRect = setRectAux c_setRectD
301 sortI = sortIdxD 304 sortI = sortIdxD
302 sortV = sortValD 305 sortV = sortValD
303 compareV = compareD 306 compareV = compareD
@@ -310,6 +313,7 @@ instance Element (Complex Float) where
310 transdata = transdataAux ctransQ 313 transdata = transdataAux ctransQ
311 constantD = constantAux cconstantQ 314 constantD = constantAux cconstantQ
312 extractR = extractAux c_extractQ 315 extractR = extractAux c_extractQ
316 setRect = setRectAux c_setRectQ
313 sortI = undefined 317 sortI = undefined
314 sortV = undefined 318 sortV = undefined
315 compareV = undefined 319 compareV = undefined
@@ -322,6 +326,7 @@ instance Element (Complex Double) where
322 transdata = transdataAux ctransC 326 transdata = transdataAux ctransC
323 constantD = constantAux cconstantC 327 constantD = constantAux cconstantC
324 extractR = extractAux c_extractC 328 extractR = extractAux c_extractC
329 setRect = setRectAux c_setRectC
325 sortI = undefined 330 sortI = undefined
326 sortV = undefined 331 sortV = undefined
327 compareV = undefined 332 compareV = undefined
@@ -333,6 +338,7 @@ instance Element (CInt) where
333 transdata = transdataAux ctransI 338 transdata = transdataAux ctransI
334 constantD = constantAux cconstantI 339 constantD = constantAux cconstantI
335 extractR = extractAux c_extractI 340 extractR = extractAux c_extractI
341 setRect = setRectAux c_setRectI
336 sortI = sortIdxI 342 sortI = sortIdxI
337 sortV = sortValI 343 sortV = sortValI
338 compareV = compareI 344 compareV = compareI
@@ -344,6 +350,7 @@ instance Element Z where
344 transdata = transdataAux ctransL 350 transdata = transdataAux ctransL
345 constantD = constantAux cconstantL 351 constantD = constantAux cconstantL
346 extractR = extractAux c_extractL 352 extractR = extractAux c_extractL
353 setRect = setRectAux c_setRectL
347 sortI = sortIdxL 354 sortI = sortIdxL
348 sortV = sortValL 355 sortV = sortValL
349 compareV = compareL 356 compareV = compareL
@@ -454,6 +461,19 @@ foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float)
454foreign import ccall unsafe "extractI" c_extractI :: Extr CInt 461foreign import ccall unsafe "extractI" c_extractI :: Extr CInt
455foreign import ccall unsafe "extractL" c_extractL :: Extr Z 462foreign import ccall unsafe "extractL" c_extractL :: Extr Z
456 463
464---------------------------------------------------------------
465
466setRectAux f i j m r = app2 (f (fi i) (fi j)) omat m omat r "setRect"
467
468type SetRect x = I -> I -> x ::> x::> Ok
469
470foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double
471foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float
472foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double)
473foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float)
474foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I
475foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
476
457-------------------------------------------------------------------------------- 477--------------------------------------------------------------------------------
458 478
459sortG f v = unsafePerformIO $ do 479sortG f v = unsafePerformIO $ do
diff --git a/packages/base/src/Internal/Modular.hs b/packages/base/src/Internal/Modular.hs
index 824fc57..3b27310 100644
--- a/packages/base/src/Internal/Modular.hs
+++ b/packages/base/src/Internal/Modular.hs
@@ -33,12 +33,15 @@ import Internal.Element
33import Internal.Container 33import Internal.Container
34import Internal.Vectorized (prodI,sumI,prodL,sumL) 34import Internal.Vectorized (prodI,sumI,prodL,sumL)
35import Internal.LAPACK (multiplyI, multiplyL) 35import Internal.LAPACK (multiplyI, multiplyL)
36import Internal.Util(Indexable(..),gaussElim) 36import Internal.Algorithms(luFact)
37import Internal.Util(Normed(..),Indexable(..),gaussElim, gaussElim_1, gaussElim_2,luST, magnit)
38import Internal.ST(mutable)
37import GHC.TypeLits 39import GHC.TypeLits
38import Data.Proxy(Proxy) 40import Data.Proxy(Proxy)
39import Foreign.ForeignPtr(castForeignPtr) 41import Foreign.ForeignPtr(castForeignPtr)
40import Foreign.Storable 42import Foreign.Storable
41import Data.Ratio 43import Data.Ratio
44import Data.Complex
42 45
43 46
44 47
@@ -116,6 +119,7 @@ instance KnownNat m => Element (Mod m I)
116 transdata n v m = i2f (transdata n (f2i v) m) 119 transdata n v m = i2f (transdata n (f2i v) m)
117 constantD x n = i2f (constantD (unMod x) n) 120 constantD x n = i2f (constantD (unMod x) n)
118 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js 121 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js
122 setRect i j m x = setRect i j (f2iM m) (f2iM x)
119 sortI = sortI . f2i 123 sortI = sortI . f2i
120 sortV = i2f . sortV . f2i 124 sortV = i2f . sortV . f2i
121 compareV u v = compareV (f2i u) (f2i v) 125 compareV u v = compareV (f2i u) (f2i v)
@@ -130,6 +134,7 @@ instance KnownNat m => Element (Mod m Z)
130 transdata n v m = i2f (transdata n (f2i v) m) 134 transdata n v m = i2f (transdata n (f2i v) m)
131 constantD x n = i2f (constantD (unMod x) n) 135 constantD x n = i2f (constantD (unMod x) n)
132 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js 136 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js
137 setRect i j m x = setRect i j (f2iM m) (f2iM x)
133 sortI = sortI . f2i 138 sortI = sortI . f2i
134 sortV = i2f . sortV . f2i 139 sortV = i2f . sortV . f2i
135 compareV u v = compareV (f2i u) (f2i v) 140 compareV u v = compareV (f2i u) (f2i v)
@@ -139,18 +144,6 @@ instance KnownNat m => Element (Mod m Z)
139 where 144 where
140 m' = fromIntegral . natVal $ (undefined :: Proxy m) 145 m' = fromIntegral . natVal $ (undefined :: Proxy m)
141 146
142{-
143instance (Ord t, Element t) => Element (Mod m t)
144 where
145 transdata n v m = i2f (transdata n (f2i v) m)
146 constantD x n = i2f (constantD (unMod x) n)
147 extractR m mi is mj js = i2fM <$> extractR (f2iM m) mi is mj js
148 sortI = sortI . f2i
149 sortV = i2f . sortV . f2i
150 compareV u v = compareV (f2i u) (f2i v)
151 selectV c l e g = i2f (selectV c (f2i l) (f2i e) (f2i g))
152 remapM i j m = i2fM (remap i j (f2iM m))
153-}
154 147
155instance forall m . KnownNat m => Container Vector (Mod m I) 148instance forall m . KnownNat m => Container Vector (Mod m I)
156 where 149 where
@@ -258,6 +251,20 @@ instance KnownNat m => Product (Mod m Z) where
258 where 251 where
259 m' = fromIntegral . natVal $ (undefined :: Proxy m) 252 m' = fromIntegral . natVal $ (undefined :: Proxy m)
260 253
254instance KnownNat m => Normed (Vector (Mod m I))
255 where
256 norm_0 = norm_0 . toInt
257 norm_1 = norm_1 . toInt
258 norm_2 = norm_2 . toInt
259 norm_Inf = norm_Inf . toInt
260
261instance KnownNat m => Normed (Vector (Mod m Z))
262 where
263 norm_0 = norm_0 . toZ
264 norm_1 = norm_1 . toZ
265 norm_2 = norm_2 . toZ
266 norm_Inf = norm_Inf . toZ
267
261 268
262instance KnownNat m => Numeric (Mod m I) 269instance KnownNat m => Numeric (Mod m I)
263instance KnownNat m => Numeric (Mod m Z) 270instance KnownNat m => Numeric (Mod m Z)
@@ -334,6 +341,15 @@ test = (ok, info)
334 lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z 341 lg = (3><3) (repeat (3*10^(9::Int))) :: Matrix Z
335 lgm = fromZ lg :: Matrix (Mod 10000000000 Z) 342 lgm = fromZ lg :: Matrix (Mod 10000000000 Z)
336 343
344 gen n = diagRect 1 (konst 5 n) n n :: Numeric t => Matrix t
345
346 checkGen x = norm_Inf $ flatten $ invg x <> x - ident (rows x)
347
348 invg t = gaussElim t (ident (rows t))
349
350 checkLU okf t = norm_Inf $ flatten (l <> u <> p - t)
351 where
352 (l,u,p,_ :: Int) = luFact $ mutable (luST okf) t
337 353
338 info = do 354 info = do
339 print v 355 print v
@@ -356,11 +372,42 @@ test = (ok, info)
356 print $ lg <> lg 372 print $ lg <> lg
357 print lgm 373 print lgm
358 print $ lgm <> lgm 374 print $ lgm <> lgm
375
376 print (checkGen (gen 5 :: Matrix R))
377 print (checkGen (gen 5 :: Matrix C))
378 print (checkGen (gen 5 :: Matrix Float))
379 print (checkGen (gen 5 :: Matrix (Complex Float)))
380 print (invg (gen 5) :: Matrix (Mod 7 I))
381 print (invg (gen 5) :: Matrix (Mod 7 Z))
382
383 print $ mutable (luST (const True)) (gen 5 :: Matrix R)
384 print $ mutable (luST (const True)) (gen 5 :: Matrix (Mod 11 Z))
385
386 print $ checkLU (magnit 0) (gen 5 :: Matrix R)
387 print $ checkLU (magnit 0) (gen 5 :: Matrix Float)
388 print $ checkLU (magnit 0) (gen 5 :: Matrix C)
389 print $ checkLU (magnit 0) (gen 5 :: Matrix (Complex Float))
390 print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))
391 print $ checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))
359 392
360 393
361 ok = and 394 ok = and
362 [ toInt (m #> v) == cmod 11 (toInt m #> toInt v ) 395 [ toInt (m #> v) == cmod 11 (toInt m #> toInt v )
363 , am <> gaussElim am bm == bm 396 , am <> gaussElim_1 am bm == bm
397 , am <> gaussElim_2 am bm == bm
398 , am <> gaussElim am bm == bm
399 , (checkGen (gen 5 :: Matrix R)) < 1E-15
400 , (checkGen (gen 5 :: Matrix Float)) < 1E-7
401 , (checkGen (gen 5 :: Matrix C)) < 1E-15
402 , (checkGen (gen 5 :: Matrix (Complex Float))) < 1E-7
403 , (checkGen (gen 5 :: Matrix (Mod 7 I))) == 0
404 , (checkGen (gen 5 :: Matrix (Mod 7 Z))) == 0
405 , (checkLU (magnit 1E-10) (gen 5 :: Matrix R)) < 1E-15
406 , (checkLU (magnit 1E-5) (gen 5 :: Matrix Float)) < 1E-6
407 , (checkLU (magnit 1E-10) (gen 5 :: Matrix C)) < 1E-15
408 , (checkLU (magnit 1E-5) (gen 5 :: Matrix (Complex Float))) < 1E-6
409 , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 I))) == 0
410 , (checkLU (magnit 0) (gen 5 :: Matrix (Mod 7 Z))) == 0
364 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I)) 411 , prodElements (konst (9:: Mod 10 I) (12::Int)) == product (replicate 12 (9:: Mod 10 I))
365 , gm <> gm == konst 0 (3,3) 412 , gm <> gm == konst 0 (3,3)
366 , lgm <> lgm == konst 0 (3,3) 413 , lgm <> lgm == konst 0 (3,3)
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 107d3c3..a84ca25 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -21,7 +21,7 @@ module Internal.ST (
21 -- * Mutable Matrices 21 -- * Mutable Matrices
22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 22 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 23 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
24 axpy, scal, swap, extractRect, 24 axpy, scal, swap, extractMatrix, setMatrix, rowOpST,
25 mutable, 25 mutable,
26 -- * Unsafe functions 26 -- * Unsafe functions
27 newUndefinedVector, 27 newUndefinedVector,
@@ -166,6 +166,9 @@ readMatrix = safeIndexM unsafeReadMatrix
166writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s () 166writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
167writeMatrix = safeIndexM unsafeWriteMatrix 167writeMatrix = safeIndexM unsafeWriteMatrix
168 168
169setMatrix :: Element t => STMatrix s t -> Int -> Int -> Matrix t -> ST s ()
170setMatrix (STMatrix x) i j m = unsafeIOToST $ setRect i j m x
171
169newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t) 172newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
170newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c 173newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
171 174
@@ -182,7 +185,7 @@ axpy (STMatrix m) a i j = rowOpST 0 a i j 0 (cols m -1) (STMatrix m)
182scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m) 185scal (STMatrix m) a i = rowOpST 1 a i i 0 (cols m -1) (STMatrix m)
183swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m) 186swap (STMatrix m) i j = rowOpST 2 0 i j 0 (cols m -1) (STMatrix m)
184 187
185extractRect (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2])) 188extractMatrix (STMatrix m) i1 i2 j1 j2 = unsafeIOToST (extractR m 0 (idxs[i1,i2]) 0 (idxs[j1,j2]))
186 189
187-------------------------------------------------------------------------------- 190--------------------------------------------------------------------------------
188 191
diff --git a/packages/base/src/Internal/Util.hs b/packages/base/src/Internal/Util.hs
index 7a556e9..2650ac8 100644
--- a/packages/base/src/Internal/Util.hs
+++ b/packages/base/src/Internal/Util.hs
@@ -41,6 +41,7 @@ module Internal.Util(
41 norm, 41 norm,
42 ℕ,ℤ,ℝ,ℂ,iC, 42 ℕ,ℤ,ℝ,ℂ,iC,
43 Normed(..), norm_Frob, norm_nuclear, 43 Normed(..), norm_Frob, norm_nuclear,
44 magnit,
44 unitary, 45 unitary,
45 mt, 46 mt,
46 (~!~), 47 (~!~),
@@ -54,7 +55,7 @@ module Internal.Util(
54 -- ** 2D 55 -- ** 2D
55 corr2, conv2, separable, 56 corr2, conv2, separable,
56 block2x2,block3x3,view1,unView1,foldMatrix, 57 block2x2,block3x3,view1,unView1,foldMatrix,
57 gaussElim_1, gaussElim_2, gaussElim 58 gaussElim_1, gaussElim_2, gaussElim, luST
58) where 59) where
59 60
60import Internal.Vector 61import Internal.Vector
@@ -300,6 +301,26 @@ instance Normed (Vector I)
300 norm_2 v = sqrt . fromIntegral $ dot v v 301 norm_2 v = sqrt . fromIntegral $ dot v v
301 norm_Inf = fromIntegral . normInf 302 norm_Inf = fromIntegral . normInf
302 303
304instance Normed (Vector Z)
305 where
306 norm_0 = fromIntegral . sumElements . step . abs
307 norm_1 = fromIntegral . norm1
308 norm_2 v = sqrt . fromIntegral $ dot v v
309 norm_Inf = fromIntegral . normInf
310
311instance Normed (Vector Float)
312 where
313 norm_0 = norm_0 . double
314 norm_1 = norm_1 . double
315 norm_2 = norm_2 . double
316 norm_Inf = norm_Inf . double
317
318instance Normed (Vector (Complex Float))
319 where
320 norm_0 = norm_0 . double
321 norm_1 = norm_1 . double
322 norm_2 = norm_2 . double
323 norm_Inf = norm_Inf . double
303 324
304 325
305norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> ℝ 326norm_Frob :: (Normed (Vector t), Element t) => Matrix t -> ℝ
@@ -308,6 +329,9 @@ norm_Frob = norm_2 . flatten
308norm_nuclear :: Field t => Matrix t -> ℝ 329norm_nuclear :: Field t => Matrix t -> ℝ
309norm_nuclear = sumElements . singularValues 330norm_nuclear = sumElements . singularValues
310 331
332magnit :: (Element t, Normed (Vector t)) => R -> t -> Bool
333magnit e x = norm_1 (fromList [x]) > e
334
311 335
312-- | Obtains a vector in the same direction with 2-norm=1 336-- | Obtains a vector in the same direction with 2-norm=1
313unitary :: Vector Double -> Vector Double 337unitary :: Vector Double -> Vector Double
@@ -618,9 +642,10 @@ gaussElim a b = dropColumns (rows a) $ fst $ mutable gaussST (fromBlocks [[a,b]]
618gaussST (r,_) x = do 642gaussST (r,_) x = do
619 let n = r-1 643 let n = r-1
620 forM_ [0..n] $ \i -> do 644 forM_ [0..n] $ \i -> do
621 c <- maxIndex . abs . flatten <$> extractRect x i n i i 645 c <- maxIndex . abs . flatten <$> extractMatrix x i n i i
622 swap x i (i+c) 646 swap x i (i+c)
623 a <- readMatrix x i i 647 a <- readMatrix x i i
648 when (a == 0) $ error "singular!"
624 scal x (recip a) i 649 scal x (recip a) i
625 forM_ [i+1..n] $ \j -> do 650 forM_ [i+1..n] $ \j -> do
626 b <- readMatrix x j i 651 b <- readMatrix x j i
@@ -630,6 +655,25 @@ gaussST (r,_) x = do
630 b <- readMatrix x j i 655 b <- readMatrix x j i
631 axpy x (-b) i j 656 axpy x (-b) i j
632 657
658
659luST ok (r,c) x = do
660 let n = r-1
661 axpy' m a i j = rowOpST 0 a i j (i+1) (c-1) m
662 p <- thawMatrix . asColumn . range $ r
663 forM_ [0..n] $ \i -> do
664 k <- maxIndex . abs . flatten <$> extractMatrix x i n i i
665 writeMatrix p i 0 (fi (k+i))
666 swap x i (i+k)
667 a <- readMatrix x i i
668 when (ok a) $ do
669 forM_ [i+1..n] $ \j -> do
670 b <- (/a) <$> readMatrix x j i
671 axpy' x (-b) i j
672 writeMatrix x j i b
673 v <- unsafeFreezeMatrix p
674 return (map ti $ toList $ flatten v)
675
676
633-------------------------------------------------------------------------------- 677--------------------------------------------------------------------------------
634 678
635instance Testable (Matrix I) where 679instance Testable (Matrix I) where
diff --git a/packages/base/src/Numeric/LinearAlgebra.hs b/packages/base/src/Numeric/LinearAlgebra.hs
index c97f415..0f8efa4 100644
--- a/packages/base/src/Numeric/LinearAlgebra.hs
+++ b/packages/base/src/Numeric/LinearAlgebra.hs
@@ -1,3 +1,5 @@
1{-# LANGUAGE FlexibleContexts #-}
2
1----------------------------------------------------------------------------- 3-----------------------------------------------------------------------------
2{- | 4{- |
3Module : Numeric.LinearAlgebra 5Module : Numeric.LinearAlgebra
@@ -119,7 +121,7 @@ module Numeric.LinearAlgebra (
119 schur, 121 schur,
120 122
121 -- * LU 123 -- * LU
122 lu, luPacked, 124 lu, luPacked, luFact, luPacked',
123 125
124 -- * Matrix functions 126 -- * Matrix functions
125 expm, 127 expm,
@@ -134,7 +136,7 @@ module Numeric.LinearAlgebra (
134 Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, 136 Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample,
135 137
136 -- * Misc 138 -- * Misc
137 meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, gaussElim_1, gaussElim_2, 139 meanCov, rowOuters, pairwiseD2, unitary, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, gaussElim, luST, magnit,
138 ℝ,ℂ,iC, 140 ℝ,ℂ,iC,
139 -- * Auxiliary classes 141 -- * Auxiliary classes
140 Element, Container, Product, Numeric, LSDiv, 142 Element, Container, Product, Numeric, LSDiv,
@@ -142,7 +144,6 @@ module Numeric.LinearAlgebra (
142 RealOf, ComplexOf, SingleOf, DoubleOf, 144 RealOf, ComplexOf, SingleOf, DoubleOf,
143 IndexOf, 145 IndexOf,
144 Field, 146 Field,
145-- Normed,
146 Transposable, 147 Transposable,
147 CGState(..), 148 CGState(..),
148 Testable(..) 149 Testable(..)
@@ -155,13 +156,14 @@ import Numeric.Vector()
155import Internal.Matrix 156import Internal.Matrix
156import Internal.Container hiding ((<>)) 157import Internal.Container hiding ((<>))
157import Internal.Numeric hiding (mul) 158import Internal.Numeric hiding (mul)
158import Internal.Algorithms hiding (linearSolve,Normed,orth) 159import Internal.Algorithms hiding (linearSolve,Normed,orth,luPacked')
159import qualified Internal.Algorithms as A 160import qualified Internal.Algorithms as A
160import Internal.Util 161import Internal.Util
161import Internal.Random 162import Internal.Random
162import Internal.Sparse((!#>)) 163import Internal.Sparse((!#>))
163import Internal.CG 164import Internal.CG
164import Internal.Conversion 165import Internal.Conversion
166import Internal.ST(mutable)
165 167
166{- | infix synonym of 'mul' 168{- | infix synonym of 'mul'
167 169
@@ -236,3 +238,5 @@ nullspace m = nullspaceSVD (Left (1*eps)) m (rightSV m)
236-- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. 238-- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'.
237orth m = orthSVD (Left (1*eps)) m (leftSV m) 239orth m = orthSVD (Left (1*eps)) m (leftSV m)
238 240
241luPacked' x = mutable (luST (magnit 0)) x
242
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
index 84763fe..f572656 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
@@ -44,7 +44,7 @@ module Numeric.LinearAlgebra.Devel(
44 -- ** Mutable Matrices 44 -- ** Mutable Matrices
45 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix, 45 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
46 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix, 46 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
47 axpy,scal,swap, extractRect, mutable, 47 axpy,scal,swap, extractMatrix, setMatrix, mutable, rowOpST,
48 -- ** Unsafe functions 48 -- ** Unsafe functions
49 newUndefinedVector, 49 newUndefinedVector,
50 unsafeReadVector, unsafeWriteVector, 50 unsafeReadVector, unsafeWriteVector,