summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Data/Packed/Internal/Numeric.hs7
-rw-r--r--packages/base/src/Data/Packed/Numeric.hs41
-rw-r--r--packages/base/src/Numeric/HMatrix.hs63
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Data.hs11
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Real.hs337
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util.hs20
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/CG.hs86
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Util/Static.hs70
-rw-r--r--packages/base/src/Numeric/Sparse.hs127
9 files changed, 535 insertions, 227 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs
index 3c1c1d0..0205a17 100644
--- a/packages/base/src/Data/Packed/Internal/Numeric.hs
+++ b/packages/base/src/Data/Packed/Internal/Numeric.hs
@@ -3,6 +3,7 @@
3{-# LANGUAGE FlexibleContexts #-} 3{-# LANGUAGE FlexibleContexts #-}
4{-# LANGUAGE FlexibleInstances #-} 4{-# LANGUAGE FlexibleInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-} 5{-# LANGUAGE MultiParamTypeClasses #-}
6{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-} 7{-# LANGUAGE UndecidableInstances #-}
7 8
8----------------------------------------------------------------------------- 9-----------------------------------------------------------------------------
@@ -692,12 +693,12 @@ condV f a b l e t = f a' b' l' e' t'
692 693
693-------------------------------------------------------------------------------- 694--------------------------------------------------------------------------------
694 695
695class Transposable t 696class Transposable m mt | m -> mt, mt -> m
696 where 697 where
697 -- | (conjugate) transpose 698 -- | (conjugate) transpose
698 tr :: t -> t 699 tr :: m -> mt
699 700
700instance (Container Vector t) => Transposable (Matrix t) 701instance (Container Vector t) => Transposable (Matrix t) (Matrix t)
701 where 702 where
702 tr = ctrans 703 tr = ctrans
703 704
diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs
index 01cf6c5..7d88cbc 100644
--- a/packages/base/src/Data/Packed/Numeric.hs
+++ b/packages/base/src/Data/Packed/Numeric.hs
@@ -32,7 +32,7 @@ module Data.Packed.Numeric (
32 diag, ident, 32 diag, ident,
33 ctrans, 33 ctrans,
34 -- * Generic operations 34 -- * Generic operations
35 Container(..), 35 Container(..), Numeric,
36 -- add, mul, sub, divide, equal, scaleRecip, addConstant, 36 -- add, mul, sub, divide, equal, scaleRecip, addConstant,
37 scalar, conj, scale, arctan2, cmap, 37 scalar, conj, scale, arctan2, cmap,
38 atIndex, minIndex, maxIndex, minElement, maxElement, 38 atIndex, minIndex, maxIndex, minElement, maxElement,
@@ -40,7 +40,7 @@ module Data.Packed.Numeric (
40 step, cond, find, assoc, accum, 40 step, cond, find, assoc, accum,
41 Transposable(..), Linear(..), 41 Transposable(..), Linear(..),
42 -- * Matrix product 42 -- * Matrix product
43 Product(..), udot, dot, (◇), 43 Product(..), udot, dot, (◇), (<·>), (#>),
44 Mul(..), 44 Mul(..),
45 Contraction(..),(<.>), 45 Contraction(..),(<.>),
46 optimiseMult, 46 optimiseMult,
@@ -96,7 +96,7 @@ linspace n (a,b) = addConstant a $ scale s $ fromList $ map fromIntegral [0 .. n
96 96
97-------------------------------------------------------- 97--------------------------------------------------------
98 98
99{- | Matrix product, matrix - vector product, and dot product (equivalent to 'contraction') 99{- Matrix product, matrix - vector product, and dot product (equivalent to 'contraction')
100 100
101(This operator can also be written using the unicode symbol ◇ (25c7).) 101(This operator can also be written using the unicode symbol ◇ (25c7).)
102 102
@@ -138,9 +138,8 @@ For complex vectors the first argument is conjugated:
138>>> fromList [1,i,1-i] <.> complex a 138>>> fromList [1,i,1-i] <.> complex a
139fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0] 139fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0]
140-} 140-}
141infixl 7 <.> 141
142(<.>) :: Contraction a b c => a -> b -> c 142
143(<.>) = contraction
144 143
145 144
146class Contraction a b c | a b -> c 145class Contraction a b c | a b -> c
@@ -160,6 +159,23 @@ instance (Container Vector t, Product t) => Contraction (Vector t) (Matrix t) (V
160instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where 159instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where
161 contraction = mXm 160 contraction = mXm
162 161
162--------------------------------------------------------------------------------
163
164infixl 7 <.>
165-- | An infix synonym for 'dot'
166(<.>) :: Numeric t => Vector t -> Vector t -> t
167(<.>) = dot
168
169
170infixr 8 <·>, #>
171-- | dot product
172(<·>) :: Numeric t => Vector t -> Vector t -> t
173(<·>) = dot
174
175
176-- | matrix-vector product
177(#>) :: Numeric t => Matrix t -> Vector t -> Vector t
178(#>) = mXv
163 179
164-------------------------------------------------------------------------------- 180--------------------------------------------------------------------------------
165 181
@@ -286,3 +302,16 @@ meanCov x = (med,cov) where
286 302
287-------------------------------------------------------------------------------- 303--------------------------------------------------------------------------------
288 304
305class ( Container Vector t
306 , Container Matrix t
307 , Konst t Int Vector
308 , Konst t (Int,Int) Matrix
309 , Product t
310 ) => Numeric t
311
312instance Numeric Double
313instance Numeric (Complex Double)
314instance Numeric Float
315instance Numeric (Complex Float)
316
317--------------------------------------------------------------------------------
diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs
index d5c66fb..1c70ef6 100644
--- a/packages/base/src/Numeric/HMatrix.hs
+++ b/packages/base/src/Numeric/HMatrix.hs
@@ -10,16 +10,16 @@ Stability : provisional
10----------------------------------------------------------------------------- 10-----------------------------------------------------------------------------
11module Numeric.HMatrix ( 11module Numeric.HMatrix (
12 12
13 -- * Basic types and data processing 13 -- * Basic types and data processing
14 module Numeric.LinearAlgebra.Data, 14 module Numeric.LinearAlgebra.Data,
15 15
16 -- * Arithmetic and numeric classes 16 -- * Arithmetic and numeric classes
17 -- | 17 -- |
18 -- The standard numeric classes are defined elementwise: 18 -- The standard numeric classes are defined elementwise:
19 -- 19 --
20 -- >>> fromList [1,2,3] * fromList [3,0,-2 :: Double] 20 -- >>> fromList [1,2,3] * fromList [3,0,-2 :: Double]
21 -- fromList [3.0,0.0,-6.0] 21 -- fromList [3.0,0.0,-6.0]
22 -- 22 --
23 -- >>> (3><3) [1..9] * ident 3 :: Matrix Double 23 -- >>> (3><3) [1..9] * ident 3 :: Matrix Double
24 -- (3><3) 24 -- (3><3)
25 -- [ 1.0, 0.0, 0.0 25 -- [ 1.0, 0.0, 0.0
@@ -29,7 +29,7 @@ module Numeric.HMatrix (
29 -- In arithmetic operations single-element vectors and matrices 29 -- In arithmetic operations single-element vectors and matrices
30 -- (created from numeric literals or using 'scalar') automatically 30 -- (created from numeric literals or using 'scalar') automatically
31 -- expand to match the dimensions of the other operand: 31 -- expand to match the dimensions of the other operand:
32 -- 32 --
33 -- >>> 5 + 2*ident 3 :: Matrix Double 33 -- >>> 5 + 2*ident 3 :: Matrix Double
34 -- (3><3) 34 -- (3><3)
35 -- [ 7.0, 5.0, 5.0 35 -- [ 7.0, 5.0, 5.0
@@ -37,13 +37,14 @@ module Numeric.HMatrix (
37 -- , 5.0, 5.0, 7.0 ] 37 -- , 5.0, 5.0, 7.0 ]
38 -- 38 --
39 39
40 -- * Matrix product 40 -- * Products
41 (<.>), 41 -- ** dot
42 42 (<·>),
43 -- | The overloaded multiplication operators may need type annotations to remove 43 -- ** matrix-vector
44 -- ambiguity. In those cases we can also use the specific functions 'mXm', 'mXv', and 'dot'. 44 (#>),(!#>),
45 -- 45 -- ** matrix-matrix
46 -- The matrix x matrix product is also implemented in the "Data.Monoid" instance, where 46 (<>),
47 -- | The matrix x matrix product is also implemented in the "Data.Monoid" instance, where
47 -- single-element matrices (created from numeric literals or using 'scalar') 48 -- single-element matrices (created from numeric literals or using 'scalar')
48 -- are used for scaling. 49 -- are used for scaling.
49 -- 50 --
@@ -55,12 +56,12 @@ module Numeric.HMatrix (
55 -- 56 --
56 -- 'mconcat' uses 'optimiseMult' to get the optimal association order. 57 -- 'mconcat' uses 'optimiseMult' to get the optimal association order.
57 58
58 59
59 -- * Other products 60 -- ** other
60 outer, kronecker, cross, 61 outer, kronecker, cross,
61 scale, 62 scale,
62 sumElements, prodElements, 63 sumElements, prodElements,
63 64
64 -- * Linear Systems 65 -- * Linear Systems
65 (<\>), 66 (<\>),
66 linearSolve, 67 linearSolve,
@@ -70,14 +71,14 @@ module Numeric.HMatrix (
70 cholSolve, 71 cholSolve,
71 cgSolve, 72 cgSolve,
72 cgSolve', 73 cgSolve',
73 74
74 -- * Inverse and pseudoinverse 75 -- * Inverse and pseudoinverse
75 inv, pinv, pinvTol, 76 inv, pinv, pinvTol,
76 77
77 -- * Determinant and rank 78 -- * Determinant and rank
78 rcond, rank, ranksv, 79 rcond, rank, ranksv,
79 det, invlndet, 80 det, invlndet,
80 81
81 -- * Singular value decomposition 82 -- * Singular value decomposition
82 svd, 83 svd,
83 fullSVD, 84 fullSVD,
@@ -85,7 +86,7 @@ module Numeric.HMatrix (
85 compactSVD, 86 compactSVD,
86 singularValues, 87 singularValues,
87 leftSV, rightSV, 88 leftSV, rightSV,
88 89
89 -- * Eigensystems 90 -- * Eigensystems
90 eig, eigSH, eigSH', 91 eig, eigSH, eigSH',
91 eigenvalues, eigenvaluesSH, eigenvaluesSH', 92 eigenvalues, eigenvaluesSH, eigenvaluesSH',
@@ -105,7 +106,7 @@ module Numeric.HMatrix (
105 106
106 -- * LU 107 -- * LU
107 lu, luPacked, 108 lu, luPacked,
108 109
109 -- * Matrix functions 110 -- * Matrix functions
110 expm, 111 expm,
111 sqrtm, 112 sqrtm,
@@ -116,7 +117,7 @@ module Numeric.HMatrix (
116 nullVector, 117 nullVector,
117 nullspaceSVD, 118 nullspaceSVD,
118 null1, null1sym, 119 null1, null1sym,
119 120
120 orth, 121 orth,
121 122
122 -- * Norms 123 -- * Norms
@@ -129,30 +130,36 @@ module Numeric.HMatrix (
129 130
130 -- * Random arrays 131 -- * Random arrays
131 132
132 RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, 133 Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample,
133 134
134 -- * Misc 135 -- * Misc
135 meanCov, peps, relativeError, haussholder, optimiseMult, dot, udot, mXm, mXv, smXv, (<>), (◇), Seed, checkT, 136 meanCov, peps, relativeError, haussholder, optimiseMult, udot,
136 -- * Auxiliary classes 137 -- * Auxiliary classes
137 Element, Container, Product, Numeric, Contraction, LSDiv, 138 Element, Container, Product, Contraction(..), Numeric, LSDiv,
138 Complexable, RealElement, 139 Complexable, RealElement,
139 RealOf, ComplexOf, SingleOf, DoubleOf, 140 RealOf, ComplexOf, SingleOf, DoubleOf,
140 IndexOf, 141 IndexOf,
141 Field, 142 Field,
142 Normed, 143 Normed,
143 CGMat, Transposable, 144 Transposable,
144 ℕ,ℤ,ℝ,ℂ,ℝn,ℂn, 𝑖, i_C --ℍ 145 CGState(..),
146 Testable(..),
147 ℕ,ℤ,ℝ,ℂ, 𝑖, i_C --ℍ
145) where 148) where
146 149
147import Numeric.LinearAlgebra.Data 150import Numeric.LinearAlgebra.Data
148 151
149import Numeric.Matrix() 152import Numeric.Matrix()
150import Numeric.Vector() 153import Numeric.Vector()
151import Data.Packed.Numeric 154import Data.Packed.Numeric hiding ((<>))
152import Numeric.LinearAlgebra.Algorithms 155import Numeric.LinearAlgebra.Algorithms
153import Numeric.LinearAlgebra.Util 156import Numeric.LinearAlgebra.Util
154import Numeric.LinearAlgebra.Random 157import Numeric.LinearAlgebra.Random
155import Numeric.Sparse(smXv) 158import Numeric.Sparse((!#>))
156import Numeric.LinearAlgebra.Util.CG 159import Numeric.LinearAlgebra.Util.CG
157 160
161-- | matrix product
162(<>) :: Numeric t => Matrix t -> Matrix t -> Matrix t
163(<>) = mXm
164infixr 8 <>
158 165
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs
index 3128a24..3417a5e 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Data.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs
@@ -49,13 +49,9 @@ module Numeric.LinearAlgebra.Data(
49 find, maxIndex, minIndex, maxElement, minElement, atIndex, 49 find, maxIndex, minIndex, maxElement, minElement, atIndex,
50 50
51 -- * Sparse 51 -- * Sparse
52 SMatrix, AssocMatrix, mkCSR, toDense, 52 GMatrix, AssocMatrix, mkSparse, toDense,
53 mkDiag, 53 mkDiagR, dense,
54
55 -- * Static dimensions
56 54
57 Static, ddata, R, vect0, sScalar, vect2, vect3, (&),
58
59 -- * IO 55 -- * IO
60 disp, 56 disp,
61 loadMatrix, saveMatrix, 57 loadMatrix, saveMatrix,
@@ -79,9 +75,8 @@ module Numeric.LinearAlgebra.Data(
79import Data.Packed.Vector 75import Data.Packed.Vector
80import Data.Packed.Matrix 76import Data.Packed.Matrix
81import Data.Packed.Numeric 77import Data.Packed.Numeric
82import Numeric.LinearAlgebra.Util hiding ((&)) 78import Numeric.LinearAlgebra.Util hiding ((&),(#))
83import Data.Complex 79import Data.Complex
84import Numeric.Sparse 80import Numeric.Sparse
85import Numeric.LinearAlgebra.Util.Static
86 81
87 82
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs
new file mode 100644
index 0000000..db15705
--- /dev/null
+++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs
@@ -0,0 +1,337 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE FlexibleContexts #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE EmptyDataDecls #-}
9{-# LANGUAGE Rank2Types #-}
10{-# LANGUAGE FlexibleInstances #-}
11{-# LANGUAGE TypeOperators #-}
12{-# LANGUAGE ViewPatterns #-}
13{-# LANGUAGE GADTs #-}
14
15
16{- |
17Module : Numeric.LinearAlgebra.Real
18Copyright : (c) Alberto Ruiz 2006-14
19License : BSD3
20Stability : provisional
21
22Experimental interface for real arrays with statically checked dimensions.
23
24-}
25
26module Numeric.LinearAlgebra.Real(
27 -- * Vector
28 R,
29 vec2, vec3, vec4, 𝕧, (&),
30 -- * Matrix
31 L, Sq,
32 𝕞,
33 (#),(¦),(——),
34 Konst(..),
35 eye,
36 diagR, diag,
37 blockAt,
38 -- * Products
39 (<>),(#>),(<·>),
40 -- * Pretty printing
41 Disp(..),
42 -- * Misc
43 Dim, unDim,
44 module Numeric.HMatrix
45) where
46
47
48import GHC.TypeLits
49import Numeric.HMatrix hiding ((<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——))
50import qualified Numeric.HMatrix as LA
51import Data.Packed.ST
52
53newtype Dim (n :: Nat) t = Dim t
54 deriving Show
55
56unDim :: Dim n t -> t
57unDim (Dim x) = x
58
59data Proxy :: Nat -> *
60
61
62lift1F
63 :: (c t -> c t)
64 -> Dim n (c t) -> Dim n (c t)
65lift1F f (Dim v) = Dim (f v)
66
67lift2F
68 :: (c t -> c t -> c t)
69 -> Dim n (c t) -> Dim n (c t) -> Dim n (c t)
70lift2F f (Dim u) (Dim v) = Dim (f u v)
71
72
73
74type R n = Dim n (Vector ℝ)
75
76type L m n = Dim m (Dim n (Matrix ℝ))
77
78
79infixl 4 &
80(&) :: forall n . KnownNat n
81 => R n -> ℝ -> R (n+1)
82Dim v & x = Dim (vjoin [v', scalar x])
83 where
84 d = fromIntegral . natVal $ (undefined :: Proxy n)
85 v' | d > 1 && size v == 1 = LA.konst (v!0) d
86 | otherwise = v
87
88
89-- vect0 :: R 0
90-- vect0 = Dim (fromList[])
91
92𝕧 :: ℝ -> R 1
93𝕧 = Dim . scalar
94
95
96vec2 :: ℝ -> ℝ -> R 2
97vec2 a b = Dim $ runSTVector $ do
98 v <- newUndefinedVector 2
99 writeVector v 0 a
100 writeVector v 1 b
101 return v
102
103vec3 :: ℝ -> ℝ -> ℝ -> R 3
104vec3 a b c = Dim $ runSTVector $ do
105 v <- newUndefinedVector 3
106 writeVector v 0 a
107 writeVector v 1 b
108 writeVector v 2 c
109 return v
110
111
112vec4 :: ℝ -> ℝ -> ℝ -> ℝ -> R 4
113vec4 a b c d = Dim $ runSTVector $ do
114 v <- newUndefinedVector 4
115 writeVector v 0 a
116 writeVector v 1 b
117 writeVector v 2 c
118 writeVector v 3 d
119 return v
120
121
122
123
124instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t))
125 where
126 (+) = lift2F (+)
127 (*) = lift2F (*)
128 (-) = lift2F (-)
129 abs = lift1F abs
130 signum = lift1F signum
131 negate = lift1F negate
132 fromInteger x = Dim (fromInteger x)
133
134instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
135 where
136 (+) = (lift2F . lift2F) (+)
137 (*) = (lift2F . lift2F) (*)
138 (-) = (lift2F . lift2F) (-)
139 abs = (lift1F . lift1F) abs
140 signum = (lift1F . lift1F) signum
141 negate = (lift1F . lift1F) negate
142 fromInteger x = Dim (Dim (fromInteger x))
143
144--------------------------------------------------------------------------------
145
146class Konst t
147 where
148 konst :: ℝ -> t
149
150instance forall n. KnownNat n => Konst (R n)
151 where
152 konst x = Dim (LA.konst x d)
153 where
154 d = fromIntegral . natVal $ (undefined :: Proxy n)
155
156instance forall m n . (KnownNat m, KnownNat n) => Konst (L m n)
157 where
158 konst x = Dim (Dim (LA.konst x (m',n')))
159 where
160 m' = fromIntegral . natVal $ (undefined :: Proxy m)
161 n' = fromIntegral . natVal $ (undefined :: Proxy n)
162
163--------------------------------------------------------------------------------
164
165diagR :: forall m n k . (KnownNat m, KnownNat n) => ℝ -> R k -> L m n
166diagR x v = Dim (Dim (diagRect x (unDim v) m' n'))
167 where
168 m' = fromIntegral . natVal $ (undefined :: Proxy m)
169 n' = fromIntegral . natVal $ (undefined :: Proxy n)
170
171diag :: KnownNat n => R n -> Sq n
172diag = diagR 0
173
174--------------------------------------------------------------------------------
175
176blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n
177blockAt x r c a = Dim (Dim res)
178 where
179 z = scalar x
180 z1 = LA.konst x (r,c)
181 z2 = LA.konst x (max 0 (m'-(ra+r)), max 0 (n'-(ca+c)))
182 ra = min (rows a) . max 0 $ m'-r
183 ca = min (cols a) . max 0 $ n'-c
184 sa = subMatrix (0,0) (ra, ca) a
185 m' = fromIntegral . natVal $ (undefined :: Proxy m)
186 n' = fromIntegral . natVal $ (undefined :: Proxy n)
187 res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]]
188
189{-
190matrix :: (KnownNat m, KnownNat n) => Matrix Double -> L n m
191matrix = blockAt 0 0 0
192-}
193
194--------------------------------------------------------------------------------
195
196class Disp t
197 where
198 disp :: Int -> t -> IO ()
199
200instance Disp (L n m)
201 where
202 disp n (d2 -> a) = do
203 if rows a == 1 && cols a == 1
204 then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a)
205 else putStr "Dim " >> LA.disp n a
206
207instance Disp (R n)
208 where
209 disp n (unDim -> v) = do
210 let su = LA.dispf n (asRow v)
211 if LA.size v == 1
212 then putStrLn $ "Const " ++ (last . words $ su )
213 else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su)
214
215--------------------------------------------------------------------------------
216
217infixl 3 #
218(#) :: L r c -> R c -> L (r+1) c
219Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v))
220
221
222𝕞 :: forall n . KnownNat n => L 0 n
223𝕞 = Dim (Dim (LA.konst 0 (0,d)))
224 where
225 d = fromIntegral . natVal $ (undefined :: Proxy n)
226
227infixl 3 ¦
228(¦) :: L r c1 -> L r c2 -> L r (c1+c2)
229Dim (Dim a) ¦ Dim (Dim b) = Dim (Dim (a LA.¦ b))
230
231infixl 2 ——
232(——) :: L r1 c -> L r2 c -> L (r1+r2) c
233Dim (Dim a) —— Dim (Dim b) = Dim (Dim (a LA.—— b))
234
235
236{-
237
238-}
239
240type Sq n = L n n
241
242type GL = (KnownNat n, KnownNat m) => L m n
243type GSq = KnownNat n => Sq n
244
245infixr 8 <>
246(<>) :: L m k -> L k n -> L m n
247(d2 -> a) <> (d2 -> b) = Dim (Dim (a LA.<> b))
248
249infixr 8 #>
250(#>) :: L m n -> R n -> R m
251(d2 -> m) #> (unDim -> v) = Dim (m LA.#> v)
252
253infixr 8 <·>
254(<·>) :: R n -> R n -> ℝ
255(unDim -> u) <·> (unDim -> v) = udot u v
256
257
258d2 :: forall c (n :: Nat) (n1 :: Nat). Dim n1 (Dim n c) -> c
259d2 = unDim . unDim
260
261
262instance Transposable (L m n) (L n m)
263 where
264 tr (Dim (Dim a)) = Dim (Dim (tr a))
265
266
267eye :: forall n . KnownNat n => Sq n
268eye = Dim (Dim (ident d))
269 where
270 d = fromIntegral . natVal $ (undefined :: Proxy n)
271
272
273--------------------------------------------------------------------------------
274
275test :: (Bool, IO ())
276test = (ok,info)
277 where
278 ok = d2 (eye :: Sq 5) == ident 5
279 && d2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..]
280 && d2 (tm :: L 3 5) == mat 5 [1..15]
281 && thingS == thingD
282 && precS == precD
283
284 info = do
285 print $ u
286 print $ v
287 print (eye :: Sq 3)
288 print $ ((u & 5) + 1) <·> v
289 print (tm :: L 2 5)
290 print (tm <> sm :: L 2 3)
291 print thingS
292 print thingD
293 print precS
294 print precD
295
296 u = vec2 3 5
297
298 v = 𝕧 2 & 4 & 7
299
300 mTm :: L n m -> Sq m
301 mTm a = tr a <> a
302
303 tm :: GL
304 tm = lmat 0 [1..]
305
306 lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n
307 lmat z xs = Dim . Dim . reshape n' . fromList . take (m'*n') $ xs ++ repeat z
308 where
309 m' = fromIntegral . natVal $ (undefined :: Proxy m)
310 n' = fromIntegral . natVal $ (undefined :: Proxy n)
311
312 sm :: GSq
313 sm = lmat 0 [1..]
314
315 thingS = (u & 1) <·> tr q #> q #> v
316 where
317 q = tm :: L 10 3
318
319 thingD = vjoin [unDim u, 1] LA.<·> tr m LA.#> m LA.#> unDim v
320 where
321 m = mat 3 [1..30]
322
323 precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v
324 precD = 1 + 2 * vjoin[unDim u, 6] LA.<·> LA.konst 2 (size (unDim u) +1, size (unDim v)) LA.#> unDim v
325
326
327instance (KnownNat n', KnownNat m') => Testable (L n' m')
328 where
329 checkT _ = test
330
331{-
332do (snd test)
333fst test
334-}
335
336
337
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs
index a319785..47b1090 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util.hs
@@ -32,7 +32,7 @@ module Numeric.LinearAlgebra.Util(
32 rand, randn, 32 rand, randn,
33 cross, 33 cross,
34 norm, 34 norm,
35 ℕ,ℤ,ℝ,ℂ,ℝn,ℂn,𝑖,i_C, --ℍ 35 ℕ,ℤ,ℝ,ℂ,𝑖,i_C, --ℍ
36 norm_1, norm_2, norm_0, norm_Inf, norm_Frob, norm_nuclear, 36 norm_1, norm_2, norm_0, norm_Inf, norm_Frob, norm_nuclear,
37 mnorm_1, mnorm_2, mnorm_0, mnorm_Inf, 37 mnorm_1, mnorm_2, mnorm_0, mnorm_Inf,
38 unitary, 38 unitary,
@@ -70,8 +70,8 @@ type ℝ = Double
70type ℕ = Int 70type ℕ = Int
71type ℤ = Int 71type ℤ = Int
72type ℂ = Complex Double 72type ℂ = Complex Double
73type ℝn = Vector ℝ 73--type ℝn = Vector ℝ
74type ℂn = Vector ℂ 74--type ℂn = Vector ℂ
75--newtype ℍ m = H m 75--newtype ℍ m = H m
76 76
77i_C, 𝑖 :: ℂ 77i_C, 𝑖 :: ℂ
@@ -84,7 +84,7 @@ i_C = 𝑖
84fromList [1.0,2.0,3.0,4.0,5.0] 84fromList [1.0,2.0,3.0,4.0,5.0]
85 85
86-} 86-}
87vect :: [ℝ] -> ℝn 87vect :: [ℝ] -> Vector
88vect = fromList 88vect = fromList
89 89
90{- | create a real matrix 90{- | create a real matrix
@@ -103,18 +103,6 @@ mat
103mat c = reshape c . fromList 103mat c = reshape c . fromList
104 104
105 105
106
107class ( Container Vector t
108 , Container Matrix t
109 , Konst t Int Vector
110 , Konst t (Int,Int) Matrix
111 ) => Numeric t
112
113instance Numeric Double
114instance Numeric (Complex Double)
115
116
117
118{- | print a real matrix with given number of digits after the decimal point 106{- | print a real matrix with given number of digits after the decimal point
119 107
120>>> disp 5 $ ident 2 / 3 108>>> disp 5 $ ident 2 / 3
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
index 5e2ea84..50372f1 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs
@@ -3,11 +3,14 @@
3 3
4module Numeric.LinearAlgebra.Util.CG( 4module Numeric.LinearAlgebra.Util.CG(
5 cgSolve, cgSolve', 5 cgSolve, cgSolve',
6 CGMat, CGState(..), R, V 6 CGState(..), R, V
7) where 7) where
8 8
9import Data.Packed.Numeric 9import Data.Packed.Numeric
10import Numeric.Sparse
10import Numeric.Vector() 11import Numeric.Vector()
12import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..))
13import Control.Arrow((***))
11 14
12{- 15{-
13import Util.Misc(debug, debugMat) 16import Util.Misc(debug, debugMat)
@@ -51,7 +54,7 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx
51 rdx = norm2 dx / max 1 (norm2 x) 54 rdx = norm2 dx / max 1 (norm2 x)
52 55
53conjugrad 56conjugrad
54 :: (Transposable m, Contraction m V V) 57 :: (Transposable m mt, Contraction m V V, Contraction mt V V)
55 => Bool -> m -> V -> V -> R -> R -> [CGState] 58 => Bool -> m -> V -> V -> R -> R -> [CGState]
56conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b 59conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b
57 60
@@ -82,27 +85,88 @@ takeUntil q xs = a++ take 1 b
82 where 85 where
83 (a,b) = break q xs 86 (a,b) = break q xs
84 87
85class (Transposable m, Contraction m V V) => CGMat m
86
87cgSolve 88cgSolve
88 :: CGMat m 89 :: Bool -- ^ is symmetric
89 => Bool -- ^ is symmetric 90 -> GMatrix -- ^ coefficient matrix
90 -> m -- ^ coefficient matrix
91 -> Vector Double -- ^ right-hand side 91 -> Vector Double -- ^ right-hand side
92 -> Vector Double -- ^ solution 92 -> Vector Double -- ^ solution
93cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 93cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0
94 where 94 where
95 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) 95 n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double))
96 96
97cgSolve' 97cgSolve'
98 :: CGMat m 98 :: Bool -- ^ symmetric
99 => Bool -- ^ symmetric
100 -> R -- ^ relative tolerance for the residual (e.g. 1E-4) 99 -> R -- ^ relative tolerance for the residual (e.g. 1E-4)
101 -> R -- ^ relative tolerance for δx (e.g. 1E-3) 100 -> R -- ^ relative tolerance for δx (e.g. 1E-3)
102 -> Int -- ^ maximum number of iterations 101 -> Int -- ^ maximum number of iterations
103 -> m -- ^ coefficient matrix 102 -> GMatrix -- ^ coefficient matrix
104 -> V -- ^ initial solution 103 -> V -- ^ initial solution
105 -> V -- ^ right-hand side 104 -> V -- ^ right-hand side
106 -> [CGState] -- ^ solution 105 -> [CGState] -- ^ solution
107cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es 106cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es
108 107
108
109--------------------------------------------------------------------------------
110
111instance Testable GMatrix
112 where
113 checkT _ = (ok,info)
114 where
115 sma = convo2 20 3
116 x1 = vect [1..20]
117 x2 = vect [1..40]
118 sm = mkSparse sma
119 dm = toDense sma
120
121 s1 = sm !#> x1
122 d1 = dm #> x1
123
124 s2 = tr sm !#> x2
125 d2 = tr dm #> x2
126
127 sdia = mkDiagR 40 20 (vect [1..10])
128 s3 = sdia !#> x1
129 s4 = tr sdia !#> x2
130 ddia = diagRect 0 (vect [1..10]) 40 20
131 d3 = ddia #> x1
132 d4 = tr ddia #> x2
133
134 v = testb 40
135 s5 = cgSolve False sm v
136 d5 = denseSolve dm v
137
138 info = do
139 print sm
140 disp (toDense sma)
141 print s1; print d1
142 print s2; print d2
143 print s3; print d3
144 print s4; print d4
145 print s5; print d5
146 print $ relativeError Infinity s5 d5
147
148 ok = s1==d1
149 && s2==d2
150 && s3==d3
151 && s4==d4
152 && relativeError Infinity s5 d5 < 1E-10
153
154 disp = putStr . dispf 2
155
156 vect = fromList :: [Double] -> Vector Double
157
158 convomat :: Int -> Int -> AssocMatrix
159 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
160
161 convo2 :: Int -> Int -> AssocMatrix
162 convo2 n k = m1 ++ m2
163 where
164 m1 = convomat n k
165 m2 = map (((+n) *** id) *** id) m1
166
167 testb n = vect $ take n $ cycle ([0..10]++[9,8..1])
168
169 denseSolve a = flatten . linearSolveLS a . asColumn
170
171 -- mkDiag v = mkDiagR (dim v) (dim v) v
172
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs
deleted file mode 100644
index a3f8eb0..0000000
--- a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs
+++ /dev/null
@@ -1,70 +0,0 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FlexibleContexts #-}
6{-# LANGUAGE ScopedTypeVariables #-}
7{-# LANGUAGE EmptyDataDecls #-}
8{-# LANGUAGE Rank2Types #-}
9{-# LANGUAGE FlexibleInstances #-}
10{-# LANGUAGE TypeOperators #-}
11
12module Numeric.LinearAlgebra.Util.Static(
13 Static (ddata),
14 R,
15 vect0, sScalar, vect2, vect3, (&)
16) where
17
18
19import GHC.TypeLits
20import Data.Packed.Numeric
21import Numeric.Vector()
22import Numeric.LinearAlgebra.Util(Numeric,ℝ)
23
24lift1F :: (Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t)
25lift1F f (Static v) = Static (f v)
26
27lift2F :: (Vector t -> Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t) -> Static n (Vector t)
28lift2F f (Static u) (Static v) = Static (f u v)
29
30newtype Static (n :: Nat) t = Static { ddata :: t } deriving Show
31
32type R n = Static n (Vector ℝ)
33
34
35infixl 4 &
36(&) :: R n -> ℝ -> R (n+1)
37Static v & x = Static (vjoin [v, scalar x])
38
39vect0 :: R 0
40vect0 = Static (fromList[])
41
42sScalar :: ℝ -> R 1
43sScalar = Static . scalar
44
45
46vect2 :: ℝ -> ℝ -> R 2
47vect2 x1 x2 = Static (fromList [x1,x2])
48
49vect3 :: ℝ -> ℝ -> ℝ -> R 3
50vect3 x1 x2 x3 = Static (fromList [x1,x2,x3])
51
52
53
54
55
56
57instance forall n t . (KnownNat n, Num (Vector t), Numeric t )=> Num (Static n (Vector t))
58 where
59 (+) = lift2F add
60 (*) = lift2F mul
61 (-) = lift2F sub
62 abs = lift1F abs
63 signum = lift1F signum
64 negate = lift1F (scale (-1))
65 fromInteger x = Static (konst (fromInteger x) d)
66 where
67 d = fromIntegral . natVal $ (undefined :: Proxy n)
68
69data Proxy :: Nat -> *
70
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs
index 2df4578..4d05bdc 100644
--- a/packages/base/src/Numeric/Sparse.hs
+++ b/packages/base/src/Numeric/Sparse.hs
@@ -3,11 +3,11 @@
3{-# LANGUAGE FlexibleInstances #-} 3{-# LANGUAGE FlexibleInstances #-}
4 4
5module Numeric.Sparse( 5module Numeric.Sparse(
6 SMatrix(..), 6 GMatrix(..),
7 mkCSR, mkDiag, 7 mkSparse, mkDiagR, dense,
8 AssocMatrix, 8 AssocMatrix,
9 toDense, 9 toDense,
10 smXv 10 gmXv, (!#>)
11)where 11)where
12 12
13import Data.Packed.Numeric 13import Data.Packed.Numeric
@@ -17,8 +17,7 @@ import Control.Arrow((***))
17import Control.Monad(when) 17import Control.Monad(when)
18import Data.List(groupBy, sort) 18import Data.List(groupBy, sort)
19import Foreign.C.Types(CInt(..)) 19import Foreign.C.Types(CInt(..))
20import Numeric.LinearAlgebra.Util.CG(CGMat,cgSolve) 20
21import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..))
22import Data.Packed.Development 21import Data.Packed.Development
23import System.IO.Unsafe(unsafePerformIO) 22import System.IO.Unsafe(unsafePerformIO)
24import Foreign(Ptr) 23import Foreign(Ptr)
@@ -29,7 +28,7 @@ c ~!~ msg = when c (error msg)
29 28
30type AssocMatrix = [((Int,Int),Double)] 29type AssocMatrix = [((Int,Int),Double)]
31 30
32data SMatrix 31data GMatrix
33 = CSR 32 = CSR
34 { csrVals :: Vector Double 33 { csrVals :: Vector Double
35 , csrCols :: Vector CInt 34 , csrCols :: Vector CInt
@@ -46,14 +45,26 @@ data SMatrix
46 } 45 }
47 | Diag 46 | Diag
48 { diagVals :: Vector Double 47 { diagVals :: Vector Double
48 , nRows :: Int
49 , nCols :: Int
50 }
51 | Dense
52 { gmDense :: Matrix Double
49 , nRows :: Int 53 , nRows :: Int
50 , nCols :: Int 54 , nCols :: Int
51 } 55 }
52-- | Banded 56-- | Banded
53 deriving Show 57 deriving Show
54 58
55mkCSR :: AssocMatrix -> SMatrix 59dense :: Matrix Double -> GMatrix
56mkCSR sm' = CSR{..} 60dense m = Dense{..}
61 where
62 gmDense = m
63 nRows = rows m
64 nCols = cols m
65
66mkSparse :: AssocMatrix -> GMatrix
67mkSparse sm' = CSR{..}
57 where 68 where
58 sm = sort sm' 69 sm = sort sm'
59 rws = map ((fromList *** fromList) 70 rws = map ((fromList *** fromList)
@@ -78,37 +89,47 @@ mkDiagR r c v
78 nCols = c 89 nCols = c
79 diagVals = v 90 diagVals = v
80 91
81mkDiag v = mkDiagR (dim v) (dim v) v
82
83 92
84type IV t = CInt -> Ptr CInt -> t 93type IV t = CInt -> Ptr CInt -> t
85type V t = CInt -> Ptr Double -> t 94type V t = CInt -> Ptr Double -> t
86type SMxV = V (IV (IV (V (V (IO CInt))))) 95type SMxV = V (IV (IV (V (V (IO CInt)))))
87 96
88smXv :: SMatrix -> Vector Double -> Vector Double 97gmXv :: GMatrix -> Vector Double -> Vector Double
89smXv CSR{..} v = unsafePerformIO $ do 98gmXv CSR{..} v = unsafePerformIO $ do
90 dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) 99 dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
91 r <- createVector nRows 100 r <- createVector nRows
92 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" 101 app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv"
93 return r 102 return r
94 103
95smXv CSC{..} v = unsafePerformIO $ do 104gmXv CSC{..} v = unsafePerformIO $ do
96 dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) 105 dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v)
97 r <- createVector nRows 106 r <- createVector nRows
98 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" 107 app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv"
99 return r 108 return r
100 109
101smXv Diag{..} v 110gmXv Diag{..} v
102 | dim v == nCols 111 | dim v == nCols
103 = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals 112 = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals
104 , konst 0 (nRows - dim diagVals) ] 113 , konst 0 (nRows - dim diagVals) ]
105 | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" 114 | otherwise = error $ printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
106 nRows nCols (dim diagVals) (dim v) 115 nRows nCols (dim diagVals) (dim v)
107 116
117gmXv Dense{..} v
118 | dim v == nCols
119 = mXv gmDense v
120 | otherwise = error $ printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d"
121 nRows nCols (dim v)
122
108 123
109instance Contraction SMatrix (Vector Double) (Vector Double) 124-- | general matrix - vector product
125infixr 8 !#>
126(!#>) :: GMatrix -> Vector Double -> Vector Double
127(!#>) = gmXv
128
129
130instance Contraction GMatrix (Vector Double) (Vector Double)
110 where 131 where
111 contraction = smXv 132 contraction = gmXv
112 133
113-------------------------------------------------------------------------------- 134--------------------------------------------------------------------------------
114 135
@@ -127,75 +148,11 @@ toDense asm = assoc (r+1,c+1) 0 asm
127 148
128 149
129 150
130instance Transposable SMatrix 151instance Transposable GMatrix GMatrix
131 where 152 where
132 tr (CSR vs cs rs n m) = CSC vs cs rs m n 153 tr (CSR vs cs rs n m) = CSC vs cs rs m n
133 tr (CSC vs rs cs n m) = CSR vs rs cs m n 154 tr (CSC vs rs cs n m) = CSR vs rs cs m n
134 tr (Diag v n m) = Diag v m n 155 tr (Diag v n m) = Diag v m n
156 tr (Dense a n m) = Dense (tr a) m n
135 157
136 158
137instance CGMat SMatrix
138instance CGMat (Matrix Double)
139
140--------------------------------------------------------------------------------
141
142instance Testable SMatrix
143 where
144 checkT _ = (ok,info)
145 where
146 sma = convo2 20 3
147 x1 = vect [1..20]
148 x2 = vect [1..40]
149 sm = mkCSR sma
150 dm = toDense sma
151
152 s1 = sm ◇ x1
153 d1 = dm ◇ x1
154
155 s2 = tr sm ◇ x2
156 d2 = tr dm ◇ x2
157
158 sdia = mkDiagR 40 20 (vect [1..10])
159 s3 = sdia ◇ x1
160 s4 = tr sdia ◇ x2
161 ddia = diagRect 0 (vect [1..10]) 40 20
162 d3 = ddia ◇ x1
163 d4 = tr ddia ◇ x2
164
165 v = testb 40
166 s5 = cgSolve False sm v
167 d5 = denseSolve dm v
168
169 info = do
170 print sm
171 disp (toDense sma)
172 print s1; print d1
173 print s2; print d2
174 print s3; print d3
175 print s4; print d4
176 print s5; print d5
177 print $ relativeError Infinity s5 d5
178
179 ok = s1==d1
180 && s2==d2
181 && s3==d3
182 && s4==d4
183 && relativeError Infinity s5 d5 < 1E-10
184
185 disp = putStr . dispf 2
186
187 vect = fromList :: [Double] -> Vector Double
188
189 convomat :: Int -> Int -> AssocMatrix
190 convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]]
191
192 convo2 :: Int -> Int -> AssocMatrix
193 convo2 n k = m1 ++ m2
194 where
195 m1 = convomat n k
196 m2 = map (((+n) *** id) *** id) m1
197
198 testb n = vect $ take n $ cycle ([0..10]++[9,8..1])
199
200 denseSolve a = flatten . linearSolveLS a . asColumn
201