diff options
author | Alberto Ruiz <aruiz@um.es> | 2016-05-28 13:51:05 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2016-05-28 13:51:05 +0200 |
commit | 9aade51bd0bb6339cfa8aca014bd96f801d9b19e (patch) | |
tree | 1c96ece927e8252501eadce0a228d4221984faa7 | |
parent | 42a88fbcb6bd1d2c4dc18fae5e962bd34fb316a1 (diff) | |
parent | cd6caa8f08e686fd4a90dae5f3414264aa2700a0 (diff) |
Merge pull request #191 from mstksg/static
[In Progress] Ongoing porting of functionality to Static module
-rw-r--r-- | packages/base/src/Internal/Static.hs | 23 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 72 |
2 files changed, 81 insertions, 14 deletions
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs index 381f3bc..f9dfff0 100644 --- a/packages/base/src/Internal/Static.hs +++ b/packages/base/src/Internal/Static.hs | |||
@@ -518,6 +518,18 @@ instance (KnownNat n, KnownNat m) => Floating (M n m) where | |||
518 | (**) = lift2MD (**) | 518 | (**) = lift2MD (**) |
519 | pi = M pi | 519 | pi = M pi |
520 | 520 | ||
521 | instance Additive (R n) where | ||
522 | add = (+) | ||
523 | |||
524 | instance Additive (C n) where | ||
525 | add = (+) | ||
526 | |||
527 | instance (KnownNat m, KnownNat n) => Additive (L m n) where | ||
528 | add = (+) | ||
529 | |||
530 | instance (KnownNat m, KnownNat n) => Additive (M m n) where | ||
531 | add = (+) | ||
532 | |||
521 | -------------------------------------------------------------------------------- | 533 | -------------------------------------------------------------------------------- |
522 | 534 | ||
523 | 535 | ||
@@ -555,6 +567,17 @@ instance KnownNat n => Disp (C n) | |||
555 | 567 | ||
556 | -------------------------------------------------------------------------------- | 568 | -------------------------------------------------------------------------------- |
557 | 569 | ||
570 | overMatL' :: (KnownNat m, KnownNat n) | ||
571 | => (LA.Matrix ℝ -> LA.Matrix ℝ) -> L m n -> L m n | ||
572 | overMatL' f = mkL . f . unwrap | ||
573 | {-# INLINE overMatL' #-} | ||
574 | |||
575 | overMatM' :: (KnownNat m, KnownNat n) | ||
576 | => (LA.Matrix ℂ -> LA.Matrix ℂ) -> M m n -> M m n | ||
577 | overMatM' f = mkM . f . unwrap | ||
578 | {-# INLINE overMatM' #-} | ||
579 | |||
580 | |||
558 | #else | 581 | #else |
559 | 582 | ||
560 | module Numeric.LinearAlgebra.Static.Internal where | 583 | module Numeric.LinearAlgebra.Static.Internal where |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 3e772b2..296f8c7 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -42,7 +42,7 @@ module Numeric.LinearAlgebra.Static( | |||
42 | blockAt, | 42 | blockAt, |
43 | matrix, | 43 | matrix, |
44 | -- * Complex | 44 | -- * Complex |
45 | C, M, Her, her, 𝑖, | 45 | ℂ, C, M, Her, her, 𝑖, |
46 | -- * Products | 46 | -- * Products |
47 | (<>),(#>),(<.>), | 47 | (<>),(#>),(<.>), |
48 | -- * Linear Systems | 48 | -- * Linear Systems |
@@ -78,6 +78,8 @@ import Data.Proxy(Proxy(..)) | |||
78 | import Internal.Static | 78 | import Internal.Static |
79 | import Control.Arrow((***)) | 79 | import Control.Arrow((***)) |
80 | import Text.Printf | 80 | import Text.Printf |
81 | import Data.Type.Equality ((:~:)(Refl)) | ||
82 | import Data.Bifunctor (first) | ||
81 | 83 | ||
82 | ud1 :: R n -> Vector ℝ | 84 | ud1 :: R n -> Vector ℝ |
83 | ud1 (R (Dim v)) = v | 85 | ud1 (R (Dim v)) = v |
@@ -444,11 +446,9 @@ exactLength | |||
444 | :: forall n m . (KnownNat n, KnownNat m) | 446 | :: forall n m . (KnownNat n, KnownNat m) |
445 | => R m | 447 | => R m |
446 | -> Maybe (R n) | 448 | -> Maybe (R n) |
447 | exactLength v | 449 | exactLength v = do |
448 | | natVal (Proxy :: Proxy n) == natVal (Proxy :: Proxy m) | 450 | Refl <- sameNat (Proxy :: Proxy n) (Proxy :: Proxy m) |
449 | = Just (mkR (unwrap v)) | 451 | return $ mkR (unwrap v) |
450 | | otherwise | ||
451 | = Nothing | ||
452 | 452 | ||
453 | withMatrix | 453 | withMatrix |
454 | :: forall z | 454 | :: forall z |
@@ -470,12 +470,10 @@ exactDims | |||
470 | :: forall n m j k . (KnownNat n, KnownNat m, KnownNat j, KnownNat k) | 470 | :: forall n m j k . (KnownNat n, KnownNat m, KnownNat j, KnownNat k) |
471 | => L m n | 471 | => L m n |
472 | -> Maybe (L j k) | 472 | -> Maybe (L j k) |
473 | exactDims m | 473 | exactDims m = do |
474 | | natVal (Proxy :: Proxy m) == natVal (Proxy :: Proxy j) | 474 | Refl <- sameNat (Proxy :: Proxy m) (Proxy :: Proxy j) |
475 | && natVal (Proxy :: Proxy n) == natVal (Proxy :: Proxy k) | 475 | Refl <- sameNat (Proxy :: Proxy n) (Proxy :: Proxy k) |
476 | = Just (mkL (unwrap m)) | 476 | return $ mkL (unwrap m) |
477 | | otherwise | ||
478 | = Nothing | ||
479 | 477 | ||
480 | randomVector | 478 | randomVector |
481 | :: forall n . KnownNat n | 479 | :: forall n . KnownNat n |
@@ -537,6 +535,10 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve | |||
537 | dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m | 535 | dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m |
538 | outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m | 536 | outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m |
539 | zipWithVector :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n | 537 | zipWithVector :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n |
538 | det :: forall n. KnownNat n => mat n n -> field | ||
539 | invlndet :: forall n. KnownNat n => mat n n -> (mat n n, (field, field)) | ||
540 | expm :: forall n. KnownNat n => mat n n -> mat n n | ||
541 | sqrtm :: forall n. KnownNat n => mat n n -> mat n n | ||
540 | 542 | ||
541 | 543 | ||
542 | instance Domain ℝ R L | 544 | instance Domain ℝ R L |
@@ -550,6 +552,10 @@ instance Domain ℝ R L | |||
550 | dmmap = mapL | 552 | dmmap = mapL |
551 | outer = outerR | 553 | outer = outerR |
552 | zipWithVector = zipWithR | 554 | zipWithVector = zipWithR |
555 | det = detL | ||
556 | invlndet = invlndetL | ||
557 | expm = expmL | ||
558 | sqrtm = sqrtmL | ||
553 | 559 | ||
554 | instance Domain ℂ C M | 560 | instance Domain ℂ C M |
555 | where | 561 | where |
@@ -562,6 +568,10 @@ instance Domain ℂ C M | |||
562 | dmmap = mapM' | 568 | dmmap = mapM' |
563 | outer = outerC | 569 | outer = outerC |
564 | zipWithVector = zipWithC | 570 | zipWithVector = zipWithC |
571 | det = detM | ||
572 | invlndet = invlndetM | ||
573 | expm = expmM | ||
574 | sqrtm = sqrtmM | ||
565 | 575 | ||
566 | -------------------------------------------------------------------------------- | 576 | -------------------------------------------------------------------------------- |
567 | 577 | ||
@@ -611,8 +621,19 @@ zipWithR :: KnownNat n => (ℝ -> ℝ -> ℝ) -> R n -> R n -> R n | |||
611 | zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y) | 621 | zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y) |
612 | 622 | ||
613 | mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m | 623 | mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m |
614 | mapL f (unwrap -> m) = mkL (LA.cmap f m) | 624 | mapL f = overMatL' (LA.cmap f) |
615 | 625 | ||
626 | detL :: KnownNat n => Sq n -> ℝ | ||
627 | detL = LA.det . unwrap | ||
628 | |||
629 | invlndetL :: KnownNat n => Sq n -> (L n n, (ℝ, ℝ)) | ||
630 | invlndetL = first mkL . LA.invlndet . unwrap | ||
631 | |||
632 | expmL :: KnownNat n => Sq n -> Sq n | ||
633 | expmL = overMatL' LA.expm | ||
634 | |||
635 | sqrtmL :: KnownNat n => Sq n -> Sq n | ||
636 | sqrtmL = overMatL' LA.sqrtm | ||
616 | 637 | ||
617 | -------------------------------------------------------------------------------- | 638 | -------------------------------------------------------------------------------- |
618 | 639 | ||
@@ -662,7 +683,19 @@ zipWithC :: KnownNat n => (ℂ -> ℂ -> ℂ) -> C n -> C n -> C n | |||
662 | zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y) | 683 | zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y) |
663 | 684 | ||
664 | mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m | 685 | mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m |
665 | mapM' f (unwrap -> m) = mkM (LA.cmap f m) | 686 | mapM' f = overMatM' (LA.cmap f) |
687 | |||
688 | detM :: KnownNat n => M n n -> ℂ | ||
689 | detM = LA.det . unwrap | ||
690 | |||
691 | invlndetM :: KnownNat n => M n n -> (M n n, (ℂ, ℂ)) | ||
692 | invlndetM = first mkM . LA.invlndet . unwrap | ||
693 | |||
694 | expmM :: KnownNat n => M n n -> M n n | ||
695 | expmM = overMatM' LA.expm | ||
696 | |||
697 | sqrtmM :: KnownNat n => M n n -> M n n | ||
698 | sqrtmM = overMatM' LA.sqrtm | ||
666 | 699 | ||
667 | 700 | ||
668 | -------------------------------------------------------------------------------- | 701 | -------------------------------------------------------------------------------- |
@@ -824,3 +857,14 @@ instance KnownNat n => Floating (Sym n) | |||
824 | sqrt = mkSym sqrt | 857 | sqrt = mkSym sqrt |
825 | (**) = mkSym2 (**) | 858 | (**) = mkSym2 (**) |
826 | pi = Sym pi | 859 | pi = Sym pi |
860 | |||
861 | instance KnownNat n => Additive (Sym n) where | ||
862 | add = (+) | ||
863 | |||
864 | instance KnownNat n => Transposable (Sym n) (Sym n) where | ||
865 | tr = id | ||
866 | tr' = id | ||
867 | |||
868 | instance KnownNat n => Transposable (Her n) (Her n) where | ||
869 | tr = id | ||
870 | tr' (Her m) = Her (tr' m) | ||