diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Numeric/HMatrix.hs | 2 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Real.hs | 97 |
2 files changed, 77 insertions, 22 deletions
diff --git a/packages/base/src/Numeric/HMatrix.hs b/packages/base/src/Numeric/HMatrix.hs index 7f27fd4..786fb6d 100644 --- a/packages/base/src/Numeric/HMatrix.hs +++ b/packages/base/src/Numeric/HMatrix.hs | |||
@@ -144,7 +144,7 @@ module Numeric.HMatrix ( | |||
144 | Transposable, | 144 | Transposable, |
145 | CGState(..), | 145 | CGState(..), |
146 | Testable(..), | 146 | Testable(..), |
147 | โ,โค,โ,โ, ๐, i_C --โ | 147 | โ,โค,โ,โ, i_C |
148 | ) where | 148 | ) where |
149 | 149 | ||
150 | import Numeric.LinearAlgebra.Data | 150 | import Numeric.LinearAlgebra.Data |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Real.hs b/packages/base/src/Numeric/LinearAlgebra/Real.hs index d03ca6e..0e54555 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Real.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Real.hs | |||
@@ -32,10 +32,10 @@ module Numeric.LinearAlgebra.Real( | |||
32 | vect, | 32 | vect, |
33 | linspace, range, dim, | 33 | linspace, range, dim, |
34 | -- * Matrix | 34 | -- * Matrix |
35 | L, Sq, | 35 | L, Sq, M, |
36 | row, col, (ยฆ),(โโ), | 36 | row, col, (ยฆ),(โโ), |
37 | unrow, uncol, | 37 | unrow, uncol, |
38 | 38 | ||
39 | eye, | 39 | eye, |
40 | diagR, diag, | 40 | diagR, diag, |
41 | blockAt, | 41 | blockAt, |
@@ -50,7 +50,7 @@ module Numeric.LinearAlgebra.Real( | |||
50 | Disp(..), | 50 | Disp(..), |
51 | -- * Misc | 51 | -- * Misc |
52 | withVector, withMatrix, | 52 | withVector, withMatrix, |
53 | Sized(..), Diag(..), Sym, sym, -- Her, her, | 53 | Sized(..), Diag(..), Sym, sym, Her, her, ๐, |
54 | module Numeric.HMatrix | 54 | module Numeric.HMatrix |
55 | ) where | 55 | ) where |
56 | 56 | ||
@@ -60,11 +60,13 @@ import Numeric.HMatrix hiding ( | |||
60 | (<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โโ),row,col,vect,mat,linspace, | 60 | (<>),(#>),(<ยท>),Konst(..),diag, disp,(ยฆ),(โโ),row,col,vect,mat,linspace, |
61 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH') | 61 | (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH') |
62 | import qualified Numeric.HMatrix as LA | 62 | import qualified Numeric.HMatrix as LA |
63 | import Data.Packed.Internal(mbCatch) | ||
64 | import Data.Proxy(Proxy) | 63 | import Data.Proxy(Proxy) |
65 | import Numeric.LinearAlgebra.Static | 64 | import Numeric.LinearAlgebra.Static |
66 | import Text.Printf | 65 | import Text.Printf |
67 | 66 | ||
67 | ๐ :: Sized โ s c => s | ||
68 | ๐ = konst i_C | ||
69 | |||
68 | instance forall n . KnownNat n => Show (R n) | 70 | instance forall n . KnownNat n => Show (R n) |
69 | where | 71 | where |
70 | show (ud1 -> v) | 72 | show (ud1 -> v) |
@@ -73,6 +75,15 @@ instance forall n . KnownNat n => Show (R n) | |||
73 | where | 75 | where |
74 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 76 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
75 | 77 | ||
78 | instance forall n . KnownNat n => Show (C n) | ||
79 | where | ||
80 | show (C (Dim v)) | ||
81 | | singleV v = "("++show (v!0)++" :: C "++show d++")" | ||
82 | | otherwise = "(fromList"++ drop 8 (show v)++" :: C "++show d++")" | ||
83 | where | ||
84 | d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
85 | |||
86 | |||
76 | 87 | ||
77 | ud1 :: R n -> Vector โ | 88 | ud1 :: R n -> Vector โ |
78 | ud1 (R (Dim v)) = v | 89 | ud1 (R (Dim v)) = v |
@@ -144,19 +155,30 @@ mkM x = M (Dim (Dim x)) | |||
144 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) | 155 | instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) |
145 | where | 156 | where |
146 | show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' | 157 | show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' |
147 | show (ud2 -> x) | 158 | show (ud2 -> x) |
148 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' | 159 | | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' |
149 | | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" | 160 | | otherwise = "(mat"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" |
150 | where | 161 | where |
151 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 162 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int |
152 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 163 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
153 | 164 | ||
165 | instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) | ||
166 | where | ||
167 | show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' | ||
168 | show (M (Dim (Dim x))) | ||
169 | | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' | ||
170 | | otherwise = "(fromList"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" | ||
171 | where | ||
172 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
173 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
174 | |||
175 | |||
154 | -------------------------------------------------------------------------------- | 176 | -------------------------------------------------------------------------------- |
155 | 177 | ||
156 | instance forall n. KnownNat n => Sized โ (C n) (Vector โ) | 178 | instance forall n. KnownNat n => Sized โ (C n) (Vector โ) |
157 | where | 179 | where |
158 | konst x = mkC (LA.scalar x) | 180 | konst x = mkC (LA.scalar x) |
159 | unwrap (C (Dim v)) = v | 181 | unwrap (C (Dim v)) = v |
160 | fromList xs = C (gvect "C" xs) | 182 | fromList xs = C (gvect "C" xs) |
161 | extract (unwrap -> v) | 183 | extract (unwrap -> v) |
162 | | singleV v = LA.konst (v!0) d | 184 | | singleV v = LA.konst (v!0) d |
@@ -240,7 +262,7 @@ blockAt x r c a = mkL res | |||
240 | 262 | ||
241 | mat :: forall m n . (KnownNat m, KnownNat n) => [โ] -> L m n | 263 | mat :: forall m n . (KnownNat m, KnownNat n) => [โ] -> L m n |
242 | mat xs = L (gmat "L" xs) | 264 | mat xs = L (gmat "L" xs) |
243 | 265 | ||
244 | -------------------------------------------------------------------------------- | 266 | -------------------------------------------------------------------------------- |
245 | 267 | ||
246 | class Disp t | 268 | class Disp t |
@@ -315,7 +337,7 @@ isKonst (unwrap -> x) | |||
315 | where | 337 | where |
316 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 338 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int |
317 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 339 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
318 | 340 | ||
319 | 341 | ||
320 | 342 | ||
321 | isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (โ, Vector โ, (Int,Int)) | 343 | isDiag :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (โ, Vector โ, (Int,Int)) |
@@ -329,7 +351,7 @@ isDiagg :: forall m n t . (Numeric t, KnownNat m, KnownNat n) => GM m n t -> May | |||
329 | isDiagg (Dim (Dim x)) | 351 | isDiagg (Dim (Dim x)) |
330 | | singleM x = Nothing | 352 | | singleM x = Nothing |
331 | | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) | 353 | | rows x == 1 && m' > 1 || cols x == 1 && n' > 1 = Just (z,yz,(m',n')) |
332 | | otherwise = Nothing | 354 | | otherwise = Nothing |
333 | where | 355 | where |
334 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | 356 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int |
335 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | 357 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int |
@@ -377,6 +399,12 @@ instance (KnownNat n, KnownNat m) => Transposable (L m n) (L n m) | |||
377 | tr a@(isDiag -> Just _) = mkL (extract a) | 399 | tr a@(isDiag -> Just _) = mkL (extract a) |
378 | tr (extract -> a) = mkL (tr a) | 400 | tr (extract -> a) = mkL (tr a) |
379 | 401 | ||
402 | instance (KnownNat n, KnownNat m) => Transposable (M m n) (M n m) | ||
403 | where | ||
404 | tr a@(isDiagC -> Just _) = mkM (extract a) | ||
405 | tr (extract -> a) = mkM (tr a) | ||
406 | |||
407 | |||
380 | -------------------------------------------------------------------------------- | 408 | -------------------------------------------------------------------------------- |
381 | 409 | ||
382 | adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b | 410 | adaptDiag f a@(isDiag -> Just _) b | isFull b = f (mkL (extract a)) b |
@@ -408,6 +436,33 @@ instance (KnownNat n, KnownNat m) => Fractional (L n m) | |||
408 | 436 | ||
409 | -------------------------------------------------------------------------------- | 437 | -------------------------------------------------------------------------------- |
410 | 438 | ||
439 | adaptDiagC f a@(isDiagC -> Just _) b | isFullC b = f (mkM (extract a)) b | ||
440 | adaptDiagC f a b@(isDiagC -> Just _) | isFullC a = f a (mkM (extract b)) | ||
441 | adaptDiagC f a b = f a b | ||
442 | |||
443 | isFullC m = isDiagC m == Nothing && not (singleM (unwrap m)) | ||
444 | |||
445 | lift1M f (M v) = M (f v) | ||
446 | lift2M f (M a) (M b) = M (f a b) | ||
447 | lift2MD f = adaptDiagC (lift2M f) | ||
448 | |||
449 | instance (KnownNat n, KnownNat m) => Num (M n m) | ||
450 | where | ||
451 | (+) = lift2MD (+) | ||
452 | (*) = lift2MD (*) | ||
453 | (-) = lift2MD (-) | ||
454 | abs = lift1M abs | ||
455 | signum = lift1M signum | ||
456 | negate = lift1M negate | ||
457 | fromInteger = M . Dim . Dim . fromInteger | ||
458 | |||
459 | instance (KnownNat n, KnownNat m) => Fractional (M n m) | ||
460 | where | ||
461 | fromRational = M . Dim . Dim . fromRational | ||
462 | (/) = lift2MD (/) | ||
463 | |||
464 | -------------------------------------------------------------------------------- | ||
465 | |||
411 | {- | 466 | {- |
412 | class Minim (n :: Nat) (m :: Nat) | 467 | class Minim (n :: Nat) (m :: Nat) |
413 | where | 468 | where |
@@ -481,16 +536,16 @@ class Eigen m l v | m -> l, m -> v | |||
481 | where | 536 | where |
482 | eigensystem :: m -> (l,v) | 537 | eigensystem :: m -> (l,v) |
483 | eigenvalues :: m -> l | 538 | eigenvalues :: m -> l |
484 | 539 | ||
485 | newtype Sym n = Sym (Sq n) deriving Show | 540 | newtype Sym n = Sym (Sq n) deriving Show |
486 | 541 | ||
487 | --newtype Her n = Her (CSq n) | 542 | newtype Her n = Her (M n n) |
488 | 543 | ||
489 | sym :: KnownNat n => Sq n -> Sym n | 544 | sym :: KnownNat n => Sq n -> Sym n |
490 | sym m = Sym $ (m + tr m)/2 | 545 | sym m = Sym $ (m + tr m)/2 |
491 | 546 | ||
492 | --her :: KnownNat n => CSq n -> Her n | 547 | her :: KnownNat n => M n n -> Her n |
493 | --her = undefined -- Her $ (m + tr m)/2 | 548 | her m = Her $ (m + tr m)/2 |
494 | 549 | ||
495 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) | 550 | instance KnownNat n => Eigen (Sym n) (R n) (L n n) |
496 | where | 551 | where |
@@ -505,12 +560,12 @@ instance KnownNat n => Eigen (Sq n) (C n) (M n n) | |||
505 | eigensystem (extract -> m) = (mkC l, mkM v) | 560 | eigensystem (extract -> m) = (mkC l, mkM v) |
506 | where | 561 | where |
507 | (l,v) = LA.eig m | 562 | (l,v) = LA.eig m |
508 | 563 | ||
509 | -------------------------------------------------------------------------------- | 564 | -------------------------------------------------------------------------------- |
510 | 565 | ||
511 | withVector | 566 | withVector |
512 | :: forall z | 567 | :: forall z |
513 | . Vector โ | 568 | . Vector โ |
514 | -> (forall n . (KnownNat n) => R n -> z) | 569 | -> (forall n . (KnownNat n) => R n -> z) |
515 | -> z | 570 | -> z |
516 | withVector v f = | 571 | withVector v f = |
@@ -521,16 +576,16 @@ withVector v f = | |||
521 | 576 | ||
522 | withMatrix | 577 | withMatrix |
523 | :: forall z | 578 | :: forall z |
524 | . Matrix โ | 579 | . Matrix โ |
525 | -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z) | 580 | -> (forall m n . (KnownNat m, KnownNat n) => L m n -> z) |
526 | -> z | 581 | -> z |
527 | withMatrix a f = | 582 | withMatrix a f = |
528 | case someNatVal $ fromIntegral $ rows a of | 583 | case someNatVal $ fromIntegral $ rows a of |
529 | Nothing -> error "static/dynamic mismatch" | 584 | Nothing -> error "static/dynamic mismatch" |
530 | Just (SomeNat (_ :: Proxy m)) -> | 585 | Just (SomeNat (_ :: Proxy m)) -> |
531 | case someNatVal $ fromIntegral $ cols a of | 586 | case someNatVal $ fromIntegral $ cols a of |
532 | Nothing -> error "static/dynamic mismatch" | 587 | Nothing -> error "static/dynamic mismatch" |
533 | Just (SomeNat (_ :: Proxy n)) -> | 588 | Just (SomeNat (_ :: Proxy n)) -> |
534 | f (mkL a :: L n m) | 589 | f (mkL a :: L n m) |
535 | 590 | ||
536 | -------------------------------------------------------------------------------- | 591 | -------------------------------------------------------------------------------- |
@@ -539,8 +594,8 @@ test :: (Bool, IO ()) | |||
539 | test = (ok,info) | 594 | test = (ok,info) |
540 | where | 595 | where |
541 | ok = extract (eye :: Sq 5) == ident 5 | 596 | ok = extract (eye :: Sq 5) == ident 5 |
542 | && ud2 (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] | 597 | && unwrap (mTm sm :: Sq 3) == tr ((3><3)[1..]) LA.<> (3><3)[1..] |
543 | && ud2 (tm :: L 3 5) == LA.mat 5 [1..15] | 598 | && unwrap (tm :: L 3 5) == LA.mat 5 [1..15] |
544 | && thingS == thingD | 599 | && thingS == thingD |
545 | && precS == precD | 600 | && precS == precD |
546 | && withVector (LA.vect [1..15]) sumV == sumElements (LA.fromList [1..15]) | 601 | && withVector (LA.vect [1..15]) sumV == sumElements (LA.fromList [1..15]) |
@@ -557,7 +612,7 @@ test = (ok,info) | |||
557 | print precS | 612 | print precS |
558 | print precD | 613 | print precD |
559 | print $ withVector (LA.vect [1..15]) sumV | 614 | print $ withVector (LA.vect [1..15]) sumV |
560 | 615 | ||
561 | sumV w = w <ยท> konst 1 | 616 | sumV w = w <ยท> konst 1 |
562 | 617 | ||
563 | u = vec2 3 5 | 618 | u = vec2 3 5 |