diff options
Diffstat (limited to 'packages/base/src/Numeric')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Real.hs | 105 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 6 |
2 files changed, 84 insertions, 27 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index 2ff69c7..d03ca6e 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs | |||
@@ -45,7 +45,7 @@ module Numeric.LinearAlgebra.Real( | |||
45 | -- * Linear Systems | 45 | -- * Linear Systems |
46 | linSolve, (<\>), | 46 | linSolve, (<\>), |
47 | -- * Factorizations | 47 | -- * Factorizations |
48 | svd, svdTall, svdFlat, eig, | 48 | svd, svdTall, svdFlat, Eigen(..), |
49 | -- * Pretty printing | 49 | -- * Pretty printing |
50 | Disp(..), | 50 | Disp(..), |
51 | -- * Misc | 51 | -- * Misc |
@@ -58,8 +58,9 @@ module Numeric.LinearAlgebra.Real( | |||
58 | import GHC.TypeLits | 58 | import GHC.TypeLits |
59 | import Numeric.HMatrix hiding ( | 59 | import Numeric.HMatrix hiding ( |
60 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace, | 60 | (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vect,mat,linspace, |
61 | (<\>),fromList,takeDiag,svd,eig) | 61 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH') |
62 | import qualified Numeric.HMatrix as LA | 62 | import qualified Numeric.HMatrix as LA |
63 | import Data.Packed.Internal(mbCatch) | ||
63 | import Data.Proxy(Proxy) | 64 | import Data.Proxy(Proxy) |
64 | import Numeric.LinearAlgebra.Static | 65 | import Numeric.LinearAlgebra.Static |
65 | import Text.Printf | 66 | import Text.Printf |
@@ -80,6 +81,8 @@ ud1 (R (Dim v)) = v | |||
80 | mkR :: Vector ℝ -> R n | 81 | mkR :: Vector ℝ -> R n |
81 | mkR = R . Dim | 82 | mkR = R . Dim |
82 | 83 | ||
84 | mkC :: Vector ℂ -> C n | ||
85 | mkC = C . Dim | ||
83 | 86 | ||
84 | infixl 4 & | 87 | infixl 4 & |
85 | (&) :: forall n . KnownNat n | 88 | (&) :: forall n . KnownNat n |
@@ -126,17 +129,17 @@ dim = mkR (scalar d) | |||
126 | 129 | ||
127 | newtype L m n = L (Dim m (Dim n (Matrix ℝ))) | 130 | newtype L m n = L (Dim m (Dim n (Matrix ℝ))) |
128 | 131 | ||
129 | -- newtype CL m n = CL (Dim m (Dim n (Matrix ℂ))) | 132 | newtype M m n = M (Dim m (Dim n (Matrix ℂ))) |
130 | 133 | ||
131 | ud2 :: L m n -> Matrix ℝ | 134 | ud2 :: L m n -> Matrix ℝ |
132 | ud2 (L (Dim (Dim x))) = x | 135 | ud2 (L (Dim (Dim x))) = x |
133 | 136 | ||
134 | 137 | ||
135 | |||
136 | |||
137 | mkL :: Matrix ℝ -> L m n | 138 | mkL :: Matrix ℝ -> L m n |
138 | mkL x = L (Dim (Dim x)) | 139 | mkL x = L (Dim (Dim x)) |
139 | 140 | ||
141 | mkM :: Matrix ℂ -> M m n | ||
142 | mkM x = M (Dim (Dim x)) | ||
140 | 143 | ||
141 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) | 144 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) |
142 | where | 145 | where |
@@ -150,6 +153,18 @@ instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) | |||
150 | 153 | ||
151 | -------------------------------------------------------------------------------- | 154 | -------------------------------------------------------------------------------- |
152 | 155 | ||
156 | instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ) | ||
157 | where | ||
158 | konst x = mkC (LA.scalar x) | ||
159 | unwrap (C (Dim v)) = v | ||
160 | fromList xs = C (gvect "C" xs) | ||
161 | extract (unwrap -> v) | ||
162 | | singleV v = LA.konst (v!0) d | ||
163 | | otherwise = v | ||
164 | where | ||
165 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
166 | |||
167 | |||
153 | instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) | 168 | instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) |
154 | where | 169 | where |
155 | konst x = mkR (LA.scalar x) | 170 | konst x = mkR (LA.scalar x) |
@@ -162,11 +177,12 @@ instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) | |||
162 | d = fromIntegral . natVal $ (undefined :: Proxy n) | 177 | d = fromIntegral . natVal $ (undefined :: Proxy n) |
163 | 178 | ||
164 | 179 | ||
180 | |||
165 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) | 181 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) |
166 | where | 182 | where |
167 | konst x = mkL (LA.scalar x) | 183 | konst x = mkL (LA.scalar x) |
168 | unwrap = ud2 | ||
169 | fromList = mat | 184 | fromList = mat |
185 | unwrap = ud2 | ||
170 | extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' | 186 | extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' |
171 | extract (unwrap -> a) | 187 | extract (unwrap -> a) |
172 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | 188 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') |
@@ -175,6 +191,20 @@ instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) | |||
175 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | 191 | m' = fromIntegral . natVal $ (undefined :: Proxy m) |
176 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | 192 | n' = fromIntegral . natVal $ (undefined :: Proxy n) |
177 | 193 | ||
194 | |||
195 | instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ) | ||
196 | where | ||
197 | konst x = mkM (LA.scalar x) | ||
198 | fromList xs = M (gmat "M" xs) | ||
199 | unwrap (M (Dim (Dim m))) = m | ||
200 | extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' | ||
201 | extract (unwrap -> a) | ||
202 | | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') | ||
203 | | otherwise = a | ||
204 | where | ||
205 | m' = fromIntegral . natVal $ (undefined :: Proxy m) | ||
206 | n' = fromIntegral . natVal $ (undefined :: Proxy n) | ||
207 | |||
178 | -------------------------------------------------------------------------------- | 208 | -------------------------------------------------------------------------------- |
179 | 209 | ||
180 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n | 210 | diagR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n |
@@ -225,26 +255,41 @@ instance (KnownNat m, KnownNat n) => Disp (L m n) | |||
225 | let su = LA.dispf n a | 255 | let su = LA.dispf n a |
226 | printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) | 256 | printf "L %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) |
227 | 257 | ||
258 | instance (KnownNat m, KnownNat n) => Disp (M m n) | ||
259 | where | ||
260 | disp n x = do | ||
261 | let a = extract x | ||
262 | let su = LA.dispcf n a | ||
263 | printf "M %d %d" (rows a) (cols a) >> putStr (dropWhile (/='\n') $ su) | ||
264 | |||
265 | |||
228 | instance KnownNat n => Disp (R n) | 266 | instance KnownNat n => Disp (R n) |
229 | where | 267 | where |
230 | disp n v = do | 268 | disp n v = do |
231 | let su = LA.dispf n (asRow $ extract v) | 269 | let su = LA.dispf n (asRow $ extract v) |
232 | putStr "R " >> putStr (tail . dropWhile (/='x') $ su) | 270 | putStr "R " >> putStr (tail . dropWhile (/='x') $ su) |
233 | 271 | ||
272 | instance KnownNat n => Disp (C n) | ||
273 | where | ||
274 | disp n v = do | ||
275 | let su = LA.dispcf n (asRow $ extract v) | ||
276 | putStr "C " >> putStr (tail . dropWhile (/='x') $ su) | ||
277 | |||
278 | |||
234 | -------------------------------------------------------------------------------- | 279 | -------------------------------------------------------------------------------- |
235 | 280 | ||
236 | 281 | ||
237 | row :: R n -> L 1 n | 282 | row :: R n -> L 1 n |
238 | row = mkL . asRow . ud1 | 283 | row = mkL . asRow . ud1 |
239 | 284 | ||
240 | col :: R n -> L n 1 | 285 | --col :: R n -> L n 1 |
241 | col = tr . row | 286 | col v = tr . row $ v |
242 | 287 | ||
243 | unrow :: L 1 n -> R n | 288 | unrow :: L 1 n -> R n |
244 | unrow = mkR . head . toRows . ud2 | 289 | unrow = mkR . head . toRows . ud2 |
245 | 290 | ||
246 | uncol :: L n 1 -> R n | 291 | --uncol :: L n 1 -> R n |
247 | uncol = unrow . tr | 292 | uncol v = unrow . tr $ v |
248 | 293 | ||
249 | 294 | ||
250 | infixl 2 —— | 295 | infixl 2 —— |
@@ -253,7 +298,7 @@ a —— b = mkL (extract a LA.—— extract b) | |||
253 | 298 | ||
254 | 299 | ||
255 | infixl 3 ¦ | 300 | infixl 3 ¦ |
256 | (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) | 301 | -- (¦) :: (KnownNat r, KnownNat c1, KnownNat c2) => L r c1 -> L r c2 -> L r (c1+c2) |
257 | a ¦ b = tr (tr a —— tr b) | 302 | a ¦ b = tr (tr a —— tr b) |
258 | 303 | ||
259 | 304 | ||
@@ -274,7 +319,14 @@ isKonst (unwrap -> x) | |||
274 | 319 | ||
275 | 320 | ||
276 | isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) | 321 | isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ, Vector ℝ, (Int,Int)) |
277 | isDiag (unwrap -> x) | 322 | isDiag (L x) = isDiagg x |
323 | |||
324 | isDiagC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ, Vector ℂ, (Int,Int)) | ||
325 | isDiagC (M x) = isDiagg x | ||
326 | |||
327 | |||
328 | isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> Maybe (t, Vector t, (Int,Int)) | ||
329 | isDiagg (Dim (Dim x)) | ||
278 | | singleM x = Nothing | 330 | | singleM x = Nothing |
279 | | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) | 331 | | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) |
280 | | otherwise = Nothing | 332 | | otherwise = Nothing |
@@ -282,7 +334,7 @@ isDiag (unwrap -> x) | |||
282 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 334 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int |
283 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 335 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
284 | v = flatten x | 336 | v = flatten x |
285 | z = v!0 | 337 | z = v `atIndex` 0 |
286 | y = subVector 1 (size v-1) v | 338 | y = subVector 1 (size v-1) v |
287 | ny = size y | 339 | ny = size y |
288 | zeros = LA.konst 0 (max 0 (min m' n' - ny)) | 340 | zeros = LA.konst 0 (max 0 (min m' n' - ny)) |
@@ -320,9 +372,10 @@ infixr 8 <·> | |||
320 | | otherwise = udot u v | 372 | | otherwise = udot u v |
321 | 373 | ||
322 | 374 | ||
323 | instance Transposable (L m n) (L n m) | 375 | instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m) |
324 | where | 376 | where |
325 | tr (ud2 -> a) = mkL (tr a) | 377 | tr a@(isDiag -> Just _) = mkL (extract a) |
378 | tr (extract -> a) = mkL (tr a) | ||
326 | 379 | ||
327 | -------------------------------------------------------------------------------- | 380 | -------------------------------------------------------------------------------- |
328 | 381 | ||
@@ -424,11 +477,12 @@ svdFlat (extract -> m) = (mkL u, mkR s, mkL v) | |||
424 | 477 | ||
425 | -------------------------------------------------------------------------------- | 478 | -------------------------------------------------------------------------------- |
426 | 479 | ||
427 | class Eig m r | m -> r | 480 | class Eigen m l v | m -> l, m -> v |
428 | where | 481 | where |
429 | eig :: m -> r | 482 | eigensystem :: m -> (l,v) |
483 | eigenvalues :: m -> l | ||
430 | 484 | ||
431 | newtype Sym n = Sym (Sq n) | 485 | newtype Sym n = Sym (Sq n) deriving Show |
432 | 486 | ||
433 | --newtype Her n = Her (CSq n) | 487 | --newtype Her n = Her (CSq n) |
434 | 488 | ||
@@ -438,16 +492,19 @@ sym m = Sym $ (m + tr m)/2 | |||
438 | --her :: KnownNat n => CSq n -> Her n | 492 | --her :: KnownNat n => CSq n -> Her n |
439 | --her = undefined -- Her $ (m + tr m)/2 | 493 | --her = undefined -- Her $ (m + tr m)/2 |
440 | 494 | ||
441 | 495 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) | |
442 | instance KnownNat n => Eig (Sym n) (R n, Sq n) | ||
443 | where | 496 | where |
444 | eig (Sym (extract -> m)) = (mkR l, mkL v) | 497 | eigenvalues (Sym (extract -> m)) = mkR . LA.eigenvaluesSH' $ m |
498 | eigensystem (Sym (extract -> m)) = (mkR l, mkL v) | ||
445 | where | 499 | where |
446 | (l,v) = eigSH m | 500 | (l,v) = LA.eigSH' m |
447 | 501 | ||
448 | instance KnownNat n => Eig (Sq n) (C n) | 502 | instance KnownNat n => Eigen (Sq n) (C n) (M n n) |
449 | where | 503 | where |
450 | eig (extract -> m) = C . Dim . eigenvalues $ m | 504 | eigenvalues (extract -> m) = mkC . LA.eigenvalues $ m |
505 | eigensystem (extract -> m) = (mkC l, mkM v) | ||
506 | where | ||
507 | (l,v) = LA.eig m | ||
451 | 508 | ||
452 | -------------------------------------------------------------------------------- | 509 | -------------------------------------------------------------------------------- |
453 | 510 | ||
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 5caf6f8..6acd9a3 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -27,7 +27,7 @@ module Numeric.LinearAlgebra.Static( | |||
27 | lift1F, lift2F, | 27 | lift1F, lift2F, |
28 | vconcat, gvec2, gvec3, gvec4, gvect, gmat, | 28 | vconcat, gvec2, gvec3, gvec4, gvect, gmat, |
29 | Sized(..), | 29 | Sized(..), |
30 | singleV, singleM | 30 | singleV, singleM,GM |
31 | ) where | 31 | ) where |
32 | 32 | ||
33 | 33 | ||
@@ -105,7 +105,7 @@ ud (Dim v) = v | |||
105 | mkV :: forall (n :: Nat) t . t -> Dim n t | 105 | mkV :: forall (n :: Nat) t . t -> Dim n t |
106 | mkV = Dim | 106 | mkV = Dim |
107 | 107 | ||
108 | type M m n t = Dim m (Dim n (Matrix t)) | 108 | type GM m n t = Dim m (Dim n (Matrix t)) |
109 | 109 | ||
110 | --ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t | 110 | --ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t |
111 | --ud2 (Dim (Dim m)) = m | 111 | --ud2 (Dim (Dim m)) = m |
@@ -166,7 +166,7 @@ gvect st xs' | |||
166 | abort info = error $ st++" "++show d++" can't be created from elements "++info | 166 | abort info = error $ st++" "++show d++" can't be created from elements "++info |
167 | 167 | ||
168 | 168 | ||
169 | gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> M m n t | 169 | gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> GM m n t |
170 | gmat st xs' | 170 | gmat st xs' |
171 | | ok = mkM x | 171 | | ok = mkM x |
172 | | not (null rest) && null (tail rest) = abort (show xs') | 172 | | not (null rest) && null (tail rest) = abort (show xs') |