summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Element.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Element.hs')
-rw-r--r--packages/base/src/Internal/Element.hs84
1 files changed, 42 insertions, 42 deletions
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs
index 2e330ee..80eda8d 100644
--- a/packages/base/src/Internal/Element.hs
+++ b/packages/base/src/Internal/Element.hs
@@ -33,14 +33,14 @@ import Data.List.Split(chunksOf)
33import Foreign.Storable(Storable) 33import Foreign.Storable(Storable)
34import System.IO.Unsafe(unsafePerformIO) 34import System.IO.Unsafe(unsafePerformIO)
35import Control.Monad(liftM) 35import Control.Monad(liftM)
36import Foreign.C.Types(CInt) 36import Data.Int
37 37
38------------------------------------------------------------------- 38-------------------------------------------------------------------
39 39
40 40
41import Data.Binary 41import Data.Binary
42 42
43instance (Binary (Vector a), Element a) => Binary (Matrix a) where 43instance (Binary (Vector a), Storable a) => Binary (Matrix a) where
44 put m = do 44 put m = do
45 put (cols m) 45 put (cols m)
46 put (flatten m) 46 put (flatten m)
@@ -52,7 +52,7 @@ instance (Binary (Vector a), Element a) => Binary (Matrix a) where
52 52
53------------------------------------------------------------------- 53-------------------------------------------------------------------
54 54
55instance (Show a, Element a) => (Show (Matrix a)) where 55instance (Show a, Storable a) => (Show (Matrix a)) where
56 show m | rows m == 0 || cols m == 0 = sizes m ++" []" 56 show m | rows m == 0 || cols m == 0 = sizes m ++" []"
57 show m = (sizes m++) . dsp . map (map show) . toLists $ m 57 show m = (sizes m++) . dsp . map (map show) . toLists $ m
58 58
@@ -70,7 +70,7 @@ dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unw
70 70
71------------------------------------------------------------------ 71------------------------------------------------------------------
72 72
73instance (Element a, Read a) => Read (Matrix a) where 73instance (Storable a, Read a) => Read (Matrix a) where
74 readsPrec _ s = [((rs><cs) . read $ listnums, rest)] 74 readsPrec _ s = [((rs><cs) . read $ listnums, rest)]
75 where (thing,rest) = breakAt ']' s 75 where (thing,rest) = breakAt ']' s
76 (dims,listnums) = breakAt ')' thing 76 (dims,listnums) = breakAt ')' thing
@@ -133,13 +133,13 @@ ppext (DropLast n) = printf "DropLast %d" n
133 133
134-} 134-}
135infixl 9 ?? 135infixl 9 ??
136(??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t 136(??) :: Storable t => Matrix t -> (Extractor,Extractor) -> Matrix t
137 137
138minEl :: Vector CInt -> CInt 138minEl :: Vector Int32 -> Int32
139minEl = toScalarI Min 139minEl = toScalarI Min
140maxEl :: Vector CInt -> CInt 140maxEl :: Vector Int32 -> Int32
141maxEl = toScalarI Max 141maxEl = toScalarI Max
142cmodi :: Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt -> Vector Foreign.C.Types.CInt 142cmodi :: Int32 -> Vector Int32 -> Vector Int32
143cmodi = vectorMapValI ModVS 143cmodi = vectorMapValI ModVS
144 144
145extractError :: Matrix t1 -> (Extractor, Extractor) -> t 145extractError :: Matrix t1 -> (Extractor, Extractor) -> t
@@ -181,7 +181,7 @@ m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n))
181m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) 181m ?? (DropLast n, e) = m ?? (Take (rows m - n), e)
182m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) 182m ?? (e, DropLast n) = m ?? (e, Take (cols m - n))
183 183
184m ?? (er,ec) = unsafePerformIO $ extractR (orderOf m) m moder rs modec cs 184m ?? (er,ec) = unsafePerformIO $ extractAux (orderOf m) m moder rs modec cs
185 where 185 where
186 (moder,rs) = mkExt (rows m) er 186 (moder,rs) = mkExt (rows m) er
187 (modec,cs) = mkExt (cols m) ec 187 (modec,cs) = mkExt (cols m) ec
@@ -209,14 +209,14 @@ common f = commonval . map f
209 209
210 210
211-- | creates a matrix from a vertical list of matrices 211-- | creates a matrix from a vertical list of matrices
212joinVert :: Element t => [Matrix t] -> Matrix t 212joinVert :: Storable t => [Matrix t] -> Matrix t
213joinVert [] = emptyM 0 0 213joinVert [] = emptyM 0 0
214joinVert ms = case common cols ms of 214joinVert ms = case common cols ms of
215 Nothing -> error "(impossible) joinVert on matrices with different number of columns" 215 Nothing -> error "(impossible) joinVert on matrices with different number of columns"
216 Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) 216 Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms)
217 217
218-- | creates a matrix from a horizontal list of matrices 218-- | creates a matrix from a horizontal list of matrices
219joinHoriz :: Element t => [Matrix t] -> Matrix t 219joinHoriz :: Storable t => [Matrix t] -> Matrix t
220joinHoriz ms = trans. joinVert . map trans $ ms 220joinHoriz ms = trans. joinVert . map trans $ ms
221 221
222{- | Create a matrix from blocks given as a list of lists of matrices. 222{- | Create a matrix from blocks given as a list of lists of matrices.
@@ -240,13 +240,13 @@ disp = putStr . dispf 2
2403 3 3 3 3 0 0 3 0 0 2403 3 3 3 3 0 0 3 0 0
241 241
242-} 242-}
243fromBlocks :: Element t => [[Matrix t]] -> Matrix t 243fromBlocks :: Storable t => [[Matrix t]] -> Matrix t
244fromBlocks = fromBlocksRaw . adaptBlocks 244fromBlocks = fromBlocksRaw . adaptBlocks
245 245
246fromBlocksRaw :: Element t => [[Matrix t]] -> Matrix t 246fromBlocksRaw :: Storable t => [[Matrix t]] -> Matrix t
247fromBlocksRaw mms = joinVert . map joinHoriz $ mms 247fromBlocksRaw mms = joinVert . map joinHoriz $ mms
248 248
249adaptBlocks :: Element t => [[Matrix t]] -> [[Matrix t]] 249adaptBlocks :: Storable t => [[Matrix t]] -> [[Matrix t]]
250adaptBlocks ms = ms' where 250adaptBlocks ms = ms' where
251 bc = case common length ms of 251 bc = case common length ms of
252 Just c -> c 252 Just c -> c
@@ -258,7 +258,7 @@ adaptBlocks ms = ms' where
258 258
259 g [Just nr,Just nc] m 259 g [Just nr,Just nc] m
260 | nr == r && nc == c = m 260 | nr == r && nc == c = m
261 | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc)) 261 | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantAux x (nr*nc))
262 | r == 1 = fromRows (replicate nr (flatten m)) 262 | r == 1 = fromRows (replicate nr (flatten m))
263 | otherwise = fromColumns (replicate nc (flatten m)) 263 | otherwise = fromColumns (replicate nc (flatten m))
264 where 264 where
@@ -288,7 +288,7 @@ adaptBlocks ms = ms' where
288 , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ] 288 , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ]
289 289
290-} 290-}
291diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t 291diagBlock :: (Storable t, Num t) => [Matrix t] -> Matrix t
292diagBlock ms = fromBlocks $ zipWith f ms [0..] 292diagBlock ms = fromBlocks $ zipWith f ms [0..]
293 where 293 where
294 f m k = take n $ replicate k z ++ m : repeat z 294 f m k = take n $ replicate k z ++ m : repeat z
@@ -299,13 +299,13 @@ diagBlock ms = fromBlocks $ zipWith f ms [0..]
299 299
300 300
301-- | Reverse rows 301-- | Reverse rows
302flipud :: Element t => Matrix t -> Matrix t 302flipud :: Storable t => Matrix t -> Matrix t
303flipud m = extractRows [r-1,r-2 .. 0] $ m 303flipud m = extractRows [r-1,r-2 .. 0] $ m
304 where 304 where
305 r = rows m 305 r = rows m
306 306
307-- | Reverse columns 307-- | Reverse columns
308fliprl :: Element t => Matrix t -> Matrix t 308fliprl :: Storable t => Matrix t -> Matrix t
309fliprl m = extractColumns [c-1,c-2 .. 0] $ m 309fliprl m = extractColumns [c-1,c-2 .. 0] $ m
310 where 310 where
311 c = cols m 311 c = cols m
@@ -330,7 +330,7 @@ diagRect z v r c = ST.runSTMatrix $ do
330 return m 330 return m
331 331
332-- | extracts the diagonal from a rectangular matrix 332-- | extracts the diagonal from a rectangular matrix
333takeDiag :: (Element t) => Matrix t -> Vector t 333takeDiag :: (Storable t) => Matrix t -> Vector t
334takeDiag m = fromList [flatten m @> (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] 334takeDiag m = fromList [flatten m @> (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]]
335 335
336------------------------------------------------------------ 336------------------------------------------------------------
@@ -363,32 +363,32 @@ r >< c = f where
363 363
364---------------------------------------------------------------- 364----------------------------------------------------------------
365 365
366takeRows :: Element t => Int -> Matrix t -> Matrix t 366takeRows :: Storable t => Int -> Matrix t -> Matrix t
367takeRows n mt = subMatrix (0,0) (n, cols mt) mt 367takeRows n mt = subMatrix (0,0) (n, cols mt) mt
368 368
369-- | Creates a matrix with the last n rows of another matrix 369-- | Creates a matrix with the last n rows of another matrix
370takeLastRows :: Element t => Int -> Matrix t -> Matrix t 370takeLastRows :: Storable t => Int -> Matrix t -> Matrix t
371takeLastRows n mt = subMatrix (rows mt - n, 0) (n, cols mt) mt 371takeLastRows n mt = subMatrix (rows mt - n, 0) (n, cols mt) mt
372 372
373dropRows :: Element t => Int -> Matrix t -> Matrix t 373dropRows :: Storable t => Int -> Matrix t -> Matrix t
374dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt 374dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt
375 375
376-- | Creates a copy of a matrix without the last n rows 376-- | Creates a copy of a matrix without the last n rows
377dropLastRows :: Element t => Int -> Matrix t -> Matrix t 377dropLastRows :: Storable t => Int -> Matrix t -> Matrix t
378dropLastRows n mt = subMatrix (0,0) (rows mt - n, cols mt) mt 378dropLastRows n mt = subMatrix (0,0) (rows mt - n, cols mt) mt
379 379
380takeColumns :: Element t => Int -> Matrix t -> Matrix t 380takeColumns :: Storable t => Int -> Matrix t -> Matrix t
381takeColumns n mt = subMatrix (0,0) (rows mt, n) mt 381takeColumns n mt = subMatrix (0,0) (rows mt, n) mt
382 382
383-- |Creates a matrix with the last n columns of another matrix 383-- |Creates a matrix with the last n columns of another matrix
384takeLastColumns :: Element t => Int -> Matrix t -> Matrix t 384takeLastColumns :: Storable t => Int -> Matrix t -> Matrix t
385takeLastColumns n mt = subMatrix (0, cols mt - n) (rows mt, n) mt 385takeLastColumns n mt = subMatrix (0, cols mt - n) (rows mt, n) mt
386 386
387dropColumns :: Element t => Int -> Matrix t -> Matrix t 387dropColumns :: Storable t => Int -> Matrix t -> Matrix t
388dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt 388dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt
389 389
390-- | Creates a copy of a matrix without the last n columns 390-- | Creates a copy of a matrix without the last n columns
391dropLastColumns :: Element t => Int -> Matrix t -> Matrix t 391dropLastColumns :: Storable t => Int -> Matrix t -> Matrix t
392dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt 392dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt
393 393
394---------------------------------------------------------------- 394----------------------------------------------------------------
@@ -402,7 +402,7 @@ dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt
402 , 5.0, 6.0 ] 402 , 5.0, 6.0 ]
403 403
404-} 404-}
405fromLists :: Element t => [[t]] -> Matrix t 405fromLists :: Storable t => [[t]] -> Matrix t
406fromLists = fromRows . map fromList 406fromLists = fromRows . map fromList
407 407
408-- | creates a 1-row matrix from a vector 408-- | creates a 1-row matrix from a vector
@@ -443,7 +443,7 @@ Hilbert matrix of order N:
443@hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@ 443@hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@
444 444
445-} 445-}
446buildMatrix :: Element a => Int -> Int -> ((Int, Int) -> a) -> Matrix a 446buildMatrix :: Storable a => Int -> Int -> ((Int, Int) -> a) -> Matrix a
447buildMatrix rc cc f = 447buildMatrix rc cc f =
448 fromLists $ map (map f) 448 fromLists $ map (map f)
449 $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)] 449 $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)]
@@ -458,11 +458,11 @@ fromArray2D m = (r><c) (elems m)
458 458
459 459
460-- | rearranges the rows of a matrix according to the order given in a list of integers. 460-- | rearranges the rows of a matrix according to the order given in a list of integers.
461extractRows :: Element t => [Int] -> Matrix t -> Matrix t 461extractRows :: Storable t => [Int] -> Matrix t -> Matrix t
462extractRows l m = m ?? (Pos (idxs l), All) 462extractRows l m = m ?? (Pos (idxs l), All)
463 463
464-- | rearranges the rows of a matrix according to the order given in a list of integers. 464-- | rearranges the rows of a matrix according to the order given in a list of integers.
465extractColumns :: Element t => [Int] -> Matrix t -> Matrix t 465extractColumns :: Storable t => [Int] -> Matrix t -> Matrix t
466extractColumns l m = m ?? (All, Pos (idxs l)) 466extractColumns l m = m ?? (All, Pos (idxs l))
467 467
468 468
@@ -476,13 +476,13 @@ extractColumns l m = m ?? (All, Pos (idxs l))
476 , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ] 476 , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ]
477 477
478-} 478-}
479repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t 479repmat :: (Storable t) => Matrix t -> Int -> Int -> Matrix t
480repmat m r c 480repmat m r c
481 | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m) 481 | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m)
482 | otherwise = fromBlocks $ replicate r $ replicate c $ m 482 | otherwise = fromBlocks $ replicate r $ replicate c $ m
483 483
484-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. 484-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix.
485liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 485liftMatrix2Auto :: (Storable t, Storable a, Storable b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
486liftMatrix2Auto f m1 m2 486liftMatrix2Auto f m1 m2
487 | compat' m1 m2 = lM f m1 m2 487 | compat' m1 m2 = lM f m1 m2
488 | ok = lM f m1' m2' 488 | ok = lM f m1' m2'
@@ -499,7 +499,7 @@ liftMatrix2Auto f m1 m2
499 m2' = conformMTo (r,c) m2 499 m2' = conformMTo (r,c) m2
500 500
501-- FIXME do not flatten if equal order 501-- FIXME do not flatten if equal order
502lM :: (Storable t, Element t1, Element t2) 502lM :: (Storable t, Storable t1, Storable t2)
503 => (Vector t1 -> Vector t2 -> Vector t) 503 => (Vector t1 -> Vector t2 -> Vector t)
504 -> Matrix t1 -> Matrix t2 -> Matrix t 504 -> Matrix t1 -> Matrix t2 -> Matrix t
505lM f m1 m2 = matrixFromVector 505lM f m1 m2 = matrixFromVector
@@ -520,7 +520,7 @@ compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
520 520
521------------------------------------------------------------ 521------------------------------------------------------------
522 522
523toBlockRows :: Element t => [Int] -> Matrix t -> [Matrix t] 523toBlockRows :: Storable t => [Int] -> Matrix t -> [Matrix t]
524toBlockRows [r] m 524toBlockRows [r] m
525 | r == rows m = [m] 525 | r == rows m = [m]
526toBlockRows rs m 526toBlockRows rs m
@@ -530,13 +530,13 @@ toBlockRows rs m
530 szs = map (* cols m) rs 530 szs = map (* cols m) rs
531 g k = (k><0)[] 531 g k = (k><0)[]
532 532
533toBlockCols :: Element t => [Int] -> Matrix t -> [Matrix t] 533toBlockCols :: Storable t => [Int] -> Matrix t -> [Matrix t]
534toBlockCols [c] m | c == cols m = [m] 534toBlockCols [c] m | c == cols m = [m]
535toBlockCols cs m = map trans . toBlockRows cs . trans $ m 535toBlockCols cs m = map trans . toBlockRows cs . trans $ m
536 536
537-- | Partition a matrix into blocks with the given numbers of rows and columns. 537-- | Partition a matrix into blocks with the given numbers of rows and columns.
538-- The remaining rows and columns are discarded. 538-- The remaining rows and columns are discarded.
539toBlocks :: (Element t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]] 539toBlocks :: (Storable t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]]
540toBlocks rs cs m 540toBlocks rs cs m
541 | ok = map (toBlockCols cs) . toBlockRows rs $ m 541 | ok = map (toBlockCols cs) . toBlockRows rs $ m
542 | otherwise = error $ "toBlocks: bad partition: "++show rs++" "++show cs 542 | otherwise = error $ "toBlocks: bad partition: "++show rs++" "++show cs
@@ -546,7 +546,7 @@ toBlocks rs cs m
546 546
547-- | Fully partition a matrix into blocks of the same size. If the dimensions are not 547-- | Fully partition a matrix into blocks of the same size. If the dimensions are not
548-- a multiple of the given size the last blocks will be smaller. 548-- a multiple of the given size the last blocks will be smaller.
549toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]] 549toBlocksEvery :: (Storable t) => Int -> Int -> Matrix t -> [[Matrix t]]
550toBlocksEvery r c m 550toBlocksEvery r c m
551 | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c 551 | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c
552 | otherwise = toBlocks rs cs m 552 | otherwise = toBlocks rs cs m
@@ -576,7 +576,7 @@ m[1,2] = 6
576 576
577-} 577-}
578mapMatrixWithIndexM_ 578mapMatrixWithIndexM_
579 :: (Element a, Num a, Monad m) => 579 :: (Storable a, Num a, Monad m) =>
580 ((Int, Int) -> a -> m ()) -> Matrix a -> m () 580 ((Int, Int) -> a -> m ()) -> Matrix a -> m ()
581mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m 581mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m
582 where 582 where
@@ -592,7 +592,7 @@ Just (3><3)
592 592
593-} 593-}
594mapMatrixWithIndexM 594mapMatrixWithIndexM
595 :: (Element a, Storable b, Monad m) => 595 :: (Storable a, Storable b, Monad m) =>
596 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) 596 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b)
597mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m 597mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m
598 where 598 where
@@ -608,11 +608,11 @@ mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . fla
608 608
609 -} 609 -}
610mapMatrixWithIndex 610mapMatrixWithIndex
611 :: (Element a, Storable b) => 611 :: (Storable a, Storable b) =>
612 ((Int, Int) -> a -> b) -> Matrix a -> Matrix b 612 ((Int, Int) -> a -> b) -> Matrix a -> Matrix b
613mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m 613mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m
614 where 614 where
615 c = cols m 615 c = cols m
616 616
617mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b 617mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b
618mapMatrix f = liftMatrix (mapVector f) 618mapMatrix f = liftMatrix (mapVector f)