diff options
Diffstat (limited to 'packages/base/src')
-rw-r--r-- | packages/base/src/Internal/Static.hs | 11 | ||||
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 22 |
2 files changed, 31 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs index 9ed4710..f9dfff0 100644 --- a/packages/base/src/Internal/Static.hs +++ b/packages/base/src/Internal/Static.hs | |||
@@ -567,6 +567,17 @@ instance KnownNat n => Disp (C n) | |||
567 | 567 | ||
568 | -------------------------------------------------------------------------------- | 568 | -------------------------------------------------------------------------------- |
569 | 569 | ||
570 | overMatL' :: (KnownNat m, KnownNat n) | ||
571 | => (LA.Matrix ℝ -> LA.Matrix ℝ) -> L m n -> L m n | ||
572 | overMatL' f = mkL . f . unwrap | ||
573 | {-# INLINE overMatL' #-} | ||
574 | |||
575 | overMatM' :: (KnownNat m, KnownNat n) | ||
576 | => (LA.Matrix ℂ -> LA.Matrix ℂ) -> M m n -> M m n | ||
577 | overMatM' f = mkM . f . unwrap | ||
578 | {-# INLINE overMatM' #-} | ||
579 | |||
580 | |||
570 | #else | 581 | #else |
571 | 582 | ||
572 | module Numeric.LinearAlgebra.Static.Internal where | 583 | module Numeric.LinearAlgebra.Static.Internal where |
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 64c0f14..296f8c7 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -537,6 +537,8 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve | |||
537 | zipWithVector :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n | 537 | zipWithVector :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n |
538 | det :: forall n. KnownNat n => mat n n -> field | 538 | det :: forall n. KnownNat n => mat n n -> field |
539 | invlndet :: forall n. KnownNat n => mat n n -> (mat n n, (field, field)) | 539 | invlndet :: forall n. KnownNat n => mat n n -> (mat n n, (field, field)) |
540 | expm :: forall n. KnownNat n => mat n n -> mat n n | ||
541 | sqrtm :: forall n. KnownNat n => mat n n -> mat n n | ||
540 | 542 | ||
541 | 543 | ||
542 | instance Domain ℝ R L | 544 | instance Domain ℝ R L |
@@ -552,6 +554,8 @@ instance Domain ℝ R L | |||
552 | zipWithVector = zipWithR | 554 | zipWithVector = zipWithR |
553 | det = detL | 555 | det = detL |
554 | invlndet = invlndetL | 556 | invlndet = invlndetL |
557 | expm = expmL | ||
558 | sqrtm = sqrtmL | ||
555 | 559 | ||
556 | instance Domain ℂ C M | 560 | instance Domain ℂ C M |
557 | where | 561 | where |
@@ -566,6 +570,8 @@ instance Domain ℂ C M | |||
566 | zipWithVector = zipWithC | 570 | zipWithVector = zipWithC |
567 | det = detM | 571 | det = detM |
568 | invlndet = invlndetM | 572 | invlndet = invlndetM |
573 | expm = expmM | ||
574 | sqrtm = sqrtmM | ||
569 | 575 | ||
570 | -------------------------------------------------------------------------------- | 576 | -------------------------------------------------------------------------------- |
571 | 577 | ||
@@ -615,7 +621,7 @@ zipWithR :: KnownNat n => (ℝ -> ℝ -> ℝ) -> R n -> R n -> R n | |||
615 | zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y) | 621 | zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y) |
616 | 622 | ||
617 | mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m | 623 | mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m |
618 | mapL f (unwrap -> m) = mkL (LA.cmap f m) | 624 | mapL f = overMatL' (LA.cmap f) |
619 | 625 | ||
620 | detL :: KnownNat n => Sq n -> ℝ | 626 | detL :: KnownNat n => Sq n -> ℝ |
621 | detL = LA.det . unwrap | 627 | detL = LA.det . unwrap |
@@ -623,6 +629,12 @@ detL = LA.det . unwrap | |||
623 | invlndetL :: KnownNat n => Sq n -> (L n n, (ℝ, ℝ)) | 629 | invlndetL :: KnownNat n => Sq n -> (L n n, (ℝ, ℝ)) |
624 | invlndetL = first mkL . LA.invlndet . unwrap | 630 | invlndetL = first mkL . LA.invlndet . unwrap |
625 | 631 | ||
632 | expmL :: KnownNat n => Sq n -> Sq n | ||
633 | expmL = overMatL' LA.expm | ||
634 | |||
635 | sqrtmL :: KnownNat n => Sq n -> Sq n | ||
636 | sqrtmL = overMatL' LA.sqrtm | ||
637 | |||
626 | -------------------------------------------------------------------------------- | 638 | -------------------------------------------------------------------------------- |
627 | 639 | ||
628 | mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n | 640 | mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n |
@@ -671,7 +683,7 @@ zipWithC :: KnownNat n => (ℂ -> ℂ -> ℂ) -> C n -> C n -> C n | |||
671 | zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y) | 683 | zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y) |
672 | 684 | ||
673 | mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m | 685 | mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m |
674 | mapM' f (unwrap -> m) = mkM (LA.cmap f m) | 686 | mapM' f = overMatM' (LA.cmap f) |
675 | 687 | ||
676 | detM :: KnownNat n => M n n -> ℂ | 688 | detM :: KnownNat n => M n n -> ℂ |
677 | detM = LA.det . unwrap | 689 | detM = LA.det . unwrap |
@@ -679,6 +691,12 @@ detM = LA.det . unwrap | |||
679 | invlndetM :: KnownNat n => M n n -> (M n n, (ℂ, ℂ)) | 691 | invlndetM :: KnownNat n => M n n -> (M n n, (ℂ, ℂ)) |
680 | invlndetM = first mkM . LA.invlndet . unwrap | 692 | invlndetM = first mkM . LA.invlndet . unwrap |
681 | 693 | ||
694 | expmM :: KnownNat n => M n n -> M n n | ||
695 | expmM = overMatM' LA.expm | ||
696 | |||
697 | sqrtmM :: KnownNat n => M n n -> M n n | ||
698 | sqrtmM = overMatM' LA.sqrtm | ||
699 | |||
682 | 700 | ||
683 | -------------------------------------------------------------------------------- | 701 | -------------------------------------------------------------------------------- |
684 | 702 | ||