diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Data/Packed/Matrix.hs | 27 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Instances.hs | 31 | ||||
-rw-r--r-- | lib/Numeric/LinearAlgebra/Tests.hs | 1 |
3 files changed, 33 insertions, 26 deletions
diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 4cb7a88..66e5082 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs | |||
@@ -27,7 +27,7 @@ module Data.Packed.Matrix ( | |||
27 | subMatrix, takeRows, dropRows, takeColumns, dropColumns, | 27 | subMatrix, takeRows, dropRows, takeColumns, dropColumns, |
28 | extractRows, | 28 | extractRows, |
29 | ident, diag, diagRect, takeDiag, | 29 | ident, diag, diagRect, takeDiag, |
30 | liftMatrix, liftMatrix2, | 30 | liftMatrix, liftMatrix2, liftMatrix2Auto, |
31 | format, dispf, disps, | 31 | format, dispf, disps, |
32 | loadMatrix, saveMatrix, fromFile, fileDimensions, | 32 | loadMatrix, saveMatrix, fromFile, fileDimensions, |
33 | readMatrix, fromArray2D | 33 | readMatrix, fromArray2D |
@@ -386,3 +386,28 @@ extractRows l m = fromRows $ extract (toRows $ m) l | |||
386 | -} | 386 | -} |
387 | repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t | 387 | repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t |
388 | repmat m r c = fromBlocks $ partit c $ replicate (r*c) m | 388 | repmat m r c = fromBlocks $ partit c $ replicate (r*c) m |
389 | |||
390 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. | ||
391 | liftMatrix2Auto :: (Element t, Element a, Element b) | ||
392 | => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | ||
393 | liftMatrix2Auto f m1 m2 | compat' m1 m2 = lM f m1 m2 | ||
394 | | rows m1 == rows m2 && cols m2 == 1 = lM f m1 (repCols (cols m1) m2) | ||
395 | | rows m1 == rows m2 && cols m1 == 1 = lM f (repCols (cols m2) m1) m2 | ||
396 | | cols m1 == cols m2 && rows m2 == 1 = lM f m1 (repRows (rows m1) m2) | ||
397 | | cols m1 == cols m2 && cols m1 == 1 = lM f (repRows (rows m2) m1) m2 | ||
398 | | rows m1 == 1 && cols m2 == 1 = lM f (repRows (rows m2) m1) (repCols (cols m1) m2) | ||
399 | | cols m1 == 1 && rows m2 == 1 = lM f (repCols (cols m2) m1) (repRows (rows m1) m2) | ||
400 | | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ show (size m1) ++ ", " ++ show (size m2) | ||
401 | |||
402 | size m = (rows m, cols m) | ||
403 | |||
404 | lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) | ||
405 | |||
406 | repRows n x = fromRows (replicate n (flatten x)) | ||
407 | repCols n x = fromColumns (replicate n (flatten x)) | ||
408 | |||
409 | compat' :: Matrix a -> Matrix b -> Bool | ||
410 | compat' m1 m2 = rows m1 == 1 && cols m1 == 1 | ||
411 | || rows m2 == 1 && cols m2 == 1 | ||
412 | || rows m1 == rows m2 && cols m1 == cols m2 | ||
413 | |||
diff --git a/lib/Numeric/LinearAlgebra/Instances.hs b/lib/Numeric/LinearAlgebra/Instances.hs index 1f8b5a0..67496f2 100644 --- a/lib/Numeric/LinearAlgebra/Instances.hs +++ b/lib/Numeric/LinearAlgebra/Instances.hs | |||
@@ -71,26 +71,6 @@ adaptScalar f1 f2 f3 x y | |||
71 | | dim y == 1 = f3 x (y@>0) | 71 | | dim y == 1 = f3 x (y@>0) |
72 | | otherwise = f2 x y | 72 | | otherwise = f2 x y |
73 | 73 | ||
74 | liftMatrix2' :: (Element t, Element a, Element b) | ||
75 | => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | ||
76 | liftMatrix2' f m1 m2 | compat' m1 m2 = lM f m1 m2 | ||
77 | | rows m1 == rows m2 && cols m2 == 1 = lM f m1 (repCols (cols m1) m2) | ||
78 | | rows m1 == rows m2 && cols m1 == 1 = lM f (repCols (cols m2) m1) m2 | ||
79 | | cols m1 == cols m2 && rows m2 == 1 = lM f m1 (repRows (rows m1) m2) | ||
80 | | cols m1 == cols m2 && cols m1 == 1 = lM f (repRows (rows m2) m1) m2 | ||
81 | | rows m1 == 1 && cols m2 == 1 = lM f (repRows (rows m2) m1) (repCols (cols m1) m2) | ||
82 | | cols m1 == 1 && rows m2 == 1 = lM f (repCols (cols m2) m1) (repRows (rows m1) m2) | ||
83 | | otherwise = error "nonconformable matrices in liftMatrix2'" | ||
84 | |||
85 | lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2)) | ||
86 | |||
87 | repRows n x = fromRows (replicate n (flatten x)) | ||
88 | repCols n x = fromColumns (replicate n (flatten x)) | ||
89 | |||
90 | compat' :: Matrix a -> Matrix b -> Bool | ||
91 | compat' m1 m2 = rows m1 == 1 && cols m1 == 1 | ||
92 | || rows m2 == 1 && cols m2 == 1 | ||
93 | || rows m1 == rows m2 && cols m1 == cols m2 | ||
94 | 74 | ||
95 | instance Linear Vector a => Eq (Vector a) where | 75 | instance Linear Vector a => Eq (Vector a) where |
96 | (==) = equal | 76 | (==) = equal |
@@ -115,10 +95,10 @@ instance Linear Matrix a => Eq (Matrix a) where | |||
115 | (==) = equal | 95 | (==) = equal |
116 | 96 | ||
117 | instance (Linear Matrix a, Num (Vector a)) => Num (Matrix a) where | 97 | instance (Linear Matrix a, Num (Vector a)) => Num (Matrix a) where |
118 | (+) = liftMatrix2' (+) | 98 | (+) = liftMatrix2Auto (+) |
119 | (-) = liftMatrix2' (-) | 99 | (-) = liftMatrix2Auto (-) |
120 | negate = liftMatrix negate | 100 | negate = liftMatrix negate |
121 | (*) = liftMatrix2' (*) | 101 | (*) = liftMatrix2Auto (*) |
122 | signum = liftMatrix signum | 102 | signum = liftMatrix signum |
123 | abs = liftMatrix abs | 103 | abs = liftMatrix abs |
124 | fromInteger = (1><1) . return . fromInteger | 104 | fromInteger = (1><1) . return . fromInteger |
@@ -135,7 +115,7 @@ instance (Linear Vector a, Num (Vector a)) => Fractional (Vector a) where | |||
135 | 115 | ||
136 | instance (Linear Vector a, Fractional (Vector a), Num (Matrix a)) => Fractional (Matrix a) where | 116 | instance (Linear Vector a, Fractional (Vector a), Num (Matrix a)) => Fractional (Matrix a) where |
137 | fromRational n = (1><1) [fromRational n] | 117 | fromRational n = (1><1) [fromRational n] |
138 | (/) = liftMatrix2' (/) | 118 | (/) = liftMatrix2Auto (/) |
139 | 119 | ||
140 | --------------------------------------------------------- | 120 | --------------------------------------------------------- |
141 | 121 | ||
@@ -196,7 +176,7 @@ instance (Linear Vector a, Floating (Vector a), Fractional (Matrix a)) => Floati | |||
196 | atanh = liftMatrix atanh | 176 | atanh = liftMatrix atanh |
197 | exp = liftMatrix exp | 177 | exp = liftMatrix exp |
198 | log = liftMatrix log | 178 | log = liftMatrix log |
199 | (**) = liftMatrix2' (**) | 179 | (**) = liftMatrix2Auto (**) |
200 | sqrt = liftMatrix sqrt | 180 | sqrt = liftMatrix sqrt |
201 | pi = (1><1) [pi] | 181 | pi = (1><1) [pi] |
202 | 182 | ||
@@ -216,3 +196,4 @@ instance (Storable a, Num (Vector a)) => Monoid (Vector a) where | |||
216 | -- | 196 | -- |
217 | -- instance (NFData a, Element a) => NFData (Matrix a) where | 197 | -- instance (NFData a, Element a) => NFData (Matrix a) where |
218 | -- rnf = rnf . flatten | 198 | -- rnf = rnf . flatten |
199 | |||
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs index 89c8297..e6b26a0 100644 --- a/lib/Numeric/LinearAlgebra/Tests.hs +++ b/lib/Numeric/LinearAlgebra/Tests.hs | |||
@@ -366,3 +366,4 @@ svdBench = do | |||
366 | time "full svd 3000x500" (fv $ svd a) | 366 | time "full svd 3000x500" (fv $ svd a) |
367 | time "singular values 1000x1000" (singularValues b) | 367 | time "singular values 1000x1000" (singularValues b) |
368 | time "full svd 1000x1000" (fv $ svd b) | 368 | time "full svd 1000x1000" (fv $ svd b) |
369 | |||