diff options
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r-- | lib/Data/Packed/Internal/Matrix.hs | 187 |
1 files changed, 187 insertions, 0 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs new file mode 100644 index 0000000..2c57c07 --- /dev/null +++ b/lib/Data/Packed/Internal/Matrix.hs | |||
@@ -0,0 +1,187 @@ | |||
1 | {-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-} | ||
2 | ----------------------------------------------------------------------------- | ||
3 | -- | | ||
4 | -- Module : Data.Packed.Internal.Matrix | ||
5 | -- Copyright : (c) Alberto Ruiz 2007 | ||
6 | -- License : GPL-style | ||
7 | -- | ||
8 | -- Maintainer : Alberto Ruiz <aruiz@um.es> | ||
9 | -- Stability : provisional | ||
10 | -- Portability : portable (uses FFI) | ||
11 | -- | ||
12 | -- Fundamental types | ||
13 | -- | ||
14 | ----------------------------------------------------------------------------- | ||
15 | |||
16 | module Data.Packed.Internal.Matrix where | ||
17 | |||
18 | import Data.Packed.Internal.Vector | ||
19 | |||
20 | import Foreign hiding (xor) | ||
21 | import Complex | ||
22 | import Control.Monad(when) | ||
23 | import Debug.Trace | ||
24 | import Data.List(transpose,intersperse) | ||
25 | import Data.Typeable | ||
26 | import Data.Maybe(fromJust) | ||
27 | |||
28 | debug x = trace (show x) x | ||
29 | |||
30 | |||
31 | data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq) | ||
32 | |||
33 | -- | 2D array | ||
34 | data Matrix t = M { rows :: Int | ||
35 | , cols :: Int | ||
36 | , dat :: Vector t | ||
37 | , tdat :: Vector t | ||
38 | , isTrans :: Bool | ||
39 | , order :: MatrixOrder | ||
40 | } deriving Typeable | ||
41 | |||
42 | xor a b = a && not b || b && not a | ||
43 | |||
44 | fortran m = order m == ColumnMajor | ||
45 | |||
46 | cdat m = if fortran m `xor` isTrans m then tdat m else dat m | ||
47 | fdat m = if fortran m `xor` isTrans m then dat m else tdat m | ||
48 | |||
49 | trans m = m { rows = cols m | ||
50 | , cols = rows m | ||
51 | , isTrans = not (isTrans m) | ||
52 | } | ||
53 | |||
54 | type Mt t s = Int -> Int -> Ptr t -> s | ||
55 | infixr 6 ::> | ||
56 | type t ::> s = Mt t s | ||
57 | |||
58 | mat d m f = f (rows m) (cols m) (ptr (d m)) | ||
59 | |||
60 | instance (Show a, Storable a) => (Show (Matrix a)) where | ||
61 | show m = (sizes++) . dsp . map (map show) . toLists $ m | ||
62 | where sizes = "("++show (rows m)++"><"++show (cols m)++")\n" | ||
63 | |||
64 | partit :: Int -> [a] -> [[a]] | ||
65 | partit _ [] = [] | ||
66 | partit n l = take n l : partit n (drop n l) | ||
67 | |||
68 | toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m | ||
69 | | otherwise = partit (cols m) . toList . dat $ m | ||
70 | |||
71 | dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp | ||
72 | where | ||
73 | mt = transpose as | ||
74 | longs = map (maximum . map length) mt | ||
75 | mtp = zipWith (\a b -> map (pad a) b) longs mt | ||
76 | pad n str = replicate (n - length str) ' ' ++ str | ||
77 | unwords' = concat . intersperse ", " | ||
78 | |||
79 | matrixFromVector RowMajor c v = | ||
80 | M { rows = r | ||
81 | , cols = c | ||
82 | , dat = v | ||
83 | , tdat = transdata c v r | ||
84 | , order = RowMajor | ||
85 | , isTrans = False | ||
86 | } where r = dim v `div` c -- TODO check mod=0 | ||
87 | |||
88 | matrixFromVector ColumnMajor c v = | ||
89 | M { rows = r | ||
90 | , cols = c | ||
91 | , dat = v | ||
92 | , tdat = transdata r v c | ||
93 | , order = ColumnMajor | ||
94 | , isTrans = False | ||
95 | } where r = dim v `div` c -- TODO check mod=0 | ||
96 | |||
97 | createMatrix order r c = do | ||
98 | p <- createVector (r*c) | ||
99 | return (matrixFromVector order c p) | ||
100 | |||
101 | transdataG :: Storable a => Int -> Vector a -> Int -> Vector a | ||
102 | transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d | ||
103 | |||
104 | transdataR :: Int -> Vector Double -> Int -> Vector Double | ||
105 | transdataR = transdataAux ctransR | ||
106 | |||
107 | transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) | ||
108 | transdataC = transdataAux ctransC | ||
109 | |||
110 | transdataAux fun c1 d c2 = unsafePerformIO $ do | ||
111 | v <- createVector (dim d) | ||
112 | let r1 = dim d `div` c1 | ||
113 | r2 = dim d `div` c2 | ||
114 | fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d] | ||
115 | --putStrLn "---> transdataAux" | ||
116 | return v | ||
117 | |||
118 | foreign import ccall safe "aux.h transR" | ||
119 | ctransR :: Double ::> Double ::> IO Int | ||
120 | foreign import ccall safe "aux.h transC" | ||
121 | ctransC :: Complex Double ::> Complex Double ::> IO Int | ||
122 | |||
123 | transdata :: Field a => Int -> Vector a -> Int -> Vector a | ||
124 | transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2 | ||
125 | | isComp baseOf d = scast $ transdataC c1 (scast d) c2 | ||
126 | | otherwise = transdataG c1 d c2 | ||
127 | |||
128 | --transdata :: Storable a => Int -> Vector a -> Int -> Vector a | ||
129 | --transdata = transdataG | ||
130 | --{-# RULES "transdataR" transdata=transdataR #-} | ||
131 | --{-# RULES "transdataC" transdata=transdataC #-} | ||
132 | |||
133 | -- | extracts the rows of a matrix as a list of vectors | ||
134 | toRows :: Storable t => Matrix t -> [Vector t] | ||
135 | toRows m = toRows' 0 where | ||
136 | v = cdat m | ||
137 | r = rows m | ||
138 | c = cols m | ||
139 | toRows' k | k == r*c = [] | ||
140 | | otherwise = subVector k c v : toRows' (k+c) | ||
141 | |||
142 | ------------------------------------------------------------------ | ||
143 | |||
144 | dotL a b = sum (zipWith (*) a b) | ||
145 | |||
146 | multiplyL a b = [[dotL x y | y <- transpose b] | x <- a] | ||
147 | |||
148 | transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m) | ||
149 | |||
150 | multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b) | ||
151 | |||
152 | ------------------------------------------------------------------ | ||
153 | |||
154 | gmatC m f | fortran m = | ||
155 | if (isTrans m) | ||
156 | then f 0 (rows m) (cols m) (ptr (dat m)) | ||
157 | else f 1 (cols m) (rows m) (ptr (dat m)) | ||
158 | | otherwise = | ||
159 | if isTrans m | ||
160 | then f 1 (cols m) (rows m) (ptr (dat m)) | ||
161 | else f 0 (rows m) (cols m) (ptr (dat m)) | ||
162 | |||
163 | |||
164 | multiplyAux order fun a b = unsafePerformIO $ do | ||
165 | when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++ | ||
166 | show (rows a,cols a) ++ " x " ++ show (rows b, cols b) | ||
167 | r <- createMatrix order (rows a) (cols b) | ||
168 | fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b] | ||
169 | return r | ||
170 | |||
171 | foreign import ccall safe "aux.h multiplyR" | ||
172 | cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int)) | ||
173 | |||
174 | foreign import ccall safe "aux.h multiplyC" | ||
175 | cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int)) | ||
176 | |||
177 | multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a | ||
178 | multiply RowMajor a b = multiplyD RowMajor a b | ||
179 | multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b | ||
180 | |||
181 | multiplyT order a b = multiplyD order (trans b) (trans a) | ||
182 | |||
183 | multiplyD order a b | ||
184 | | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b) | ||
185 | | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b) | ||
186 | | otherwise = multiplyG a b | ||
187 | |||