From a749785e839d14fadc47ab4c6e94afdd167bdd21 Mon Sep 17 00:00:00 2001 From: Alberto Ruiz Date: Tue, 26 Jun 2007 16:57:58 +0000 Subject: tensor refactorization --- lib/Data/Packed/Internal/Tensor.hs | 32 ++++++++++----- lib/Data/Packed/Matrix.hs | 4 +- lib/Data/Packed/Tensor.hs | 82 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 103 insertions(+), 15 deletions(-) (limited to 'lib/Data') diff --git a/lib/Data/Packed/Internal/Tensor.hs b/lib/Data/Packed/Internal/Tensor.hs index c4faf49..8296935 100644 --- a/lib/Data/Packed/Internal/Tensor.hs +++ b/lib/Data/Packed/Internal/Tensor.hs @@ -18,7 +18,8 @@ import Data.Packed.Internal.Common import Data.Packed.Internal.Vector import Data.Packed.Internal.Matrix import Foreign.Storable -import Data.List(sort,elemIndex,nub) +import Data.List(sort,elemIndex,nub,foldl1') +import GSL.Vector data IdxType = Covariant | Contravariant deriving (Show,Eq) @@ -79,10 +80,10 @@ concatRename l1 l2 = l1 ++ map ren l2 where ren idx = if {- s `elem` fs -} True then idx {idxName = idxName idx ++ "'"} else idx fs = map idxName l1 -prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t +--prod :: (Field t, Num t) => Tensor t -> Tensor t -> Tensor t prod (T d1 v1) (T d2 v2) = T (concatRename d1 d2) (outer' v1 v2) -contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t +--contraction :: (Field t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t contraction t1 n1 t2 n2 = if compatIdx t1 n1 t2 n2 then T (concatRename (tail d1) (tail d2)) (cdat m) @@ -91,16 +92,27 @@ contraction t1 n1 t2 n2 = (d2,m2) = putFirstIdx n2 t2 m = multiply RowMajor (trans m1) m2 -sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] -sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) +--sumT :: (Storable t, Enum t, Num t) => [Tensor t] -> [t] +--sumT ls = foldl (zipWith (+)) [0,0..] (map (toList.ten) ls) +--addT ts = T (dims (head ts)) (fromList $ sumT ts) -contract1 :: (Num t, Enum t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t -contract1 t name1 name2 = T d $ fromList $ sumT y +liftTensor f (T d v) = T d (f v) + +liftTensor2 f (T d1 v1) (T d2 v2) | compat d1 d2 = T d1 (f v1 v2) + | otherwise = error "liftTensor2 with incompatible tensors" + where compat a b = length a == length b + + +a |+| b = liftTensor2 add a b +addT l = foldl1' (|+|) l + +--contract1 :: (Num t, Field t) => Tensor t -> IdxName -> IdxName -> Tensor t +contract1 t name1 name2 = addT y where d = dims (head y) x = (map (flip parts name2) (parts t name1)) y = map head $ zipWith drop [0..] x -contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t +--contraction' :: (Field t, Enum t, Num t) => Tensor t -> IdxName -> Tensor t -> IdxName -> Tensor t contraction' t1 n1 t2 n2 = if compatIdx t1 n1 t2 n2 then contract1 (prod t1 t2) n1 (n2++"'") @@ -130,8 +142,8 @@ names t = sort $ map idxName (dims t) normal :: (Field t) => Tensor t -> Tensor t normal t = tridx (names t) t -contractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] -contractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] +possibleContractions :: (Num t, Field t) => Tensor t -> Tensor t -> [Tensor t] +possibleContractions t1 t2 = [ contraction t1 n1 t2 n2 | n1 <- names t1, n2 <- names t2, compatIdx t1 n1 t2 n2 ] -- sent to Haskell-Cafe by Sebastian Sylvan perms :: [t] -> [[t]] diff --git a/lib/Data/Packed/Matrix.hs b/lib/Data/Packed/Matrix.hs index 36bf32e..2033dc7 100644 --- a/lib/Data/Packed/Matrix.hs +++ b/lib/Data/Packed/Matrix.hs @@ -87,14 +87,14 @@ ident n = diag (constant 1 n) r >< c = f where f l | dim v == r*c = matrixFromVector RowMajor c v | otherwise = error $ "inconsistent list size = " - ++show (dim v) ++"in ("++show r++"><"++show c++")" + ++show (dim v) ++" in ("++show r++"><"++show c++")" where v = fromList l (>|<) :: (Field a) => Int -> Int -> [a] -> Matrix a r >|< c = f where f l | dim v == r*c = matrixFromVector ColumnMajor c v | otherwise = error $ "inconsistent list size = " - ++show (dim v) ++"in ("++show r++"><"++show c++")" + ++show (dim v) ++" in ("++show r++"><"++show c++")" where v = fromList l ---------------------------------------------------------------- diff --git a/lib/Data/Packed/Tensor.hs b/lib/Data/Packed/Tensor.hs index 75a9288..68ce9a5 100644 --- a/lib/Data/Packed/Tensor.hs +++ b/lib/Data/Packed/Tensor.hs @@ -12,9 +12,85 @@ -- ----------------------------------------------------------------------------- -module Data.Packed.Tensor ( - -) where +module Data.Packed.Tensor where +import Data.Packed.Matrix import Data.Packed.Internal import Complex +import Data.List(transpose,intersperse,sort,elemIndex,nub,foldl',foldl1') + +scalar x = T [] (fromList [x]) +tensorFromVector (tp,nm) v = T {dims = [IdxDesc (dim v) tp nm] + , ten = v} +tensorFromMatrix (tpr,nmr) (tpc,nmc) m = T {dims = [IdxDesc (rows m) tpr nmr,IdxDesc (cols m) tpc nmc] + , ten = cdat m} + +scsig t = scalar (signature (nms t)) `prod` t + where nms = map idxName . dims + +antisym' t = addT $ map (scsig . flip tridx t) (perms (names t)) + + +auxrename (T d v) = T d' v + where d' = [IdxDesc n c (show (pos q)) | IdxDesc n c q <- d] + pos n = i where Just i = elemIndex n nms + nms = map idxName d + +antisym t = T (dims t) (ten (antisym' (auxrename t))) + + +norper t = prod t (scalar (recip $ fromIntegral $ product [1 .. length (dims t)])) +antinorper t = prod t (scalar (fromIntegral $ product [1 .. length (dims t)])) + + +tvector n v = tensorFromVector (Contravariant,n) v +tcovector n v = tensorFromVector (Covariant,n) v + +wedge a b = antisym (prod (norper a) (norper b)) + +a /\ b = wedge a b + +a <*> b = normal $ prod a b + +normAT t = sqrt $ innerAT t t + +innerAT t1 t2 = dot (ten t1) (ten t2) / fromIntegral (fact $ length $ dims t1) + +fact n = product [1..n] + +leviCivita n = antisym $ foldl1 prod $ zipWith tcovector (map show [1,2..]) (toRows (ident n)) + +contractionF t1 t2 = contraction t1 n1 t2 n2 + where n1 = fn t1 + n2 = fn t2 + fn = idxName . head . dims + + +dualV vs = foldl' contractionF (leviCivita n) vs + where n = idxDim . head . dims . head $ vs + +raise (T d v) = T (map raise' d) v + where raise' idx@IdxDesc {idxType = Covariant } = idx {idxType = Contravariant} + raise' idx@IdxDesc {idxType = Contravariant } = idx {idxType = Covariant} + +dualMV t = prod (foldl' contract1b (lc <*> t) ds) (scalar (recip $ fromIntegral $ fact (length ds))) + where + lc = leviCivita n + nms1 = map idxName (dims lc) + nms2 = map ((++"'").idxName) (dims t) + ds = zip nms1 nms2 + n = idxDim . head . dims $ t + +contract1b t (n1,n2) = contract1 t n1 n2 + +contractions t pairs = foldl' contract1b t pairs + +asBase r n = filter (\x-> (x==nub x && x==sort x)) $ sequence $ replicate r [1..n] + +partF t i = part t (name,i) where name = idxName . head . dims $ t + +niceAS t = filter ((/=0.0).fst) $ zip vals base + where vals = map ((`at` 0).ten.foldl' partF t) (map (map pred) base) + base = asBase r n + r = length (dims t) + n = idxDim . head . dims $ t -- cgit v1.2.3