summaryrefslogtreecommitdiff
path: root/packages/base/src
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src')
-rw-r--r--packages/base/src/Internal/Static.hs23
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs72
2 files changed, 81 insertions, 14 deletions
diff --git a/packages/base/src/Internal/Static.hs b/packages/base/src/Internal/Static.hs
index 381f3bc..f9dfff0 100644
--- a/packages/base/src/Internal/Static.hs
+++ b/packages/base/src/Internal/Static.hs
@@ -518,6 +518,18 @@ instance (KnownNat n, KnownNat m) => Floating (M n m) where
518 (**) = lift2MD (**) 518 (**) = lift2MD (**)
519 pi = M pi 519 pi = M pi
520 520
521instance Additive (R n) where
522 add = (+)
523
524instance Additive (C n) where
525 add = (+)
526
527instance (KnownNat m, KnownNat n) => Additive (L m n) where
528 add = (+)
529
530instance (KnownNat m, KnownNat n) => Additive (M m n) where
531 add = (+)
532
521-------------------------------------------------------------------------------- 533--------------------------------------------------------------------------------
522 534
523 535
@@ -555,6 +567,17 @@ instance KnownNat n => Disp (C n)
555 567
556-------------------------------------------------------------------------------- 568--------------------------------------------------------------------------------
557 569
570overMatL' :: (KnownNat m, KnownNat n)
571 => (LA.Matrix ℝ -> LA.Matrix ℝ) -> L m n -> L m n
572overMatL' f = mkL . f . unwrap
573{-# INLINE overMatL' #-}
574
575overMatM' :: (KnownNat m, KnownNat n)
576 => (LA.Matrix ℂ -> LA.Matrix ℂ) -> M m n -> M m n
577overMatM' f = mkM . f . unwrap
578{-# INLINE overMatM' #-}
579
580
558#else 581#else
559 582
560module Numeric.LinearAlgebra.Static.Internal where 583module Numeric.LinearAlgebra.Static.Internal where
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
index 3e772b2..296f8c7 100644
--- a/packages/base/src/Numeric/LinearAlgebra/Static.hs
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -42,7 +42,7 @@ module Numeric.LinearAlgebra.Static(
42 blockAt, 42 blockAt,
43 matrix, 43 matrix,
44 -- * Complex 44 -- * Complex
45 C, M, Her, her, 𝑖, 45 ℂ, C, M, Her, her, 𝑖,
46 -- * Products 46 -- * Products
47 (<>),(#>),(<.>), 47 (<>),(#>),(<.>),
48 -- * Linear Systems 48 -- * Linear Systems
@@ -78,6 +78,8 @@ import Data.Proxy(Proxy(..))
78import Internal.Static 78import Internal.Static
79import Control.Arrow((***)) 79import Control.Arrow((***))
80import Text.Printf 80import Text.Printf
81import Data.Type.Equality ((:~:)(Refl))
82import Data.Bifunctor (first)
81 83
82ud1 :: R n -> Vector ℝ 84ud1 :: R n -> Vector ℝ
83ud1 (R (Dim v)) = v 85ud1 (R (Dim v)) = v
@@ -444,11 +446,9 @@ exactLength
444 :: forall n m . (KnownNat n, KnownNat m) 446 :: forall n m . (KnownNat n, KnownNat m)
445 => R m 447 => R m
446 -> Maybe (R n) 448 -> Maybe (R n)
447exactLength v 449exactLength v = do
448 | natVal (Proxy :: Proxy n) == natVal (Proxy :: Proxy m) 450 Refl <- sameNat (Proxy :: Proxy n) (Proxy :: Proxy m)
449 = Just (mkR (unwrap v)) 451 return $ mkR (unwrap v)
450 | otherwise
451 = Nothing
452 452
453withMatrix 453withMatrix
454 :: forall z 454 :: forall z
@@ -470,12 +470,10 @@ exactDims
470 :: forall n m j k . (KnownNat n, KnownNat m, KnownNat j, KnownNat k) 470 :: forall n m j k . (KnownNat n, KnownNat m, KnownNat j, KnownNat k)
471 => L m n 471 => L m n
472 -> Maybe (L j k) 472 -> Maybe (L j k)
473exactDims m 473exactDims m = do
474 | natVal (Proxy :: Proxy m) == natVal (Proxy :: Proxy j) 474 Refl <- sameNat (Proxy :: Proxy m) (Proxy :: Proxy j)
475 && natVal (Proxy :: Proxy n) == natVal (Proxy :: Proxy k) 475 Refl <- sameNat (Proxy :: Proxy n) (Proxy :: Proxy k)
476 = Just (mkL (unwrap m)) 476 return $ mkL (unwrap m)
477 | otherwise
478 = Nothing
479 477
480randomVector 478randomVector
481 :: forall n . KnownNat n 479 :: forall n . KnownNat n
@@ -537,6 +535,10 @@ class Domain field vec mat | mat -> vec field, vec -> mat field, field -> mat ve
537 dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m 535 dmmap :: forall n m. (KnownNat m, KnownNat n) => (field -> field) -> mat n m -> mat n m
538 outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m 536 outer :: forall n m. (KnownNat m, KnownNat n) => vec n -> vec m -> mat n m
539 zipWithVector :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n 537 zipWithVector :: forall n. KnownNat n => (field -> field -> field) -> vec n -> vec n -> vec n
538 det :: forall n. KnownNat n => mat n n -> field
539 invlndet :: forall n. KnownNat n => mat n n -> (mat n n, (field, field))
540 expm :: forall n. KnownNat n => mat n n -> mat n n
541 sqrtm :: forall n. KnownNat n => mat n n -> mat n n
540 542
541 543
542instance Domain ℝ R L 544instance Domain ℝ R L
@@ -550,6 +552,10 @@ instance Domain ℝ R L
550 dmmap = mapL 552 dmmap = mapL
551 outer = outerR 553 outer = outerR
552 zipWithVector = zipWithR 554 zipWithVector = zipWithR
555 det = detL
556 invlndet = invlndetL
557 expm = expmL
558 sqrtm = sqrtmL
553 559
554instance Domain ℂ C M 560instance Domain ℂ C M
555 where 561 where
@@ -562,6 +568,10 @@ instance Domain ℂ C M
562 dmmap = mapM' 568 dmmap = mapM'
563 outer = outerC 569 outer = outerC
564 zipWithVector = zipWithC 570 zipWithVector = zipWithC
571 det = detM
572 invlndet = invlndetM
573 expm = expmM
574 sqrtm = sqrtmM
565 575
566-------------------------------------------------------------------------------- 576--------------------------------------------------------------------------------
567 577
@@ -611,8 +621,19 @@ zipWithR :: KnownNat n => (ℝ -> ℝ -> ℝ) -> R n -> R n -> R n
611zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y) 621zipWithR f (extract -> x) (extract -> y) = mkR (LA.zipVectorWith f x y)
612 622
613mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m 623mapL :: (KnownNat n, KnownNat m) => (ℝ -> ℝ) -> L n m -> L n m
614mapL f (unwrap -> m) = mkL (LA.cmap f m) 624mapL f = overMatL' (LA.cmap f)
615 625
626detL :: KnownNat n => Sq n -> ℝ
627detL = LA.det . unwrap
628
629invlndetL :: KnownNat n => Sq n -> (L n n, (ℝ, ℝ))
630invlndetL = first mkL . LA.invlndet . unwrap
631
632expmL :: KnownNat n => Sq n -> Sq n
633expmL = overMatL' LA.expm
634
635sqrtmL :: KnownNat n => Sq n -> Sq n
636sqrtmL = overMatL' LA.sqrtm
616 637
617-------------------------------------------------------------------------------- 638--------------------------------------------------------------------------------
618 639
@@ -662,7 +683,19 @@ zipWithC :: KnownNat n => (ℂ -> ℂ -> ℂ) -> C n -> C n -> C n
662zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y) 683zipWithC f (extract -> x) (extract -> y) = mkC (LA.zipVectorWith f x y)
663 684
664mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m 685mapM' :: (KnownNat n, KnownNat m) => (ℂ -> ℂ) -> M n m -> M n m
665mapM' f (unwrap -> m) = mkM (LA.cmap f m) 686mapM' f = overMatM' (LA.cmap f)
687
688detM :: KnownNat n => M n n -> ℂ
689detM = LA.det . unwrap
690
691invlndetM :: KnownNat n => M n n -> (M n n, (ℂ, ℂ))
692invlndetM = first mkM . LA.invlndet . unwrap
693
694expmM :: KnownNat n => M n n -> M n n
695expmM = overMatM' LA.expm
696
697sqrtmM :: KnownNat n => M n n -> M n n
698sqrtmM = overMatM' LA.sqrtm
666 699
667 700
668-------------------------------------------------------------------------------- 701--------------------------------------------------------------------------------
@@ -824,3 +857,14 @@ instance KnownNat n => Floating (Sym n)
824 sqrt = mkSym sqrt 857 sqrt = mkSym sqrt
825 (**) = mkSym2 (**) 858 (**) = mkSym2 (**)
826 pi = Sym pi 859 pi = Sym pi
860
861instance KnownNat n => Additive (Sym n) where
862 add = (+)
863
864instance KnownNat n => Transposable (Sym n) (Sym n) where
865 tr = id
866 tr' = id
867
868instance KnownNat n => Transposable (Her n) (Her n) where
869 tr = id
870 tr' (Her m) = Her (tr' m)