From 1cfc81ba6a318b593598a9a038adaa73009f6530 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 17 Jun 2014 10:01:35 +0200 Subject: size and create --- packages/base/src/Numeric/LinearAlgebra/Static.hs | 95 ++++++++---------- .../src/Numeric/LinearAlgebra/Static/Internal.hs | 107 ++++++++++++--------- 2 files changed, 101 insertions(+), 101 deletions(-) (limited to 'packages/base/src/Numeric/LinearAlgebra') diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs index 388d165..213c42c 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs @@ -64,7 +64,7 @@ import GHC.TypeLits import Numeric.LinearAlgebra.HMatrix hiding ( (<>),(#>),(<·>),Konst(..),diag, disp,(¦),(——),row,col,vector,matrix,linspace,toRows,toColumns, (<\>),fromList,takeDiag,svd,eig,eigSH,eigSH',eigenvalues,eigenvaluesSH,eigenvaluesSH',build, - qr) + qr,size) import qualified Numeric.LinearAlgebra.HMatrix as LA import Data.Proxy(Proxy) import Numeric.LinearAlgebra.Static.Internal @@ -107,20 +107,20 @@ matrix :: (KnownNat m, KnownNat n) => [ℝ] -> L m n matrix = fromList linspace :: forall n . KnownNat n => (ℝ,ℝ) -> R n -linspace (a,b) = mkR (LA.linspace d (a,b)) +linspace (a,b) = v where - d = fromIntegral . natVal $ (undefined :: Proxy n) + v = mkR (LA.linspace (size v) (a,b)) range :: forall n . KnownNat n => R n -range = mkR (LA.linspace d (1,fromIntegral d)) +range = v where - d = fromIntegral . natVal $ (undefined :: Proxy n) + v = mkR (LA.linspace d (1,fromIntegral d)) + d = size v dim :: forall n . KnownNat n => R n -dim = mkR (scalar d) +dim = v where - d = fromIntegral . natVal $ (undefined :: Proxy n) - + v = mkR (scalar (fromIntegral $ size v)) -------------------------------------------------------------------------------- @@ -140,7 +140,7 @@ eye = diag 1 -------------------------------------------------------------------------------- blockAt :: forall m n . (KnownNat m, KnownNat n) => ℝ -> Int -> Int -> Matrix Double -> L m n -blockAt x r c a = mkL res +blockAt x r c a = res where z = scalar x z1 = LA.konst x (r,c) @@ -148,13 +148,8 @@ blockAt x r c a = mkL res ra = min (rows a) . max 0 $ m'-r ca = min (cols a) . max 0 $ n'-c sa = subMatrix (0,0) (ra, ca) a - m' = fromIntegral . natVal $ (undefined :: Proxy m) - n' = fromIntegral . natVal $ (undefined :: Proxy n) - res = fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] - - - - + (m',n') = size res + res = mkL $ fromBlocks [[z1,z,z],[z,sa,z],[z,z,z2]] -------------------------------------------------------------------------------- @@ -189,22 +184,15 @@ type GL = (KnownNat n, KnownNat m) => L m n type GSq = KnownNat n => Sq n isKonst :: forall m n . (KnownNat m, KnownNat n) => L m n -> Maybe (ℝ,(Int,Int)) -isKonst (unwrap -> x) - | singleM x = Just (x `atIndex` (0,0), (m',n')) +isKonst s@(unwrap -> x) + | singleM x = Just (x `atIndex` (0,0), (size s)) | otherwise = Nothing - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int - n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int isKonstC :: forall m n . (KnownNat m, KnownNat n) => M m n -> Maybe (ℂ,(Int,Int)) -isKonstC (unwrap -> x) - | singleM x = Just (x `atIndex` (0,0), (m',n')) +isKonstC s@(unwrap -> x) + | singleM x = Just (x `atIndex` (0,0), (size s)) | otherwise = Nothing - where - m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int - n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int - infixr 8 <> @@ -256,7 +244,7 @@ svd (extract -> m) = (mkL u, mkR s', mkL v) where (u,s,v) = LA.svd m s' = vjoin [s, z] - z = LA.konst 0 (max 0 (cols m - size s)) + z = LA.konst 0 (max 0 (cols m - LA.size s)) svdTall :: (KnownNat m, KnownNat n, n <= m) => L m n -> (L m n, R n, L n n) @@ -333,7 +321,7 @@ withCompactSVD -> (forall k . (KnownNat k) => (L m k, R k, L n k) -> z) -> z withCompactSVD (LA.compactSVD . extract -> (u,s,v)) f = - case someNatVal $ fromIntegral $ size s of + case someNatVal $ fromIntegral $ LA.size s of Nothing -> error "static/dynamic mismatch" Just (SomeNat (_ :: Proxy k)) -> f (mkL u :: L m k, mkR s :: R k, mkL v :: L n k) @@ -350,7 +338,7 @@ qr (extract -> x) = (mkL q, mkL r) split :: forall p n . (KnownNat p, KnownNat n, p<=n) => R n -> (R p, R (n-p)) split (extract -> v) = ( mkR (subVector 0 p' v) , - mkR (subVector p' (size v - p') v) ) + mkR (subVector p' (LA.size v - p') v) ) where p' = fromIntegral . natVal $ (undefined :: Proxy p) :: Int @@ -383,10 +371,9 @@ build :: forall m n . (KnownNat n, KnownNat m) => (ℝ -> ℝ -> ℝ) -> L m n -build f = mkL $ LA.build (m',n') f +build f = r where - m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int - n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + r = mkL $ LA.build (size r) f -------------------------------------------------------------------------------- @@ -396,7 +383,7 @@ withVector -> (forall n . (KnownNat n) => R n -> z) -> z withVector v f = - case someNatVal $ fromIntegral $ size v of + case someNatVal $ fromIntegral $ LA.size v of Nothing -> error "static/dynamic mismatch" Just (SomeNat (_ :: Proxy m)) -> f (mkR v :: R m) @@ -451,19 +438,19 @@ mulR (isKonst -> Just (a,(_,k))) (isKonst -> Just (b,_)) = konst (a * b * fromIn mulR (isDiag -> Just (0,a,_)) (isDiag -> Just (0,b,_)) = diagR 0 (mkR v :: R k) where v = a' * b' - n = min (size a) (size b) + n = min (LA.size a) (LA.size b) a' = subVector 0 n a b' = subVector 0 n b -mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (size a) b) +mulR (isDiag -> Just (0,a,_)) (extract -> b) = mkL (asColumn a * takeRows (LA.size a) b) -mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (size b) a * asRow b) +mulR (extract -> a) (isDiag -> Just (0,b,_)) = mkL (takeColumns (LA.size b) a * asRow b) mulR a b = mkL (extract a LA.<> extract b) appR :: (KnownNat m, KnownNat n) => L m n -> R n -> R m -appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (size w) (extract v)) +appR (isDiag -> Just (0, w, _)) v = mkR (w * subVector 0 (LA.size w) (extract v)) appR m v = mkR (extract m LA.#> extract v) @@ -489,19 +476,19 @@ mulC (isKonstC -> Just (a,(_,k))) (isKonstC -> Just (b,_)) = konst (a * b * from mulC (isDiagC -> Just (0,a,_)) (isDiagC -> Just (0,b,_)) = diagR 0 (mkC v :: C k) where v = a' * b' - n = min (size a) (size b) + n = min (LA.size a) (LA.size b) a' = subVector 0 n a b' = subVector 0 n b -mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (size a) b) +mulC (isDiagC -> Just (0,a,_)) (extract -> b) = mkM (asColumn a * takeRows (LA.size a) b) -mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (size b) a * asRow b) +mulC (extract -> a) (isDiagC -> Just (0,b,_)) = mkM (takeColumns (LA.size b) a * asRow b) mulC a b = mkM (extract a LA.<> extract b) appC :: (KnownNat m, KnownNat n) => M m n -> C n -> C m -appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (size w) (extract v)) +appC (isDiagC -> Just (0, w, _)) v = mkC (w * subVector 0 (LA.size w) (extract v)) appC m v = mkC (extract m LA.#> extract v) @@ -521,21 +508,21 @@ crossC (extract -> x) (extract -> y) = mkC (LA.fromList [z1, z2, z3]) -------------------------------------------------------------------------------- diagRectR :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℝ -> R k -> L m n -diagRectR x v = mkL (asRow (vjoin [scalar x, ev, zeros])) +diagRectR x v = r where + r = mkL (asRow (vjoin [scalar x, ev, zeros])) ev = extract v - zeros = LA.konst x (max 0 ((min m' n') - size ev)) - m' = fromIntegral . natVal $ (undefined :: Proxy m) - n' = fromIntegral . natVal $ (undefined :: Proxy n) + zeros = LA.konst x (max 0 ((min m' n') - LA.size ev)) + (m',n') = size r diagRectC :: forall m n k . (KnownNat m, KnownNat n, KnownNat k) => ℂ -> C k -> M m n -diagRectC x v = mkM (asRow (vjoin [scalar x, ev, zeros])) +diagRectC x v = r where + r = mkM (asRow (vjoin [scalar x, ev, zeros])) ev = extract v - zeros = LA.konst x (max 0 ((min m' n') - size ev)) - m' = fromIntegral . natVal $ (undefined :: Proxy m) - n' = fromIntegral . natVal $ (undefined :: Proxy n) + zeros = LA.konst x (max 0 ((min m' n') - LA.size ev)) + (m',n') = size r -------------------------------------------------------------------------------- @@ -578,10 +565,10 @@ test = (ok,info) tm = lmat 0 [1..] lmat :: forall m n . (KnownNat m, KnownNat n) => ℝ -> [ℝ] -> L m n - lmat z xs = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z + lmat z xs = r where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - n' = fromIntegral . natVal $ (undefined :: Proxy n) + r = mkL . reshape n' . LA.fromList . take (m'*n') $ xs ++ repeat z + (m',n') = size r sm :: GSq sm = lmat 0 [1..] @@ -595,7 +582,7 @@ test = (ok,info) m = LA.matrix 3 [1..30] precS = (1::Double) + (2::Double) * ((1 :: R 3) * (u & 6)) <·> konst 2 #> v - precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (size (ud1 u) +1, size (ud1 v)) LA.#> ud1 v + precD = 1 + 2 * vjoin[ud1 u, 6] LA.<·> LA.konst 2 (LA.size (ud1 u) +1, LA.size (ud1 v)) LA.#> ud1 v splittest diff --git a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs index 7968d77..339ef7d 100644 --- a/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs +++ b/packages/base/src/Numeric/LinearAlgebra/Static/Internal.hs @@ -7,13 +7,10 @@ {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE EmptyDataDecls #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE GADTs #-} - {- | Module : Numeric.LinearAlgebra.Static.Internal @@ -28,7 +25,7 @@ module Numeric.LinearAlgebra.Static.Internal where import GHC.TypeLits import qualified Numeric.LinearAlgebra.HMatrix as LA -import Numeric.LinearAlgebra.HMatrix hiding (konst) +import Numeric.LinearAlgebra.HMatrix hiding (konst,size) import Data.Packed as D import Data.Packed.ST import Data.Proxy(Proxy) @@ -83,7 +80,7 @@ ud :: Dim n (Vector t) -> Vector t ud (Dim v) = v mkV :: forall (n :: Nat) t . t -> Dim n t -mkV = Dim +mkV = Dim vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) @@ -92,9 +89,9 @@ vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) where du = fromIntegral . natVal $ (undefined :: Proxy n) dv = fromIntegral . natVal $ (undefined :: Proxy m) - u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du + u' | du > 1 && LA.size u == 1 = LA.konst (u D.@> 0) du | otherwise = u - v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv + v' | dv > 1 && LA.size v == 1 = LA.konst (v D.@> 0) dv | otherwise = v @@ -132,7 +129,7 @@ gvect st xs' | otherwise = abort (show xs) where (xs,rest) = splitAt d xs' - ok = size v == d && null rest + ok = LA.size v == d && null rest v = LA.fromList xs d = fromIntegral . natVal $ (undefined :: Proxy n) abort info = error $ st++" "++show d++" can't be created from elements "++info @@ -153,7 +150,7 @@ gmat st xs' (xs,rest) = splitAt (m'*n') xs' v = LA.fromList xs x = reshape n' v - ok = rem (size v) n' == 0 && size x == (m',n') && null rest + ok = rem (LA.size v) n' == 0 && LA.size x == (m',n') && null rest m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info @@ -162,66 +159,84 @@ gmat st xs' class Num t => Sized t s d | s -> t, s -> d where - konst :: t -> s - unwrap :: s -> d - fromList :: [t] -> s - extract :: s -> d - -singleV v = size v == 1 + konst :: t -> s + unwrap :: s -> d t + fromList :: [t] -> s + extract :: s -> d t + create :: d t -> Maybe s + size :: s -> IndexOf d + +singleV v = LA.size v == 1 singleM m = rows m == 1 && cols m == 1 -instance forall n. KnownNat n => Sized ℂ (C n) (Vector ℂ) +instance forall n. KnownNat n => Sized ℂ (C n) Vector where + size _ = fromIntegral . natVal $ (undefined :: Proxy n) konst x = mkC (LA.scalar x) unwrap (C (Dim v)) = v fromList xs = C (gvect "C" xs) - extract (unwrap -> v) - | singleV v = LA.konst (v!0) d + extract s@(unwrap -> v) + | singleV v = LA.konst (v!0) (size s) | otherwise = v - where - d = fromIntegral . natVal $ (undefined :: Proxy n) + create v + | LA.size v == size r = Just r + | otherwise = Nothing + where + r = mkC v :: C n -instance forall n. KnownNat n => Sized ℝ (R n) (Vector ℝ) +instance forall n. KnownNat n => Sized ℝ (R n) Vector where + size _ = fromIntegral . natVal $ (undefined :: Proxy n) konst x = mkR (LA.scalar x) unwrap (R (Dim v)) = v fromList xs = R (gvect "R" xs) - extract (unwrap -> v) - | singleV v = LA.konst (v!0) d + extract s@(unwrap -> v) + | singleV v = LA.konst (v!0) (size s) | otherwise = v - where - d = fromIntegral . natVal $ (undefined :: Proxy n) + create v + | LA.size v == size r = Just r + | otherwise = Nothing + where + r = mkR v :: R n -instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) (Matrix ℝ) +instance forall m n . (KnownNat m, KnownNat n) => Sized ℝ (L m n) Matrix where + size _ = ((fromIntegral . natVal) (undefined :: Proxy m) + ,(fromIntegral . natVal) (undefined :: Proxy n)) konst x = mkL (LA.scalar x) fromList xs = L (gmat "L" xs) unwrap (L (Dim (Dim m))) = m extract (isDiag -> Just (z,y,(m',n'))) = diagRect z y m' n' - extract (unwrap -> a) - | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') + extract s@(unwrap -> a) + | singleM a = LA.konst (a `atIndex` (0,0)) (size s) | otherwise = a + create x + | LA.size x == size r = Just r + | otherwise = Nothing where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - n' = fromIntegral . natVal $ (undefined :: Proxy n) + r = mkL x :: L m n -instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) (Matrix ℂ) +instance forall m n . (KnownNat m, KnownNat n) => Sized ℂ (M m n) Matrix where + size _ = ((fromIntegral . natVal) (undefined :: Proxy m) + ,(fromIntegral . natVal) (undefined :: Proxy n)) konst x = mkM (LA.scalar x) fromList xs = M (gmat "M" xs) unwrap (M (Dim (Dim m))) = m extract (isDiagC -> Just (z,y,(m',n'))) = diagRect z y m' n' - extract (unwrap -> a) - | singleM a = LA.konst (a `atIndex` (0,0)) (m',n') + extract s@(unwrap -> a) + | singleM a = LA.konst (a `atIndex` (0,0)) (size s) | otherwise = a + create x + | LA.size x == size r = Just r + | otherwise = Nothing where - m' = fromIntegral . natVal $ (undefined :: Proxy m) - n' = fromIntegral . natVal $ (undefined :: Proxy n) + r = mkM x :: M m n -------------------------------------------------------------------------------- @@ -254,8 +269,8 @@ isDiagg (Dim (Dim x)) n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int v = flatten x z = v `atIndex` 0 - y = subVector 1 (size v-1) v - ny = size y + y = subVector 1 (LA.size v-1) v + ny = LA.size y zeros = LA.konst 0 (max 0 (min m' n' - ny)) yz = vjoin [y,zeros] @@ -263,39 +278,37 @@ isDiagg (Dim (Dim x)) instance forall n . KnownNat n => Show (R n) where - show (R (Dim v)) + show s@(R (Dim v)) | singleV v = "("++show (v!0)++" :: R "++show d++")" | otherwise = "(vector"++ drop 8 (show v)++" :: R "++show d++")" where - d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + d = size s instance forall n . KnownNat n => Show (C n) where - show (C (Dim v)) + show s@(C (Dim v)) | singleV v = "("++show (v!0)++" :: C "++show d++")" | otherwise = "(vector"++ drop 8 (show v)++" :: C "++show d++")" where - d = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + d = size s instance forall m n . (KnownNat m, KnownNat n) => Show (L m n) where show (isDiag -> Just (z,y,(m',n'))) = printf "(diag %s %s :: L %d %d)" (show z) (drop 9 $ show y) m' n' - show (L (Dim (Dim x))) + show s@(L (Dim (Dim x))) | singleM x = printf "(%s :: L %d %d)" (show (x `atIndex` (0,0))) m' n' | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: L "++show m'++" "++show n'++")" where - m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int - n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + (m',n') = size s instance forall m n . (KnownNat m, KnownNat n) => Show (M m n) where show (isDiagC -> Just (z,y,(m',n'))) = printf "(diag %s %s :: M %d %d)" (show z) (drop 9 $ show y) m' n' - show (M (Dim (Dim x))) + show s@(M (Dim (Dim x))) | singleM x = printf "(%s :: M %d %d)" (show (x `atIndex` (0,0))) m' n' | otherwise = "(matrix"++ dropWhile (/='\n') (show x)++" :: M "++show m'++" "++show n'++")" where - m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int - n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int + (m',n') = size s -------------------------------------------------------------------------------- -- cgit v1.2.3