summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Static.hs
diff options
context:
space:
mode:
authorJustin Le <justin@jle.im>2016-05-25 08:37:10 -0700
committerJustin Le <justin@jle.im>2016-05-25 08:37:10 -0700
commit3ddb98ef66ed672c4da67e38ff2127cc912aefe7 (patch)
treee121e8b57dbb85e85cf5d8bc463ac2f1ca023ec4 /packages/base/src/Numeric/LinearAlgebra/Static.hs
parent79369ee4c72d3c4844c734219c8f6430b3b0c4ab (diff)
added determinate functions to Domain typeclass. Rationale is that these can be verified to be square and are therefore total, compared to the determinate function from the untyped packages.
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs18
1 files changed, 18 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index 8019558..2e7c462 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -79,6 +79,7 @@ import Internal.Static
79import Control.Arrow((***)) 79import Control.Arrow((***))
80import Text.Printf 80import Text.Printf
81import Data.Type.Equality ((:~:)(Refl)) 81import Data.Type.Equality ((:~:)(Refl))
82import Data.Bifunctor (first)
82 83
83ud1 :: R n -> Vector ℝ 84ud1 :: R n -> Vector ℝ
84ud1 (R (Dim v)) = v 85ud1 (R (Dim v)) = v
@@ -534,6 +535,8 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve
534 dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m 535 dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m
535 outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m 536 outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m
536 zipWithVector :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n 537 zipWithVector :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n
538 det :: forall n. KnownNat n => mat n n -> field
539 invlndet :: forall n. KnownNat n => mat n n -> (mat n n, (field, field))
537 540
538 541
539instance Domain ℝ R L 542instance Domain ℝ R L
@@ -547,6 +550,8 @@ instance Domain ℝ R L
547 dmmap = mapL 550 dmmap = mapL
548 outer = outerR 551 outer = outerR
549 zipWithVector = zipWithR 552 zipWithVector = zipWithR
553 det = detL
554 invlndet = invlndetL
550 555
551instance Domain ℂ C M 556instance Domain ℂ C M
552 where 557 where
@@ -559,6 +564,8 @@ instance Domain ℂ C M
559 dmmap = mapM' 564 dmmap = mapM'
560 outer = outerC 565 outer = outerC
561 zipWithVector = zipWithC 566 zipWithVector = zipWithC
567 det = detM
568 invlndet = invlndetM
562 569
563-------------------------------------------------------------------------------- 570--------------------------------------------------------------------------------
564 571
@@ -610,6 +617,11 @@ zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y)
610mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m 617mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m
611mapL f (unwrap -> m) = mkL (LA.cmap f m) 618mapL f (unwrap -> m) = mkL (LA.cmap f m)
612 619
620detL :: KnownNat n => Sq n -> ℝ
621detL = LA.det . unwrap
622
623invlndetL :: KnownNat n => Sq n -> (L n n, (ℝ, ℝ))
624invlndetL = first mkL . LA.invlndet . unwrap
613 625
614-------------------------------------------------------------------------------- 626--------------------------------------------------------------------------------
615 627
@@ -661,6 +673,12 @@ zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y)
661mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m 673mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m
662mapM' f (unwrap -> m) = mkM (LA.cmap f m) 674mapM' f (unwrap -> m) = mkM (LA.cmap f m)
663 675
676detM :: KnownNat n => M n n -> ℂ
677detM = LA.det . unwrap
678
679invlndetM :: KnownNat n => M n n -> (M n n, (ℂ, ℂ))
680invlndetM = first mkM . LA.invlndet . unwrap
681
664 682
665-------------------------------------------------------------------------------- 683--------------------------------------------------------------------------------
666 684