summaryrefslogtreecommitdiff
path: root/lib/Numeric/LinearAlgebra
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2010-02-01 10:29:20 +0000
committerAlberto Ruiz <aruiz@um.es>2010-02-01 10:29:20 +0000
commit2a3baa18b508851a9d30e4abc7d7de7978146557 (patch)
treeadc4a32b9461678bb468b1ce48cb9759679d5a61 /lib/Numeric/LinearAlgebra
parent6f4137cabbc16fa616e823db3d1b2cf90c03e5c9 (diff)
export liftMatrix2Auto
Diffstat (limited to 'lib/Numeric/LinearAlgebra')
-rw-r--r--lib/Numeric/LinearAlgebra/Instances.hs31
-rw-r--r--lib/Numeric/LinearAlgebra/Tests.hs1
2 files changed, 7 insertions, 25 deletions
diff --git a/lib/Numeric/LinearAlgebra/Instances.hs b/lib/Numeric/LinearAlgebra/Instances.hs
index 1f8b5a0..67496f2 100644
--- a/lib/Numeric/LinearAlgebra/Instances.hs
+++ b/lib/Numeric/LinearAlgebra/Instances.hs
@@ -71,26 +71,6 @@ adaptScalar f1 f2 f3 x y
71 | dim y == 1 = f3 x (y@>0) 71 | dim y == 1 = f3 x (y@>0)
72 | otherwise = f2 x y 72 | otherwise = f2 x y
73 73
74liftMatrix2' :: (Element t, Element a, Element b)
75 => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
76liftMatrix2' f m1 m2 | compat' m1 m2 = lM f m1 m2
77 | rows m1 == rows m2 && cols m2 == 1 = lM f m1 (repCols (cols m1) m2)
78 | rows m1 == rows m2 && cols m1 == 1 = lM f (repCols (cols m2) m1) m2
79 | cols m1 == cols m2 && rows m2 == 1 = lM f m1 (repRows (rows m1) m2)
80 | cols m1 == cols m2 && cols m1 == 1 = lM f (repRows (rows m2) m1) m2
81 | rows m1 == 1 && cols m2 == 1 = lM f (repRows (rows m2) m1) (repCols (cols m1) m2)
82 | cols m1 == 1 && rows m2 == 1 = lM f (repCols (cols m2) m1) (repRows (rows m1) m2)
83 | otherwise = error "nonconformable matrices in liftMatrix2'"
84
85lM f m1 m2 = reshape (max (cols m1) (cols m2)) (f (flatten m1) (flatten m2))
86
87repRows n x = fromRows (replicate n (flatten x))
88repCols n x = fromColumns (replicate n (flatten x))
89
90compat' :: Matrix a -> Matrix b -> Bool
91compat' m1 m2 = rows m1 == 1 && cols m1 == 1
92 || rows m2 == 1 && cols m2 == 1
93 || rows m1 == rows m2 && cols m1 == cols m2
94 74
95instance Linear Vector a => Eq (Vector a) where 75instance Linear Vector a => Eq (Vector a) where
96 (==) = equal 76 (==) = equal
@@ -115,10 +95,10 @@ instance Linear Matrix a => Eq (Matrix a) where
115 (==) = equal 95 (==) = equal
116 96
117instance (Linear Matrix a, Num (Vector a)) => Num (Matrix a) where 97instance (Linear Matrix a, Num (Vector a)) => Num (Matrix a) where
118 (+) = liftMatrix2' (+) 98 (+) = liftMatrix2Auto (+)
119 (-) = liftMatrix2' (-) 99 (-) = liftMatrix2Auto (-)
120 negate = liftMatrix negate 100 negate = liftMatrix negate
121 (*) = liftMatrix2' (*) 101 (*) = liftMatrix2Auto (*)
122 signum = liftMatrix signum 102 signum = liftMatrix signum
123 abs = liftMatrix abs 103 abs = liftMatrix abs
124 fromInteger = (1><1) . return . fromInteger 104 fromInteger = (1><1) . return . fromInteger
@@ -135,7 +115,7 @@ instance (Linear Vector a, Num (Vector a)) => Fractional (Vector a) where
135 115
136instance (Linear Vector a, Fractional (Vector a), Num (Matrix a)) => Fractional (Matrix a) where 116instance (Linear Vector a, Fractional (Vector a), Num (Matrix a)) => Fractional (Matrix a) where
137 fromRational n = (1><1) [fromRational n] 117 fromRational n = (1><1) [fromRational n]
138 (/) = liftMatrix2' (/) 118 (/) = liftMatrix2Auto (/)
139 119
140--------------------------------------------------------- 120---------------------------------------------------------
141 121
@@ -196,7 +176,7 @@ instance (Linear Vector a, Floating (Vector a), Fractional (Matrix a)) => Floati
196 atanh = liftMatrix atanh 176 atanh = liftMatrix atanh
197 exp = liftMatrix exp 177 exp = liftMatrix exp
198 log = liftMatrix log 178 log = liftMatrix log
199 (**) = liftMatrix2' (**) 179 (**) = liftMatrix2Auto (**)
200 sqrt = liftMatrix sqrt 180 sqrt = liftMatrix sqrt
201 pi = (1><1) [pi] 181 pi = (1><1) [pi]
202 182
@@ -216,3 +196,4 @@ instance (Storable a, Num (Vector a)) => Monoid (Vector a) where
216-- 196--
217-- instance (NFData a, Element a) => NFData (Matrix a) where 197-- instance (NFData a, Element a) => NFData (Matrix a) where
218-- rnf = rnf . flatten 198-- rnf = rnf . flatten
199
diff --git a/lib/Numeric/LinearAlgebra/Tests.hs b/lib/Numeric/LinearAlgebra/Tests.hs
index 89c8297..e6b26a0 100644
--- a/lib/Numeric/LinearAlgebra/Tests.hs
+++ b/lib/Numeric/LinearAlgebra/Tests.hs
@@ -366,3 +366,4 @@ svdBench = do
366 time "full svd 3000x500" (fv $ svd a) 366 time "full svd 3000x500" (fv $ svd a)
367 time "singular values 1000x1000" (singularValues b) 367 time "singular values 1000x1000" (singularValues b)
368 time "full svd 1000x1000" (fv $ svd b) 368 time "full svd 1000x1000" (fv $ svd b)
369