diff options
Diffstat (limited to 'packages/base/src/Internal/Element.hs')
-rw-r--r-- | packages/base/src/Internal/Element.hs | 490 |
1 files changed, 490 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs new file mode 100644 index 0000000..6fc2981 --- /dev/null +++ b/packages/base/src/Internal/Element.hs | |||
@@ -0,0 +1,490 @@ | |||
1 | {-# LANGUAGE TypeFamilies #-} | ||
2 | {-# LANGUAGE FlexibleContexts #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | {-# LANGUAGE UndecidableInstances #-} | ||
5 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
6 | {-# LANGUAGE CPP #-} | ||
7 | |||
8 | ----------------------------------------------------------------------------- | ||
9 | -- | | ||
10 | -- Module : Data.Packed.Matrix | ||
11 | -- Copyright : (c) Alberto Ruiz 2007-10 | ||
12 | -- License : BSD3 | ||
13 | -- Maintainer : Alberto Ruiz | ||
14 | -- Stability : provisional | ||
15 | -- | ||
16 | -- A Matrix representation suitable for numerical computations using LAPACK and GSL. | ||
17 | -- | ||
18 | -- This module provides basic functions for manipulation of structure. | ||
19 | |||
20 | ----------------------------------------------------------------------------- | ||
21 | |||
22 | module Internal.Element where | ||
23 | |||
24 | import Internal.Tools | ||
25 | import Internal.Vector | ||
26 | import Internal.Matrix | ||
27 | import qualified Internal.ST as ST | ||
28 | import Data.Array | ||
29 | |||
30 | import Data.Vector.Storable(fromList) | ||
31 | import Data.List(transpose,intersperse) | ||
32 | import Foreign.Storable(Storable) | ||
33 | import Control.Monad(liftM) | ||
34 | |||
35 | ------------------------------------------------------------------- | ||
36 | |||
37 | #ifdef BINARY | ||
38 | |||
39 | import Data.Binary | ||
40 | |||
41 | instance (Binary (Vector a), Element a) => Binary (Matrix a) where | ||
42 | put m = do | ||
43 | put (cols m) | ||
44 | put (flatten m) | ||
45 | get = do | ||
46 | c <- get | ||
47 | v <- get | ||
48 | return (reshape c v) | ||
49 | |||
50 | #endif | ||
51 | |||
52 | ------------------------------------------------------------------- | ||
53 | |||
54 | instance (Show a, Element a) => (Show (Matrix a)) where | ||
55 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" | ||
56 | show m = (sizes m++) . dsp . map (map show) . toLists $ m | ||
57 | |||
58 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
59 | |||
60 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | ||
61 | where | ||
62 | mt = transpose as | ||
63 | longs = map (maximum . map length) mt | ||
64 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
65 | pad n str = replicate (n - length str) ' ' ++ str | ||
66 | unwords' = concat . intersperse ", " | ||
67 | |||
68 | ------------------------------------------------------------------ | ||
69 | |||
70 | instance (Element a, Read a) => Read (Matrix a) where | ||
71 | readsPrec _ s = [((rs><cs) . read $ listnums, rest)] | ||
72 | where (thing,rest) = breakAt ']' s | ||
73 | (dims,listnums) = breakAt ')' thing | ||
74 | cs = read . init . fst. breakAt ')' . snd . breakAt '<' $ dims | ||
75 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims | ||
76 | |||
77 | |||
78 | breakAt c l = (a++[c],tail b) where | ||
79 | (a,b) = break (==c) l | ||
80 | |||
81 | ------------------------------------------------------------------ | ||
82 | |||
83 | -- | creates a matrix from a vertical list of matrices | ||
84 | joinVert :: Element t => [Matrix t] -> Matrix t | ||
85 | joinVert [] = emptyM 0 0 | ||
86 | joinVert ms = case common cols ms of | ||
87 | Nothing -> error "(impossible) joinVert on matrices with different number of columns" | ||
88 | Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) | ||
89 | |||
90 | -- | creates a matrix from a horizontal list of matrices | ||
91 | joinHoriz :: Element t => [Matrix t] -> Matrix t | ||
92 | joinHoriz ms = trans. joinVert . map trans $ ms | ||
93 | |||
94 | {- | Create a matrix from blocks given as a list of lists of matrices. | ||
95 | |||
96 | Single row-column components are automatically expanded to match the | ||
97 | corresponding common row and column: | ||
98 | |||
99 | @ | ||
100 | disp = putStr . dispf 2 | ||
101 | @ | ||
102 | |||
103 | >>> disp $ fromBlocks [[ident 5, 7, row[10,20]], [3, diagl[1,2,3], 0]] | ||
104 | 8x10 | ||
105 | 1 0 0 0 0 7 7 7 10 20 | ||
106 | 0 1 0 0 0 7 7 7 10 20 | ||
107 | 0 0 1 0 0 7 7 7 10 20 | ||
108 | 0 0 0 1 0 7 7 7 10 20 | ||
109 | 0 0 0 0 1 7 7 7 10 20 | ||
110 | 3 3 3 3 3 1 0 0 0 0 | ||
111 | 3 3 3 3 3 0 2 0 0 0 | ||
112 | 3 3 3 3 3 0 0 3 0 0 | ||
113 | |||
114 | -} | ||
115 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t | ||
116 | fromBlocks = fromBlocksRaw . adaptBlocks | ||
117 | |||
118 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms | ||
119 | |||
120 | adaptBlocks ms = ms' where | ||
121 | bc = case common length ms of | ||
122 | Just c -> c | ||
123 | Nothing -> error "fromBlocks requires rectangular [[Matrix]]" | ||
124 | rs = map (compatdim . map rows) ms | ||
125 | cs = map (compatdim . map cols) (transpose ms) | ||
126 | szs = sequence [rs,cs] | ||
127 | ms' = splitEvery bc $ zipWith g szs (concat ms) | ||
128 | |||
129 | g [Just nr,Just nc] m | ||
130 | | nr == r && nc == c = m | ||
131 | | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc)) | ||
132 | | r == 1 = fromRows (replicate nr (flatten m)) | ||
133 | | otherwise = fromColumns (replicate nc (flatten m)) | ||
134 | where | ||
135 | r = rows m | ||
136 | c = cols m | ||
137 | x = m@@>(0,0) | ||
138 | g _ _ = error "inconsistent dimensions in fromBlocks" | ||
139 | |||
140 | |||
141 | -------------------------------------------------------------------------------- | ||
142 | |||
143 | {- | create a block diagonal matrix | ||
144 | |||
145 | >>> disp 2 $ diagBlock [konst 1 (2,2), konst 2 (3,5), col [5,7]] | ||
146 | 7x8 | ||
147 | 1 1 0 0 0 0 0 0 | ||
148 | 1 1 0 0 0 0 0 0 | ||
149 | 0 0 2 2 2 2 2 0 | ||
150 | 0 0 2 2 2 2 2 0 | ||
151 | 0 0 2 2 2 2 2 0 | ||
152 | 0 0 0 0 0 0 0 5 | ||
153 | 0 0 0 0 0 0 0 7 | ||
154 | |||
155 | >>> diagBlock [(0><4)[], konst 2 (2,3)] :: Matrix Double | ||
156 | (2><7) | ||
157 | [ 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 | ||
158 | , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ] | ||
159 | |||
160 | -} | ||
161 | diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t | ||
162 | diagBlock ms = fromBlocks $ zipWith f ms [0..] | ||
163 | where | ||
164 | f m k = take n $ replicate k z ++ m : repeat z | ||
165 | n = length ms | ||
166 | z = (1><1) [0] | ||
167 | |||
168 | -------------------------------------------------------------------------------- | ||
169 | |||
170 | |||
171 | -- | Reverse rows | ||
172 | flipud :: Element t => Matrix t -> Matrix t | ||
173 | flipud m = extractRows [r-1,r-2 .. 0] $ m | ||
174 | where | ||
175 | r = rows m | ||
176 | |||
177 | -- | Reverse columns | ||
178 | fliprl :: Element t => Matrix t -> Matrix t | ||
179 | fliprl m = extractColumns [c-1,c-2 .. 0] $ m | ||
180 | where | ||
181 | c = cols m | ||
182 | |||
183 | ------------------------------------------------------------ | ||
184 | |||
185 | {- | creates a rectangular diagonal matrix: | ||
186 | |||
187 | >>> diagRect 7 (fromList [10,20,30]) 4 5 :: Matrix Double | ||
188 | (4><5) | ||
189 | [ 10.0, 7.0, 7.0, 7.0, 7.0 | ||
190 | , 7.0, 20.0, 7.0, 7.0, 7.0 | ||
191 | , 7.0, 7.0, 30.0, 7.0, 7.0 | ||
192 | , 7.0, 7.0, 7.0, 7.0, 7.0 ] | ||
193 | |||
194 | -} | ||
195 | diagRect :: (Storable t) => t -> Vector t -> Int -> Int -> Matrix t | ||
196 | diagRect z v r c = ST.runSTMatrix $ do | ||
197 | m <- ST.newMatrix z r c | ||
198 | let d = min r c `min` (dim v) | ||
199 | mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1] | ||
200 | return m | ||
201 | |||
202 | -- | extracts the diagonal from a rectangular matrix | ||
203 | takeDiag :: (Element t) => Matrix t -> Vector t | ||
204 | takeDiag m = fromList [flatten m @> (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | ||
205 | |||
206 | ------------------------------------------------------------ | ||
207 | |||
208 | {- | create a general matrix | ||
209 | |||
210 | >>> (2><3) [2, 4, 7+2*𝑖, -3, 11, 0] | ||
211 | (2><3) | ||
212 | [ 2.0 :+ 0.0, 4.0 :+ 0.0, 7.0 :+ 2.0 | ||
213 | , (-3.0) :+ (-0.0), 11.0 :+ 0.0, 0.0 :+ 0.0 ] | ||
214 | |||
215 | The input list is explicitly truncated, so that it can | ||
216 | safely be used with lists that are too long (like infinite lists). | ||
217 | |||
218 | >>> (2><3)[1..] | ||
219 | (2><3) | ||
220 | [ 1.0, 2.0, 3.0 | ||
221 | , 4.0, 5.0, 6.0 ] | ||
222 | |||
223 | This is the format produced by the instances of Show (Matrix a), which | ||
224 | can also be used for input. | ||
225 | |||
226 | -} | ||
227 | (><) :: (Storable a) => Int -> Int -> [a] -> Matrix a | ||
228 | r >< c = f where | ||
229 | f l | dim v == r*c = matrixFromVector RowMajor r c v | ||
230 | | otherwise = error $ "inconsistent list size = " | ||
231 | ++show (dim v) ++" in ("++show r++"><"++show c++")" | ||
232 | where v = fromList $ take (r*c) l | ||
233 | |||
234 | ---------------------------------------------------------------- | ||
235 | |||
236 | -- | Creates a matrix with the first n rows of another matrix | ||
237 | takeRows :: Element t => Int -> Matrix t -> Matrix t | ||
238 | takeRows n mt = subMatrix (0,0) (n, cols mt) mt | ||
239 | -- | Creates a matrix with the last n rows of another matrix | ||
240 | takeLastRows :: Element t => Int -> Matrix t -> Matrix t | ||
241 | takeLastRows n mt = subMatrix (rows mt - n, 0) (n, cols mt) mt | ||
242 | -- | Creates a copy of a matrix without the first n rows | ||
243 | dropRows :: Element t => Int -> Matrix t -> Matrix t | ||
244 | dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt | ||
245 | -- | Creates a copy of a matrix without the last n rows | ||
246 | dropLastRows :: Element t => Int -> Matrix t -> Matrix t | ||
247 | dropLastRows n mt = subMatrix (0,0) (rows mt - n, cols mt) mt | ||
248 | -- |Creates a matrix with the first n columns of another matrix | ||
249 | takeColumns :: Element t => Int -> Matrix t -> Matrix t | ||
250 | takeColumns n mt = subMatrix (0,0) (rows mt, n) mt | ||
251 | -- |Creates a matrix with the last n columns of another matrix | ||
252 | takeLastColumns :: Element t => Int -> Matrix t -> Matrix t | ||
253 | takeLastColumns n mt = subMatrix (0, cols mt - n) (rows mt, n) mt | ||
254 | -- | Creates a copy of a matrix without the first n columns | ||
255 | dropColumns :: Element t => Int -> Matrix t -> Matrix t | ||
256 | dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt | ||
257 | -- | Creates a copy of a matrix without the last n columns | ||
258 | dropLastColumns :: Element t => Int -> Matrix t -> Matrix t | ||
259 | dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt | ||
260 | |||
261 | ---------------------------------------------------------------- | ||
262 | |||
263 | {- | Creates a 'Matrix' from a list of lists (considered as rows). | ||
264 | |||
265 | >>> fromLists [[1,2],[3,4],[5,6]] | ||
266 | (3><2) | ||
267 | [ 1.0, 2.0 | ||
268 | , 3.0, 4.0 | ||
269 | , 5.0, 6.0 ] | ||
270 | |||
271 | -} | ||
272 | fromLists :: Element t => [[t]] -> Matrix t | ||
273 | fromLists = fromRows . map fromList | ||
274 | |||
275 | -- | creates a 1-row matrix from a vector | ||
276 | -- | ||
277 | -- >>> asRow (fromList [1..5]) | ||
278 | -- (1><5) | ||
279 | -- [ 1.0, 2.0, 3.0, 4.0, 5.0 ] | ||
280 | -- | ||
281 | asRow :: Storable a => Vector a -> Matrix a | ||
282 | asRow = trans . asColumn | ||
283 | |||
284 | -- | creates a 1-column matrix from a vector | ||
285 | -- | ||
286 | -- >>> asColumn (fromList [1..5]) | ||
287 | -- (5><1) | ||
288 | -- [ 1.0 | ||
289 | -- , 2.0 | ||
290 | -- , 3.0 | ||
291 | -- , 4.0 | ||
292 | -- , 5.0 ] | ||
293 | -- | ||
294 | asColumn :: Storable a => Vector a -> Matrix a | ||
295 | asColumn v = reshape 1 v | ||
296 | |||
297 | |||
298 | |||
299 | {- | creates a Matrix of the specified size using the supplied function to | ||
300 | to map the row\/column position to the value at that row\/column position. | ||
301 | |||
302 | @> buildMatrix 3 4 (\\(r,c) -> fromIntegral r * fromIntegral c) | ||
303 | (3><4) | ||
304 | [ 0.0, 0.0, 0.0, 0.0, 0.0 | ||
305 | , 0.0, 1.0, 2.0, 3.0, 4.0 | ||
306 | , 0.0, 2.0, 4.0, 6.0, 8.0]@ | ||
307 | |||
308 | Hilbert matrix of order N: | ||
309 | |||
310 | @hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@ | ||
311 | |||
312 | -} | ||
313 | buildMatrix :: Element a => Int -> Int -> ((Int, Int) -> a) -> Matrix a | ||
314 | buildMatrix rc cc f = | ||
315 | fromLists $ map (map f) | ||
316 | $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)] | ||
317 | |||
318 | ----------------------------------------------------- | ||
319 | |||
320 | fromArray2D :: (Storable e) => Array (Int, Int) e -> Matrix e | ||
321 | fromArray2D m = (r><c) (elems m) | ||
322 | where ((r0,c0),(r1,c1)) = bounds m | ||
323 | r = r1-r0+1 | ||
324 | c = c1-c0+1 | ||
325 | |||
326 | |||
327 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | ||
328 | extractRows :: Element t => [Int] -> Matrix t -> Matrix t | ||
329 | extractRows [] m = emptyM 0 (cols m) | ||
330 | extractRows l m = fromRows $ extract (toRows m) l | ||
331 | where | ||
332 | extract l' is = [l'!!i | i<- map verify is] | ||
333 | verify k | ||
334 | | k >= 0 && k < rows m = k | ||
335 | | otherwise = error $ "can't extract row " | ||
336 | ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m | ||
337 | |||
338 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | ||
339 | extractColumns :: Element t => [Int] -> Matrix t -> Matrix t | ||
340 | extractColumns l m = trans . extractRows (map verify l) . trans $ m | ||
341 | where | ||
342 | verify k | ||
343 | | k >= 0 && k < cols m = k | ||
344 | | otherwise = error $ "can't extract column " | ||
345 | ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m | ||
346 | |||
347 | |||
348 | |||
349 | {- | creates matrix by repetition of a matrix a given number of rows and columns | ||
350 | |||
351 | >>> repmat (ident 2) 2 3 | ||
352 | (4><6) | ||
353 | [ 1.0, 0.0, 1.0, 0.0, 1.0, 0.0 | ||
354 | , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 | ||
355 | , 1.0, 0.0, 1.0, 0.0, 1.0, 0.0 | ||
356 | , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ] | ||
357 | |||
358 | -} | ||
359 | repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t | ||
360 | repmat m r c | ||
361 | | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m) | ||
362 | | otherwise = fromBlocks $ replicate r $ replicate c $ m | ||
363 | |||
364 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. | ||
365 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | ||
366 | liftMatrix2Auto f m1 m2 | ||
367 | | compat' m1 m2 = lM f m1 m2 | ||
368 | | ok = lM f m1' m2' | ||
369 | | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2 | ||
370 | where | ||
371 | (r1,c1) = size m1 | ||
372 | (r2,c2) = size m2 | ||
373 | r = max r1 r2 | ||
374 | c = max c1 c2 | ||
375 | r0 = min r1 r2 | ||
376 | c0 = min c1 c2 | ||
377 | ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2 | ||
378 | m1' = conformMTo (r,c) m1 | ||
379 | m2' = conformMTo (r,c) m2 | ||
380 | |||
381 | -- FIXME do not flatten if equal order | ||
382 | lM f m1 m2 = matrixFromVector | ||
383 | RowMajor | ||
384 | (max (rows m1) (rows m2)) | ||
385 | (max (cols m1) (cols m2)) | ||
386 | (f (flatten m1) (flatten m2)) | ||
387 | |||
388 | compat' :: Matrix a -> Matrix b -> Bool | ||
389 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | ||
390 | where | ||
391 | s1 = size m1 | ||
392 | s2 = size m2 | ||
393 | |||
394 | ------------------------------------------------------------ | ||
395 | |||
396 | toBlockRows [r] m | ||
397 | | r == rows m = [m] | ||
398 | toBlockRows rs m | ||
399 | | cols m > 0 = map (reshape (cols m)) (takesV szs (flatten m)) | ||
400 | | otherwise = map g rs | ||
401 | where | ||
402 | szs = map (* cols m) rs | ||
403 | g k = (k><0)[] | ||
404 | |||
405 | toBlockCols [c] m | c == cols m = [m] | ||
406 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m | ||
407 | |||
408 | -- | Partition a matrix into blocks with the given numbers of rows and columns. | ||
409 | -- The remaining rows and columns are discarded. | ||
410 | toBlocks :: (Element t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]] | ||
411 | toBlocks rs cs m | ||
412 | | ok = map (toBlockCols cs) . toBlockRows rs $ m | ||
413 | | otherwise = error $ "toBlocks: bad partition: "++show rs++" "++show cs | ||
414 | ++ " "++shSize m | ||
415 | where | ||
416 | ok = sum rs <= rows m && sum cs <= cols m && all (>=0) rs && all (>=0) cs | ||
417 | |||
418 | -- | Fully partition a matrix into blocks of the same size. If the dimensions are not | ||
419 | -- a multiple of the given size the last blocks will be smaller. | ||
420 | toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]] | ||
421 | toBlocksEvery r c m | ||
422 | | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c | ||
423 | | otherwise = toBlocks rs cs m | ||
424 | where | ||
425 | (qr,rr) = rows m `divMod` r | ||
426 | (qc,rc) = cols m `divMod` c | ||
427 | rs = replicate qr r ++ if rr > 0 then [rr] else [] | ||
428 | cs = replicate qc c ++ if rc > 0 then [rc] else [] | ||
429 | |||
430 | ------------------------------------------------------------------- | ||
431 | |||
432 | -- Given a column number and a function taking matrix indexes, returns | ||
433 | -- a function which takes vector indexes (that can be used on the | ||
434 | -- flattened matrix). | ||
435 | mk :: Int -> ((Int, Int) -> t) -> (Int -> t) | ||
436 | mk c g = \k -> g (divMod k c) | ||
437 | |||
438 | {- | | ||
439 | |||
440 | >>> mapMatrixWithIndexM_ (\(i,j) v -> printf "m[%d,%d] = %.f\n" i j v :: IO()) ((2><3)[1 :: Double ..]) | ||
441 | m[0,0] = 1 | ||
442 | m[0,1] = 2 | ||
443 | m[0,2] = 3 | ||
444 | m[1,0] = 4 | ||
445 | m[1,1] = 5 | ||
446 | m[1,2] = 6 | ||
447 | |||
448 | -} | ||
449 | mapMatrixWithIndexM_ | ||
450 | :: (Element a, Num a, Monad m) => | ||
451 | ((Int, Int) -> a -> m ()) -> Matrix a -> m () | ||
452 | mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m | ||
453 | where | ||
454 | c = cols m | ||
455 | |||
456 | {- | | ||
457 | |||
458 | >>> mapMatrixWithIndexM (\(i,j) v -> Just $ 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double) | ||
459 | Just (3><3) | ||
460 | [ 100.0, 1.0, 2.0 | ||
461 | , 10.0, 111.0, 12.0 | ||
462 | , 20.0, 21.0, 122.0 ] | ||
463 | |||
464 | -} | ||
465 | mapMatrixWithIndexM | ||
466 | :: (Element a, Storable b, Monad m) => | ||
467 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) | ||
468 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m | ||
469 | where | ||
470 | c = cols m | ||
471 | |||
472 | {- | | ||
473 | |||
474 | >>> mapMatrixWithIndex (\(i,j) v -> 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double) | ||
475 | (3><3) | ||
476 | [ 100.0, 1.0, 2.0 | ||
477 | , 10.0, 111.0, 12.0 | ||
478 | , 20.0, 21.0, 122.0 ] | ||
479 | |||
480 | -} | ||
481 | mapMatrixWithIndex | ||
482 | :: (Element a, Storable b) => | ||
483 | ((Int, Int) -> a -> b) -> Matrix a -> Matrix b | ||
484 | mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m | ||
485 | where | ||
486 | c = cols m | ||
487 | |||
488 | mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b | ||
489 | mapMatrix f = liftMatrix (mapVector f) | ||
490 | |||