diff options
Diffstat (limited to 'packages/base/src/Internal/Element.hs')
-rw-r--r-- | packages/base/src/Internal/Element.hs | 604 |
1 files changed, 604 insertions, 0 deletions
diff --git a/packages/base/src/Internal/Element.hs b/packages/base/src/Internal/Element.hs new file mode 100644 index 0000000..a459678 --- /dev/null +++ b/packages/base/src/Internal/Element.hs | |||
@@ -0,0 +1,604 @@ | |||
1 | {-# LANGUAGE TypeFamilies #-} | ||
2 | {-# LANGUAGE FlexibleContexts #-} | ||
3 | {-# LANGUAGE FlexibleInstances #-} | ||
4 | {-# LANGUAGE UndecidableInstances #-} | ||
5 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
6 | {-# LANGUAGE CPP #-} | ||
7 | |||
8 | ----------------------------------------------------------------------------- | ||
9 | -- | | ||
10 | -- Module : Data.Packed.Matrix | ||
11 | -- Copyright : (c) Alberto Ruiz 2007-10 | ||
12 | -- License : BSD3 | ||
13 | -- Maintainer : Alberto Ruiz | ||
14 | -- Stability : provisional | ||
15 | -- | ||
16 | -- A Matrix representation suitable for numerical computations using LAPACK and GSL. | ||
17 | -- | ||
18 | -- This module provides basic functions for manipulation of structure. | ||
19 | |||
20 | ----------------------------------------------------------------------------- | ||
21 | |||
22 | module Internal.Element where | ||
23 | |||
24 | import Internal.Vector | ||
25 | import Internal.Matrix | ||
26 | import Internal.Vectorized | ||
27 | import qualified Internal.ST as ST | ||
28 | import Data.Array | ||
29 | import Text.Printf | ||
30 | import Data.List(transpose,intersperse) | ||
31 | import Data.List.Split(chunksOf) | ||
32 | import Foreign.Storable(Storable) | ||
33 | import System.IO.Unsafe(unsafePerformIO) | ||
34 | import Control.Monad(liftM) | ||
35 | |||
36 | ------------------------------------------------------------------- | ||
37 | |||
38 | #ifdef BINARY | ||
39 | |||
40 | import Data.Binary | ||
41 | |||
42 | instance (Binary (Vector a), Element a) => Binary (Matrix a) where | ||
43 | put m = do | ||
44 | put (cols m) | ||
45 | put (flatten m) | ||
46 | get = do | ||
47 | c <- get | ||
48 | v <- get | ||
49 | return (reshape c v) | ||
50 | |||
51 | #endif | ||
52 | |||
53 | ------------------------------------------------------------------- | ||
54 | |||
55 | instance (Show a, Element a) => (Show (Matrix a)) where | ||
56 | show m | rows m == 0 || cols m == 0 = sizes m ++" []" | ||
57 | show m = (sizes m++) . dsp . map (map show) . toLists $ m | ||
58 | |||
59 | sizes m = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
60 | |||
61 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | ||
62 | where | ||
63 | mt = transpose as | ||
64 | longs = map (maximum . map length) mt | ||
65 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
66 | pad n str = replicate (n - length str) ' ' ++ str | ||
67 | unwords' = concat . intersperse ", " | ||
68 | |||
69 | ------------------------------------------------------------------ | ||
70 | |||
71 | instance (Element a, Read a) => Read (Matrix a) where | ||
72 | readsPrec _ s = [((rs><cs) . read $ listnums, rest)] | ||
73 | where (thing,rest) = breakAt ']' s | ||
74 | (dims,listnums) = breakAt ')' thing | ||
75 | cs = read . init . fst. breakAt ')' . snd . breakAt '<' $ dims | ||
76 | rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims | ||
77 | |||
78 | |||
79 | breakAt c l = (a++[c],tail b) where | ||
80 | (a,b) = break (==c) l | ||
81 | |||
82 | -------------------------------------------------------------------------------- | ||
83 | -- | Specification of indexes for the operator '??'. | ||
84 | data Extractor | ||
85 | = All | ||
86 | | Range Int Int Int | ||
87 | | Pos (Vector I) | ||
88 | | PosCyc (Vector I) | ||
89 | | Take Int | ||
90 | | TakeLast Int | ||
91 | | Drop Int | ||
92 | | DropLast Int | ||
93 | deriving Show | ||
94 | |||
95 | ppext All = ":" | ||
96 | ppext (Range a 1 c) = printf "%d:%d" a c | ||
97 | ppext (Range a b c) = printf "%d:%d:%d" a b c | ||
98 | ppext (Pos v) = show (toList v) | ||
99 | ppext (PosCyc v) = "Cyclic"++show (toList v) | ||
100 | ppext (Take n) = printf "Take %d" n | ||
101 | ppext (Drop n) = printf "Drop %d" n | ||
102 | ppext (TakeLast n) = printf "TakeLast %d" n | ||
103 | ppext (DropLast n) = printf "DropLast %d" n | ||
104 | |||
105 | {- | General matrix slicing. | ||
106 | |||
107 | >>> m | ||
108 | (4><5) | ||
109 | [ 0, 1, 2, 3, 4 | ||
110 | , 5, 6, 7, 8, 9 | ||
111 | , 10, 11, 12, 13, 14 | ||
112 | , 15, 16, 17, 18, 19 ] | ||
113 | |||
114 | >>> m ?? (Take 3, DropLast 2) | ||
115 | (3><3) | ||
116 | [ 0, 1, 2 | ||
117 | , 5, 6, 7 | ||
118 | , 10, 11, 12 ] | ||
119 | |||
120 | >>> m ?? (Pos (idxs[2,1]), All) | ||
121 | (2><5) | ||
122 | [ 10, 11, 12, 13, 14 | ||
123 | , 5, 6, 7, 8, 9 ] | ||
124 | |||
125 | >>> m ?? (PosCyc (idxs[-7,80]), Range 4 (-2) 0) | ||
126 | (2><3) | ||
127 | [ 9, 7, 5 | ||
128 | , 4, 2, 0 ] | ||
129 | |||
130 | -} | ||
131 | infixl 9 ?? | ||
132 | (??) :: Element t => Matrix t -> (Extractor,Extractor) -> Matrix t | ||
133 | |||
134 | minEl = toScalarI Min | ||
135 | maxEl = toScalarI Max | ||
136 | cmodi = vectorMapValI ModVS | ||
137 | |||
138 | extractError m (e1,e2)= error $ printf "can't extract (%s,%s) from matrix %dx%d" (ppext e1::String) (ppext e2::String) (rows m) (cols m) | ||
139 | |||
140 | m ?? (Range a s b,e) | s /= 1 = m ?? (Pos (idxs [a,a+s .. b]), e) | ||
141 | m ?? (e,Range a s b) | s /= 1 = m ?? (e, Pos (idxs [a,a+s .. b])) | ||
142 | |||
143 | m ?? e@(Range a _ b,_) | a < 0 || b >= rows m = extractError m e | ||
144 | m ?? e@(_,Range a _ b) | a < 0 || b >= cols m = extractError m e | ||
145 | |||
146 | m ?? e@(Pos vs,_) | dim vs>0 && (minEl vs < 0 || maxEl vs >= fi (rows m)) = extractError m e | ||
147 | m ?? e@(_,Pos vs) | dim vs>0 && (minEl vs < 0 || maxEl vs >= fi (cols m)) = extractError m e | ||
148 | |||
149 | m ?? (All,All) = m | ||
150 | |||
151 | m ?? (Range a _ b,e) | a > b = m ?? (Take 0,e) | ||
152 | m ?? (e,Range a _ b) | a > b = m ?? (e,Take 0) | ||
153 | |||
154 | m ?? (Take n,e) | ||
155 | | n <= 0 = (0><cols m) [] ?? (All,e) | ||
156 | | n >= rows m = m ?? (All,e) | ||
157 | |||
158 | m ?? (e,Take n) | ||
159 | | n <= 0 = (rows m><0) [] ?? (e,All) | ||
160 | | n >= cols m = m ?? (e,All) | ||
161 | |||
162 | m ?? (Drop n,e) | ||
163 | | n <= 0 = m ?? (All,e) | ||
164 | | n >= rows m = (0><cols m) [] ?? (All,e) | ||
165 | |||
166 | m ?? (e,Drop n) | ||
167 | | n <= 0 = m ?? (e,All) | ||
168 | | n >= cols m = (rows m><0) [] ?? (e,All) | ||
169 | |||
170 | m ?? (TakeLast n, e) = m ?? (Drop (rows m - n), e) | ||
171 | m ?? (e, TakeLast n) = m ?? (e, Drop (cols m - n)) | ||
172 | |||
173 | m ?? (DropLast n, e) = m ?? (Take (rows m - n), e) | ||
174 | m ?? (e, DropLast n) = m ?? (e, Take (cols m - n)) | ||
175 | |||
176 | m ?? (er,ec) = unsafePerformIO $ extractR (orderOf m) m moder rs modec cs | ||
177 | where | ||
178 | (moder,rs) = mkExt (rows m) er | ||
179 | (modec,cs) = mkExt (cols m) ec | ||
180 | ran a b = (0, idxs [a,b]) | ||
181 | pos ks = (1, ks) | ||
182 | mkExt _ (Pos ks) = pos ks | ||
183 | mkExt n (PosCyc ks) | ||
184 | | n == 0 = mkExt n (Take 0) | ||
185 | | otherwise = pos (cmodi (fi n) ks) | ||
186 | mkExt _ (Range mn _ mx) = ran mn mx | ||
187 | mkExt _ (Take k) = ran 0 (k-1) | ||
188 | mkExt n (Drop k) = ran k (n-1) | ||
189 | mkExt n _ = ran 0 (n-1) -- All | ||
190 | |||
191 | -------------------------------------------------------------------------------- | ||
192 | |||
193 | -- | obtains the common value of a property of a list | ||
194 | common :: (Eq a) => (b->a) -> [b] -> Maybe a | ||
195 | common f = commonval . map f | ||
196 | where | ||
197 | commonval :: (Eq a) => [a] -> Maybe a | ||
198 | commonval [] = Nothing | ||
199 | commonval [a] = Just a | ||
200 | commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing | ||
201 | |||
202 | |||
203 | -- | creates a matrix from a vertical list of matrices | ||
204 | joinVert :: Element t => [Matrix t] -> Matrix t | ||
205 | joinVert [] = emptyM 0 0 | ||
206 | joinVert ms = case common cols ms of | ||
207 | Nothing -> error "(impossible) joinVert on matrices with different number of columns" | ||
208 | Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms) | ||
209 | |||
210 | -- | creates a matrix from a horizontal list of matrices | ||
211 | joinHoriz :: Element t => [Matrix t] -> Matrix t | ||
212 | joinHoriz ms = trans. joinVert . map trans $ ms | ||
213 | |||
214 | {- | Create a matrix from blocks given as a list of lists of matrices. | ||
215 | |||
216 | Single row-column components are automatically expanded to match the | ||
217 | corresponding common row and column: | ||
218 | |||
219 | @ | ||
220 | disp = putStr . dispf 2 | ||
221 | @ | ||
222 | |||
223 | >>> disp $ fromBlocks [[ident 5, 7, row[10,20]], [3, diagl[1,2,3], 0]] | ||
224 | 8x10 | ||
225 | 1 0 0 0 0 7 7 7 10 20 | ||
226 | 0 1 0 0 0 7 7 7 10 20 | ||
227 | 0 0 1 0 0 7 7 7 10 20 | ||
228 | 0 0 0 1 0 7 7 7 10 20 | ||
229 | 0 0 0 0 1 7 7 7 10 20 | ||
230 | 3 3 3 3 3 1 0 0 0 0 | ||
231 | 3 3 3 3 3 0 2 0 0 0 | ||
232 | 3 3 3 3 3 0 0 3 0 0 | ||
233 | |||
234 | -} | ||
235 | fromBlocks :: Element t => [[Matrix t]] -> Matrix t | ||
236 | fromBlocks = fromBlocksRaw . adaptBlocks | ||
237 | |||
238 | fromBlocksRaw mms = joinVert . map joinHoriz $ mms | ||
239 | |||
240 | adaptBlocks ms = ms' where | ||
241 | bc = case common length ms of | ||
242 | Just c -> c | ||
243 | Nothing -> error "fromBlocks requires rectangular [[Matrix]]" | ||
244 | rs = map (compatdim . map rows) ms | ||
245 | cs = map (compatdim . map cols) (transpose ms) | ||
246 | szs = sequence [rs,cs] | ||
247 | ms' = chunksOf bc $ zipWith g szs (concat ms) | ||
248 | |||
249 | g [Just nr,Just nc] m | ||
250 | | nr == r && nc == c = m | ||
251 | | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc)) | ||
252 | | r == 1 = fromRows (replicate nr (flatten m)) | ||
253 | | otherwise = fromColumns (replicate nc (flatten m)) | ||
254 | where | ||
255 | r = rows m | ||
256 | c = cols m | ||
257 | x = m@@>(0,0) | ||
258 | g _ _ = error "inconsistent dimensions in fromBlocks" | ||
259 | |||
260 | |||
261 | -------------------------------------------------------------------------------- | ||
262 | |||
263 | {- | create a block diagonal matrix | ||
264 | |||
265 | >>> disp 2 $ diagBlock [konst 1 (2,2), konst 2 (3,5), col [5,7]] | ||
266 | 7x8 | ||
267 | 1 1 0 0 0 0 0 0 | ||
268 | 1 1 0 0 0 0 0 0 | ||
269 | 0 0 2 2 2 2 2 0 | ||
270 | 0 0 2 2 2 2 2 0 | ||
271 | 0 0 2 2 2 2 2 0 | ||
272 | 0 0 0 0 0 0 0 5 | ||
273 | 0 0 0 0 0 0 0 7 | ||
274 | |||
275 | >>> diagBlock [(0><4)[], konst 2 (2,3)] :: Matrix Double | ||
276 | (2><7) | ||
277 | [ 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 | ||
278 | , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ] | ||
279 | |||
280 | -} | ||
281 | diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t | ||
282 | diagBlock ms = fromBlocks $ zipWith f ms [0..] | ||
283 | where | ||
284 | f m k = take n $ replicate k z ++ m : repeat z | ||
285 | n = length ms | ||
286 | z = (1><1) [0] | ||
287 | |||
288 | -------------------------------------------------------------------------------- | ||
289 | |||
290 | |||
291 | -- | Reverse rows | ||
292 | flipud :: Element t => Matrix t -> Matrix t | ||
293 | flipud m = extractRows [r-1,r-2 .. 0] $ m | ||
294 | where | ||
295 | r = rows m | ||
296 | |||
297 | -- | Reverse columns | ||
298 | fliprl :: Element t => Matrix t -> Matrix t | ||
299 | fliprl m = extractColumns [c-1,c-2 .. 0] $ m | ||
300 | where | ||
301 | c = cols m | ||
302 | |||
303 | ------------------------------------------------------------ | ||
304 | |||
305 | {- | creates a rectangular diagonal matrix: | ||
306 | |||
307 | >>> diagRect 7 (fromList [10,20,30]) 4 5 :: Matrix Double | ||
308 | (4><5) | ||
309 | [ 10.0, 7.0, 7.0, 7.0, 7.0 | ||
310 | , 7.0, 20.0, 7.0, 7.0, 7.0 | ||
311 | , 7.0, 7.0, 30.0, 7.0, 7.0 | ||
312 | , 7.0, 7.0, 7.0, 7.0, 7.0 ] | ||
313 | |||
314 | -} | ||
315 | diagRect :: (Storable t) => t -> Vector t -> Int -> Int -> Matrix t | ||
316 | diagRect z v r c = ST.runSTMatrix $ do | ||
317 | m <- ST.newMatrix z r c | ||
318 | let d = min r c `min` (dim v) | ||
319 | mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1] | ||
320 | return m | ||
321 | |||
322 | -- | extracts the diagonal from a rectangular matrix | ||
323 | takeDiag :: (Element t) => Matrix t -> Vector t | ||
324 | takeDiag m = fromList [flatten m @> (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]] | ||
325 | |||
326 | ------------------------------------------------------------ | ||
327 | |||
328 | {- | Create a matrix from a list of elements | ||
329 | |||
330 | >>> (2><3) [2, 4, 7+2*iC, -3, 11, 0] | ||
331 | (2><3) | ||
332 | [ 2.0 :+ 0.0, 4.0 :+ 0.0, 7.0 :+ 2.0 | ||
333 | , (-3.0) :+ (-0.0), 11.0 :+ 0.0, 0.0 :+ 0.0 ] | ||
334 | |||
335 | The input list is explicitly truncated, so that it can | ||
336 | safely be used with lists that are too long (like infinite lists). | ||
337 | |||
338 | >>> (2><3)[1..] | ||
339 | (2><3) | ||
340 | [ 1.0, 2.0, 3.0 | ||
341 | , 4.0, 5.0, 6.0 ] | ||
342 | |||
343 | This is the format produced by the instances of Show (Matrix a), which | ||
344 | can also be used for input. | ||
345 | |||
346 | -} | ||
347 | (><) :: (Storable a) => Int -> Int -> [a] -> Matrix a | ||
348 | r >< c = f where | ||
349 | f l | dim v == r*c = matrixFromVector RowMajor r c v | ||
350 | | otherwise = error $ "inconsistent list size = " | ||
351 | ++show (dim v) ++" in ("++show r++"><"++show c++")" | ||
352 | where v = fromList $ take (r*c) l | ||
353 | |||
354 | ---------------------------------------------------------------- | ||
355 | |||
356 | takeRows :: Element t => Int -> Matrix t -> Matrix t | ||
357 | takeRows n mt = subMatrix (0,0) (n, cols mt) mt | ||
358 | |||
359 | -- | Creates a matrix with the last n rows of another matrix | ||
360 | takeLastRows :: Element t => Int -> Matrix t -> Matrix t | ||
361 | takeLastRows n mt = subMatrix (rows mt - n, 0) (n, cols mt) mt | ||
362 | |||
363 | dropRows :: Element t => Int -> Matrix t -> Matrix t | ||
364 | dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt | ||
365 | |||
366 | -- | Creates a copy of a matrix without the last n rows | ||
367 | dropLastRows :: Element t => Int -> Matrix t -> Matrix t | ||
368 | dropLastRows n mt = subMatrix (0,0) (rows mt - n, cols mt) mt | ||
369 | |||
370 | takeColumns :: Element t => Int -> Matrix t -> Matrix t | ||
371 | takeColumns n mt = subMatrix (0,0) (rows mt, n) mt | ||
372 | |||
373 | -- |Creates a matrix with the last n columns of another matrix | ||
374 | takeLastColumns :: Element t => Int -> Matrix t -> Matrix t | ||
375 | takeLastColumns n mt = subMatrix (0, cols mt - n) (rows mt, n) mt | ||
376 | |||
377 | dropColumns :: Element t => Int -> Matrix t -> Matrix t | ||
378 | dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt | ||
379 | |||
380 | -- | Creates a copy of a matrix without the last n columns | ||
381 | dropLastColumns :: Element t => Int -> Matrix t -> Matrix t | ||
382 | dropLastColumns n mt = subMatrix (0,0) (rows mt, cols mt - n) mt | ||
383 | |||
384 | ---------------------------------------------------------------- | ||
385 | |||
386 | {- | Creates a 'Matrix' from a list of lists (considered as rows). | ||
387 | |||
388 | >>> fromLists [[1,2],[3,4],[5,6]] | ||
389 | (3><2) | ||
390 | [ 1.0, 2.0 | ||
391 | , 3.0, 4.0 | ||
392 | , 5.0, 6.0 ] | ||
393 | |||
394 | -} | ||
395 | fromLists :: Element t => [[t]] -> Matrix t | ||
396 | fromLists = fromRows . map fromList | ||
397 | |||
398 | -- | creates a 1-row matrix from a vector | ||
399 | -- | ||
400 | -- >>> asRow (fromList [1..5]) | ||
401 | -- (1><5) | ||
402 | -- [ 1.0, 2.0, 3.0, 4.0, 5.0 ] | ||
403 | -- | ||
404 | asRow :: Storable a => Vector a -> Matrix a | ||
405 | asRow = trans . asColumn | ||
406 | |||
407 | -- | creates a 1-column matrix from a vector | ||
408 | -- | ||
409 | -- >>> asColumn (fromList [1..5]) | ||
410 | -- (5><1) | ||
411 | -- [ 1.0 | ||
412 | -- , 2.0 | ||
413 | -- , 3.0 | ||
414 | -- , 4.0 | ||
415 | -- , 5.0 ] | ||
416 | -- | ||
417 | asColumn :: Storable a => Vector a -> Matrix a | ||
418 | asColumn v = reshape 1 v | ||
419 | |||
420 | |||
421 | |||
422 | {- | creates a Matrix of the specified size using the supplied function to | ||
423 | to map the row\/column position to the value at that row\/column position. | ||
424 | |||
425 | @> buildMatrix 3 4 (\\(r,c) -> fromIntegral r * fromIntegral c) | ||
426 | (3><4) | ||
427 | [ 0.0, 0.0, 0.0, 0.0, 0.0 | ||
428 | , 0.0, 1.0, 2.0, 3.0, 4.0 | ||
429 | , 0.0, 2.0, 4.0, 6.0, 8.0]@ | ||
430 | |||
431 | Hilbert matrix of order N: | ||
432 | |||
433 | @hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@ | ||
434 | |||
435 | -} | ||
436 | buildMatrix :: Element a => Int -> Int -> ((Int, Int) -> a) -> Matrix a | ||
437 | buildMatrix rc cc f = | ||
438 | fromLists $ map (map f) | ||
439 | $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)] | ||
440 | |||
441 | ----------------------------------------------------- | ||
442 | |||
443 | fromArray2D :: (Storable e) => Array (Int, Int) e -> Matrix e | ||
444 | fromArray2D m = (r><c) (elems m) | ||
445 | where ((r0,c0),(r1,c1)) = bounds m | ||
446 | r = r1-r0+1 | ||
447 | c = c1-c0+1 | ||
448 | |||
449 | |||
450 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | ||
451 | extractRows :: Element t => [Int] -> Matrix t -> Matrix t | ||
452 | extractRows l m = m ?? (Pos (idxs l), All) | ||
453 | |||
454 | -- | rearranges the rows of a matrix according to the order given in a list of integers. | ||
455 | extractColumns :: Element t => [Int] -> Matrix t -> Matrix t | ||
456 | extractColumns l m = m ?? (All, Pos (idxs l)) | ||
457 | |||
458 | |||
459 | {- | creates matrix by repetition of a matrix a given number of rows and columns | ||
460 | |||
461 | >>> repmat (ident 2) 2 3 | ||
462 | (4><6) | ||
463 | [ 1.0, 0.0, 1.0, 0.0, 1.0, 0.0 | ||
464 | , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 | ||
465 | , 1.0, 0.0, 1.0, 0.0, 1.0, 0.0 | ||
466 | , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ] | ||
467 | |||
468 | -} | ||
469 | repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t | ||
470 | repmat m r c | ||
471 | | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m) | ||
472 | | otherwise = fromBlocks $ replicate r $ replicate c $ m | ||
473 | |||
474 | -- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix. | ||
475 | liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t | ||
476 | liftMatrix2Auto f m1 m2 | ||
477 | | compat' m1 m2 = lM f m1 m2 | ||
478 | | ok = lM f m1' m2' | ||
479 | | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2 | ||
480 | where | ||
481 | (r1,c1) = size m1 | ||
482 | (r2,c2) = size m2 | ||
483 | r = max r1 r2 | ||
484 | c = max c1 c2 | ||
485 | r0 = min r1 r2 | ||
486 | c0 = min c1 c2 | ||
487 | ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2 | ||
488 | m1' = conformMTo (r,c) m1 | ||
489 | m2' = conformMTo (r,c) m2 | ||
490 | |||
491 | -- FIXME do not flatten if equal order | ||
492 | lM f m1 m2 = matrixFromVector | ||
493 | RowMajor | ||
494 | (max' (rows m1) (rows m2)) | ||
495 | (max' (cols m1) (cols m2)) | ||
496 | (f (flatten m1) (flatten m2)) | ||
497 | where | ||
498 | max' 1 b = b | ||
499 | max' a 1 = a | ||
500 | max' a b = max a b | ||
501 | |||
502 | compat' :: Matrix a -> Matrix b -> Bool | ||
503 | compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2 | ||
504 | where | ||
505 | s1 = size m1 | ||
506 | s2 = size m2 | ||
507 | |||
508 | ------------------------------------------------------------ | ||
509 | |||
510 | toBlockRows [r] m | ||
511 | | r == rows m = [m] | ||
512 | toBlockRows rs m | ||
513 | | cols m > 0 = map (reshape (cols m)) (takesV szs (flatten m)) | ||
514 | | otherwise = map g rs | ||
515 | where | ||
516 | szs = map (* cols m) rs | ||
517 | g k = (k><0)[] | ||
518 | |||
519 | toBlockCols [c] m | c == cols m = [m] | ||
520 | toBlockCols cs m = map trans . toBlockRows cs . trans $ m | ||
521 | |||
522 | -- | Partition a matrix into blocks with the given numbers of rows and columns. | ||
523 | -- The remaining rows and columns are discarded. | ||
524 | toBlocks :: (Element t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]] | ||
525 | toBlocks rs cs m | ||
526 | | ok = map (toBlockCols cs) . toBlockRows rs $ m | ||
527 | | otherwise = error $ "toBlocks: bad partition: "++show rs++" "++show cs | ||
528 | ++ " "++shSize m | ||
529 | where | ||
530 | ok = sum rs <= rows m && sum cs <= cols m && all (>=0) rs && all (>=0) cs | ||
531 | |||
532 | -- | Fully partition a matrix into blocks of the same size. If the dimensions are not | ||
533 | -- a multiple of the given size the last blocks will be smaller. | ||
534 | toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]] | ||
535 | toBlocksEvery r c m | ||
536 | | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c | ||
537 | | otherwise = toBlocks rs cs m | ||
538 | where | ||
539 | (qr,rr) = rows m `divMod` r | ||
540 | (qc,rc) = cols m `divMod` c | ||
541 | rs = replicate qr r ++ if rr > 0 then [rr] else [] | ||
542 | cs = replicate qc c ++ if rc > 0 then [rc] else [] | ||
543 | |||
544 | ------------------------------------------------------------------- | ||
545 | |||
546 | -- Given a column number and a function taking matrix indexes, returns | ||
547 | -- a function which takes vector indexes (that can be used on the | ||
548 | -- flattened matrix). | ||
549 | mk :: Int -> ((Int, Int) -> t) -> (Int -> t) | ||
550 | mk c g = \k -> g (divMod k c) | ||
551 | |||
552 | {- | | ||
553 | |||
554 | >>> mapMatrixWithIndexM_ (\(i,j) v -> printf "m[%d,%d] = %.f\n" i j v :: IO()) ((2><3)[1 :: Double ..]) | ||
555 | m[0,0] = 1 | ||
556 | m[0,1] = 2 | ||
557 | m[0,2] = 3 | ||
558 | m[1,0] = 4 | ||
559 | m[1,1] = 5 | ||
560 | m[1,2] = 6 | ||
561 | |||
562 | -} | ||
563 | mapMatrixWithIndexM_ | ||
564 | :: (Element a, Num a, Monad m) => | ||
565 | ((Int, Int) -> a -> m ()) -> Matrix a -> m () | ||
566 | mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m | ||
567 | where | ||
568 | c = cols m | ||
569 | |||
570 | {- | | ||
571 | |||
572 | >>> mapMatrixWithIndexM (\(i,j) v -> Just $ 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double) | ||
573 | Just (3><3) | ||
574 | [ 100.0, 1.0, 2.0 | ||
575 | , 10.0, 111.0, 12.0 | ||
576 | , 20.0, 21.0, 122.0 ] | ||
577 | |||
578 | -} | ||
579 | mapMatrixWithIndexM | ||
580 | :: (Element a, Storable b, Monad m) => | ||
581 | ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b) | ||
582 | mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m | ||
583 | where | ||
584 | c = cols m | ||
585 | |||
586 | {- | | ||
587 | |||
588 | >>> mapMatrixWithIndex (\(i,j) v -> 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double) | ||
589 | (3><3) | ||
590 | [ 100.0, 1.0, 2.0 | ||
591 | , 10.0, 111.0, 12.0 | ||
592 | , 20.0, 21.0, 122.0 ] | ||
593 | |||
594 | -} | ||
595 | mapMatrixWithIndex | ||
596 | :: (Element a, Storable b) => | ||
597 | ((Int, Int) -> a -> b) -> Matrix a -> Matrix b | ||
598 | mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m | ||
599 | where | ||
600 | c = cols m | ||
601 | |||
602 | mapMatrix :: (Element a, Element b) => (a -> b) -> Matrix a -> Matrix b | ||
603 | mapMatrix f = liftMatrix (mapVector f) | ||
604 | |||