summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Data/Packed/Internal/Matrix.hs10
-rw-r--r--lib/Data/Packed/Vector.hs3
-rw-r--r--lib/GSL.hs12
-rw-r--r--lib/GSL/Compat.hs115
-rw-r--r--lib/GSL/Matrix.hs2
5 files changed, 124 insertions, 18 deletions
diff --git a/lib/Data/Packed/Internal/Matrix.hs b/lib/Data/Packed/Internal/Matrix.hs
index dd33943..9895393 100644
--- a/lib/Data/Packed/Internal/Matrix.hs
+++ b/lib/Data/Packed/Internal/Matrix.hs
@@ -144,13 +144,17 @@ transdata c1 d c2 | isReal baseOf d = scast $ transdataR c1 (scast d) c2
144--{-# RULES "transdataC" transdata=transdataC #-} 144--{-# RULES "transdataC" transdata=transdataC #-}
145 145
146----------------------------------------------------------------- 146-----------------------------------------------------------------
147liftMatrix :: (Vector a -> Vector b) -> Matrix a -> Matrix b 147liftMatrix :: (Field a, Field b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
148liftMatrix f m = m { dat = f (dat m), tdat = f (tdat m) } -- check sizes 148liftMatrix f m = reshape (cols m) (f (cdat m))
149 149
150liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t 150liftMatrix2 :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
151liftMatrix2 f m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2)) -- check sizes 151liftMatrix2 f m1 m2 | compat m1 m2 = reshape (cols m1) (f (cdat m1) (cdat m2))
152 | otherwise = error "nonconformant matrices in liftMatrix2"
152------------------------------------------------------------------ 153------------------------------------------------------------------
153 154
155compat :: Matrix a -> Matrix b -> Bool
156compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
157
154dotL a b = sum (zipWith (*) a b) 158dotL a b = sum (zipWith (*) a b)
155 159
156multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a] 160multiplyL a b | ok = [[dotL x y | y <- transpose b] | x <- a]
diff --git a/lib/Data/Packed/Vector.hs b/lib/Data/Packed/Vector.hs
index 94f70be..27ba6a3 100644
--- a/lib/Data/Packed/Vector.hs
+++ b/lib/Data/Packed/Vector.hs
@@ -21,7 +21,8 @@ module Data.Packed.Vector (
21 toComplex, comp, 21 toComplex, comp,
22 conj, 22 conj,
23 dot, 23 dot,
24 linspace 24 linspace,
25 liftVector, liftVector2
25) where 26) where
26 27
27import Data.Packed.Internal 28import Data.Packed.Internal
diff --git a/lib/GSL.hs b/lib/GSL.hs
index ac3a2e8..bffbb62 100644
--- a/lib/GSL.hs
+++ b/lib/GSL.hs
@@ -25,15 +25,16 @@ module GSL.Special,
25module GSL.Fourier, 25module GSL.Fourier,
26module GSL.Polynomials, 26module GSL.Polynomials,
27module GSL.Minimization, 27module GSL.Minimization,
28module Data.Packed.Plot, 28module GSL.Matrix,
29module GSL.Compat 29module GSL.Compat,
30module Data.Packed.Plot
30 31
31) where 32) where
32 33
33import Data.Packed.Vector 34import Data.Packed.Vector hiding (constant)
34import Data.Packed.Matrix 35import Data.Packed.Matrix
35import Data.Packed.Tensor 36import Data.Packed.Tensor
36import LinearAlgebra.Algorithms 37import LinearAlgebra.Algorithms hiding (pnorm)
37import LAPACK 38import LAPACK
38import GSL.Integration 39import GSL.Integration
39import GSL.Differentiation 40import GSL.Differentiation
@@ -41,5 +42,6 @@ import GSL.Special
41import GSL.Fourier 42import GSL.Fourier
42import GSL.Polynomials 43import GSL.Polynomials
43import GSL.Minimization 44import GSL.Minimization
44import Data.Packed.Plot 45import GSL.Matrix
45import GSL.Compat 46import GSL.Compat
47import Data.Packed.Plot
diff --git a/lib/GSL/Compat.hs b/lib/GSL/Compat.hs
index 6a94191..2cae0c4 100644
--- a/lib/GSL/Compat.hs
+++ b/lib/GSL/Compat.hs
@@ -15,7 +15,8 @@ Creates reasonable numeric instances for Vectors and Matrices. In the context of
15----------------------------------------------------------------------------- 15-----------------------------------------------------------------------------
16 16
17module GSL.Compat( 17module GSL.Compat(
18 Mul,(<>), fromFile, readMatrix, size, dispR, dispC, format, gmap 18 Mul,(<>), readMatrix, size, dispR, dispC, format, gmap, Joinable, (<|>),(<->), GSL.Compat.constant,
19 vectorMax, vectorMin, fromArray2D, fromComplex, GSL.Compat.pnorm, scale
19) where 20) where
20 21
21import Data.Packed.Internal hiding (dsp) 22import Data.Packed.Internal hiding (dsp)
@@ -27,6 +28,8 @@ import LinearAlgebra.Algorithms
27import Complex 28import Complex
28import Numeric(showGFloat) 29import Numeric(showGFloat)
29import Data.List(transpose,intersperse) 30import Data.List(transpose,intersperse)
31import Foreign(Storable)
32import Data.Array
30 33
31 34
32adaptScalar f1 f2 f3 x y 35adaptScalar f1 f2 f3 x y
@@ -34,6 +37,15 @@ adaptScalar f1 f2 f3 x y
34 | dim y == 1 = f3 x (y@>0) 37 | dim y == 1 = f3 x (y@>0)
35 | otherwise = f2 x y 38 | otherwise = f2 x y
36 39
40liftMatrix2' :: (Field t) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
41liftMatrix2' f m1 m2 | compat' m1 m2 = reshape (max (cols m1) (cols m2)) (f (cdat m1) (cdat m2))
42 | otherwise = error "nonconformant matrices in liftMatrix2'"
43
44compat' :: Matrix a -> Matrix b -> Bool
45compat' m1 m2 = rows m1 == 1 && cols m1 == 1
46 || rows m2 == 1 && cols m2 == 1
47 || rows m1 == rows m2 && cols m1 == cols m2
48
37instance (Eq a, Field a) => Eq (Vector a) where 49instance (Eq a, Field a) => Eq (Vector a) where
38 a == b = dim a == dim b && toList a == toList b 50 a == b = dim a == dim b && toList a == toList b
39 51
@@ -49,9 +61,9 @@ instance (Eq a, Field a) => Eq (Matrix a) where
49 a == b = rows a == rows b && cols a == cols b && cdat a == cdat b && fdat a == fdat b 61 a == b = rows a == rows b && cols a == cols b && cdat a == cdat b && fdat a == fdat b
50 62
51instance (Num a, Field a) => Num (Matrix a) where 63instance (Num a, Field a) => Num (Matrix a) where
52 (+) = liftMatrix2 (+) 64 (+) = liftMatrix2' (+)
53 negate = liftMatrix negate 65 negate = liftMatrix negate
54 (*) = liftMatrix2 (*) 66 (*) = liftMatrix2' (*)
55 signum = liftMatrix signum 67 signum = liftMatrix signum
56 abs = liftMatrix abs 68 abs = liftMatrix abs
57 fromInteger = (1><1) . return . fromInteger 69 fromInteger = (1><1) . return . fromInteger
@@ -76,13 +88,13 @@ instance Fractional (Vector (Complex Double)) where
76 88
77instance Fractional (Matrix Double) where 89instance Fractional (Matrix Double) where
78 fromRational n = (1><1) [fromRational n] 90 fromRational n = (1><1) [fromRational n]
79 (/) = liftMatrix2 (/) 91 (/) = liftMatrix2' (/)
80 92
81------------------------------------------------------- 93-------------------------------------------------------
82 94
83instance Fractional (Matrix (Complex Double)) where 95instance Fractional (Matrix (Complex Double)) where
84 fromRational n = (1><1) [fromRational n] 96 fromRational n = (1><1) [fromRational n]
85 (/) = liftMatrix2 (/) 97 (/) = liftMatrix2' (/)
86 98
87--------------------------------------------------------- 99---------------------------------------------------------
88 100
@@ -122,7 +134,7 @@ instance Floating (Matrix Double) where
122 atanh = liftMatrix atanh 134 atanh = liftMatrix atanh
123 exp = liftMatrix exp 135 exp = liftMatrix exp
124 log = liftMatrix log 136 log = liftMatrix log
125 (**) = liftMatrix2 (**) 137 (**) = liftMatrix2' (**)
126 sqrt = liftMatrix sqrt 138 sqrt = liftMatrix sqrt
127 pi = (1><1) [pi] 139 pi = (1><1) [pi]
128------------------------------------------------------------- 140-------------------------------------------------------------
@@ -163,7 +175,7 @@ instance Floating (Matrix (Complex Double)) where
163 atanh = liftMatrix atanh 175 atanh = liftMatrix atanh
164 exp = liftMatrix exp 176 exp = liftMatrix exp
165 log = liftMatrix log 177 log = liftMatrix log
166 (**) = liftMatrix2 (**) 178 (**) = liftMatrix2' (**)
167 sqrt = liftMatrix sqrt 179 sqrt = liftMatrix sqrt
168 pi = (1><1) [pi] 180 pi = (1><1) [pi]
169 181
@@ -330,8 +342,11 @@ instance Mul (Matrix Double) (Complex Double) (Matrix (Complex Double)) where
330size :: Vector a -> Int 342size :: Vector a -> Int
331size = dim 343size = dim
332 344
345gmap :: (Storable a, Storable b) => (a->b) -> Vector a -> Vector b
333gmap f v = liftVector f v 346gmap f v = liftVector f v
334 347
348constant :: Double -> Int -> Vector Double
349constant = constantR
335 350
336-- shows a Double with n digits after the decimal point 351-- shows a Double with n digits after the decimal point
337shf :: (RealFloat a) => Int -> a -> String 352shf :: (RealFloat a) => Int -> a -> String
@@ -367,4 +382,88 @@ dispC d m = disp m (shfc d)
367 382
368-- | creates a matrix from a table of numbers. 383-- | creates a matrix from a table of numbers.
369readMatrix :: String -> Matrix Double 384readMatrix :: String -> Matrix Double
370readMatrix = fromLists . map (map read). map words . filter (not.null) . lines \ No newline at end of file 385readMatrix = fromLists . map (map read). map words . filter (not.null) . lines
386
387-------------------------------------------------------------
388
389class Joinable a b c | a b -> c where
390 joinH :: a -> b -> c
391 joinV :: a -> b -> c
392
393instance Joinable (Matrix Double) (Vector Double) (Matrix Double) where
394 joinH m v = fromBlocks [[m,reshape 1 v]]
395 joinV m v = fromBlocks [[m],[reshape (size v) v]]
396
397instance Joinable (Vector Double) (Matrix Double) (Matrix Double) where
398 joinH v m = fromBlocks [[reshape 1 v,m]]
399 joinV v m = fromBlocks [[reshape (size v) v],[m]]
400
401instance Joinable (Matrix Double) (Matrix Double) (Matrix Double) where
402 joinH m1 m2 = fromBlocks [[m1,m2]]
403 joinV m1 m2 = fromBlocks [[m1],[m2]]
404
405instance Joinable (Matrix (Complex Double)) (Vector (Complex Double)) (Matrix (Complex Double)) where
406 joinH m v = fromBlocks [[m,reshape 1 v]]
407 joinV m v = fromBlocks [[m],[reshape (size v) v]]
408
409instance Joinable (Vector (Complex Double)) (Matrix (Complex Double)) (Matrix (Complex Double)) where
410 joinH v m = fromBlocks [[reshape 1 v,m]]
411 joinV v m = fromBlocks [[reshape (size v) v],[m]]
412
413instance Joinable (Matrix (Complex Double)) (Matrix (Complex Double)) (Matrix (Complex Double)) where
414 joinH m1 m2 = fromBlocks [[m1,m2]]
415 joinV m1 m2 = fromBlocks [[m1],[m2]]
416
417infixl 3 <|>, <->
418
419{- | Horizontal concatenation of matrices and vectors:
420
421@\> 'ident' 3 \<-\> i\<\>'ident' 3 \<|\> 'fromList' [1..6]
422 1. 0. 0. 1.
423 0. 1. 0. 2.
424 0. 0. 1. 3.
4251.i 0. 0. 4.
426 0. 1.i 0. 5.
427 0. 0. 1.i 6.@
428-}
429(<|>) :: (Joinable a b c) => a -> b -> c
430a <|> b = joinH a b
431
432-- | Vertical concatenation of matrices and vectors.
433(<->) :: (Joinable a b c) => a -> b -> c
434a <-> b = joinV a b
435
436----------------------------------------------------------
437
438vectorMax = toScalarR Max
439
440vectorMin = toScalarR Min
441
442fromArray2D m = (r><c) (elems m)
443 where ((r0,c0),(r1,c1)) = bounds m
444 r = r1-r0+1
445 c = c1-c0+1
446
447-- | creates a complex vector from vectors with real and imaginary parts
448toComplexV :: (Vector Double, Vector Double) -> Vector (Complex Double)
449toComplexV (r,i) = asComplex $ flatten $ fromColumns [r,i]
450
451-- | extracts the real and imaginary parts of a complex vector
452fromComplexV :: Vector (Complex Double) -> (Vector Double, Vector Double)
453fromComplexV m = (a,b) where [a,b] = toColumns $ reshape 2 $ asReal m
454
455-- | creates a complex matrix from matrices with real and imaginary parts
456toComplexM :: (Matrix Double, Matrix Double) -> Matrix (Complex Double)
457toComplexM (r,i) = reshape (cols r) $ asComplex $ flatten $ fromColumns [flatten r, flatten i]
458
459-- | extracts the real and imaginary parts of a complex matrix
460fromComplexM :: Matrix (Complex Double) -> (Matrix Double, Matrix Double)
461fromComplexM m = (reshape c a, reshape c b)
462 where c = cols m
463 [a,b] = toColumns $ reshape 2 $ asReal $ flatten m
464
465fromComplex = fromComplexM
466
467pnorm 0 = LinearAlgebra.Algorithms.pnorm Infinity
468pnorm 1 = LinearAlgebra.Algorithms.pnorm PNorm1
469pnorm 2 = LinearAlgebra.Algorithms.pnorm PNorm2 \ No newline at end of file
diff --git a/lib/GSL/Matrix.hs b/lib/GSL/Matrix.hs
index 919c2d9..26c5e2a 100644
--- a/lib/GSL/Matrix.hs
+++ b/lib/GSL/Matrix.hs
@@ -19,7 +19,7 @@ module GSL.Matrix(
19 chol, 19 chol,
20 luSolveR, luSolveC, 20 luSolveR, luSolveC,
21 luR, luC, 21 luR, luC,
22 fromFile 22 fromFile, extractRows
23) where 23) where
24 24
25import Data.Packed.Internal 25import Data.Packed.Internal