diff options
author | Alberto Ruiz <aruiz@um.es> | 2010-09-11 09:10:24 +0000 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2010-09-11 09:10:24 +0000 |
commit | ec9965371be5b37234684ba392f55a1a1e24f053 (patch) | |
tree | fae00b984fc8499e19952200f1b3f9f7ee5f2d20 | |
parent | ae6d18808cef554979b99cc55f46d5324518df01 (diff) |
optimized conjugate
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 25 | ||||
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 2 | ||||
-rw-r--r-- | lib/Numeric/Conversion.hs | 9 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | 25 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 10 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests/Properties.hs | 22 |
6 files changed, 71 insertions, 22 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs index 5892f1c..d39481d 100644 --- a/lib/Data/Packed/Internal/Matrix.hs +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -255,27 +255,27 @@ class (Storable a, Floating a) => Element a where | |||
255 | transdata = transdata' | 255 | transdata = transdata' |
256 | constantD :: a -> Int -> Vector a | 256 | constantD :: a -> Int -> Vector a |
257 | constantD = constant' | 257 | constantD = constant' |
258 | ctrans' :: Matrix a -> Matrix a | 258 | conjugateD :: Vector a -> Vector a |
259 | 259 | ||
260 | instance Element Float where | 260 | instance Element Float where |
261 | transdata = transdataAux ctransF | 261 | transdata = transdataAux ctransF |
262 | constantD = constantAux cconstantF | 262 | constantD = constantAux cconstantF |
263 | ctrans' = trans | 263 | conjugateD = id |
264 | 264 | ||
265 | instance Element Double where | 265 | instance Element Double where |
266 | transdata = transdataAux ctransR | 266 | transdata = transdataAux ctransR |
267 | constantD = constantAux cconstantR | 267 | constantD = constantAux cconstantR |
268 | ctrans' = trans | 268 | conjugateD = id |
269 | 269 | ||
270 | instance Element (Complex Float) where | 270 | instance Element (Complex Float) where |
271 | transdata = transdataAux ctransQ | 271 | transdata = transdataAux ctransQ |
272 | constantD = constantAux cconstantQ | 272 | constantD = constantAux cconstantQ |
273 | ctrans' = liftMatrix (mapVector conjugate) . trans | 273 | conjugateD = conjugateQ |
274 | 274 | ||
275 | instance Element (Complex Double) where | 275 | instance Element (Complex Double) where |
276 | transdata = transdataAux ctransC | 276 | transdata = transdataAux ctransC |
277 | constantD = constantAux cconstantC | 277 | constantD = constantAux cconstantC |
278 | ctrans' = liftMatrix (mapVector conjugate) . trans | 278 | conjugateD = conjugateC |
279 | 279 | ||
280 | ------------------------------------------------------------------- | 280 | ------------------------------------------------------------------- |
281 | 281 | ||
@@ -359,6 +359,21 @@ constantC :: Complex Double -> Int -> Vector (Complex Double) | |||
359 | constantC = constantAux cconstantC | 359 | constantC = constantAux cconstantC |
360 | foreign import ccall "constantC" cconstantC :: Ptr (Complex Double) -> TCV | 360 | foreign import ccall "constantC" cconstantC :: Ptr (Complex Double) -> TCV |
361 | 361 | ||
362 | --------------------------------------- | ||
363 | |||
364 | conjugateAux fun x = unsafePerformIO $ do | ||
365 | v <- createVector (dim x) | ||
366 | app2 fun vec x vec v "conjugateAux" | ||
367 | return v | ||
368 | |||
369 | conjugateQ :: Vector (Complex Float) -> Vector (Complex Float) | ||
370 | conjugateQ = conjugateAux c_conjugateQ | ||
371 | foreign import ccall "conjugateQ" c_conjugateQ :: TQVQV | ||
372 | |||
373 | conjugateC :: Vector (Complex Double) -> Vector (Complex Double) | ||
374 | conjugateC = conjugateAux c_conjugateC | ||
375 | foreign import ccall "conjugateC" c_conjugateC :: TCVCV | ||
376 | |||
362 | ---------------------------------------------------------------------- | 377 | ---------------------------------------------------------------------- |
363 | 378 | ||
364 | -- | Extracts a submatrix from a matrix. | 379 | -- | Extracts a submatrix from a matrix. |
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index af937f4..0fc7876 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -270,7 +270,7 @@ asColumn v = reshape 1 v | |||
270 | 270 | ||
271 | -- | conjugate transpose | 271 | -- | conjugate transpose |
272 | ctrans :: Element e => Matrix e -> Matrix e | 272 | ctrans :: Element e => Matrix e -> Matrix e |
273 | ctrans = ctrans' | 273 | ctrans = liftMatrix conjugateD . trans |
274 | 274 | ||
275 | 275 | ||
276 | {- | creates a Matrix of the specified size using the supplied function to | 276 | {- | creates a Matrix of the specified size using the supplied function to |
diff --git a/lib/Numeric/Conversion.hs b/lib/Numeric/Conversion.hs index b05069c..809ac51 100644 --- a/lib/Numeric/Conversion.hs +++ b/lib/Numeric/Conversion.hs | |||
@@ -69,7 +69,6 @@ class ComplexContainer c where | |||
69 | fromComplex :: (RealElement e) => c (Complex e) -> (c e, c e) | 69 | fromComplex :: (RealElement e) => c (Complex e) -> (c e, c e) |
70 | comp :: (RealElement e) => c e -> c (Complex e) | 70 | comp :: (RealElement e) => c e -> c (Complex e) |
71 | conj :: (RealElement e) => c (Complex e) -> c (Complex e) | 71 | conj :: (RealElement e) => c (Complex e) -> c (Complex e) |
72 | -- cmap :: (Element a, Element b) => (a -> b) -> c a -> c b | ||
73 | single' :: Precision a b => c b -> c a | 72 | single' :: Precision a b => c b -> c a |
74 | double' :: Precision a b => c a -> c b | 73 | double' :: Precision a b => c a -> c b |
75 | 74 | ||
@@ -78,16 +77,11 @@ instance ComplexContainer Vector where | |||
78 | toComplex = toComplexV | 77 | toComplex = toComplexV |
79 | fromComplex = fromComplexV | 78 | fromComplex = fromComplexV |
80 | comp v = toComplex (v,constantD 0 (dim v)) | 79 | comp v = toComplex (v,constantD 0 (dim v)) |
81 | conj = conjV | 80 | conj = conjugateD |
82 | -- cmap = mapVector | ||
83 | single' = double2FloatG | 81 | single' = double2FloatG |
84 | double' = float2DoubleG | 82 | double' = float2DoubleG |
85 | 83 | ||
86 | 84 | ||
87 | -- | obtains the complex conjugate of a complex vector | ||
88 | conjV :: (RealElement a) => Vector (Complex a) -> Vector (Complex a) | ||
89 | conjV = mapVector conjugate | ||
90 | |||
91 | -- | creates a complex vector from vectors with real and imaginary parts | 85 | -- | creates a complex vector from vectors with real and imaginary parts |
92 | toComplexV :: (RealElement a) => (Vector a, Vector a) -> Vector (Complex a) | 86 | toComplexV :: (RealElement a) => (Vector a, Vector a) -> Vector (Complex a) |
93 | toComplexV (r,i) = asComplex $ flatten $ fromColumns [r,i] | 87 | toComplexV (r,i) = asComplex $ flatten $ fromColumns [r,i] |
@@ -104,7 +98,6 @@ instance ComplexContainer Matrix where | |||
104 | where c = cols z | 98 | where c = cols z |
105 | comp = liftMatrix comp | 99 | comp = liftMatrix comp |
106 | conj = liftMatrix conj | 100 | conj = liftMatrix conj |
107 | -- cmap f = liftMatrix (cmap f) | ||
108 | single' = liftMatrix single' | 101 | single' = liftMatrix single' |
109 | double' = liftMatrix double' | 102 | double' = liftMatrix double' |
110 | 103 | ||
diff --git a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c index 9e44431..2c4c647 100644 --- a/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c +++ b/lib/Numeric/LinearAlgebra/LAPACK/lapack-aux.c | |||
@@ -1201,3 +1201,28 @@ int double2float(DVEC(x),FVEC(y)) { | |||
1201 | } | 1201 | } |
1202 | OK | 1202 | OK |
1203 | } | 1203 | } |
1204 | |||
1205 | //////////////////// conjugate ///////////////////////// | ||
1206 | |||
1207 | int conjugateQ(KQVEC(x),QVEC(t)) { | ||
1208 | REQUIRES(xn==tn,BAD_SIZE); | ||
1209 | DEBUGMSG("conjugateQ"); | ||
1210 | int k; | ||
1211 | for(k=0;k<xn;k++) { | ||
1212 | ((complex*)tp)[k].r=((complex*)xp)[k].r; | ||
1213 | ((complex*)tp)[k].i=-((complex*)xp)[k].i; | ||
1214 | } | ||
1215 | OK | ||
1216 | } | ||
1217 | |||
1218 | int conjugateC(KCVEC(x),CVEC(t)) { | ||
1219 | REQUIRES(xn==tn,BAD_SIZE); | ||
1220 | DEBUGMSG("conjugateC"); | ||
1221 | int k; | ||
1222 | for(k=0;k<xn;k++) { | ||
1223 | ((doublecomplex*)tp)[k].r=((doublecomplex*)xp)[k].r; | ||
1224 | ((doublecomplex*)tp)[k].i=-((doublecomplex*)xp)[k].i; | ||
1225 | } | ||
1226 | OK | ||
1227 | } | ||
1228 | |||
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 0b4e3bf..426700b 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -271,6 +271,10 @@ normsMTest = TestList [ | |||
271 | 271 | ||
272 | --------------------------------------------------------------------- | 272 | --------------------------------------------------------------------- |
273 | 273 | ||
274 | conjuTest m = mapVector conjugate (flatten (trans m)) == flatten (ctrans m) | ||
275 | |||
276 | --------------------------------------------------------------------- | ||
277 | |||
274 | 278 | ||
275 | -- | All tests must pass with a maximum dimension of about 20 | 279 | -- | All tests must pass with a maximum dimension of about 20 |
276 | -- (some tests may fail with bigger sizes due to precision loss). | 280 | -- (some tests may fail with bigger sizes due to precision loss). |
@@ -292,6 +296,9 @@ runTests n = do | |||
292 | putStrLn "------ sub-trans" | 296 | putStrLn "------ sub-trans" |
293 | test (subProp . rM) | 297 | test (subProp . rM) |
294 | test (subProp . cM) | 298 | test (subProp . cM) |
299 | putStrLn "------ ctrans" | ||
300 | test (conjuTest . cM) | ||
301 | test (conjuTest . zM) | ||
295 | putStrLn "------ lu" | 302 | putStrLn "------ lu" |
296 | test (luProp . rM) | 303 | test (luProp . rM) |
297 | test (luProp . cM) | 304 | test (luProp . cM) |
@@ -362,6 +369,9 @@ runTests n = do | |||
362 | test (qrProp . cM) | 369 | test (qrProp . cM) |
363 | test (rqProp . rM) | 370 | test (rqProp . rM) |
364 | test (rqProp . cM) | 371 | test (rqProp . cM) |
372 | test (rqProp1 . cM) | ||
373 | test (rqProp2 . cM) | ||
374 | test (rqProp3 . cM) | ||
365 | putStrLn "------ hess" | 375 | putStrLn "------ hess" |
366 | test (hessProp . rSq) | 376 | test (hessProp . rSq) |
367 | test (hessProp . cSq) | 377 | test (hessProp . cSq) |
diff --git a/lib/Numeric/LinearAlgebra/Tests/Properties.hs b/lib/Numeric/LinearAlgebra/Tests/Properties.hs index b96f53e..e780c35 100644 --- a/lib/Numeric/LinearAlgebra/Tests/Properties.hs +++ b/lib/Numeric/LinearAlgebra/Tests/Properties.hs | |||
@@ -32,7 +32,7 @@ module Numeric.LinearAlgebra.Tests.Properties ( | |||
32 | svdProp1, svdProp1a, svdProp1b, svdProp2, svdProp3, svdProp4, | 32 | svdProp1, svdProp1a, svdProp1b, svdProp2, svdProp3, svdProp4, |
33 | svdProp5a, svdProp5b, svdProp6a, svdProp6b, svdProp7, | 33 | svdProp5a, svdProp5b, svdProp6a, svdProp6b, svdProp7, |
34 | eigProp, eigSHProp, eigProp2, eigSHProp2, | 34 | eigProp, eigSHProp, eigProp2, eigSHProp2, |
35 | qrProp, rqProp, | 35 | qrProp, rqProp, rqProp1, rqProp2, rqProp3, |
36 | hessProp, | 36 | hessProp, |
37 | schurProp1, schurProp2, | 37 | schurProp1, schurProp2, |
38 | cholProp, | 38 | cholProp, |
@@ -210,15 +210,21 @@ eigSHProp2 m = fst (eigSH m) |~| eigenvaluesSH m | |||
210 | qrProp m = q <> r |~| m && unitary q && upperTriang r | 210 | qrProp m = q <> r |~| m && unitary q && upperTriang r |
211 | where (q,r) = qr m | 211 | where (q,r) = qr m |
212 | 212 | ||
213 | rqProp m = r <> q |~| m && unitary q && utr | 213 | rqProp m = r <> q |~| m && unitary q && upperTriang' r |
214 | where (r,q) = rq m | 214 | where (r,q) = rq m |
215 | upptr f c = buildMatrix f c $ \(r',c') -> if r'-t > c' then 0 else 1 | ||
216 | where t = f-c | ||
217 | utr = upptr (rows r) (cols r) * r |~| r | ||
218 | 215 | ||
219 | upperTriang' m = rows m == 1 || down |~| z | 216 | rqProp1 m = r <> q |~| m |
220 | where down = fromList $ concat $ zipWith drop [1..] (toLists (ctrans m)) | 217 | where (r,q) = rq m |
221 | z = constant 0 (dim down) | 218 | |
219 | rqProp2 m = unitary q | ||
220 | where (r,q) = rq m | ||
221 | |||
222 | rqProp3 m = upperTriang' r | ||
223 | where (r,q) = rq m | ||
224 | |||
225 | upperTriang' r = upptr (rows r) (cols r) * r |~| r | ||
226 | where upptr f c = buildMatrix f c $ \(r',c') -> if r'-t > c' then 0 else 1 | ||
227 | where t = f-c | ||
222 | 228 | ||
223 | hessProp m = m |~| p <> h <> ctrans p && unitary p && upperHessenberg h | 229 | hessProp m = m |~| p <> h <> ctrans p && unitary p && upperHessenberg h |
224 | where (p,h) = hess m | 230 | where (p,h) = hess m |