summaryrefslogtreecommitdiff
path: root/lib/Data/Packed/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Data/Packed/Internal/Matrix.hs')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs187
1 files changed, 187 insertions, 0 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
new file mode 100644
index 0000000..2c57c07
--- /dev/null
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -0,0 +1,187 @@
1{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Internal.Matrix
5-- Copyright : (c) Alberto Ruiz 2007
6-- License : GPL-style
7--
8-- Maintainer : Alberto Ruiz <aruiz@um.es>
9-- Stability : provisional
10-- Portability : portable (uses FFI)
11--
12-- Fundamental types
13--
14-----------------------------------------------------------------------------
15
16module Data.Packed.Internal.Matrix where
17
18import Data.Packed.Internal.Vector
19
20import Foreign hiding (xor)
21import Complex
22import Control.Monad(when)
23import Debug.Trace
24import Data.List(transpose,intersperse)
25import Data.Typeable
26import Data.Maybe(fromJust)
27
28debug x = trace (show x) x
29
30
31data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
32
33-- | 2D array
34data Matrix t = M { rows :: Int
35 , cols :: Int
36 , dat :: Vector t
37 , tdat :: Vector t
38 , isTrans :: Bool
39 , order :: MatrixOrder
40 } deriving Typeable
41
42xor a b = a && not b || b && not a
43
44fortran m = order m == ColumnMajor
45
46cdat m = if fortran m `xor` isTrans m then tdat m else dat m
47fdat m = if fortran m `xor` isTrans m then dat m else tdat m
48
49trans m = m { rows = cols m
50 , cols = rows m
51 , isTrans = not (isTrans m)
52 }
53
54type Mt t s = Int -> Int -> Ptr t -> s
55infixr 6 ::>
56type t ::> s = Mt t s
57
58mat d m f = f (rows m) (cols m) (ptr (d m))
59
60instance (Show a, Storable a) => (Show (Matrix a)) where
61 show m = (sizes++) . dsp . map (map show) . toLists $ m
62 where sizes = "("++show (rows m)++"><"++show (cols m)++")\n"
63
64partit :: Int -> [a] -> [[a]]
65partit _ [] = []
66partit n l = take n l : partit n (drop n l)
67
68toLists m | fortran m = transpose $ partit (rows m) . toList . dat $ m
69 | otherwise = partit (cols m) . toList . dat $ m
70
71dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
72 where
73 mt = transpose as
74 longs = map (maximum . map length) mt
75 mtp = zipWith (\a b -> map (pad a) b) longs mt
76 pad n str = replicate (n - length str) ' ' ++ str
77 unwords' = concat . intersperse ", "
78
79matrixFromVector RowMajor c v =
80 M { rows = r
81 , cols = c
82 , dat = v
83 , tdat = transdata c v r
84 , order = RowMajor
85 , isTrans = False
86 } where r = dim v `div` c -- TODO check mod=0
87
88matrixFromVector ColumnMajor c v =
89 M { rows = r
90 , cols = c
91 , dat = v
92 , tdat = transdata r v c
93 , order = ColumnMajor
94 , isTrans = False
95 } where r = dim v `div` c -- TODO check mod=0
96
97createMatrix order r c = do
98 p <- createVector (r*c)
99 return (matrixFromVector order c p)
100
101transdataG :: Storable a => Int -> Vector a -> Int -> Vector a
102transdataG c1 d c2 = fromList . concat . transpose . partit c1 . toList $ d
103
104transdataR :: Int -> Vector Double -> Int -> Vector Double
105transdataR = transdataAux ctransR
106
107transdataC :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double)
108transdataC = transdataAux ctransC
109
110transdataAux fun c1 d c2 = unsafePerformIO $ do
111 v <- createVector (dim d)
112 let r1 = dim d `div` c1
113 r2 = dim d `div` c2
114 fun r1 c1 (ptr d) r2 c2 (ptr v) // check "transdataAux" [d]
115 --putStrLn "---> transdataAux"
116 return v
117
118foreign import ccall safe "aux.h transR"
119 ctransR :: Double ::> Double ::> IO Int
120foreign import ccall safe "aux.h transC"
121 ctransC :: Complex Double ::> Complex Double ::> IO Int
122
123transdata :: Field a => Int -> Vector a -> Int -> Vector a
124transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
125 | isComp baseOf d = scast $ transdataC c1 (scast d) c2
126 | otherwise = transdataG c1 d c2
127
128--transdata :: Storable a => Int -> Vector a -> Int -> Vector a
129--transdata = transdataG
130--{-# RULES "transdataR" transdata=transdataR #-}
131--{-# RULES "transdataC" transdata=transdataC #-}
132
133-- | extracts the rows of a matrix as a list of vectors
134toRows :: Storable t => Matrix t -> [Vector t]
135toRows m = toRows' 0 where
136 v = cdat m
137 r = rows m
138 c = cols m
139 toRows' k | k == r*c = []
140 | otherwise = subVector k c v : toRows' (k+c)
141
142------------------------------------------------------------------
143
144dotL a b = sum (zipWith (*) a b)
145
146multiplyL a b = [[dotL x y | y <- transpose b] | x <- a]
147
148transL m = matrixFromVector RowMajor (rows m) $ transdataG (cols m) (cdat m) (rows m)
149
150multiplyG a b = matrixFromVector RowMajor (cols b) $ fromList $ concat $ multiplyL (toLists a) (toLists b)
151
152------------------------------------------------------------------
153
154gmatC m f | fortran m =
155 if (isTrans m)
156 then f 0 (rows m) (cols m) (ptr (dat m))
157 else f 1 (cols m) (rows m) (ptr (dat m))
158 | otherwise =
159 if isTrans m
160 then f 1 (cols m) (rows m) (ptr (dat m))
161 else f 0 (rows m) (cols m) (ptr (dat m))
162
163
164multiplyAux order fun a b = unsafePerformIO $ do
165 when (cols a /= rows b) $ error $ "inconsistent dimensions in contraction "++
166 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
167 r <- createMatrix order (rows a) (cols b)
168 fun // gmatC a // gmatC b // mat dat r // check "multiplyAux" [dat a, dat b]
169 return r
170
171foreign import ccall safe "aux.h multiplyR"
172 cmultiplyR :: Int -> Double ::> (Int -> Double ::> (Double ::> IO Int))
173
174foreign import ccall safe "aux.h multiplyC"
175 cmultiplyC :: Int -> Complex Double ::> (Int -> Complex Double ::> (Complex Double ::> IO Int))
176
177multiply :: (Num a, Field a) => MatrixOrder -> Matrix a -> Matrix a -> Matrix a
178multiply RowMajor a b = multiplyD RowMajor a b
179multiply ColumnMajor a b = trans $ multiplyT ColumnMajor a b
180
181multiplyT order a b = multiplyD order (trans b) (trans a)
182
183multiplyD order a b
184 | isReal (baseOf.dat) a = scast $ multiplyAux order cmultiplyR (scast a) (scast b)
185 | isComp (baseOf.dat) a = scast $ multiplyAux order cmultiplyC (scast a) (scast b)
186 | otherwise = multiplyG a b
187