diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Data/Packed/Numeric.hs | 28 | ||||
-rw-r--r-- | packages/base/src/Numeric/HMatrix.hs | 4 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Data.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Devel.hs | 5 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Real.hs | 59 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 15 | ||||
-rw-r--r-- | packages/base/src/Numeric/Sparse.hs | 13 |
7 files changed, 57 insertions, 69 deletions
diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs index 7d88cbc..e324ab6 100644 --- a/packages/base/src/Data/Packed/Numeric.hs +++ b/packages/base/src/Data/Packed/Numeric.hs | |||
@@ -40,9 +40,9 @@ module Data.Packed.Numeric ( | |||
40 | step, cond, find, assoc, accum, | 40 | step, cond, find, assoc, accum, |
41 | Transposable(..), Linear(..), | 41 | Transposable(..), Linear(..), |
42 | -- * Matrix product | 42 | -- * Matrix product |
43 | Product(..), udot, dot, (◇), (<·>), (#>), | 43 | Product(..), udot, dot, (<·>), (#>), |
44 | Mul(..), | 44 | Mul(..), |
45 | Contraction(..),(<.>), | 45 | (<.>), |
46 | optimiseMult, | 46 | optimiseMult, |
47 | mXm,mXv,vXm,LSDiv,(<\>), | 47 | mXm,mXv,vXm,LSDiv,(<\>), |
48 | outer, kronecker, | 48 | outer, kronecker, |
@@ -140,25 +140,6 @@ fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0] | |||
140 | -} | 140 | -} |
141 | 141 | ||
142 | 142 | ||
143 | |||
144 | |||
145 | class Contraction a b c | a b -> c | ||
146 | where | ||
147 | -- | Matrix product, matrix - vector product, and dot product | ||
148 | contraction :: a -> b -> c | ||
149 | |||
150 | instance (Product t, Container Vector t) => Contraction (Vector t) (Vector t) t where | ||
151 | u `contraction` v = conj u `udot` v | ||
152 | |||
153 | instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where | ||
154 | contraction = mXv | ||
155 | |||
156 | instance (Container Vector t, Product t) => Contraction (Vector t) (Matrix t) (Vector t) where | ||
157 | contraction v m = (conj v) `vXm` m | ||
158 | |||
159 | instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where | ||
160 | contraction = mXm | ||
161 | |||
162 | -------------------------------------------------------------------------------- | 143 | -------------------------------------------------------------------------------- |
163 | 144 | ||
164 | infixl 7 <.> | 145 | infixl 7 <.> |
@@ -265,11 +246,6 @@ instance Container Matrix e => Build (Int,Int) (e -> e -> e) Matrix e | |||
265 | 246 | ||
266 | -------------------------------------------------------------------------------- | 247 | -------------------------------------------------------------------------------- |
267 | 248 | ||
268 | -- | alternative unicode symbol (25c7) for 'contraction' | ||
269 | (◇) :: Contraction a b c => a -> b -> c | ||
270 | infixl 7 ◇ | ||
271 | (◇) = contraction | ||
272 | |||
273 | -- | dot product: @cdot u v = 'udot' ('conj' u) v@ | 249 | -- | dot product: @cdot u v = 'udot' ('conj' u) v@ |
274 | dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t | 250 | dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t |
275 | dot u v = udot (conj u) v | 251 | dot u v = udot (conj u) v |
diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs index 1c70ef6..7f27fd4 100644 --- a/packages/base/src/Numeric/HMatrix.hs +++ b/packages/base/src/Numeric/HMatrix.hs | |||
@@ -41,7 +41,7 @@ module Numeric.HMatrix ( | |||
41 | -- ** dot | 41 | -- ** dot |
42 | (<·>), | 42 | (<·>), |
43 | -- ** matrix-vector | 43 | -- ** matrix-vector |
44 | (#>),(!#>), | 44 | (#>), (!#>), |
45 | -- ** matrix-matrix | 45 | -- ** matrix-matrix |
46 | (<>), | 46 | (<>), |
47 | -- | The matrix x matrix product is also implemented in the "Data.Monoid" instance, where | 47 | -- | The matrix x matrix product is also implemented in the "Data.Monoid" instance, where |
@@ -135,7 +135,7 @@ module Numeric.HMatrix ( | |||
135 | -- * Misc | 135 | -- * Misc |
136 | meanCov, peps, relativeError, haussholder, optimiseMult, udot, | 136 | meanCov, peps, relativeError, haussholder, optimiseMult, udot, |
137 | -- * Auxiliary classes | 137 | -- * Auxiliary classes |
138 | Element, Container, Product, Contraction(..), Numeric, LSDiv, | 138 | Element, Container, Product, Numeric, LSDiv, |
139 | Complexable, RealElement, | 139 | Complexable, RealElement, |
140 | RealOf, ComplexOf, SingleOf, DoubleOf, | 140 | RealOf, ComplexOf, SingleOf, DoubleOf, |
141 | IndexOf, | 141 | IndexOf, |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 33a2c9a..20f3b81 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs | |||
@@ -68,7 +68,7 @@ module Numeric.LinearAlgebra.Data( | |||
68 | 68 | ||
69 | module Data.Complex, | 69 | module Data.Complex, |
70 | 70 | ||
71 | Vector, Matrix, GMatrix, CSR(..), mkCSR | 71 | Vector, Matrix, GMatrix, nRows, nCols |
72 | 72 | ||
73 | ) where | 73 | ) where |
74 | 74 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs index ca9e53a..fce8b71 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs | |||
@@ -49,10 +49,15 @@ module Numeric.LinearAlgebra.Devel( | |||
49 | mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, | 49 | mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, |
50 | liftMatrix, liftMatrix2, liftMatrix2Auto, | 50 | liftMatrix, liftMatrix2, liftMatrix2Auto, |
51 | 51 | ||
52 | -- * Misc | ||
53 | CSR(..), fromCSR, mkCSR, | ||
54 | GMatrix(..) | ||
55 | |||
52 | ) where | 56 | ) where |
53 | 57 | ||
54 | import Data.Packed.Foreign | 58 | import Data.Packed.Foreign |
55 | import Data.Packed.Development | 59 | import Data.Packed.Development |
56 | import Data.Packed.ST | 60 | import Data.Packed.ST |
57 | import Data.Packed | 61 | import Data.Packed |
62 | import Numeric.Sparse | ||
58 | 63 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index 1e8b544..5634031 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs | |||
@@ -29,26 +29,26 @@ module Numeric.LinearAlgebra.Real( | |||
29 | vec2, vec3, vec4, 𝕧, (&), | 29 | vec2, vec3, vec4, 𝕧, (&), |
30 | -- * Matrix | 30 | -- * Matrix |
31 | L, Sq, | 31 | L, Sq, |
32 | 𝕞, | 32 | row, col, (¦),(——), |
33 | (#),(¦),(——), | 33 | Konst(..), |
34 | Konst(..), | 34 | eye, |
35 | eye, | 35 | diagR, diag, |
36 | diagR, diag, | 36 | blockAt, |
37 | blockAt, | ||
38 | -- * Products | 37 | -- * Products |
39 | (<>),(#>),(<·>), | 38 | (<>),(#>),(<·>), |
40 | -- * Pretty printing | 39 | -- * Pretty printing |
41 | Disp(..), | 40 | Disp(..), |
42 | -- * Misc | 41 | -- * Misc |
43 | Dim, unDim, | 42 | Dim, unDim, |
44 | module Numeric.HMatrix | 43 | module Numeric.HMatrix |
45 | ) where | 44 | ) where |
46 | 45 | ||
47 | 46 | ||
48 | import GHC.TypeLits | 47 | import GHC.TypeLits |
49 | import Numeric.HMatrix hiding ((<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——)) | 48 | import Numeric.HMatrix hiding ((<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col) |
50 | import qualified Numeric.HMatrix as LA | 49 | import qualified Numeric.HMatrix as LA |
51 | import Data.Packed.ST | 50 | import Data.Packed.ST |
51 | import Data.Proxy(Proxy) | ||
52 | 52 | ||
53 | newtype Dim (n :: Nat) t = Dim t | 53 | newtype Dim (n :: Nat) t = Dim t |
54 | deriving Show | 54 | deriving Show |
@@ -56,7 +56,7 @@ newtype Dim (n :: Nat) t = Dim t | |||
56 | unDim :: Dim n t -> t | 56 | unDim :: Dim n t -> t |
57 | unDim (Dim x) = x | 57 | unDim (Dim x) = x |
58 | 58 | ||
59 | data Proxy :: Nat -> * | 59 | -- data Proxy :: Nat -> * |
60 | 60 | ||
61 | 61 | ||
62 | lift1F | 62 | lift1F |
@@ -223,7 +223,7 @@ instance Disp (R n) | |||
223 | else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su) | 223 | else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su) |
224 | 224 | ||
225 | -------------------------------------------------------------------------------- | 225 | -------------------------------------------------------------------------------- |
226 | 226 | {- | |
227 | infixl 3 # | 227 | infixl 3 # |
228 | (#) :: L r c -> R c -> L (r+1) c | 228 | (#) :: L r c -> R c -> L (r+1) c |
229 | Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v)) | 229 | Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v)) |
@@ -233,14 +233,31 @@ Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v)) | |||
233 | 𝕞 = Dim (Dim (LA.konst 0 (0,d))) | 233 | 𝕞 = Dim (Dim (LA.konst 0 (0,d))) |
234 | where | 234 | where |
235 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 235 | d = fromIntegral . natVal $ (undefined :: Proxy n) |
236 | -} | ||
237 | |||
238 | row :: R n -> L 1 n | ||
239 | row (Dim v) = Dim (Dim (asRow v)) | ||
240 | |||
241 | col :: R n -> L n 1 | ||
242 | col = tr . row | ||
236 | 243 | ||
237 | infixl 3 ¦ | 244 | infixl 3 ¦ |
238 | (¦) :: L r c1 -> L r c2 -> L r (c1+c2) | 245 | (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) |
239 | Dim (Dim a) ¦ Dim (Dim b) = Dim (Dim (a LA.¦ b)) | 246 | a ¦ b = rjoin (expk a) (expk b) |
247 | where | ||
248 | Dim (Dim a') `rjoin` Dim (Dim b') = Dim (Dim (a' LA.¦ b')) | ||
240 | 249 | ||
241 | infixl 2 —— | 250 | infixl 2 —— |
242 | (——) :: L r1 c -> L r2 c -> L (r1+r2) c | 251 | (——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c |
243 | Dim (Dim a) —— Dim (Dim b) = Dim (Dim (a LA.—— b)) | 252 | a —— b = cjoin (expk a) (expk b) |
253 | where | ||
254 | Dim (Dim a') `cjoin` Dim (Dim b') = Dim (Dim (a' LA.—— b')) | ||
255 | |||
256 | expk :: (KnownNat n, KnownNat m) => L m n -> L m n | ||
257 | expk x | singleton x = konst (d2 x `atIndex` (0,0)) | ||
258 | | otherwise = x | ||
259 | where | ||
260 | singleton (d2 -> m) = rows m == 1 && cols m == 1 | ||
244 | 261 | ||
245 | 262 | ||
246 | {- | 263 | {- |
@@ -338,10 +355,4 @@ instance (KnownNat n', KnownNat m') => Testable (L n' m') | |||
338 | where | 355 | where |
339 | checkT _ = test | 356 | checkT _ = test |
340 | 357 | ||
341 | {- | ||
342 | do (snd test) | ||
343 | fst test | ||
344 | -} | ||
345 | |||
346 | |||
347 | 358 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs index f821b57..b82c74f 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | |||
@@ -41,22 +41,21 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | |||
41 | ap1 = a p | 41 | ap1 = a p |
42 | ap | sym = ap1 | 42 | ap | sym = ap1 |
43 | | otherwise = at ap1 | 43 | | otherwise = at ap1 |
44 | pap | sym = p ◇ ap1 | 44 | pap | sym = p <·> ap1 |
45 | | otherwise = norm2 ap1 ** 2 | 45 | | otherwise = norm2 ap1 ** 2 |
46 | alpha = r2 / pap | 46 | alpha = r2 / pap |
47 | dx = scale alpha p | 47 | dx = scale alpha p |
48 | x' = x + dx | 48 | x' = x + dx |
49 | r' = r - scale alpha ap | 49 | r' = r - scale alpha ap |
50 | r'2 = r' ◇ r' | 50 | r'2 = r' <·> r' |
51 | beta = r'2 / r2 | 51 | beta = r'2 / r2 |
52 | p' = r' + scale beta p | 52 | p' = r' + scale beta p |
53 | 53 | ||
54 | rdx = norm2 dx / max 1 (norm2 x) | 54 | rdx = norm2 dx / max 1 (norm2 x) |
55 | 55 | ||
56 | conjugrad | 56 | conjugrad |
57 | :: (Transposable m mt, Contraction m V V, Contraction mt V V) | 57 | :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState] |
58 | => Bool -> m -> V -> V -> R -> R -> [CGState] | 58 | conjugrad sym a b = solveG (tr a !#>) (a !#>) (cg sym) b |
59 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b | ||
60 | 59 | ||
61 | solveG | 60 | solveG |
62 | :: (V -> V) -> (V -> V) | 61 | :: (V -> V) -> (V -> V) |
@@ -72,9 +71,9 @@ solveG mat ma meth rawb x0' ϵb ϵx | |||
72 | b = mat rawb | 71 | b = mat rawb |
73 | x0 = if x0' == 0 then konst 0 (dim b) else x0' | 72 | x0 = if x0' == 0 then konst 0 (dim b) else x0' |
74 | r0 = b - a x0 | 73 | r0 = b - a x0 |
75 | r20 = r0 ◇ r0 | 74 | r20 = r0 <·> r0 |
76 | p0 = r0 | 75 | p0 = r0 |
77 | nb2 = b ◇ b | 76 | nb2 = b <·> b |
78 | ok CGState {..} | 77 | ok CGState {..} |
79 | = cgr2 <nb2*ϵb**2 | 78 | = cgr2 <nb2*ϵb**2 |
80 | || cgdx < ϵx | 79 | || cgdx < ϵx |
@@ -115,7 +114,7 @@ instance Testable GMatrix | |||
115 | sma = convo2 20 3 | 114 | sma = convo2 20 3 |
116 | x1 = vect [1..20] | 115 | x1 = vect [1..20] |
117 | x2 = vect [1..40] | 116 | x2 = vect [1..40] |
118 | sm = (mkSparse . mkCSR) sma | 117 | sm = mkSparse sma |
119 | dm = toDense sma | 118 | dm = toDense sma |
120 | 119 | ||
121 | s1 = sm !#> x1 | 120 | s1 = sm !#> x1 |
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs index 3c19c93..1b8a7b3 100644 --- a/packages/base/src/Numeric/Sparse.hs +++ b/packages/base/src/Numeric/Sparse.hs | |||
@@ -3,7 +3,7 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | 4 | ||
5 | module Numeric.Sparse( | 5 | module Numeric.Sparse( |
6 | GMatrix, CSR(..), mkCSR, | 6 | GMatrix(..), CSR(..), mkCSR, fromCSR, |
7 | mkSparse, mkDiagR, mkDense, | 7 | mkSparse, mkDiagR, mkDense, |
8 | AssocMatrix, | 8 | AssocMatrix, |
9 | toDense, | 9 | toDense, |
@@ -95,9 +95,11 @@ mkDense m = Dense{..} | |||
95 | nRows = rows m | 95 | nRows = rows m |
96 | nCols = cols m | 96 | nCols = cols m |
97 | 97 | ||
98 | mkSparse :: AssocMatrix -> GMatrix | ||
99 | mkSparse = fromCSR . mkCSR | ||
98 | 100 | ||
99 | mkSparse :: CSR -> GMatrix | 101 | fromCSR :: CSR -> GMatrix |
100 | mkSparse csr = SparseR {..} | 102 | fromCSR csr = SparseR {..} |
101 | where | 103 | where |
102 | gmCSR @ CSR {..} = csr | 104 | gmCSR @ CSR {..} = csr |
103 | nRows = csrNRows | 105 | nRows = csrNRows |
@@ -149,11 +151,6 @@ infixr 8 !#> | |||
149 | (!#>) :: GMatrix -> Vector Double -> Vector Double | 151 | (!#>) :: GMatrix -> Vector Double -> Vector Double |
150 | (!#>) = gmXv | 152 | (!#>) = gmXv |
151 | 153 | ||
152 | |||
153 | instance Contraction GMatrix (Vector Double) (Vector Double) | ||
154 | where | ||
155 | contraction = gmXv | ||
156 | |||
157 | -------------------------------------------------------------------------------- | 154 | -------------------------------------------------------------------------------- |
158 | 155 | ||
159 | foreign import ccall unsafe "smXv" | 156 | foreign import ccall unsafe "smXv" |