diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Data/Packed/Internal/Numeric.hs | 7 | ||||
-rw-r--r-- | packages/base/src/Data/Packed/Numeric.hs | 41 | ||||
-rw-r--r-- | packages/base/src/Numeric/HMatrix.hs | 63 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Data.hs | 11 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Real.hs | 337 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util.hs | 20 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | 86 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Util/Static.hs | 70 | ||||
-rw-r--r-- | packages/base/src/Numeric/Sparse.hs | 127 |
9 files changed, 535 insertions, 227 deletions
diff --git a/packages/base/src/Data/Packed/Internal/Numeric.hs b/packages/base/src/Data/Packed/Internal/Numeric.hs index 3c1c1d0..0205a17 100644 --- a/packages/base/src/Data/Packed/Internal/Numeric.hs +++ b/packages/base/src/Data/Packed/Internal/Numeric.hs | |||
@@ -3,6 +3,7 @@ | |||
3 | {-# LANGUAGE FlexibleContexts #-} | 3 | {-# LANGUAGE FlexibleContexts #-} |
4 | {-# LANGUAGE FlexibleInstances #-} | 4 | {-# LANGUAGE FlexibleInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE UndecidableInstances #-} | 7 | {-# LANGUAGE UndecidableInstances #-} |
7 | 8 | ||
8 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
@@ -692,12 +693,12 @@ condV f a b l e t = f a' b' l' e' t' | |||
692 | 693 | ||
693 | -------------------------------------------------------------------------------- | 694 | -------------------------------------------------------------------------------- |
694 | 695 | ||
695 | class Transposable t | 696 | class Transposable m mt | m -> mt, mt -> m |
696 | where | 697 | where |
697 | -- | (conjugate) transpose | 698 | -- | (conjugate) transpose |
698 | tr :: t -> t | 699 | tr :: m -> mt |
699 | 700 | ||
700 | instance (Container Vector t) => Transposable (Matrix t) | 701 | instance (Container Vector t) => Transposable (Matrix t) (Matrix t) |
701 | where | 702 | where |
702 | tr = ctrans | 703 | tr = ctrans |
703 | 704 | ||
diff --git a/packages/base/src/Data/Packed/Numeric.hs b/packages/base/src/Data/Packed/Numeric.hs index 01cf6c5..7d88cbc 100644 --- a/packages/base/src/Data/Packed/Numeric.hs +++ b/packages/base/src/Data/Packed/Numeric.hs | |||
@@ -32,7 +32,7 @@ module Data.Packed.Numeric ( | |||
32 | diag, ident, | 32 | diag, ident, |
33 | ctrans, | 33 | ctrans, |
34 | -- * Generic operations | 34 | -- * Generic operations |
35 | Container(..), | 35 | Container(..), Numeric, |
36 | -- add, mul, sub, divide, equal, scaleRecip, addConstant, | 36 | -- add, mul, sub, divide, equal, scaleRecip, addConstant, |
37 | scalar, conj, scale, arctan2, cmap, | 37 | scalar, conj, scale, arctan2, cmap, |
38 | atIndex, minIndex, maxIndex, minElement, maxElement, | 38 | atIndex, minIndex, maxIndex, minElement, maxElement, |
@@ -40,7 +40,7 @@ module Data.Packed.Numeric ( | |||
40 | step, cond, find, assoc, accum, | 40 | step, cond, find, assoc, accum, |
41 | Transposable(..), Linear(..), | 41 | Transposable(..), Linear(..), |
42 | -- * Matrix product | 42 | -- * Matrix product |
43 | Product(..), udot, dot, (◇), | 43 | Product(..), udot, dot, (◇), (<·>), (#>), |
44 | Mul(..), | 44 | Mul(..), |
45 | Contraction(..),(<.>), | 45 | Contraction(..),(<.>), |
46 | optimiseMult, | 46 | optimiseMult, |
@@ -96,7 +96,7 @@ linspace n (a,b) = addConstant a $ scale s $ fromList $ map fromIntegral [0 .. n | |||
96 | 96 | ||
97 | -------------------------------------------------------- | 97 | -------------------------------------------------------- |
98 | 98 | ||
99 | {- | Matrix product, matrix - vector product, and dot product (equivalent to 'contraction') | 99 | {- Matrix product, matrix - vector product, and dot product (equivalent to 'contraction') |
100 | 100 | ||
101 | (This operator can also be written using the unicode symbol ◇ (25c7).) | 101 | (This operator can also be written using the unicode symbol ◇ (25c7).) |
102 | 102 | ||
@@ -138,9 +138,8 @@ For complex vectors the first argument is conjugated: | |||
138 | >>> fromList [1,i,1-i] <.> complex a | 138 | >>> fromList [1,i,1-i] <.> complex a |
139 | fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0] | 139 | fromList [10.0 :+ 4.0,12.0 :+ 4.0,14.0 :+ 4.0,16.0 :+ 4.0] |
140 | -} | 140 | -} |
141 | infixl 7 <.> | 141 | |
142 | (<.>) :: Contraction a b c => a -> b -> c | 142 | |
143 | (<.>) = contraction | ||
144 | 143 | ||
145 | 144 | ||
146 | class Contraction a b c | a b -> c | 145 | class Contraction a b c | a b -> c |
@@ -160,6 +159,23 @@ instance (Container Vector t, Product t) => Contraction (Vector t) (Matrix t) (V | |||
160 | instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where | 159 | instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where |
161 | contraction = mXm | 160 | contraction = mXm |
162 | 161 | ||
162 | -------------------------------------------------------------------------------- | ||
163 | |||
164 | infixl 7 <.> | ||
165 | -- | An infix synonym for 'dot' | ||
166 | (<.>) :: Numeric t => Vector t -> Vector t -> t | ||
167 | (<.>) = dot | ||
168 | |||
169 | |||
170 | infixr 8 <·>, #> | ||
171 | -- | dot product | ||
172 | (<·>) :: Numeric t => Vector t -> Vector t -> t | ||
173 | (<·>) = dot | ||
174 | |||
175 | |||
176 | -- | matrix-vector product | ||
177 | (#>) :: Numeric t => Matrix t -> Vector t -> Vector t | ||
178 | (#>) = mXv | ||
163 | 179 | ||
164 | -------------------------------------------------------------------------------- | 180 | -------------------------------------------------------------------------------- |
165 | 181 | ||
@@ -286,3 +302,16 @@ meanCov x = (med,cov) where | |||
286 | 302 | ||
287 | -------------------------------------------------------------------------------- | 303 | -------------------------------------------------------------------------------- |
288 | 304 | ||
305 | class ( Container Vector t | ||
306 | , Container Matrix t | ||
307 | , Konst t Int Vector | ||
308 | , Konst t (Int,Int) Matrix | ||
309 | , Product t | ||
310 | ) => Numeric t | ||
311 | |||
312 | instance Numeric Double | ||
313 | instance Numeric (Complex Double) | ||
314 | instance Numeric Float | ||
315 | instance Numeric (Complex Float) | ||
316 | |||
317 | -------------------------------------------------------------------------------- | ||
diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs index d5c66fb..1c70ef6 100644 --- a/packages/base/src/Numeric/HMatrix.hs +++ b/packages/base/src/Numeric/HMatrix.hs | |||
@@ -10,16 +10,16 @@ Stability : provisional | |||
10 | ----------------------------------------------------------------------------- | 10 | ----------------------------------------------------------------------------- |
11 | module Numeric.HMatrix ( | 11 | module Numeric.HMatrix ( |
12 | 12 | ||
13 | -- * Basic types and data processing | 13 | -- * Basic types and data processing |
14 | module Numeric.LinearAlgebra.Data, | 14 | module Numeric.LinearAlgebra.Data, |
15 | 15 | ||
16 | -- * Arithmetic and numeric classes | 16 | -- * Arithmetic and numeric classes |
17 | -- | | 17 | -- | |
18 | -- The standard numeric classes are defined elementwise: | 18 | -- The standard numeric classes are defined elementwise: |
19 | -- | 19 | -- |
20 | -- >>> fromList [1,2,3] * fromList [3,0,-2 :: Double] | 20 | -- >>> fromList [1,2,3] * fromList [3,0,-2 :: Double] |
21 | -- fromList [3.0,0.0,-6.0] | 21 | -- fromList [3.0,0.0,-6.0] |
22 | -- | 22 | -- |
23 | -- >>> (3><3) [1..9] * ident 3 :: Matrix Double | 23 | -- >>> (3><3) [1..9] * ident 3 :: Matrix Double |
24 | -- (3><3) | 24 | -- (3><3) |
25 | -- [ 1.0, 0.0, 0.0 | 25 | -- [ 1.0, 0.0, 0.0 |
@@ -29,7 +29,7 @@ module Numeric.HMatrix ( | |||
29 | -- In arithmetic operations single-element vectors and matrices | 29 | -- In arithmetic operations single-element vectors and matrices |
30 | -- (created from numeric literals or using 'scalar') automatically | 30 | -- (created from numeric literals or using 'scalar') automatically |
31 | -- expand to match the dimensions of the other operand: | 31 | -- expand to match the dimensions of the other operand: |
32 | -- | 32 | -- |
33 | -- >>> 5 + 2*ident 3 :: Matrix Double | 33 | -- >>> 5 + 2*ident 3 :: Matrix Double |
34 | -- (3><3) | 34 | -- (3><3) |
35 | -- [ 7.0, 5.0, 5.0 | 35 | -- [ 7.0, 5.0, 5.0 |
@@ -37,13 +37,14 @@ module Numeric.HMatrix ( | |||
37 | -- , 5.0, 5.0, 7.0 ] | 37 | -- , 5.0, 5.0, 7.0 ] |
38 | -- | 38 | -- |
39 | 39 | ||
40 | -- * Matrix product | 40 | -- * Products |
41 | (<.>), | 41 | -- ** dot |
42 | 42 | (<·>), | |
43 | -- | The overloaded multiplication operators may need type annotations to remove | 43 | -- ** matrix-vector |
44 | -- ambiguity. In those cases we can also use the specific functions 'mXm', 'mXv', and 'dot'. | 44 | (#>),(!#>), |
45 | -- | 45 | -- ** matrix-matrix |
46 | -- The matrix x matrix product is also implemented in the "Data.Monoid" instance, where | 46 | (<>), |
47 | -- | The matrix x matrix product is also implemented in the "Data.Monoid" instance, where | ||
47 | -- single-element matrices (created from numeric literals or using 'scalar') | 48 | -- single-element matrices (created from numeric literals or using 'scalar') |
48 | -- are used for scaling. | 49 | -- are used for scaling. |
49 | -- | 50 | -- |
@@ -55,12 +56,12 @@ module Numeric.HMatrix ( | |||
55 | -- | 56 | -- |
56 | -- 'mconcat' uses 'optimiseMult' to get the optimal association order. | 57 | -- 'mconcat' uses 'optimiseMult' to get the optimal association order. |
57 | 58 | ||
58 | 59 | ||
59 | -- * Other products | 60 | -- ** other |
60 | outer, kronecker, cross, | 61 | outer, kronecker, cross, |
61 | scale, | 62 | scale, |
62 | sumElements, prodElements, | 63 | sumElements, prodElements, |
63 | 64 | ||
64 | -- * Linear Systems | 65 | -- * Linear Systems |
65 | (<\>), | 66 | (<\>), |
66 | linearSolve, | 67 | linearSolve, |
@@ -70,14 +71,14 @@ module Numeric.HMatrix ( | |||
70 | cholSolve, | 71 | cholSolve, |
71 | cgSolve, | 72 | cgSolve, |
72 | cgSolve', | 73 | cgSolve', |
73 | 74 | ||
74 | -- * Inverse and pseudoinverse | 75 | -- * Inverse and pseudoinverse |
75 | inv, pinv, pinvTol, | 76 | inv, pinv, pinvTol, |
76 | 77 | ||
77 | -- * Determinant and rank | 78 | -- * Determinant and rank |
78 | rcond, rank, ranksv, | 79 | rcond, rank, ranksv, |
79 | det, invlndet, | 80 | det, invlndet, |
80 | 81 | ||
81 | -- * Singular value decomposition | 82 | -- * Singular value decomposition |
82 | svd, | 83 | svd, |
83 | fullSVD, | 84 | fullSVD, |
@@ -85,7 +86,7 @@ module Numeric.HMatrix ( | |||
85 | compactSVD, | 86 | compactSVD, |
86 | singularValues, | 87 | singularValues, |
87 | leftSV, rightSV, | 88 | leftSV, rightSV, |
88 | 89 | ||
89 | -- * Eigensystems | 90 | -- * Eigensystems |
90 | eig, eigSH, eigSH', | 91 | eig, eigSH, eigSH', |
91 | eigenvalues, eigenvaluesSH, eigenvaluesSH', | 92 | eigenvalues, eigenvaluesSH, eigenvaluesSH', |
@@ -105,7 +106,7 @@ module Numeric.HMatrix ( | |||
105 | 106 | ||
106 | -- * LU | 107 | -- * LU |
107 | lu, luPacked, | 108 | lu, luPacked, |
108 | 109 | ||
109 | -- * Matrix functions | 110 | -- * Matrix functions |
110 | expm, | 111 | expm, |
111 | sqrtm, | 112 | sqrtm, |
@@ -116,7 +117,7 @@ module Numeric.HMatrix ( | |||
116 | nullVector, | 117 | nullVector, |
117 | nullspaceSVD, | 118 | nullspaceSVD, |
118 | null1, null1sym, | 119 | null1, null1sym, |
119 | 120 | ||
120 | orth, | 121 | orth, |
121 | 122 | ||
122 | -- * Norms | 123 | -- * Norms |
@@ -129,30 +130,36 @@ module Numeric.HMatrix ( | |||
129 | 130 | ||
130 | -- * Random arrays | 131 | -- * Random arrays |
131 | 132 | ||
132 | RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, | 133 | Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, |
133 | 134 | ||
134 | -- * Misc | 135 | -- * Misc |
135 | meanCov, peps, relativeError, haussholder, optimiseMult, dot, udot, mXm, mXv, smXv, (<>), (◇), Seed, checkT, | 136 | meanCov, peps, relativeError, haussholder, optimiseMult, udot, |
136 | -- * Auxiliary classes | 137 | -- * Auxiliary classes |
137 | Element, Container, Product, Numeric, Contraction, LSDiv, | 138 | Element, Container, Product, Contraction(..), Numeric, LSDiv, |
138 | Complexable, RealElement, | 139 | Complexable, RealElement, |
139 | RealOf, ComplexOf, SingleOf, DoubleOf, | 140 | RealOf, ComplexOf, SingleOf, DoubleOf, |
140 | IndexOf, | 141 | IndexOf, |
141 | Field, | 142 | Field, |
142 | Normed, | 143 | Normed, |
143 | CGMat, Transposable, | 144 | Transposable, |
144 | ℕ,ℤ,ℝ,ℂ,ℝn,ℂn, 𝑖, i_C --ℍ | 145 | CGState(..), |
146 | Testable(..), | ||
147 | ℕ,ℤ,ℝ,ℂ, 𝑖, i_C --ℍ | ||
145 | ) where | 148 | ) where |
146 | 149 | ||
147 | import Numeric.LinearAlgebra.Data | 150 | import Numeric.LinearAlgebra.Data |
148 | 151 | ||
149 | import Numeric.Matrix() | 152 | import Numeric.Matrix() |
150 | import Numeric.Vector() | 153 | import Numeric.Vector() |
151 | import Data.Packed.Numeric | 154 | import Data.Packed.Numeric hiding ((<>)) |
152 | import Numeric.LinearAlgebra.Algorithms | 155 | import Numeric.LinearAlgebra.Algorithms |
153 | import Numeric.LinearAlgebra.Util | 156 | import Numeric.LinearAlgebra.Util |
154 | import Numeric.LinearAlgebra.Random | 157 | import Numeric.LinearAlgebra.Random |
155 | import Numeric.Sparse(smXv) | 158 | import Numeric.Sparse((!#>)) |
156 | import Numeric.LinearAlgebra.Util.CG | 159 | import Numeric.LinearAlgebra.Util.CG |
157 | 160 | ||
161 | -- | matrix product | ||
162 | (<>) :: Numeric t => Matrix t -> Matrix t -> Matrix t | ||
163 | (<>) = mXm | ||
164 | infixr 8 <> | ||
158 | 165 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Data.hs b/packages/base/src/Numeric/LinearAlgebra/Data.hs index 3128a24..3417a5e 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Data.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Data.hs | |||
@@ -49,13 +49,9 @@ module Numeric.LinearAlgebra.Data( | |||
49 | find, maxIndex, minIndex, maxElement, minElement, atIndex, | 49 | find, maxIndex, minIndex, maxElement, minElement, atIndex, |
50 | 50 | ||
51 | -- * Sparse | 51 | -- * Sparse |
52 | SMatrix, AssocMatrix, mkCSR, toDense, | 52 | GMatrix, AssocMatrix, mkSparse, toDense, |
53 | mkDiag, | 53 | mkDiagR, dense, |
54 | |||
55 | -- * Static dimensions | ||
56 | 54 | ||
57 | Static, ddata, R, vect0, sScalar, vect2, vect3, (&), | ||
58 | |||
59 | -- * IO | 55 | -- * IO |
60 | disp, | 56 | disp, |
61 | loadMatrix, saveMatrix, | 57 | loadMatrix, saveMatrix, |
@@ -79,9 +75,8 @@ module Numeric.LinearAlgebra.Data( | |||
79 | import Data.Packed.Vector | 75 | import Data.Packed.Vector |
80 | import Data.Packed.Matrix | 76 | import Data.Packed.Matrix |
81 | import Data.Packed.Numeric | 77 | import Data.Packed.Numeric |
82 | import Numeric.LinearAlgebra.Util hiding ((&)) | 78 | import Numeric.LinearAlgebra.Util hiding ((&),(#)) |
83 | import Data.Complex | 79 | import Data.Complex |
84 | import Numeric.Sparse | 80 | import Numeric.Sparse |
85 | import Numeric.LinearAlgebra.Util.Static | ||
86 | 81 | ||
87 | 82 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs new file mode 100644 index 0000000..db15705 --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs | |||
@@ -0,0 +1,337 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE FlexibleContexts #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE EmptyDataDecls #-} | ||
9 | {-# LANGUAGE Rank2Types #-} | ||
10 | {-# LANGUAGE FlexibleInstances #-} | ||
11 | {-# LANGUAGE TypeOperators #-} | ||
12 | {-# LANGUAGE ViewPatterns #-} | ||
13 | {-# LANGUAGE GADTs #-} | ||
14 | |||
15 | |||
16 | {- | | ||
17 | Module : Numeric.LinearAlgebra.Real | ||
18 | Copyright : (c) Alberto Ruiz 2006-14 | ||
19 | License : BSD3 | ||
20 | Stability : provisional | ||
21 | |||
22 | Experimental interface for real arrays with statically checked dimensions. | ||
23 | |||
24 | -} | ||
25 | |||
26 | module Numeric.LinearAlgebra.Real( | ||
27 | -- * Vector | ||
28 | R, | ||
29 | vec2, vec3, vec4, 𝕧, (&), | ||
30 | -- * Matrix | ||
31 | L, Sq, | ||
32 | 𝕞, | ||
33 | (#),(¦),(——), | ||
34 | Konst(..), | ||
35 | eye, | ||
36 | diagR, diag, | ||
37 | blockAt, | ||
38 | -- * Products | ||
39 | (<>),(#>),(<·>), | ||
40 | -- * Pretty printing | ||
41 | Disp(..), | ||
42 | -- * Misc | ||
43 | Dim, unDim, | ||
44 | module Numeric.HMatrix | ||
45 | ) where | ||
46 | |||
47 | |||
48 | import GHC.TypeLits | ||
49 | import Numeric.HMatrix hiding ((<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——)) | ||
50 | import qualified Numeric.HMatrix as LA | ||
51 | import Data.Packed.ST | ||
52 | |||
53 | newtype Dim (n :: Nat) t = Dim t | ||
54 | deriving Show | ||
55 | |||
56 | unDim :: Dim n t -> t | ||
57 | unDim (Dim x) = x | ||
58 | |||
59 | data Proxy :: Nat -> * | ||
60 | |||
61 | |||
62 | lift1F | ||
63 | :: (c t -> c t) | ||
64 | -> Dim n (c t) -> Dim n (c t) | ||
65 | lift1F f (Dim v) = Dim (f v) | ||
66 | |||
67 | lift2F | ||
68 | :: (c t -> c t -> c t) | ||
69 | -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) | ||
70 | lift2F f (Dim u) (Dim v) = Dim (f u v) | ||
71 | |||
72 | |||
73 | |||
74 | type R n = Dim n (Vector ℝ) | ||
75 | |||
76 | type L m n = Dim m (Dim n (Matrix ℝ)) | ||
77 | |||
78 | |||
79 | infixl 4 & | ||
80 | (&) :: forall n . KnownNat n | ||
81 | => R n -> ℝ -> R (n+1) | ||
82 | Dim v & x = Dim (vjoin [v', scalar x]) | ||
83 | where | ||
84 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
85 | v' | d > 1 && size v == 1 = LA.konst (v!0) d | ||
86 | | otherwise = v | ||
87 | |||
88 | |||
89 | -- vect0 :: R 0 | ||
90 | -- vect0 = Dim (fromList[]) | ||
91 | |||
92 | 𝕧 :: ℝ -> R 1 | ||
93 | 𝕧 = Dim . scalar | ||
94 | |||
95 | |||
96 | vec2 :: ℝ -> ℝ -> R 2 | ||
97 | vec2 a b = Dim $ runSTVector $ do | ||
98 | v <- newUndefinedVector 2 | ||
99 | writeVector v 0 a | ||
100 | writeVector v 1 b | ||
101 | return v | ||
102 | |||
103 | vec3 :: ℝ -> ℝ -> ℝ -> R 3 | ||
104 | vec3 a b c = Dim $ runSTVector $ do | ||
105 | v <- newUndefinedVector 3 | ||
106 | writeVector v 0 a | ||
107 | writeVector v 1 b | ||
108 | writeVector v 2 c | ||
109 | return v | ||
110 | |||
111 | |||
112 | vec4 :: ℝ -> ℝ -> ℝ -> ℝ -> R 4 | ||
113 | vec4 a b c d = Dim $ runSTVector $ do | ||
114 | v <- newUndefinedVector 4 | ||
115 | writeVector v 0 a | ||
116 | writeVector v 1 b | ||
117 | writeVector v 2 c | ||
118 | writeVector v 3 d | ||
119 | return v | ||
120 | |||
121 | |||
122 | |||
123 | |||
124 | instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) | ||
125 | where | ||
126 | (+) = lift2F (+) | ||
127 | (*) = lift2F (*) | ||
128 | (-) = lift2F (-) | ||
129 | abs = lift1F abs | ||
130 | signum = lift1F signum | ||
131 | negate = lift1F negate | ||
132 | fromInteger x = Dim (fromInteger x) | ||
133 | |||
134 | instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) | ||
135 | where | ||
136 | (+) = (lift2F . lift2F) (+) | ||
137 | (*) = (lift2F . lift2F) (*) | ||
138 | (-) = (lift2F . lift2F) (-) | ||
139 | abs = (lift1F . lift1F) abs | ||
140 | signum = (lift1F . lift1F) signum | ||
141 | negate = (lift1F . lift1F) negate | ||
142 | fromInteger x = Dim (Dim (fromInteger x)) | ||
143 | |||
144 | -------------------------------------------------------------------------------- | ||
145 | |||
146 | class Konst t | ||
147 | where | ||
148 | konst :: ℝ -> t | ||
149 | |||
150 | instance forall n. KnownNat n => Konst (R n) | ||
151 | where | ||
152 | konst x = Dim (LA.konst x d) | ||
153 | where | ||
154 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
155 | |||
156 | instance forall m n . (KnownNat m, KnownNat n) => Konst (L m n) | ||
157 | where | ||
158 | konst x = Dim (Dim (LA.konst x (m',n'))) | ||
159 | where | ||
160 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
161 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
162 | |||
163 | -------------------------------------------------------------------------------- | ||
164 | |||
165 | diagR :: forall m n k . (KnownNat m, KnownNat n) => ℝ -> R k -> L m n | ||
166 | diagR x v = Dim (Dim (diagRect x (unDim v) m' n')) | ||
167 | where | ||
168 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
169 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
170 | |||
171 | diag :: KnownNat n => R n -> Sq n | ||
172 | diag = diagR 0 | ||
173 | |||
174 | -------------------------------------------------------------------------------- | ||
175 | |||
176 | blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n | ||
177 | blockAt x r c a = Dim (Dim res) | ||
178 | where | ||
179 | z = scalar x | ||
180 | z1 = LA.konst x (r,c) | ||
181 | z2 = LA.konst x (max 0 (m'-(ra+r)), max 0 (n'-(ca+c))) | ||
182 | ra = min (rows a) . max 0 $ m'-r | ||
183 | ca = min (cols a) . max 0 $ n'-c | ||
184 | sa = subMatrix (0,0) (ra, ca) a | ||
185 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
186 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
187 | res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] | ||
188 | |||
189 | {- | ||
190 | matrix :: (KnownNat m, KnownNat n) => Matrix Double -> L n m | ||
191 | matrix = blockAt 0 0 0 | ||
192 | -} | ||
193 | |||
194 | -------------------------------------------------------------------------------- | ||
195 | |||
196 | class Disp t | ||
197 | where | ||
198 | disp :: Int -> t -> IO () | ||
199 | |||
200 | instance Disp (L n m) | ||
201 | where | ||
202 | disp n (d2 -> a) = do | ||
203 | if rows a == 1 && cols a == 1 | ||
204 | then putStrLn $ "Const " ++ (last . words . LA.dispf n $ a) | ||
205 | else putStr "Dim " >> LA.disp n a | ||
206 | |||
207 | instance Disp (R n) | ||
208 | where | ||
209 | disp n (unDim -> v) = do | ||
210 | let su = LA.dispf n (asRow v) | ||
211 | if LA.size v == 1 | ||
212 | then putStrLn $ "Const " ++ (last . words $ su ) | ||
213 | else putStr "Dim " >> putStr (tail . dropWhile (/='x') $ su) | ||
214 | |||
215 | -------------------------------------------------------------------------------- | ||
216 | |||
217 | infixl 3 # | ||
218 | (#) :: L r c -> R c -> L (r+1) c | ||
219 | Dim (Dim m) # Dim v = Dim (Dim (m LA.—— asRow v)) | ||
220 | |||
221 | |||
222 | 𝕞 :: forall n . KnownNat n => L 0 n | ||
223 | 𝕞 = Dim (Dim (LA.konst 0 (0,d))) | ||
224 | where | ||
225 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
226 | |||
227 | infixl 3 ¦ | ||
228 | (¦) :: L r c1 -> L r c2 -> L r (c1+c2) | ||
229 | Dim (Dim a) ¦ Dim (Dim b) = Dim (Dim (a LA.¦ b)) | ||
230 | |||
231 | infixl 2 —— | ||
232 | (——) :: L r1 c -> L r2 c -> L (r1+r2) c | ||
233 | Dim (Dim a) —— Dim (Dim b) = Dim (Dim (a LA.—— b)) | ||
234 | |||
235 | |||
236 | {- | ||
237 | |||
238 | -} | ||
239 | |||
240 | type Sq n = L n n | ||
241 | |||
242 | type GL = (KnownNat n, KnownNat m) => L m n | ||
243 | type GSq = KnownNat n => Sq n | ||
244 | |||
245 | infixr 8 <> | ||
246 | (<>) :: L m k -> L k n -> L m n | ||
247 | (d2 -> a) <> (d2 -> b) = Dim (Dim (a LA.<> b)) | ||
248 | |||
249 | infixr 8 #> | ||
250 | (#>) :: L m n -> R n -> R m | ||
251 | (d2 -> m) #> (unDim -> v) = Dim (m LA.#> v) | ||
252 | |||
253 | infixr 8 <·> | ||
254 | (<·>) :: R n -> R n -> ℝ | ||
255 | (unDim -> u) <·> (unDim -> v) = udot u v | ||
256 | |||
257 | |||
258 | d2 :: forall c (n :: Nat) (n1 :: Nat). Dim n1 (Dim n c) -> c | ||
259 | d2 = unDim . unDim | ||
260 | |||
261 | |||
262 | instance Transposable (L m n) (L n m) | ||
263 | where | ||
264 | tr (Dim (Dim a)) = Dim (Dim (tr a)) | ||
265 | |||
266 | |||
267 | eye :: forall n . KnownNat n => Sq n | ||
268 | eye = Dim (Dim (ident d)) | ||
269 | where | ||
270 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
271 | |||
272 | |||
273 | -------------------------------------------------------------------------------- | ||
274 | |||
275 | test :: (Bool, IO ()) | ||
276 | test = (ok,info) | ||
277 | where | ||
278 | ok = d2 (eye :: Sq 5) == ident 5 | ||
279 | && d2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] | ||
280 | && d2 (tm :: L 3 5) == mat 5 [1..15] | ||
281 | && thingS == thingD | ||
282 | && precS == precD | ||
283 | |||
284 | info = do | ||
285 | print $ u | ||
286 | print $ v | ||
287 | print (eye :: Sq 3) | ||
288 | print $ ((u & 5) + 1) <·> v | ||
289 | print (tm :: L 2 5) | ||
290 | print (tm <> sm :: L 2 3) | ||
291 | print thingS | ||
292 | print thingD | ||
293 | print precS | ||
294 | print precD | ||
295 | |||
296 | u = vec2 3 5 | ||
297 | |||
298 | v = 𝕧 2 & 4 & 7 | ||
299 | |||
300 | mTm :: L n m -> Sq m | ||
301 | mTm a = tr a <> a | ||
302 | |||
303 | tm :: GL | ||
304 | tm = lmat 0 [1..] | ||
305 | |||
306 | lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n | ||
307 | lmat z xs = Dim . Dim . reshape n' . fromList . take (m'*n') $ xs ++ repeat z | ||
308 | where | ||
309 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
310 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
311 | |||
312 | sm :: GSq | ||
313 | sm = lmat 0 [1..] | ||
314 | |||
315 | thingS = (u & 1) <·> tr q #> q #> v | ||
316 | where | ||
317 | q = tm :: L 10 3 | ||
318 | |||
319 | thingD = vjoin [unDim u, 1] LA.<·> tr m LA.#> m LA.#> unDim v | ||
320 | where | ||
321 | m = mat 3 [1..30] | ||
322 | |||
323 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v | ||
324 | precD = 1 + 2 * vjoin[unDim u, 6] LA.<·> LA.konst 2 (size (unDim u) +1, size (unDim v)) LA.#> unDim v | ||
325 | |||
326 | |||
327 | instance (KnownNat n', KnownNat m') => Testable (L n' m') | ||
328 | where | ||
329 | checkT _ = test | ||
330 | |||
331 | {- | ||
332 | do (snd test) | ||
333 | fst test | ||
334 | -} | ||
335 | |||
336 | |||
337 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util.hs b/packages/base/src/Numeric/LinearAlgebra/Util.hs index a319785..47b1090 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util.hs | |||
@@ -32,7 +32,7 @@ module Numeric.LinearAlgebra.Util( | |||
32 | rand, randn, | 32 | rand, randn, |
33 | cross, | 33 | cross, |
34 | norm, | 34 | norm, |
35 | ℕ,ℤ,ℝ,ℂ,ℝn,ℂn,𝑖,i_C, --ℍ | 35 | ℕ,ℤ,ℝ,ℂ,𝑖,i_C, --ℍ |
36 | norm_1, norm_2, norm_0, norm_Inf, norm_Frob, norm_nuclear, | 36 | norm_1, norm_2, norm_0, norm_Inf, norm_Frob, norm_nuclear, |
37 | mnorm_1, mnorm_2, mnorm_0, mnorm_Inf, | 37 | mnorm_1, mnorm_2, mnorm_0, mnorm_Inf, |
38 | unitary, | 38 | unitary, |
@@ -70,8 +70,8 @@ type ℝ = Double | |||
70 | type ℕ = Int | 70 | type ℕ = Int |
71 | type ℤ = Int | 71 | type ℤ = Int |
72 | type ℂ = Complex Double | 72 | type ℂ = Complex Double |
73 | type ℝn = Vector ℝ | 73 | --type ℝn = Vector ℝ |
74 | type ℂn = Vector ℂ | 74 | --type ℂn = Vector ℂ |
75 | --newtype ℍ m = H m | 75 | --newtype ℍ m = H m |
76 | 76 | ||
77 | i_C, 𝑖 :: ℂ | 77 | i_C, 𝑖 :: ℂ |
@@ -84,7 +84,7 @@ i_C = 𝑖 | |||
84 | fromList [1.0,2.0,3.0,4.0,5.0] | 84 | fromList [1.0,2.0,3.0,4.0,5.0] |
85 | 85 | ||
86 | -} | 86 | -} |
87 | vect :: [ℝ] -> ℝn | 87 | vect :: [ℝ] -> Vector ℝ |
88 | vect = fromList | 88 | vect = fromList |
89 | 89 | ||
90 | {- | create a real matrix | 90 | {- | create a real matrix |
@@ -103,18 +103,6 @@ mat | |||
103 | mat c = reshape c . fromList | 103 | mat c = reshape c . fromList |
104 | 104 | ||
105 | 105 | ||
106 | |||
107 | class ( Container Vector t | ||
108 | , Container Matrix t | ||
109 | , Konst t Int Vector | ||
110 | , Konst t (Int,Int) Matrix | ||
111 | ) => Numeric t | ||
112 | |||
113 | instance Numeric Double | ||
114 | instance Numeric (Complex Double) | ||
115 | |||
116 | |||
117 | |||
118 | {- | print a real matrix with given number of digits after the decimal point | 106 | {- | print a real matrix with given number of digits after the decimal point |
119 | 107 | ||
120 | >>> disp 5 $ ident 2 / 3 | 108 | >>> disp 5 $ ident 2 / 3 |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs index 5e2ea84..50372f1 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Util/CG.hs | |||
@@ -3,11 +3,14 @@ | |||
3 | 3 | ||
4 | module Numeric.LinearAlgebra.Util.CG( | 4 | module Numeric.LinearAlgebra.Util.CG( |
5 | cgSolve, cgSolve', | 5 | cgSolve, cgSolve', |
6 | CGMat, CGState(..), R, V | 6 | CGState(..), R, V |
7 | ) where | 7 | ) where |
8 | 8 | ||
9 | import Data.Packed.Numeric | 9 | import Data.Packed.Numeric |
10 | import Numeric.Sparse | ||
10 | import Numeric.Vector() | 11 | import Numeric.Vector() |
12 | import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) | ||
13 | import Control.Arrow((***)) | ||
11 | 14 | ||
12 | {- | 15 | {- |
13 | import Util.Misc(debug, debugMat) | 16 | import Util.Misc(debug, debugMat) |
@@ -51,7 +54,7 @@ cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx | |||
51 | rdx = norm2 dx / max 1 (norm2 x) | 54 | rdx = norm2 dx / max 1 (norm2 x) |
52 | 55 | ||
53 | conjugrad | 56 | conjugrad |
54 | :: (Transposable m, Contraction m V V) | 57 | :: (Transposable m mt, Contraction m V V, Contraction mt V V) |
55 | => Bool -> m -> V -> V -> R -> R -> [CGState] | 58 | => Bool -> m -> V -> V -> R -> R -> [CGState] |
56 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b | 59 | conjugrad sym a b = solveG (tr a ◇) (a ◇) (cg sym) b |
57 | 60 | ||
@@ -82,27 +85,88 @@ takeUntil q xs = a++ take 1 b | |||
82 | where | 85 | where |
83 | (a,b) = break q xs | 86 | (a,b) = break q xs |
84 | 87 | ||
85 | class (Transposable m, Contraction m V V) => CGMat m | ||
86 | |||
87 | cgSolve | 88 | cgSolve |
88 | :: CGMat m | 89 | :: Bool -- ^ is symmetric |
89 | => Bool -- ^ is symmetric | 90 | -> GMatrix -- ^ coefficient matrix |
90 | -> m -- ^ coefficient matrix | ||
91 | -> Vector Double -- ^ right-hand side | 91 | -> Vector Double -- ^ right-hand side |
92 | -> Vector Double -- ^ solution | 92 | -> Vector Double -- ^ solution |
93 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 | 93 | cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 |
94 | where | 94 | where |
95 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) | 95 | n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) |
96 | 96 | ||
97 | cgSolve' | 97 | cgSolve' |
98 | :: CGMat m | 98 | :: Bool -- ^ symmetric |
99 | => Bool -- ^ symmetric | ||
100 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) | 99 | -> R -- ^ relative tolerance for the residual (e.g. 1E-4) |
101 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) | 100 | -> R -- ^ relative tolerance for δx (e.g. 1E-3) |
102 | -> Int -- ^ maximum number of iterations | 101 | -> Int -- ^ maximum number of iterations |
103 | -> m -- ^ coefficient matrix | 102 | -> GMatrix -- ^ coefficient matrix |
104 | -> V -- ^ initial solution | 103 | -> V -- ^ initial solution |
105 | -> V -- ^ right-hand side | 104 | -> V -- ^ right-hand side |
106 | -> [CGState] -- ^ solution | 105 | -> [CGState] -- ^ solution |
107 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es | 106 | cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es |
108 | 107 | ||
108 | |||
109 | -------------------------------------------------------------------------------- | ||
110 | |||
111 | instance Testable GMatrix | ||
112 | where | ||
113 | checkT _ = (ok,info) | ||
114 | where | ||
115 | sma = convo2 20 3 | ||
116 | x1 = vect [1..20] | ||
117 | x2 = vect [1..40] | ||
118 | sm = mkSparse sma | ||
119 | dm = toDense sma | ||
120 | |||
121 | s1 = sm !#> x1 | ||
122 | d1 = dm #> x1 | ||
123 | |||
124 | s2 = tr sm !#> x2 | ||
125 | d2 = tr dm #> x2 | ||
126 | |||
127 | sdia = mkDiagR 40 20 (vect [1..10]) | ||
128 | s3 = sdia !#> x1 | ||
129 | s4 = tr sdia !#> x2 | ||
130 | ddia = diagRect 0 (vect [1..10]) 40 20 | ||
131 | d3 = ddia #> x1 | ||
132 | d4 = tr ddia #> x2 | ||
133 | |||
134 | v = testb 40 | ||
135 | s5 = cgSolve False sm v | ||
136 | d5 = denseSolve dm v | ||
137 | |||
138 | info = do | ||
139 | print sm | ||
140 | disp (toDense sma) | ||
141 | print s1; print d1 | ||
142 | print s2; print d2 | ||
143 | print s3; print d3 | ||
144 | print s4; print d4 | ||
145 | print s5; print d5 | ||
146 | print $ relativeError Infinity s5 d5 | ||
147 | |||
148 | ok = s1==d1 | ||
149 | && s2==d2 | ||
150 | && s3==d3 | ||
151 | && s4==d4 | ||
152 | && relativeError Infinity s5 d5 < 1E-10 | ||
153 | |||
154 | disp = putStr . dispf 2 | ||
155 | |||
156 | vect = fromList :: [Double] -> Vector Double | ||
157 | |||
158 | convomat :: Int -> Int -> AssocMatrix | ||
159 | convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] | ||
160 | |||
161 | convo2 :: Int -> Int -> AssocMatrix | ||
162 | convo2 n k = m1 ++ m2 | ||
163 | where | ||
164 | m1 = convomat n k | ||
165 | m2 = map (((+n) *** id) *** id) m1 | ||
166 | |||
167 | testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) | ||
168 | |||
169 | denseSolve a = flatten . linearSolveLS a . asColumn | ||
170 | |||
171 | -- mkDiag v = mkDiagR (dim v) (dim v) v | ||
172 | |||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs deleted file mode 100644 index a3f8eb0..0000000 --- a/packages/base/src/Numeric/LinearAlgebra/Util/Static.hs +++ /dev/null | |||
@@ -1,70 +0,0 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FlexibleContexts #-} | ||
6 | {-# LANGUAGE ScopedTypeVariables #-} | ||
7 | {-# LANGUAGE EmptyDataDecls #-} | ||
8 | {-# LANGUAGE Rank2Types #-} | ||
9 | {-# LANGUAGE FlexibleInstances #-} | ||
10 | {-# LANGUAGE TypeOperators #-} | ||
11 | |||
12 | module Numeric.LinearAlgebra.Util.Static( | ||
13 | Static (ddata), | ||
14 | R, | ||
15 | vect0, sScalar, vect2, vect3, (&) | ||
16 | ) where | ||
17 | |||
18 | |||
19 | import GHC.TypeLits | ||
20 | import Data.Packed.Numeric | ||
21 | import Numeric.Vector() | ||
22 | import Numeric.LinearAlgebra.Util(Numeric,ℝ) | ||
23 | |||
24 | lift1F :: (Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t) | ||
25 | lift1F f (Static v) = Static (f v) | ||
26 | |||
27 | lift2F :: (Vector t -> Vector t -> Vector t) -> Static n (Vector t) -> Static n (Vector t) -> Static n (Vector t) | ||
28 | lift2F f (Static u) (Static v) = Static (f u v) | ||
29 | |||
30 | newtype Static (n :: Nat) t = Static { ddata :: t } deriving Show | ||
31 | |||
32 | type R n = Static n (Vector ℝ) | ||
33 | |||
34 | |||
35 | infixl 4 & | ||
36 | (&) :: R n -> ℝ -> R (n+1) | ||
37 | Static v & x = Static (vjoin [v, scalar x]) | ||
38 | |||
39 | vect0 :: R 0 | ||
40 | vect0 = Static (fromList[]) | ||
41 | |||
42 | sScalar :: ℝ -> R 1 | ||
43 | sScalar = Static . scalar | ||
44 | |||
45 | |||
46 | vect2 :: ℝ -> ℝ -> R 2 | ||
47 | vect2 x1 x2 = Static (fromList [x1,x2]) | ||
48 | |||
49 | vect3 :: ℝ -> ℝ -> ℝ -> R 3 | ||
50 | vect3 x1 x2 x3 = Static (fromList [x1,x2,x3]) | ||
51 | |||
52 | |||
53 | |||
54 | |||
55 | |||
56 | |||
57 | instance forall n t . (KnownNat n, Num (Vector t), Numeric t )=> Num (Static n (Vector t)) | ||
58 | where | ||
59 | (+) = lift2F add | ||
60 | (*) = lift2F mul | ||
61 | (-) = lift2F sub | ||
62 | abs = lift1F abs | ||
63 | signum = lift1F signum | ||
64 | negate = lift1F (scale (-1)) | ||
65 | fromInteger x = Static (konst (fromInteger x) d) | ||
66 | where | ||
67 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
68 | |||
69 | data Proxy :: Nat -> * | ||
70 | |||
diff --git a/packages/base/src/Numeric/Sparse.hs b/packages/base/src/Numeric/Sparse.hs index 2df4578..4d05bdc 100644 --- a/packages/base/src/Numeric/Sparse.hs +++ b/packages/base/src/Numeric/Sparse.hs | |||
@@ -3,11 +3,11 @@ | |||
3 | {-# LANGUAGE FlexibleInstances #-} | 3 | {-# LANGUAGE FlexibleInstances #-} |
4 | 4 | ||
5 | module Numeric.Sparse( | 5 | module Numeric.Sparse( |
6 | SMatrix(..), | 6 | GMatrix(..), |
7 | mkCSR, mkDiag, | 7 | mkSparse, mkDiagR, dense, |
8 | AssocMatrix, | 8 | AssocMatrix, |
9 | toDense, | 9 | toDense, |
10 | smXv | 10 | gmXv, (!#>) |
11 | )where | 11 | )where |
12 | 12 | ||
13 | import Data.Packed.Numeric | 13 | import Data.Packed.Numeric |
@@ -17,8 +17,7 @@ import Control.Arrow((***)) | |||
17 | import Control.Monad(when) | 17 | import Control.Monad(when) |
18 | import Data.List(groupBy, sort) | 18 | import Data.List(groupBy, sort) |
19 | import Foreign.C.Types(CInt(..)) | 19 | import Foreign.C.Types(CInt(..)) |
20 | import Numeric.LinearAlgebra.Util.CG(CGMat,cgSolve) | 20 | |
21 | import Numeric.LinearAlgebra.Algorithms(linearSolveLS, relativeError, NormType(..)) | ||
22 | import Data.Packed.Development | 21 | import Data.Packed.Development |
23 | import System.IO.Unsafe(unsafePerformIO) | 22 | import System.IO.Unsafe(unsafePerformIO) |
24 | import Foreign(Ptr) | 23 | import Foreign(Ptr) |
@@ -29,7 +28,7 @@ c ~!~ msg = when c (error msg) | |||
29 | 28 | ||
30 | type AssocMatrix = [((Int,Int),Double)] | 29 | type AssocMatrix = [((Int,Int),Double)] |
31 | 30 | ||
32 | data SMatrix | 31 | data GMatrix |
33 | = CSR | 32 | = CSR |
34 | { csrVals :: Vector Double | 33 | { csrVals :: Vector Double |
35 | , csrCols :: Vector CInt | 34 | , csrCols :: Vector CInt |
@@ -46,14 +45,26 @@ data SMatrix | |||
46 | } | 45 | } |
47 | | Diag | 46 | | Diag |
48 | { diagVals :: Vector Double | 47 | { diagVals :: Vector Double |
48 | , nRows :: Int | ||
49 | , nCols :: Int | ||
50 | } | ||
51 | | Dense | ||
52 | { gmDense :: Matrix Double | ||
49 | , nRows :: Int | 53 | , nRows :: Int |
50 | , nCols :: Int | 54 | , nCols :: Int |
51 | } | 55 | } |
52 | -- | Banded | 56 | -- | Banded |
53 | deriving Show | 57 | deriving Show |
54 | 58 | ||
55 | mkCSR :: AssocMatrix -> SMatrix | 59 | dense :: Matrix Double -> GMatrix |
56 | mkCSR sm' = CSR{..} | 60 | dense m = Dense{..} |
61 | where | ||
62 | gmDense = m | ||
63 | nRows = rows m | ||
64 | nCols = cols m | ||
65 | |||
66 | mkSparse :: AssocMatrix -> GMatrix | ||
67 | mkSparse sm' = CSR{..} | ||
57 | where | 68 | where |
58 | sm = sort sm' | 69 | sm = sort sm' |
59 | rws = map ((fromList *** fromList) | 70 | rws = map ((fromList *** fromList) |
@@ -78,37 +89,47 @@ mkDiagR r c v | |||
78 | nCols = c | 89 | nCols = c |
79 | diagVals = v | 90 | diagVals = v |
80 | 91 | ||
81 | mkDiag v = mkDiagR (dim v) (dim v) v | ||
82 | |||
83 | 92 | ||
84 | type IV t = CInt -> Ptr CInt -> t | 93 | type IV t = CInt -> Ptr CInt -> t |
85 | type V t = CInt -> Ptr Double -> t | 94 | type V t = CInt -> Ptr Double -> t |
86 | type SMxV = V (IV (IV (V (V (IO CInt))))) | 95 | type SMxV = V (IV (IV (V (V (IO CInt))))) |
87 | 96 | ||
88 | smXv :: SMatrix -> Vector Double -> Vector Double | 97 | gmXv :: GMatrix -> Vector Double -> Vector Double |
89 | smXv CSR{..} v = unsafePerformIO $ do | 98 | gmXv CSR{..} v = unsafePerformIO $ do |
90 | dim v /= nCols ~!~ printf "smXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | 99 | dim v /= nCols ~!~ printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) |
91 | r <- createVector nRows | 100 | r <- createVector nRows |
92 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" | 101 | app5 c_smXv vec csrVals vec csrCols vec csrRows vec v vec r "CSRXv" |
93 | return r | 102 | return r |
94 | 103 | ||
95 | smXv CSC{..} v = unsafePerformIO $ do | 104 | gmXv CSC{..} v = unsafePerformIO $ do |
96 | dim v /= nCols ~!~ printf "smXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) | 105 | dim v /= nCols ~!~ printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" nRows nCols (dim v) |
97 | r <- createVector nRows | 106 | r <- createVector nRows |
98 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" | 107 | app5 c_smTXv vec cscVals vec cscRows vec cscCols vec v vec r "CSCXv" |
99 | return r | 108 | return r |
100 | 109 | ||
101 | smXv Diag{..} v | 110 | gmXv Diag{..} v |
102 | | dim v == nCols | 111 | | dim v == nCols |
103 | = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals | 112 | = vjoin [ subVector 0 (dim diagVals) v `mul` diagVals |
104 | , konst 0 (nRows - dim diagVals) ] | 113 | , konst 0 (nRows - dim diagVals) ] |
105 | | otherwise = error $ printf "smXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" | 114 | | otherwise = error $ printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d" |
106 | nRows nCols (dim diagVals) (dim v) | 115 | nRows nCols (dim diagVals) (dim v) |
107 | 116 | ||
117 | gmXv Dense{..} v | ||
118 | | dim v == nCols | ||
119 | = mXv gmDense v | ||
120 | | otherwise = error $ printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d" | ||
121 | nRows nCols (dim v) | ||
122 | |||
108 | 123 | ||
109 | instance Contraction SMatrix (Vector Double) (Vector Double) | 124 | -- | general matrix - vector product |
125 | infixr 8 !#> | ||
126 | (!#>) :: GMatrix -> Vector Double -> Vector Double | ||
127 | (!#>) = gmXv | ||
128 | |||
129 | |||
130 | instance Contraction GMatrix (Vector Double) (Vector Double) | ||
110 | where | 131 | where |
111 | contraction = smXv | 132 | contraction = gmXv |
112 | 133 | ||
113 | -------------------------------------------------------------------------------- | 134 | -------------------------------------------------------------------------------- |
114 | 135 | ||
@@ -127,75 +148,11 @@ toDense asm = assoc (r+1,c+1) 0 asm | |||
127 | 148 | ||
128 | 149 | ||
129 | 150 | ||
130 | instance Transposable SMatrix | 151 | instance Transposable GMatrix GMatrix |
131 | where | 152 | where |
132 | tr (CSR vs cs rs n m) = CSC vs cs rs m n | 153 | tr (CSR vs cs rs n m) = CSC vs cs rs m n |
133 | tr (CSC vs rs cs n m) = CSR vs rs cs m n | 154 | tr (CSC vs rs cs n m) = CSR vs rs cs m n |
134 | tr (Diag v n m) = Diag v m n | 155 | tr (Diag v n m) = Diag v m n |
156 | tr (Dense a n m) = Dense (tr a) m n | ||
135 | 157 | ||
136 | 158 | ||
137 | instance CGMat SMatrix | ||
138 | instance CGMat (Matrix Double) | ||
139 | |||
140 | -------------------------------------------------------------------------------- | ||
141 | |||
142 | instance Testable SMatrix | ||
143 | where | ||
144 | checkT _ = (ok,info) | ||
145 | where | ||
146 | sma = convo2 20 3 | ||
147 | x1 = vect [1..20] | ||
148 | x2 = vect [1..40] | ||
149 | sm = mkCSR sma | ||
150 | dm = toDense sma | ||
151 | |||
152 | s1 = sm ◇ x1 | ||
153 | d1 = dm ◇ x1 | ||
154 | |||
155 | s2 = tr sm ◇ x2 | ||
156 | d2 = tr dm ◇ x2 | ||
157 | |||
158 | sdia = mkDiagR 40 20 (vect [1..10]) | ||
159 | s3 = sdia ◇ x1 | ||
160 | s4 = tr sdia ◇ x2 | ||
161 | ddia = diagRect 0 (vect [1..10]) 40 20 | ||
162 | d3 = ddia ◇ x1 | ||
163 | d4 = tr ddia ◇ x2 | ||
164 | |||
165 | v = testb 40 | ||
166 | s5 = cgSolve False sm v | ||
167 | d5 = denseSolve dm v | ||
168 | |||
169 | info = do | ||
170 | print sm | ||
171 | disp (toDense sma) | ||
172 | print s1; print d1 | ||
173 | print s2; print d2 | ||
174 | print s3; print d3 | ||
175 | print s4; print d4 | ||
176 | print s5; print d5 | ||
177 | print $ relativeError Infinity s5 d5 | ||
178 | |||
179 | ok = s1==d1 | ||
180 | && s2==d2 | ||
181 | && s3==d3 | ||
182 | && s4==d4 | ||
183 | && relativeError Infinity s5 d5 < 1E-10 | ||
184 | |||
185 | disp = putStr . dispf 2 | ||
186 | |||
187 | vect = fromList :: [Double] -> Vector Double | ||
188 | |||
189 | convomat :: Int -> Int -> AssocMatrix | ||
190 | convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] | ||
191 | |||
192 | convo2 :: Int -> Int -> AssocMatrix | ||
193 | convo2 n k = m1 ++ m2 | ||
194 | where | ||
195 | m1 = convomat n k | ||
196 | m2 = map (((+n) *** id) *** id) m1 | ||
197 | |||
198 | testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) | ||
199 | |||
200 | denseSolve a = flatten . linearSolveLS a . asColumn | ||
201 | |||