diff options
Diffstat (limited to 'packages/base/src/Numeric/HMatrix.hs')
-rw-r--r-- | packages/base/src/Numeric/HMatrix.hs | 634 |
1 files changed, 465 insertions, 169 deletions
diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs index ec96bfc..421333a 100644 --- a/packages/base/src/Numeric/HMatrix.hs +++ b/packages/base/src/Numeric/HMatrix.hs | |||
@@ -1,201 +1,497 @@ | |||
1 | ----------------------------------------------------------------------------- | 1 | {-# LANGUAGE DataKinds #-} |
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE FlexibleContexts #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE EmptyDataDecls #-} | ||
9 | {-# LANGUAGE Rank2Types #-} | ||
10 | {-# LANGUAGE FlexibleInstances #-} | ||
11 | {-# LANGUAGE TypeOperators #-} | ||
12 | {-# LANGUAGE ViewPatterns #-} | ||
13 | {-# LANGUAGE GADTs #-} | ||
14 | {-# LANGUAGE OverlappingInstances #-} | ||
15 | {-# LANGUAGE TypeFamilies #-} | ||
16 | |||
17 | |||
2 | {- | | 18 | {- | |
3 | Module : Numeric.HMatrix | 19 | Module : Numeric.HMatrix |
4 | Copyright : (c) Alberto Ruiz 2006-14 | 20 | Copyright : (c) Alberto Ruiz 2006-14 |
5 | License : BSD3 | 21 | License : BSD3 |
6 | Maintainer : Alberto Ruiz | 22 | Stability : experimental |
7 | Stability : provisional | 23 | |
24 | Experimental interface for real arrays with statically checked dimensions. | ||
8 | 25 | ||
9 | -} | 26 | -} |
10 | ----------------------------------------------------------------------------- | ||
11 | module Numeric.HMatrix ( | ||
12 | |||
13 | -- * Basic types and data processing | ||
14 | module Numeric.LinearAlgebra.Data, | ||
15 | |||
16 | -- * Arithmetic and numeric classes | ||
17 | -- | | ||
18 | -- The standard numeric classes are defined elementwise: | ||
19 | -- | ||
20 | -- >>> vect [1,2,3] * vect [3,0,-2] | ||
21 | -- fromList [3.0,0.0,-6.0] | ||
22 | -- | ||
23 | -- >>> mat 3 [1..9] * ident 3 | ||
24 | -- (3><3) | ||
25 | -- [ 1.0, 0.0, 0.0 | ||
26 | -- , 0.0, 5.0, 0.0 | ||
27 | -- , 0.0, 0.0, 9.0 ] | ||
28 | -- | ||
29 | -- In arithmetic operations single-element vectors and matrices | ||
30 | -- (created from numeric literals or using 'scalar') automatically | ||
31 | -- expand to match the dimensions of the other operand: | ||
32 | -- | ||
33 | -- >>> 5 + 2*ident 3 :: Matrix Double | ||
34 | -- (3><3) | ||
35 | -- [ 7.0, 5.0, 5.0 | ||
36 | -- , 5.0, 7.0, 5.0 | ||
37 | -- , 5.0, 5.0, 7.0 ] | ||
38 | -- | ||
39 | -- >>> mat 3 [1..9] + mat 1 [10,20,30] | ||
40 | -- (3><3) | ||
41 | -- [ 11.0, 12.0, 13.0 | ||
42 | -- , 24.0, 25.0, 26.0 | ||
43 | -- , 37.0, 38.0, 39.0 ] | ||
44 | -- | ||
45 | 27 | ||
28 | module Numeric.HMatrix( | ||
29 | -- * Vector | ||
30 | R, | ||
31 | vec2, vec3, vec4, (&), (#), split, headTail, | ||
32 | vector, | ||
33 | linspace, range, dim, | ||
34 | -- * Matrix | ||
35 | L, Sq, build, | ||
36 | row, col, (¦),(——), splitRows, splitCols, | ||
37 | unrow, uncol, | ||
38 | |||
39 | eye, | ||
40 | diagR, diag, | ||
41 | blockAt, | ||
42 | matrix, | ||
46 | -- * Products | 43 | -- * Products |
47 | -- ** dot | 44 | (<>),(#>),(<·>), |
48 | (<·>), | ||
49 | -- ** matrix-vector | ||
50 | (#>), (!#>), | ||
51 | -- ** matrix-matrix | ||
52 | (<>), | ||
53 | -- | The matrix x matrix product is also implemented in the "Data.Monoid" instance, where | ||
54 | -- single-element matrices (created from numeric literals or using 'scalar') | ||
55 | -- are used for scaling. | ||
56 | -- | ||
57 | -- >>> import Data.Monoid as M | ||
58 | -- >>> let m = mat 3 [1..6] | ||
59 | -- >>> m M.<> 2 M.<> diagl[0.5,1,0] | ||
60 | -- (2><3) | ||
61 | -- [ 1.0, 4.0, 0.0 | ||
62 | -- , 4.0, 10.0, 0.0 ] | ||
63 | -- | ||
64 | -- 'mconcat' uses 'optimiseMult' to get the optimal association order. | ||
65 | |||
66 | |||
67 | -- ** other | ||
68 | outer, kronecker, cross, | ||
69 | scale, | ||
70 | sumElements, prodElements, | ||
71 | |||
72 | -- * Linear Systems | 45 | -- * Linear Systems |
73 | (<\>), | 46 | linSolve, (<\>), |
74 | linearSolve, | 47 | -- * Factorizations |
75 | linearSolveLS, | 48 | svd, svdTall, svdFlat, Eigen(..), |
76 | linearSolveSVD, | 49 | withNullspace, |
77 | luSolve, | 50 | -- * Misc |
78 | cholSolve, | 51 | Disp(..), |
79 | cgSolve, | 52 | withVector, withMatrix, |
80 | cgSolve', | 53 | toRows, toColumns, |
54 | Sized(..), Diag(..), Sym, sym, | ||
55 | module Numeric.LinearAlgebra.HMatrix | ||
56 | ) where | ||
81 | 57 | ||
82 | -- * Inverse and pseudoinverse | ||
83 | inv, pinv, pinvTol, | ||
84 | 58 | ||
85 | -- * Determinant and rank | 59 | import GHC.TypeLits |
86 | rcond, rank, | 60 | import Numeric.LinearAlgebra.HMatrix hiding ( |
87 | det, invlndet, | 61 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vector,matrix,linspace,toRows,toColumns, |
62 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build) | ||
63 | import qualified Numeric.LinearAlgebra.HMatrix as LA | ||
64 | import Data.Proxy(Proxy) | ||
65 | import Numeric.LinearAlgebra.Static | ||
66 | import Control.Arrow((***)) | ||
88 | 67 | ||
89 | -- * Norms | ||
90 | Normed(..), | ||
91 | norm_Frob, norm_nuclear, | ||
92 | 68 | ||
93 | -- * Nullspace and range | ||
94 | orth, | ||
95 | nullspace, null1, null1sym, | ||
96 | 69 | ||
97 | -- * SVD | ||
98 | svd, | ||
99 | fullSVD, | ||
100 | thinSVD, | ||
101 | compactSVD, | ||
102 | singularValues, | ||
103 | leftSV, rightSV, | ||
104 | 70 | ||
105 | -- * Eigensystems | ||
106 | eig, eigSH, eigSH', | ||
107 | eigenvalues, eigenvaluesSH, eigenvaluesSH', | ||
108 | geigSH', | ||
109 | 71 | ||
110 | -- * QR | 72 | ud1 :: R n -> Vector ℝ |
111 | qr, rq, qrRaw, qrgr, | 73 | ud1 (R (Dim v)) = v |
112 | 74 | ||
113 | -- * Cholesky | ||
114 | chol, cholSH, mbCholSH, | ||
115 | 75 | ||
116 | -- * Hessenberg | 76 | infixl 4 & |
117 | hess, | 77 | (&) :: forall n . KnownNat n |
78 | => R n -> ℝ -> R (n+1) | ||
79 | u & x = u # (konst x :: R 1) | ||
118 | 80 | ||
119 | -- * Schur | 81 | infixl 4 # |
120 | schur, | 82 | (#) :: forall n m . (KnownNat n, KnownNat m) |
83 | => R n -> R m -> R (n+m) | ||
84 | (R u) # (R v) = R (vconcat u v) | ||
121 | 85 | ||
122 | -- * LU | ||
123 | lu, luPacked, | ||
124 | 86 | ||
125 | -- * Matrix functions | ||
126 | expm, | ||
127 | sqrtm, | ||
128 | matFunc, | ||
129 | 87 | ||
130 | -- * Correlation and convolution | 88 | vec2 :: ℝ -> ℝ -> R 2 |
131 | corr, conv, corrMin, corr2, conv2, | 89 | vec2 a b = R (gvec2 a b) |
132 | 90 | ||
133 | -- * Random arrays | 91 | vec3 :: ℝ -> ℝ -> ℝ -> R 3 |
92 | vec3 a b c = R (gvec3 a b c) | ||
93 | |||
94 | |||
95 | vec4 :: ℝ -> ℝ -> ℝ -> ℝ -> R 4 | ||
96 | vec4 a b c d = R (gvec4 a b c d) | ||
97 | |||
98 | vector :: KnownNat n => [ℝ] -> R n | ||
99 | vector = fromList | ||
100 | |||
101 | matrix :: (KnownNat m, KnownNat n) => [ℝ] -> L m n | ||
102 | matrix = fromList | ||
103 | |||
104 | linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n | ||
105 | linspace (a,b) = mkR (LA.linspace d (a,b)) | ||
106 | where | ||
107 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
108 | |||
109 | range :: forall n . KnownNat n => R n | ||
110 | range = mkR (LA.linspace d (1,fromIntegral d)) | ||
111 | where | ||
112 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
113 | |||
114 | dim :: forall n . KnownNat n => R n | ||
115 | dim = mkR (scalar d) | ||
116 | where | ||
117 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
118 | |||
119 | |||
120 | -------------------------------------------------------------------------------- | ||
121 | |||
122 | |||
123 | ud2 :: L m n -> Matrix ℝ | ||
124 | ud2 (L (Dim (Dim x))) = x | ||
125 | |||
126 | |||
127 | -------------------------------------------------------------------------------- | ||
128 | -------------------------------------------------------------------------------- | ||
129 | |||
130 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n | ||
131 | diagR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) | ||
132 | where | ||
133 | ev = extract v | ||
134 | zeros = LA.konst x (max 0 ((min m' n') - size ev)) | ||
135 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
136 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
137 | |||
138 | diag :: KnownNat n => R n -> Sq n | ||
139 | diag = diagR 0 | ||
140 | |||
141 | eye :: KnownNat n => Sq n | ||
142 | eye = diag 1 | ||
143 | |||
144 | -------------------------------------------------------------------------------- | ||
145 | |||
146 | blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n | ||
147 | blockAt x r c a = mkL res | ||
148 | where | ||
149 | z = scalar x | ||
150 | z1 = LA.konst x (r,c) | ||
151 | z2 = LA.konst x (max 0 (m'-(ra+r)), max 0 (n'-(ca+c))) | ||
152 | ra = min (rows a) . max 0 $ m'-r | ||
153 | ca = min (cols a) . max 0 $ n'-c | ||
154 | sa = subMatrix (0,0) (ra, ca) a | ||
155 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
156 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
157 | res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] | ||
158 | |||
159 | |||
160 | |||
161 | |||
162 | |||
163 | -------------------------------------------------------------------------------- | ||
164 | |||
165 | |||
166 | row :: R n -> L 1 n | ||
167 | row = mkL . asRow . ud1 | ||
168 | |||
169 | --col :: R n -> L n 1 | ||
170 | col v = tr . row $ v | ||
171 | |||
172 | unrow :: L 1 n -> R n | ||
173 | unrow = mkR . head . LA.toRows . ud2 | ||
174 | |||
175 | --uncol :: L n 1 -> R n | ||
176 | uncol v = unrow . tr $ v | ||
177 | |||
178 | |||
179 | infixl 2 —— | ||
180 | (——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c | ||
181 | a —— b = mkL (extract a LA.—— extract b) | ||
182 | |||
183 | |||
184 | infixl 3 ¦ | ||
185 | -- (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) | ||
186 | a ¦ b = tr (tr a —— tr b) | ||
187 | |||
188 | |||
189 | type Sq n = L n n | ||
190 | --type CSq n = CL n n | ||
191 | |||
192 | type GL = (KnownNat n, KnownNat m) => L m n | ||
193 | type GSq = KnownNat n => Sq n | ||
194 | |||
195 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) | ||
196 | isKonst (unwrap -> x) | ||
197 | | singleM x = Just (x `atIndex` (0,0), (m',n')) | ||
198 | | otherwise = Nothing | ||
199 | where | ||
200 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
201 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
134 | 202 | ||
135 | Seed, RandDist(..), randomVector, rand, randn, gaussianSample, uniformSample, | ||
136 | 203 | ||
137 | -- * Misc | ||
138 | meanCov, peps, relativeError, haussholder, optimiseMult, udot, nullspaceSVD, orthSVD, ranksv, | ||
139 | ℝ,ℂ,iC, | ||
140 | -- * Auxiliary classes | ||
141 | Element, Container, Product, Numeric, LSDiv, | ||
142 | Complexable, RealElement, | ||
143 | RealOf, ComplexOf, SingleOf, DoubleOf, | ||
144 | IndexOf, | ||
145 | Field, | ||
146 | -- Normed, | ||
147 | Transposable, | ||
148 | CGState(..), | ||
149 | Testable(..) | ||
150 | ) where | ||
151 | 204 | ||
152 | import Numeric.LinearAlgebra.Data | ||
153 | |||
154 | import Numeric.Matrix() | ||
155 | import Numeric.Vector() | ||
156 | import Data.Packed.Numeric hiding ((<>)) | ||
157 | import Numeric.LinearAlgebra.Algorithms hiding (linearSolve,Normed,orth) | ||
158 | import qualified Numeric.LinearAlgebra.Algorithms as A | ||
159 | import Numeric.LinearAlgebra.Util | ||
160 | import Numeric.LinearAlgebra.Random | ||
161 | import Numeric.Sparse((!#>)) | ||
162 | import Numeric.LinearAlgebra.Util.CG | ||
163 | |||
164 | {- | dense matrix product | ||
165 | |||
166 | >>> let a = (3><5) [1..] | ||
167 | >>> a | ||
168 | (3><5) | ||
169 | [ 1.0, 2.0, 3.0, 4.0, 5.0 | ||
170 | , 6.0, 7.0, 8.0, 9.0, 10.0 | ||
171 | , 11.0, 12.0, 13.0, 14.0, 15.0 ] | ||
172 | |||
173 | >>> let b = (5><2) [1,3, 0,2, -1,5, 7,7, 6,0] | ||
174 | >>> b | ||
175 | (5><2) | ||
176 | [ 1.0, 3.0 | ||
177 | , 0.0, 2.0 | ||
178 | , -1.0, 5.0 | ||
179 | , 7.0, 7.0 | ||
180 | , 6.0, 0.0 ] | ||
181 | |||
182 | >>> a <> b | ||
183 | (3><2) | ||
184 | [ 56.0, 50.0 | ||
185 | , 121.0, 135.0 | ||
186 | , 186.0, 220.0 ] | ||
187 | 205 | ||
188 | -} | ||
189 | (<>) :: Numeric t => Matrix t -> Matrix t -> Matrix t | ||
190 | (<>) = mXm | ||
191 | infixr 8 <> | 206 | infixr 8 <> |
207 | (<>) :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => L m k -> L k n -> L m n | ||
208 | |||
209 | (isKonst -> Just (a,(_,k))) <> (isKonst -> Just (b,_)) = konst (a * b * fromIntegral k) | ||
210 | |||
211 | (isDiag -> Just (0,a,_)) <> (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) | ||
212 | where | ||
213 | v = a' * b' | ||
214 | n = min (size a) (size b) | ||
215 | a' = subVector 0 n a | ||
216 | b' = subVector 0 n b | ||
217 | |||
218 | (isDiag -> Just (0,a,_)) <> (extract -> b) = mkL (asColumn a * takeRows (size a) b) | ||
219 | |||
220 | (extract -> a) <> (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) | ||
221 | |||
222 | a <> b = mkL (extract a LA.<> extract b) | ||
223 | |||
224 | infixr 8 #> | ||
225 | (#>) :: (KnownNat m, KnownNat n) => L m n -> R n -> R m | ||
226 | (isDiag -> Just (0, w, _)) #> v = mkR (w * subVector 0 (size w) (extract v)) | ||
227 | m #> v = mkR (extract m LA.#> extract v) | ||
228 | |||
229 | |||
230 | infixr 8 <·> | ||
231 | (<·>) :: R n -> R n -> ℝ | ||
232 | (ud1 -> u) <·> (ud1 -> v) | ||
233 | | singleV u || singleV v = sumElements (u * v) | ||
234 | | otherwise = udot u v | ||
235 | |||
236 | -------------------------------------------------------------------------------- | ||
237 | |||
238 | {- | ||
239 | class Minim (n :: Nat) (m :: Nat) | ||
240 | where | ||
241 | type Mini n m :: Nat | ||
242 | |||
243 | instance forall (n :: Nat) . Minim n n | ||
244 | where | ||
245 | type Mini n n = n | ||
246 | |||
247 | |||
248 | instance forall (n :: Nat) (m :: Nat) . (n <= m+1) => Minim n m | ||
249 | where | ||
250 | type Mini n m = n | ||
251 | |||
252 | instance forall (n :: Nat) (m :: Nat) . (m <= n+1) => Minim n m | ||
253 | where | ||
254 | type Mini n m = m | ||
255 | -} | ||
256 | |||
257 | class Diag m d | m -> d | ||
258 | where | ||
259 | takeDiag :: m -> d | ||
260 | |||
261 | |||
262 | |||
263 | instance forall n . (KnownNat n) => Diag (L n n) (R n) | ||
264 | where | ||
265 | takeDiag m = mkR (LA.takeDiag (extract m)) | ||
266 | |||
267 | |||
268 | instance forall m n . (KnownNat m, KnownNat n, m <= n+1) => Diag (L m n) (R m) | ||
269 | where | ||
270 | takeDiag m = mkR (LA.takeDiag (extract m)) | ||
271 | |||
272 | |||
273 | instance forall m n . (KnownNat m, KnownNat n, n <= m+1) => Diag (L m n) (R n) | ||
274 | where | ||
275 | takeDiag m = mkR (LA.takeDiag (extract m)) | ||
276 | |||
277 | |||
278 | -------------------------------------------------------------------------------- | ||
279 | |||
280 | linSolve :: (KnownNat m, KnownNat n) => L m m -> L m n -> Maybe (L m n) | ||
281 | linSolve (extract -> a) (extract -> b) = fmap mkL (LA.linearSolve a b) | ||
282 | |||
283 | (<\>) :: (KnownNat m, KnownNat n, KnownNat r) => L m n -> L m r -> L n r | ||
284 | (extract -> a) <\> (extract -> b) = mkL (a LA.<\> b) | ||
285 | |||
286 | svd :: (KnownNat m, KnownNat n) => L m n -> (L m m, R n, L n n) | ||
287 | svd (extract -> m) = (mkL u, mkR s', mkL v) | ||
288 | where | ||
289 | (u,s,v) = LA.svd m | ||
290 | s' = vjoin [s, z] | ||
291 | z = LA.konst 0 (max 0 (cols m - size s)) | ||
292 | |||
293 | |||
294 | svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) | ||
295 | svdTall (extract -> m) = (mkL u, mkR s, mkL v) | ||
296 | where | ||
297 | (u,s,v) = LA.thinSVD m | ||
298 | |||
299 | |||
300 | svdFlat :: (KnownNat m, KnownNat n, m <= n) => L m n -> (L m m, R m, L n m) | ||
301 | svdFlat (extract -> m) = (mkL u, mkR s, mkL v) | ||
302 | where | ||
303 | (u,s,v) = LA.thinSVD m | ||
304 | |||
305 | -------------------------------------------------------------------------------- | ||
306 | |||
307 | class Eigen m l v | m -> l, m -> v | ||
308 | where | ||
309 | eigensystem :: m -> (l,v) | ||
310 | eigenvalues :: m -> l | ||
311 | |||
312 | newtype Sym n = Sym (Sq n) deriving Show | ||
313 | |||
314 | |||
315 | sym :: KnownNat n => Sq n -> Sym n | ||
316 | sym m = Sym $ (m + tr m)/2 | ||
317 | |||
318 | |||
319 | |||
320 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) | ||
321 | where | ||
322 | eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m | ||
323 | eigensystem (Sym (extract -> m)) = (mkR l, mkL v) | ||
324 | where | ||
325 | (l,v) = LA.eigSH' m | ||
326 | |||
327 | instance KnownNat n => Eigen (Sq n) (C n) (M n n) | ||
328 | where | ||
329 | eigenvalues (extract -> m) = mkC . LA.eigenvalues $ m | ||
330 | eigensystem (extract -> m) = (mkC l, mkM v) | ||
331 | where | ||
332 | (l,v) = LA.eig m | ||
333 | |||
334 | -------------------------------------------------------------------------------- | ||
335 | |||
336 | |||
337 | withNullspace | ||
338 | :: forall m n z . (KnownNat m, KnownNat n) | ||
339 | => L m n | ||
340 | -> (forall k . (KnownNat k) => L n k -> z) | ||
341 | -> z | ||
342 | withNullspace (LA.nullspace . extract -> a) f = | ||
343 | case someNatVal $ fromIntegral $ cols a of | ||
344 | Nothing -> error "static/dynamic mismatch" | ||
345 | Just (SomeNat (_ :: Proxy k)) -> f (mkL a :: L n k) | ||
346 | |||
347 | -------------------------------------------------------------------------------- | ||
348 | |||
349 | split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) | ||
350 | split (extract -> v) = ( mkR (subVector 0 p' v) , | ||
351 | mkR (subVector p' (size v - p') v) ) | ||
352 | where | ||
353 | p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int | ||
354 | |||
355 | |||
356 | headTail :: (KnownNat n, 1<=n) => R n -> (ℝ, R (n-1)) | ||
357 | headTail = ((!0) . extract *** id) . split | ||
358 | |||
359 | |||
360 | splitRows :: forall p m n . (KnownNat p, KnownNat m, KnownNat n, p<=m) => L m n -> (L p n, L (m-p) n) | ||
361 | splitRows (extract -> x) = ( mkL (takeRows p' x) , | ||
362 | mkL (dropRows p' x) ) | ||
363 | where | ||
364 | p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int | ||
365 | |||
366 | splitCols :: forall p m n. (KnownNat p, KnownNat m, KnownNat n, KnownNat (n-p), p<=n) => L m n -> (L m p, L m (n-p)) | ||
367 | splitCols = (tr *** tr) . splitRows . tr | ||
368 | |||
369 | |||
370 | toRows :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R n] | ||
371 | toRows (LA.toRows . extract -> vs) = map mkR vs | ||
372 | |||
373 | |||
374 | toColumns :: forall m n . (KnownNat m, KnownNat n) => L m n -> [R m] | ||
375 | toColumns (LA.toColumns . extract -> vs) = map mkR vs | ||
376 | |||
377 | |||
378 | splittest | ||
379 | = do | ||
380 | let v = range :: R 7 | ||
381 | a = snd (split v) :: R 4 | ||
382 | print $ a | ||
383 | print $ snd . headTail . snd . headTail $ v | ||
384 | print $ first (vec3 1 2 3) | ||
385 | print $ second (vec3 1 2 3) | ||
386 | print $ third (vec3 1 2 3) | ||
387 | print $ (snd $ splitRows eye :: L 4 6) | ||
388 | where | ||
389 | first v = fst . headTail $ v | ||
390 | second v = first . snd . headTail $ v | ||
391 | third v = first . snd . headTail . snd . headTail $ v | ||
392 | |||
393 | -------------------------------------------------------------------------------- | ||
394 | |||
395 | build | ||
396 | :: forall m n . (KnownNat n, KnownNat m) | ||
397 | => (ℝ -> ℝ -> ℝ) | ||
398 | -> L m n | ||
399 | build f = mkL $ LA.build (m',n') f | ||
400 | where | ||
401 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
402 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
192 | 403 | ||
193 | -- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, returning Nothing for a singular system. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'. | 404 | -------------------------------------------------------------------------------- |
194 | linearSolve m b = A.mbLinearSolve m b | ||
195 | 405 | ||
196 | -- | return an orthonormal basis of the null space of a matrix. See also 'nullspaceSVD'. | 406 | withVector |
197 | nullspace m = nullspaceSVD (Left (1*eps)) m (rightSV m) | 407 | :: forall z |
408 | . Vector ℝ | ||
409 | -> (forall n . (KnownNat n) => R n -> z) | ||
410 | -> z | ||
411 | withVector v f = | ||
412 | case someNatVal $ fromIntegral $ size v of | ||
413 | Nothing -> error "static/dynamic mismatch" | ||
414 | Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) | ||
415 | |||
416 | |||
417 | withMatrix | ||
418 | :: forall z | ||
419 | . Matrix ℝ | ||
420 | -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z) | ||
421 | -> z | ||
422 | withMatrix a f = | ||
423 | case someNatVal $ fromIntegral $ rows a of | ||
424 | Nothing -> error "static/dynamic mismatch" | ||
425 | Just (SomeNat (_ :: Proxy m)) -> | ||
426 | case someNatVal $ fromIntegral $ cols a of | ||
427 | Nothing -> error "static/dynamic mismatch" | ||
428 | Just (SomeNat (_ :: Proxy n)) -> | ||
429 | f (mkL a :: L m n) | ||
430 | |||
431 | |||
432 | -------------------------------------------------------------------------------- | ||
433 | |||
434 | test :: (Bool, IO ()) | ||
435 | test = (ok,info) | ||
436 | where | ||
437 | ok = extract (eye :: Sq 5) == ident 5 | ||
438 | && unwrap (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] | ||
439 | && unwrap (tm :: L 3 5) == LA.matrix 5 [1..15] | ||
440 | && thingS == thingD | ||
441 | && precS == precD | ||
442 | && withVector (LA.vector [1..15]) sumV == sumElements (LA.fromList [1..15]) | ||
443 | |||
444 | info = do | ||
445 | print $ u | ||
446 | print $ v | ||
447 | print (eye :: Sq 3) | ||
448 | print $ ((u & 5) + 1) <·> v | ||
449 | print (tm :: L 2 5) | ||
450 | print (tm <> sm :: L 2 3) | ||
451 | print thingS | ||
452 | print thingD | ||
453 | print precS | ||
454 | print precD | ||
455 | print $ withVector (LA.vector [1..15]) sumV | ||
456 | splittest | ||
457 | |||
458 | sumV w = w <·> konst 1 | ||
459 | |||
460 | u = vec2 3 5 | ||
461 | |||
462 | 𝕧 x = vector [x] :: R 1 | ||
463 | |||
464 | v = 𝕧 2 & 4 & 7 | ||
465 | |||
466 | -- mTm :: L n m -> Sq m | ||
467 | mTm a = tr a <> a | ||
468 | |||
469 | tm :: GL | ||
470 | tm = lmat 0 [1..] | ||
471 | |||
472 | lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n | ||
473 | lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z | ||
474 | where | ||
475 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
476 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
477 | |||
478 | sm :: GSq | ||
479 | sm = lmat 0 [1..] | ||
480 | |||
481 | thingS = (u & 1) <·> tr q #> q #> v | ||
482 | where | ||
483 | q = tm :: L 10 3 | ||
484 | |||
485 | thingD = vjoin [ud1 u, 1] LA.<·> tr m LA.#> m LA.#> ud1 v | ||
486 | where | ||
487 | m = LA.matrix 3 [1..30] | ||
488 | |||
489 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v | ||
490 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v | ||
491 | |||
492 | |||
493 | instance (KnownNat n', KnownNat m') => Testable (L n' m') | ||
494 | where | ||
495 | checkT _ = test | ||
198 | 496 | ||
199 | -- | return an orthonormal basis of the range space of a matrix. See also 'orthSVD'. | ||
200 | orth m = orthSVD (Left (1*eps)) m (leftSV m) | ||
201 | 497 | ||