summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Internal/Static.hs11
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs22
2 files changed, 31 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs
index 9ed4710..f9dfff0 100644
--- a/packages/base/src/Internal/Static.hs
+++ b/packages/base/src/Internal/Static.hs
@@ -567,6 +567,17 @@ instance KnownNat n => Disp (C n)
567 567
568-------------------------------------------------------------------------------- 568--------------------------------------------------------------------------------
569 569
570overMatL' :: (KnownNat m, KnownNat n)
571 => (LA.Matrix ℝ -> LA.Matrix ℝ) -> L m n -> L m n
572overMatL' f = mkL . f . unwrap
573{-# INLINE overMatL' #-}
574
575overMatM' :: (KnownNat m, KnownNat n)
576 => (LA.Matrix ℂ -> LA.Matrix ℂ) -> M m n -> M m n
577overMatM' f = mkM . f . unwrap
578{-# INLINE overMatM' #-}
579
580
570#else 581#else
571 582
572module Numeric.LinearAlgebra.Static.Internal where 583module Numeric.LinearAlgebra.Static.Internal where
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index 64c0f14..296f8c7 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -537,6 +537,8 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve
537 zipWithVector :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n 537 zipWithVector :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n
538 det :: forall n. KnownNat n => mat n n -> field 538 det :: forall n. KnownNat n => mat n n -> field
539 invlndet :: forall n. KnownNat n => mat n n -> (mat n n, (field, field)) 539 invlndet :: forall n. KnownNat n => mat n n -> (mat n n, (field, field))
540 expm :: forall n. KnownNat n => mat n n -> mat n n
541 sqrtm :: forall n. KnownNat n => mat n n -> mat n n
540 542
541 543
542instance Domain ℝ R L 544instance Domain ℝ R L
@@ -552,6 +554,8 @@ instance Domain ℝ R L
552 zipWithVector = zipWithR 554 zipWithVector = zipWithR
553 det = detL 555 det = detL
554 invlndet = invlndetL 556 invlndet = invlndetL
557 expm = expmL
558 sqrtm = sqrtmL
555 559
556instance Domain ℂ C M 560instance Domain ℂ C M
557 where 561 where
@@ -566,6 +570,8 @@ instance Domain ℂ C M
566 zipWithVector = zipWithC 570 zipWithVector = zipWithC
567 det = detM 571 det = detM
568 invlndet = invlndetM 572 invlndet = invlndetM
573 expm = expmM
574 sqrtm = sqrtmM
569 575
570-------------------------------------------------------------------------------- 576--------------------------------------------------------------------------------
571 577
@@ -615,7 +621,7 @@ zipWithR :: KnownNat n => (ℝ -> ℝ -> ℝ) -> R n -> R n -> R n
615zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y) 621zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y)
616 622
617mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m 623mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m
618mapL f (unwrap -> m) = mkL (LA.cmap f m) 624mapL f = overMatL' (LA.cmap f)
619 625
620detL :: KnownNat n => Sq n -> ℝ 626detL :: KnownNat n => Sq n -> ℝ
621detL = LA.det . unwrap 627detL = LA.det . unwrap
@@ -623,6 +629,12 @@ detL = LA.det . unwrap
623invlndetL :: KnownNat n => Sq n -> (L n n, (ℝ, ℝ)) 629invlndetL :: KnownNat n => Sq n -> (L n n, (ℝ, ℝ))
624invlndetL = first mkL . LA.invlndet . unwrap 630invlndetL = first mkL . LA.invlndet . unwrap
625 631
632expmL :: KnownNat n => Sq n -> Sq n
633expmL = overMatL' LA.expm
634
635sqrtmL :: KnownNat n => Sq n -> Sq n
636sqrtmL = overMatL' LA.sqrtm
637
626-------------------------------------------------------------------------------- 638--------------------------------------------------------------------------------
627 639
628mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n 640mulC :: forall m k n. (KnownNat m, KnownNat k, KnownNat n) => M m k -> M k n -> M m n
@@ -671,7 +683,7 @@ zipWithC :: KnownNat n => (ℂ -> ℂ -> ℂ) -> C n -> C n -> C n
671zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y) 683zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y)
672 684
673mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m 685mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m
674mapM' f (unwrap -> m) = mkM (LA.cmap f m) 686mapM' f = overMatM' (LA.cmap f)
675 687
676detM :: KnownNat n => M n n -> ℂ 688detM :: KnownNat n => M n n -> ℂ
677detM = LA.det . unwrap 689detM = LA.det . unwrap
@@ -679,6 +691,12 @@ detM = LA.det . unwrap
679invlndetM :: KnownNat n => M n n -> (M n n, (ℂ, ℂ)) 691invlndetM :: KnownNat n => M n n -> (M n n, (ℂ, ℂ))
680invlndetM = first mkM . LA.invlndet . unwrap 692invlndetM = first mkM . LA.invlndet . unwrap
681 693
694expmM :: KnownNat n => M n n -> M n n
695expmM = overMatM' LA.expm
696
697sqrtmM :: KnownNat n => M n n -> M n n
698sqrtmM = overMatM' LA.sqrtm
699
682 700
683-------------------------------------------------------------------------------- 701--------------------------------------------------------------------------------
684 702