diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static.hs')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 84 |
1 files changed, 31 insertions, 53 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 3398e6a..843c727 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -1,5 +1,3 @@ | |||
1 | #if __GLASGOW_HASKELL__ >= 708 | ||
2 | |||
3 | {-# LANGUAGE DataKinds #-} | 1 | {-# LANGUAGE DataKinds #-} |
4 | {-# LANGUAGE KindSignatures #-} | 2 | {-# LANGUAGE KindSignatures #-} |
5 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | 3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} |
@@ -13,7 +11,6 @@ | |||
13 | {-# LANGUAGE TypeOperators #-} | 11 | {-# LANGUAGE TypeOperators #-} |
14 | {-# LANGUAGE ViewPatterns #-} | 12 | {-# LANGUAGE ViewPatterns #-} |
15 | {-# LANGUAGE GADTs #-} | 13 | {-# LANGUAGE GADTs #-} |
16 | {-# LANGUAGE OverlappingInstances #-} | ||
17 | {-# LANGUAGE TypeFamilies #-} | 14 | {-# LANGUAGE TypeFamilies #-} |
18 | 15 | ||
19 | 16 | ||
@@ -25,19 +22,19 @@ Stability : experimental | |||
25 | 22 | ||
26 | Experimental interface with statically checked dimensions. | 23 | Experimental interface with statically checked dimensions. |
27 | 24 | ||
28 | This module is under active development and the interface is subject to changes. | 25 | See code examples at http://dis.um.es/~alberto/hmatrix/static.html. |
29 | 26 | ||
30 | -} | 27 | -} |
31 | 28 | ||
32 | module Numeric.LinearAlgebra.Static( | 29 | module Numeric.LinearAlgebra.Static( |
33 | -- * Vector | 30 | -- * Vector |
34 | ℝ, R, | 31 | ℝ, R, |
35 | vec2, vec3, vec4, (&), (#), split, headTail, | 32 | vec2, vec3, vec4, (&), (#), split, headTail, |
36 | vector, | 33 | vector, |
37 | linspace, range, dim, | 34 | linspace, range, dim, |
38 | -- * Matrix | 35 | -- * Matrix |
39 | L, Sq, build, | 36 | L, Sq, build, |
40 | row, col, (¦),(——), splitRows, splitCols, | 37 | row, col, (|||),(===), splitRows, splitCols, |
41 | unrow, uncol, | 38 | unrow, uncol, |
42 | tr, | 39 | tr, |
43 | eye, | 40 | eye, |
@@ -47,7 +44,7 @@ module Numeric.LinearAlgebra.Static( | |||
47 | -- * Complex | 44 | -- * Complex |
48 | C, M, Her, her, 𝑖, | 45 | C, M, Her, her, 𝑖, |
49 | -- * Products | 46 | -- * Products |
50 | (<>),(#>),(<·>), | 47 | (<>),(#>),(<.>), |
51 | -- * Linear Systems | 48 | -- * Linear Systems |
52 | linSolve, (<\>), | 49 | linSolve, (<\>), |
53 | -- * Factorizations | 50 | -- * Factorizations |
@@ -58,26 +55,22 @@ module Numeric.LinearAlgebra.Static( | |||
58 | Disp(..), Domain(..), | 55 | Disp(..), Domain(..), |
59 | withVector, withMatrix, | 56 | withVector, withMatrix, |
60 | toRows, toColumns, | 57 | toRows, toColumns, |
61 | Sized(..), Diag(..), Sym, sym, mTm, unSym | 58 | Sized(..), Diag(..), Sym, sym, mTm, unSym, (<·>) |
62 | ) where | 59 | ) where |
63 | 60 | ||
64 | 61 | ||
65 | import GHC.TypeLits | 62 | import GHC.TypeLits |
66 | import Numeric.LinearAlgebra.HMatrix hiding ( | 63 | import Numeric.LinearAlgebra hiding ( |
67 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——), | 64 | (<>),(#>),(<.>),Konst(..),diag, disp,(===),(|||), |
68 | row,col,vector,matrix,linspace,toRows,toColumns, | 65 | row,col,vector,matrix,linspace,toRows,toColumns, |
69 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH', | 66 | (<\>),fromList,takeDiag,svd,eig,eigSH, |
70 | eigenvalues,eigenvaluesSH,eigenvaluesSH',build, | 67 | eigenvalues,eigenvaluesSH,build, |
71 | qr,size,app,mul,dot,chol) | 68 | qr,size,dot,chol,range,R,C,sym,mTm,unSym) |
72 | import qualified Numeric.LinearAlgebra.HMatrix as LA | 69 | import qualified Numeric.LinearAlgebra as LA |
73 | import Data.Proxy(Proxy) | 70 | import Data.Proxy(Proxy) |
74 | import Numeric.LinearAlgebra.Static.Internal | 71 | import Internal.Static |
75 | import Control.Arrow((***)) | 72 | import Control.Arrow((***)) |
76 | 73 | ||
77 | |||
78 | |||
79 | |||
80 | |||
81 | ud1 :: R n -> Vector ℝ | 74 | ud1 :: R n -> Vector ℝ |
82 | ud1 (R (Dim v)) = v | 75 | ud1 (R (Dim v)) = v |
83 | 76 | ||
@@ -171,21 +164,22 @@ unrow = mkR . head . LA.toRows . ud2 | |||
171 | uncol v = unrow . tr $ v | 164 | uncol v = unrow . tr $ v |
172 | 165 | ||
173 | 166 | ||
174 | infixl 2 —— | 167 | infixl 2 === |
175 | (——) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c | 168 | (===) :: (KnownNat r1, KnownNat r2, KnownNat c) => L r1 c -> L r2 c -> L (r1+r2) c |
176 | a —— b = mkL (extract a LA.—— extract b) | 169 | a === b = mkL (extract a LA.=== extract b) |
177 | 170 | ||
178 | 171 | ||
179 | infixl 3 ¦ | 172 | infixl 3 ||| |
180 | -- (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) | 173 | -- (|||) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) |
181 | a ¦ b = tr (tr a —— tr b) | 174 | a ||| b = tr (tr a === tr b) |
182 | 175 | ||
183 | 176 | ||
184 | type Sq n = L n n | 177 | type Sq n = L n n |
185 | --type CSq n = CL n n | 178 | --type CSq n = CL n n |
186 | 179 | ||
187 | type GL = forall n m. (KnownNat n, KnownNat m) => L m n | 180 | |
188 | type GSq = forall n. KnownNat n => Sq n | 181 | type GL = forall n m . (KnownNat n, KnownNat m) => L m n |
182 | type GSq = forall n . KnownNat n => Sq n | ||
189 | 183 | ||
190 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) | 184 | isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) |
191 | isKonst s@(unwrap -> x) | 185 | isKonst s@(unwrap -> x) |
@@ -213,6 +207,10 @@ infixr 8 <·> | |||
213 | (<·>) :: R n -> R n -> ℝ | 207 | (<·>) :: R n -> R n -> ℝ |
214 | (<·>) = dotR | 208 | (<·>) = dotR |
215 | 209 | ||
210 | infixr 8 <.> | ||
211 | (<.>) :: R n -> R n -> ℝ | ||
212 | (<.>) = dotR | ||
213 | |||
216 | -------------------------------------------------------------------------------- | 214 | -------------------------------------------------------------------------------- |
217 | 215 | ||
218 | class Diag m d | m -> d | 216 | class Diag m d | m -> d |
@@ -294,10 +292,10 @@ her m = Her $ (m + LA.tr m)/2 | |||
294 | 292 | ||
295 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) | 293 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) |
296 | where | 294 | where |
297 | eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m | 295 | eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH . LA.trustSym $ m |
298 | eigensystem (Sym (extract -> m)) = (mkR l, mkL v) | 296 | eigensystem (Sym (extract -> m)) = (mkR l, mkL v) |
299 | where | 297 | where |
300 | (l,v) = LA.eigSH' m | 298 | (l,v) = LA.eigSH . LA.trustSym $ m |
301 | 299 | ||
302 | instance KnownNat n => Eigen (Sq n) (C n) (M n n) | 300 | instance KnownNat n => Eigen (Sq n) (C n) (M n n) |
303 | where | 301 | where |
@@ -307,7 +305,7 @@ instance KnownNat n => Eigen (Sq n) (C n) (M n n) | |||
307 | (l,v) = LA.eig m | 305 | (l,v) = LA.eig m |
308 | 306 | ||
309 | chol :: KnownNat n => Sym n -> Sq n | 307 | chol :: KnownNat n => Sym n -> Sq n |
310 | chol (extract . unSym -> m) = mkL $ LA.cholSH m | 308 | chol (extract . unSym -> m) = mkL $ LA.chol $ LA.trustSym m |
311 | 309 | ||
312 | -------------------------------------------------------------------------------- | 310 | -------------------------------------------------------------------------------- |
313 | 311 | ||
@@ -502,7 +500,7 @@ appC m v = mkC (extract m LA.#> extract v) | |||
502 | dotC :: KnownNat n => C n -> C n -> ℂ | 500 | dotC :: KnownNat n => C n -> C n -> ℂ |
503 | dotC (unwrap -> u) (unwrap -> v) | 501 | dotC (unwrap -> u) (unwrap -> v) |
504 | | singleV u || singleV v = sumElements (conj u * v) | 502 | | singleV u || singleV v = sumElements (conj u * v) |
505 | | otherwise = u LA.<·> v | 503 | | otherwise = u LA.<.> v |
506 | 504 | ||
507 | 505 | ||
508 | crossC :: C 3 -> C 3 -> C 3 | 506 | crossC :: C 3 -> C 3 -> C 3 |
@@ -590,12 +588,12 @@ test = (ok,info) | |||
590 | where | 588 | where |
591 | q = tm :: L 10 3 | 589 | q = tm :: L 10 3 |
592 | 590 | ||
593 | thingD = vjoin [ud1 u, 1] LA.<·> tr m LA.#> m LA.#> ud1 v | 591 | thingD = vjoin [ud1 u, 1] LA.<.> tr m LA.#> m LA.#> ud1 v |
594 | where | 592 | where |
595 | m = LA.matrix 3 [1..30] | 593 | m = LA.matrix 3 [1..30] |
596 | 594 | ||
597 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v | 595 | precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v |
598 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (LA.size (ud1 u) +1, LA.size (ud1 v)) LA.#> ud1 v | 596 | precD = 1 + 2 * vjoin[ud1 u, 6] LA.<.> LA.konst 2 (LA.size (ud1 u) +1, LA.size (ud1 v)) LA.#> ud1 v |
599 | 597 | ||
600 | 598 | ||
601 | splittest | 599 | splittest |
@@ -618,23 +616,3 @@ instance (KnownNat n', KnownNat m') => Testable (L n' m') | |||
618 | where | 616 | where |
619 | checkT _ = test | 617 | checkT _ = test |
620 | 618 | ||
621 | #else | ||
622 | |||
623 | {- | | ||
624 | Module : Numeric.LinearAlgebra.Static | ||
625 | Copyright : (c) Alberto Ruiz 2014 | ||
626 | License : BSD3 | ||
627 | Stability : experimental | ||
628 | |||
629 | Experimental interface with statically checked dimensions. | ||
630 | |||
631 | This module requires GHC >= 7.8 | ||
632 | |||
633 | -} | ||
634 | |||
635 | module Numeric.LinearAlgebra.Static | ||
636 | {-# WARNING "This module requires GHC >= 7.8" #-} | ||
637 | where | ||
638 | |||
639 | #endif | ||
640 | |||