summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Data/Packed/Numeric.hs28
-rw-r--r--packages/base/src/Numeric/HMatrix.hs4
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Data.hs2
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Devel.hs5
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs59
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs15
-rw-r--r--packages/base/src/Numeric/Sparse.hs13
7 files changed, 57 insertions, 69 deletions
diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs
index 7d88cbc..e324ab6 100644
--- a/packages/base/src/Data/Packed/Numeric.hs
+++ b/packages/base/src/Data/Packed/Numeric.hs
@@ -40,9 +40,9 @@ module Data.Packed.Numeric (
40 step, cond, find, assoc, accum, 40 step, cond, find, assoc, accum,
41 Transposable(..), Linear(..), 41 Transposable(..), Linear(..),
42 -- * Matrix product 42 -- * Matrix product
43 Product(..), udot, dot, (◇), (<·>), (#>), 43 Product(..), udot, dot, (<·>), (#>),
44 Mul(..), 44 Mul(..),
45 Contraction(..),(<.>), 45 (<.>),
46 optimiseMult, 46 optimiseMult,
47 mXm,mXv,vXm,LSDiv,(<\>), 47 mXm,mXv,vXm,LSDiv,(<\>),
48 outer, kronecker, 48 outer, kronecker,
@@ -140,25 +140,6 @@ fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0]
140-} 140-}
141 141
142 142
143
144
145class Contraction a b c | a b -> c
146 where
147 -- | Matrix product, matrix - vector product, and dot product
148 contraction :: a -> b -> c
149
150instance (Product t, Container Vector t) => Contraction (Vector t) (Vector t) t where
151 u `contraction` v = conj u `udot` v
152
153instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where
154 contraction = mXv
155
156instance (Container Vector t, Product t) => Contraction (Vector t) (Matrix t) (Vector t) where
157 contraction v m = (conj v) `vXm` m
158
159instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where
160 contraction = mXm
161
162-------------------------------------------------------------------------------- 143--------------------------------------------------------------------------------
163 144
164infixl 7 <.> 145infixl 7 <.>
@@ -265,11 +246,6 @@ instance Container Matrix e => Build (Int,Int) (e -> e -> e) Matrix e
265 246
266-------------------------------------------------------------------------------- 247--------------------------------------------------------------------------------
267 248
268-- | alternative unicode symbol (25c7) for 'contraction'
269(◇) :: Contraction a b c => a -> b -> c
270infixl 7 ◇
271(◇) = contraction
272
273-- | dot product: @cdot u v = 'udot' ('conj' u) v@ 249-- | dot product: @cdot u v = 'udot' ('conj' u) v@
274dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t 250dot :: (Container Vector t, Product t) => Vector t -> Vector t -> t
275dot u v = udot (conj u) v 251dot u v = udot (conj u) v
diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs
index 1c70ef6..7f27fd4 100644
--- a/packages/base/src/Numeric/HMatrix.hs
+++ b/packages/base/src/Numeric/HMatrix.hs
@@ -41,7 +41,7 @@ module Numeric.HMatrix (
41 -- ** dot 41 -- ** dot
42 (<·>), 42 (<·>),
43 -- ** matrix-vector 43 -- ** matrix-vector
44 (#>),(!#>), 44 (#>), (!#>),
45 -- ** matrix-matrix 45 -- ** matrix-matrix
46 (<>), 46 (<>),
47 -- | The matrix x matrix product is also implemented in the "Data.Monoid" instance, where 47 -- | The matrix x matrix product is also implemented in the "Data.Monoid" instance, where
@@ -135,7 +135,7 @@ module Numeric.HMatrix (
135 -- * Misc 135 -- * Misc
136 meanCov, peps, relativeError, haussholder, optimiseMult, udot, 136 meanCov, peps, relativeError, haussholder, optimiseMult, udot,
137 -- * Auxiliary classes 137 -- * Auxiliary classes
138 Element, Container, Product, Contraction(..), Numeric, LSDiv, 138 Element, Container, Product, Numeric, LSDiv,
139 Complexable, RealElement, 139 Complexable, RealElement,
140 RealOf, ComplexOf, SingleOf, DoubleOf, 140 RealOf, ComplexOf, SingleOf, DoubleOf,
141 IndexOf, 141 IndexOf,
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs
index 33a2c9a..20f3b81 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Data.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs
@@ -68,7 +68,7 @@ module Numeric.LinearAlgebra.Data(
68 68
69 module Data.Complex, 69 module Data.Complex,
70 70
71 Vector, Matrix, GMatrix, CSR(..), mkCSR 71 Vector, Matrix, GMatrix, nRows, nCols
72 72
73) where 73) where
74 74
diff --git a/packages/base/src/Numeric/LinearAlgebra/Devel.hs b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
index ca9e53a..fce8b71 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Devel.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Devel.hs
@@ -49,10 +49,15 @@ module Numeric.LinearAlgebra.Devel(
49 mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_, 49 mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_,
50 liftMatrix, liftMatrix2, liftMatrix2Auto, 50 liftMatrix, liftMatrix2, liftMatrix2Auto,
51 51
52 -- * Misc
53 CSR(..), fromCSR, mkCSR,
54 GMatrix(..)
55
52) where 56) where
53 57
54import Data.Packed.Foreign 58import Data.Packed.Foreign
55import Data.Packed.Development 59import Data.Packed.Development
56import Data.Packed.ST 60import Data.Packed.ST
57import Data.Packed 61import Data.Packed
62import Numeric.Sparse
58 63
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
index 1e8b544..5634031 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Real.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs
@@ -29,26 +29,26 @@ module Numeric.LinearAlgebra.Real(
29 vec2, vec3, vec4, 𝕧, (&), 29 vec2, vec3, vec4, 𝕧, (&),
30 -- * Matrix 30 -- * Matrix
31 L, Sq, 31 L, Sq,
32 𝕞, 32 row, col, (¦),(——),
33 (#),(¦),(——), 33 Konst(..),
34 Konst(..), 34 eye,
35 eye, 35 diagR, diag,
36 diagR, diag, 36 blockAt,
37 blockAt,
38 -- * Products 37 -- * Products
39 (<>),(#>),(<·>), 38 (<>),(#>),(<·>),
40 -- * Pretty printing 39 -- * Pretty printing
41 Disp(..), 40 Disp(..),
42 -- * Misc 41 -- * Misc
43 Dim, unDim, 42 Dim, unDim,
44 module Numeric.HMatrix 43 module Numeric.HMatrix
45) where 44) where
46 45
47 46
48import GHC.TypeLits 47import GHC.TypeLits
49import Numeric.HMatrix hiding ((<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——)) 48import Numeric.HMatrix hiding ((<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col)
50import qualified Numeric.HMatrix as LA 49import qualified Numeric.HMatrix as LA
51import Data.Packed.ST 50import Data.Packed.ST
51import Data.Proxy(Proxy)
52 52
53newtype Dim (n :: Nat) t = Dim t 53newtype Dim (n :: Nat) t = Dim t
54 deriving Show 54 deriving Show
@@ -56,7 +56,7 @@ newtype Dim (n :: Nat) t = Dim t
56unDim :: Dim n t -> t 56unDim :: Dim n t -> t
57unDim (Dim x) = x 57unDim (Dim x) = x
58 58
59data Proxy :: Nat -> * 59-- data Proxy :: Nat -> *
60 60
61 61
62lift1F 62lift1F
@@ -223,7 +223,7 @@ instance Disp (R n)
223 else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su) 223 else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su)
224 224
225-------------------------------------------------------------------------------- 225--------------------------------------------------------------------------------
226 226{-
227infixl 3 # 227infixl 3 #
228(#) :: L r c -> R c -> L (r+1) c 228(#) :: L r c -> R c -> L (r+1) c
229Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v)) 229Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v))
@@ -233,14 +233,31 @@ Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v))
233𝕞 = Dim (Dim (LA.konst 0 (0,d))) 233𝕞 = Dim (Dim (LA.konst 0 (0,d)))
234 where 234 where
235 d = fromIntegral . natVal $ (undefined :: Proxy n) 235 d = fromIntegral . natVal $ (undefined :: Proxy n)
236-}
237
238row :: R n -> L 1 n
239row (Dim v) = Dim (Dim (asRow v))
240
241col :: R n -> L n 1
242col = tr . row
236 243
237infixl 3 ¦ 244infixl 3 ¦
238(¦) :: L r c1 -> L r c2 -> L r (c1+c2) 245(¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2)
239Dim (Dim a) ¦ Dim (Dim b) = Dim (Dim (a LA.¦ b)) 246a ¦ b = rjoin (expk a) (expk b)
247 where
248 Dim (Dim a') `rjoin` Dim (Dim b') = Dim (Dim (a' LA.¦ b'))
240 249
241infixl 2 —— 250infixl 2 ——
242(——) :: L r1 c -> L r2 c -> L (r1+r2) c 251(——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c
243Dim (Dim a) —— Dim (Dim b) = Dim (Dim (a LA.—— b)) 252a —— b = cjoin (expk a) (expk b)
253 where
254 Dim (Dim a') `cjoin` Dim (Dim b') = Dim (Dim (a' LA.—— b'))
255
256expk :: (KnownNat n, KnownNat m) => L m n -> L m n
257expk x | singleton x = konst (d2 x `atIndex` (0,0))
258 | otherwise = x
259 where
260 singleton (d2 -> m) = rows m == 1 && cols m == 1
244 261
245 262
246{- 263{-
@@ -338,10 +355,4 @@ instance (KnownNat n', KnownNat m') => Testable (L n' m')
338 where 355 where
339 checkT _ = test 356 checkT _ = test
340 357
341{-
342do (snd test)
343fst test
344-}
345
346
347 358
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
index f821b57..b82c74f 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
@@ -41,22 +41,21 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
41 ap1 = a p 41 ap1 = a p
42 ap | sym = ap1 42 ap | sym = ap1
43 | otherwise = at ap1 43 | otherwise = at ap1
44 pap | sym = p ap1 44 pap | sym = p <·> ap1
45 | otherwise = norm2 ap1 ** 2 45 | otherwise = norm2 ap1 ** 2
46 alpha = r2 / pap 46 alpha = r2 / pap
47 dx = scale alpha p 47 dx = scale alpha p
48 x' = x + dx 48 x' = x + dx
49 r' = r - scale alpha ap 49 r' = r - scale alpha ap
50 r'2 = r' r' 50 r'2 = r' <·> r'
51 beta = r'2 / r2 51 beta = r'2 / r2
52 p' = r' + scale beta p 52 p' = r' + scale beta p
53 53
54 rdx = norm2 dx / max 1 (norm2 x) 54 rdx = norm2 dx / max 1 (norm2 x)
55 55
56conjugrad 56conjugrad
57 :: (Transposable m mt, Contraction m V V, Contraction mt V V) 57 :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState]
58 => Bool -> m -> V -> V -> R -> R -> [CGState] 58conjugrad sym a b = solveG (tr a !#>) (a !#>) (cg sym) b
59conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b
60 59
61solveG 60solveG
62 :: (V -> V) -> (V -> V) 61 :: (V -> V) -> (V -> V)
@@ -72,9 +71,9 @@ solveG mat ma meth rawb x0' ϵb ϵx
72 b = mat rawb 71 b = mat rawb
73 x0 = if x0' == 0 then konst 0 (dim b) else x0' 72 x0 = if x0' == 0 then konst 0 (dim b) else x0'
74 r0 = b - a x0 73 r0 = b - a x0
75 r20 = r0 r0 74 r20 = r0 <·> r0
76 p0 = r0 75 p0 = r0
77 nb2 = b b 76 nb2 = b <·> b
78 ok CGState {..} 77 ok CGState {..}
79 = cgr2 <nb2*ϵb**2 78 = cgr2 <nb2*ϵb**2
80 || cgdx < ϵx 79 || cgdx < ϵx
@@ -115,7 +114,7 @@ instance Testable GMatrix
115 sma = convo2 20 3 114 sma = convo2 20 3
116 x1 = vect [1..20] 115 x1 = vect [1..20]
117 x2 = vect [1..40] 116 x2 = vect [1..40]
118 sm = (mkSparse . mkCSR) sma 117 sm = mkSparse sma
119 dm = toDense sma 118 dm = toDense sma
120 119
121 s1 = sm !#> x1 120 s1 = sm !#> x1
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs
index 3c19c93..1b8a7b3 100644
--- a/packages/base/src/Numeric/Sparse.hs
+++ b/packages/base/src/Numeric/Sparse.hs
@@ -3,7 +3,7 @@
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4 4
5module Numeric.Sparse( 5module Numeric.Sparse(
6 GMatrix, CSR(..), mkCSR, 6 GMatrix(..), CSR(..), mkCSR, fromCSR,
7 mkSparse, mkDiagR, mkDense, 7 mkSparse, mkDiagR, mkDense,
8 AssocMatrix, 8 AssocMatrix,
9 toDense, 9 toDense,
@@ -95,9 +95,11 @@ mkDense m = Dense{..}
95 nRows = rows m 95 nRows = rows m
96 nCols = cols m 96 nCols = cols m
97 97
98mkSparse :: AssocMatrix -> GMatrix
99mkSparse = fromCSR . mkCSR
98 100
99mkSparse :: CSR -> GMatrix 101fromCSR :: CSR -> GMatrix
100mkSparse csr = SparseR {..} 102fromCSR csr = SparseR {..}
101 where 103 where
102 gmCSR @ CSR {..} = csr 104 gmCSR @ CSR {..} = csr
103 nRows = csrNRows 105 nRows = csrNRows
@@ -149,11 +151,6 @@ infixr 8 !#>
149(!#>) :: GMatrix -> Vector Double -> Vector Double 151(!#>) :: GMatrix -> Vector Double -> Vector Double
150(!#>) = gmXv 152(!#>) = gmXv
151 153
152
153instance Contraction GMatrix (Vector Double) (Vector Double)
154 where
155 contraction = gmXv
156
157-------------------------------------------------------------------------------- 154--------------------------------------------------------------------------------
158 155
159foreign import ccall unsafe "smXv" 156foreign import ccall unsafe "smXv"