diff options
Diffstat (limited to 'packages/base/src/Internal/Element.hs')
-rw-r--r-- | packages/base/src/Internal/Element.hs | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs index eb3a25b..2e330ee 100644 --- a/packages/base/src/Internal/Element.hs +++ b/packages/base/src/Internal/Element.hs | |||
@@ -4,6 +4,8 @@ | |||
4 | {-# LANGUAGE UndecidableInstances #-} | 4 | {-# LANGUAGE UndecidableInstances #-} |
5 | {-# LANGUAGE MultiParamTypeClasses #-} | 5 | {-# LANGUAGE MultiParamTypeClasses #-} |
6 | 6 | ||
7 | {-# OPTIONS_GHC -fno-warn-orphans #-} | ||
8 | |||
7 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
8 | -- | | 10 | -- | |
9 | -- Module : Data.Packed.Matrix | 11 | -- Module : Data.Packed.Matrix |
@@ -31,6 +33,7 @@ import Data.List.Split(chunksOf) | |||
31 | import Foreign.Storable(Storable) | 33 | import Foreign.Storable(Storable) |
32 | import System.IO.Unsafe(unsafePerformIO) | 34 | import System.IO.Unsafe(unsafePerformIO) |
33 | import Control.Monad(liftM) | 35 | import Control.Monad(liftM) |
36 | import Foreign.C.Types(CInt) | ||
34 | 37 | ||
35 | ------------------------------------------------------------------- | 38 | ------------------------------------------------------------------- |
36 | 39 | ||
@@ -53,8 +56,10 @@ instance (Show a, Element a) => (Show (Matrix a)) where | |||
53 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" | 56 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" |
54 | show m = (sizes m++) . dsp . map (map show) . toLists $ m | 57 | show m = (sizes m++) . dsp . map (map show) . toLists $ m |
55 | 58 | ||
59 | sizes :: Matrix t -> [Char] | ||
56 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" | 60 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" |
57 | 61 | ||
62 | dsp :: [[[Char]]] -> [Char] | ||
58 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | 63 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp |
59 | where | 64 | where |
60 | mt = transpose as | 65 | mt = transpose as |
@@ -73,6 +78,7 @@ instance (Element a, Read a) => Read (Matrix a) where | |||
73 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims | 78 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims |
74 | 79 | ||
75 | 80 | ||
81 | breakAt :: Eq a => a -> [a] -> ([a], [a]) | ||
76 | breakAt c l = (a++[c],tail b) where | 82 | breakAt c l = (a++[c],tail b) where |
77 | (a,b) = break (==c) l | 83 | (a,b) = break (==c) l |
78 | 84 | ||
@@ -88,7 +94,8 @@ data Extractor | |||
88 | | Drop Int | 94 | | Drop Int |
89 | | DropLast Int | 95 | | DropLast Int |
90 | deriving Show | 96 | deriving Show |
91 | 97 | ||
98 | ppext :: Extractor -> [Char] | ||
92 | ppext All = ":" | 99 | ppext All = ":" |
93 | ppext (Range a 1 c) = printf "%d:%d" a c | 100 | ppext (Range a 1 c) = printf "%d:%d" a c |
94 | ppext (Range a b c) = printf "%d:%d:%d" a b c | 101 | ppext (Range a b c) = printf "%d:%d:%d" a b c |
@@ -128,10 +135,14 @@ ppext (DropLast n) = printf "DropLast %d" n | |||
128 | infixl 9 ?? | 135 | infixl 9 ?? |
129 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t | 136 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t |
130 | 137 | ||
138 | minEl :: Vector CInt -> CInt | ||
131 | minEl = toScalarI Min | 139 | minEl = toScalarI Min |
140 | maxEl :: Vector CInt -> CInt | ||
132 | maxEl = toScalarI Max | 141 | maxEl = toScalarI Max |
142 | cmodi :: Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt | ||
133 | cmodi = vectorMapValI ModVS | 143 | cmodi = vectorMapValI ModVS |
134 | 144 | ||
145 | extractError :: Matrix t1 -> (Extractor, Extractor) -> t | ||
135 | extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) | 146 | extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) |
136 | 147 | ||
137 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) | 148 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) |
@@ -232,8 +243,10 @@ disp = putStr . dispf 2 | |||
232 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t | 243 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t |
233 | fromBlocks = fromBlocksRaw . adaptBlocks | 244 | fromBlocks = fromBlocksRaw . adaptBlocks |
234 | 245 | ||
246 | fromBlocksRaw :: Element t => [[Matrix t]] -> Matrix t | ||
235 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms | 247 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms |
236 | 248 | ||
249 | adaptBlocks :: Element t => [[Matrix t]] -> [[Matrix t]] | ||
237 | adaptBlocks ms = ms' where | 250 | adaptBlocks ms = ms' where |
238 | bc = case common length ms of | 251 | bc = case common length ms of |
239 | Just c -> c | 252 | Just c -> c |
@@ -486,6 +499,9 @@ liftMatrix2Auto f m1 m2 | |||
486 | m2' = conformMTo (r,c) m2 | 499 | m2' = conformMTo (r,c) m2 |
487 | 500 | ||
488 | -- FIXME do not flatten if equal order | 501 | -- FIXME do not flatten if equal order |
502 | lM :: (Storable t, Element t1, Element t2) | ||
503 | => (Vector t1 -> Vector t2 -> Vector t) | ||
504 | -> Matrix t1 -> Matrix t2 -> Matrix t | ||
489 | lM f m1 m2 = matrixFromVector | 505 | lM f m1 m2 = matrixFromVector |
490 | RowMajor | 506 | RowMajor |
491 | (max' (rows m1) (rows m2)) | 507 | (max' (rows m1) (rows m2)) |
@@ -504,6 +520,7 @@ compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | |||
504 | 520 | ||
505 | ------------------------------------------------------------ | 521 | ------------------------------------------------------------ |
506 | 522 | ||
523 | toBlockRows :: Element t => [Int] -> Matrix t -> [Matrix t] | ||
507 | toBlockRows [r] m | 524 | toBlockRows [r] m |
508 | | r == rows m = [m] | 525 | | r == rows m = [m] |
509 | toBlockRows rs m | 526 | toBlockRows rs m |
@@ -513,6 +530,7 @@ toBlockRows rs m | |||
513 | szs = map (* cols m) rs | 530 | szs = map (* cols m) rs |
514 | g k = (k><0)[] | 531 | g k = (k><0)[] |
515 | 532 | ||
533 | toBlockCols :: Element t => [Int] -> Matrix t -> [Matrix t] | ||
516 | toBlockCols [c] m | c == cols m = [m] | 534 | toBlockCols [c] m | c == cols m = [m] |
517 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m | 535 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m |
518 | 536 | ||
@@ -576,7 +594,7 @@ Just (3><3) | |||
576 | mapMatrixWithIndexM | 594 | mapMatrixWithIndexM |
577 | :: (Element a, Storable b, Monad m) => | 595 | :: (Element a, Storable b, Monad m) => |
578 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) | 596 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) |
579 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m | 597 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m |
580 | where | 598 | where |
581 | c = cols m | 599 | c = cols m |
582 | 600 | ||
@@ -598,4 +616,3 @@ mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m | |||
598 | 616 | ||
599 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b | 617 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b |
600 | mapMatrix f = liftMatrix (mapVector f) | 618 | mapMatrix f = liftMatrix (mapVector f) |
601 | |||