summaryrefslogtreecommitdiff
path: root/packages/hmatrix/src
diff options
context:
space:
mode:
authorAlberto Ruiz <aruiz@um.es>2014-05-08 08:48:12 +0200
committerAlberto Ruiz <aruiz@um.es>2014-05-08 08:48:12 +0200
commit1925c123d7d8184a1d2ddc0a413e0fd2776e1083 (patch)
treefad79f909d9c3be53d68e6ebd67202650536d387 /packages/hmatrix/src
parenteb3f702d065a4a967bb754977233e6eec408fd1f (diff)
empty hmatrix-base
Diffstat (limited to 'packages/hmatrix/src')
-rw-r--r--packages/hmatrix/src/Data/Packed.hs29
-rw-r--r--packages/hmatrix/src/Data/Packed/Development.hs32
-rw-r--r--packages/hmatrix/src/Data/Packed/Foreign.hs100
-rw-r--r--packages/hmatrix/src/Data/Packed/Internal.hs26
-rw-r--r--packages/hmatrix/src/Data/Packed/Internal/Common.hs171
-rw-r--r--packages/hmatrix/src/Data/Packed/Internal/Matrix.hs491
-rw-r--r--packages/hmatrix/src/Data/Packed/Internal/Signatures.hs72
-rw-r--r--packages/hmatrix/src/Data/Packed/Internal/Vector.hs521
-rw-r--r--packages/hmatrix/src/Data/Packed/Matrix.hs490
-rw-r--r--packages/hmatrix/src/Data/Packed/Random.hs57
-rw-r--r--packages/hmatrix/src/Data/Packed/ST.hs179
-rw-r--r--packages/hmatrix/src/Data/Packed/Vector.hs96
-rw-r--r--packages/hmatrix/src/Graphics/Plot.hs184
-rw-r--r--packages/hmatrix/src/Numeric/Chain.hs140
-rw-r--r--packages/hmatrix/src/Numeric/Container.hs303
-rw-r--r--packages/hmatrix/src/Numeric/ContainerBoot.hs611
-rw-r--r--packages/hmatrix/src/Numeric/Conversion.hs91
-rw-r--r--packages/hmatrix/src/Numeric/GSL.hs43
-rw-r--r--packages/hmatrix/src/Numeric/GSL/Differentiation.hs87
-rw-r--r--packages/hmatrix/src/Numeric/GSL/Fitting.hs179
-rw-r--r--packages/hmatrix/src/Numeric/GSL/Fourier.hs47
-rw-r--r--packages/hmatrix/src/Numeric/GSL/Integration.hs250
-rw-r--r--packages/hmatrix/src/Numeric/GSL/Internal.hs76
-rw-r--r--packages/hmatrix/src/Numeric/GSL/Minimization.hs227
-rw-r--r--packages/hmatrix/src/Numeric/GSL/ODE.hs138
-rw-r--r--packages/hmatrix/src/Numeric/GSL/Polynomials.hs58
-rw-r--r--packages/hmatrix/src/Numeric/GSL/Root.hs199
-rw-r--r--packages/hmatrix/src/Numeric/GSL/Vector.hs328
-rw-r--r--packages/hmatrix/src/Numeric/GSL/gsl-aux.c1541
-rw-r--r--packages/hmatrix/src/Numeric/GSL/gsl-ode.c182
-rw-r--r--packages/hmatrix/src/Numeric/HMatrix.hs136
-rw-r--r--packages/hmatrix/src/Numeric/HMatrix/Data.hs69
-rw-r--r--packages/hmatrix/src/Numeric/HMatrix/Devel.hs69
-rw-r--r--packages/hmatrix/src/Numeric/IO.hs165
-rw-r--r--packages/hmatrix/src/Numeric/LinearAlgebra.hs30
-rw-r--r--packages/hmatrix/src/Numeric/LinearAlgebra/Algorithms.hs746
-rw-r--r--packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK.hs555
-rw-r--r--packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK/lapack-aux.c1489
-rw-r--r--packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK/lapack-aux.h60
-rw-r--r--packages/hmatrix/src/Numeric/LinearAlgebra/Util.hs295
-rw-r--r--packages/hmatrix/src/Numeric/LinearAlgebra/Util/Convolution.hs114
-rw-r--r--packages/hmatrix/src/Numeric/Matrix.hs98
-rw-r--r--packages/hmatrix/src/Numeric/Vector.hs158
43 files changed, 10932 insertions, 0 deletions
diff --git a/packages/hmatrix/src/Data/Packed.hs b/packages/hmatrix/src/Data/Packed.hs
new file mode 100644
index 0000000..957aab8
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed.hs
@@ -0,0 +1,29 @@
1-----------------------------------------------------------------------------
2{- |
3Module : Data.Packed
4Copyright : (c) Alberto Ruiz 2006-2010
5License : GPL-style
6
7Maintainer : Alberto Ruiz (aruiz at um dot es)
8Stability : provisional
9Portability : uses ffi
10
11Types for dense 'Vector' and 'Matrix' of 'Storable' elements.
12
13-}
14-----------------------------------------------------------------------------
15{-# OPTIONS_HADDOCK hide #-}
16
17module Data.Packed (
18 module Data.Packed.Vector,
19 module Data.Packed.Matrix,
20-- module Numeric.Conversion,
21-- module Data.Packed.Random,
22-- module Data.Complex
23) where
24
25import Data.Packed.Vector
26import Data.Packed.Matrix
27--import Data.Packed.Random
28--import Data.Complex
29--import Numeric.Conversion
diff --git a/packages/hmatrix/src/Data/Packed/Development.hs b/packages/hmatrix/src/Data/Packed/Development.hs
new file mode 100644
index 0000000..471e560
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Development.hs
@@ -0,0 +1,32 @@
1
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Development
5-- Copyright : (c) Alberto Ruiz 2009
6-- License : GPL
7--
8-- Maintainer : Alberto Ruiz <aruiz@um.es>
9-- Stability : provisional
10-- Portability : portable
11--
12-- The library can be easily extended with additional foreign functions
13-- using the tools in this module. Illustrative usage examples can be found
14-- in the @examples\/devel@ folder included in the package.
15--
16-----------------------------------------------------------------------------
17{-# OPTIONS_HADDOCK hide #-}
18
19module Data.Packed.Development (
20 createVector, createMatrix,
21 vec, mat,
22 app1, app2, app3, app4,
23 app5, app6, app7, app8, app9, app10,
24 MatrixOrder(..), orderOf, cmat, fmat,
25 matrixFromVector,
26 unsafeFromForeignPtr,
27 unsafeToForeignPtr,
28 check, (//),
29 at', atM'
30) where
31
32import Data.Packed.Internal
diff --git a/packages/hmatrix/src/Data/Packed/Foreign.hs b/packages/hmatrix/src/Data/Packed/Foreign.hs
new file mode 100644
index 0000000..1ec3694
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Foreign.hs
@@ -0,0 +1,100 @@
1{-# LANGUAGE MagicHash, UnboxedTuples #-}
2-- | FFI and hmatrix helpers.
3--
4-- Sample usage, to upload a perspective matrix to a shader.
5--
6-- @ glUniformMatrix4fv 0 1 (fromIntegral gl_TRUE) \`appMatrix\` perspective 0.01 100 (pi\/2) (4\/3)
7-- @
8--
9{-# OPTIONS_HADDOCK hide #-}
10module Data.Packed.Foreign
11 ( app
12 , appVector, appVectorLen
13 , appMatrix, appMatrixLen, appMatrixRaw, appMatrixRawLen
14 , unsafeMatrixToVector, unsafeMatrixToForeignPtr
15 ) where
16import Data.Packed.Internal
17import qualified Data.Vector.Storable as S
18import Foreign (Ptr, ForeignPtr, Storable)
19import Foreign.C.Types (CInt)
20import GHC.Base (IO(..), realWorld#)
21
22{-# INLINE unsafeInlinePerformIO #-}
23-- | If we use unsafePerformIO, it may not get inlined, so in a function that returns IO (which are all safe uses of app* in this module), there would be
24-- unecessary calls to unsafePerformIO or its internals.
25unsafeInlinePerformIO :: IO a -> a
26unsafeInlinePerformIO (IO f) = case f realWorld# of
27 (# _, x #) -> x
28
29{-# INLINE app #-}
30-- | Only useful since it is left associated with a precedence of 1, unlike 'Prelude.$', which is right associative.
31-- e.g.
32--
33-- @
34-- someFunction
35-- \`appMatrixLen\` m
36-- \`appVectorLen\` v
37-- \`app\` other
38-- \`app\` arguments
39-- \`app\` go here
40-- @
41--
42-- One could also write:
43--
44-- @
45-- (someFunction
46-- \`appMatrixLen\` m
47-- \`appVectorLen\` v)
48-- other
49-- arguments
50-- (go here)
51-- @
52--
53app :: (a -> b) -> a -> b
54app f = f
55
56{-# INLINE appVector #-}
57appVector :: Storable a => (Ptr a -> b) -> Vector a -> b
58appVector f x = unsafeInlinePerformIO (S.unsafeWith x (return . f))
59
60{-# INLINE appVectorLen #-}
61appVectorLen :: Storable a => (CInt -> Ptr a -> b) -> Vector a -> b
62appVectorLen f x = unsafeInlinePerformIO (S.unsafeWith x (return . f (fromIntegral (S.length x))))
63
64{-# INLINE appMatrix #-}
65appMatrix :: Element a => (Ptr a -> b) -> Matrix a -> b
66appMatrix f x = unsafeInlinePerformIO (S.unsafeWith (flatten x) (return . f))
67
68{-# INLINE appMatrixLen #-}
69appMatrixLen :: Element a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
70appMatrixLen f x = unsafeInlinePerformIO (S.unsafeWith (flatten x) (return . f r c))
71 where
72 r = fromIntegral (rows x)
73 c = fromIntegral (cols x)
74
75{-# INLINE appMatrixRaw #-}
76appMatrixRaw :: Storable a => (Ptr a -> b) -> Matrix a -> b
77appMatrixRaw f x = unsafeInlinePerformIO (S.unsafeWith (xdat x) (return . f))
78
79{-# INLINE appMatrixRawLen #-}
80appMatrixRawLen :: Element a => (CInt -> CInt -> Ptr a -> b) -> Matrix a -> b
81appMatrixRawLen f x = unsafeInlinePerformIO (S.unsafeWith (xdat x) (return . f r c))
82 where
83 r = fromIntegral (rows x)
84 c = fromIntegral (cols x)
85
86infixl 1 `app`
87infixl 1 `appVector`
88infixl 1 `appMatrix`
89infixl 1 `appMatrixRaw`
90
91{-# INLINE unsafeMatrixToVector #-}
92-- | This will disregard the order of the matrix, and simply return it as-is.
93-- If the order of the matrix is RowMajor, this function is identical to 'flatten'.
94unsafeMatrixToVector :: Matrix a -> Vector a
95unsafeMatrixToVector = xdat
96
97{-# INLINE unsafeMatrixToForeignPtr #-}
98unsafeMatrixToForeignPtr :: Storable a => Matrix a -> (ForeignPtr a, Int)
99unsafeMatrixToForeignPtr m = S.unsafeToForeignPtr0 (xdat m)
100
diff --git a/packages/hmatrix/src/Data/Packed/Internal.hs b/packages/hmatrix/src/Data/Packed/Internal.hs
new file mode 100644
index 0000000..537e51e
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Internal.hs
@@ -0,0 +1,26 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Data.Packed.Internal
4-- Copyright : (c) Alberto Ruiz 2007
5-- License : GPL-style
6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es>
8-- Stability : provisional
9-- Portability : portable
10--
11-- Reexports all internal modules
12--
13-----------------------------------------------------------------------------
14-- #hide
15
16module Data.Packed.Internal (
17 module Data.Packed.Internal.Common,
18 module Data.Packed.Internal.Signatures,
19 module Data.Packed.Internal.Vector,
20 module Data.Packed.Internal.Matrix,
21) where
22
23import Data.Packed.Internal.Common
24import Data.Packed.Internal.Signatures
25import Data.Packed.Internal.Vector
26import Data.Packed.Internal.Matrix
diff --git a/packages/hmatrix/src/Data/Packed/Internal/Common.hs b/packages/hmatrix/src/Data/Packed/Internal/Common.hs
new file mode 100644
index 0000000..edef3c2
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Internal/Common.hs
@@ -0,0 +1,171 @@
1{-# LANGUAGE CPP #-}
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Internal.Common
5-- Copyright : (c) Alberto Ruiz 2007
6-- License : GPL-style
7--
8-- Maintainer : Alberto Ruiz <aruiz@um.es>
9-- Stability : provisional
10-- Portability : portable (uses FFI)
11--
12-- Development utilities.
13--
14-----------------------------------------------------------------------------
15-- #hide
16
17module Data.Packed.Internal.Common(
18 Adapt,
19 app1, app2, app3, app4,
20 app5, app6, app7, app8, app9, app10,
21 (//), check, mbCatch,
22 splitEvery, common, compatdim,
23 fi,
24 table
25) where
26
27import Foreign
28import Control.Monad(when)
29import Foreign.C.String(peekCString)
30import Foreign.C.Types
31import Foreign.Storable.Complex()
32import Data.List(transpose,intersperse)
33import Control.Exception as E
34
35-- | @splitEvery 3 [1..9] == [[1,2,3],[4,5,6],[7,8,9]]@
36splitEvery :: Int -> [a] -> [[a]]
37splitEvery _ [] = []
38splitEvery k l = take k l : splitEvery k (drop k l)
39
40-- | obtains the common value of a property of a list
41common :: (Eq a) => (b->a) -> [b] -> Maybe a
42common f = commonval . map f where
43 commonval :: (Eq a) => [a] -> Maybe a
44 commonval [] = Nothing
45 commonval [a] = Just a
46 commonval (a:b:xs) = if a==b then commonval (b:xs) else Nothing
47
48-- | common value with \"adaptable\" 1
49compatdim :: [Int] -> Maybe Int
50compatdim [] = Nothing
51compatdim [a] = Just a
52compatdim (a:b:xs)
53 | a==b = compatdim (b:xs)
54 | a==1 = compatdim (b:xs)
55 | b==1 = compatdim (a:xs)
56 | otherwise = Nothing
57
58-- | Formatting tool
59table :: String -> [[String]] -> String
60table sep as = unlines . map unwords' $ transpose mtp where
61 mt = transpose as
62 longs = map (maximum . map length) mt
63 mtp = zipWith (\a b -> map (pad a) b) longs mt
64 pad n str = replicate (n - length str) ' ' ++ str
65 unwords' = concat . intersperse sep
66
67-- | postfix function application (@flip ($)@)
68(//) :: x -> (x -> y) -> y
69infixl 0 //
70(//) = flip ($)
71
72-- | specialized fromIntegral
73fi :: Int -> CInt
74fi = fromIntegral
75
76-- hmm..
77ww2 w1 o1 w2 o2 f = w1 o1 $ w2 o2 . f
78ww3 w1 o1 w2 o2 w3 o3 f = w1 o1 $ ww2 w2 o2 w3 o3 . f
79ww4 w1 o1 w2 o2 w3 o3 w4 o4 f = w1 o1 $ ww3 w2 o2 w3 o3 w4 o4 . f
80ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 f = w1 o1 $ ww4 w2 o2 w3 o3 w4 o4 w5 o5 . f
81ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 f = w1 o1 $ ww5 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 . f
82ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 f = w1 o1 $ ww6 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 . f
83ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 f = w1 o1 $ ww7 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 . f
84ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 f = w1 o1 $ ww8 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 . f
85ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 f = w1 o1 $ ww9 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 . f
86
87type Adapt f t r = t -> ((f -> r) -> IO()) -> IO()
88
89type Adapt1 f t1 = Adapt f t1 (IO CInt) -> t1 -> String -> IO()
90type Adapt2 f t1 r1 t2 = Adapt f t1 r1 -> t1 -> Adapt1 r1 t2
91type Adapt3 f t1 r1 t2 r2 t3 = Adapt f t1 r1 -> t1 -> Adapt2 r1 t2 r2 t3
92type Adapt4 f t1 r1 t2 r2 t3 r3 t4 = Adapt f t1 r1 -> t1 -> Adapt3 r1 t2 r2 t3 r3 t4
93type Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5 = Adapt f t1 r1 -> t1 -> Adapt4 r1 t2 r2 t3 r3 t4 r4 t5
94type Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 = Adapt f t1 r1 -> t1 -> Adapt5 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6
95type Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 = Adapt f t1 r1 -> t1 -> Adapt6 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7
96type Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 = Adapt f t1 r1 -> t1 -> Adapt7 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8
97type Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 = Adapt f t1 r1 -> t1 -> Adapt8 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9
98type Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10 = Adapt f t1 r1 -> t1 -> Adapt9 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10
99
100app1 :: f -> Adapt1 f t1
101app2 :: f -> Adapt2 f t1 r1 t2
102app3 :: f -> Adapt3 f t1 r1 t2 r2 t3
103app4 :: f -> Adapt4 f t1 r1 t2 r2 t3 r3 t4
104app5 :: f -> Adapt5 f t1 r1 t2 r2 t3 r3 t4 r4 t5
105app6 :: f -> Adapt6 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6
106app7 :: f -> Adapt7 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7
107app8 :: f -> Adapt8 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8
108app9 :: f -> Adapt9 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9
109app10 :: f -> Adapt10 f t1 r1 t2 r2 t3 r3 t4 r4 t5 r5 t6 r6 t7 r7 t8 r8 t9 r9 t10
110
111app1 f w1 o1 s = w1 o1 $ \a1 -> f // a1 // check s
112app2 f w1 o1 w2 o2 s = ww2 w1 o1 w2 o2 $ \a1 a2 -> f // a1 // a2 // check s
113app3 f w1 o1 w2 o2 w3 o3 s = ww3 w1 o1 w2 o2 w3 o3 $
114 \a1 a2 a3 -> f // a1 // a2 // a3 // check s
115app4 f w1 o1 w2 o2 w3 o3 w4 o4 s = ww4 w1 o1 w2 o2 w3 o3 w4 o4 $
116 \a1 a2 a3 a4 -> f // a1 // a2 // a3 // a4 // check s
117app5 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 s = ww5 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 $
118 \a1 a2 a3 a4 a5 -> f // a1 // a2 // a3 // a4 // a5 // check s
119app6 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 s = ww6 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 $
120 \a1 a2 a3 a4 a5 a6 -> f // a1 // a2 // a3 // a4 // a5 // a6 // check s
121app7 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 s = ww7 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 $
122 \a1 a2 a3 a4 a5 a6 a7 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // check s
123app8 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 s = ww8 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 $
124 \a1 a2 a3 a4 a5 a6 a7 a8 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // check s
125app9 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 s = ww9 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 $
126 \a1 a2 a3 a4 a5 a6 a7 a8 a9 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // check s
127app10 f w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 s = ww10 w1 o1 w2 o2 w3 o3 w4 o4 w5 o5 w6 o6 w7 o7 w8 o8 w9 o9 w10 o10 $
128 \a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 -> f // a1 // a2 // a3 // a4 // a5 // a6 // a7 // a8 // a9 // a10 // check s
129
130
131
132-- GSL error codes are <= 1024
133-- | error codes for the auxiliary functions required by the wrappers
134errorCode :: CInt -> String
135errorCode 2000 = "bad size"
136errorCode 2001 = "bad function code"
137errorCode 2002 = "memory problem"
138errorCode 2003 = "bad file"
139errorCode 2004 = "singular"
140errorCode 2005 = "didn't converge"
141errorCode 2006 = "the input matrix is not positive definite"
142errorCode 2007 = "not yet supported in this OS"
143errorCode n = "code "++show n
144
145
146-- | clear the fpu
147foreign import ccall unsafe "asm_finit" finit :: IO ()
148
149-- | check the error code
150check :: String -> IO CInt -> IO ()
151check msg f = do
152#if FINIT
153 finit
154#endif
155 err <- f
156 when (err/=0) $ if err > 1024
157 then (error (msg++": "++errorCode err)) -- our errors
158 else do -- GSL errors
159 ps <- gsl_strerror err
160 s <- peekCString ps
161 error (msg++": "++s)
162 return ()
163
164-- | description of GSL error codes
165foreign import ccall unsafe "gsl_strerror" gsl_strerror :: CInt -> IO (Ptr CChar)
166
167-- | Error capture and conversion to Maybe
168mbCatch :: IO x -> IO (Maybe x)
169mbCatch act = E.catch (Just `fmap` act) f
170 where f :: SomeException -> IO (Maybe x)
171 f _ = return Nothing
diff --git a/packages/hmatrix/src/Data/Packed/Internal/Matrix.hs b/packages/hmatrix/src/Data/Packed/Internal/Matrix.hs
new file mode 100644
index 0000000..9719fc0
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Internal/Matrix.hs
@@ -0,0 +1,491 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE BangPatterns #-}
5-----------------------------------------------------------------------------
6-- |
7-- Module : Data.Packed.Internal.Matrix
8-- Copyright : (c) Alberto Ruiz 2007
9-- License : GPL-style
10--
11-- Maintainer : Alberto Ruiz <aruiz@um.es>
12-- Stability : provisional
13-- Portability : portable (uses FFI)
14--
15-- Internal matrix representation
16--
17-----------------------------------------------------------------------------
18-- #hide
19
20module Data.Packed.Internal.Matrix(
21 Matrix(..), rows, cols, cdat, fdat,
22 MatrixOrder(..), orderOf,
23 createMatrix, mat,
24 cmat, fmat,
25 toLists, flatten, reshape,
26 Element(..),
27 trans,
28 fromRows, toRows, fromColumns, toColumns,
29 matrixFromVector,
30 subMatrix,
31 liftMatrix, liftMatrix2,
32 (@@>), atM',
33 saveMatrix,
34 singleton,
35 emptyM,
36 size, shSize, conformVs, conformMs, conformVTo, conformMTo
37) where
38
39import Data.Packed.Internal.Common
40import Data.Packed.Internal.Signatures
41import Data.Packed.Internal.Vector
42
43import Foreign.Marshal.Alloc(alloca, free)
44import Foreign.Marshal.Array(newArray)
45import Foreign.Ptr(Ptr, castPtr)
46import Foreign.Storable(Storable, peekElemOff, pokeElemOff, poke, sizeOf)
47import Data.Complex(Complex)
48import Foreign.C.Types
49import Foreign.C.String(newCString)
50import System.IO.Unsafe(unsafePerformIO)
51import Control.DeepSeq
52
53-----------------------------------------------------------------
54
55{- Design considerations for the Matrix Type
56 -----------------------------------------
57
58- we must easily handle both row major and column major order,
59 for bindings to LAPACK and GSL/C
60
61- we'd like to simplify redundant matrix transposes:
62 - Some of them arise from the order requirements of some functions
63 - some functions (matrix product) admit transposed arguments
64
65- maybe we don't really need this kind of simplification:
66 - more complex code
67 - some computational overhead
68 - only appreciable gain in code with a lot of redundant transpositions
69 and cheap matrix computations
70
71- we could carry both the matrix and its (lazily computed) transpose.
72 This may save some transpositions, but it is necessary to keep track of the
73 data which is actually computed to be used by functions like the matrix product
74 which admit both orders.
75
76- but if we need the transposed data and it is not in the structure, we must make
77 sure that we touch the same foreignptr that is used in the computation.
78
79- a reasonable solution is using two constructors for a matrix. Transposition just
80 "flips" the constructor. Actual data transposition is not done if followed by a
81 matrix product or another transpose.
82
83-}
84
85data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
86
87transOrder RowMajor = ColumnMajor
88transOrder ColumnMajor = RowMajor
89{- | Matrix representation suitable for GSL and LAPACK computations.
90
91The elements are stored in a continuous memory array.
92
93-}
94
95data Matrix t = Matrix { irows :: {-# UNPACK #-} !Int
96 , icols :: {-# UNPACK #-} !Int
97 , xdat :: {-# UNPACK #-} !(Vector t)
98 , order :: !MatrixOrder }
99-- RowMajor: preferred by C, fdat may require a transposition
100-- ColumnMajor: preferred by LAPACK, cdat may require a transposition
101
102cdat = xdat
103fdat = xdat
104
105rows :: Matrix t -> Int
106rows = irows
107
108cols :: Matrix t -> Int
109cols = icols
110
111orderOf :: Matrix t -> MatrixOrder
112orderOf = order
113
114
115-- | Matrix transpose.
116trans :: Matrix t -> Matrix t
117trans Matrix {irows = r, icols = c, xdat = d, order = o } = Matrix { irows = c, icols = r, xdat = d, order = transOrder o}
118
119cmat :: (Element t) => Matrix t -> Matrix t
120cmat m@Matrix{order = RowMajor} = m
121cmat Matrix {irows = r, icols = c, xdat = d, order = ColumnMajor } = Matrix { irows = r, icols = c, xdat = transdata r d c, order = RowMajor}
122
123fmat :: (Element t) => Matrix t -> Matrix t
124fmat m@Matrix{order = ColumnMajor} = m
125fmat Matrix {irows = r, icols = c, xdat = d, order = RowMajor } = Matrix { irows = r, icols = c, xdat = transdata c d r, order = ColumnMajor}
126
127-- C-Haskell matrix adapter
128-- mat :: Adapt (CInt -> CInt -> Ptr t -> r) (Matrix t) r
129
130mat :: (Storable t) => Matrix t -> (((CInt -> CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
131mat a f =
132 unsafeWith (xdat a) $ \p -> do
133 let m g = do
134 g (fi (rows a)) (fi (cols a)) p
135 f m
136
137{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
138
139>>> flatten (ident 3)
140fromList [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
141
142-}
143flatten :: Element t => Matrix t -> Vector t
144flatten = xdat . cmat
145
146{-
147type Mt t s = Int -> Int -> Ptr t -> s
148
149infixr 6 ::>
150type t ::> s = Mt t s
151-}
152
153-- | the inverse of 'Data.Packed.Matrix.fromLists'
154toLists :: (Element t) => Matrix t -> [[t]]
155toLists m = splitEvery (cols m) . toList . flatten $ m
156
157-- | Create a matrix from a list of vectors.
158-- All vectors must have the same dimension,
159-- or dimension 1, which is are automatically expanded.
160fromRows :: Element t => [Vector t] -> Matrix t
161fromRows [] = emptyM 0 0
162fromRows vs = case compatdim (map dim vs) of
163 Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
164 Just 0 -> emptyM r 0
165 Just c -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
166 where
167 r = length vs
168 adapt c v
169 | c == 0 = fromList[]
170 | dim v == c = v
171 | otherwise = constantD (v@>0) c
172
173-- | extracts the rows of a matrix as a list of vectors
174toRows :: Element t => Matrix t -> [Vector t]
175toRows m
176 | c == 0 = replicate r (fromList[])
177 | otherwise = toRows' 0
178 where
179 v = flatten m
180 r = rows m
181 c = cols m
182 toRows' k | k == r*c = []
183 | otherwise = subVector k c v : toRows' (k+c)
184
185-- | Creates a matrix from a list of vectors, as columns
186fromColumns :: Element t => [Vector t] -> Matrix t
187fromColumns m = trans . fromRows $ m
188
189-- | Creates a list of vectors from the columns of a matrix
190toColumns :: Element t => Matrix t -> [Vector t]
191toColumns m = toRows . trans $ m
192
193-- | Reads a matrix position.
194(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
195infixl 9 @@>
196m@Matrix {irows = r, icols = c} @@> (i,j)
197 | safe = if i<0 || i>=r || j<0 || j>=c
198 then error "matrix indexing out of range"
199 else atM' m i j
200 | otherwise = atM' m i j
201{-# INLINE (@@>) #-}
202
203-- Unsafe matrix access without range checking
204atM' Matrix {icols = c, xdat = v, order = RowMajor} i j = v `at'` (i*c+j)
205atM' Matrix {irows = r, xdat = v, order = ColumnMajor} i j = v `at'` (j*r+i)
206{-# INLINE atM' #-}
207
208------------------------------------------------------------------
209
210matrixFromVector o r c v
211 | r * c == dim v = m
212 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
213 where
214 m = Matrix { irows = r, icols = c, xdat = v, order = o }
215
216-- allocates memory for a new matrix
217createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
218createMatrix ord r c = do
219 p <- createVector (r*c)
220 return (matrixFromVector ord r c p)
221
222{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = trans . reshape r@
223where r is the desired number of rows.)
224
225>>> reshape 4 (fromList [1..12])
226(3><4)
227 [ 1.0, 2.0, 3.0, 4.0
228 , 5.0, 6.0, 7.0, 8.0
229 , 9.0, 10.0, 11.0, 12.0 ]
230
231-}
232reshape :: Storable t => Int -> Vector t -> Matrix t
233reshape 0 v = matrixFromVector RowMajor 0 0 v
234reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
235
236singleton x = reshape 1 (fromList [x])
237
238-- | application of a vector function on the flattened matrix elements
239liftMatrix :: (Storable a, Storable b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
240liftMatrix f Matrix { irows = r, icols = c, xdat = d, order = o } = matrixFromVector o r c (f d)
241
242-- | application of a vector function on the flattened matrices elements
243liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
244liftMatrix2 f m1 m2
245 | not (compat m1 m2) = error "nonconformant matrices in liftMatrix2"
246 | otherwise = case orderOf m1 of
247 RowMajor -> matrixFromVector RowMajor (rows m1) (cols m1) (f (xdat m1) (flatten m2))
248 ColumnMajor -> matrixFromVector ColumnMajor (rows m1) (cols m1) (f (xdat m1) ((xdat.fmat) m2))
249
250
251compat :: Matrix a -> Matrix b -> Bool
252compat m1 m2 = rows m1 == rows m2 && cols m1 == cols m2
253
254------------------------------------------------------------------
255
256{- | Supported matrix elements.
257
258 This class provides optimized internal
259 operations for selected element types.
260 It provides unoptimised defaults for any 'Storable' type,
261 so you can create instances simply as:
262 @instance Element Foo@.
263-}
264class (Storable a) => Element a where
265 subMatrixD :: (Int,Int) -- ^ (r0,c0) starting position
266 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
267 -> Matrix a -> Matrix a
268 subMatrixD = subMatrix'
269 transdata :: Int -> Vector a -> Int -> Vector a
270 transdata = transdataP -- transdata'
271 constantD :: a -> Int -> Vector a
272 constantD = constantP -- constant'
273
274
275instance Element Float where
276 transdata = transdataAux ctransF
277 constantD = constantAux cconstantF
278
279instance Element Double where
280 transdata = transdataAux ctransR
281 constantD = constantAux cconstantR
282
283instance Element (Complex Float) where
284 transdata = transdataAux ctransQ
285 constantD = constantAux cconstantQ
286
287instance Element (Complex Double) where
288 transdata = transdataAux ctransC
289 constantD = constantAux cconstantC
290
291-------------------------------------------------------------------
292
293transdata' :: Storable a => Int -> Vector a -> Int -> Vector a
294transdata' c1 v c2 =
295 if noneed
296 then v
297 else unsafePerformIO $ do
298 w <- createVector (r2*c2)
299 unsafeWith v $ \p ->
300 unsafeWith w $ \q -> do
301 let go (-1) _ = return ()
302 go !i (-1) = go (i-1) (c1-1)
303 go !i !j = do x <- peekElemOff p (i*c1+j)
304 pokeElemOff q (j*c2+i) x
305 go i (j-1)
306 go (r1-1) (c1-1)
307 return w
308 where r1 = dim v `div` c1
309 r2 = dim v `div` c2
310 noneed = dim v == 0 || r1 == 1 || c1 == 1
311
312-- {-# SPECIALIZE transdata' :: Int -> Vector Double -> Int -> Vector Double #-}
313-- {-# SPECIALIZE transdata' :: Int -> Vector (Complex Double) -> Int -> Vector (Complex Double) #-}
314
315-- I don't know how to specialize...
316-- The above pragmas only seem to work on top level defs
317-- Fortunately everything seems to work using the above class
318
319-- C versions, still a little faster:
320
321transdataAux fun c1 d c2 =
322 if noneed
323 then d
324 else unsafePerformIO $ do
325 v <- createVector (dim d)
326 unsafeWith d $ \pd ->
327 unsafeWith v $ \pv ->
328 fun (fi r1) (fi c1) pd (fi r2) (fi c2) pv // check "transdataAux"
329 return v
330 where r1 = dim d `div` c1
331 r2 = dim d `div` c2
332 noneed = dim d == 0 || r1 == 1 || c1 == 1
333
334transdataP :: Storable a => Int -> Vector a -> Int -> Vector a
335transdataP c1 d c2 =
336 if noneed
337 then d
338 else unsafePerformIO $ do
339 v <- createVector (dim d)
340 unsafeWith d $ \pd ->
341 unsafeWith v $ \pv ->
342 ctransP (fi r1) (fi c1) (castPtr pd) (fi sz) (fi r2) (fi c2) (castPtr pv) (fi sz) // check "transdataP"
343 return v
344 where r1 = dim d `div` c1
345 r2 = dim d `div` c2
346 sz = sizeOf (d @> 0)
347 noneed = dim d == 0 || r1 == 1 || c1 == 1
348
349foreign import ccall unsafe "transF" ctransF :: TFMFM
350foreign import ccall unsafe "transR" ctransR :: TMM
351foreign import ccall unsafe "transQ" ctransQ :: TQMQM
352foreign import ccall unsafe "transC" ctransC :: TCMCM
353foreign import ccall unsafe "transP" ctransP :: CInt -> CInt -> Ptr () -> CInt -> CInt -> CInt -> Ptr () -> CInt -> IO CInt
354
355----------------------------------------------------------------------
356
357constant' v n = unsafePerformIO $ do
358 w <- createVector n
359 unsafeWith w $ \p -> do
360 let go (-1) = return ()
361 go !k = pokeElemOff p k v >> go (k-1)
362 go (n-1)
363 return w
364
365-- C versions
366
367constantAux fun x n = unsafePerformIO $ do
368 v <- createVector n
369 px <- newArray [x]
370 app1 (fun px) vec v "constantAux"
371 free px
372 return v
373
374constantF :: Float -> Int -> Vector Float
375constantF = constantAux cconstantF
376foreign import ccall unsafe "constantF" cconstantF :: Ptr Float -> TF
377
378constantR :: Double -> Int -> Vector Double
379constantR = constantAux cconstantR
380foreign import ccall unsafe "constantR" cconstantR :: Ptr Double -> TV
381
382constantQ :: Complex Float -> Int -> Vector (Complex Float)
383constantQ = constantAux cconstantQ
384foreign import ccall unsafe "constantQ" cconstantQ :: Ptr (Complex Float) -> TQV
385
386constantC :: Complex Double -> Int -> Vector (Complex Double)
387constantC = constantAux cconstantC
388foreign import ccall unsafe "constantC" cconstantC :: Ptr (Complex Double) -> TCV
389
390constantP :: Storable a => a -> Int -> Vector a
391constantP a n = unsafePerformIO $ do
392 let sz = sizeOf a
393 v <- createVector n
394 unsafeWith v $ \p -> do
395 alloca $ \k -> do
396 poke k a
397 cconstantP (castPtr k) (fi n) (castPtr p) (fi sz) // check "constantP"
398 return v
399foreign import ccall unsafe "constantP" cconstantP :: Ptr () -> CInt -> Ptr () -> CInt -> IO CInt
400
401----------------------------------------------------------------------
402
403-- | Extracts a submatrix from a matrix.
404subMatrix :: Element a
405 => (Int,Int) -- ^ (r0,c0) starting position
406 -> (Int,Int) -- ^ (rt,ct) dimensions of submatrix
407 -> Matrix a -- ^ input matrix
408 -> Matrix a -- ^ result
409subMatrix (r0,c0) (rt,ct) m
410 | 0 <= r0 && 0 <= rt && r0+rt <= (rows m) &&
411 0 <= c0 && 0 <= ct && c0+ct <= (cols m) = subMatrixD (r0,c0) (rt,ct) m
412 | otherwise = error $ "wrong subMatrix "++
413 show ((r0,c0),(rt,ct))++" of "++show(rows m)++"x"++ show (cols m)
414
415subMatrix'' (r0,c0) (rt,ct) c v = unsafePerformIO $ do
416 w <- createVector (rt*ct)
417 unsafeWith v $ \p ->
418 unsafeWith w $ \q -> do
419 let go (-1) _ = return ()
420 go !i (-1) = go (i-1) (ct-1)
421 go !i !j = do x <- peekElemOff p ((i+r0)*c+j+c0)
422 pokeElemOff q (i*ct+j) x
423 go i (j-1)
424 go (rt-1) (ct-1)
425 return w
426
427subMatrix' (r0,c0) (rt,ct) (Matrix { icols = c, xdat = v, order = RowMajor}) = Matrix rt ct (subMatrix'' (r0,c0) (rt,ct) c v) RowMajor
428subMatrix' (r0,c0) (rt,ct) m = trans $ subMatrix' (c0,r0) (ct,rt) (trans m)
429
430--------------------------------------------------------------------------
431
432-- | Saves a matrix as 2D ASCII table.
433saveMatrix :: FilePath
434 -> String -- ^ format (%f, %g, %e)
435 -> Matrix Double
436 -> IO ()
437saveMatrix filename fmt m = do
438 charname <- newCString filename
439 charfmt <- newCString fmt
440 let o = if orderOf m == RowMajor then 1 else 0
441 app1 (matrix_fprintf charname charfmt o) mat m "matrix_fprintf"
442 free charname
443 free charfmt
444
445foreign import ccall unsafe "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM
446
447----------------------------------------------------------------------
448
449maxZ xs = if minimum xs == 0 then 0 else maximum xs
450
451conformMs ms = map (conformMTo (r,c)) ms
452 where
453 r = maxZ (map rows ms)
454 c = maxZ (map cols ms)
455
456
457conformVs vs = map (conformVTo n) vs
458 where
459 n = maxZ (map dim vs)
460
461conformMTo (r,c) m
462 | size m == (r,c) = m
463 | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
464 | size m == (r,1) = repCols c m
465 | size m == (1,c) = repRows r m
466 | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to (" ++ show r ++ "><"++ show c ++")"
467
468conformVTo n v
469 | dim v == n = v
470 | dim v == 1 = constantD (v@>0) n
471 | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
472
473repRows n x = fromRows (replicate n (flatten x))
474repCols n x = fromColumns (replicate n (flatten x))
475
476size m = (rows m, cols m)
477
478shSize m = "(" ++ show (rows m) ++"><"++ show (cols m)++")"
479
480emptyM r c = matrixFromVector RowMajor r c (fromList[])
481
482----------------------------------------------------------------------
483
484instance (Storable t, NFData t) => NFData (Matrix t)
485 where
486 rnf m | d > 0 = rnf (v @> 0)
487 | otherwise = ()
488 where
489 d = dim v
490 v = xdat m
491
diff --git a/packages/hmatrix/src/Data/Packed/Internal/Signatures.hs b/packages/hmatrix/src/Data/Packed/Internal/Signatures.hs
new file mode 100644
index 0000000..2835720
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Internal/Signatures.hs
@@ -0,0 +1,72 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Data.Packed.Internal.Signatures
4-- Copyright : (c) Alberto Ruiz 2009
5-- License : GPL-style
6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es>
8-- Stability : provisional
9-- Portability : portable (uses FFI)
10--
11-- Signatures of the C functions.
12--
13-----------------------------------------------------------------------------
14
15module Data.Packed.Internal.Signatures where
16
17import Foreign.Ptr(Ptr)
18import Data.Complex(Complex)
19import Foreign.C.Types(CInt)
20
21type PF = Ptr Float --
22type PD = Ptr Double --
23type PQ = Ptr (Complex Float) --
24type PC = Ptr (Complex Double) --
25type TF = CInt -> PF -> IO CInt --
26type TFF = CInt -> PF -> TF --
27type TFV = CInt -> PF -> TV --
28type TVF = CInt -> PD -> TF --
29type TFFF = CInt -> PF -> TFF --
30type TV = CInt -> PD -> IO CInt --
31type TVV = CInt -> PD -> TV --
32type TVVV = CInt -> PD -> TVV --
33type TFM = CInt -> CInt -> PF -> IO CInt --
34type TFMFM = CInt -> CInt -> PF -> TFM --
35type TFMFMFM = CInt -> CInt -> PF -> TFMFM --
36type TM = CInt -> CInt -> PD -> IO CInt --
37type TMM = CInt -> CInt -> PD -> TM --
38type TVMM = CInt -> PD -> TMM --
39type TMVMM = CInt -> CInt -> PD -> TVMM --
40type TMMM = CInt -> CInt -> PD -> TMM --
41type TVM = CInt -> PD -> TM --
42type TVVM = CInt -> PD -> TVM --
43type TMV = CInt -> CInt -> PD -> TV --
44type TMMV = CInt -> CInt -> PD -> TMV --
45type TMVM = CInt -> CInt -> PD -> TVM --
46type TMMVM = CInt -> CInt -> PD -> TMVM --
47type TCM = CInt -> CInt -> PC -> IO CInt --
48type TCVCM = CInt -> PC -> TCM --
49type TCMCVCM = CInt -> CInt -> PC -> TCVCM --
50type TMCMCVCM = CInt -> CInt -> PD -> TCMCVCM --
51type TCMCMCVCM = CInt -> CInt -> PC -> TCMCVCM --
52type TCMCM = CInt -> CInt -> PC -> TCM --
53type TVCM = CInt -> PD -> TCM --
54type TCMVCM = CInt -> CInt -> PC -> TVCM --
55type TCMCMVCM = CInt -> CInt -> PC -> TCMVCM --
56type TCMCMCM = CInt -> CInt -> PC -> TCMCM --
57type TCV = CInt -> PC -> IO CInt --
58type TCVCV = CInt -> PC -> TCV --
59type TCVCVCV = CInt -> PC -> TCVCV --
60type TCVV = CInt -> PC -> TV --
61type TQV = CInt -> PQ -> IO CInt --
62type TQVQV = CInt -> PQ -> TQV --
63type TQVQVQV = CInt -> PQ -> TQVQV --
64type TQVF = CInt -> PQ -> TF --
65type TQM = CInt -> CInt -> PQ -> IO CInt --
66type TQMQM = CInt -> CInt -> PQ -> TQM --
67type TQMQMQM = CInt -> CInt -> PQ -> TQMQM --
68type TCMCV = CInt -> CInt -> PC -> TCV --
69type TVCV = CInt -> PD -> TCV --
70type TCVM = CInt -> PC -> TM --
71type TMCVM = CInt -> CInt -> PD -> TCVM --
72type TMMCVM = CInt -> CInt -> PD -> TMCVM --
diff --git a/packages/hmatrix/src/Data/Packed/Internal/Vector.hs b/packages/hmatrix/src/Data/Packed/Internal/Vector.hs
new file mode 100644
index 0000000..6d03438
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Internal/Vector.hs
@@ -0,0 +1,521 @@
1{-# LANGUAGE MagicHash, CPP, UnboxedTuples, BangPatterns, FlexibleContexts #-}
2-----------------------------------------------------------------------------
3-- |
4-- Module : Data.Packed.Internal.Vector
5-- Copyright : (c) Alberto Ruiz 2007
6-- License : GPL-style
7--
8-- Maintainer : Alberto Ruiz <aruiz@um.es>
9-- Stability : provisional
10-- Portability : portable (uses FFI)
11--
12-- Vector implementation
13--
14-----------------------------------------------------------------------------
15
16module Data.Packed.Internal.Vector (
17 Vector, dim,
18 fromList, toList, (|>),
19 vjoin, (@>), safe, at, at', subVector, takesV,
20 mapVector, mapVectorWithIndex, zipVectorWith, unzipVectorWith,
21 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_,
22 foldVector, foldVectorG, foldLoop, foldVectorWithIndex,
23 createVector, vec,
24 asComplex, asReal, float2DoubleV, double2FloatV,
25 stepF, stepD, condF, condD,
26 conjugateQ, conjugateC,
27 fwriteVector, freadVector, fprintfVector, fscanfVector,
28 cloneVector,
29 unsafeToForeignPtr,
30 unsafeFromForeignPtr,
31 unsafeWith
32) where
33
34import Data.Packed.Internal.Common
35import Data.Packed.Internal.Signatures
36import Foreign.Marshal.Alloc(free)
37import Foreign.Marshal.Array(peekArray, copyArray, advancePtr)
38import Foreign.ForeignPtr(ForeignPtr, castForeignPtr)
39import Foreign.Ptr(Ptr)
40import Foreign.Storable(Storable, peekElemOff, pokeElemOff, sizeOf)
41import Foreign.C.String
42import Foreign.C.Types
43import Data.Complex
44import Control.Monad(when)
45import System.IO.Unsafe(unsafePerformIO)
46
47#if __GLASGOW_HASKELL__ >= 605
48import GHC.ForeignPtr (mallocPlainForeignPtrBytes)
49#else
50import Foreign.ForeignPtr (mallocForeignPtrBytes)
51#endif
52
53import GHC.Base
54#if __GLASGOW_HASKELL__ < 612
55import GHC.IOBase hiding (liftIO)
56#endif
57
58import qualified Data.Vector.Storable as Vector
59import Data.Vector.Storable(Vector,
60 fromList,
61 unsafeToForeignPtr,
62 unsafeFromForeignPtr,
63 unsafeWith)
64
65
66-- | Number of elements
67dim :: (Storable t) => Vector t -> Int
68dim = Vector.length
69
70
71-- C-Haskell vector adapter
72-- vec :: Adapt (CInt -> Ptr t -> r) (Vector t) r
73vec :: (Storable t) => Vector t -> (((CInt -> Ptr t -> t1) -> t1) -> IO b) -> IO b
74vec x f = unsafeWith x $ \p -> do
75 let v g = do
76 g (fi $ dim x) p
77 f v
78{-# INLINE vec #-}
79
80
81-- allocates memory for a new vector
82createVector :: Storable a => Int -> IO (Vector a)
83createVector n = do
84 when (n < 0) $ error ("trying to createVector of negative dim: "++show n)
85 fp <- doMalloc undefined
86 return $ unsafeFromForeignPtr fp 0 n
87 where
88 --
89 -- Use the much cheaper Haskell heap allocated storage
90 -- for foreign pointer space we control
91 --
92 doMalloc :: Storable b => b -> IO (ForeignPtr b)
93 doMalloc dummy = do
94#if __GLASGOW_HASKELL__ >= 605
95 mallocPlainForeignPtrBytes (n * sizeOf dummy)
96#else
97 mallocForeignPtrBytes (n * sizeOf dummy)
98#endif
99
100{- | creates a Vector from a list:
101
102@> fromList [2,3,5,7]
1034 |> [2.0,3.0,5.0,7.0]@
104
105-}
106
107safeRead v = inlinePerformIO . unsafeWith v
108{-# INLINE safeRead #-}
109
110inlinePerformIO :: IO a -> a
111inlinePerformIO (IO m) = case m realWorld# of (# _, r #) -> r
112{-# INLINE inlinePerformIO #-}
113
114{- | extracts the Vector elements to a list
115
116>>> toList (linspace 5 (1,10))
117[1.0,3.25,5.5,7.75,10.0]
118
119-}
120toList :: Storable a => Vector a -> [a]
121toList v = safeRead v $ peekArray (dim v)
122
123{- | Create a vector from a list of elements and explicit dimension. The input
124 list is explicitly truncated if it is too long, so it may safely
125 be used, for instance, with infinite lists.
126
127>>> 5 |> [1..]
128fromList [1.0,2.0,3.0,4.0,5.0]
129
130-}
131(|>) :: (Storable a) => Int -> [a] -> Vector a
132infixl 9 |>
133n |> l = if length l' == n
134 then fromList l'
135 else error "list too short for |>"
136 where l' = take n l
137
138
139-- | access to Vector elements without range checking
140at' :: Storable a => Vector a -> Int -> a
141at' v n = safeRead v $ flip peekElemOff n
142{-# INLINE at' #-}
143
144--
145-- turn off bounds checking with -funsafe at configure time.
146-- ghc will optimise away the salways true case at compile time.
147--
148#if defined(UNSAFE)
149safe :: Bool
150safe = False
151#else
152safe = True
153#endif
154
155-- | access to Vector elements with range checking.
156at :: Storable a => Vector a -> Int -> a
157at v n
158 | safe = if n >= 0 && n < dim v
159 then at' v n
160 else error "vector index out of range"
161 | otherwise = at' v n
162{-# INLINE at #-}
163
164{- | takes a number of consecutive elements from a Vector
165
166>>> subVector 2 3 (fromList [1..10])
167fromList [3.0,4.0,5.0]
168
169-}
170subVector :: Storable t => Int -- ^ index of the starting element
171 -> Int -- ^ number of elements to extract
172 -> Vector t -- ^ source
173 -> Vector t -- ^ result
174subVector = Vector.slice
175
176
177{- | Reads a vector position:
178
179>>> fromList [0..9] @> 7
1807.0
181
182-}
183(@>) :: Storable t => Vector t -> Int -> t
184infixl 9 @>
185(@>) = at
186
187
188{- | concatenate a list of vectors
189
190>>> vjoin [fromList [1..5::Double], konst 1 3]
191fromList [1.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0]
192
193-}
194vjoin :: Storable t => [Vector t] -> Vector t
195vjoin [] = fromList []
196vjoin [v] = v
197vjoin as = unsafePerformIO $ do
198 let tot = sum (map dim as)
199 r <- createVector tot
200 unsafeWith r $ \ptr ->
201 joiner as tot ptr
202 return r
203 where joiner [] _ _ = return ()
204 joiner (v:cs) _ p = do
205 let n = dim v
206 unsafeWith v $ \pb -> copyArray p pb n
207 joiner cs 0 (advancePtr p n)
208
209
210{- | Extract consecutive subvectors of the given sizes.
211
212>>> takesV [3,4] (linspace 10 (1,10::Double))
213[fromList [1.0,2.0,3.0],fromList [4.0,5.0,6.0,7.0]]
214
215-}
216takesV :: Storable t => [Int] -> Vector t -> [Vector t]
217takesV ms w | sum ms > dim w = error $ "takesV " ++ show ms ++ " on dim = " ++ (show $ dim w)
218 | otherwise = go ms w
219 where go [] _ = []
220 go (n:ns) v = subVector 0 n v
221 : go ns (subVector n (dim v - n) v)
222
223---------------------------------------------------------------
224
225-- | transforms a complex vector into a real vector with alternating real and imaginary parts
226asReal :: (RealFloat a, Storable a) => Vector (Complex a) -> Vector a
227asReal v = unsafeFromForeignPtr (castForeignPtr fp) (2*i) (2*n)
228 where (fp,i,n) = unsafeToForeignPtr v
229
230-- | transforms a real vector into a complex vector with alternating real and imaginary parts
231asComplex :: (RealFloat a, Storable a) => Vector a -> Vector (Complex a)
232asComplex v = unsafeFromForeignPtr (castForeignPtr fp) (i `div` 2) (n `div` 2)
233 where (fp,i,n) = unsafeToForeignPtr v
234
235---------------------------------------------------------------
236
237float2DoubleV :: Vector Float -> Vector Double
238float2DoubleV v = unsafePerformIO $ do
239 r <- createVector (dim v)
240 app2 c_float2double vec v vec r "float2double"
241 return r
242
243double2FloatV :: Vector Double -> Vector Float
244double2FloatV v = unsafePerformIO $ do
245 r <- createVector (dim v)
246 app2 c_double2float vec v vec r "double2float2"
247 return r
248
249
250foreign import ccall unsafe "float2double" c_float2double:: TFV
251foreign import ccall unsafe "double2float" c_double2float:: TVF
252
253---------------------------------------------------------------
254
255stepF :: Vector Float -> Vector Float
256stepF v = unsafePerformIO $ do
257 r <- createVector (dim v)
258 app2 c_stepF vec v vec r "stepF"
259 return r
260
261stepD :: Vector Double -> Vector Double
262stepD v = unsafePerformIO $ do
263 r <- createVector (dim v)
264 app2 c_stepD vec v vec r "stepD"
265 return r
266
267foreign import ccall unsafe "stepF" c_stepF :: TFF
268foreign import ccall unsafe "stepD" c_stepD :: TVV
269
270---------------------------------------------------------------
271
272condF :: Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float -> Vector Float
273condF x y l e g = unsafePerformIO $ do
274 r <- createVector (dim x)
275 app6 c_condF vec x vec y vec l vec e vec g vec r "condF"
276 return r
277
278condD :: Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double -> Vector Double
279condD x y l e g = unsafePerformIO $ do
280 r <- createVector (dim x)
281 app6 c_condD vec x vec y vec l vec e vec g vec r "condD"
282 return r
283
284foreign import ccall unsafe "condF" c_condF :: CInt -> PF -> CInt -> PF -> CInt -> PF -> TFFF
285foreign import ccall unsafe "condD" c_condD :: CInt -> PD -> CInt -> PD -> CInt -> PD -> TVVV
286
287--------------------------------------------------------------------------------
288
289conjugateAux fun x = unsafePerformIO $ do
290 v <- createVector (dim x)
291 app2 fun vec x vec v "conjugateAux"
292 return v
293
294conjugateQ :: Vector (Complex Float) -> Vector (Complex Float)
295conjugateQ = conjugateAux c_conjugateQ
296foreign import ccall unsafe "conjugateQ" c_conjugateQ :: TQVQV
297
298conjugateC :: Vector (Complex Double) -> Vector (Complex Double)
299conjugateC = conjugateAux c_conjugateC
300foreign import ccall unsafe "conjugateC" c_conjugateC :: TCVCV
301
302--------------------------------------------------------------------------------
303
304cloneVector :: Storable t => Vector t -> IO (Vector t)
305cloneVector v = do
306 let n = dim v
307 r <- createVector n
308 let f _ s _ d = copyArray d s n >> return 0
309 app2 f vec v vec r "cloneVector"
310 return r
311
312------------------------------------------------------------------
313
314-- | map on Vectors
315mapVector :: (Storable a, Storable b) => (a-> b) -> Vector a -> Vector b
316mapVector f v = unsafePerformIO $ do
317 w <- createVector (dim v)
318 unsafeWith v $ \p ->
319 unsafeWith w $ \q -> do
320 let go (-1) = return ()
321 go !k = do x <- peekElemOff p k
322 pokeElemOff q k (f x)
323 go (k-1)
324 go (dim v -1)
325 return w
326{-# INLINE mapVector #-}
327
328-- | zipWith for Vectors
329zipVectorWith :: (Storable a, Storable b, Storable c) => (a-> b -> c) -> Vector a -> Vector b -> Vector c
330zipVectorWith f u v = unsafePerformIO $ do
331 let n = min (dim u) (dim v)
332 w <- createVector n
333 unsafeWith u $ \pu ->
334 unsafeWith v $ \pv ->
335 unsafeWith w $ \pw -> do
336 let go (-1) = return ()
337 go !k = do x <- peekElemOff pu k
338 y <- peekElemOff pv k
339 pokeElemOff pw k (f x y)
340 go (k-1)
341 go (n -1)
342 return w
343{-# INLINE zipVectorWith #-}
344
345-- | unzipWith for Vectors
346unzipVectorWith :: (Storable (a,b), Storable c, Storable d)
347 => ((a,b) -> (c,d)) -> Vector (a,b) -> (Vector c,Vector d)
348unzipVectorWith f u = unsafePerformIO $ do
349 let n = dim u
350 v <- createVector n
351 w <- createVector n
352 unsafeWith u $ \pu ->
353 unsafeWith v $ \pv ->
354 unsafeWith w $ \pw -> do
355 let go (-1) = return ()
356 go !k = do z <- peekElemOff pu k
357 let (x,y) = f z
358 pokeElemOff pv k x
359 pokeElemOff pw k y
360 go (k-1)
361 go (n-1)
362 return (v,w)
363{-# INLINE unzipVectorWith #-}
364
365foldVector :: Storable a => (a -> b -> b) -> b -> Vector a -> b
366foldVector f x v = unsafePerformIO $
367 unsafeWith v $ \p -> do
368 let go (-1) s = return s
369 go !k !s = do y <- peekElemOff p k
370 go (k-1::Int) (f y s)
371 go (dim v -1) x
372{-# INLINE foldVector #-}
373
374-- the zero-indexed index is passed to the folding function
375foldVectorWithIndex :: Storable a => (Int -> a -> b -> b) -> b -> Vector a -> b
376foldVectorWithIndex f x v = unsafePerformIO $
377 unsafeWith v $ \p -> do
378 let go (-1) s = return s
379 go !k !s = do y <- peekElemOff p k
380 go (k-1::Int) (f k y s)
381 go (dim v -1) x
382{-# INLINE foldVectorWithIndex #-}
383
384foldLoop f s0 d = go (d - 1) s0
385 where
386 go 0 s = f (0::Int) s
387 go !j !s = go (j - 1) (f j s)
388
389foldVectorG f s0 v = foldLoop g s0 (dim v)
390 where g !k !s = f k (at' v) s
391 {-# INLINE g #-} -- Thanks to Ryan Ingram (http://permalink.gmane.org/gmane.comp.lang.haskell.cafe/46479)
392{-# INLINE foldVectorG #-}
393
394-------------------------------------------------------------------
395
396-- | monadic map over Vectors
397-- the monad @m@ must be strict
398mapVectorM :: (Storable a, Storable b, Monad m) => (a -> m b) -> Vector a -> m (Vector b)
399mapVectorM f v = do
400 w <- return $! unsafePerformIO $! createVector (dim v)
401 mapVectorM' w 0 (dim v -1)
402 return w
403 where mapVectorM' w' !k !t
404 | k == t = do
405 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
406 y <- f x
407 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
408 | otherwise = do
409 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
410 y <- f x
411 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
412 mapVectorM' w' (k+1) t
413{-# INLINE mapVectorM #-}
414
415-- | monadic map over Vectors
416mapVectorM_ :: (Storable a, Monad m) => (a -> m ()) -> Vector a -> m ()
417mapVectorM_ f v = do
418 mapVectorM' 0 (dim v -1)
419 where mapVectorM' !k !t
420 | k == t = do
421 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
422 f x
423 | otherwise = do
424 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
425 _ <- f x
426 mapVectorM' (k+1) t
427{-# INLINE mapVectorM_ #-}
428
429-- | monadic map over Vectors with the zero-indexed index passed to the mapping function
430-- the monad @m@ must be strict
431mapVectorWithIndexM :: (Storable a, Storable b, Monad m) => (Int -> a -> m b) -> Vector a -> m (Vector b)
432mapVectorWithIndexM f v = do
433 w <- return $! unsafePerformIO $! createVector (dim v)
434 mapVectorM' w 0 (dim v -1)
435 return w
436 where mapVectorM' w' !k !t
437 | k == t = do
438 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
439 y <- f k x
440 return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
441 | otherwise = do
442 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
443 y <- f k x
444 _ <- return $! inlinePerformIO $! unsafeWith w' $! \q -> pokeElemOff q k y
445 mapVectorM' w' (k+1) t
446{-# INLINE mapVectorWithIndexM #-}
447
448-- | monadic map over Vectors with the zero-indexed index passed to the mapping function
449mapVectorWithIndexM_ :: (Storable a, Monad m) => (Int -> a -> m ()) -> Vector a -> m ()
450mapVectorWithIndexM_ f v = do
451 mapVectorM' 0 (dim v -1)
452 where mapVectorM' !k !t
453 | k == t = do
454 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
455 f k x
456 | otherwise = do
457 x <- return $! inlinePerformIO $! unsafeWith v $! \p -> peekElemOff p k
458 _ <- f k x
459 mapVectorM' (k+1) t
460{-# INLINE mapVectorWithIndexM_ #-}
461
462
463mapVectorWithIndex :: (Storable a, Storable b) => (Int -> a -> b) -> Vector a -> Vector b
464--mapVectorWithIndex g = head . mapVectorWithIndexM (\a b -> [g a b])
465mapVectorWithIndex f v = unsafePerformIO $ do
466 w <- createVector (dim v)
467 unsafeWith v $ \p ->
468 unsafeWith w $ \q -> do
469 let go (-1) = return ()
470 go !k = do x <- peekElemOff p k
471 pokeElemOff q k (f k x)
472 go (k-1)
473 go (dim v -1)
474 return w
475{-# INLINE mapVectorWithIndex #-}
476
477-------------------------------------------------------------------
478
479
480-- | Loads a vector from an ASCII file (the number of elements must be known in advance).
481fscanfVector :: FilePath -> Int -> IO (Vector Double)
482fscanfVector filename n = do
483 charname <- newCString filename
484 res <- createVector n
485 app1 (gsl_vector_fscanf charname) vec res "gsl_vector_fscanf"
486 free charname
487 return res
488
489foreign import ccall unsafe "vector_fscanf" gsl_vector_fscanf:: Ptr CChar -> TV
490
491-- | Saves the elements of a vector, with a given format (%f, %e, %g), to an ASCII file.
492fprintfVector :: FilePath -> String -> Vector Double -> IO ()
493fprintfVector filename fmt v = do
494 charname <- newCString filename
495 charfmt <- newCString fmt
496 app1 (gsl_vector_fprintf charname charfmt) vec v "gsl_vector_fprintf"
497 free charname
498 free charfmt
499
500foreign import ccall unsafe "vector_fprintf" gsl_vector_fprintf :: Ptr CChar -> Ptr CChar -> TV
501
502-- | Loads a vector from a binary file (the number of elements must be known in advance).
503freadVector :: FilePath -> Int -> IO (Vector Double)
504freadVector filename n = do
505 charname <- newCString filename
506 res <- createVector n
507 app1 (gsl_vector_fread charname) vec res "gsl_vector_fread"
508 free charname
509 return res
510
511foreign import ccall unsafe "vector_fread" gsl_vector_fread:: Ptr CChar -> TV
512
513-- | Saves the elements of a vector to a binary file.
514fwriteVector :: FilePath -> Vector Double -> IO ()
515fwriteVector filename v = do
516 charname <- newCString filename
517 app1 (gsl_vector_fwrite charname) vec v "gsl_vector_fwrite"
518 free charname
519
520foreign import ccall unsafe "vector_fwrite" gsl_vector_fwrite :: Ptr CChar -> TV
521
diff --git a/packages/hmatrix/src/Data/Packed/Matrix.hs b/packages/hmatrix/src/Data/Packed/Matrix.hs
new file mode 100644
index 0000000..d94d167
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Matrix.hs
@@ -0,0 +1,490 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE CPP #-}
6
7-----------------------------------------------------------------------------
8-- |
9-- Module : Data.Packed.Matrix
10-- Copyright : (c) Alberto Ruiz 2007-10
11-- License : GPL
12--
13-- Maintainer : Alberto Ruiz <aruiz@um.es>
14-- Stability : provisional
15--
16-- A Matrix representation suitable for numerical computations using LAPACK and GSL.
17--
18-- This module provides basic functions for manipulation of structure.
19
20-----------------------------------------------------------------------------
21{-# OPTIONS_HADDOCK hide #-}
22
23module Data.Packed.Matrix (
24 Matrix,
25 Element,
26 rows,cols,
27 (><),
28 trans,
29 reshape, flatten,
30 fromLists, toLists, buildMatrix,
31 (@@>),
32 asRow, asColumn,
33 fromRows, toRows, fromColumns, toColumns,
34 fromBlocks, diagBlock, toBlocks, toBlocksEvery,
35 repmat,
36 flipud, fliprl,
37 subMatrix, takeRows, dropRows, takeColumns, dropColumns,
38 extractRows, extractColumns,
39 diagRect, takeDiag,
40 mapMatrix, mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_,
41 liftMatrix, liftMatrix2, liftMatrix2Auto,fromArray2D
42) where
43
44import Data.Packed.Internal
45import qualified Data.Packed.ST as ST
46import Data.Array
47
48import Data.List(transpose,intersperse)
49import Foreign.Storable(Storable)
50import Control.Monad(liftM)
51
52-------------------------------------------------------------------
53
54#ifdef BINARY
55
56import Data.Binary
57import Control.Monad(replicateM)
58
59instance (Binary a, Element a, Storable a) => Binary (Matrix a) where
60 put m = do
61 let r = rows m
62 let c = cols m
63 put r
64 put c
65 mapM_ (\i -> mapM_ (\j -> put $ m @@> (i,j)) [0..(c-1)]) [0..(r-1)]
66 get = do
67 r <- get
68 c <- get
69 xs <- replicateM r $ replicateM c get
70 return $ fromLists xs
71
72#endif
73
74-------------------------------------------------------------------
75
76instance (Show a, Element a) => (Show (Matrix a)) where
77 show m | rows m == 0 || cols m == 0 = sizes m ++" []"
78 show m = (sizes m++) . dsp . map (map show) . toLists $ m
79
80sizes m = "("++show (rows m)++"><"++show (cols m)++")\n"
81
82dsp as = (++" ]") . (" ["++) . init . drop 2 . unlines . map (" , "++) . map unwords' $ transpose mtp
83 where
84 mt = transpose as
85 longs = map (maximum . map length) mt
86 mtp = zipWith (\a b -> map (pad a) b) longs mt
87 pad n str = replicate (n - length str) ' ' ++ str
88 unwords' = concat . intersperse ", "
89
90------------------------------------------------------------------
91
92instance (Element a, Read a) => Read (Matrix a) where
93 readsPrec _ s = [((rs><cs) . read $ listnums, rest)]
94 where (thing,rest) = breakAt ']' s
95 (dims,listnums) = breakAt ')' thing
96 cs = read . init . fst. breakAt ')' . snd . breakAt '<' $ dims
97 rs = read . snd . breakAt '(' .init . fst . breakAt '>' $ dims
98
99
100breakAt c l = (a++[c],tail b) where
101 (a,b) = break (==c) l
102
103------------------------------------------------------------------
104
105-- | creates a matrix from a vertical list of matrices
106joinVert :: Element t => [Matrix t] -> Matrix t
107joinVert [] = emptyM 0 0
108joinVert ms = case common cols ms of
109 Nothing -> error "(impossible) joinVert on matrices with different number of columns"
110 Just c -> matrixFromVector RowMajor (sum (map rows ms)) c $ vjoin (map flatten ms)
111
112-- | creates a matrix from a horizontal list of matrices
113joinHoriz :: Element t => [Matrix t] -> Matrix t
114joinHoriz ms = trans. joinVert . map trans $ ms
115
116{- | Create a matrix from blocks given as a list of lists of matrices.
117
118Single row-column components are automatically expanded to match the
119corresponding common row and column:
120
121@
122disp = putStr . dispf 2
123@
124
125>>> disp $ fromBlocks [[ident 5, 7, row[10,20]], [3, diagl[1,2,3], 0]]
1268x10
1271 0 0 0 0 7 7 7 10 20
1280 1 0 0 0 7 7 7 10 20
1290 0 1 0 0 7 7 7 10 20
1300 0 0 1 0 7 7 7 10 20
1310 0 0 0 1 7 7 7 10 20
1323 3 3 3 3 1 0 0 0 0
1333 3 3 3 3 0 2 0 0 0
1343 3 3 3 3 0 0 3 0 0
135
136-}
137fromBlocks :: Element t => [[Matrix t]] -> Matrix t
138fromBlocks = fromBlocksRaw . adaptBlocks
139
140fromBlocksRaw mms = joinVert . map joinHoriz $ mms
141
142adaptBlocks ms = ms' where
143 bc = case common length ms of
144 Just c -> c
145 Nothing -> error "fromBlocks requires rectangular [[Matrix]]"
146 rs = map (compatdim . map rows) ms
147 cs = map (compatdim . map cols) (transpose ms)
148 szs = sequence [rs,cs]
149 ms' = splitEvery bc $ zipWith g szs (concat ms)
150
151 g [Just nr,Just nc] m
152 | nr == r && nc == c = m
153 | r == 1 && c == 1 = matrixFromVector RowMajor nr nc (constantD x (nr*nc))
154 | r == 1 = fromRows (replicate nr (flatten m))
155 | otherwise = fromColumns (replicate nc (flatten m))
156 where
157 r = rows m
158 c = cols m
159 x = m@@>(0,0)
160 g _ _ = error "inconsistent dimensions in fromBlocks"
161
162
163--------------------------------------------------------------------------------
164
165{- | create a block diagonal matrix
166
167>>> disp 2 $ diagBlock [konst 1 (2,2), konst 2 (3,5), col [5,7]]
1687x8
1691 1 0 0 0 0 0 0
1701 1 0 0 0 0 0 0
1710 0 2 2 2 2 2 0
1720 0 2 2 2 2 2 0
1730 0 2 2 2 2 2 0
1740 0 0 0 0 0 0 5
1750 0 0 0 0 0 0 7
176
177>>> diagBlock [(0><4)[], konst 2 (2,3)] :: Matrix Double
178(2><7)
179 [ 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0
180 , 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0 ]
181
182-}
183diagBlock :: (Element t, Num t) => [Matrix t] -> Matrix t
184diagBlock ms = fromBlocks $ zipWith f ms [0..]
185 where
186 f m k = take n $ replicate k z ++ m : repeat z
187 n = length ms
188 z = (1><1) [0]
189
190--------------------------------------------------------------------------------
191
192
193-- | Reverse rows
194flipud :: Element t => Matrix t -> Matrix t
195flipud m = extractRows [r-1,r-2 .. 0] $ m
196 where
197 r = rows m
198
199-- | Reverse columns
200fliprl :: Element t => Matrix t -> Matrix t
201fliprl m = extractColumns [c-1,c-2 .. 0] $ m
202 where
203 c = cols m
204
205------------------------------------------------------------
206
207{- | creates a rectangular diagonal matrix:
208
209>>> diagRect 7 (fromList [10,20,30]) 4 5 :: Matrix Double
210(4><5)
211 [ 10.0, 7.0, 7.0, 7.0, 7.0
212 , 7.0, 20.0, 7.0, 7.0, 7.0
213 , 7.0, 7.0, 30.0, 7.0, 7.0
214 , 7.0, 7.0, 7.0, 7.0, 7.0 ]
215
216-}
217diagRect :: (Storable t) => t -> Vector t -> Int -> Int -> Matrix t
218diagRect z v r c = ST.runSTMatrix $ do
219 m <- ST.newMatrix z r c
220 let d = min r c `min` (dim v)
221 mapM_ (\k -> ST.writeMatrix m k k (v@>k)) [0..d-1]
222 return m
223
224-- | extracts the diagonal from a rectangular matrix
225takeDiag :: (Element t) => Matrix t -> Vector t
226takeDiag m = fromList [flatten m `at` (k*cols m+k) | k <- [0 .. min (rows m) (cols m) -1]]
227
228------------------------------------------------------------
229
230{- | An easy way to create a matrix:
231
232>>> (2><3)[2,4,7,-3,11,0]
233(2><3)
234 [ 2.0, 4.0, 7.0
235 , -3.0, 11.0, 0.0 ]
236
237This is the format produced by the instances of Show (Matrix a), which
238can also be used for input.
239
240The input list is explicitly truncated, so that it can
241safely be used with lists that are too long (like infinite lists).
242
243>>> (2><3)[1..]
244(2><3)
245 [ 1.0, 2.0, 3.0
246 , 4.0, 5.0, 6.0 ]
247
248
249-}
250(><) :: (Storable a) => Int -> Int -> [a] -> Matrix a
251r >< c = f where
252 f l | dim v == r*c = matrixFromVector RowMajor r c v
253 | otherwise = error $ "inconsistent list size = "
254 ++show (dim v) ++" in ("++show r++"><"++show c++")"
255 where v = fromList $ take (r*c) l
256
257----------------------------------------------------------------
258
259-- | Creates a matrix with the first n rows of another matrix
260takeRows :: Element t => Int -> Matrix t -> Matrix t
261takeRows n mt = subMatrix (0,0) (n, cols mt) mt
262-- | Creates a copy of a matrix without the first n rows
263dropRows :: Element t => Int -> Matrix t -> Matrix t
264dropRows n mt = subMatrix (n,0) (rows mt - n, cols mt) mt
265-- |Creates a matrix with the first n columns of another matrix
266takeColumns :: Element t => Int -> Matrix t -> Matrix t
267takeColumns n mt = subMatrix (0,0) (rows mt, n) mt
268-- | Creates a copy of a matrix without the first n columns
269dropColumns :: Element t => Int -> Matrix t -> Matrix t
270dropColumns n mt = subMatrix (0,n) (rows mt, cols mt - n) mt
271
272----------------------------------------------------------------
273
274{- | Creates a 'Matrix' from a list of lists (considered as rows).
275
276>>> fromLists [[1,2],[3,4],[5,6]]
277(3><2)
278 [ 1.0, 2.0
279 , 3.0, 4.0
280 , 5.0, 6.0 ]
281
282-}
283fromLists :: Element t => [[t]] -> Matrix t
284fromLists = fromRows . map fromList
285
286-- | creates a 1-row matrix from a vector
287--
288-- >>> asRow (fromList [1..5])
289-- (1><5)
290-- [ 1.0, 2.0, 3.0, 4.0, 5.0 ]
291--
292asRow :: Storable a => Vector a -> Matrix a
293asRow v = reshape (dim v) v
294
295-- | creates a 1-column matrix from a vector
296--
297-- >>> asColumn (fromList [1..5])
298-- (5><1)
299-- [ 1.0
300-- , 2.0
301-- , 3.0
302-- , 4.0
303-- , 5.0 ]
304--
305asColumn :: Storable a => Vector a -> Matrix a
306asColumn = trans . asRow
307
308
309
310{- | creates a Matrix of the specified size using the supplied function to
311 to map the row\/column position to the value at that row\/column position.
312
313@> buildMatrix 3 4 (\\(r,c) -> fromIntegral r * fromIntegral c)
314(3><4)
315 [ 0.0, 0.0, 0.0, 0.0, 0.0
316 , 0.0, 1.0, 2.0, 3.0, 4.0
317 , 0.0, 2.0, 4.0, 6.0, 8.0]@
318
319Hilbert matrix of order N:
320
321@hilb n = buildMatrix n n (\\(i,j)->1/(fromIntegral i + fromIntegral j +1))@
322
323-}
324buildMatrix :: Element a => Int -> Int -> ((Int, Int) -> a) -> Matrix a
325buildMatrix rc cc f =
326 fromLists $ map (map f)
327 $ map (\ ri -> map (\ ci -> (ri, ci)) [0 .. (cc - 1)]) [0 .. (rc - 1)]
328
329-----------------------------------------------------
330
331fromArray2D :: (Storable e) => Array (Int, Int) e -> Matrix e
332fromArray2D m = (r><c) (elems m)
333 where ((r0,c0),(r1,c1)) = bounds m
334 r = r1-r0+1
335 c = c1-c0+1
336
337
338-- | rearranges the rows of a matrix according to the order given in a list of integers.
339extractRows :: Element t => [Int] -> Matrix t -> Matrix t
340extractRows [] m = emptyM 0 (cols m)
341extractRows l m = fromRows $ extract (toRows m) l
342 where
343 extract l' is = [l'!!i | i<- map verify is]
344 verify k
345 | k >= 0 && k < rows m = k
346 | otherwise = error $ "can't extract row "
347 ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m
348
349-- | rearranges the rows of a matrix according to the order given in a list of integers.
350extractColumns :: Element t => [Int] -> Matrix t -> Matrix t
351extractColumns l m = trans . extractRows (map verify l) . trans $ m
352 where
353 verify k
354 | k >= 0 && k < cols m = k
355 | otherwise = error $ "can't extract column "
356 ++show k++" in list " ++ show l ++ " from matrix " ++ shSize m
357
358
359
360{- | creates matrix by repetition of a matrix a given number of rows and columns
361
362>>> repmat (ident 2) 2 3
363(4><6)
364 [ 1.0, 0.0, 1.0, 0.0, 1.0, 0.0
365 , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0
366 , 1.0, 0.0, 1.0, 0.0, 1.0, 0.0
367 , 0.0, 1.0, 0.0, 1.0, 0.0, 1.0 ]
368
369-}
370repmat :: (Element t) => Matrix t -> Int -> Int -> Matrix t
371repmat m r c
372 | r == 0 || c == 0 = emptyM (r*rows m) (c*cols m)
373 | otherwise = fromBlocks $ replicate r $ replicate c $ m
374
375-- | A version of 'liftMatrix2' which automatically adapt matrices with a single row or column to match the dimensions of the other matrix.
376liftMatrix2Auto :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
377liftMatrix2Auto f m1 m2
378 | compat' m1 m2 = lM f m1 m2
379 | ok = lM f m1' m2'
380 | otherwise = error $ "nonconformable matrices in liftMatrix2Auto: " ++ shSize m1 ++ ", " ++ shSize m2
381 where
382 (r1,c1) = size m1
383 (r2,c2) = size m2
384 r = max r1 r2
385 c = max c1 c2
386 r0 = min r1 r2
387 c0 = min c1 c2
388 ok = r0 == 1 || r1 == r2 && c0 == 1 || c1 == c2
389 m1' = conformMTo (r,c) m1
390 m2' = conformMTo (r,c) m2
391
392-- FIXME do not flatten if equal order
393lM f m1 m2 = matrixFromVector
394 RowMajor
395 (max (rows m1) (rows m2))
396 (max (cols m1) (cols m2))
397 (f (flatten m1) (flatten m2))
398
399compat' :: Matrix a -> Matrix b -> Bool
400compat' m1 m2 = s1 == (1,1) || s2 == (1,1) || s1 == s2
401 where
402 s1 = size m1
403 s2 = size m2
404
405------------------------------------------------------------
406
407toBlockRows [r] m | r == rows m = [m]
408toBlockRows rs m = map (reshape (cols m)) (takesV szs (flatten m))
409 where szs = map (* cols m) rs
410
411toBlockCols [c] m | c == cols m = [m]
412toBlockCols cs m = map trans . toBlockRows cs . trans $ m
413
414-- | Partition a matrix into blocks with the given numbers of rows and columns.
415-- The remaining rows and columns are discarded.
416toBlocks :: (Element t) => [Int] -> [Int] -> Matrix t -> [[Matrix t]]
417toBlocks rs cs m = map (toBlockCols cs) . toBlockRows rs $ m
418
419-- | Fully partition a matrix into blocks of the same size. If the dimensions are not
420-- a multiple of the given size the last blocks will be smaller.
421toBlocksEvery :: (Element t) => Int -> Int -> Matrix t -> [[Matrix t]]
422toBlocksEvery r c m
423 | r < 1 || c < 1 = error $ "toBlocksEvery expects block sizes > 0, given "++show r++" and "++ show c
424 | otherwise = toBlocks rs cs m
425 where
426 (qr,rr) = rows m `divMod` r
427 (qc,rc) = cols m `divMod` c
428 rs = replicate qr r ++ if rr > 0 then [rr] else []
429 cs = replicate qc c ++ if rc > 0 then [rc] else []
430
431-------------------------------------------------------------------
432
433-- Given a column number and a function taking matrix indexes, returns
434-- a function which takes vector indexes (that can be used on the
435-- flattened matrix).
436mk :: Int -> ((Int, Int) -> t) -> (Int -> t)
437mk c g = \k -> g (divMod k c)
438
439{- |
440
441>>> mapMatrixWithIndexM_ (\(i,j) v -> printf "m[%d,%d] = %.f\n" i j v :: IO()) ((2><3)[1 :: Double ..])
442m[0,0] = 1
443m[0,1] = 2
444m[0,2] = 3
445m[1,0] = 4
446m[1,1] = 5
447m[1,2] = 6
448
449-}
450mapMatrixWithIndexM_
451 :: (Element a, Num a, Monad m) =>
452 ((Int, Int) -> a -> m ()) -> Matrix a -> m ()
453mapMatrixWithIndexM_ g m = mapVectorWithIndexM_ (mk c g) . flatten $ m
454 where
455 c = cols m
456
457{- |
458
459>>> mapMatrixWithIndexM (\(i,j) v -> Just $ 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double)
460Just (3><3)
461 [ 100.0, 1.0, 2.0
462 , 10.0, 111.0, 12.0
463 , 20.0, 21.0, 122.0 ]
464
465-}
466mapMatrixWithIndexM
467 :: (Element a, Storable b, Monad m) =>
468 ((Int, Int) -> a -> m b) -> Matrix a -> m (Matrix b)
469mapMatrixWithIndexM g m = liftM (reshape c) . mapVectorWithIndexM (mk c g) . flatten $ m
470 where
471 c = cols m
472
473{- |
474
475>>> mapMatrixWithIndex (\\(i,j) v -> 100*v + 10*fromIntegral i + fromIntegral j) (ident 3:: Matrix Double)
476(3><3)
477 [ 100.0, 1.0, 2.0
478 , 10.0, 111.0, 12.0
479 , 20.0, 21.0, 122.0 ]
480
481 -}
482mapMatrixWithIndex
483 :: (Element a, Storable b) =>
484 ((Int, Int) -> a -> b) -> Matrix a -> Matrix b
485mapMatrixWithIndex g m = reshape c . mapVectorWithIndex (mk c g) . flatten $ m
486 where
487 c = cols m
488
489mapMatrix :: (Storable a, Storable b) => (a -> b) -> Matrix a -> Matrix b
490mapMatrix f = liftMatrix (mapVector f)
diff --git a/packages/hmatrix/src/Data/Packed/Random.hs b/packages/hmatrix/src/Data/Packed/Random.hs
new file mode 100644
index 0000000..e8b0268
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Random.hs
@@ -0,0 +1,57 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Data.Packed.Vector
4-- Copyright : (c) Alberto Ruiz 2009
5-- License : GPL
6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es>
8-- Stability : provisional
9--
10-- Random vectors and matrices.
11--
12-----------------------------------------------------------------------------
13
14module Data.Packed.Random (
15 Seed,
16 RandDist(..),
17 randomVector,
18 gaussianSample,
19 uniformSample
20) where
21
22import Numeric.GSL.Vector
23import Data.Packed
24import Numeric.ContainerBoot
25import Numeric.LinearAlgebra.Algorithms
26
27
28type Seed = Int
29
30-- | Obtains a matrix whose rows are pseudorandom samples from a multivariate
31-- Gaussian distribution.
32gaussianSample :: Seed
33 -> Int -- ^ number of rows
34 -> Vector Double -- ^ mean vector
35 -> Matrix Double -- ^ covariance matrix
36 -> Matrix Double -- ^ result
37gaussianSample seed n med cov = m where
38 c = dim med
39 meds = konst' 1 n `outer` med
40 rs = reshape c $ randomVector seed Gaussian (c * n)
41 m = rs `mXm` cholSH cov `add` meds
42
43-- | Obtains a matrix whose rows are pseudorandom samples from a multivariate
44-- uniform distribution.
45uniformSample :: Seed
46 -> Int -- ^ number of rows
47 -> [(Double,Double)] -- ^ ranges for each column
48 -> Matrix Double -- ^ result
49uniformSample seed n rgs = m where
50 (as,bs) = unzip rgs
51 a = fromList as
52 cs = zipWith subtract as bs
53 d = dim a
54 dat = toRows $ reshape n $ randomVector seed Uniform (n*d)
55 am = konst' 1 n `outer` a
56 m = fromColumns (zipWith scale cs dat) `add` am
57
diff --git a/packages/hmatrix/src/Data/Packed/ST.hs b/packages/hmatrix/src/Data/Packed/ST.hs
new file mode 100644
index 0000000..1cef296
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/ST.hs
@@ -0,0 +1,179 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeOperators #-}
3{-# LANGUAGE Rank2Types #-}
4{-# LANGUAGE BangPatterns #-}
5-----------------------------------------------------------------------------
6-- |
7-- Module : Data.Packed.ST
8-- Copyright : (c) Alberto Ruiz 2008
9-- License : GPL-style
10--
11-- Maintainer : Alberto Ruiz <aruiz@um.es>
12-- Stability : provisional
13-- Portability : portable
14--
15-- In-place manipulation inside the ST monad.
16-- See examples/inplace.hs in the distribution.
17--
18-----------------------------------------------------------------------------
19{-# OPTIONS_HADDOCK hide #-}
20
21module Data.Packed.ST (
22 -- * Mutable Vectors
23 STVector, newVector, thawVector, freezeVector, runSTVector,
24 readVector, writeVector, modifyVector, liftSTVector,
25 -- * Mutable Matrices
26 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
27 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
28 -- * Unsafe functions
29 newUndefinedVector,
30 unsafeReadVector, unsafeWriteVector,
31 unsafeThawVector, unsafeFreezeVector,
32 newUndefinedMatrix,
33 unsafeReadMatrix, unsafeWriteMatrix,
34 unsafeThawMatrix, unsafeFreezeMatrix
35) where
36
37import Data.Packed.Internal
38
39import Control.Monad.ST(ST, runST)
40import Foreign.Storable(Storable, peekElemOff, pokeElemOff)
41
42#if MIN_VERSION_base(4,4,0)
43import Control.Monad.ST.Unsafe(unsafeIOToST)
44#else
45import Control.Monad.ST(unsafeIOToST)
46#endif
47
48{-# INLINE ioReadV #-}
49ioReadV :: Storable t => Vector t -> Int -> IO t
50ioReadV v k = unsafeWith v $ \s -> peekElemOff s k
51
52{-# INLINE ioWriteV #-}
53ioWriteV :: Storable t => Vector t -> Int -> t -> IO ()
54ioWriteV v k x = unsafeWith v $ \s -> pokeElemOff s k x
55
56newtype STVector s t = STVector (Vector t)
57
58thawVector :: Storable t => Vector t -> ST s (STVector s t)
59thawVector = unsafeIOToST . fmap STVector . cloneVector
60
61unsafeThawVector :: Storable t => Vector t -> ST s (STVector s t)
62unsafeThawVector = unsafeIOToST . return . STVector
63
64runSTVector :: Storable t => (forall s . ST s (STVector s t)) -> Vector t
65runSTVector st = runST (st >>= unsafeFreezeVector)
66
67{-# INLINE unsafeReadVector #-}
68unsafeReadVector :: Storable t => STVector s t -> Int -> ST s t
69unsafeReadVector (STVector x) = unsafeIOToST . ioReadV x
70
71{-# INLINE unsafeWriteVector #-}
72unsafeWriteVector :: Storable t => STVector s t -> Int -> t -> ST s ()
73unsafeWriteVector (STVector x) k = unsafeIOToST . ioWriteV x k
74
75{-# INLINE modifyVector #-}
76modifyVector :: (Storable t) => STVector s t -> Int -> (t -> t) -> ST s ()
77modifyVector x k f = readVector x k >>= return . f >>= unsafeWriteVector x k
78
79liftSTVector :: (Storable t) => (Vector t -> a) -> STVector s1 t -> ST s2 a
80liftSTVector f (STVector x) = unsafeIOToST . fmap f . cloneVector $ x
81
82freezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
83freezeVector v = liftSTVector id v
84
85unsafeFreezeVector :: (Storable t) => STVector s1 t -> ST s2 (Vector t)
86unsafeFreezeVector (STVector x) = unsafeIOToST . return $ x
87
88{-# INLINE safeIndexV #-}
89safeIndexV f (STVector v) k
90 | k < 0 || k>= dim v = error $ "out of range error in vector (dim="
91 ++show (dim v)++", pos="++show k++")"
92 | otherwise = f (STVector v) k
93
94{-# INLINE readVector #-}
95readVector :: Storable t => STVector s t -> Int -> ST s t
96readVector = safeIndexV unsafeReadVector
97
98{-# INLINE writeVector #-}
99writeVector :: Storable t => STVector s t -> Int -> t -> ST s ()
100writeVector = safeIndexV unsafeWriteVector
101
102newUndefinedVector :: Storable t => Int -> ST s (STVector s t)
103newUndefinedVector = unsafeIOToST . fmap STVector . createVector
104
105{-# INLINE newVector #-}
106newVector :: Storable t => t -> Int -> ST s (STVector s t)
107newVector x n = do
108 v <- newUndefinedVector n
109 let go (-1) = return v
110 go !k = unsafeWriteVector v k x >> go (k-1 :: Int)
111 go (n-1)
112
113-------------------------------------------------------------------------
114
115{-# INLINE ioReadM #-}
116ioReadM :: Storable t => Matrix t -> Int -> Int -> IO t
117ioReadM (Matrix _ nc cv RowMajor) r c = ioReadV cv (r*nc+c)
118ioReadM (Matrix nr _ fv ColumnMajor) r c = ioReadV fv (c*nr+r)
119
120{-# INLINE ioWriteM #-}
121ioWriteM :: Storable t => Matrix t -> Int -> Int -> t -> IO ()
122ioWriteM (Matrix _ nc cv RowMajor) r c val = ioWriteV cv (r*nc+c) val
123ioWriteM (Matrix nr _ fv ColumnMajor) r c val = ioWriteV fv (c*nr+r) val
124
125newtype STMatrix s t = STMatrix (Matrix t)
126
127thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
128thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
129
130unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
131unsafeThawMatrix = unsafeIOToST . return . STMatrix
132
133runSTMatrix :: Storable t => (forall s . ST s (STMatrix s t)) -> Matrix t
134runSTMatrix st = runST (st >>= unsafeFreezeMatrix)
135
136{-# INLINE unsafeReadMatrix #-}
137unsafeReadMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
138unsafeReadMatrix (STMatrix x) r = unsafeIOToST . ioReadM x r
139
140{-# INLINE unsafeWriteMatrix #-}
141unsafeWriteMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
142unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
143
144{-# INLINE modifyMatrix #-}
145modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
146modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
147
148liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s1 t -> ST s2 a
149liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
150
151unsafeFreezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
152unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
153
154freezeMatrix :: (Storable t) => STMatrix s1 t -> ST s2 (Matrix t)
155freezeMatrix m = liftSTMatrix id m
156
157cloneMatrix (Matrix r c d o) = cloneVector d >>= return . (\d' -> Matrix r c d' o)
158
159{-# INLINE safeIndexM #-}
160safeIndexM f (STMatrix m) r c
161 | r<0 || r>=rows m ||
162 c<0 || c>=cols m = error $ "out of range error in matrix (size="
163 ++show (rows m,cols m)++", pos="++show (r,c)++")"
164 | otherwise = f (STMatrix m) r c
165
166{-# INLINE readMatrix #-}
167readMatrix :: Storable t => STMatrix s t -> Int -> Int -> ST s t
168readMatrix = safeIndexM unsafeReadMatrix
169
170{-# INLINE writeMatrix #-}
171writeMatrix :: Storable t => STMatrix s t -> Int -> Int -> t -> ST s ()
172writeMatrix = safeIndexM unsafeWriteMatrix
173
174newUndefinedMatrix :: Storable t => MatrixOrder -> Int -> Int -> ST s (STMatrix s t)
175newUndefinedMatrix ord r c = unsafeIOToST $ fmap STMatrix $ createMatrix ord r c
176
177{-# NOINLINE newMatrix #-}
178newMatrix :: Storable t => t -> Int -> Int -> ST s (STMatrix s t)
179newMatrix v r c = unsafeThawMatrix $ reshape c $ runSTVector $ newVector v (r*c)
diff --git a/packages/hmatrix/src/Data/Packed/Vector.hs b/packages/hmatrix/src/Data/Packed/Vector.hs
new file mode 100644
index 0000000..b5a4318
--- /dev/null
+++ b/packages/hmatrix/src/Data/Packed/Vector.hs
@@ -0,0 +1,96 @@
1{-# LANGUAGE FlexibleContexts #-}
2{-# LANGUAGE CPP #-}
3-----------------------------------------------------------------------------
4-- |
5-- Module : Data.Packed.Vector
6-- Copyright : (c) Alberto Ruiz 2007-10
7-- License : GPL
8--
9-- Maintainer : Alberto Ruiz <aruiz@um.es>
10-- Stability : provisional
11--
12-- 1D arrays suitable for numeric computations using external libraries.
13--
14-- This module provides basic functions for manipulation of structure.
15--
16-----------------------------------------------------------------------------
17{-# OPTIONS_HADDOCK hide #-}
18
19module Data.Packed.Vector (
20 Vector,
21 fromList, (|>), toList, buildVector,
22 dim, (@>),
23 subVector, takesV, vjoin, join,
24 mapVector, mapVectorWithIndex, zipVector, zipVectorWith, unzipVector, unzipVectorWith,
25 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_,
26 foldLoop, foldVector, foldVectorG, foldVectorWithIndex
27) where
28
29import Data.Packed.Internal.Vector
30import Foreign.Storable
31
32-------------------------------------------------------------------
33
34#ifdef BINARY
35
36import Data.Binary
37import Control.Monad(replicateM)
38
39-- a 64K cache, with a Double taking 13 bytes in Bytestring,
40-- implies a chunk size of 5041
41chunk :: Int
42chunk = 5000
43
44chunks :: Int -> [Int]
45chunks d = let c = d `div` chunk
46 m = d `mod` chunk
47 in if m /= 0 then reverse (m:(replicate c chunk)) else (replicate c chunk)
48
49putVector v = do
50 let d = dim v
51 mapM_ (\i -> put $ v @> i) [0..(d-1)]
52
53getVector d = do
54 xs <- replicateM d get
55 return $! fromList xs
56
57instance (Binary a, Storable a) => Binary (Vector a) where
58 put v = do
59 let d = dim v
60 put d
61 mapM_ putVector $! takesV (chunks d) v
62 get = do
63 d <- get
64 vs <- mapM getVector $ chunks d
65 return $! vjoin vs
66
67#endif
68
69-------------------------------------------------------------------
70
71{- | creates a Vector of the specified length using the supplied function to
72 to map the index to the value at that index.
73
74@> buildVector 4 fromIntegral
754 |> [0.0,1.0,2.0,3.0]@
76
77-}
78buildVector :: Storable a => Int -> (Int -> a) -> Vector a
79buildVector len f =
80 fromList $ map f [0 .. (len - 1)]
81
82
83-- | zip for Vectors
84zipVector :: (Storable a, Storable b, Storable (a,b)) => Vector a -> Vector b -> Vector (a,b)
85zipVector = zipVectorWith (,)
86
87-- | unzip for Vectors
88unzipVector :: (Storable a, Storable b, Storable (a,b)) => Vector (a,b) -> (Vector a,Vector b)
89unzipVector = unzipVectorWith id
90
91-------------------------------------------------------------------
92
93{-# DEPRECATED join "use vjoin or Data.Vector.concat" #-}
94join :: Storable t => [Vector t] -> Vector t
95join = vjoin
96
diff --git a/packages/hmatrix/src/Graphics/Plot.hs b/packages/hmatrix/src/Graphics/Plot.hs
new file mode 100644
index 0000000..0ea41ac
--- /dev/null
+++ b/packages/hmatrix/src/Graphics/Plot.hs
@@ -0,0 +1,184 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Graphics.Plot
4-- Copyright : (c) Alberto Ruiz 2005-8
5-- License : GPL-style
6--
7-- Maintainer : Alberto Ruiz (aruiz at um dot es)
8-- Stability : provisional
9-- Portability : uses gnuplot and ImageMagick
10--
11-- This module is deprecated. It can be replaced by improved drawing tools
12-- available in the plot\\plot-gtk packages by Vivian McPhail or Gnuplot by Henning Thielemann.
13-----------------------------------------------------------------------------
14{-# OPTIONS_HADDOCK hide #-}
15
16module Graphics.Plot(
17
18 mplot,
19
20 plot, parametricPlot,
21
22 splot, mesh, meshdom,
23
24 matrixToPGM, imshow,
25
26 gnuplotX, gnuplotpdf, gnuplotWin
27
28) where
29
30import Numeric.Container
31import Data.List(intersperse)
32import System.Process (system)
33
34-- | From vectors x and y, it generates a pair of matrices to be used as x and y arguments for matrix functions.
35meshdom :: Vector Double -> Vector Double -> (Matrix Double , Matrix Double)
36meshdom r1 r2 = (outer r1 (constant 1 (dim r2)), outer (constant 1 (dim r1)) r2)
37
38
39{- | Draws a 3D surface representation of a real matrix.
40
41> > mesh $ build (10,10) (\\i j -> i + (j-5)^2)
42
43In certain versions you can interactively rotate the graphic using the mouse.
44
45-}
46mesh :: Matrix Double -> IO ()
47mesh m = gnuplotX (command++dat) where
48 command = "splot "++datafollows++" matrix with lines\n"
49 dat = prep $ toLists m
50
51{- | Draws the surface represented by the function f in the desired ranges and number of points, internally using 'mesh'.
52
53> > let f x y = cos (x + y)
54> > splot f (0,pi) (0,2*pi) 50
55
56-}
57splot :: (Matrix Double->Matrix Double->Matrix Double) -> (Double,Double) -> (Double,Double) -> Int -> IO ()
58splot f rx ry n = mesh z where
59 (x,y) = meshdom (linspace n rx) (linspace n ry)
60 z = f x y
61
62{- | plots several vectors against the first one
63
64> > let t = linspace 100 (-3,3) in mplot [t, sin t, exp (-t^2)]
65
66-}
67mplot :: [Vector Double] -> IO ()
68mplot m = gnuplotX (commands++dats) where
69 commands = if length m == 1 then command1 else commandmore
70 command1 = "plot "++datafollows++" with lines\n" ++ dat
71 commandmore = "plot " ++ plots ++ "\n"
72 plots = concat $ intersperse ", " (map cmd [2 .. length m])
73 cmd k = datafollows++" using 1:"++show k++" with lines"
74 dat = prep $ toLists $ fromColumns m
75 dats = concat (replicate (length m-1) dat)
76
77
78{- | Draws a list of functions over a desired range and with a desired number of points
79
80> > plot [sin, cos, sin.(3*)] (0,2*pi) 1000
81
82-}
83plot :: [Vector Double->Vector Double] -> (Double,Double) -> Int -> IO ()
84plot fs rx n = mplot (x: mapf fs x)
85 where x = linspace n rx
86 mapf gs y = map ($ y) gs
87
88{- | Draws a parametric curve. For instance, to draw a spiral we can do something like:
89
90> > parametricPlot (\t->(t * sin t, t * cos t)) (0,10*pi) 1000
91
92-}
93parametricPlot :: (Vector Double->(Vector Double,Vector Double)) -> (Double, Double) -> Int -> IO ()
94parametricPlot f rt n = mplot [fx, fy]
95 where t = linspace n rt
96 (fx,fy) = f t
97
98
99-- | writes a matrix to pgm image file
100matrixToPGM :: Matrix Double -> String
101matrixToPGM m = header ++ unlines (map unwords ll) where
102 c = cols m
103 r = rows m
104 header = "P2 "++show c++" "++show r++" "++show (round maxgray :: Int)++"\n"
105 maxgray = 255.0
106 maxval = maxElement m
107 minval = minElement m
108 scale' = if maxval == minval
109 then 0.0
110 else maxgray / (maxval - minval)
111 f x = show ( round ( scale' *(x - minval) ) :: Int )
112 ll = map (map f) (toLists m)
113
114-- | imshow shows a representation of a matrix as a gray level image using ImageMagick's display.
115imshow :: Matrix Double -> IO ()
116imshow m = do
117 _ <- system $ "echo \""++ matrixToPGM m ++"\"| display -antialias -resize 300 - &"
118 return ()
119
120----------------------------------------------------
121
122gnuplotX :: String -> IO ()
123gnuplotX command = do { _ <- system cmdstr; return()} where
124 cmdstr = "echo \""++command++"\" | gnuplot -persist"
125
126datafollows = "\\\"-\\\""
127
128prep = (++"e\n\n") . unlines . map (unwords . map show)
129
130
131gnuplotpdf :: String -> String -> [([[Double]], String)] -> IO ()
132gnuplotpdf title command ds = gnuplot (prelude ++ command ++" "++ draw) >> postproc where
133 prelude = "set terminal epslatex color; set output '"++title++".tex';"
134 (dats,defs) = unzip ds
135 draw = concat (intersperse ", " (map ("\"-\" "++) defs)) ++ "\n" ++
136 concatMap pr dats
137 postproc = do
138 _ <- system $ "epstopdf "++title++".eps"
139 mklatex
140 _ <- system $ "pdflatex "++title++"aux.tex > /dev/null"
141 _ <- system $ "pdfcrop "++title++"aux.pdf > /dev/null"
142 _ <- system $ "mv "++title++"aux-crop.pdf "++title++".pdf"
143 _ <- system $ "rm "++title++"aux.* "++title++".eps "++title++".tex"
144 return ()
145
146 mklatex = writeFile (title++"aux.tex") $
147 "\\documentclass{article}\n"++
148 "\\usepackage{graphics}\n"++
149 "\\usepackage{nopageno}\n"++
150 "\\usepackage{txfonts}\n"++
151 "\\renewcommand{\\familydefault}{phv}\n"++
152 "\\usepackage[usenames]{color}\n"++
153
154 "\\begin{document}\n"++
155
156 "\\begin{center}\n"++
157 " \\input{./"++title++".tex}\n"++
158 "\\end{center}\n"++
159
160 "\\end{document}"
161
162 pr = (++"e\n") . unlines . map (unwords . map show)
163
164 gnuplot cmd = do
165 writeFile "gnuplotcommand" cmd
166 _ <- system "gnuplot gnuplotcommand"
167 _ <- system "rm gnuplotcommand"
168 return ()
169
170gnuplotWin :: String -> String -> [([[Double]], String)] -> IO ()
171gnuplotWin title command ds = gnuplot (prelude ++ command ++" "++ draw) where
172 (dats,defs) = unzip ds
173 draw = concat (intersperse ", " (map ("\"-\" "++) defs)) ++ "\n" ++
174 concatMap pr dats
175
176 pr = (++"e\n") . unlines . map (unwords . map show)
177
178 prelude = "set title \""++title++"\";"
179
180 gnuplot cmd = do
181 writeFile "gnuplotcommand" cmd
182 _ <- system "gnuplot -persist gnuplotcommand"
183 _ <- system "rm gnuplotcommand"
184 return ()
diff --git a/packages/hmatrix/src/Numeric/Chain.hs b/packages/hmatrix/src/Numeric/Chain.hs
new file mode 100644
index 0000000..e1ab7da
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/Chain.hs
@@ -0,0 +1,140 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Numeric.Chain
4-- Copyright : (c) Vivian McPhail 2010
5-- License : GPL-style
6--
7-- Maintainer : Vivian McPhail <haskell.vivian.mcphail <at> gmail.com>
8-- Stability : provisional
9-- Portability : portable
10--
11-- optimisation of association order for chains of matrix multiplication
12--
13-----------------------------------------------------------------------------
14
15module Numeric.Chain (
16 optimiseMult,
17 ) where
18
19import Data.Maybe
20
21import Data.Packed.Matrix
22import Numeric.ContainerBoot
23
24import qualified Data.Array.IArray as A
25
26-----------------------------------------------------------------------------
27{- |
28 Provide optimal association order for a chain of matrix multiplications
29 and apply the multiplications.
30
31 The algorithm is the well-known O(n\^3) dynamic programming algorithm
32 that builds a pyramid of optimal associations.
33
34> m1, m2, m3, m4 :: Matrix Double
35> m1 = (10><15) [1..]
36> m2 = (15><20) [1..]
37> m3 = (20><5) [1..]
38> m4 = (5><10) [1..]
39
40> >>> optimiseMult [m1,m2,m3,m4]
41
42will perform @((m1 `multiply` (m2 `multiply` m3)) `multiply` m4)@
43
44The naive left-to-right multiplication would take @4500@ scalar multiplications
45whereas the optimised version performs @2750@ scalar multiplications. The complexity
46in this case is 32 (= 4^3/2) * (2 comparisons, 3 scalar multiplications, 3 scalar additions,
475 lookups, 2 updates) + a constant (= three table allocations)
48-}
49optimiseMult :: Product t => [Matrix t] -> Matrix t
50optimiseMult = chain
51
52-----------------------------------------------------------------------------
53
54type Matrices a = A.Array Int (Matrix a)
55type Sizes = A.Array Int (Int,Int)
56type Cost = A.Array Int (A.Array Int (Maybe Int))
57type Indexes = A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
58
59update :: A.Array Int (A.Array Int a) -> (Int,Int) -> a -> A.Array Int (A.Array Int a)
60update a (r,c) e = a A.// [(r,(a A.! r) A.// [(c,e)])]
61
62newWorkSpaceCost :: Int -> A.Array Int (A.Array Int (Maybe Int))
63newWorkSpaceCost n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
64 where subArray i = A.listArray (1,i) (repeat Nothing)
65
66newWorkSpaceIndexes :: Int -> A.Array Int (A.Array Int (Maybe ((Int,Int),(Int,Int))))
67newWorkSpaceIndexes n = A.array (1,n) $ map (\i -> (i, subArray i)) [1..n]
68 where subArray i = A.listArray (1,i) (repeat Nothing)
69
70matricesToSizes :: [Matrix a] -> Sizes
71matricesToSizes ms = A.listArray (1,length ms) $ map (\m -> (rows m,cols m)) ms
72
73chain :: Product a => [Matrix a] -> Matrix a
74chain [] = error "chain: zero matrices to multiply"
75chain [m] = m
76chain [ml,mr] = ml `multiply` mr
77chain ms = let ln = length ms
78 ma = A.listArray (1,ln) ms
79 mz = matricesToSizes ms
80 i = chain_cost mz
81 in chain_paren (ln,ln) i ma
82
83chain_cost :: Sizes -> Indexes
84chain_cost mz = let (_,u) = A.bounds mz
85 cost = newWorkSpaceCost u
86 ixes = newWorkSpaceIndexes u
87 (_,_,i) = foldl chain_cost' (mz,cost,ixes) (order u)
88 in i
89
90chain_cost' :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
91chain_cost' sci@(mz,cost,ixes) (r,c)
92 | c == 1 = let cost' = update cost (r,c) (Just 0)
93 ixes' = update ixes (r,c) (Just ((r,c),(r,c)))
94 in (mz,cost',ixes')
95 | otherwise = minimum_cost sci (r,c)
96
97minimum_cost :: (Sizes,Cost,Indexes) -> (Int,Int) -> (Sizes,Cost,Indexes)
98minimum_cost sci fu = foldl (smaller_cost fu) sci (fulcrum_order fu)
99
100smaller_cost :: (Int,Int) -> (Sizes,Cost,Indexes) -> ((Int,Int),(Int,Int)) -> (Sizes,Cost,Indexes)
101smaller_cost (r,c) (mz,cost,ixes) ix@((lr,lc),(rr,rc)) = let op_cost = fromJust ((cost A.! lr) A.! lc)
102 + fromJust ((cost A.! rr) A.! rc)
103 + fst (mz A.! (lr-lc+1))
104 * snd (mz A.! lc)
105 * snd (mz A.! rr)
106 cost' = (cost A.! r) A.! c
107 in case cost' of
108 Nothing -> let cost'' = update cost (r,c) (Just op_cost)
109 ixes'' = update ixes (r,c) (Just ix)
110 in (mz,cost'',ixes'')
111 Just ct -> if op_cost < ct then
112 let cost'' = update cost (r,c) (Just op_cost)
113 ixes'' = update ixes (r,c) (Just ix)
114 in (mz,cost'',ixes'')
115 else (mz,cost,ixes)
116
117
118fulcrum_order (r,c) = let fs' = zip (repeat r) [1..(c-1)]
119 in map (partner (r,c)) fs'
120
121partner (r,c) (a,b) = ((r-b, c-b), (a,b))
122
123order 0 = []
124order n = order (n-1) ++ zip (repeat n) [1..n]
125
126chain_paren :: Product a => (Int,Int) -> Indexes -> Matrices a -> Matrix a
127chain_paren (r,c) ixes ma = let ((lr,lc),(rr,rc)) = fromJust $ (ixes A.! r) A.! c
128 in if lr == rr && lc == rc then (ma A.! lr)
129 else (chain_paren (lr,lc) ixes ma) `multiply` (chain_paren (rr,rc) ixes ma)
130
131--------------------------------------------------------------------------
132
133{- TESTS -}
134
135-- optimal association is ((m1*(m2*m3))*m4)
136m1, m2, m3, m4 :: Matrix Double
137m1 = (10><15) [1..]
138m2 = (15><20) [1..]
139m3 = (20><5) [1..]
140m4 = (5><10) [1..] \ No newline at end of file
diff --git a/packages/hmatrix/src/Numeric/Container.hs b/packages/hmatrix/src/Numeric/Container.hs
new file mode 100644
index 0000000..7e46147
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/Container.hs
@@ -0,0 +1,303 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-}
7
8-----------------------------------------------------------------------------
9-- |
10-- Module : Numeric.Container
11-- Copyright : (c) Alberto Ruiz 2010-14
12-- License : GPL
13--
14-- Maintainer : Alberto Ruiz <aruiz@um.es>
15-- Stability : provisional
16-- Portability : portable
17--
18-- Basic numeric operations on 'Vector' and 'Matrix', including conversion routines.
19--
20-- The 'Container' class is used to define optimized generic functions which work
21-- on 'Vector' and 'Matrix' with real or complex elements.
22--
23-- Some of these functions are also available in the instances of the standard
24-- numeric Haskell classes provided by "Numeric.LinearAlgebra".
25--
26-----------------------------------------------------------------------------
27{-# OPTIONS_HADDOCK hide #-}
28
29module Numeric.Container (
30 -- * Basic functions
31 module Data.Packed,
32 konst, build,
33 constant, linspace,
34 diag, ident,
35 ctrans,
36 -- * Generic operations
37 Container(..),
38 -- * Matrix product
39 Product(..), udot,
40 Mul(..),
41 Contraction(..),
42 optimiseMult,
43 mXm,mXv,vXm,LSDiv(..), cdot, (·), dot, (<.>),
44 outer, kronecker,
45 -- * Random numbers
46 RandDist(..),
47 randomVector,
48 gaussianSample,
49 uniformSample,
50 meanCov,
51 -- * Element conversion
52 Convert(..),
53 Complexable(),
54 RealElement(),
55
56 RealOf, ComplexOf, SingleOf, DoubleOf,
57
58 IndexOf,
59 module Data.Complex,
60 -- * IO
61 dispf, disps, dispcf, vecdisp, latexFormat, format,
62 loadMatrix, saveMatrix, fromFile, fileDimensions,
63 readMatrix,
64 fscanfVector, fprintfVector, freadVector, fwriteVector,
65) where
66
67import Data.Packed
68import Data.Packed.Internal(constantD)
69import Numeric.ContainerBoot
70import Numeric.Chain
71import Numeric.IO
72import Data.Complex
73import Numeric.LinearAlgebra.Algorithms(Field,linearSolveSVD)
74import Data.Packed.Random
75
76------------------------------------------------------------------
77
78{- | creates a vector with a given number of equal components:
79
80@> constant 2 7
817 |> [2.0,2.0,2.0,2.0,2.0,2.0,2.0]@
82-}
83constant :: Element a => a -> Int -> Vector a
84-- constant x n = runSTVector (newVector x n)
85constant = constantD-- about 2x faster
86
87{- | Creates a real vector containing a range of values:
88
89>>> linspace 5 (-3,7::Double)
90fromList [-3.0,-0.5,2.0,4.5,7.0]@
91
92>>> linspace 5 (8,2+i) :: Vector (Complex Double)
93fromList [8.0 :+ 0.0,6.5 :+ 0.25,5.0 :+ 0.5,3.5 :+ 0.75,2.0 :+ 1.0]
94
95Logarithmic spacing can be defined as follows:
96
97@logspace n (a,b) = 10 ** linspace n (a,b)@
98-}
99linspace :: (Container Vector e) => Int -> (e, e) -> Vector e
100linspace 0 (a,b) = fromList[(a+b)/2]
101linspace n (a,b) = addConstant a $ scale s $ fromList $ map fromIntegral [0 .. n-1]
102 where s = (b-a)/fromIntegral (n-1)
103
104-- | dot product: @cdot u v = 'udot' ('conj' u) v@
105cdot :: (Container Vector t, Product t) => Vector t -> Vector t -> t
106cdot u v = udot (conj u) v
107
108--------------------------------------------------------
109
110class Contraction a b c | a b -> c, c -> a b
111 where
112 infixr 7 ×
113 {- | Matrix-matrix product, matrix-vector product, and unconjugated dot product
114
115(unicode 0x00d7, multiplication sign)
116
117Examples:
118
119>>> let a = (3><4) [1..] :: Matrix Double
120>>> let v = fromList [1,0,2,-1] :: Vector Double
121>>> let u = fromList [1,2,3] :: Vector Double
122
123>>> a
124(3><4)
125 [ 1.0, 2.0, 3.0, 4.0
126 , 5.0, 6.0, 7.0, 8.0
127 , 9.0, 10.0, 11.0, 12.0 ]
128
129matrix × matrix:
130
131>>> disp 2 (a × trans a)
1323x3
133 30 70 110
134 70 174 278
135110 278 446
136
137matrix × vector:
138
139>>> a × v
140fromList [3.0,11.0,19.0]
141
142unconjugated dot product:
143
144>>> fromList [1,i] × fromList[2*i+1,3]
1451.0 :+ 5.0
146
147(×) is right associative, so we can write:
148
149>>> u × a × v
15082.0 :: Double
151
152-}
153 (×) :: a -> b -> c
154
155instance Product t => Contraction (Matrix t) (Vector t) (Vector t) where
156 (×) = mXv
157
158instance Product t => Contraction (Matrix t) (Matrix t) (Matrix t) where
159 (×) = mXm
160
161instance Contraction (Vector Double) (Vector Double) Double where
162 (×) = udot
163
164instance Contraction (Vector Float) (Vector Float) Float where
165 (×) = udot
166
167instance Contraction (Vector (Complex Double)) (Vector (Complex Double)) (Complex Double) where
168 (×) = udot
169
170instance Contraction (Vector (Complex Float)) (Vector (Complex Float)) (Complex Float) where
171 (×) = udot
172
173
174-- | alternative function for the matrix product (×)
175mmul :: Contraction a b c => a -> b -> c
176mmul = (×)
177
178--------------------------------------------------------------------------------
179
180class Mul a b c | a b -> c where
181 infixl 7 <>
182 -- | Matrix-matrix, matrix-vector, and vector-matrix products.
183 (<>) :: Product t => a t -> b t -> c t
184
185instance Mul Matrix Matrix Matrix where
186 (<>) = mXm
187
188instance Mul Matrix Vector Vector where
189 (<>) m v = flatten $ m <> asColumn v
190
191instance Mul Vector Matrix Vector where
192 (<>) v m = flatten $ asRow v <> m
193
194--------------------------------------------------------------------------------
195
196class LSDiv c where
197 infixl 7 <\>
198 -- | least squares solution of a linear system, similar to the \\ operator of Matlab\/Octave (based on linearSolveSVD)
199 (<\>) :: Field t => Matrix t -> c t -> c t
200
201instance LSDiv Vector where
202 m <\> v = flatten (linearSolveSVD m (reshape 1 v))
203
204instance LSDiv Matrix where
205 (<\>) = linearSolveSVD
206
207--------------------------------------------------------
208
209{- | Dot product : @u · v = 'cdot' u v@
210
211 (unicode 0x00b7, middle dot, Alt-Gr .)
212
213>>> fromList [1,i] · fromList[2*i+1,3]
2141.0 :+ (-1.0)
215
216-}
217(·) :: (Container Vector t, Product t) => Vector t -> Vector t -> t
218infixl 7 ·
219u · v = cdot u v
220
221--------------------------------------------------------------------------------
222
223-- bidirectional type inference
224class Konst e d c | d -> c, c -> d
225 where
226 -- |
227 -- >>> konst 7 3 :: Vector Float
228 -- fromList [7.0,7.0,7.0]
229 --
230 -- >>> konst i (3::Int,4::Int)
231 -- (3><4)
232 -- [ 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0
233 -- , 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0
234 -- , 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0, 0.0 :+ 1.0 ]
235 --
236 konst :: e -> d -> c e
237
238instance Container Vector e => Konst e Int Vector
239 where
240 konst = konst'
241
242instance Container Vector e => Konst e (Int,Int) Matrix
243 where
244 konst = konst'
245
246--------------------------------------------------------------------------------
247
248class Build d f c e | d -> c, c -> d, f -> e, f -> d, f -> c, c e -> f, d e -> f
249 where
250 -- |
251 -- >>> build 5 (**2) :: Vector Double
252 -- fromList [0.0,1.0,4.0,9.0,16.0]
253 --
254 -- Hilbert matrix of order N:
255 --
256 -- >>> let hilb n = build (n,n) (\i j -> 1/(i+j+1)) :: Matrix Double
257 -- >>> putStr . dispf 2 $ hilb 3
258 -- 3x3
259 -- 1.00 0.50 0.33
260 -- 0.50 0.33 0.25
261 -- 0.33 0.25 0.20
262 --
263 build :: d -> f -> c e
264
265instance Container Vector e => Build Int (e -> e) Vector e
266 where
267 build = build'
268
269instance Container Matrix e => Build (Int,Int) (e -> e -> e) Matrix e
270 where
271 build = build'
272
273--------------------------------------------------------------------------------
274
275{- | Compute mean vector and covariance matrix of the rows of a matrix.
276
277>>> meanCov $ gaussianSample 666 1000 (fromList[4,5]) (diagl[2,3])
278(fromList [4.010341078059521,5.0197204699640405],
279(2><2)
280 [ 1.9862461923890056, -1.0127225830525157e-2
281 , -1.0127225830525157e-2, 3.0373954915729318 ])
282
283-}
284meanCov :: Matrix Double -> (Vector Double, Matrix Double)
285meanCov x = (med,cov) where
286 r = rows x
287 k = 1 / fromIntegral r
288 med = konst k r `vXm` x
289 meds = konst 1 r `outer` med
290 xc = x `sub` meds
291 cov = scale (recip (fromIntegral (r-1))) (trans xc `mXm` xc)
292
293--------------------------------------------------------------------------------
294
295{-# DEPRECATED dot "use udot" #-}
296dot :: Product e => Vector e -> Vector e -> e
297dot = udot
298
299-- | contraction operator, equivalent to (x)
300infixr 7 <.>
301(<.>) :: Contraction a b c => a -> b -> c
302(<.>) = (×)
303
diff --git a/packages/hmatrix/src/Numeric/ContainerBoot.hs b/packages/hmatrix/src/Numeric/ContainerBoot.hs
new file mode 100644
index 0000000..ea4262c
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/ContainerBoot.hs
@@ -0,0 +1,611 @@
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeFamilies #-}
3{-# LANGUAGE FlexibleContexts #-}
4{-# LANGUAGE FlexibleInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6{-# LANGUAGE UndecidableInstances #-}
7
8-----------------------------------------------------------------------------
9-- |
10-- Module : Numeric.ContainerBoot
11-- Copyright : (c) Alberto Ruiz 2010
12-- License : GPL-style
13--
14-- Maintainer : Alberto Ruiz <aruiz@um.es>
15-- Stability : provisional
16-- Portability : portable
17--
18-- Module to avoid cyclyc dependencies.
19--
20-----------------------------------------------------------------------------
21
22module Numeric.ContainerBoot (
23 -- * Basic functions
24 ident, diag, ctrans,
25 -- * Generic operations
26 Container(..),
27 -- * Matrix product and related functions
28 Product(..), udot,
29 mXm,mXv,vXm,
30 outer, kronecker,
31 -- * Element conversion
32 Convert(..),
33 Complexable(),
34 RealElement(),
35
36 RealOf, ComplexOf, SingleOf, DoubleOf,
37
38 IndexOf,
39 module Data.Complex
40) where
41
42import Data.Packed
43import Data.Packed.ST as ST
44import Numeric.Conversion
45import Data.Packed.Internal
46import Numeric.GSL.Vector
47import Data.Complex
48import Control.Applicative((<*>))
49
50import Numeric.LinearAlgebra.LAPACK(multiplyR,multiplyC,multiplyF,multiplyQ)
51
52-------------------------------------------------------------------
53
54type family IndexOf (c :: * -> *)
55
56type instance IndexOf Vector = Int
57type instance IndexOf Matrix = (Int,Int)
58
59type family ArgOf (c :: * -> *) a
60
61type instance ArgOf Vector a = a -> a
62type instance ArgOf Matrix a = a -> a -> a
63
64-------------------------------------------------------------------
65
66-- | Basic element-by-element functions for numeric containers
67class (Complexable c, Fractional e, Element e) => Container c e where
68 -- | create a structure with a single element
69 --
70 -- >>> let v = fromList [1..3::Double]
71 -- >>> v / scalar (norm2 v)
72 -- fromList [0.2672612419124244,0.5345224838248488,0.8017837257372732]
73 --
74 scalar :: e -> c e
75 -- | complex conjugate
76 conj :: c e -> c e
77 scale :: e -> c e -> c e
78 -- | scale the element by element reciprocal of the object:
79 --
80 -- @scaleRecip 2 (fromList [5,i]) == 2 |> [0.4 :+ 0.0,0.0 :+ (-2.0)]@
81 scaleRecip :: e -> c e -> c e
82 addConstant :: e -> c e -> c e
83 add :: c e -> c e -> c e
84 sub :: c e -> c e -> c e
85 -- | element by element multiplication
86 mul :: c e -> c e -> c e
87 -- | element by element division
88 divide :: c e -> c e -> c e
89 equal :: c e -> c e -> Bool
90 --
91 -- element by element inverse tangent
92 arctan2 :: c e -> c e -> c e
93 --
94 -- | cannot implement instance Functor because of Element class constraint
95 cmap :: (Element b) => (e -> b) -> c e -> c b
96 -- | constant structure of given size
97 konst' :: e -> IndexOf c -> c e
98 -- | create a structure using a function
99 --
100 -- Hilbert matrix of order N:
101 --
102 -- @hilb n = build' (n,n) (\\i j -> 1/(i+j+1))@
103 build' :: IndexOf c -> (ArgOf c e) -> c e
104 -- | indexing function
105 atIndex :: c e -> IndexOf c -> e
106 -- | index of min element
107 minIndex :: c e -> IndexOf c
108 -- | index of max element
109 maxIndex :: c e -> IndexOf c
110 -- | value of min element
111 minElement :: c e -> e
112 -- | value of max element
113 maxElement :: c e -> e
114 -- the C functions sumX/prodX are twice as fast as using foldVector
115 -- | the sum of elements (faster than using @fold@)
116 sumElements :: c e -> e
117 -- | the product of elements (faster than using @fold@)
118 prodElements :: c e -> e
119
120 -- | A more efficient implementation of @cmap (\\x -> if x>0 then 1 else 0)@
121 --
122 -- >>> step $ linspace 5 (-1,1::Double)
123 -- 5 |> [0.0,0.0,0.0,1.0,1.0]
124 --
125
126 step :: RealElement e => c e -> c e
127
128 -- | Element by element version of @case compare a b of {LT -> l; EQ -> e; GT -> g}@.
129 --
130 -- Arguments with any dimension = 1 are automatically expanded:
131 --
132 -- >>> cond ((1><4)[1..]) ((3><1)[1..]) 0 100 ((3><4)[1..]) :: Matrix Double
133 -- (3><4)
134 -- [ 100.0, 2.0, 3.0, 4.0
135 -- , 0.0, 100.0, 7.0, 8.0
136 -- , 0.0, 0.0, 100.0, 12.0 ]
137 --
138
139 cond :: RealElement e
140 => c e -- ^ a
141 -> c e -- ^ b
142 -> c e -- ^ l
143 -> c e -- ^ e
144 -> c e -- ^ g
145 -> c e -- ^ result
146
147 -- | Find index of elements which satisfy a predicate
148 --
149 -- >>> find (>0) (ident 3 :: Matrix Double)
150 -- [(0,0),(1,1),(2,2)]
151 --
152
153 find :: (e -> Bool) -> c e -> [IndexOf c]
154
155 -- | Create a structure from an association list
156 --
157 -- >>> assoc 5 0 [(3,7),(1,4)] :: Vector Double
158 -- fromList [0.0,4.0,0.0,7.0,0.0]
159 --
160 -- >>> assoc (2,3) 0 [((0,2),7),((1,0),2*i-3)] :: Matrix (Complex Double)
161 -- (2><3)
162 -- [ 0.0 :+ 0.0, 0.0 :+ 0.0, 7.0 :+ 0.0
163 -- , (-3.0) :+ 2.0, 0.0 :+ 0.0, 0.0 :+ 0.0 ]
164 --
165 assoc :: IndexOf c -- ^ size
166 -> e -- ^ default value
167 -> [(IndexOf c, e)] -- ^ association list
168 -> c e -- ^ result
169
170 -- | Modify a structure using an update function
171 --
172 -- >>> accum (ident 5) (+) [((1,1),5),((0,3),3)] :: Matrix Double
173 -- (5><5)
174 -- [ 1.0, 0.0, 0.0, 3.0, 0.0
175 -- , 0.0, 6.0, 0.0, 0.0, 0.0
176 -- , 0.0, 0.0, 1.0, 0.0, 0.0
177 -- , 0.0, 0.0, 0.0, 1.0, 0.0
178 -- , 0.0, 0.0, 0.0, 0.0, 1.0 ]
179 --
180 -- computation of histogram:
181 --
182 -- >>> accum (konst 0 7) (+) (map (flip (,) 1) [4,5,4,1,5,2,5]) :: Vector Double
183 -- fromList [0.0,1.0,1.0,0.0,2.0,3.0,0.0]
184 --
185
186 accum :: c e -- ^ initial structure
187 -> (e -> e -> e) -- ^ update function
188 -> [(IndexOf c, e)] -- ^ association list
189 -> c e -- ^ result
190
191--------------------------------------------------------------------------
192
193instance Container Vector Float where
194 scale = vectorMapValF Scale
195 scaleRecip = vectorMapValF Recip
196 addConstant = vectorMapValF AddConstant
197 add = vectorZipF Add
198 sub = vectorZipF Sub
199 mul = vectorZipF Mul
200 divide = vectorZipF Div
201 equal u v = dim u == dim v && maxElement (vectorMapF Abs (sub u v)) == 0.0
202 arctan2 = vectorZipF ATan2
203 scalar x = fromList [x]
204 konst' = constantD
205 build' = buildV
206 conj = id
207 cmap = mapVector
208 atIndex = (@>)
209 minIndex = emptyErrorV "minIndex" (round . toScalarF MinIdx)
210 maxIndex = emptyErrorV "maxIndex" (round . toScalarF MaxIdx)
211 minElement = emptyErrorV "minElement" (toScalarF Min)
212 maxElement = emptyErrorV "maxElement" (toScalarF Max)
213 sumElements = sumF
214 prodElements = prodF
215 step = stepF
216 find = findV
217 assoc = assocV
218 accum = accumV
219 cond = condV condF
220
221instance Container Vector Double where
222 scale = vectorMapValR Scale
223 scaleRecip = vectorMapValR Recip
224 addConstant = vectorMapValR AddConstant
225 add = vectorZipR Add
226 sub = vectorZipR Sub
227 mul = vectorZipR Mul
228 divide = vectorZipR Div
229 equal u v = dim u == dim v && maxElement (vectorMapR Abs (sub u v)) == 0.0
230 arctan2 = vectorZipR ATan2
231 scalar x = fromList [x]
232 konst' = constantD
233 build' = buildV
234 conj = id
235 cmap = mapVector
236 atIndex = (@>)
237 minIndex = emptyErrorV "minIndex" (round . toScalarR MinIdx)
238 maxIndex = emptyErrorV "maxIndex" (round . toScalarR MaxIdx)
239 minElement = emptyErrorV "minElement" (toScalarR Min)
240 maxElement = emptyErrorV "maxElement" (toScalarR Max)
241 sumElements = sumR
242 prodElements = prodR
243 step = stepD
244 find = findV
245 assoc = assocV
246 accum = accumV
247 cond = condV condD
248
249instance Container Vector (Complex Double) where
250 scale = vectorMapValC Scale
251 scaleRecip = vectorMapValC Recip
252 addConstant = vectorMapValC AddConstant
253 add = vectorZipC Add
254 sub = vectorZipC Sub
255 mul = vectorZipC Mul
256 divide = vectorZipC Div
257 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
258 arctan2 = vectorZipC ATan2
259 scalar x = fromList [x]
260 konst' = constantD
261 build' = buildV
262 conj = conjugateC
263 cmap = mapVector
264 atIndex = (@>)
265 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
266 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
267 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
268 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
269 sumElements = sumC
270 prodElements = prodC
271 step = undefined -- cannot match
272 find = findV
273 assoc = assocV
274 accum = accumV
275 cond = undefined -- cannot match
276
277instance Container Vector (Complex Float) where
278 scale = vectorMapValQ Scale
279 scaleRecip = vectorMapValQ Recip
280 addConstant = vectorMapValQ AddConstant
281 add = vectorZipQ Add
282 sub = vectorZipQ Sub
283 mul = vectorZipQ Mul
284 divide = vectorZipQ Div
285 equal u v = dim u == dim v && maxElement (mapVector magnitude (sub u v)) == 0.0
286 arctan2 = vectorZipQ ATan2
287 scalar x = fromList [x]
288 konst' = constantD
289 build' = buildV
290 conj = conjugateQ
291 cmap = mapVector
292 atIndex = (@>)
293 minIndex = emptyErrorV "minIndex" (minIndex . fst . fromComplex . (mul <*> conj))
294 maxIndex = emptyErrorV "maxIndex" (maxIndex . fst . fromComplex . (mul <*> conj))
295 minElement = emptyErrorV "minElement" (atIndex <*> minIndex)
296 maxElement = emptyErrorV "maxElement" (atIndex <*> maxIndex)
297 sumElements = sumQ
298 prodElements = prodQ
299 step = undefined -- cannot match
300 find = findV
301 assoc = assocV
302 accum = accumV
303 cond = undefined -- cannot match
304
305---------------------------------------------------------------
306
307instance (Container Vector a) => Container Matrix a where
308 scale x = liftMatrix (scale x)
309 scaleRecip x = liftMatrix (scaleRecip x)
310 addConstant x = liftMatrix (addConstant x)
311 add = liftMatrix2 add
312 sub = liftMatrix2 sub
313 mul = liftMatrix2 mul
314 divide = liftMatrix2 divide
315 equal a b = cols a == cols b && flatten a `equal` flatten b
316 arctan2 = liftMatrix2 arctan2
317 scalar x = (1><1) [x]
318 konst' v (r,c) = matrixFromVector RowMajor r c (konst' v (r*c))
319 build' = buildM
320 conj = liftMatrix conj
321 cmap f = liftMatrix (mapVector f)
322 atIndex = (@@>)
323 minIndex = emptyErrorM "minIndex of Matrix" $
324 \m -> divMod (minIndex $ flatten m) (cols m)
325 maxIndex = emptyErrorM "maxIndex of Matrix" $
326 \m -> divMod (maxIndex $ flatten m) (cols m)
327 minElement = emptyErrorM "minElement of Matrix" (atIndex <*> minIndex)
328 maxElement = emptyErrorM "maxElement of Matrix" (atIndex <*> maxIndex)
329 sumElements = sumElements . flatten
330 prodElements = prodElements . flatten
331 step = liftMatrix step
332 find = findM
333 assoc = assocM
334 accum = accumM
335 cond = condM
336
337
338emptyErrorV msg f v =
339 if dim v > 0
340 then f v
341 else error $ msg ++ " of Vector with dim = 0"
342
343emptyErrorM msg f m =
344 if rows m > 0 && cols m > 0
345 then f m
346 else error $ msg++" "++shSize m
347
348----------------------------------------------------
349
350-- | Matrix product and related functions
351class (Num e, Element e) => Product e where
352 -- | matrix product
353 multiply :: Matrix e -> Matrix e -> Matrix e
354 -- | sum of absolute value of elements (differs in complex case from @norm1@)
355 absSum :: Vector e -> RealOf e
356 -- | sum of absolute value of elements
357 norm1 :: Vector e -> RealOf e
358 -- | euclidean norm
359 norm2 :: Vector e -> RealOf e
360 -- | element of maximum magnitude
361 normInf :: Vector e -> RealOf e
362
363instance Product Float where
364 norm2 = emptyVal (toScalarF Norm2)
365 absSum = emptyVal (toScalarF AbsSum)
366 norm1 = emptyVal (toScalarF AbsSum)
367 normInf = emptyVal (maxElement . vectorMapF Abs)
368 multiply = emptyMul multiplyF
369
370instance Product Double where
371 norm2 = emptyVal (toScalarR Norm2)
372 absSum = emptyVal (toScalarR AbsSum)
373 norm1 = emptyVal (toScalarR AbsSum)
374 normInf = emptyVal (maxElement . vectorMapR Abs)
375 multiply = emptyMul multiplyR
376
377instance Product (Complex Float) where
378 norm2 = emptyVal (toScalarQ Norm2)
379 absSum = emptyVal (toScalarQ AbsSum)
380 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapQ Abs)
381 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapQ Abs)
382 multiply = emptyMul multiplyQ
383
384instance Product (Complex Double) where
385 norm2 = emptyVal (toScalarC Norm2)
386 absSum = emptyVal (toScalarC AbsSum)
387 norm1 = emptyVal (sumElements . fst . fromComplex . vectorMapC Abs)
388 normInf = emptyVal (maxElement . fst . fromComplex . vectorMapC Abs)
389 multiply = emptyMul multiplyC
390
391emptyMul m a b
392 | x1 == 0 && x2 == 0 || r == 0 || c == 0 = konst' 0 (r,c)
393 | otherwise = m a b
394 where
395 r = rows a
396 x1 = cols a
397 x2 = rows b
398 c = cols b
399
400emptyVal f v =
401 if dim v > 0
402 then f v
403 else 0
404
405-- FIXME remove unused C wrappers
406-- | (unconjugated) dot product
407udot :: Product e => Vector e -> Vector e -> e
408udot u v
409 | dim u == dim v = val (asRow u `multiply` asColumn v)
410 | otherwise = error $ "different dimensions "++show (dim u)++" and "++show (dim v)++" in dot product"
411 where
412 val m | dim u > 0 = m@@>(0,0)
413 | otherwise = 0
414
415----------------------------------------------------------
416
417-- synonym for matrix product
418mXm :: Product t => Matrix t -> Matrix t -> Matrix t
419mXm = multiply
420
421-- matrix - vector product
422mXv :: Product t => Matrix t -> Vector t -> Vector t
423mXv m v = flatten $ m `mXm` (asColumn v)
424
425-- vector - matrix product
426vXm :: Product t => Vector t -> Matrix t -> Vector t
427vXm v m = flatten $ (asRow v) `mXm` m
428
429{- | Outer product of two vectors.
430
431>>> fromList [1,2,3] `outer` fromList [5,2,3]
432(3><3)
433 [ 5.0, 2.0, 3.0
434 , 10.0, 4.0, 6.0
435 , 15.0, 6.0, 9.0 ]
436
437-}
438outer :: (Product t) => Vector t -> Vector t -> Matrix t
439outer u v = asColumn u `multiply` asRow v
440
441{- | Kronecker product of two matrices.
442
443@m1=(2><3)
444 [ 1.0, 2.0, 0.0
445 , 0.0, -1.0, 3.0 ]
446m2=(4><3)
447 [ 1.0, 2.0, 3.0
448 , 4.0, 5.0, 6.0
449 , 7.0, 8.0, 9.0
450 , 10.0, 11.0, 12.0 ]@
451
452>>> kronecker m1 m2
453(8><9)
454 [ 1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 0.0, 0.0, 0.0
455 , 4.0, 5.0, 6.0, 8.0, 10.0, 12.0, 0.0, 0.0, 0.0
456 , 7.0, 8.0, 9.0, 14.0, 16.0, 18.0, 0.0, 0.0, 0.0
457 , 10.0, 11.0, 12.0, 20.0, 22.0, 24.0, 0.0, 0.0, 0.0
458 , 0.0, 0.0, 0.0, -1.0, -2.0, -3.0, 3.0, 6.0, 9.0
459 , 0.0, 0.0, 0.0, -4.0, -5.0, -6.0, 12.0, 15.0, 18.0
460 , 0.0, 0.0, 0.0, -7.0, -8.0, -9.0, 21.0, 24.0, 27.0
461 , 0.0, 0.0, 0.0, -10.0, -11.0, -12.0, 30.0, 33.0, 36.0 ]
462
463-}
464kronecker :: (Product t) => Matrix t -> Matrix t -> Matrix t
465kronecker a b = fromBlocks
466 . splitEvery (cols a)
467 . map (reshape (cols b))
468 . toRows
469 $ flatten a `outer` flatten b
470
471-------------------------------------------------------------------
472
473
474class Convert t where
475 real :: Container c t => c (RealOf t) -> c t
476 complex :: Container c t => c t -> c (ComplexOf t)
477 single :: Container c t => c t -> c (SingleOf t)
478 double :: Container c t => c t -> c (DoubleOf t)
479 toComplex :: (Container c t, RealElement t) => (c t, c t) -> c (Complex t)
480 fromComplex :: (Container c t, RealElement t) => c (Complex t) -> (c t, c t)
481
482
483instance Convert Double where
484 real = id
485 complex = comp'
486 single = single'
487 double = id
488 toComplex = toComplex'
489 fromComplex = fromComplex'
490
491instance Convert Float where
492 real = id
493 complex = comp'
494 single = id
495 double = double'
496 toComplex = toComplex'
497 fromComplex = fromComplex'
498
499instance Convert (Complex Double) where
500 real = comp'
501 complex = id
502 single = single'
503 double = id
504 toComplex = toComplex'
505 fromComplex = fromComplex'
506
507instance Convert (Complex Float) where
508 real = comp'
509 complex = id
510 single = id
511 double = double'
512 toComplex = toComplex'
513 fromComplex = fromComplex'
514
515-------------------------------------------------------------------
516
517type family RealOf x
518
519type instance RealOf Double = Double
520type instance RealOf (Complex Double) = Double
521
522type instance RealOf Float = Float
523type instance RealOf (Complex Float) = Float
524
525type family ComplexOf x
526
527type instance ComplexOf Double = Complex Double
528type instance ComplexOf (Complex Double) = Complex Double
529
530type instance ComplexOf Float = Complex Float
531type instance ComplexOf (Complex Float) = Complex Float
532
533type family SingleOf x
534
535type instance SingleOf Double = Float
536type instance SingleOf Float = Float
537
538type instance SingleOf (Complex a) = Complex (SingleOf a)
539
540type family DoubleOf x
541
542type instance DoubleOf Double = Double
543type instance DoubleOf Float = Double
544
545type instance DoubleOf (Complex a) = Complex (DoubleOf a)
546
547type family ElementOf c
548
549type instance ElementOf (Vector a) = a
550type instance ElementOf (Matrix a) = a
551
552------------------------------------------------------------
553
554buildM (rc,cc) f = fromLists [ [f r c | c <- cs] | r <- rs ]
555 where rs = map fromIntegral [0 .. (rc-1)]
556 cs = map fromIntegral [0 .. (cc-1)]
557
558buildV n f = fromList [f k | k <- ks]
559 where ks = map fromIntegral [0 .. (n-1)]
560
561--------------------------------------------------------
562-- | conjugate transpose
563ctrans :: (Container Vector e, Element e) => Matrix e -> Matrix e
564ctrans = liftMatrix conj . trans
565
566-- | Creates a square matrix with a given diagonal.
567diag :: (Num a, Element a) => Vector a -> Matrix a
568diag v = diagRect 0 v n n where n = dim v
569
570-- | creates the identity matrix of given dimension
571ident :: (Num a, Element a) => Int -> Matrix a
572ident n = diag (constantD 1 n)
573
574--------------------------------------------------------
575
576findV p x = foldVectorWithIndex g [] x where
577 g k z l = if p z then k:l else l
578
579findM p x = map ((`divMod` cols x)) $ findV p (flatten x)
580
581assocV n z xs = ST.runSTVector $ do
582 v <- ST.newVector z n
583 mapM_ (\(k,x) -> ST.writeVector v k x) xs
584 return v
585
586assocM (r,c) z xs = ST.runSTMatrix $ do
587 m <- ST.newMatrix z r c
588 mapM_ (\((i,j),x) -> ST.writeMatrix m i j x) xs
589 return m
590
591accumV v0 f xs = ST.runSTVector $ do
592 v <- ST.thawVector v0
593 mapM_ (\(k,x) -> ST.modifyVector v k (f x)) xs
594 return v
595
596accumM m0 f xs = ST.runSTMatrix $ do
597 m <- ST.thawMatrix m0
598 mapM_ (\((i,j),x) -> ST.modifyMatrix m i j (f x)) xs
599 return m
600
601----------------------------------------------------------------------
602
603condM a b l e t = matrixFromVector RowMajor (rows a'') (cols a'') $ cond a' b' l' e' t'
604 where
605 args@(a'':_) = conformMs [a,b,l,e,t]
606 [a', b', l', e', t'] = map flatten args
607
608condV f a b l e t = f a' b' l' e' t'
609 where
610 [a', b', l', e', t'] = conformVs [a,b,l,e,t]
611
diff --git a/packages/hmatrix/src/Numeric/Conversion.hs b/packages/hmatrix/src/Numeric/Conversion.hs
new file mode 100644
index 0000000..8941451
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/Conversion.hs
@@ -0,0 +1,91 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE UndecidableInstances #-}
7
8-----------------------------------------------------------------------------
9-- |
10-- Module : Numeric.Conversion
11-- Copyright : (c) Alberto Ruiz 2010
12-- License : GPL-style
13--
14-- Maintainer : Alberto Ruiz <aruiz@um.es>
15-- Stability : provisional
16-- Portability : portable
17--
18-- Conversion routines
19--
20-----------------------------------------------------------------------------
21
22module Numeric.Conversion (
23 Complexable(..), RealElement,
24 module Data.Complex
25) where
26
27import Data.Packed.Internal.Vector
28import Data.Packed.Internal.Matrix
29import Data.Complex
30import Control.Arrow((***))
31
32-------------------------------------------------------------------
33
34-- | Supported single-double precision type pairs
35class (Element s, Element d) => Precision s d | s -> d, d -> s where
36 double2FloatG :: Vector d -> Vector s
37 float2DoubleG :: Vector s -> Vector d
38
39instance Precision Float Double where
40 double2FloatG = double2FloatV
41 float2DoubleG = float2DoubleV
42
43instance Precision (Complex Float) (Complex Double) where
44 double2FloatG = asComplex . double2FloatV . asReal
45 float2DoubleG = asComplex . float2DoubleV . asReal
46
47-- | Supported real types
48class (Element t, Element (Complex t), RealFloat t
49-- , RealOf t ~ t, RealOf (Complex t) ~ t
50 )
51 => RealElement t
52
53instance RealElement Double
54instance RealElement Float
55
56
57-- | Structures that may contain complex numbers
58class Complexable c where
59 toComplex' :: (RealElement e) => (c e, c e) -> c (Complex e)
60 fromComplex' :: (RealElement e) => c (Complex e) -> (c e, c e)
61 comp' :: (RealElement e) => c e -> c (Complex e)
62 single' :: Precision a b => c b -> c a
63 double' :: Precision a b => c a -> c b
64
65
66instance Complexable Vector where
67 toComplex' = toComplexV
68 fromComplex' = fromComplexV
69 comp' v = toComplex' (v,constantD 0 (dim v))
70 single' = double2FloatG
71 double' = float2DoubleG
72
73
74-- | creates a complex vector from vectors with real and imaginary parts
75toComplexV :: (RealElement a) => (Vector a, Vector a) -> Vector (Complex a)
76toComplexV (r,i) = asComplex $ flatten $ fromColumns [r,i]
77
78-- | the inverse of 'toComplex'
79fromComplexV :: (RealElement a) => Vector (Complex a) -> (Vector a, Vector a)
80fromComplexV z = (r,i) where
81 [r,i] = toColumns $ reshape 2 $ asReal z
82
83
84instance Complexable Matrix where
85 toComplex' = uncurry $ liftMatrix2 $ curry toComplex'
86 fromComplex' z = (reshape c *** reshape c) . fromComplex' . flatten $ z
87 where c = cols z
88 comp' = liftMatrix comp'
89 single' = liftMatrix single'
90 double' = liftMatrix double'
91
diff --git a/packages/hmatrix/src/Numeric/GSL.hs b/packages/hmatrix/src/Numeric/GSL.hs
new file mode 100644
index 0000000..5f39a3e
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL.hs
@@ -0,0 +1,43 @@
1{- |
2
3Module : Numeric.GSL
4Copyright : (c) Alberto Ruiz 2006-7
5License : GPL-style
6
7Maintainer : Alberto Ruiz (aruiz at um dot es)
8Stability : provisional
9Portability : uses -fffi and -fglasgow-exts
10
11This module reexports all available GSL functions.
12
13The GSL special functions are in the separate package hmatrix-special.
14
15-}
16
17module Numeric.GSL (
18 module Numeric.GSL.Integration
19, module Numeric.GSL.Differentiation
20, module Numeric.GSL.Fourier
21, module Numeric.GSL.Polynomials
22, module Numeric.GSL.Minimization
23, module Numeric.GSL.Root
24, module Numeric.GSL.ODE
25, module Numeric.GSL.Fitting
26, module Data.Complex
27, setErrorHandlerOff
28) where
29
30import Numeric.GSL.Integration
31import Numeric.GSL.Differentiation
32import Numeric.GSL.Fourier
33import Numeric.GSL.Polynomials
34import Numeric.GSL.Minimization
35import Numeric.GSL.Root
36import Numeric.GSL.ODE
37import Numeric.GSL.Fitting
38import Data.Complex
39
40
41-- | This action removes the GSL default error handler (which aborts the program), so that
42-- GSL errors can be handled by Haskell (using Control.Exception) and ghci doesn't abort.
43foreign import ccall unsafe "GSL/gsl-aux.h no_abort_on_error" setErrorHandlerOff :: IO ()
diff --git a/packages/hmatrix/src/Numeric/GSL/Differentiation.hs b/packages/hmatrix/src/Numeric/GSL/Differentiation.hs
new file mode 100644
index 0000000..93c5007
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/Differentiation.hs
@@ -0,0 +1,87 @@
1{-# OPTIONS #-}
2-----------------------------------------------------------------------------
3{- |
4Module : Numeric.GSL.Differentiation
5Copyright : (c) Alberto Ruiz 2006
6License : GPL-style
7
8Maintainer : Alberto Ruiz (aruiz at um dot es)
9Stability : provisional
10Portability : uses ffi
11
12Numerical differentiation.
13
14<http://www.gnu.org/software/gsl/manual/html_node/Numerical-Differentiation.html#Numerical-Differentiation>
15
16From the GSL manual: \"The functions described in this chapter compute numerical derivatives by finite differencing. An adaptive algorithm is used to find the best choice of finite difference and to estimate the error in the derivative.\"
17-}
18-----------------------------------------------------------------------------
19module Numeric.GSL.Differentiation (
20 derivCentral,
21 derivForward,
22 derivBackward
23) where
24
25import Foreign.C.Types
26import Foreign.Marshal.Alloc(malloc, free)
27import Foreign.Ptr(Ptr, FunPtr, freeHaskellFunPtr)
28import Foreign.Storable(peek)
29import Data.Packed.Internal(check,(//))
30import System.IO.Unsafe(unsafePerformIO)
31
32derivGen ::
33 CInt -- ^ type: 0 central, 1 forward, 2 backward
34 -> Double -- ^ initial step size
35 -> (Double -> Double) -- ^ function
36 -> Double -- ^ point where the derivative is taken
37 -> (Double, Double) -- ^ result and error
38derivGen c h f x = unsafePerformIO $ do
39 r <- malloc
40 e <- malloc
41 fp <- mkfun (\y _ -> f y)
42 c_deriv c fp x h r e // check "deriv"
43 vr <- peek r
44 ve <- peek e
45 let result = (vr,ve)
46 free r
47 free e
48 freeHaskellFunPtr fp
49 return result
50
51foreign import ccall safe "gsl-aux.h deriv"
52 c_deriv :: CInt -> FunPtr (Double -> Ptr () -> Double) -> Double -> Double
53 -> Ptr Double -> Ptr Double -> IO CInt
54
55
56{- | Adaptive central difference algorithm, /gsl_deriv_central/. For example:
57
58>>> let deriv = derivCentral 0.01
59>>> deriv sin (pi/4)
60(0.7071067812000676,1.0600063101654055e-10)
61>>> cos (pi/4)
620.7071067811865476
63
64-}
65derivCentral :: Double -- ^ initial step size
66 -> (Double -> Double) -- ^ function
67 -> Double -- ^ point where the derivative is taken
68 -> (Double, Double) -- ^ result and absolute error
69derivCentral = derivGen 0
70
71-- | Adaptive forward difference algorithm, /gsl_deriv_forward/. The function is evaluated only at points greater than x, and never at x itself. The derivative is returned in result and an estimate of its absolute error is returned in abserr. This function should be used if f(x) has a discontinuity at x, or is undefined for values less than x. A backward derivative can be obtained using a negative step.
72derivForward :: Double -- ^ initial step size
73 -> (Double -> Double) -- ^ function
74 -> Double -- ^ point where the derivative is taken
75 -> (Double, Double) -- ^ result and absolute error
76derivForward = derivGen 1
77
78-- | Adaptive backward difference algorithm, /gsl_deriv_backward/.
79derivBackward ::Double -- ^ initial step size
80 -> (Double -> Double) -- ^ function
81 -> Double -- ^ point where the derivative is taken
82 -> (Double, Double) -- ^ result and absolute error
83derivBackward = derivGen 2
84
85{- | conversion of Haskell functions into function pointers that can be used in the C side
86-}
87foreign import ccall safe "wrapper" mkfun:: (Double -> Ptr() -> Double) -> IO( FunPtr (Double -> Ptr() -> Double))
diff --git a/packages/hmatrix/src/Numeric/GSL/Fitting.hs b/packages/hmatrix/src/Numeric/GSL/Fitting.hs
new file mode 100644
index 0000000..c4f3a91
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/Fitting.hs
@@ -0,0 +1,179 @@
1{- |
2Module : Numeric.GSL.Fitting
3Copyright : (c) Alberto Ruiz 2010
4License : GPL
5
6Maintainer : Alberto Ruiz (aruiz at um dot es)
7Stability : provisional
8Portability : uses ffi
9
10Nonlinear Least-Squares Fitting
11
12<http://www.gnu.org/software/gsl/manual/html_node/Nonlinear-Least_002dSquares-Fitting.html>
13
14The example program in the GSL manual (see examples/fitting.hs):
15
16@
17dat = [
18 ([0.0],([6.0133918608118675],0.1)),
19 ([1.0],([5.5153769909966535],0.1)),
20 ([2.0],([5.261094606015287],0.1)),
21 ...
22 ([39.0],([1.0619821710802808],0.1))]
23
24expModel [a,lambda,b] [t] = [a * exp (-lambda * t) + b]
25
26expModelDer [a,lambda,b] [t] = [[exp (-lambda * t), -t * a * exp(-lambda*t) , 1]]
27
28(sol,path) = fitModelScaled 1E-4 1E-4 20 (expModel, expModelDer) dat [1,0,0]
29@
30
31>>> path
32(6><5)
33 [ 1.0, 76.45780563978782, 1.6465931240727802, 1.8147715267618197e-2, 0.6465931240727797
34 , 2.0, 37.683816318260355, 2.858760367632973, 8.092094813253975e-2, 1.4479636296208662
35 , 3.0, 9.5807893736187, 4.948995119561291, 0.11942927999921617, 1.0945766509238248
36 , 4.0, 5.630494933603935, 5.021755718065913, 0.10287787128056883, 1.0338835440862608
37 , 5.0, 5.443976278682909, 5.045204331329302, 0.10405523433131504, 1.019416067207375
38 , 6.0, 5.4439736648994685, 5.045357818922331, 0.10404905846029407, 1.0192487112786812 ]
39>>> sol
40[(5.045357818922331,6.027976702418132e-2),
41(0.10404905846029407,3.157045047172834e-3),
42(1.0192487112786812,3.782067731353722e-2)]
43
44-}
45-----------------------------------------------------------------------------
46
47module Numeric.GSL.Fitting (
48 -- * Levenberg-Marquardt
49 nlFitting, FittingMethod(..),
50 -- * Utilities
51 fitModelScaled, fitModel
52) where
53
54import Data.Packed.Internal
55import Numeric.LinearAlgebra
56import Numeric.GSL.Internal
57
58import Foreign.Ptr(FunPtr, freeHaskellFunPtr)
59import Foreign.C.Types
60import System.IO.Unsafe(unsafePerformIO)
61
62-------------------------------------------------------------------------
63
64data FittingMethod = LevenbergMarquardtScaled -- ^ Interface to gsl_multifit_fdfsolver_lmsder. This is a robust and efficient version of the Levenberg-Marquardt algorithm as implemented in the scaled lmder routine in minpack. Minpack was written by Jorge J. More, Burton S. Garbow and Kenneth E. Hillstrom.
65 | LevenbergMarquardt -- ^ This is an unscaled version of the lmder algorithm. The elements of the diagonal scaling matrix D are set to 1. This algorithm may be useful in circumstances where the scaled version of lmder converges too slowly, or the function is already scaled appropriately.
66 deriving (Enum,Eq,Show,Bounded)
67
68
69-- | Nonlinear multidimensional least-squares fitting.
70nlFitting :: FittingMethod
71 -> Double -- ^ absolute tolerance
72 -> Double -- ^ relative tolerance
73 -> Int -- ^ maximum number of iterations allowed
74 -> (Vector Double -> Vector Double) -- ^ function to be minimized
75 -> (Vector Double -> Matrix Double) -- ^ Jacobian
76 -> Vector Double -- ^ starting point
77 -> (Vector Double, Matrix Double) -- ^ solution vector and optimization path
78
79nlFitting method epsabs epsrel maxit fun jac xinit = nlFitGen (fi (fromEnum method)) fun jac xinit epsabs epsrel maxit
80
81nlFitGen m f jac xiv epsabs epsrel maxit = unsafePerformIO $ do
82 let p = dim xiv
83 n = dim (f xiv)
84 fp <- mkVecVecfun (aux_vTov (checkdim1 n p . f))
85 jp <- mkVecMatfun (aux_vTom (checkdim2 n p . jac))
86 rawpath <- createMatrix RowMajor maxit (2+p)
87 app2 (c_nlfit m fp jp epsabs epsrel (fi maxit) (fi n)) vec xiv mat rawpath "c_nlfit"
88 let it = round (rawpath @@> (maxit-1,0))
89 path = takeRows it rawpath
90 [sol] = toRows $ dropRows (it-1) path
91 freeHaskellFunPtr fp
92 freeHaskellFunPtr jp
93 return (subVector 2 p sol, path)
94
95foreign import ccall safe "nlfit"
96 c_nlfit:: CInt -> FunPtr TVV -> FunPtr TVM -> Double -> Double -> CInt -> CInt -> TVM
97
98-------------------------------------------------------
99
100checkdim1 n _p v
101 | dim v == n = v
102 | otherwise = error $ "Error: "++ show n
103 ++ " components expected in the result of the function supplied to nlFitting"
104
105checkdim2 n p m
106 | rows m == n && cols m == p = m
107 | otherwise = error $ "Error: "++ show n ++ "x" ++ show p
108 ++ " Jacobian expected in nlFitting"
109
110------------------------------------------------------------
111
112err (model,deriv) dat vsol = zip sol errs where
113 sol = toList vsol
114 c = max 1 (chi/sqrt (fromIntegral dof))
115 dof = length dat - (rows cov)
116 chi = norm2 (fromList $ cost (resMs model) dat sol)
117 js = fromLists $ jacobian (resDs deriv) dat sol
118 cov = inv $ trans js <> js
119 errs = toList $ scalar c * sqrt (takeDiag cov)
120
121
122
123-- | Higher level interface to 'nlFitting' 'LevenbergMarquardtScaled'. The optimization function and
124-- Jacobian are automatically built from a model f vs x = y and its derivatives, and a list of
125-- instances (x, (y,sigma)) to be fitted.
126
127fitModelScaled
128 :: Double -- ^ absolute tolerance
129 -> Double -- ^ relative tolerance
130 -> Int -- ^ maximum number of iterations allowed
131 -> ([Double] -> x -> [Double], [Double] -> x -> [[Double]]) -- ^ (model, derivatives)
132 -> [(x, ([Double], Double))] -- ^ instances
133 -> [Double] -- ^ starting point
134 -> ([(Double, Double)], Matrix Double) -- ^ (solution, error) and optimization path
135fitModelScaled epsabs epsrel maxit (model,deriv) dt xin = (err (model,deriv) dt sol, path) where
136 (sol,path) = nlFitting LevenbergMarquardtScaled epsabs epsrel maxit
137 (fromList . cost (resMs model) dt . toList)
138 (fromLists . jacobian (resDs deriv) dt . toList)
139 (fromList xin)
140
141
142
143-- | Higher level interface to 'nlFitting' 'LevenbergMarquardt'. The optimization function and
144-- Jacobian are automatically built from a model f vs x = y and its derivatives, and a list of
145-- instances (x,y) to be fitted.
146
147fitModel :: Double -- ^ absolute tolerance
148 -> Double -- ^ relative tolerance
149 -> Int -- ^ maximum number of iterations allowed
150 -> ([Double] -> x -> [Double], [Double] -> x -> [[Double]]) -- ^ (model, derivatives)
151 -> [(x, [Double])] -- ^ instances
152 -> [Double] -- ^ starting point
153 -> ([Double], Matrix Double) -- ^ solution and optimization path
154fitModel epsabs epsrel maxit (model,deriv) dt xin = (toList sol, path) where
155 (sol,path) = nlFitting LevenbergMarquardt epsabs epsrel maxit
156 (fromList . cost (resM model) dt . toList)
157 (fromLists . jacobian (resD deriv) dt . toList)
158 (fromList xin)
159
160cost model ds vs = concatMap (model vs) ds
161
162jacobian modelDer ds vs = concatMap (modelDer vs) ds
163
164-- | Model-to-residual for association pairs with sigma, to be used with 'fitModel'.
165resMs :: ([Double] -> x -> [Double]) -> [Double] -> (x, ([Double], Double)) -> [Double]
166resMs m v = \(x,(ys,s)) -> zipWith (g s) (m v x) ys where g s a b = (a-b)/s
167
168-- | Associated derivative for 'resMs'.
169resDs :: ([Double] -> x -> [[Double]]) -> [Double] -> (x, ([Double], Double)) -> [[Double]]
170resDs m v = \(x,(_,s)) -> map (map (/s)) (m v x)
171
172-- | Model-to-residual for association pairs, to be used with 'fitModel'. It is equivalent
173-- to 'resMs' with all sigmas = 1.
174resM :: ([Double] -> x -> [Double]) -> [Double] -> (x, [Double]) -> [Double]
175resM m v = \(x,ys) -> zipWith g (m v x) ys where g a b = a-b
176
177-- | Associated derivative for 'resM'.
178resD :: ([Double] -> x -> [[Double]]) -> [Double] -> (x, [Double]) -> [[Double]]
179resD m v = \(x,_) -> m v x
diff --git a/packages/hmatrix/src/Numeric/GSL/Fourier.hs b/packages/hmatrix/src/Numeric/GSL/Fourier.hs
new file mode 100644
index 0000000..86aedd6
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/Fourier.hs
@@ -0,0 +1,47 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2-----------------------------------------------------------------------------
3{- |
4Module : Numeric.GSL.Fourier
5Copyright : (c) Alberto Ruiz 2006
6License : GPL-style
7
8Maintainer : Alberto Ruiz (aruiz at um dot es)
9Stability : provisional
10Portability : uses ffi
11
12Fourier Transform.
13
14<http://www.gnu.org/software/gsl/manual/html_node/Fast-Fourier-Transforms.html#Fast-Fourier-Transforms>
15
16-}
17-----------------------------------------------------------------------------
18module Numeric.GSL.Fourier (
19 fft,
20 ifft
21) where
22
23import Data.Packed.Internal
24import Data.Complex
25import Foreign.C.Types
26import System.IO.Unsafe (unsafePerformIO)
27
28genfft code v = unsafePerformIO $ do
29 r <- createVector (dim v)
30 app2 (c_fft code) vec v vec r "fft"
31 return r
32
33foreign import ccall unsafe "gsl-aux.h fft" c_fft :: CInt -> TCVCV
34
35
36{- | Fast 1D Fourier transform of a 'Vector' @(@'Complex' 'Double'@)@ using /gsl_fft_complex_forward/. It uses the same scaling conventions as GNU Octave.
37
38>>> fft (fromList [1,2,3,4])
39fromList [10.0 :+ 0.0,(-2.0) :+ 2.0,(-2.0) :+ 0.0,(-2.0) :+ (-2.0)]
40
41-}
42fft :: Vector (Complex Double) -> Vector (Complex Double)
43fft = genfft 0
44
45-- | The inverse of 'fft', using /gsl_fft_complex_inverse/.
46ifft :: Vector (Complex Double) -> Vector (Complex Double)
47ifft = genfft 1
diff --git a/packages/hmatrix/src/Numeric/GSL/Integration.hs b/packages/hmatrix/src/Numeric/GSL/Integration.hs
new file mode 100644
index 0000000..5f0a415
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/Integration.hs
@@ -0,0 +1,250 @@
1{-# OPTIONS #-}
2-----------------------------------------------------------------------------
3{- |
4Module : Numeric.GSL.Integration
5Copyright : (c) Alberto Ruiz 2006
6License : GPL-style
7
8Maintainer : Alberto Ruiz (aruiz at um dot es)
9Stability : provisional
10Portability : uses ffi
11
12Numerical integration routines.
13
14<http://www.gnu.org/software/gsl/manual/html_node/Numerical-Integration.html#Numerical-Integration>
15-}
16-----------------------------------------------------------------------------
17
18module Numeric.GSL.Integration (
19 integrateQNG,
20 integrateQAGS,
21 integrateQAGI,
22 integrateQAGIU,
23 integrateQAGIL,
24 integrateCQUAD
25) where
26
27import Foreign.C.Types
28import Foreign.Marshal.Alloc(malloc, free)
29import Foreign.Ptr(Ptr, FunPtr, freeHaskellFunPtr)
30import Foreign.Storable(peek)
31import Data.Packed.Internal(check,(//))
32import System.IO.Unsafe(unsafePerformIO)
33
34eps = 1e-12
35
36{- | conversion of Haskell functions into function pointers that can be used in the C side
37-}
38foreign import ccall safe "wrapper" mkfun:: (Double -> Ptr() -> Double) -> IO( FunPtr (Double -> Ptr() -> Double))
39
40--------------------------------------------------------------------
41{- | Numerical integration using /gsl_integration_qags/ (adaptive integration with singularities). For example:
42
43>>> let quad = integrateQAGS 1E-9 1000
44>>> let f a x = x**(-0.5) * log (a*x)
45>>> quad (f 1) 0 1
46(-3.999999999999974,4.871658632055187e-13)
47
48-}
49
50integrateQAGS :: Double -- ^ precision (e.g. 1E-9)
51 -> Int -- ^ size of auxiliary workspace (e.g. 1000)
52 -> (Double -> Double) -- ^ function to be integrated on the interval (a,b)
53 -> Double -- ^ a
54 -> Double -- ^ b
55 -> (Double, Double) -- ^ result of the integration and error
56integrateQAGS prec n f a b = unsafePerformIO $ do
57 r <- malloc
58 e <- malloc
59 fp <- mkfun (\x _ -> f x)
60 c_integrate_qags fp a b eps prec (fromIntegral n) r e // check "integrate_qags"
61 vr <- peek r
62 ve <- peek e
63 let result = (vr,ve)
64 free r
65 free e
66 freeHaskellFunPtr fp
67 return result
68
69foreign import ccall safe "integrate_qags" c_integrate_qags
70 :: FunPtr (Double-> Ptr() -> Double) -> Double -> Double
71 -> Double -> Double -> CInt -> Ptr Double -> Ptr Double -> IO CInt
72
73-----------------------------------------------------------------
74{- | Numerical integration using /gsl_integration_qng/ (useful for fast integration of smooth functions). For example:
75
76>>> let quad = integrateQNG 1E-6
77>>> quad (\x -> 4/(1+x*x)) 0 1
78(3.141592653589793,3.487868498008632e-14)
79
80-}
81
82integrateQNG :: Double -- ^ precision (e.g. 1E-9)
83 -> (Double -> Double) -- ^ function to be integrated on the interval (a,b)
84 -> Double -- ^ a
85 -> Double -- ^ b
86 -> (Double, Double) -- ^ result of the integration and error
87integrateQNG prec f a b = unsafePerformIO $ do
88 r <- malloc
89 e <- malloc
90 fp <- mkfun (\x _ -> f x)
91 c_integrate_qng fp a b eps prec r e // check "integrate_qng"
92 vr <- peek r
93 ve <- peek e
94 let result = (vr,ve)
95 free r
96 free e
97 freeHaskellFunPtr fp
98 return result
99
100foreign import ccall safe "integrate_qng" c_integrate_qng
101 :: FunPtr (Double-> Ptr() -> Double) -> Double -> Double
102 -> Double -> Double -> Ptr Double -> Ptr Double -> IO CInt
103
104--------------------------------------------------------------------
105{- | Numerical integration using /gsl_integration_qagi/ (integration over the infinite integral -Inf..Inf using QAGS).
106For example:
107
108>>> let quad = integrateQAGI 1E-9 1000
109>>> let f a x = exp(-a * x^2)
110>>> quad (f 0.5)
111(2.5066282746310002,6.229215880648858e-11)
112
113-}
114
115integrateQAGI :: Double -- ^ precision (e.g. 1E-9)
116 -> Int -- ^ size of auxiliary workspace (e.g. 1000)
117 -> (Double -> Double) -- ^ function to be integrated on the interval (-Inf,Inf)
118 -> (Double, Double) -- ^ result of the integration and error
119integrateQAGI prec n f = unsafePerformIO $ do
120 r <- malloc
121 e <- malloc
122 fp <- mkfun (\x _ -> f x)
123 c_integrate_qagi fp eps prec (fromIntegral n) r e // check "integrate_qagi"
124 vr <- peek r
125 ve <- peek e
126 let result = (vr,ve)
127 free r
128 free e
129 freeHaskellFunPtr fp
130 return result
131
132foreign import ccall safe "integrate_qagi" c_integrate_qagi
133 :: FunPtr (Double-> Ptr() -> Double) -> Double -> Double
134 -> CInt -> Ptr Double -> Ptr Double -> IO CInt
135
136--------------------------------------------------------------------
137{- | Numerical integration using /gsl_integration_qagiu/ (integration over the semi-infinite integral a..Inf).
138For example:
139
140>>> let quad = integrateQAGIU 1E-9 1000
141>>> let f a x = exp(-a * x^2)
142>>> quad (f 0.5) 0
143(1.2533141373155001,3.114607940324429e-11)
144
145-}
146
147integrateQAGIU :: Double -- ^ precision (e.g. 1E-9)
148 -> Int -- ^ size of auxiliary workspace (e.g. 1000)
149 -> (Double -> Double) -- ^ function to be integrated on the interval (a,Inf)
150 -> Double -- ^ a
151 -> (Double, Double) -- ^ result of the integration and error
152integrateQAGIU prec n f a = unsafePerformIO $ do
153 r <- malloc
154 e <- malloc
155 fp <- mkfun (\x _ -> f x)
156 c_integrate_qagiu fp a eps prec (fromIntegral n) r e // check "integrate_qagiu"
157 vr <- peek r
158 ve <- peek e
159 let result = (vr,ve)
160 free r
161 free e
162 freeHaskellFunPtr fp
163 return result
164
165foreign import ccall safe "integrate_qagiu" c_integrate_qagiu
166 :: FunPtr (Double-> Ptr() -> Double) -> Double -> Double
167 -> Double -> CInt -> Ptr Double -> Ptr Double -> IO CInt
168
169--------------------------------------------------------------------
170{- | Numerical integration using /gsl_integration_qagil/ (integration over the semi-infinite integral -Inf..b).
171For example:
172
173>>> let quad = integrateQAGIL 1E-9 1000
174>>> let f a x = exp(-a * x^2)
175>>> quad (f 0.5) 0
176(1.2533141373155001,3.114607940324429e-11)
177
178-}
179
180integrateQAGIL :: Double -- ^ precision (e.g. 1E-9)
181 -> Int -- ^ size of auxiliary workspace (e.g. 1000)
182 -> (Double -> Double) -- ^ function to be integrated on the interval (a,Inf)
183 -> Double -- ^ b
184 -> (Double, Double) -- ^ result of the integration and error
185integrateQAGIL prec n f b = unsafePerformIO $ do
186 r <- malloc
187 e <- malloc
188 fp <- mkfun (\x _ -> f x)
189 c_integrate_qagil fp b eps prec (fromIntegral n) r e // check "integrate_qagil"
190 vr <- peek r
191 ve <- peek e
192 let result = (vr,ve)
193 free r
194 free e
195 freeHaskellFunPtr fp
196 return result
197
198foreign import ccall safe "gsl-aux.h integrate_qagil" c_integrate_qagil
199 :: FunPtr (Double-> Ptr() -> Double) -> Double -> Double
200 -> Double -> CInt -> Ptr Double -> Ptr Double -> IO CInt
201
202
203--------------------------------------------------------------------
204{- | Numerical integration using /gsl_integration_cquad/ (quadrature
205for general integrands). From the GSL manual:
206
207@CQUAD is a new doubly-adaptive general-purpose quadrature routine
208which can handle most types of singularities, non-numerical function
209values such as Inf or NaN, as well as some divergent integrals. It
210generally requires more function evaluations than the integration
211routines in QUADPACK, yet fails less often for difficult integrands.@
212
213For example:
214
215>>> let quad = integrateCQUAD 1E-12 1000
216>>> let f a x = exp(-a * x^2)
217>>> quad (f 0.5) 2 5
218(5.7025405463957006e-2,9.678874441303705e-16,95)
219
220Unlike other quadrature methods, integrateCQUAD also returns the
221number of function evaluations required.
222
223-}
224
225integrateCQUAD :: Double -- ^ precision (e.g. 1E-9)
226 -> Int -- ^ size of auxiliary workspace (e.g. 1000)
227 -> (Double -> Double) -- ^ function to be integrated on the interval (a, b)
228 -> Double -- ^ a
229 -> Double -- ^ b
230 -> (Double, Double, Int) -- ^ result of the integration, error and number of function evaluations performed
231integrateCQUAD prec n f a b = unsafePerformIO $ do
232 r <- malloc
233 e <- malloc
234 neval <- malloc
235 fp <- mkfun (\x _ -> f x)
236 c_integrate_cquad fp a b eps prec (fromIntegral n) r e neval // check "integrate_cquad"
237 vr <- peek r
238 ve <- peek e
239 vneval <- peek neval
240 let result = (vr,ve,vneval)
241 free r
242 free e
243 free neval
244 freeHaskellFunPtr fp
245 return result
246
247foreign import ccall safe "integrate_cquad" c_integrate_cquad
248 :: FunPtr (Double-> Ptr() -> Double) -> Double -> Double
249 -> Double -> Double -> CInt -> Ptr Double -> Ptr Double -> Ptr Int -> IO CInt
250
diff --git a/packages/hmatrix/src/Numeric/GSL/Internal.hs b/packages/hmatrix/src/Numeric/GSL/Internal.hs
new file mode 100644
index 0000000..69a9750
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/Internal.hs
@@ -0,0 +1,76 @@
1-- Module : Numeric.GSL.Internal
2-- Copyright : (c) Alberto Ruiz 2009
3-- License : GPL
4--
5-- Maintainer : Alberto Ruiz (aruiz at um dot es)
6-- Stability : provisional
7-- Portability : uses ffi
8--
9-- Auxiliary functions.
10--
11-- #hide
12
13module Numeric.GSL.Internal where
14
15import Data.Packed.Internal
16
17import Foreign.Marshal.Array(copyArray)
18import Foreign.Ptr(Ptr, FunPtr)
19import Foreign.C.Types
20import System.IO.Unsafe(unsafePerformIO)
21
22iv :: (Vector Double -> Double) -> (CInt -> Ptr Double -> Double)
23iv f n p = f (createV (fromIntegral n) copy "iv") where
24 copy n' q = do
25 copyArray q p (fromIntegral n')
26 return 0
27
28-- | conversion of Haskell functions into function pointers that can be used in the C side
29foreign import ccall safe "wrapper"
30 mkVecfun :: (CInt -> Ptr Double -> Double)
31 -> IO( FunPtr (CInt -> Ptr Double -> Double))
32
33foreign import ccall safe "wrapper"
34 mkVecVecfun :: TVV -> IO (FunPtr TVV)
35
36foreign import ccall safe "wrapper"
37 mkDoubleVecVecfun :: (Double -> TVV) -> IO (FunPtr (Double -> TVV))
38
39foreign import ccall safe "wrapper"
40 mkDoublefun :: (Double -> Double) -> IO (FunPtr (Double -> Double))
41
42aux_vTov :: (Vector Double -> Vector Double) -> TVV
43aux_vTov f n p nr r = g where
44 v = f x
45 x = createV (fromIntegral n) copy "aux_vTov"
46 copy n' q = do
47 copyArray q p (fromIntegral n')
48 return 0
49 g = do unsafeWith v $ \p' -> copyArray r p' (fromIntegral nr)
50 return 0
51
52foreign import ccall safe "wrapper"
53 mkVecMatfun :: TVM -> IO (FunPtr TVM)
54
55foreign import ccall safe "wrapper"
56 mkDoubleVecMatfun :: (Double -> TVM) -> IO (FunPtr (Double -> TVM))
57
58aux_vTom :: (Vector Double -> Matrix Double) -> TVM
59aux_vTom f n p rr cr r = g where
60 v = flatten $ f x
61 x = createV (fromIntegral n) copy "aux_vTov"
62 copy n' q = do
63 copyArray q p (fromIntegral n')
64 return 0
65 g = do unsafeWith v $ \p' -> copyArray r p' (fromIntegral $ rr*cr)
66 return 0
67
68createV n fun msg = unsafePerformIO $ do
69 r <- createVector n
70 app1 fun vec r msg
71 return r
72
73createMIO r c fun msg = do
74 res <- createMatrix RowMajor r c
75 app1 fun mat res msg
76 return res
diff --git a/packages/hmatrix/src/Numeric/GSL/Minimization.hs b/packages/hmatrix/src/Numeric/GSL/Minimization.hs
new file mode 100644
index 0000000..1879dab
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/Minimization.hs
@@ -0,0 +1,227 @@
1{-# LANGUAGE ForeignFunctionInterface #-}
2-----------------------------------------------------------------------------
3{- |
4Module : Numeric.GSL.Minimization
5Copyright : (c) Alberto Ruiz 2006-9
6License : GPL-style
7
8Maintainer : Alberto Ruiz (aruiz at um dot es)
9Stability : provisional
10Portability : uses ffi
11
12Minimization of a multidimensional function using some of the algorithms described in:
13
14<http://www.gnu.org/software/gsl/manual/html_node/Multidimensional-Minimization.html>
15
16The example in the GSL manual:
17
18@
19f [x,y] = 10*(x-1)^2 + 20*(y-2)^2 + 30
20
21main = do
22 let (s,p) = minimize NMSimplex2 1E-2 30 [1,1] f [5,7]
23 print s
24 print p
25@
26
27>>> main
28[0.9920430849306288,1.9969168063253182]
29 0.000 512.500 1.130 6.500 5.000
30 1.000 290.625 1.409 5.250 4.000
31 2.000 290.625 1.409 5.250 4.000
32 3.000 252.500 1.409 5.500 1.000
33 ...
3422.000 30.001 0.013 0.992 1.997
3523.000 30.001 0.008 0.992 1.997
36
37The path to the solution can be graphically shown by means of:
38
39@'Graphics.Plot.mplot' $ drop 3 ('toColumns' p)@
40
41Taken from the GSL manual:
42
43The vector Broyden-Fletcher-Goldfarb-Shanno (BFGS) algorithm is a quasi-Newton method which builds up an approximation to the second derivatives of the function f using the difference between successive gradient vectors. By combining the first and second derivatives the algorithm is able to take Newton-type steps towards the function minimum, assuming quadratic behavior in that region.
44
45The bfgs2 version of this minimizer is the most efficient version available, and is a faithful implementation of the line minimization scheme described in Fletcher's Practical Methods of Optimization, Algorithms 2.6.2 and 2.6.4. It supercedes the original bfgs routine and requires substantially fewer function and gradient evaluations. The user-supplied tolerance tol corresponds to the parameter \sigma used by Fletcher. A value of 0.1 is recommended for typical use (larger values correspond to less accurate line searches).
46
47The nmsimplex2 version is a new O(N) implementation of the earlier O(N^2) nmsimplex minimiser. It calculates the size of simplex as the rms distance of each vertex from the center rather than the mean distance, which has the advantage of allowing a linear update.
48
49-}
50
51-----------------------------------------------------------------------------
52module Numeric.GSL.Minimization (
53 minimize, minimizeV, MinimizeMethod(..),
54 minimizeD, minimizeVD, MinimizeMethodD(..),
55 uniMinimize, UniMinimizeMethod(..),
56
57 minimizeNMSimplex,
58 minimizeConjugateGradient,
59 minimizeVectorBFGS2
60) where
61
62
63import Data.Packed.Internal
64import Data.Packed.Matrix
65import Numeric.GSL.Internal
66
67import Foreign.Ptr(Ptr, FunPtr, freeHaskellFunPtr)
68import Foreign.C.Types
69import System.IO.Unsafe(unsafePerformIO)
70
71------------------------------------------------------------------------
72
73{-# DEPRECATED minimizeNMSimplex "use minimize NMSimplex2 eps maxit sizes f xi" #-}
74minimizeNMSimplex f xi szs eps maxit = minimize NMSimplex eps maxit szs f xi
75
76{-# DEPRECATED minimizeConjugateGradient "use minimizeD ConjugateFR eps maxit step tol f g xi" #-}
77minimizeConjugateGradient step tol eps maxit f g xi = minimizeD ConjugateFR eps maxit step tol f g xi
78
79{-# DEPRECATED minimizeVectorBFGS2 "use minimizeD VectorBFGS2 eps maxit step tol f g xi" #-}
80minimizeVectorBFGS2 step tol eps maxit f g xi = minimizeD VectorBFGS2 eps maxit step tol f g xi
81
82-------------------------------------------------------------------------
83
84data UniMinimizeMethod = GoldenSection
85 | BrentMini
86 | QuadGolden
87 deriving (Enum, Eq, Show, Bounded)
88
89-- | Onedimensional minimization.
90
91uniMinimize :: UniMinimizeMethod -- ^ The method used.
92 -> Double -- ^ desired precision of the solution
93 -> Int -- ^ maximum number of iterations allowed
94 -> (Double -> Double) -- ^ function to minimize
95 -> Double -- ^ guess for the location of the minimum
96 -> Double -- ^ lower bound of search interval
97 -> Double -- ^ upper bound of search interval
98 -> (Double, Matrix Double) -- ^ solution and optimization path
99
100uniMinimize method epsrel maxit fun xmin xl xu = uniMinimizeGen (fi (fromEnum method)) fun xmin xl xu epsrel maxit
101
102uniMinimizeGen m f xmin xl xu epsrel maxit = unsafePerformIO $ do
103 fp <- mkDoublefun f
104 rawpath <- createMIO maxit 4
105 (c_uniMinize m fp epsrel (fi maxit) xmin xl xu)
106 "uniMinimize"
107 let it = round (rawpath @@> (maxit-1,0))
108 path = takeRows it rawpath
109 [sol] = toLists $ dropRows (it-1) path
110 freeHaskellFunPtr fp
111 return (sol !! 1, path)
112
113
114foreign import ccall safe "uniMinimize"
115 c_uniMinize:: CInt -> FunPtr (Double -> Double) -> Double -> CInt -> Double -> Double -> Double -> TM
116
117data MinimizeMethod = NMSimplex
118 | NMSimplex2
119 deriving (Enum,Eq,Show,Bounded)
120
121-- | Minimization without derivatives
122minimize :: MinimizeMethod
123 -> Double -- ^ desired precision of the solution (size test)
124 -> Int -- ^ maximum number of iterations allowed
125 -> [Double] -- ^ sizes of the initial search box
126 -> ([Double] -> Double) -- ^ function to minimize
127 -> [Double] -- ^ starting point
128 -> ([Double], Matrix Double) -- ^ solution vector and optimization path
129
130-- | Minimization without derivatives (vector version)
131minimizeV :: MinimizeMethod
132 -> Double -- ^ desired precision of the solution (size test)
133 -> Int -- ^ maximum number of iterations allowed
134 -> Vector Double -- ^ sizes of the initial search box
135 -> (Vector Double -> Double) -- ^ function to minimize
136 -> Vector Double -- ^ starting point
137 -> (Vector Double, Matrix Double) -- ^ solution vector and optimization path
138
139minimize method eps maxit sz f xi = v2l $ minimizeV method eps maxit (fromList sz) (f.toList) (fromList xi)
140 where v2l (v,m) = (toList v, m)
141
142ww2 w1 o1 w2 o2 f = w1 o1 $ \a1 -> w2 o2 $ \a2 -> f a1 a2
143
144minimizeV method eps maxit szv f xiv = unsafePerformIO $ do
145 let n = dim xiv
146 fp <- mkVecfun (iv f)
147 rawpath <- ww2 vec xiv vec szv $ \xiv' szv' ->
148 createMIO maxit (n+3)
149 (c_minimize (fi (fromEnum method)) fp eps (fi maxit) // xiv' // szv')
150 "minimize"
151 let it = round (rawpath @@> (maxit-1,0))
152 path = takeRows it rawpath
153 sol = cdat $ dropColumns 3 $ dropRows (it-1) path
154 freeHaskellFunPtr fp
155 return (sol, path)
156
157
158foreign import ccall safe "gsl-aux.h minimize"
159 c_minimize:: CInt -> FunPtr (CInt -> Ptr Double -> Double) -> Double -> CInt -> TVVM
160
161----------------------------------------------------------------------------------
162
163
164data MinimizeMethodD = ConjugateFR
165 | ConjugatePR
166 | VectorBFGS
167 | VectorBFGS2
168 | SteepestDescent
169 deriving (Enum,Eq,Show,Bounded)
170
171-- | Minimization with derivatives.
172minimizeD :: MinimizeMethodD
173 -> Double -- ^ desired precision of the solution (gradient test)
174 -> Int -- ^ maximum number of iterations allowed
175 -> Double -- ^ size of the first trial step
176 -> Double -- ^ tol (precise meaning depends on method)
177 -> ([Double] -> Double) -- ^ function to minimize
178 -> ([Double] -> [Double]) -- ^ gradient
179 -> [Double] -- ^ starting point
180 -> ([Double], Matrix Double) -- ^ solution vector and optimization path
181
182-- | Minimization with derivatives (vector version)
183minimizeVD :: MinimizeMethodD
184 -> Double -- ^ desired precision of the solution (gradient test)
185 -> Int -- ^ maximum number of iterations allowed
186 -> Double -- ^ size of the first trial step
187 -> Double -- ^ tol (precise meaning depends on method)
188 -> (Vector Double -> Double) -- ^ function to minimize
189 -> (Vector Double -> Vector Double) -- ^ gradient
190 -> Vector Double -- ^ starting point
191 -> (Vector Double, Matrix Double) -- ^ solution vector and optimization path
192
193minimizeD method eps maxit istep tol f df xi = v2l $ minimizeVD
194 method eps maxit istep tol (f.toList) (fromList.df.toList) (fromList xi)
195 where v2l (v,m) = (toList v, m)
196
197
198minimizeVD method eps maxit istep tol f df xiv = unsafePerformIO $ do
199 let n = dim xiv
200 f' = f
201 df' = (checkdim1 n . df)
202 fp <- mkVecfun (iv f')
203 dfp <- mkVecVecfun (aux_vTov df')
204 rawpath <- vec xiv $ \xiv' ->
205 createMIO maxit (n+2)
206 (c_minimizeD (fi (fromEnum method)) fp dfp istep tol eps (fi maxit) // xiv')
207 "minimizeD"
208 let it = round (rawpath @@> (maxit-1,0))
209 path = takeRows it rawpath
210 sol = cdat $ dropColumns 2 $ dropRows (it-1) path
211 freeHaskellFunPtr fp
212 freeHaskellFunPtr dfp
213 return (sol,path)
214
215foreign import ccall safe "gsl-aux.h minimizeD"
216 c_minimizeD :: CInt
217 -> FunPtr (CInt -> Ptr Double -> Double)
218 -> FunPtr TVV
219 -> Double -> Double -> Double -> CInt
220 -> TVM
221
222---------------------------------------------------------------------
223
224checkdim1 n v
225 | dim v == n = v
226 | otherwise = error $ "Error: "++ show n
227 ++ " components expected in the result of the gradient supplied to minimizeD"
diff --git a/packages/hmatrix/src/Numeric/GSL/ODE.hs b/packages/hmatrix/src/Numeric/GSL/ODE.hs
new file mode 100644
index 0000000..9a29085
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/ODE.hs
@@ -0,0 +1,138 @@
1{- |
2Module : Numeric.GSL.ODE
3Copyright : (c) Alberto Ruiz 2010
4License : GPL
5
6Maintainer : Alberto Ruiz (aruiz at um dot es)
7Stability : provisional
8Portability : uses ffi
9
10Solution of ordinary differential equation (ODE) initial value problems.
11
12<http://www.gnu.org/software/gsl/manual/html_node/Ordinary-Differential-Equations.html>
13
14A simple example:
15
16@
17import Numeric.GSL.ODE
18import Numeric.LinearAlgebra
19import Numeric.LinearAlgebra.Util(mplot)
20
21xdot t [x,v] = [v, -0.95*x - 0.1*v]
22
23ts = linspace 100 (0,20 :: Double)
24
25sol = odeSolve xdot [10,0] ts
26
27main = mplot (ts : toColumns sol)
28@
29
30-}
31-----------------------------------------------------------------------------
32
33module Numeric.GSL.ODE (
34 odeSolve, odeSolveV, ODEMethod(..), Jacobian
35) where
36
37import Data.Packed.Internal
38import Numeric.GSL.Internal
39
40import Foreign.Ptr(FunPtr, nullFunPtr, freeHaskellFunPtr)
41import Foreign.C.Types
42import System.IO.Unsafe(unsafePerformIO)
43
44-------------------------------------------------------------------------
45
46type Jacobian = Double -> Vector Double -> Matrix Double
47
48-- | Stepping functions
49data ODEMethod = RK2 -- ^ Embedded Runge-Kutta (2, 3) method.
50 | RK4 -- ^ 4th order (classical) Runge-Kutta. The error estimate is obtained by halving the step-size. For more efficient estimate of the error, use the embedded methods.
51 | RKf45 -- ^ Embedded Runge-Kutta-Fehlberg (4, 5) method. This method is a good general-purpose integrator.
52 | RKck -- ^ Embedded Runge-Kutta Cash-Karp (4, 5) method.
53 | RK8pd -- ^ Embedded Runge-Kutta Prince-Dormand (8,9) method.
54 | RK2imp Jacobian -- ^ Implicit 2nd order Runge-Kutta at Gaussian points.
55 | RK4imp Jacobian -- ^ Implicit 4th order Runge-Kutta at Gaussian points.
56 | BSimp Jacobian -- ^ Implicit Bulirsch-Stoer method of Bader and Deuflhard. The method is generally suitable for stiff problems.
57 | RK1imp Jacobian -- ^ Implicit Gaussian first order Runge-Kutta. Also known as implicit Euler or backward Euler method. Error estimation is carried out by the step doubling method.
58 | MSAdams -- ^ A variable-coefficient linear multistep Adams method in Nordsieck form. This stepper uses explicit Adams-Bashforth (predictor) and implicit Adams-Moulton (corrector) methods in P(EC)^m functional iteration mode. Method order varies dynamically between 1 and 12.
59 | MSBDF Jacobian -- ^ A variable-coefficient linear multistep backward differentiation formula (BDF) method in Nordsieck form. This stepper uses the explicit BDF formula as predictor and implicit BDF formula as corrector. A modified Newton iteration method is used to solve the system of non-linear equations. Method order varies dynamically between 1 and 5. The method is generally suitable for stiff problems.
60
61
62-- | A version of 'odeSolveV' with reasonable default parameters and system of equations defined using lists.
63odeSolve
64 :: (Double -> [Double] -> [Double]) -- ^ xdot(t,x)
65 -> [Double] -- ^ initial conditions
66 -> Vector Double -- ^ desired solution times
67 -> Matrix Double -- ^ solution
68odeSolve xdot xi ts = odeSolveV RKf45 hi epsAbs epsRel (l2v xdot) (fromList xi) ts
69 where hi = (ts@>1 - ts@>0)/100
70 epsAbs = 1.49012e-08
71 epsRel = 1.49012e-08
72 l2v f = \t -> fromList . f t . toList
73
74-- | Evolution of the system with adaptive step-size control.
75odeSolveV
76 :: ODEMethod
77 -> Double -- ^ initial step size
78 -> Double -- ^ absolute tolerance for the state vector
79 -> Double -- ^ relative tolerance for the state vector
80 -> (Double -> Vector Double -> Vector Double) -- ^ xdot(t,x)
81 -> Vector Double -- ^ initial conditions
82 -> Vector Double -- ^ desired solution times
83 -> Matrix Double -- ^ solution
84odeSolveV RK2 = odeSolveV' 0 Nothing
85odeSolveV RK4 = odeSolveV' 1 Nothing
86odeSolveV RKf45 = odeSolveV' 2 Nothing
87odeSolveV RKck = odeSolveV' 3 Nothing
88odeSolveV RK8pd = odeSolveV' 4 Nothing
89odeSolveV (RK2imp jac) = odeSolveV' 5 (Just jac)
90odeSolveV (RK4imp jac) = odeSolveV' 6 (Just jac)
91odeSolveV (BSimp jac) = odeSolveV' 7 (Just jac)
92odeSolveV (RK1imp jac) = odeSolveV' 8 (Just jac)
93odeSolveV MSAdams = odeSolveV' 9 Nothing
94odeSolveV (MSBDF jac) = odeSolveV' 10 (Just jac)
95
96
97odeSolveV'
98 :: CInt
99 -> Maybe (Double -> Vector Double -> Matrix Double) -- ^ optional jacobian
100 -> Double -- ^ initial step size
101 -> Double -- ^ absolute tolerance for the state vector
102 -> Double -- ^ relative tolerance for the state vector
103 -> (Double -> Vector Double -> Vector Double) -- ^ xdot(t,x)
104 -> Vector Double -- ^ initial conditions
105 -> Vector Double -- ^ desired solution times
106 -> Matrix Double -- ^ solution
107odeSolveV' method mbjac h epsAbs epsRel f xiv ts = unsafePerformIO $ do
108 let n = dim xiv
109 fp <- mkDoubleVecVecfun (\t -> aux_vTov (checkdim1 n . f t))
110 jp <- case mbjac of
111 Just jac -> mkDoubleVecMatfun (\t -> aux_vTom (checkdim2 n . jac t))
112 Nothing -> return nullFunPtr
113 sol <- vec xiv $ \xiv' ->
114 vec (checkTimes ts) $ \ts' ->
115 createMIO (dim ts) n
116 (ode_c (method) h epsAbs epsRel fp jp // xiv' // ts' )
117 "ode"
118 freeHaskellFunPtr fp
119 return sol
120
121foreign import ccall safe "ode"
122 ode_c :: CInt -> Double -> Double -> Double -> FunPtr (Double -> TVV) -> FunPtr (Double -> TVM) -> TVVM
123
124-------------------------------------------------------
125
126checkdim1 n v
127 | dim v == n = v
128 | otherwise = error $ "Error: "++ show n
129 ++ " components expected in the result of the function supplied to odeSolve"
130
131checkdim2 n m
132 | rows m == n && cols m == n = m
133 | otherwise = error $ "Error: "++ show n ++ "x" ++ show n
134 ++ " Jacobian expected in odeSolve"
135
136checkTimes ts | dim ts > 1 && all (>0) (zipWith subtract ts' (tail ts')) = ts
137 | otherwise = error "odeSolve requires increasing times"
138 where ts' = toList ts
diff --git a/packages/hmatrix/src/Numeric/GSL/Polynomials.hs b/packages/hmatrix/src/Numeric/GSL/Polynomials.hs
new file mode 100644
index 0000000..290c615
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/Polynomials.hs
@@ -0,0 +1,58 @@
1{-# LANGUAGE CPP, ForeignFunctionInterface #-}
2-----------------------------------------------------------------------------
3{- |
4Module : Numeric.GSL.Polynomials
5Copyright : (c) Alberto Ruiz 2006
6License : GPL-style
7
8Maintainer : Alberto Ruiz (aruiz at um dot es)
9Stability : provisional
10Portability : uses ffi
11
12Polynomials.
13
14<http://www.gnu.org/software/gsl/manual/html_node/General-Polynomial-Equations.html#General-Polynomial-Equations>
15
16-}
17-----------------------------------------------------------------------------
18module Numeric.GSL.Polynomials (
19 polySolve
20) where
21
22import Data.Packed.Internal
23import Data.Complex
24import System.IO.Unsafe (unsafePerformIO)
25
26#if __GLASGOW_HASKELL__ >= 704
27import Foreign.C.Types (CInt(..))
28#endif
29
30{- | Solution of general polynomial equations, using /gsl_poly_complex_solve/.
31
32For example, the three solutions of x^3 + 8 = 0
33
34>>> polySolve [8,0,0,1]
35[(-2.0) :+ 0.0,1.0 :+ 1.7320508075688776,1.0 :+ (-1.7320508075688776)]
36
37
38The example in the GSL manual: To find the roots of x^5 -1 = 0:
39
40>>> polySolve [-1, 0, 0, 0, 0, 1]
41[(-0.8090169943749472) :+ 0.5877852522924731,
42(-0.8090169943749472) :+ (-0.5877852522924731),
430.30901699437494756 :+ 0.9510565162951535,
440.30901699437494756 :+ (-0.9510565162951535),
451.0000000000000002 :+ 0.0]
46
47-}
48polySolve :: [Double] -> [Complex Double]
49polySolve = toList . polySolve' . fromList
50
51polySolve' :: Vector Double -> Vector (Complex Double)
52polySolve' v | dim v > 1 = unsafePerformIO $ do
53 r <- createVector (dim v-1)
54 app2 c_polySolve vec v vec r "polySolve"
55 return r
56 | otherwise = error "polySolve on a polynomial of degree zero"
57
58foreign import ccall unsafe "gsl-aux.h polySolve" c_polySolve:: TVCV
diff --git a/packages/hmatrix/src/Numeric/GSL/Root.hs b/packages/hmatrix/src/Numeric/GSL/Root.hs
new file mode 100644
index 0000000..9d561c4
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/Root.hs
@@ -0,0 +1,199 @@
1{- |
2Module : Numeric.GSL.Root
3Copyright : (c) Alberto Ruiz 2009
4License : GPL
5
6Maintainer : Alberto Ruiz (aruiz at um dot es)
7Stability : provisional
8Portability : uses ffi
9
10Multidimensional root finding.
11
12<http://www.gnu.org/software/gsl/manual/html_node/Multidimensional-Root_002dFinding.html>
13
14The example in the GSL manual:
15
16>>> let rosenbrock a b [x,y] = [ a*(1-x), b*(y-x^2) ]
17>>> let (sol,path) = root Hybrids 1E-7 30 (rosenbrock 1 10) [-10,-5]
18>>> sol
19[1.0,1.0]
20>>> disp 3 path
2111x5
22 1.000 -10.000 -5.000 11.000 -1050.000
23 2.000 -3.976 24.827 4.976 90.203
24 3.000 -3.976 24.827 4.976 90.203
25 4.000 -3.976 24.827 4.976 90.203
26 5.000 -1.274 -5.680 2.274 -73.018
27 6.000 -1.274 -5.680 2.274 -73.018
28 7.000 0.249 0.298 0.751 2.359
29 8.000 0.249 0.298 0.751 2.359
30 9.000 1.000 0.878 -0.000 -1.218
3110.000 1.000 0.989 -0.000 -0.108
3211.000 1.000 1.000 0.000 0.000
33
34-}
35-----------------------------------------------------------------------------
36
37module Numeric.GSL.Root (
38 uniRoot, UniRootMethod(..),
39 uniRootJ, UniRootMethodJ(..),
40 root, RootMethod(..),
41 rootJ, RootMethodJ(..),
42) where
43
44import Data.Packed.Internal
45import Data.Packed.Matrix
46import Numeric.GSL.Internal
47import Foreign.Ptr(FunPtr, freeHaskellFunPtr)
48import Foreign.C.Types
49import System.IO.Unsafe(unsafePerformIO)
50
51-------------------------------------------------------------------------
52
53data UniRootMethod = Bisection
54 | FalsePos
55 | Brent
56 deriving (Enum, Eq, Show, Bounded)
57
58uniRoot :: UniRootMethod
59 -> Double
60 -> Int
61 -> (Double -> Double)
62 -> Double
63 -> Double
64 -> (Double, Matrix Double)
65uniRoot method epsrel maxit fun xl xu = uniRootGen (fi (fromEnum method)) fun xl xu epsrel maxit
66
67uniRootGen m f xl xu epsrel maxit = unsafePerformIO $ do
68 fp <- mkDoublefun f
69 rawpath <- createMIO maxit 4
70 (c_root m fp epsrel (fi maxit) xl xu)
71 "root"
72 let it = round (rawpath @@> (maxit-1,0))
73 path = takeRows it rawpath
74 [sol] = toLists $ dropRows (it-1) path
75 freeHaskellFunPtr fp
76 return (sol !! 1, path)
77
78foreign import ccall safe "root"
79 c_root:: CInt -> FunPtr (Double -> Double) -> Double -> CInt -> Double -> Double -> TM
80
81-------------------------------------------------------------------------
82data UniRootMethodJ = UNewton
83 | Secant
84 | Steffenson
85 deriving (Enum, Eq, Show, Bounded)
86
87uniRootJ :: UniRootMethodJ
88 -> Double
89 -> Int
90 -> (Double -> Double)
91 -> (Double -> Double)
92 -> Double
93 -> (Double, Matrix Double)
94uniRootJ method epsrel maxit fun dfun x = uniRootJGen (fi (fromEnum method)) fun
95 dfun x epsrel maxit
96
97uniRootJGen m f df x epsrel maxit = unsafePerformIO $ do
98 fp <- mkDoublefun f
99 dfp <- mkDoublefun df
100 rawpath <- createMIO maxit 2
101 (c_rootj m fp dfp epsrel (fi maxit) x)
102 "rootj"
103 let it = round (rawpath @@> (maxit-1,0))
104 path = takeRows it rawpath
105 [sol] = toLists $ dropRows (it-1) path
106 freeHaskellFunPtr fp
107 return (sol !! 1, path)
108
109foreign import ccall safe "rootj"
110 c_rootj :: CInt -> FunPtr (Double -> Double) -> FunPtr (Double -> Double)
111 -> Double -> CInt -> Double -> TM
112
113-------------------------------------------------------------------------
114
115data RootMethod = Hybrids
116 | Hybrid
117 | DNewton
118 | Broyden
119 deriving (Enum,Eq,Show,Bounded)
120
121-- | Nonlinear multidimensional root finding using algorithms that do not require
122-- any derivative information to be supplied by the user.
123-- Any derivatives needed are approximated by finite differences.
124root :: RootMethod
125 -> Double -- ^ maximum residual
126 -> Int -- ^ maximum number of iterations allowed
127 -> ([Double] -> [Double]) -- ^ function to minimize
128 -> [Double] -- ^ starting point
129 -> ([Double], Matrix Double) -- ^ solution vector and optimization path
130
131root method epsabs maxit fun xinit = rootGen (fi (fromEnum method)) fun xinit epsabs maxit
132
133rootGen m f xi epsabs maxit = unsafePerformIO $ do
134 let xiv = fromList xi
135 n = dim xiv
136 fp <- mkVecVecfun (aux_vTov (checkdim1 n . fromList . f . toList))
137 rawpath <- vec xiv $ \xiv' ->
138 createMIO maxit (2*n+1)
139 (c_multiroot m fp epsabs (fi maxit) // xiv')
140 "multiroot"
141 let it = round (rawpath @@> (maxit-1,0))
142 path = takeRows it rawpath
143 [sol] = toLists $ dropRows (it-1) path
144 freeHaskellFunPtr fp
145 return (take n $ drop 1 sol, path)
146
147
148foreign import ccall safe "multiroot"
149 c_multiroot:: CInt -> FunPtr TVV -> Double -> CInt -> TVM
150
151-------------------------------------------------------------------------
152
153data RootMethodJ = HybridsJ
154 | HybridJ
155 | Newton
156 | GNewton
157 deriving (Enum,Eq,Show,Bounded)
158
159-- | Nonlinear multidimensional root finding using both the function and its derivatives.
160rootJ :: RootMethodJ
161 -> Double -- ^ maximum residual
162 -> Int -- ^ maximum number of iterations allowed
163 -> ([Double] -> [Double]) -- ^ function to minimize
164 -> ([Double] -> [[Double]]) -- ^ Jacobian
165 -> [Double] -- ^ starting point
166 -> ([Double], Matrix Double) -- ^ solution vector and optimization path
167
168rootJ method epsabs maxit fun jac xinit = rootJGen (fi (fromEnum method)) fun jac xinit epsabs maxit
169
170rootJGen m f jac xi epsabs maxit = unsafePerformIO $ do
171 let xiv = fromList xi
172 n = dim xiv
173 fp <- mkVecVecfun (aux_vTov (checkdim1 n . fromList . f . toList))
174 jp <- mkVecMatfun (aux_vTom (checkdim2 n . fromLists . jac . toList))
175 rawpath <- vec xiv $ \xiv' ->
176 createMIO maxit (2*n+1)
177 (c_multirootj m fp jp epsabs (fi maxit) // xiv')
178 "multiroot"
179 let it = round (rawpath @@> (maxit-1,0))
180 path = takeRows it rawpath
181 [sol] = toLists $ dropRows (it-1) path
182 freeHaskellFunPtr fp
183 freeHaskellFunPtr jp
184 return (take n $ drop 1 sol, path)
185
186foreign import ccall safe "multirootj"
187 c_multirootj:: CInt -> FunPtr TVV -> FunPtr TVM -> Double -> CInt -> TVM
188
189-------------------------------------------------------
190
191checkdim1 n v
192 | dim v == n = v
193 | otherwise = error $ "Error: "++ show n
194 ++ " components expected in the result of the function supplied to root"
195
196checkdim2 n m
197 | rows m == n && cols m == n = m
198 | otherwise = error $ "Error: "++ show n ++ "x" ++ show n
199 ++ " Jacobian expected in rootJ"
diff --git a/packages/hmatrix/src/Numeric/GSL/Vector.hs b/packages/hmatrix/src/Numeric/GSL/Vector.hs
new file mode 100644
index 0000000..6204b8e
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/Vector.hs
@@ -0,0 +1,328 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Numeric.GSL.Vector
4-- Copyright : (c) Alberto Ruiz 2007
5-- License : GPL-style
6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es>
8-- Stability : provisional
9-- Portability : portable (uses FFI)
10--
11-- Low level interface to vector operations.
12--
13-----------------------------------------------------------------------------
14
15module Numeric.GSL.Vector (
16 sumF, sumR, sumQ, sumC,
17 prodF, prodR, prodQ, prodC,
18 dotF, dotR, dotQ, dotC,
19 FunCodeS(..), toScalarR, toScalarF, toScalarC, toScalarQ,
20 FunCodeV(..), vectorMapR, vectorMapC, vectorMapF, vectorMapQ,
21 FunCodeSV(..), vectorMapValR, vectorMapValC, vectorMapValF, vectorMapValQ,
22 FunCodeVV(..), vectorZipR, vectorZipC, vectorZipF, vectorZipQ,
23 RandDist(..), randomVector
24) where
25
26import Data.Packed.Internal.Common
27import Data.Packed.Internal.Signatures
28import Data.Packed.Internal.Vector
29
30import Data.Complex
31import Foreign.Marshal.Alloc(free)
32import Foreign.Marshal.Array(newArray)
33import Foreign.Ptr(Ptr)
34import Foreign.C.Types
35import System.IO.Unsafe(unsafePerformIO)
36import Control.Monad(when)
37
38fromei x = fromIntegral (fromEnum x) :: CInt
39
40data FunCodeV = Sin
41 | Cos
42 | Tan
43 | Abs
44 | ASin
45 | ACos
46 | ATan
47 | Sinh
48 | Cosh
49 | Tanh
50 | ASinh
51 | ACosh
52 | ATanh
53 | Exp
54 | Log
55 | Sign
56 | Sqrt
57 deriving Enum
58
59data FunCodeSV = Scale
60 | Recip
61 | AddConstant
62 | Negate
63 | PowSV
64 | PowVS
65 deriving Enum
66
67data FunCodeVV = Add
68 | Sub
69 | Mul
70 | Div
71 | Pow
72 | ATan2
73 deriving Enum
74
75data FunCodeS = Norm2
76 | AbsSum
77 | MaxIdx
78 | Max
79 | MinIdx
80 | Min
81 deriving Enum
82
83------------------------------------------------------------------
84
85-- | sum of elements
86sumF :: Vector Float -> Float
87sumF x = unsafePerformIO $ do
88 r <- createVector 1
89 app2 c_sumF vec x vec r "sumF"
90 return $ r @> 0
91
92-- | sum of elements
93sumR :: Vector Double -> Double
94sumR x = unsafePerformIO $ do
95 r <- createVector 1
96 app2 c_sumR vec x vec r "sumR"
97 return $ r @> 0
98
99-- | sum of elements
100sumQ :: Vector (Complex Float) -> Complex Float
101sumQ x = unsafePerformIO $ do
102 r <- createVector 1
103 app2 c_sumQ vec x vec r "sumQ"
104 return $ r @> 0
105
106-- | sum of elements
107sumC :: Vector (Complex Double) -> Complex Double
108sumC x = unsafePerformIO $ do
109 r <- createVector 1
110 app2 c_sumC vec x vec r "sumC"
111 return $ r @> 0
112
113foreign import ccall unsafe "gsl-aux.h sumF" c_sumF :: TFF
114foreign import ccall unsafe "gsl-aux.h sumR" c_sumR :: TVV
115foreign import ccall unsafe "gsl-aux.h sumQ" c_sumQ :: TQVQV
116foreign import ccall unsafe "gsl-aux.h sumC" c_sumC :: TCVCV
117
118-- | product of elements
119prodF :: Vector Float -> Float
120prodF x = unsafePerformIO $ do
121 r <- createVector 1
122 app2 c_prodF vec x vec r "prodF"
123 return $ r @> 0
124
125-- | product of elements
126prodR :: Vector Double -> Double
127prodR x = unsafePerformIO $ do
128 r <- createVector 1
129 app2 c_prodR vec x vec r "prodR"
130 return $ r @> 0
131
132-- | product of elements
133prodQ :: Vector (Complex Float) -> Complex Float
134prodQ x = unsafePerformIO $ do
135 r <- createVector 1
136 app2 c_prodQ vec x vec r "prodQ"
137 return $ r @> 0
138
139-- | product of elements
140prodC :: Vector (Complex Double) -> Complex Double
141prodC x = unsafePerformIO $ do
142 r <- createVector 1
143 app2 c_prodC vec x vec r "prodC"
144 return $ r @> 0
145
146foreign import ccall unsafe "gsl-aux.h prodF" c_prodF :: TFF
147foreign import ccall unsafe "gsl-aux.h prodR" c_prodR :: TVV
148foreign import ccall unsafe "gsl-aux.h prodQ" c_prodQ :: TQVQV
149foreign import ccall unsafe "gsl-aux.h prodC" c_prodC :: TCVCV
150
151-- | dot product
152dotF :: Vector Float -> Vector Float -> Float
153dotF x y = unsafePerformIO $ do
154 r <- createVector 1
155 app3 c_dotF vec x vec y vec r "dotF"
156 return $ r @> 0
157
158-- | dot product
159dotR :: Vector Double -> Vector Double -> Double
160dotR x y = unsafePerformIO $ do
161 r <- createVector 1
162 app3 c_dotR vec x vec y vec r "dotR"
163 return $ r @> 0
164
165-- | dot product
166dotQ :: Vector (Complex Float) -> Vector (Complex Float) -> Complex Float
167dotQ x y = unsafePerformIO $ do
168 r <- createVector 1
169 app3 c_dotQ vec x vec y vec r "dotQ"
170 return $ r @> 0
171
172-- | dot product
173dotC :: Vector (Complex Double) -> Vector (Complex Double) -> Complex Double
174dotC x y = unsafePerformIO $ do
175 r <- createVector 1
176 app3 c_dotC vec x vec y vec r "dotC"
177 return $ r @> 0
178
179foreign import ccall unsafe "gsl-aux.h dotF" c_dotF :: TFFF
180foreign import ccall unsafe "gsl-aux.h dotR" c_dotR :: TVVV
181foreign import ccall unsafe "gsl-aux.h dotQ" c_dotQ :: TQVQVQV
182foreign import ccall unsafe "gsl-aux.h dotC" c_dotC :: TCVCVCV
183
184------------------------------------------------------------------
185
186toScalarAux fun code v = unsafePerformIO $ do
187 r <- createVector 1
188 app2 (fun (fromei code)) vec v vec r "toScalarAux"
189 return (r `at` 0)
190
191vectorMapAux fun code v = unsafePerformIO $ do
192 r <- createVector (dim v)
193 app2 (fun (fromei code)) vec v vec r "vectorMapAux"
194 return r
195
196vectorMapValAux fun code val v = unsafePerformIO $ do
197 r <- createVector (dim v)
198 pval <- newArray [val]
199 app2 (fun (fromei code) pval) vec v vec r "vectorMapValAux"
200 free pval
201 return r
202
203vectorZipAux fun code u v = unsafePerformIO $ do
204 r <- createVector (dim u)
205 when (dim u > 0) $ app3 (fun (fromei code)) vec u vec v vec r "vectorZipAux"
206 return r
207
208---------------------------------------------------------------------
209
210-- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc.
211toScalarR :: FunCodeS -> Vector Double -> Double
212toScalarR oper = toScalarAux c_toScalarR (fromei oper)
213
214foreign import ccall unsafe "gsl-aux.h toScalarR" c_toScalarR :: CInt -> TVV
215
216-- | obtains different functions of a vector: norm1, norm2, max, min, posmax, posmin, etc.
217toScalarF :: FunCodeS -> Vector Float -> Float
218toScalarF oper = toScalarAux c_toScalarF (fromei oper)
219
220foreign import ccall unsafe "gsl-aux.h toScalarF" c_toScalarF :: CInt -> TFF
221
222-- | obtains different functions of a vector: only norm1, norm2
223toScalarC :: FunCodeS -> Vector (Complex Double) -> Double
224toScalarC oper = toScalarAux c_toScalarC (fromei oper)
225
226foreign import ccall unsafe "gsl-aux.h toScalarC" c_toScalarC :: CInt -> TCVV
227
228-- | obtains different functions of a vector: only norm1, norm2
229toScalarQ :: FunCodeS -> Vector (Complex Float) -> Float
230toScalarQ oper = toScalarAux c_toScalarQ (fromei oper)
231
232foreign import ccall unsafe "gsl-aux.h toScalarQ" c_toScalarQ :: CInt -> TQVF
233
234------------------------------------------------------------------
235
236-- | map of real vectors with given function
237vectorMapR :: FunCodeV -> Vector Double -> Vector Double
238vectorMapR = vectorMapAux c_vectorMapR
239
240foreign import ccall unsafe "gsl-aux.h mapR" c_vectorMapR :: CInt -> TVV
241
242-- | map of complex vectors with given function
243vectorMapC :: FunCodeV -> Vector (Complex Double) -> Vector (Complex Double)
244vectorMapC oper = vectorMapAux c_vectorMapC (fromei oper)
245
246foreign import ccall unsafe "gsl-aux.h mapC" c_vectorMapC :: CInt -> TCVCV
247
248-- | map of real vectors with given function
249vectorMapF :: FunCodeV -> Vector Float -> Vector Float
250vectorMapF = vectorMapAux c_vectorMapF
251
252foreign import ccall unsafe "gsl-aux.h mapF" c_vectorMapF :: CInt -> TFF
253
254-- | map of real vectors with given function
255vectorMapQ :: FunCodeV -> Vector (Complex Float) -> Vector (Complex Float)
256vectorMapQ = vectorMapAux c_vectorMapQ
257
258foreign import ccall unsafe "gsl-aux.h mapQ" c_vectorMapQ :: CInt -> TQVQV
259
260-------------------------------------------------------------------
261
262-- | map of real vectors with given function
263vectorMapValR :: FunCodeSV -> Double -> Vector Double -> Vector Double
264vectorMapValR oper = vectorMapValAux c_vectorMapValR (fromei oper)
265
266foreign import ccall unsafe "gsl-aux.h mapValR" c_vectorMapValR :: CInt -> Ptr Double -> TVV
267
268-- | map of complex vectors with given function
269vectorMapValC :: FunCodeSV -> Complex Double -> Vector (Complex Double) -> Vector (Complex Double)
270vectorMapValC = vectorMapValAux c_vectorMapValC
271
272foreign import ccall unsafe "gsl-aux.h mapValC" c_vectorMapValC :: CInt -> Ptr (Complex Double) -> TCVCV
273
274-- | map of real vectors with given function
275vectorMapValF :: FunCodeSV -> Float -> Vector Float -> Vector Float
276vectorMapValF oper = vectorMapValAux c_vectorMapValF (fromei oper)
277
278foreign import ccall unsafe "gsl-aux.h mapValF" c_vectorMapValF :: CInt -> Ptr Float -> TFF
279
280-- | map of complex vectors with given function
281vectorMapValQ :: FunCodeSV -> Complex Float -> Vector (Complex Float) -> Vector (Complex Float)
282vectorMapValQ oper = vectorMapValAux c_vectorMapValQ (fromei oper)
283
284foreign import ccall unsafe "gsl-aux.h mapValQ" c_vectorMapValQ :: CInt -> Ptr (Complex Float) -> TQVQV
285
286-------------------------------------------------------------------
287
288-- | elementwise operation on real vectors
289vectorZipR :: FunCodeVV -> Vector Double -> Vector Double -> Vector Double
290vectorZipR = vectorZipAux c_vectorZipR
291
292foreign import ccall unsafe "gsl-aux.h zipR" c_vectorZipR :: CInt -> TVVV
293
294-- | elementwise operation on complex vectors
295vectorZipC :: FunCodeVV -> Vector (Complex Double) -> Vector (Complex Double) -> Vector (Complex Double)
296vectorZipC = vectorZipAux c_vectorZipC
297
298foreign import ccall unsafe "gsl-aux.h zipC" c_vectorZipC :: CInt -> TCVCVCV
299
300-- | elementwise operation on real vectors
301vectorZipF :: FunCodeVV -> Vector Float -> Vector Float -> Vector Float
302vectorZipF = vectorZipAux c_vectorZipF
303
304foreign import ccall unsafe "gsl-aux.h zipF" c_vectorZipF :: CInt -> TFFF
305
306-- | elementwise operation on complex vectors
307vectorZipQ :: FunCodeVV -> Vector (Complex Float) -> Vector (Complex Float) -> Vector (Complex Float)
308vectorZipQ = vectorZipAux c_vectorZipQ
309
310foreign import ccall unsafe "gsl-aux.h zipQ" c_vectorZipQ :: CInt -> TQVQVQV
311
312-----------------------------------------------------------------------
313
314data RandDist = Uniform -- ^ uniform distribution in [0,1)
315 | Gaussian -- ^ normal distribution with mean zero and standard deviation one
316 deriving Enum
317
318-- | Obtains a vector of pseudorandom elements from the the mt19937 generator in GSL, with a given seed. Use randomIO to get a random seed.
319randomVector :: Int -- ^ seed
320 -> RandDist -- ^ distribution
321 -> Int -- ^ vector size
322 -> Vector Double
323randomVector seed dist n = unsafePerformIO $ do
324 r <- createVector n
325 app1 (c_random_vector (fi seed) ((fi.fromEnum) dist)) vec r "randomVector"
326 return r
327
328foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> TV
diff --git a/packages/hmatrix/src/Numeric/GSL/gsl-aux.c b/packages/hmatrix/src/Numeric/GSL/gsl-aux.c
new file mode 100644
index 0000000..410d157
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/gsl-aux.c
@@ -0,0 +1,1541 @@
1#include <gsl/gsl_complex.h>
2
3#define RVEC(A) int A##n, double*A##p
4#define RMAT(A) int A##r, int A##c, double* A##p
5#define KRVEC(A) int A##n, const double*A##p
6#define KRMAT(A) int A##r, int A##c, const double* A##p
7
8#define CVEC(A) int A##n, gsl_complex*A##p
9#define CMAT(A) int A##r, int A##c, gsl_complex* A##p
10#define KCVEC(A) int A##n, const gsl_complex*A##p
11#define KCMAT(A) int A##r, int A##c, const gsl_complex* A##p
12
13#define FVEC(A) int A##n, float*A##p
14#define FMAT(A) int A##r, int A##c, float* A##p
15#define KFVEC(A) int A##n, const float*A##p
16#define KFMAT(A) int A##r, int A##c, const float* A##p
17
18#define QVEC(A) int A##n, gsl_complex_float*A##p
19#define QMAT(A) int A##r, int A##c, gsl_complex_float* A##p
20#define KQVEC(A) int A##n, const gsl_complex_float*A##p
21#define KQMAT(A) int A##r, int A##c, const gsl_complex_float* A##p
22
23#include <gsl/gsl_blas.h>
24#include <gsl/gsl_math.h>
25#include <gsl/gsl_errno.h>
26#include <gsl/gsl_fft_complex.h>
27#include <gsl/gsl_integration.h>
28#include <gsl/gsl_deriv.h>
29#include <gsl/gsl_poly.h>
30#include <gsl/gsl_multimin.h>
31#include <gsl/gsl_multiroots.h>
32#include <gsl/gsl_min.h>
33#include <gsl/gsl_complex_math.h>
34#include <gsl/gsl_rng.h>
35#include <gsl/gsl_randist.h>
36#include <gsl/gsl_roots.h>
37#include <gsl/gsl_multifit_nlin.h>
38#include <string.h>
39#include <stdio.h>
40
41#define MACRO(B) do {B} while (0)
42#define ERROR(CODE) MACRO(return CODE;)
43#define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);})
44#define OK return 0;
45
46#define MIN(A,B) ((A)<(B)?(A):(B))
47#define MAX(A,B) ((A)>(B)?(A):(B))
48
49#ifdef DBG
50#define DEBUGMSG(M) printf("*** calling aux C function: %s\n",M);
51#else
52#define DEBUGMSG(M)
53#endif
54
55#define CHECK(RES,CODE) MACRO(if(RES) return CODE;)
56
57#ifdef DBG
58#define DEBUGMAT(MSG,X) printf(MSG" = \n"); gsl_matrix_fprintf(stdout,X,"%f"); printf("\n");
59#else
60#define DEBUGMAT(MSG,X)
61#endif
62
63#ifdef DBG
64#define DEBUGVEC(MSG,X) printf(MSG" = \n"); gsl_vector_fprintf(stdout,X,"%f"); printf("\n");
65#else
66#define DEBUGVEC(MSG,X)
67#endif
68
69#define DVVIEW(A) gsl_vector_view A = gsl_vector_view_array(A##p,A##n)
70#define DMVIEW(A) gsl_matrix_view A = gsl_matrix_view_array(A##p,A##r,A##c)
71#define CVVIEW(A) gsl_vector_complex_view A = gsl_vector_complex_view_array((double*)A##p,A##n)
72#define CMVIEW(A) gsl_matrix_complex_view A = gsl_matrix_complex_view_array((double*)A##p,A##r,A##c)
73#define KDVVIEW(A) gsl_vector_const_view A = gsl_vector_const_view_array(A##p,A##n)
74#define KDMVIEW(A) gsl_matrix_const_view A = gsl_matrix_const_view_array(A##p,A##r,A##c)
75#define KCVVIEW(A) gsl_vector_complex_const_view A = gsl_vector_complex_const_view_array((double*)A##p,A##n)
76#define KCMVIEW(A) gsl_matrix_complex_const_view A = gsl_matrix_complex_const_view_array((double*)A##p,A##r,A##c)
77
78#define FVVIEW(A) gsl_vector_float_view A = gsl_vector_float_view_array(A##p,A##n)
79#define FMVIEW(A) gsl_matrix_float_view A = gsl_matrix_float_view_array(A##p,A##r,A##c)
80#define QVVIEW(A) gsl_vector_complex_float_view A = gsl_vector_float_complex_view_array((float*)A##p,A##n)
81#define QMVIEW(A) gsl_matrix_complex_float_view A = gsl_matrix_float_complex_view_array((float*)A##p,A##r,A##c)
82#define KFVVIEW(A) gsl_vector_float_const_view A = gsl_vector_float_const_view_array(A##p,A##n)
83#define KFMVIEW(A) gsl_matrix_float_const_view A = gsl_matrix_float_const_view_array(A##p,A##r,A##c)
84#define KQVVIEW(A) gsl_vector_complex_float_const_view A = gsl_vector_complex_float_const_view_array((float*)A##p,A##n)
85#define KQMVIEW(A) gsl_matrix_complex_float_const_view A = gsl_matrix_complex_float_const_view_array((float*)A##p,A##r,A##c)
86
87#define V(a) (&a.vector)
88#define M(a) (&a.matrix)
89
90#define GCVEC(A) int A##n, gsl_complex*A##p
91#define KGCVEC(A) int A##n, const gsl_complex*A##p
92
93#define GQVEC(A) int A##n, gsl_complex_float*A##p
94#define KGQVEC(A) int A##n, const gsl_complex_float*A##p
95
96#define BAD_SIZE 2000
97#define BAD_CODE 2001
98#define MEM 2002
99#define BAD_FILE 2003
100
101
102void no_abort_on_error() {
103 gsl_set_error_handler_off();
104}
105
106
107int sumF(KFVEC(x),FVEC(r)) {
108 DEBUGMSG("sumF");
109 REQUIRES(rn==1,BAD_SIZE);
110 int i;
111 float res = 0;
112 for (i = 0; i < xn; i++) res += xp[i];
113 rp[0] = res;
114 OK
115}
116
117int sumR(KRVEC(x),RVEC(r)) {
118 DEBUGMSG("sumR");
119 REQUIRES(rn==1,BAD_SIZE);
120 int i;
121 double res = 0;
122 for (i = 0; i < xn; i++) res += xp[i];
123 rp[0] = res;
124 OK
125}
126
127int sumQ(KQVEC(x),QVEC(r)) {
128 DEBUGMSG("sumQ");
129 REQUIRES(rn==1,BAD_SIZE);
130 int i;
131 gsl_complex_float res;
132 res.dat[0] = 0;
133 res.dat[1] = 0;
134 for (i = 0; i < xn; i++) {
135 res.dat[0] += xp[i].dat[0];
136 res.dat[1] += xp[i].dat[1];
137 }
138 rp[0] = res;
139 OK
140}
141
142int sumC(KCVEC(x),CVEC(r)) {
143 DEBUGMSG("sumC");
144 REQUIRES(rn==1,BAD_SIZE);
145 int i;
146 gsl_complex res;
147 res.dat[0] = 0;
148 res.dat[1] = 0;
149 for (i = 0; i < xn; i++) {
150 res.dat[0] += xp[i].dat[0];
151 res.dat[1] += xp[i].dat[1];
152 }
153 rp[0] = res;
154 OK
155}
156
157int prodF(KFVEC(x),FVEC(r)) {
158 DEBUGMSG("prodF");
159 REQUIRES(rn==1,BAD_SIZE);
160 int i;
161 float res = 1;
162 for (i = 0; i < xn; i++) res *= xp[i];
163 rp[0] = res;
164 OK
165}
166
167int prodR(KRVEC(x),RVEC(r)) {
168 DEBUGMSG("prodR");
169 REQUIRES(rn==1,BAD_SIZE);
170 int i;
171 double res = 1;
172 for (i = 0; i < xn; i++) res *= xp[i];
173 rp[0] = res;
174 OK
175}
176
177int prodQ(KQVEC(x),QVEC(r)) {
178 DEBUGMSG("prodQ");
179 REQUIRES(rn==1,BAD_SIZE);
180 int i;
181 gsl_complex_float res;
182 float temp;
183 res.dat[0] = 1;
184 res.dat[1] = 0;
185 for (i = 0; i < xn; i++) {
186 temp = res.dat[0] * xp[i].dat[0] - res.dat[1] * xp[i].dat[1];
187 res.dat[1] = res.dat[0] * xp[i].dat[1] + res.dat[1] * xp[i].dat[0];
188 res.dat[0] = temp;
189 }
190 rp[0] = res;
191 OK
192}
193
194int prodC(KCVEC(x),CVEC(r)) {
195 DEBUGMSG("prodC");
196 REQUIRES(rn==1,BAD_SIZE);
197 int i;
198 gsl_complex res;
199 double temp;
200 res.dat[0] = 1;
201 res.dat[1] = 0;
202 for (i = 0; i < xn; i++) {
203 temp = res.dat[0] * xp[i].dat[0] - res.dat[1] * xp[i].dat[1];
204 res.dat[1] = res.dat[0] * xp[i].dat[1] + res.dat[1] * xp[i].dat[0];
205 res.dat[0] = temp;
206 }
207 rp[0] = res;
208 OK
209}
210
211int dotF(KFVEC(x), KFVEC(y), FVEC(r)) {
212 DEBUGMSG("dotF");
213 REQUIRES(xn==yn,BAD_SIZE);
214 REQUIRES(rn==1,BAD_SIZE);
215 DEBUGMSG("dotF");
216 KFVVIEW(x);
217 KFVVIEW(y);
218 gsl_blas_sdot(V(x),V(y),rp);
219 OK
220}
221
222int dotR(KRVEC(x), KRVEC(y), RVEC(r)) {
223 DEBUGMSG("dotR");
224 REQUIRES(xn==yn,BAD_SIZE);
225 REQUIRES(rn==1,BAD_SIZE);
226 DEBUGMSG("dotR");
227 KDVVIEW(x);
228 KDVVIEW(y);
229 gsl_blas_ddot(V(x),V(y),rp);
230 OK
231}
232
233int dotQ(KQVEC(x), KQVEC(y), QVEC(r)) {
234 DEBUGMSG("dotQ");
235 REQUIRES(xn==yn,BAD_SIZE);
236 REQUIRES(rn==1,BAD_SIZE);
237 DEBUGMSG("dotQ");
238 KQVVIEW(x);
239 KQVVIEW(y);
240 gsl_blas_cdotu(V(x),V(y),rp);
241 OK
242}
243
244int dotC(KCVEC(x), KCVEC(y), CVEC(r)) {
245 DEBUGMSG("dotC");
246 REQUIRES(xn==yn,BAD_SIZE);
247 REQUIRES(rn==1,BAD_SIZE);
248 DEBUGMSG("dotC");
249 KCVVIEW(x);
250 KCVVIEW(y);
251 gsl_blas_zdotu(V(x),V(y),rp);
252 OK
253}
254
255int toScalarR(int code, KRVEC(x), RVEC(r)) {
256 REQUIRES(rn==1,BAD_SIZE);
257 DEBUGMSG("toScalarR");
258 KDVVIEW(x);
259 double res;
260 switch(code) {
261 case 0: { res = gsl_blas_dnrm2(V(x)); break; }
262 case 1: { res = gsl_blas_dasum(V(x)); break; }
263 case 2: { res = gsl_vector_max_index(V(x)); break; }
264 case 3: { res = gsl_vector_max(V(x)); break; }
265 case 4: { res = gsl_vector_min_index(V(x)); break; }
266 case 5: { res = gsl_vector_min(V(x)); break; }
267 default: ERROR(BAD_CODE);
268 }
269 rp[0] = res;
270 OK
271}
272
273int toScalarF(int code, KFVEC(x), FVEC(r)) {
274 REQUIRES(rn==1,BAD_SIZE);
275 DEBUGMSG("toScalarF");
276 KFVVIEW(x);
277 float res;
278 switch(code) {
279 case 0: { res = gsl_blas_snrm2(V(x)); break; }
280 case 1: { res = gsl_blas_sasum(V(x)); break; }
281 case 2: { res = gsl_vector_float_max_index(V(x)); break; }
282 case 3: { res = gsl_vector_float_max(V(x)); break; }
283 case 4: { res = gsl_vector_float_min_index(V(x)); break; }
284 case 5: { res = gsl_vector_float_min(V(x)); break; }
285 default: ERROR(BAD_CODE);
286 }
287 rp[0] = res;
288 OK
289}
290
291
292int toScalarC(int code, KCVEC(x), RVEC(r)) {
293 REQUIRES(rn==1,BAD_SIZE);
294 DEBUGMSG("toScalarC");
295 KCVVIEW(x);
296 double res;
297 switch(code) {
298 case 0: { res = gsl_blas_dznrm2(V(x)); break; }
299 case 1: { res = gsl_blas_dzasum(V(x)); break; }
300 default: ERROR(BAD_CODE);
301 }
302 rp[0] = res;
303 OK
304}
305
306int toScalarQ(int code, KQVEC(x), FVEC(r)) {
307 REQUIRES(rn==1,BAD_SIZE);
308 DEBUGMSG("toScalarQ");
309 KQVVIEW(x);
310 float res;
311 switch(code) {
312 case 0: { res = gsl_blas_scnrm2(V(x)); break; }
313 case 1: { res = gsl_blas_scasum(V(x)); break; }
314 default: ERROR(BAD_CODE);
315 }
316 rp[0] = res;
317 OK
318}
319
320
321inline double sign(double x) {
322 if(x>0) {
323 return +1.0;
324 } else if (x<0) {
325 return -1.0;
326 } else {
327 return 0.0;
328 }
329}
330
331inline float float_sign(float x) {
332 if(x>0) {
333 return +1.0;
334 } else if (x<0) {
335 return -1.0;
336 } else {
337 return 0.0;
338 }
339}
340
341inline gsl_complex complex_abs(gsl_complex z) {
342 gsl_complex r;
343 r.dat[0] = gsl_complex_abs(z);
344 r.dat[1] = 0;
345 return r;
346}
347
348inline gsl_complex complex_signum(gsl_complex z) {
349 gsl_complex r;
350 double mag;
351 if (z.dat[0] == 0 && z.dat[1] == 0) {
352 r.dat[0] = 0;
353 r.dat[1] = 0;
354 } else {
355 mag = gsl_complex_abs(z);
356 r.dat[0] = z.dat[0]/mag;
357 r.dat[1] = z.dat[1]/mag;
358 }
359 return r;
360}
361
362#define OP(C,F) case C: { for(k=0;k<xn;k++) rp[k] = F(xp[k]); OK }
363#define OPV(C,E) case C: { for(k=0;k<xn;k++) rp[k] = E; OK }
364int mapR(int code, KRVEC(x), RVEC(r)) {
365 int k;
366 REQUIRES(xn == rn,BAD_SIZE);
367 DEBUGMSG("mapR");
368 switch (code) {
369 OP(0,sin)
370 OP(1,cos)
371 OP(2,tan)
372 OP(3,fabs)
373 OP(4,asin)
374 OP(5,acos)
375 OP(6,atan) /* atan2 mediante vectorZip */
376 OP(7,sinh)
377 OP(8,cosh)
378 OP(9,tanh)
379 OP(10,gsl_asinh)
380 OP(11,gsl_acosh)
381 OP(12,gsl_atanh)
382 OP(13,exp)
383 OP(14,log)
384 OP(15,sign)
385 OP(16,sqrt)
386 default: ERROR(BAD_CODE);
387 }
388}
389
390int mapF(int code, KFVEC(x), FVEC(r)) {
391 int k;
392 REQUIRES(xn == rn,BAD_SIZE);
393 DEBUGMSG("mapF");
394 switch (code) {
395 OP(0,sin)
396 OP(1,cos)
397 OP(2,tan)
398 OP(3,fabs)
399 OP(4,asin)
400 OP(5,acos)
401 OP(6,atan) /* atan2 mediante vectorZip */
402 OP(7,sinh)
403 OP(8,cosh)
404 OP(9,tanh)
405 OP(10,gsl_asinh)
406 OP(11,gsl_acosh)
407 OP(12,gsl_atanh)
408 OP(13,exp)
409 OP(14,log)
410 OP(15,sign)
411 OP(16,sqrt)
412 default: ERROR(BAD_CODE);
413 }
414}
415
416
417int mapCAux(int code, KGCVEC(x), GCVEC(r)) {
418 int k;
419 REQUIRES(xn == rn,BAD_SIZE);
420 DEBUGMSG("mapC");
421 switch (code) {
422 OP(0,gsl_complex_sin)
423 OP(1,gsl_complex_cos)
424 OP(2,gsl_complex_tan)
425 OP(3,complex_abs)
426 OP(4,gsl_complex_arcsin)
427 OP(5,gsl_complex_arccos)
428 OP(6,gsl_complex_arctan)
429 OP(7,gsl_complex_sinh)
430 OP(8,gsl_complex_cosh)
431 OP(9,gsl_complex_tanh)
432 OP(10,gsl_complex_arcsinh)
433 OP(11,gsl_complex_arccosh)
434 OP(12,gsl_complex_arctanh)
435 OP(13,gsl_complex_exp)
436 OP(14,gsl_complex_log)
437 OP(15,complex_signum)
438 OP(16,gsl_complex_sqrt)
439
440 // gsl_complex_arg
441 // gsl_complex_abs
442 default: ERROR(BAD_CODE);
443 }
444}
445
446int mapC(int code, KCVEC(x), CVEC(r)) {
447 return mapCAux(code, xn, (gsl_complex*)xp, rn, (gsl_complex*)rp);
448}
449
450
451gsl_complex_float complex_float_math_fun(gsl_complex (*cf)(gsl_complex), gsl_complex_float a)
452{
453 gsl_complex c;
454 gsl_complex r;
455
456 gsl_complex_float float_r;
457
458 c.dat[0] = a.dat[0];
459 c.dat[1] = a.dat[1];
460
461 r = (*cf)(c);
462
463 float_r.dat[0] = r.dat[0];
464 float_r.dat[1] = r.dat[1];
465
466 return float_r;
467}
468
469gsl_complex_float complex_float_math_op(gsl_complex (*cf)(gsl_complex,gsl_complex),
470 gsl_complex_float a,gsl_complex_float b)
471{
472 gsl_complex c1;
473 gsl_complex c2;
474 gsl_complex r;
475
476 gsl_complex_float float_r;
477
478 c1.dat[0] = a.dat[0];
479 c1.dat[1] = a.dat[1];
480
481 c2.dat[0] = b.dat[0];
482 c2.dat[1] = b.dat[1];
483
484 r = (*cf)(c1,c2);
485
486 float_r.dat[0] = r.dat[0];
487 float_r.dat[1] = r.dat[1];
488
489 return float_r;
490}
491
492#define OPC(C,F) case C: { for(k=0;k<xn;k++) rp[k] = complex_float_math_fun(&F,xp[k]); OK }
493#define OPCA(C,F,A,B) case C: { for(k=0;k<xn;k++) rp[k] = complex_float_math_op(&F,A,B); OK }
494int mapQAux(int code, KGQVEC(x), GQVEC(r)) {
495 int k;
496 REQUIRES(xn == rn,BAD_SIZE);
497 DEBUGMSG("mapQ");
498 switch (code) {
499 OPC(0,gsl_complex_sin)
500 OPC(1,gsl_complex_cos)
501 OPC(2,gsl_complex_tan)
502 OPC(3,complex_abs)
503 OPC(4,gsl_complex_arcsin)
504 OPC(5,gsl_complex_arccos)
505 OPC(6,gsl_complex_arctan)
506 OPC(7,gsl_complex_sinh)
507 OPC(8,gsl_complex_cosh)
508 OPC(9,gsl_complex_tanh)
509 OPC(10,gsl_complex_arcsinh)
510 OPC(11,gsl_complex_arccosh)
511 OPC(12,gsl_complex_arctanh)
512 OPC(13,gsl_complex_exp)
513 OPC(14,gsl_complex_log)
514 OPC(15,complex_signum)
515 OPC(16,gsl_complex_sqrt)
516
517 // gsl_complex_arg
518 // gsl_complex_abs
519 default: ERROR(BAD_CODE);
520 }
521}
522
523int mapQ(int code, KQVEC(x), QVEC(r)) {
524 return mapQAux(code, xn, (gsl_complex_float*)xp, rn, (gsl_complex_float*)rp);
525}
526
527
528int mapValR(int code, double* pval, KRVEC(x), RVEC(r)) {
529 int k;
530 double val = *pval;
531 REQUIRES(xn == rn,BAD_SIZE);
532 DEBUGMSG("mapValR");
533 switch (code) {
534 OPV(0,val*xp[k])
535 OPV(1,val/xp[k])
536 OPV(2,val+xp[k])
537 OPV(3,val-xp[k])
538 OPV(4,pow(val,xp[k]))
539 OPV(5,pow(xp[k],val))
540 default: ERROR(BAD_CODE);
541 }
542}
543
544int mapValF(int code, float* pval, KFVEC(x), FVEC(r)) {
545 int k;
546 float val = *pval;
547 REQUIRES(xn == rn,BAD_SIZE);
548 DEBUGMSG("mapValF");
549 switch (code) {
550 OPV(0,val*xp[k])
551 OPV(1,val/xp[k])
552 OPV(2,val+xp[k])
553 OPV(3,val-xp[k])
554 OPV(4,pow(val,xp[k]))
555 OPV(5,pow(xp[k],val))
556 default: ERROR(BAD_CODE);
557 }
558}
559
560int mapValCAux(int code, gsl_complex* pval, KGCVEC(x), GCVEC(r)) {
561 int k;
562 gsl_complex val = *pval;
563 REQUIRES(xn == rn,BAD_SIZE);
564 DEBUGMSG("mapValC");
565 switch (code) {
566 OPV(0,gsl_complex_mul(val,xp[k]))
567 OPV(1,gsl_complex_div(val,xp[k]))
568 OPV(2,gsl_complex_add(val,xp[k]))
569 OPV(3,gsl_complex_sub(val,xp[k]))
570 OPV(4,gsl_complex_pow(val,xp[k]))
571 OPV(5,gsl_complex_pow(xp[k],val))
572 default: ERROR(BAD_CODE);
573 }
574}
575
576int mapValC(int code, gsl_complex* val, KCVEC(x), CVEC(r)) {
577 return mapValCAux(code, val, xn, (gsl_complex*)xp, rn, (gsl_complex*)rp);
578}
579
580
581int mapValQAux(int code, gsl_complex_float* pval, KQVEC(x), GQVEC(r)) {
582 int k;
583 gsl_complex_float val = *pval;
584 REQUIRES(xn == rn,BAD_SIZE);
585 DEBUGMSG("mapValQ");
586 switch (code) {
587 OPCA(0,gsl_complex_mul,val,xp[k])
588 OPCA(1,gsl_complex_div,val,xp[k])
589 OPCA(2,gsl_complex_add,val,xp[k])
590 OPCA(3,gsl_complex_sub,val,xp[k])
591 OPCA(4,gsl_complex_pow,val,xp[k])
592 OPCA(5,gsl_complex_pow,xp[k],val)
593 default: ERROR(BAD_CODE);
594 }
595}
596
597int mapValQ(int code, gsl_complex_float* val, KQVEC(x), QVEC(r)) {
598 return mapValQAux(code, val, xn, (gsl_complex_float*)xp, rn, (gsl_complex_float*)rp);
599}
600
601
602#define OPZE(C,msg,E) case C: {DEBUGMSG(msg) for(k=0;k<an;k++) rp[k] = E(ap[k],bp[k]); OK }
603#define OPZV(C,msg,E) case C: {DEBUGMSG(msg) res = E(V(r),V(b)); CHECK(res,res); OK }
604int zipR(int code, KRVEC(a), KRVEC(b), RVEC(r)) {
605 REQUIRES(an == bn && an == rn, BAD_SIZE);
606 int k;
607 switch(code) {
608 OPZE(4,"zipR Pow",pow)
609 OPZE(5,"zipR ATan2",atan2)
610 }
611 KDVVIEW(a);
612 KDVVIEW(b);
613 DVVIEW(r);
614 gsl_vector_memcpy(V(r),V(a));
615 int res;
616 switch(code) {
617 OPZV(0,"zipR Add",gsl_vector_add)
618 OPZV(1,"zipR Sub",gsl_vector_sub)
619 OPZV(2,"zipR Mul",gsl_vector_mul)
620 OPZV(3,"zipR Div",gsl_vector_div)
621 default: ERROR(BAD_CODE);
622 }
623}
624
625
626int zipF(int code, KFVEC(a), KFVEC(b), FVEC(r)) {
627 REQUIRES(an == bn && an == rn, BAD_SIZE);
628 int k;
629 switch(code) {
630 OPZE(4,"zipF Pow",pow)
631 OPZE(5,"zipF ATan2",atan2)
632 }
633 KFVVIEW(a);
634 KFVVIEW(b);
635 FVVIEW(r);
636 gsl_vector_float_memcpy(V(r),V(a));
637 int res;
638 switch(code) {
639 OPZV(0,"zipF Add",gsl_vector_float_add)
640 OPZV(1,"zipF Sub",gsl_vector_float_sub)
641 OPZV(2,"zipF Mul",gsl_vector_float_mul)
642 OPZV(3,"zipF Div",gsl_vector_float_div)
643 default: ERROR(BAD_CODE);
644 }
645}
646
647
648int zipCAux(int code, KGCVEC(a), KGCVEC(b), GCVEC(r)) {
649 REQUIRES(an == bn && an == rn, BAD_SIZE);
650 int k;
651 switch(code) {
652 OPZE(0,"zipC Add",gsl_complex_add)
653 OPZE(1,"zipC Sub",gsl_complex_sub)
654 OPZE(2,"zipC Mul",gsl_complex_mul)
655 OPZE(3,"zipC Div",gsl_complex_div)
656 OPZE(4,"zipC Pow",gsl_complex_pow)
657 //OPZE(5,"zipR ATan2",atan2)
658 }
659 //KCVVIEW(a);
660 //KCVVIEW(b);
661 //CVVIEW(r);
662 //gsl_vector_memcpy(V(r),V(a));
663 //int res;
664 switch(code) {
665 default: ERROR(BAD_CODE);
666 }
667}
668
669
670int zipC(int code, KCVEC(a), KCVEC(b), CVEC(r)) {
671 return zipCAux(code, an, (gsl_complex*)ap, bn, (gsl_complex*)bp, rn, (gsl_complex*)rp);
672}
673
674
675#define OPCZE(C,msg,E) case C: {DEBUGMSG(msg) for(k=0;k<an;k++) rp[k] = complex_float_math_op(&E,ap[k],bp[k]); OK }
676int zipQAux(int code, KGQVEC(a), KGQVEC(b), GQVEC(r)) {
677 REQUIRES(an == bn && an == rn, BAD_SIZE);
678 int k;
679 switch(code) {
680 OPCZE(0,"zipQ Add",gsl_complex_add)
681 OPCZE(1,"zipQ Sub",gsl_complex_sub)
682 OPCZE(2,"zipQ Mul",gsl_complex_mul)
683 OPCZE(3,"zipQ Div",gsl_complex_div)
684 OPCZE(4,"zipQ Pow",gsl_complex_pow)
685 //OPZE(5,"zipR ATan2",atan2)
686 }
687 //KCVVIEW(a);
688 //KCVVIEW(b);
689 //CVVIEW(r);
690 //gsl_vector_memcpy(V(r),V(a));
691 //int res;
692 switch(code) {
693 default: ERROR(BAD_CODE);
694 }
695}
696
697
698int zipQ(int code, KQVEC(a), KQVEC(b), QVEC(r)) {
699 return zipQAux(code, an, (gsl_complex_float*)ap, bn, (gsl_complex_float*)bp, rn, (gsl_complex_float*)rp);
700}
701
702
703
704int fft(int code, KCVEC(X), CVEC(R)) {
705 REQUIRES(Xn == Rn,BAD_SIZE);
706 DEBUGMSG("fft");
707 int s = Xn;
708 gsl_fft_complex_wavetable * wavetable = gsl_fft_complex_wavetable_alloc (s);
709 gsl_fft_complex_workspace * workspace = gsl_fft_complex_workspace_alloc (s);
710 gsl_vector_const_view X = gsl_vector_const_view_array((double*)Xp, 2*Xn);
711 gsl_vector_view R = gsl_vector_view_array((double*)Rp, 2*Rn);
712 gsl_blas_dcopy(&X.vector,&R.vector);
713 if(code==0) {
714 gsl_fft_complex_forward ((double*)Rp, 1, s, wavetable, workspace);
715 } else {
716 gsl_fft_complex_inverse ((double*)Rp, 1, s, wavetable, workspace);
717 }
718 gsl_fft_complex_wavetable_free (wavetable);
719 gsl_fft_complex_workspace_free (workspace);
720 OK
721}
722
723
724int deriv(int code, double f(double, void*), double x, double h, double * result, double * abserr)
725{
726 gsl_function F;
727 F.function = f;
728 F.params = 0;
729
730 if(code==0) return gsl_deriv_central (&F, x, h, result, abserr);
731
732 if(code==1) return gsl_deriv_forward (&F, x, h, result, abserr);
733
734 if(code==2) return gsl_deriv_backward (&F, x, h, result, abserr);
735
736 return 0;
737}
738
739
740int integrate_qng(double f(double, void*), double a, double b, double aprec, double prec,
741 double *result, double*error) {
742 DEBUGMSG("integrate_qng");
743 gsl_function F;
744 F.function = f;
745 F.params = NULL;
746 size_t neval;
747 int res = gsl_integration_qng (&F, a,b, aprec, prec, result, error, &neval);
748 CHECK(res,res);
749 OK
750}
751
752int integrate_qags(double f(double,void*), double a, double b, double aprec, double prec, int w,
753 double *result, double* error) {
754 DEBUGMSG("integrate_qags");
755 gsl_integration_workspace * wk = gsl_integration_workspace_alloc (w);
756 gsl_function F;
757 F.function = f;
758 F.params = NULL;
759 int res = gsl_integration_qags (&F, a,b, aprec, prec, w,wk, result, error);
760 CHECK(res,res);
761 gsl_integration_workspace_free (wk);
762 OK
763}
764
765int integrate_qagi(double f(double,void*), double aprec, double prec, int w,
766 double *result, double* error) {
767 DEBUGMSG("integrate_qagi");
768 gsl_integration_workspace * wk = gsl_integration_workspace_alloc (w);
769 gsl_function F;
770 F.function = f;
771 F.params = NULL;
772 int res = gsl_integration_qagi (&F, aprec, prec, w,wk, result, error);
773 CHECK(res,res);
774 gsl_integration_workspace_free (wk);
775 OK
776}
777
778
779int integrate_qagiu(double f(double,void*), double a, double aprec, double prec, int w,
780 double *result, double* error) {
781 DEBUGMSG("integrate_qagiu");
782 gsl_integration_workspace * wk = gsl_integration_workspace_alloc (w);
783 gsl_function F;
784 F.function = f;
785 F.params = NULL;
786 int res = gsl_integration_qagiu (&F, a, aprec, prec, w,wk, result, error);
787 CHECK(res,res);
788 gsl_integration_workspace_free (wk);
789 OK
790}
791
792
793int integrate_qagil(double f(double,void*), double b, double aprec, double prec, int w,
794 double *result, double* error) {
795 DEBUGMSG("integrate_qagil");
796 gsl_integration_workspace * wk = gsl_integration_workspace_alloc (w);
797 gsl_function F;
798 F.function = f;
799 F.params = NULL;
800 int res = gsl_integration_qagil (&F, b, aprec, prec, w,wk, result, error);
801 CHECK(res,res);
802 gsl_integration_workspace_free (wk);
803 OK
804}
805
806int integrate_cquad(double f(double,void*), double a, double b, double aprec, double prec,
807 int w, double *result, double* error, int *neval) {
808 DEBUGMSG("integrate_cquad");
809 gsl_integration_cquad_workspace * wk = gsl_integration_cquad_workspace_alloc (w);
810 gsl_function F;
811 F.function = f;
812 F.params = NULL;
813 size_t * sneval = NULL;
814 int res = gsl_integration_cquad (&F, a, b, aprec, prec, wk, result, error, sneval);
815 *neval = *sneval;
816 CHECK(res,res);
817 gsl_integration_cquad_workspace_free (wk);
818 OK
819}
820
821
822int polySolve(KRVEC(a), CVEC(z)) {
823 DEBUGMSG("polySolve");
824 REQUIRES(an>1,BAD_SIZE);
825 gsl_poly_complex_workspace * w = gsl_poly_complex_workspace_alloc (an);
826 int res = gsl_poly_complex_solve ((double*)ap, an, w, (double*)zp);
827 CHECK(res,res);
828 gsl_poly_complex_workspace_free (w);
829 OK;
830}
831
832int vector_fscanf(char*filename, RVEC(a)) {
833 DEBUGMSG("gsl_vector_fscanf");
834 DVVIEW(a);
835 FILE * f = fopen(filename,"r");
836 CHECK(!f,BAD_FILE);
837 int res = gsl_vector_fscanf(f,V(a));
838 CHECK(res,res);
839 fclose (f);
840 OK
841}
842
843int vector_fprintf(char*filename, char*fmt, RVEC(a)) {
844 DEBUGMSG("gsl_vector_fprintf");
845 DVVIEW(a);
846 FILE * f = fopen(filename,"w");
847 CHECK(!f,BAD_FILE);
848 int res = gsl_vector_fprintf(f,V(a),fmt);
849 CHECK(res,res);
850 fclose (f);
851 OK
852}
853
854int vector_fread(char*filename, RVEC(a)) {
855 DEBUGMSG("gsl_vector_fread");
856 DVVIEW(a);
857 FILE * f = fopen(filename,"r");
858 CHECK(!f,BAD_FILE);
859 int res = gsl_vector_fread(f,V(a));
860 CHECK(res,res);
861 fclose (f);
862 OK
863}
864
865int vector_fwrite(char*filename, RVEC(a)) {
866 DEBUGMSG("gsl_vector_fwrite");
867 DVVIEW(a);
868 FILE * f = fopen(filename,"w");
869 CHECK(!f,BAD_FILE);
870 int res = gsl_vector_fwrite(f,V(a));
871 CHECK(res,res);
872 fclose (f);
873 OK
874}
875
876int matrix_fprintf(char*filename, char*fmt, int ro, RMAT(m)) {
877 DEBUGMSG("matrix_fprintf");
878 FILE * f = fopen(filename,"w");
879 CHECK(!f,BAD_FILE);
880 int i,j,sr,sc;
881 if (ro==1) { sr = mc; sc = 1;} else { sr = 1; sc = mr;}
882 #define AT(M,r,c) (M##p[(r)*sr+(c)*sc])
883 for (i=0; i<mr; i++) {
884 for (j=0; j<mc-1; j++) {
885 fprintf(f,fmt,AT(m,i,j));
886 fprintf(f," ");
887 }
888 fprintf(f,fmt,AT(m,i,j));
889 fprintf(f,"\n");
890 }
891 fclose (f);
892 OK
893}
894
895//---------------------------------------------------------------
896
897typedef double Trawfun(int, double*);
898
899double only_f_aux_min(const gsl_vector*x, void *pars) {
900 Trawfun * f = (Trawfun*) pars;
901 double* p = (double*)calloc(x->size,sizeof(double));
902 int k;
903 for(k=0;k<x->size;k++) {
904 p[k] = gsl_vector_get(x,k);
905 }
906 double res = f(x->size,p);
907 free(p);
908 return res;
909}
910
911double only_f_aux_root(double x, void *pars);
912int uniMinimize(int method, double f(double),
913 double epsrel, int maxit, double min,
914 double xl, double xu, RMAT(sol)) {
915 REQUIRES(solr == maxit && solc == 4,BAD_SIZE);
916 DEBUGMSG("minimize_only_f");
917 gsl_function my_func;
918 my_func.function = only_f_aux_root;
919 my_func.params = f;
920 size_t iter = 0;
921 int status;
922 const gsl_min_fminimizer_type *T;
923 gsl_min_fminimizer *s;
924 // Starting point
925 switch(method) {
926 case 0 : {T = gsl_min_fminimizer_goldensection; break; }
927 case 1 : {T = gsl_min_fminimizer_brent; break; }
928 case 2 : {T = gsl_min_fminimizer_quad_golden; break; }
929 default: ERROR(BAD_CODE);
930 }
931 s = gsl_min_fminimizer_alloc (T);
932 gsl_min_fminimizer_set (s, &my_func, min, xl, xu);
933 do {
934 double current_min, current_lo, current_hi;
935 status = gsl_min_fminimizer_iterate (s);
936 current_min = gsl_min_fminimizer_x_minimum (s);
937 current_lo = gsl_min_fminimizer_x_lower (s);
938 current_hi = gsl_min_fminimizer_x_upper (s);
939 solp[iter*solc] = iter + 1;
940 solp[iter*solc+1] = current_min;
941 solp[iter*solc+2] = current_lo;
942 solp[iter*solc+3] = current_hi;
943 iter++;
944 if (status) /* check if solver is stuck */
945 break;
946
947 status =
948 gsl_min_test_interval (current_lo, current_hi, 0, epsrel);
949 }
950 while (status == GSL_CONTINUE && iter < maxit);
951 int i;
952 for (i=iter; i<solr; i++) {
953 solp[i*solc+0] = iter;
954 solp[i*solc+1]=0.;
955 solp[i*solc+2]=0.;
956 solp[i*solc+3]=0.;
957 }
958 gsl_min_fminimizer_free(s);
959 OK
960}
961
962
963
964// this version returns info about intermediate steps
965int minimize(int method, double f(int, double*), double tolsize, int maxit,
966 KRVEC(xi), KRVEC(sz), RMAT(sol)) {
967 REQUIRES(xin==szn && solr == maxit && solc == 3+xin,BAD_SIZE);
968 DEBUGMSG("minimizeList (nmsimplex)");
969 gsl_multimin_function my_func;
970 // extract function from pars
971 my_func.f = only_f_aux_min;
972 my_func.n = xin;
973 my_func.params = f;
974 size_t iter = 0;
975 int status;
976 double size;
977 const gsl_multimin_fminimizer_type *T;
978 gsl_multimin_fminimizer *s = NULL;
979 // Initial vertex size vector
980 KDVVIEW(sz);
981 // Starting point
982 KDVVIEW(xi);
983 // Minimizer nmsimplex, without derivatives
984 switch(method) {
985 case 0 : {T = gsl_multimin_fminimizer_nmsimplex; break; }
986#ifdef GSL110
987 case 1 : {T = gsl_multimin_fminimizer_nmsimplex; break; }
988#else
989 case 1 : {T = gsl_multimin_fminimizer_nmsimplex2; break; }
990#endif
991 default: ERROR(BAD_CODE);
992 }
993 s = gsl_multimin_fminimizer_alloc (T, my_func.n);
994 gsl_multimin_fminimizer_set (s, &my_func, V(xi), V(sz));
995 do {
996 status = gsl_multimin_fminimizer_iterate (s);
997 size = gsl_multimin_fminimizer_size (s);
998
999 solp[iter*solc+0] = iter+1;
1000 solp[iter*solc+1] = s->fval;
1001 solp[iter*solc+2] = size;
1002
1003 int k;
1004 for(k=0;k<xin;k++) {
1005 solp[iter*solc+k+3] = gsl_vector_get(s->x,k);
1006 }
1007 iter++;
1008 if (status) break;
1009 status = gsl_multimin_test_size (size, tolsize);
1010 } while (status == GSL_CONTINUE && iter < maxit);
1011 int i,j;
1012 for (i=iter; i<solr; i++) {
1013 solp[i*solc+0] = iter;
1014 for(j=1;j<solc;j++) {
1015 solp[i*solc+j]=0.;
1016 }
1017 }
1018 gsl_multimin_fminimizer_free(s);
1019 OK
1020}
1021
1022// working with the gradient
1023
1024typedef struct {double (*f)(int, double*); int (*df)(int, double*, int, double*);} Tfdf;
1025
1026double f_aux_min(const gsl_vector*x, void *pars) {
1027 Tfdf * fdf = ((Tfdf*) pars);
1028 double* p = (double*)calloc(x->size,sizeof(double));
1029 int k;
1030 for(k=0;k<x->size;k++) {
1031 p[k] = gsl_vector_get(x,k);
1032 }
1033 double res = fdf->f(x->size,p);
1034 free(p);
1035 return res;
1036}
1037
1038
1039void df_aux_min(const gsl_vector * x, void * pars, gsl_vector * g) {
1040 Tfdf * fdf = ((Tfdf*) pars);
1041 double* p = (double*)calloc(x->size,sizeof(double));
1042 double* q = (double*)calloc(g->size,sizeof(double));
1043 int k;
1044 for(k=0;k<x->size;k++) {
1045 p[k] = gsl_vector_get(x,k);
1046 }
1047
1048 fdf->df(x->size,p,g->size,q);
1049
1050 for(k=0;k<x->size;k++) {
1051 gsl_vector_set(g,k,q[k]);
1052 }
1053 free(p);
1054 free(q);
1055}
1056
1057void fdf_aux_min(const gsl_vector * x, void * pars, double * f, gsl_vector * g) {
1058 *f = f_aux_min(x,pars);
1059 df_aux_min(x,pars,g);
1060}
1061
1062
1063int minimizeD(int method, double f(int, double*), int df(int, double*, int, double*),
1064 double initstep, double minimpar, double tolgrad, int maxit,
1065 KRVEC(xi), RMAT(sol)) {
1066 REQUIRES(solr == maxit && solc == 2+xin,BAD_SIZE);
1067 DEBUGMSG("minimizeWithDeriv (conjugate_fr)");
1068 gsl_multimin_function_fdf my_func;
1069 // extract function from pars
1070 my_func.f = f_aux_min;
1071 my_func.df = df_aux_min;
1072 my_func.fdf = fdf_aux_min;
1073 my_func.n = xin;
1074 Tfdf stfdf;
1075 stfdf.f = f;
1076 stfdf.df = df;
1077 my_func.params = &stfdf;
1078 size_t iter = 0;
1079 int status;
1080 const gsl_multimin_fdfminimizer_type *T;
1081 gsl_multimin_fdfminimizer *s = NULL;
1082 // Starting point
1083 KDVVIEW(xi);
1084 // conjugate gradient fr
1085 switch(method) {
1086 case 0 : {T = gsl_multimin_fdfminimizer_conjugate_fr; break; }
1087 case 1 : {T = gsl_multimin_fdfminimizer_conjugate_pr; break; }
1088 case 2 : {T = gsl_multimin_fdfminimizer_vector_bfgs; break; }
1089 case 3 : {T = gsl_multimin_fdfminimizer_vector_bfgs2; break; }
1090 case 4 : {T = gsl_multimin_fdfminimizer_steepest_descent; break; }
1091 default: ERROR(BAD_CODE);
1092 }
1093 s = gsl_multimin_fdfminimizer_alloc (T, my_func.n);
1094 gsl_multimin_fdfminimizer_set (s, &my_func, V(xi), initstep, minimpar);
1095 do {
1096 status = gsl_multimin_fdfminimizer_iterate (s);
1097 solp[iter*solc+0] = iter+1;
1098 solp[iter*solc+1] = s->f;
1099 int k;
1100 for(k=0;k<xin;k++) {
1101 solp[iter*solc+k+2] = gsl_vector_get(s->x,k);
1102 }
1103 iter++;
1104 if (status) break;
1105 status = gsl_multimin_test_gradient (s->gradient, tolgrad);
1106 } while (status == GSL_CONTINUE && iter < maxit);
1107 int i,j;
1108 for (i=iter; i<solr; i++) {
1109 solp[i*solc+0] = iter;
1110 for(j=1;j<solc;j++) {
1111 solp[i*solc+j]=0.;
1112 }
1113 }
1114 gsl_multimin_fdfminimizer_free(s);
1115 OK
1116}
1117
1118//---------------------------------------------------------------
1119
1120double only_f_aux_root(double x, void *pars) {
1121 double (*f)(double) = (double (*)(double)) pars;
1122 return f(x);
1123}
1124
1125int root(int method, double f(double),
1126 double epsrel, int maxit,
1127 double xl, double xu, RMAT(sol)) {
1128 REQUIRES(solr == maxit && solc == 4,BAD_SIZE);
1129 DEBUGMSG("root_only_f");
1130 gsl_function my_func;
1131 // extract function from pars
1132 my_func.function = only_f_aux_root;
1133 my_func.params = f;
1134 size_t iter = 0;
1135 int status;
1136 const gsl_root_fsolver_type *T;
1137 gsl_root_fsolver *s;
1138 // Starting point
1139 switch(method) {
1140 case 0 : {T = gsl_root_fsolver_bisection; printf("7\n"); break; }
1141 case 1 : {T = gsl_root_fsolver_falsepos; break; }
1142 case 2 : {T = gsl_root_fsolver_brent; break; }
1143 default: ERROR(BAD_CODE);
1144 }
1145 s = gsl_root_fsolver_alloc (T);
1146 gsl_root_fsolver_set (s, &my_func, xl, xu);
1147 do {
1148 double best, current_lo, current_hi;
1149 status = gsl_root_fsolver_iterate (s);
1150 best = gsl_root_fsolver_root (s);
1151 current_lo = gsl_root_fsolver_x_lower (s);
1152 current_hi = gsl_root_fsolver_x_upper (s);
1153 solp[iter*solc] = iter + 1;
1154 solp[iter*solc+1] = best;
1155 solp[iter*solc+2] = current_lo;
1156 solp[iter*solc+3] = current_hi;
1157 iter++;
1158 if (status) /* check if solver is stuck */
1159 break;
1160
1161 status =
1162 gsl_root_test_interval (current_lo, current_hi, 0, epsrel);
1163 }
1164 while (status == GSL_CONTINUE && iter < maxit);
1165 int i;
1166 for (i=iter; i<solr; i++) {
1167 solp[i*solc+0] = iter;
1168 solp[i*solc+1]=0.;
1169 solp[i*solc+2]=0.;
1170 solp[i*solc+3]=0.;
1171 }
1172 gsl_root_fsolver_free(s);
1173 OK
1174}
1175
1176typedef struct {
1177 double (*f)(double);
1178 double (*jf)(double);
1179} uniTfjf;
1180
1181double f_aux_uni(double x, void *pars) {
1182 uniTfjf * fjf = ((uniTfjf*) pars);
1183 return (fjf->f)(x);
1184}
1185
1186double jf_aux_uni(double x, void * pars) {
1187 uniTfjf * fjf = ((uniTfjf*) pars);
1188 return (fjf->jf)(x);
1189}
1190
1191void fjf_aux_uni(double x, void * pars, double * f, double * g) {
1192 *f = f_aux_uni(x,pars);
1193 *g = jf_aux_uni(x,pars);
1194}
1195
1196int rootj(int method, double f(double),
1197 double df(double),
1198 double epsrel, int maxit,
1199 double x, RMAT(sol)) {
1200 REQUIRES(solr == maxit && solc == 2,BAD_SIZE);
1201 DEBUGMSG("root_fjf");
1202 gsl_function_fdf my_func;
1203 // extract function from pars
1204 my_func.f = f_aux_uni;
1205 my_func.df = jf_aux_uni;
1206 my_func.fdf = fjf_aux_uni;
1207 uniTfjf stfjf;
1208 stfjf.f = f;
1209 stfjf.jf = df;
1210 my_func.params = &stfjf;
1211 size_t iter = 0;
1212 int status;
1213 const gsl_root_fdfsolver_type *T;
1214 gsl_root_fdfsolver *s;
1215 // Starting point
1216 switch(method) {
1217 case 0 : {T = gsl_root_fdfsolver_newton;; break; }
1218 case 1 : {T = gsl_root_fdfsolver_secant; break; }
1219 case 2 : {T = gsl_root_fdfsolver_steffenson; break; }
1220 default: ERROR(BAD_CODE);
1221 }
1222 s = gsl_root_fdfsolver_alloc (T);
1223
1224 gsl_root_fdfsolver_set (s, &my_func, x);
1225
1226 do {
1227 double x0;
1228 status = gsl_root_fdfsolver_iterate (s);
1229 x0 = x;
1230 x = gsl_root_fdfsolver_root(s);
1231 solp[iter*solc+0] = iter+1;
1232 solp[iter*solc+1] = x;
1233
1234 iter++;
1235 if (status) /* check if solver is stuck */
1236 break;
1237
1238 status =
1239 gsl_root_test_delta (x, x0, 0, epsrel);
1240 }
1241 while (status == GSL_CONTINUE && iter < maxit);
1242
1243 int i;
1244 for (i=iter; i<solr; i++) {
1245 solp[i*solc+0] = iter;
1246 solp[i*solc+1]=0.;
1247 }
1248 gsl_root_fdfsolver_free(s);
1249 OK
1250}
1251
1252
1253//---------------------------------------------------------------
1254
1255typedef void TrawfunV(int, double*, int, double*);
1256
1257int only_f_aux_multiroot(const gsl_vector*x, void *pars, gsl_vector*y) {
1258 TrawfunV * f = (TrawfunV*) pars;
1259 double* p = (double*)calloc(x->size,sizeof(double));
1260 double* q = (double*)calloc(y->size,sizeof(double));
1261 int k;
1262 for(k=0;k<x->size;k++) {
1263 p[k] = gsl_vector_get(x,k);
1264 }
1265 f(x->size,p,y->size,q);
1266 for(k=0;k<y->size;k++) {
1267 gsl_vector_set(y,k,q[k]);
1268 }
1269 free(p);
1270 free(q);
1271 return 0; //hmmm
1272}
1273
1274int multiroot(int method, void f(int, double*, int, double*),
1275 double epsabs, int maxit,
1276 KRVEC(xi), RMAT(sol)) {
1277 REQUIRES(solr == maxit && solc == 1+2*xin,BAD_SIZE);
1278 DEBUGMSG("root_only_f");
1279 gsl_multiroot_function my_func;
1280 // extract function from pars
1281 my_func.f = only_f_aux_multiroot;
1282 my_func.n = xin;
1283 my_func.params = f;
1284 size_t iter = 0;
1285 int status;
1286 const gsl_multiroot_fsolver_type *T;
1287 gsl_multiroot_fsolver *s;
1288 // Starting point
1289 KDVVIEW(xi);
1290 switch(method) {
1291 case 0 : {T = gsl_multiroot_fsolver_hybrids;; break; }
1292 case 1 : {T = gsl_multiroot_fsolver_hybrid; break; }
1293 case 2 : {T = gsl_multiroot_fsolver_dnewton; break; }
1294 case 3 : {T = gsl_multiroot_fsolver_broyden; break; }
1295 default: ERROR(BAD_CODE);
1296 }
1297 s = gsl_multiroot_fsolver_alloc (T, my_func.n);
1298 gsl_multiroot_fsolver_set (s, &my_func, V(xi));
1299
1300 do {
1301 status = gsl_multiroot_fsolver_iterate (s);
1302
1303 solp[iter*solc+0] = iter+1;
1304
1305 int k;
1306 for(k=0;k<xin;k++) {
1307 solp[iter*solc+k+1] = gsl_vector_get(s->x,k);
1308 }
1309 for(k=xin;k<2*xin;k++) {
1310 solp[iter*solc+k+1] = gsl_vector_get(s->f,k-xin);
1311 }
1312
1313 iter++;
1314 if (status) /* check if solver is stuck */
1315 break;
1316
1317 status =
1318 gsl_multiroot_test_residual (s->f, epsabs);
1319 }
1320 while (status == GSL_CONTINUE && iter < maxit);
1321
1322 int i,j;
1323 for (i=iter; i<solr; i++) {
1324 solp[i*solc+0] = iter;
1325 for(j=1;j<solc;j++) {
1326 solp[i*solc+j]=0.;
1327 }
1328 }
1329 gsl_multiroot_fsolver_free(s);
1330 OK
1331}
1332
1333// working with the jacobian
1334
1335typedef struct {int (*f)(int, double*, int, double *);
1336 int (*jf)(int, double*, int, int, double*);} Tfjf;
1337
1338int f_aux(const gsl_vector*x, void *pars, gsl_vector*y) {
1339 Tfjf * fjf = ((Tfjf*) pars);
1340 double* p = (double*)calloc(x->size,sizeof(double));
1341 double* q = (double*)calloc(y->size,sizeof(double));
1342 int k;
1343 for(k=0;k<x->size;k++) {
1344 p[k] = gsl_vector_get(x,k);
1345 }
1346 (fjf->f)(x->size,p,y->size,q);
1347 for(k=0;k<y->size;k++) {
1348 gsl_vector_set(y,k,q[k]);
1349 }
1350 free(p);
1351 free(q);
1352 return 0;
1353}
1354
1355int jf_aux(const gsl_vector * x, void * pars, gsl_matrix * jac) {
1356 Tfjf * fjf = ((Tfjf*) pars);
1357 double* p = (double*)calloc(x->size,sizeof(double));
1358 double* q = (double*)calloc((jac->size1)*(jac->size2),sizeof(double));
1359 int i,j,k;
1360 for(k=0;k<x->size;k++) {
1361 p[k] = gsl_vector_get(x,k);
1362 }
1363
1364 (fjf->jf)(x->size,p,jac->size1,jac->size2,q);
1365
1366 k=0;
1367 for(i=0;i<jac->size1;i++) {
1368 for(j=0;j<jac->size2;j++){
1369 gsl_matrix_set(jac,i,j,q[k++]);
1370 }
1371 }
1372 free(p);
1373 free(q);
1374 return 0;
1375}
1376
1377int fjf_aux(const gsl_vector * x, void * pars, gsl_vector * f, gsl_matrix * g) {
1378 f_aux(x,pars,f);
1379 jf_aux(x,pars,g);
1380 return 0;
1381}
1382
1383int multirootj(int method, int f(int, double*, int, double*),
1384 int jac(int, double*, int, int, double*),
1385 double epsabs, int maxit,
1386 KRVEC(xi), RMAT(sol)) {
1387 REQUIRES(solr == maxit && solc == 1+2*xin,BAD_SIZE);
1388 DEBUGMSG("root_fjf");
1389 gsl_multiroot_function_fdf my_func;
1390 // extract function from pars
1391 my_func.f = f_aux;
1392 my_func.df = jf_aux;
1393 my_func.fdf = fjf_aux;
1394 my_func.n = xin;
1395 Tfjf stfjf;
1396 stfjf.f = f;
1397 stfjf.jf = jac;
1398 my_func.params = &stfjf;
1399 size_t iter = 0;
1400 int status;
1401 const gsl_multiroot_fdfsolver_type *T;
1402 gsl_multiroot_fdfsolver *s;
1403 // Starting point
1404 KDVVIEW(xi);
1405 switch(method) {
1406 case 0 : {T = gsl_multiroot_fdfsolver_hybridsj;; break; }
1407 case 1 : {T = gsl_multiroot_fdfsolver_hybridj; break; }
1408 case 2 : {T = gsl_multiroot_fdfsolver_newton; break; }
1409 case 3 : {T = gsl_multiroot_fdfsolver_gnewton; break; }
1410 default: ERROR(BAD_CODE);
1411 }
1412 s = gsl_multiroot_fdfsolver_alloc (T, my_func.n);
1413
1414 gsl_multiroot_fdfsolver_set (s, &my_func, V(xi));
1415
1416 do {
1417 status = gsl_multiroot_fdfsolver_iterate (s);
1418
1419 solp[iter*solc+0] = iter+1;
1420
1421 int k;
1422 for(k=0;k<xin;k++) {
1423 solp[iter*solc+k+1] = gsl_vector_get(s->x,k);
1424 }
1425 for(k=xin;k<2*xin;k++) {
1426 solp[iter*solc+k+1] = gsl_vector_get(s->f,k-xin);
1427 }
1428
1429 iter++;
1430 if (status) /* check if solver is stuck */
1431 break;
1432
1433 status =
1434 gsl_multiroot_test_residual (s->f, epsabs);
1435 }
1436 while (status == GSL_CONTINUE && iter < maxit);
1437
1438 int i,j;
1439 for (i=iter; i<solr; i++) {
1440 solp[i*solc+0] = iter;
1441 for(j=1;j<solc;j++) {
1442 solp[i*solc+j]=0.;
1443 }
1444 }
1445 gsl_multiroot_fdfsolver_free(s);
1446 OK
1447}
1448
1449//-------------- non linear least squares fitting -------------------
1450
1451int nlfit(int method, int f(int, double*, int, double*),
1452 int jac(int, double*, int, int, double*),
1453 double epsabs, double epsrel, int maxit, int p,
1454 KRVEC(xi), RMAT(sol)) {
1455 REQUIRES(solr == maxit && solc == 2+xin,BAD_SIZE);
1456 DEBUGMSG("nlfit");
1457 const gsl_multifit_fdfsolver_type *T;
1458 gsl_multifit_fdfsolver *s;
1459 gsl_multifit_function_fdf my_f;
1460 // extract function from pars
1461 my_f.f = f_aux;
1462 my_f.df = jf_aux;
1463 my_f.fdf = fjf_aux;
1464 my_f.n = p;
1465 my_f.p = xin; // !!!!
1466 Tfjf stfjf;
1467 stfjf.f = f;
1468 stfjf.jf = jac;
1469 my_f.params = &stfjf;
1470 size_t iter = 0;
1471 int status;
1472
1473 KDVVIEW(xi);
1474 //DMVIEW(cov);
1475
1476 switch(method) {
1477 case 0 : { T = gsl_multifit_fdfsolver_lmsder; break; }
1478 case 1 : { T = gsl_multifit_fdfsolver_lmder; break; }
1479 default: ERROR(BAD_CODE);
1480 }
1481
1482 s = gsl_multifit_fdfsolver_alloc (T, my_f.n, my_f.p);
1483 gsl_multifit_fdfsolver_set (s, &my_f, V(xi));
1484
1485 do { status = gsl_multifit_fdfsolver_iterate (s);
1486
1487 solp[iter*solc+0] = iter+1;
1488 solp[iter*solc+1] = gsl_blas_dnrm2 (s->f);
1489
1490 int k;
1491 for(k=0;k<xin;k++) {
1492 solp[iter*solc+k+2] = gsl_vector_get(s->x,k);
1493 }
1494
1495 iter++;
1496 if (status) /* check if solver is stuck */
1497 break;
1498
1499 status = gsl_multifit_test_delta (s->dx, s->x, epsabs, epsrel);
1500 }
1501 while (status == GSL_CONTINUE && iter < maxit);
1502
1503 int i,j;
1504 for (i=iter; i<solr; i++) {
1505 solp[i*solc+0] = iter;
1506 for(j=1;j<solc;j++) {
1507 solp[i*solc+j]=0.;
1508 }
1509 }
1510
1511 //gsl_multifit_covar (s->J, 0.0, M(cov));
1512
1513 gsl_multifit_fdfsolver_free (s);
1514 OK
1515}
1516
1517
1518//////////////////////////////////////////////////////
1519
1520
1521#define RAN(C,F) case C: { for(k=0;k<rn;k++) { rp[k]= F(gen); }; OK }
1522
1523int random_vector(int seed, int code, RVEC(r)) {
1524 DEBUGMSG("random_vector")
1525 static gsl_rng * gen = NULL;
1526 if (!gen) { gen = gsl_rng_alloc (gsl_rng_mt19937);}
1527 gsl_rng_set (gen, seed);
1528 int k;
1529 switch (code) {
1530 RAN(0,gsl_rng_uniform)
1531 RAN(1,gsl_ran_ugaussian)
1532 default: ERROR(BAD_CODE);
1533 }
1534}
1535#undef RAN
1536
1537//////////////////////////////////////////////////////
1538
1539#include "gsl-ode.c"
1540
1541//////////////////////////////////////////////////////
diff --git a/packages/hmatrix/src/Numeric/GSL/gsl-ode.c b/packages/hmatrix/src/Numeric/GSL/gsl-ode.c
new file mode 100644
index 0000000..3f2771b
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/GSL/gsl-ode.c
@@ -0,0 +1,182 @@
1
2#ifdef GSLODE1
3
4////////////////////////////// ODE V1 //////////////////////////////////////////
5
6#include <gsl/gsl_odeiv.h>
7
8typedef struct {int n; int (*f)(double,int, const double*, int, double *); int (*j)(double,int, const double*, int, int, double*);} Tode;
9
10int odefunc (double t, const double y[], double f[], void *params) {
11 Tode * P = (Tode*) params;
12 (P->f)(t,P->n,y,P->n,f);
13 return GSL_SUCCESS;
14}
15
16int odejac (double t, const double y[], double *dfdy, double dfdt[], void *params) {
17 Tode * P = ((Tode*) params);
18 (P->j)(t,P->n,y,P->n,P->n,dfdy);
19 int j;
20 for (j=0; j< P->n; j++)
21 dfdt[j] = 0.0;
22 return GSL_SUCCESS;
23}
24
25
26int ode(int method, double h, double eps_abs, double eps_rel,
27 int f(double, int, const double*, int, double*),
28 int jac(double, int, const double*, int, int, double*),
29 KRVEC(xi), KRVEC(ts), RMAT(sol)) {
30
31 const gsl_odeiv_step_type * T;
32
33 switch(method) {
34 case 0 : {T = gsl_odeiv_step_rk2; break; }
35 case 1 : {T = gsl_odeiv_step_rk4; break; }
36 case 2 : {T = gsl_odeiv_step_rkf45; break; }
37 case 3 : {T = gsl_odeiv_step_rkck; break; }
38 case 4 : {T = gsl_odeiv_step_rk8pd; break; }
39 case 5 : {T = gsl_odeiv_step_rk2imp; break; }
40 case 6 : {T = gsl_odeiv_step_rk4imp; break; }
41 case 7 : {T = gsl_odeiv_step_bsimp; break; }
42 case 8 : { printf("Sorry: ODE rk1imp not available in this GSL version\n"); exit(0); }
43 case 9 : { printf("Sorry: ODE msadams not available in this GSL version\n"); exit(0); }
44 case 10: { printf("Sorry: ODE msbdf not available in this GSL version\n"); exit(0); }
45 default: ERROR(BAD_CODE);
46 }
47
48 gsl_odeiv_step * s = gsl_odeiv_step_alloc (T, xin);
49 gsl_odeiv_control * c = gsl_odeiv_control_y_new (eps_abs, eps_rel);
50 gsl_odeiv_evolve * e = gsl_odeiv_evolve_alloc (xin);
51
52 Tode P;
53 P.f = f;
54 P.j = jac;
55 P.n = xin;
56
57 gsl_odeiv_system sys = {odefunc, odejac, xin, &P};
58
59 double t = tsp[0];
60
61 double* y = (double*)calloc(xin,sizeof(double));
62 int i,j;
63 for(i=0; i< xin; i++) {
64 y[i] = xip[i];
65 solp[i] = xip[i];
66 }
67
68 for (i = 1; i < tsn ; i++)
69 {
70 double ti = tsp[i];
71 while (t < ti)
72 {
73 gsl_odeiv_evolve_apply (e, c, s,
74 &sys,
75 &t, ti, &h,
76 y);
77 // if (h < hmin) h = hmin;
78 }
79 for(j=0; j<xin; j++) {
80 solp[i*xin + j] = y[j];
81 }
82 }
83
84 free(y);
85 gsl_odeiv_evolve_free (e);
86 gsl_odeiv_control_free (c);
87 gsl_odeiv_step_free (s);
88 return 0;
89}
90
91#else
92
93///////////////////// ODE V2 ///////////////////////////////////////////////////
94
95#include <gsl/gsl_odeiv2.h>
96
97typedef struct {int n; int (*f)(double,int, const double*, int, double *); int (*j)(double,int, const double*, int, int, double*);} Tode;
98
99int odefunc (double t, const double y[], double f[], void *params) {
100 Tode * P = (Tode*) params;
101 (P->f)(t,P->n,y,P->n,f);
102 return GSL_SUCCESS;
103}
104
105int odejac (double t, const double y[], double *dfdy, double dfdt[], void *params) {
106 Tode * P = ((Tode*) params);
107 (P->j)(t,P->n,y,P->n,P->n,dfdy);
108 int j;
109 for (j=0; j< P->n; j++)
110 dfdt[j] = 0.0;
111 return GSL_SUCCESS;
112}
113
114
115int ode(int method, double h, double eps_abs, double eps_rel,
116 int f(double, int, const double*, int, double*),
117 int jac(double, int, const double*, int, int, double*),
118 KRVEC(xi), KRVEC(ts), RMAT(sol)) {
119
120 const gsl_odeiv2_step_type * T;
121
122 switch(method) {
123 case 0 : {T = gsl_odeiv2_step_rk2; break; }
124 case 1 : {T = gsl_odeiv2_step_rk4; break; }
125 case 2 : {T = gsl_odeiv2_step_rkf45; break; }
126 case 3 : {T = gsl_odeiv2_step_rkck; break; }
127 case 4 : {T = gsl_odeiv2_step_rk8pd; break; }
128 case 5 : {T = gsl_odeiv2_step_rk2imp; break; }
129 case 6 : {T = gsl_odeiv2_step_rk4imp; break; }
130 case 7 : {T = gsl_odeiv2_step_bsimp; break; }
131 case 8 : {T = gsl_odeiv2_step_rk1imp; break; }
132 case 9 : {T = gsl_odeiv2_step_msadams; break; }
133 case 10: {T = gsl_odeiv2_step_msbdf; break; }
134 default: ERROR(BAD_CODE);
135 }
136
137 Tode P;
138 P.f = f;
139 P.j = jac;
140 P.n = xin;
141
142 gsl_odeiv2_system sys = {odefunc, odejac, xin, &P};
143
144 gsl_odeiv2_driver * d =
145 gsl_odeiv2_driver_alloc_y_new (&sys, T, h, eps_abs, eps_rel);
146
147 double t = tsp[0];
148
149 double* y = (double*)calloc(xin,sizeof(double));
150 int i,j;
151 int status=0;
152 for(i=0; i< xin; i++) {
153 y[i] = xip[i];
154 solp[i] = xip[i];
155 }
156
157 for (i = 1; i < tsn ; i++)
158 {
159 double ti = tsp[i];
160
161 status = gsl_odeiv2_driver_apply (d, &t, ti, y);
162
163 if (status != GSL_SUCCESS) {
164 printf ("error in ode, return value=%d\n", status);
165 break;
166 }
167
168// printf ("%.5e %.5e %.5e\n", t, y[0], y[1]);
169
170 for(j=0; j<xin; j++) {
171 solp[i*xin + j] = y[j];
172 }
173 }
174
175 free(y);
176 gsl_odeiv2_driver_free (d);
177
178 return status;
179}
180
181#endif
182
diff --git a/packages/hmatrix/src/Numeric/HMatrix.hs b/packages/hmatrix/src/Numeric/HMatrix.hs
new file mode 100644
index 0000000..2e01454
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/HMatrix.hs
@@ -0,0 +1,136 @@
1-----------------------------------------------------------------------------
2{- |
3Module : Numeric.HMatrix
4Copyright : (c) Alberto Ruiz 2006-14
5License : GPL
6
7Maintainer : Alberto Ruiz
8Stability : provisional
9
10This module reexports the most common Linear Algebra functions.
11
12-}
13-----------------------------------------------------------------------------
14module Numeric.HMatrix (
15
16 -- * Basic types and data processing
17 module Numeric.HMatrix.Data,
18
19 -- | The standard numeric classes are defined elementwise:
20 --
21 -- >>> fromList [1,2,3] * fromList [3,0,-2 :: Double]
22 -- fromList [3.0,0.0,-6.0]
23 --
24 -- >>> (3><3) [1..9] * ident 3 :: Matrix Double
25 -- (3><3)
26 -- [ 1.0, 0.0, 0.0
27 -- , 0.0, 5.0, 0.0
28 -- , 0.0, 0.0, 9.0 ]
29 --
30 -- In arithmetic operations single-element vectors and matrices
31 -- (created from numeric literals or using 'scalar') automatically
32 -- expand to match the dimensions of the other operand:
33 --
34 -- >>> 5 + 2*ident 3 :: Matrix Double
35 -- (3><3)
36 -- [ 7.0, 5.0, 5.0
37 -- , 5.0, 7.0, 5.0
38 -- , 5.0, 5.0, 7.0 ]
39 --
40
41 -- * Products
42 (×),
43
44 -- | The matrix product is also implemented in the "Data.Monoid" instance for Matrix, where
45 -- single-element matrices (created from numeric literals or using 'scalar')
46 -- are used for scaling.
47 --
48 -- >>> let m = (2><3)[1..] :: Matrix Double
49 -- >>> m <> 2 <> diagl[0.5,1,0]
50 -- (2><3)
51 -- [ 1.0, 4.0, 0.0
52 -- , 4.0, 10.0, 0.0 ]
53 --
54 -- mconcat uses 'optimiseMult' to get the optimal association order.
55
56 (·), outer, kronecker, cross,
57 scale,
58 sumElements, prodElements, absSum,
59
60 -- * Linear Systems
61 (<\>),
62 linearSolve,
63 linearSolveLS,
64 linearSolveSVD,
65 luSolve,
66 cholSolve,
67
68 -- * Inverse and pseudoinverse
69 inv, pinv, pinvTol,
70
71 -- * Determinant and rank
72 rcond, rank, ranksv,
73 det, invlndet,
74
75 -- * Singular value decomposition
76 svd,
77 fullSVD,
78 thinSVD,
79 compactSVD,
80 singularValues,
81 leftSV, rightSV,
82
83 -- * Eigensystems
84 eig, eigSH, eigSH',
85 eigenvalues, eigenvaluesSH, eigenvaluesSH',
86 geigSH',
87
88 -- * QR
89 qr, rq, qrRaw, qrgr,
90
91 -- * Cholesky
92 chol, cholSH, mbCholSH,
93
94 -- * Hessenberg
95 hess,
96
97 -- * Schur
98 schur,
99
100 -- * LU
101 lu, luPacked,
102
103 -- * Matrix functions
104 expm,
105 sqrtm,
106 matFunc,
107
108 -- * Nullspace
109 nullspacePrec,
110 nullVector,
111 nullspaceSVD,
112 null1, null1sym,
113
114 orth,
115
116 -- * Norms
117 norm1, norm2, normInf, pnorm, NormType(..),
118
119 -- * Correlation and Convolution
120 corr, conv, corrMin, corr2, conv2,
121
122 -- * Random arrays
123 rand, randn, RandDist(..), randomVector, gaussianSample, uniformSample,
124
125 -- * Misc
126 meanCov, peps, relativeError, haussholder, optimiseMult, udot, cdot, (<.>)
127) where
128
129import Numeric.HMatrix.Data
130
131import Numeric.Matrix()
132import Numeric.Vector()
133import Numeric.Container
134import Numeric.LinearAlgebra.Algorithms
135import Numeric.LinearAlgebra.Util
136
diff --git a/packages/hmatrix/src/Numeric/HMatrix/Data.hs b/packages/hmatrix/src/Numeric/HMatrix/Data.hs
new file mode 100644
index 0000000..568dc05
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/HMatrix/Data.hs
@@ -0,0 +1,69 @@
1--------------------------------------------------------------------------------
2{- |
3Module : Numeric.HMatrix.Data
4Copyright : (c) Alberto Ruiz 2014
5License : GPL
6
7Maintainer : Alberto Ruiz
8Stability : provisional
9
10Basic data processing.
11
12-}
13--------------------------------------------------------------------------------
14
15module Numeric.HMatrix.Data(
16
17 -- * Vector
18 -- | 1D arrays are storable vectors from the vector package.
19
20 Vector, (|>), dim, (@>),
21
22 -- * Matrix
23 Matrix, (><), size, (@@>), trans, ctrans,
24
25 -- * Construction
26 scalar, konst, build, assoc, accum, linspace, -- ones, zeros,
27
28 -- * Diagonal
29 ident, diag, diagl, diagRect, takeDiag,
30
31 -- * Data manipulation
32 fromList, toList, subVector, takesV, vjoin,
33 flatten, reshape, asRow, asColumn, row, col,
34 fromRows, toRows, fromColumns, toColumns, fromLists, toLists, fromArray2D,
35 takeRows, dropRows, takeColumns, dropColumns, subMatrix, (?), (¿), fliprl, flipud,
36
37 -- * Block matrix
38 fromBlocks, (¦), (——), diagBlock, repmat, toBlocks, toBlocksEvery,
39
40 -- * Mapping functions
41 conj, cmap, step, cond,
42
43 -- * Find elements
44 find, maxIndex, minIndex, maxElement, minElement, atIndex,
45
46 -- * IO
47 disp, dispf, disps, dispcf, latexFormat, format,
48 loadMatrix, saveMatrix, fromFile, fileDimensions,
49 readMatrix,
50 fscanfVector, fprintfVector, freadVector, fwriteVector,
51
52-- * Conversion
53 Convert(..),
54
55 -- * Misc
56 arctan2,
57 rows, cols,
58 separable,
59
60 module Data.Complex
61
62) where
63
64import Data.Packed.Vector
65import Data.Packed.Matrix
66import Numeric.Container
67import Numeric.LinearAlgebra.Util
68import Data.Complex
69
diff --git a/packages/hmatrix/src/Numeric/HMatrix/Devel.hs b/packages/hmatrix/src/Numeric/HMatrix/Devel.hs
new file mode 100644
index 0000000..b921f44
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/HMatrix/Devel.hs
@@ -0,0 +1,69 @@
1--------------------------------------------------------------------------------
2{- |
3Module : Numeric.HMatrix.Devel
4Copyright : (c) Alberto Ruiz 2014
5License : GPL
6
7Maintainer : Alberto Ruiz
8Stability : provisional
9
10The library can be easily extended using the tools in this module.
11
12-}
13--------------------------------------------------------------------------------
14
15module Numeric.HMatrix.Devel(
16 -- * FFI helpers
17 -- | Sample usage, to upload a perspective matrix to a shader.
18 --
19 -- @ glUniformMatrix4fv 0 1 (fromIntegral gl_TRUE) \`appMatrix\` perspective 0.01 100 (pi\/2) (4\/3)
20 -- @
21 module Data.Packed.Foreign,
22
23 -- * FFI tools
24 -- | Illustrative usage examples can be found
25 -- in the @examples\/devel@ folder included in the package.
26 module Data.Packed.Development,
27
28 -- * ST
29 -- | In-place manipulation inside the ST monad.
30 -- See examples\/inplace.hs in the distribution.
31
32 -- ** Mutable Vectors
33 STVector, newVector, thawVector, freezeVector, runSTVector,
34 readVector, writeVector, modifyVector, liftSTVector,
35 -- ** Mutable Matrices
36 STMatrix, newMatrix, thawMatrix, freezeMatrix, runSTMatrix,
37 readMatrix, writeMatrix, modifyMatrix, liftSTMatrix,
38 -- ** Unsafe functions
39 newUndefinedVector,
40 unsafeReadVector, unsafeWriteVector,
41 unsafeThawVector, unsafeFreezeVector,
42 newUndefinedMatrix,
43 unsafeReadMatrix, unsafeWriteMatrix,
44 unsafeThawMatrix, unsafeFreezeMatrix,
45
46 -- * Special maps and zips
47 mapVectorWithIndex, zipVector, zipVectorWith, unzipVector, unzipVectorWith,
48 mapVectorM, mapVectorM_, mapVectorWithIndexM, mapVectorWithIndexM_,
49 foldLoop, foldVector, foldVectorG, foldVectorWithIndex,
50 mapMatrixWithIndex, mapMatrixWithIndexM, mapMatrixWithIndexM_,
51 liftMatrix, liftMatrix2, liftMatrix2Auto,
52
53 -- * Auxiliary classes
54 Element, Container, Product, Contraction, LSDiv,
55 Complexable(), RealElement(),
56 RealOf, ComplexOf, SingleOf, DoubleOf,
57 IndexOf,
58 Field, Normed
59) where
60
61import Data.Packed.Foreign
62import Data.Packed.Development
63import Data.Packed.ST
64import Numeric.Container(Container,Contraction,LSDiv,Product,
65 Complexable(),RealElement(),
66 RealOf, ComplexOf, SingleOf, DoubleOf, IndexOf)
67import Data.Packed
68import Numeric.LinearAlgebra.Algorithms(Field,Normed)
69
diff --git a/packages/hmatrix/src/Numeric/IO.hs b/packages/hmatrix/src/Numeric/IO.hs
new file mode 100644
index 0000000..836f352
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/IO.hs
@@ -0,0 +1,165 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Numeric.IO
4-- Copyright : (c) Alberto Ruiz 2010
5-- License : GPL
6--
7-- Maintainer : Alberto Ruiz <aruiz@um.es>
8-- Stability : provisional
9-- Portability : portable
10--
11-- Display, formatting and IO functions for numeric 'Vector' and 'Matrix'
12--
13-----------------------------------------------------------------------------
14
15module Numeric.IO (
16 dispf, disps, dispcf, vecdisp, latexFormat, format,
17 loadMatrix, saveMatrix, fromFile, fileDimensions,
18 readMatrix, fromArray2D,
19 fscanfVector, fprintfVector, freadVector, fwriteVector
20) where
21
22import Data.Packed
23import Data.Packed.Internal
24import System.Process(readProcess)
25import Text.Printf(printf)
26import Data.List(intersperse)
27import Data.Complex
28
29{- | Creates a string from a matrix given a separator and a function to show each entry. Using
30this function the user can easily define any desired display function:
31
32@import Text.Printf(printf)@
33
34@disp = putStr . format \" \" (printf \"%.2f\")@
35
36-}
37format :: (Element t) => String -> (t -> String) -> Matrix t -> String
38format sep f m = table sep . map (map f) . toLists $ m
39
40{- | Show a matrix with \"autoscaling\" and a given number of decimal places.
41
42>>> putStr . disps 2 $ 120 * (3><4) [1..]
433x4 E3
44 0.12 0.24 0.36 0.48
45 0.60 0.72 0.84 0.96
46 1.08 1.20 1.32 1.44
47
48-}
49disps :: Int -> Matrix Double -> String
50disps d x = sdims x ++ " " ++ formatScaled d x
51
52{- | Show a matrix with a given number of decimal places.
53
54>>> dispf 2 (1/3 + ident 3)
55"3x3\n1.33 0.33 0.33\n0.33 1.33 0.33\n0.33 0.33 1.33\n"
56
57>>> putStr . dispf 2 $ (3><4)[1,1.5..]
583x4
591.00 1.50 2.00 2.50
603.00 3.50 4.00 4.50
615.00 5.50 6.00 6.50
62
63>>> putStr . unlines . tail . lines . dispf 2 . asRow $ linspace 10 (0,1)
640.00 0.11 0.22 0.33 0.44 0.56 0.67 0.78 0.89 1.00
65
66-}
67dispf :: Int -> Matrix Double -> String
68dispf d x = sdims x ++ "\n" ++ formatFixed (if isInt x then 0 else d) x
69
70sdims x = show (rows x) ++ "x" ++ show (cols x)
71
72formatFixed d x = format " " (printf ("%."++show d++"f")) $ x
73
74isInt = all lookslikeInt . toList . flatten
75
76formatScaled dec t = "E"++show o++"\n" ++ ss
77 where ss = format " " (printf fmt. g) t
78 g x | o >= 0 = x/10^(o::Int)
79 | otherwise = x*10^(-o)
80 o | rows t == 0 || cols t == 0 = 0
81 | otherwise = floor $ maximum $ map (logBase 10 . abs) $ toList $ flatten t
82 fmt = '%':show (dec+3) ++ '.':show dec ++"f"
83
84{- | Show a vector using a function for showing matrices.
85
86>>> putStr . vecdisp (dispf 2) $ linspace 10 (0,1)
8710 |> 0.00 0.11 0.22 0.33 0.44 0.56 0.67 0.78 0.89 1.00
88
89-}
90vecdisp :: (Element t) => (Matrix t -> String) -> Vector t -> String
91vecdisp f v
92 = ((show (dim v) ++ " |> ") ++) . (++"\n")
93 . unwords . lines . tail . dropWhile (not . (`elem` " \n"))
94 . f . trans . reshape 1
95 $ v
96
97{- | Tool to display matrices with latex syntax.
98
99>>> latexFormat "bmatrix" (dispf 2 $ ident 2)
100"\\begin{bmatrix}\n1 & 0\n\\\\\n0 & 1\n\\end{bmatrix}"
101
102-}
103latexFormat :: String -- ^ type of braces: \"matrix\", \"bmatrix\", \"pmatrix\", etc.
104 -> String -- ^ Formatted matrix, with elements separated by spaces and newlines
105 -> String
106latexFormat del tab = "\\begin{"++del++"}\n" ++ f tab ++ "\\end{"++del++"}"
107 where f = unlines . intersperse "\\\\" . map unwords . map (intersperse " & " . words) . tail . lines
108
109-- | Pretty print a complex number with at most n decimal digits.
110showComplex :: Int -> Complex Double -> String
111showComplex d (a:+b)
112 | isZero a && isZero b = "0"
113 | isZero b = sa
114 | isZero a && isOne b = s2++"i"
115 | isZero a = sb++"i"
116 | isOne b = sa++s3++"i"
117 | otherwise = sa++s1++sb++"i"
118 where
119 sa = shcr d a
120 sb = shcr d b
121 s1 = if b<0 then "" else "+"
122 s2 = if b<0 then "-" else ""
123 s3 = if b<0 then "-" else "+"
124
125shcr d a | lookslikeInt a = printf "%.0f" a
126 | otherwise = printf ("%."++show d++"f") a
127
128
129lookslikeInt x = show (round x :: Int) ++".0" == shx || "-0.0" == shx
130 where shx = show x
131
132isZero x = show x `elem` ["0.0","-0.0"]
133isOne x = show x `elem` ["1.0","-1.0"]
134
135-- | Pretty print a complex matrix with at most n decimal digits.
136dispcf :: Int -> Matrix (Complex Double) -> String
137dispcf d m = sdims m ++ "\n" ++ format " " (showComplex d) m
138
139--------------------------------------------------------------------
140
141-- | reads a matrix from a string containing a table of numbers.
142readMatrix :: String -> Matrix Double
143readMatrix = fromLists . map (map read). map words . filter (not.null) . lines
144
145{- | obtains the number of rows and columns in an ASCII data file
146 (provisionally using unix's wc).
147-}
148fileDimensions :: FilePath -> IO (Int,Int)
149fileDimensions fname = do
150 wcres <- readProcess "wc" ["-w",fname] ""
151 contents <- readFile fname
152 let tot = read . head . words $ wcres
153 c = length . head . dropWhile null . map words . lines $ contents
154 if tot > 0
155 then return (tot `div` c, c)
156 else return (0,0)
157
158-- | Loads a matrix from an ASCII file formatted as a 2D table.
159loadMatrix :: FilePath -> IO (Matrix Double)
160loadMatrix file = fromFile file =<< fileDimensions file
161
162-- | Loads a matrix from an ASCII file (the number of rows and columns must be known in advance).
163fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double)
164fromFile filename (r,c) = reshape c `fmap` fscanfVector filename (r*c)
165
diff --git a/packages/hmatrix/src/Numeric/LinearAlgebra.hs b/packages/hmatrix/src/Numeric/LinearAlgebra.hs
new file mode 100644
index 0000000..1db860c
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/LinearAlgebra.hs
@@ -0,0 +1,30 @@
1-----------------------------------------------------------------------------
2{- |
3Module : Numeric.LinearAlgebra
4Copyright : (c) Alberto Ruiz 2006-10
5License : GPL-style
6
7Maintainer : Alberto Ruiz (aruiz at um dot es)
8Stability : provisional
9Portability : uses ffi
10
11This module reexports all normally required functions for Linear Algebra applications.
12
13It also provides instances of standard classes 'Show', 'Read', 'Eq',
14'Num', 'Fractional', and 'Floating' for 'Vector' and 'Matrix'.
15In arithmetic operations one-component vectors and matrices automatically
16expand to match the dimensions of the other operand.
17
18-}
19-----------------------------------------------------------------------------
20{-# OPTIONS_HADDOCK hide #-}
21
22module Numeric.LinearAlgebra (
23 module Numeric.Container,
24 module Numeric.LinearAlgebra.Algorithms
25) where
26
27import Numeric.Container
28import Numeric.LinearAlgebra.Algorithms
29import Numeric.Matrix()
30import Numeric.Vector()
diff --git a/packages/hmatrix/src/Numeric/LinearAlgebra/Algorithms.hs b/packages/hmatrix/src/Numeric/LinearAlgebra/Algorithms.hs
new file mode 100644
index 0000000..8c4b610
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/LinearAlgebra/Algorithms.hs
@@ -0,0 +1,746 @@
1{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
2{-# LANGUAGE CPP #-}
3{-# LANGUAGE MultiParamTypeClasses #-}
4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE TypeFamilies #-}
6
7-----------------------------------------------------------------------------
8{- |
9Module : Numeric.LinearAlgebra.Algorithms
10Copyright : (c) Alberto Ruiz 2006-9
11License : GPL-style
12
13Maintainer : Alberto Ruiz (aruiz at um dot es)
14Stability : provisional
15Portability : uses ffi
16
17High level generic interface to common matrix computations.
18
19Specific functions for particular base types can also be explicitly
20imported from "Numeric.LinearAlgebra.LAPACK".
21
22-}
23-----------------------------------------------------------------------------
24{-# OPTIONS_HADDOCK hide #-}
25
26
27module Numeric.LinearAlgebra.Algorithms (
28-- * Supported types
29 Field(),
30-- * Linear Systems
31 linearSolve,
32 luSolve,
33 cholSolve,
34 linearSolveLS,
35 linearSolveSVD,
36 inv, pinv, pinvTol,
37 det, invlndet,
38 rank, rcond,
39-- * Matrix factorizations
40-- ** Singular value decomposition
41 svd,
42 fullSVD,
43 thinSVD,
44 compactSVD,
45 singularValues,
46 leftSV, rightSV,
47-- ** Eigensystems
48 eig, eigSH, eigSH',
49 eigenvalues, eigenvaluesSH, eigenvaluesSH',
50 geigSH',
51-- ** QR
52 qr, rq, qrRaw, qrgr,
53-- ** Cholesky
54 chol, cholSH, mbCholSH,
55-- ** Hessenberg
56 hess,
57-- ** Schur
58 schur,
59-- ** LU
60 lu, luPacked,
61-- * Matrix functions
62 expm,
63 sqrtm,
64 matFunc,
65-- * Nullspace
66 nullspacePrec,
67 nullVector,
68 nullspaceSVD,
69 orth,
70-- * Norms
71 Normed(..), NormType(..),
72 relativeError,
73-- * Misc
74 eps, peps, i,
75-- * Util
76 haussholder,
77 unpackQR, unpackHess,
78 ranksv
79) where
80
81
82import Data.Packed.Internal hiding ((//))
83import Data.Packed.Matrix
84import Numeric.LinearAlgebra.LAPACK as LAPACK
85import Data.List(foldl1')
86import Data.Array
87import Numeric.ContainerBoot
88
89
90{- | Generic linear algebra functions for double precision real and complex matrices.
91
92(Single precision data can be converted using 'single' and 'double').
93
94-}
95class (Product t,
96 Convert t,
97 Container Vector t,
98 Container Matrix t,
99 Normed Matrix t,
100 Normed Vector t,
101 Floating t,
102 RealOf t ~ Double) => Field t where
103 svd' :: Matrix t -> (Matrix t, Vector Double, Matrix t)
104 thinSVD' :: Matrix t -> (Matrix t, Vector Double, Matrix t)
105 sv' :: Matrix t -> Vector Double
106 luPacked' :: Matrix t -> (Matrix t, [Int])
107 luSolve' :: (Matrix t, [Int]) -> Matrix t -> Matrix t
108 linearSolve' :: Matrix t -> Matrix t -> Matrix t
109 cholSolve' :: Matrix t -> Matrix t -> Matrix t
110 linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t
111 linearSolveLS' :: Matrix t -> Matrix t -> Matrix t
112 eig' :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
113 eigSH'' :: Matrix t -> (Vector Double, Matrix t)
114 eigOnly :: Matrix t -> Vector (Complex Double)
115 eigOnlySH :: Matrix t -> Vector Double
116 cholSH' :: Matrix t -> Matrix t
117 mbCholSH' :: Matrix t -> Maybe (Matrix t)
118 qr' :: Matrix t -> (Matrix t, Vector t)
119 qrgr' :: Int -> (Matrix t, Vector t) -> Matrix t
120 hess' :: Matrix t -> (Matrix t, Matrix t)
121 schur' :: Matrix t -> (Matrix t, Matrix t)
122
123
124instance Field Double where
125 svd' = svdRd
126 thinSVD' = thinSVDRd
127 sv' = svR
128 luPacked' = luR
129 luSolve' (l_u,perm) = lusR l_u perm
130 linearSolve' = linearSolveR -- (luSolve . luPacked) ??
131 cholSolve' = cholSolveR
132 linearSolveLS' = linearSolveLSR
133 linearSolveSVD' = linearSolveSVDR Nothing
134 eig' = eigR
135 eigSH'' = eigS
136 eigOnly = eigOnlyR
137 eigOnlySH = eigOnlyS
138 cholSH' = cholS
139 mbCholSH' = mbCholS
140 qr' = qrR
141 qrgr' = qrgrR
142 hess' = unpackHess hessR
143 schur' = schurR
144
145instance Field (Complex Double) where
146#ifdef NOZGESDD
147 svd' = svdC
148 thinSVD' = thinSVDC
149#else
150 svd' = svdCd
151 thinSVD' = thinSVDCd
152#endif
153 sv' = svC
154 luPacked' = luC
155 luSolve' (l_u,perm) = lusC l_u perm
156 linearSolve' = linearSolveC
157 cholSolve' = cholSolveC
158 linearSolveLS' = linearSolveLSC
159 linearSolveSVD' = linearSolveSVDC Nothing
160 eig' = eigC
161 eigOnly = eigOnlyC
162 eigSH'' = eigH
163 eigOnlySH = eigOnlyH
164 cholSH' = cholH
165 mbCholSH' = mbCholH
166 qr' = qrC
167 qrgr' = qrgrC
168 hess' = unpackHess hessC
169 schur' = schurC
170
171--------------------------------------------------------------
172
173square m = rows m == cols m
174
175vertical m = rows m >= cols m
176
177exactHermitian m = m `equal` ctrans m
178
179--------------------------------------------------------------
180
181-- | Full singular value decomposition.
182svd :: Field t => Matrix t -> (Matrix t, Vector Double, Matrix t)
183svd = {-# SCC "svd" #-} svd'
184
185-- | A version of 'svd' which returns only the @min (rows m) (cols m)@ singular vectors of @m@.
186--
187-- If @(u,s,v) = thinSVD m@ then @m == u \<> diag s \<> trans v@.
188thinSVD :: Field t => Matrix t -> (Matrix t, Vector Double, Matrix t)
189thinSVD = {-# SCC "thinSVD" #-} thinSVD'
190
191-- | Singular values only.
192singularValues :: Field t => Matrix t -> Vector Double
193singularValues = {-# SCC "singularValues" #-} sv'
194
195-- | A version of 'svd' which returns an appropriate diagonal matrix with the singular values.
196--
197-- If @(u,d,v) = fullSVD m@ then @m == u \<> d \<> trans v@.
198fullSVD :: Field t => Matrix t -> (Matrix t, Matrix Double, Matrix t)
199fullSVD m = (u,d,v) where
200 (u,s,v) = svd m
201 d = diagRect 0 s r c
202 r = rows m
203 c = cols m
204
205-- | Similar to 'thinSVD', returning only the nonzero singular values and the corresponding singular vectors.
206compactSVD :: Field t => Matrix t -> (Matrix t, Vector Double, Matrix t)
207compactSVD m = (u', subVector 0 d s, v') where
208 (u,s,v) = thinSVD m
209 d = rankSVD (1*eps) m s `max` 1
210 u' = takeColumns d u
211 v' = takeColumns d v
212
213
214-- | Singular values and all right singular vectors.
215rightSV :: Field t => Matrix t -> (Vector Double, Matrix t)
216rightSV m | vertical m = let (_,s,v) = thinSVD m in (s,v)
217 | otherwise = let (_,s,v) = svd m in (s,v)
218
219-- | Singular values and all left singular vectors.
220leftSV :: Field t => Matrix t -> (Matrix t, Vector Double)
221leftSV m | vertical m = let (u,s,_) = svd m in (u,s)
222 | otherwise = let (u,s,_) = thinSVD m in (u,s)
223
224
225--------------------------------------------------------------
226
227-- | Obtains the LU decomposition of a matrix in a compact data structure suitable for 'luSolve'.
228luPacked :: Field t => Matrix t -> (Matrix t, [Int])
229luPacked = {-# SCC "luPacked" #-} luPacked'
230
231-- | Solution of a linear system (for several right hand sides) from the precomputed LU factorization obtained by 'luPacked'.
232luSolve :: Field t => (Matrix t, [Int]) -> Matrix t -> Matrix t
233luSolve = {-# SCC "luSolve" #-} luSolve'
234
235-- | Solve a linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition. For underconstrained or overconstrained systems use 'linearSolveLS' or 'linearSolveSVD'.
236-- It is similar to 'luSolve' . 'luPacked', but @linearSolve@ raises an error if called on a singular system.
237linearSolve :: Field t => Matrix t -> Matrix t -> Matrix t
238linearSolve = {-# SCC "linearSolve" #-} linearSolve'
239
240-- | Solve a symmetric or Hermitian positive definite linear system using a precomputed Cholesky decomposition obtained by 'chol'.
241cholSolve :: Field t => Matrix t -> Matrix t -> Matrix t
242cholSolve = {-# SCC "cholSolve" #-} cholSolve'
243
244-- | Minimum norm solution of a general linear least squares problem Ax=B using the SVD. Admits rank-deficient systems but it is slower than 'linearSolveLS'. The effective rank of A is determined by treating as zero those singular valures which are less than 'eps' times the largest singular value.
245linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t
246linearSolveSVD = {-# SCC "linearSolveSVD" #-} linearSolveSVD'
247
248
249-- | Least squared error solution of an overconstrained linear system, or the minimum norm solution of an underconstrained system. For rank-deficient systems use 'linearSolveSVD'.
250linearSolveLS :: Field t => Matrix t -> Matrix t -> Matrix t
251linearSolveLS = {-# SCC "linearSolveLS" #-} linearSolveLS'
252
253--------------------------------------------------------------
254
255-- | Eigenvalues and eigenvectors of a general square matrix.
256--
257-- If @(s,v) = eig m@ then @m \<> v == v \<> diag s@
258eig :: Field t => Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
259eig = {-# SCC "eig" #-} eig'
260
261-- | Eigenvalues of a general square matrix.
262eigenvalues :: Field t => Matrix t -> Vector (Complex Double)
263eigenvalues = {-# SCC "eigenvalues" #-} eigOnly
264
265-- | Similar to 'eigSH' without checking that the input matrix is hermitian or symmetric. It works with the upper triangular part.
266eigSH' :: Field t => Matrix t -> (Vector Double, Matrix t)
267eigSH' = {-# SCC "eigSH'" #-} eigSH''
268
269-- | Similar to 'eigenvaluesSH' without checking that the input matrix is hermitian or symmetric. It works with the upper triangular part.
270eigenvaluesSH' :: Field t => Matrix t -> Vector Double
271eigenvaluesSH' = {-# SCC "eigenvaluesSH'" #-} eigOnlySH
272
273-- | Eigenvalues and Eigenvectors of a complex hermitian or real symmetric matrix.
274--
275-- If @(s,v) = eigSH m@ then @m == v \<> diag s \<> ctrans v@
276eigSH :: Field t => Matrix t -> (Vector Double, Matrix t)
277eigSH m | exactHermitian m = eigSH' m
278 | otherwise = error "eigSH requires complex hermitian or real symmetric matrix"
279
280-- | Eigenvalues of a complex hermitian or real symmetric matrix.
281eigenvaluesSH :: Field t => Matrix t -> Vector Double
282eigenvaluesSH m | exactHermitian m = eigenvaluesSH' m
283 | otherwise = error "eigenvaluesSH requires complex hermitian or real symmetric matrix"
284
285--------------------------------------------------------------
286
287-- | QR factorization.
288--
289-- If @(q,r) = qr m@ then @m == q \<> r@, where q is unitary and r is upper triangular.
290qr :: Field t => Matrix t -> (Matrix t, Matrix t)
291qr = {-# SCC "qr" #-} unpackQR . qr'
292
293qrRaw m = qr' m
294
295{- | generate a matrix with k orthogonal columns from the output of qrRaw
296-}
297qrgr n (a,t)
298 | dim t > min (cols a) (rows a) || n < 0 || n > dim t = error "qrgr expects k <= min(rows,cols)"
299 | otherwise = qrgr' n (a,t)
300
301-- | RQ factorization.
302--
303-- If @(r,q) = rq m@ then @m == r \<> q@, where q is unitary and r is upper triangular.
304rq :: Field t => Matrix t -> (Matrix t, Matrix t)
305rq m = {-# SCC "rq" #-} (r,q) where
306 (q',r') = qr $ trans $ rev1 m
307 r = rev2 (trans r')
308 q = rev2 (trans q')
309 rev1 = flipud . fliprl
310 rev2 = fliprl . flipud
311
312-- | Hessenberg factorization.
313--
314-- If @(p,h) = hess m@ then @m == p \<> h \<> ctrans p@, where p is unitary
315-- and h is in upper Hessenberg form (it has zero entries below the first subdiagonal).
316hess :: Field t => Matrix t -> (Matrix t, Matrix t)
317hess = hess'
318
319-- | Schur factorization.
320--
321-- If @(u,s) = schur m@ then @m == u \<> s \<> ctrans u@, where u is unitary
322-- and s is a Shur matrix. A complex Schur matrix is upper triangular. A real Schur matrix is
323-- upper triangular in 2x2 blocks.
324--
325-- \"Anything that the Jordan decomposition can do, the Schur decomposition
326-- can do better!\" (Van Loan)
327schur :: Field t => Matrix t -> (Matrix t, Matrix t)
328schur = schur'
329
330
331-- | Similar to 'cholSH', but instead of an error (e.g., caused by a matrix not positive definite) it returns 'Nothing'.
332mbCholSH :: Field t => Matrix t -> Maybe (Matrix t)
333mbCholSH = {-# SCC "mbCholSH" #-} mbCholSH'
334
335-- | Similar to 'chol', without checking that the input matrix is hermitian or symmetric. It works with the upper triangular part.
336cholSH :: Field t => Matrix t -> Matrix t
337cholSH = {-# SCC "cholSH" #-} cholSH'
338
339-- | Cholesky factorization of a positive definite hermitian or symmetric matrix.
340--
341-- If @c = chol m@ then @c@ is upper triangular and @m == ctrans c \<> c@.
342chol :: Field t => Matrix t -> Matrix t
343chol m | exactHermitian m = cholSH m
344 | otherwise = error "chol requires positive definite complex hermitian or real symmetric matrix"
345
346
347-- | Joint computation of inverse and logarithm of determinant of a square matrix.
348invlndet :: Field t
349 => Matrix t
350 -> (Matrix t, (t, t)) -- ^ (inverse, (log abs det, sign or phase of det))
351invlndet m | square m = (im,(ladm,sdm))
352 | otherwise = error $ "invlndet of nonsquare "++ shSize m ++ " matrix"
353 where
354 lp@(lup,perm) = luPacked m
355 s = signlp (rows m) perm
356 dg = toList $ takeDiag $ lup
357 ladm = sum $ map (log.abs) dg
358 sdm = s* product (map signum dg)
359 im = luSolve lp (ident (rows m))
360
361
362-- | Determinant of a square matrix. To avoid possible overflow or underflow use 'invlndet'.
363det :: Field t => Matrix t -> t
364det m | square m = {-# SCC "det" #-} s * (product $ toList $ takeDiag $ lup)
365 | otherwise = error $ "det of nonsquare "++ shSize m ++ " matrix"
366 where (lup,perm) = luPacked m
367 s = signlp (rows m) perm
368
369-- | Explicit LU factorization of a general matrix.
370--
371-- If @(l,u,p,s) = lu m@ then @m == p \<> l \<> u@, where l is lower triangular,
372-- u is upper triangular, p is a permutation matrix and s is the signature of the permutation.
373lu :: Field t => Matrix t -> (Matrix t, Matrix t, Matrix t, t)
374lu = luFact . luPacked
375
376-- | Inverse of a square matrix. See also 'invlndet'.
377inv :: Field t => Matrix t -> Matrix t
378inv m | square m = m `linearSolve` ident (rows m)
379 | otherwise = error $ "inv of nonsquare "++ shSize m ++ " matrix"
380
381
382-- | Pseudoinverse of a general matrix with default tolerance ('pinvTol' 1, similar to GNU-Octave).
383pinv :: Field t => Matrix t -> Matrix t
384pinv = pinvTol 1
385
386{- | @pinvTol r@ computes the pseudoinverse of a matrix with tolerance @tol=r*g*eps*(max rows cols)@, where g is the greatest singular value.
387
388@
389m = (3><3) [ 1, 0, 0
390 , 0, 1, 0
391 , 0, 0, 1e-10] :: Matrix Double
392@
393
394>>> pinv m
3951. 0. 0.
3960. 1. 0.
3970. 0. 10000000000.
398
399>>> pinvTol 1E8 m
4001. 0. 0.
4010. 1. 0.
4020. 0. 1.
403
404-}
405
406pinvTol :: Field t => Double -> Matrix t -> Matrix t
407pinvTol t m = conj v' `mXm` diag s' `mXm` ctrans u' where
408 (u,s,v) = thinSVD m
409 sl@(g:_) = toList s
410 s' = real . fromList . map rec $ sl
411 rec x = if x <= g*tol then x else 1/x
412 tol = (fromIntegral (max r c) * g * t * eps)
413 r = rows m
414 c = cols m
415 d = dim s
416 u' = takeColumns d u
417 v' = takeColumns d v
418
419
420-- | Numeric rank of a matrix from the SVD decomposition.
421rankSVD :: Element t
422 => Double -- ^ numeric zero (e.g. 1*'eps')
423 -> Matrix t -- ^ input matrix m
424 -> Vector Double -- ^ 'sv' of m
425 -> Int -- ^ rank of m
426rankSVD teps m s = ranksv teps (max (rows m) (cols m)) (toList s)
427
428-- | Numeric rank of a matrix from its singular values.
429ranksv :: Double -- ^ numeric zero (e.g. 1*'eps')
430 -> Int -- ^ maximum dimension of the matrix
431 -> [Double] -- ^ singular values
432 -> Int -- ^ rank of m
433ranksv teps maxdim s = k where
434 g = maximum s
435 tol = fromIntegral maxdim * g * teps
436 s' = filter (>tol) s
437 k = if g > teps then length s' else 0
438
439-- | The machine precision of a Double: @eps = 2.22044604925031e-16@ (the value used by GNU-Octave).
440eps :: Double
441eps = 2.22044604925031e-16
442
443
444-- | 1 + 0.5*peps == 1, 1 + 0.6*peps /= 1
445peps :: RealFloat x => x
446peps = x where x = 2.0 ** fromIntegral (1 - floatDigits x)
447
448
449-- | The imaginary unit: @i = 0.0 :+ 1.0@
450i :: Complex Double
451i = 0:+1
452
453-----------------------------------------------------------------------
454
455-- | The nullspace of a matrix from its precomputed SVD decomposition.
456nullspaceSVD :: Field t
457 => Either Double Int -- ^ Left \"numeric\" zero (eg. 1*'eps'),
458 -- or Right \"theoretical\" matrix rank.
459 -> Matrix t -- ^ input matrix m
460 -> (Vector Double, Matrix t) -- ^ 'rightSV' of m
461 -> [Vector t] -- ^ list of unitary vectors spanning the nullspace
462nullspaceSVD hint a (s,v) = vs where
463 tol = case hint of
464 Left t -> t
465 _ -> eps
466 k = case hint of
467 Right t -> t
468 _ -> rankSVD tol a s
469 vs = drop k $ toRows $ ctrans v
470
471
472-- | The nullspace of a matrix. See also 'nullspaceSVD'.
473nullspacePrec :: Field t
474 => Double -- ^ relative tolerance in 'eps' units (e.g., use 3 to get 3*'eps')
475 -> Matrix t -- ^ input matrix
476 -> [Vector t] -- ^ list of unitary vectors spanning the nullspace
477nullspacePrec t m = nullspaceSVD (Left (t*eps)) m (rightSV m)
478
479-- | The nullspace of a matrix, assumed to be one-dimensional, with machine precision.
480nullVector :: Field t => Matrix t -> Vector t
481nullVector = last . nullspacePrec 1
482
483orth :: Field t => Matrix t -> [Vector t]
484-- ^ Return an orthonormal basis of the range space of a matrix
485orth m = take r $ toColumns u
486 where
487 (u,s,_) = compactSVD m
488 r = ranksv eps (max (rows m) (cols m)) (toList s)
489
490------------------------------------------------------------------------
491
492-- many thanks, quickcheck!
493
494haussholder :: (Field a) => a -> Vector a -> Matrix a
495haussholder tau v = ident (dim v) `sub` (tau `scale` (w `mXm` ctrans w))
496 where w = asColumn v
497
498
499zh k v = fromList $ replicate (k-1) 0 ++ (1:drop k xs)
500 where xs = toList v
501
502zt 0 v = v
503zt k v = vjoin [subVector 0 (dim v - k) v, konst' 0 k]
504
505
506unpackQR :: (Field t) => (Matrix t, Vector t) -> (Matrix t, Matrix t)
507unpackQR (pq, tau) = {-# SCC "unpackQR" #-} (q,r)
508 where cs = toColumns pq
509 m = rows pq
510 n = cols pq
511 mn = min m n
512 r = fromColumns $ zipWith zt ([m-1, m-2 .. 1] ++ repeat 0) cs
513 vs = zipWith zh [1..mn] cs
514 hs = zipWith haussholder (toList tau) vs
515 q = foldl1' mXm hs
516
517unpackHess :: (Field t) => (Matrix t -> (Matrix t,Vector t)) -> Matrix t -> (Matrix t, Matrix t)
518unpackHess hf m
519 | rows m == 1 = ((1><1)[1],m)
520 | otherwise = (uH . hf) m
521
522uH (pq, tau) = (p,h)
523 where cs = toColumns pq
524 m = rows pq
525 n = cols pq
526 mn = min m n
527 h = fromColumns $ zipWith zt ([m-2, m-3 .. 1] ++ repeat 0) cs
528 vs = zipWith zh [2..mn] cs
529 hs = zipWith haussholder (toList tau) vs
530 p = foldl1' mXm hs
531
532--------------------------------------------------------------------------
533
534-- | Reciprocal of the 2-norm condition number of a matrix, computed from the singular values.
535rcond :: Field t => Matrix t -> Double
536rcond m = last s / head s
537 where s = toList (singularValues m)
538
539-- | Number of linearly independent rows or columns.
540rank :: Field t => Matrix t -> Int
541rank m = rankSVD eps m (singularValues m)
542
543{-
544expm' m = case diagonalize (complex m) of
545 Just (l,v) -> v `mXm` diag (exp l) `mXm` inv v
546 Nothing -> error "Sorry, expm not yet implemented for non-diagonalizable matrices"
547 where exp = vectorMapC Exp
548-}
549
550diagonalize m = if rank v == n
551 then Just (l,v)
552 else Nothing
553 where n = rows m
554 (l,v) = if exactHermitian m
555 then let (l',v') = eigSH m in (real l', v')
556 else eig m
557
558-- | Generic matrix functions for diagonalizable matrices. For instance:
559--
560-- @logm = matFunc log@
561--
562matFunc :: (Complex Double -> Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
563matFunc f m = case diagonalize m of
564 Just (l,v) -> v `mXm` diag (mapVector f l) `mXm` inv v
565 Nothing -> error "Sorry, matFunc requires a diagonalizable matrix"
566
567--------------------------------------------------------------
568
569golubeps :: Integer -> Integer -> Double
570golubeps p q = a * fromIntegral b / fromIntegral c where
571 a = 2^^(3-p-q)
572 b = fact p * fact q
573 c = fact (p+q) * fact (p+q+1)
574 fact n = product [1..n]
575
576epslist :: [(Int,Double)]
577epslist = [ (fromIntegral k, golubeps k k) | k <- [1..]]
578
579geps delta = head [ k | (k,g) <- epslist, g<delta]
580
581
582{- | Matrix exponential. It uses a direct translation of Algorithm 11.3.1 in Golub & Van Loan,
583 based on a scaled Pade approximation.
584-}
585expm :: Field t => Matrix t -> Matrix t
586expm = expGolub
587
588expGolub :: Field t => Matrix t -> Matrix t
589expGolub m = iterate msq f !! j
590 where j = max 0 $ floor $ logBase 2 $ pnorm Infinity m
591 a = m */ fromIntegral ((2::Int)^j)
592 q = geps eps -- 7 steps
593 eye = ident (rows m)
594 work (k,c,x,n,d) = (k',c',x',n',d')
595 where k' = k+1
596 c' = c * fromIntegral (q-k+1) / fromIntegral ((2*q-k+1)*k)
597 x' = a <> x
598 n' = n |+| (c' .* x')
599 d' = d |+| (((-1)^k * c') .* x')
600 (_,_,_,nf,df) = iterate work (1,1,eye,eye,eye) !! q
601 f = linearSolve df nf
602 msq x = x <> x
603
604 (<>) = multiply
605 v */ x = scale (recip x) v
606 (.*) = scale
607 (|+|) = add
608
609--------------------------------------------------------------
610
611{- | Matrix square root. Currently it uses a simple iterative algorithm described in Wikipedia.
612It only works with invertible matrices that have a real solution. For diagonalizable matrices you can try @matFunc sqrt@.
613
614@m = (2><2) [4,9
615 ,0,4] :: Matrix Double@
616
617>>> sqrtm m
618(2><2)
619 [ 2.0, 2.25
620 , 0.0, 2.0 ]
621
622-}
623sqrtm :: Field t => Matrix t -> Matrix t
624sqrtm = sqrtmInv
625
626sqrtmInv x = fst $ fixedPoint $ iterate f (x, ident (rows x))
627 where fixedPoint (a:b:rest) | pnorm PNorm1 (fst a |-| fst b) < peps = a
628 | otherwise = fixedPoint (b:rest)
629 fixedPoint _ = error "fixedpoint with impossible inputs"
630 f (y,z) = (0.5 .* (y |+| inv z),
631 0.5 .* (inv y |+| z))
632 (.*) = scale
633 (|+|) = add
634 (|-|) = sub
635
636------------------------------------------------------------------
637
638signlp r vals = foldl f 1 (zip [0..r-1] vals)
639 where f s (a,b) | a /= b = -s
640 | otherwise = s
641
642swap (arr,s) (a,b) | a /= b = (arr // [(a, arr!b),(b,arr!a)],-s)
643 | otherwise = (arr,s)
644
645fixPerm r vals = (fromColumns $ elems res, sign)
646 where v = [0..r-1]
647 s = toColumns (ident r)
648 (res,sign) = foldl swap (listArray (0,r-1) s, 1) (zip v vals)
649
650triang r c h v = (r><c) [el s t | s<-[0..r-1], t<-[0..c-1]]
651 where el p q = if q-p>=h then v else 1 - v
652
653luFact (l_u,perm) | r <= c = (l ,u ,p, s)
654 | otherwise = (l',u',p, s)
655 where
656 r = rows l_u
657 c = cols l_u
658 tu = triang r c 0 1
659 tl = triang r c 0 0
660 l = takeColumns r (l_u |*| tl) |+| diagRect 0 (konst' 1 r) r r
661 u = l_u |*| tu
662 (p,s) = fixPerm r perm
663 l' = (l_u |*| tl) |+| diagRect 0 (konst' 1 c) r c
664 u' = takeRows c (l_u |*| tu)
665 (|+|) = add
666 (|*|) = mul
667
668---------------------------------------------------------------------------
669
670data NormType = Infinity | PNorm1 | PNorm2 | Frobenius
671
672class (RealFloat (RealOf t)) => Normed c t where
673 pnorm :: NormType -> c t -> RealOf t
674
675instance Normed Vector Double where
676 pnorm PNorm1 = norm1
677 pnorm PNorm2 = norm2
678 pnorm Infinity = normInf
679 pnorm Frobenius = norm2
680
681instance Normed Vector (Complex Double) where
682 pnorm PNorm1 = norm1
683 pnorm PNorm2 = norm2
684 pnorm Infinity = normInf
685 pnorm Frobenius = pnorm PNorm2
686
687instance Normed Vector Float where
688 pnorm PNorm1 = norm1
689 pnorm PNorm2 = norm2
690 pnorm Infinity = normInf
691 pnorm Frobenius = pnorm PNorm2
692
693instance Normed Vector (Complex Float) where
694 pnorm PNorm1 = norm1
695 pnorm PNorm2 = norm2
696 pnorm Infinity = normInf
697 pnorm Frobenius = pnorm PNorm2
698
699
700instance Normed Matrix Double where
701 pnorm PNorm1 = maximum . map (pnorm PNorm1) . toColumns
702 pnorm PNorm2 = (@>0) . singularValues
703 pnorm Infinity = pnorm PNorm1 . trans
704 pnorm Frobenius = pnorm PNorm2 . flatten
705
706instance Normed Matrix (Complex Double) where
707 pnorm PNorm1 = maximum . map (pnorm PNorm1) . toColumns
708 pnorm PNorm2 = (@>0) . singularValues
709 pnorm Infinity = pnorm PNorm1 . trans
710 pnorm Frobenius = pnorm PNorm2 . flatten
711
712instance Normed Matrix Float where
713 pnorm PNorm1 = maximum . map (pnorm PNorm1) . toColumns
714 pnorm PNorm2 = realToFrac . (@>0) . singularValues . double
715 pnorm Infinity = pnorm PNorm1 . trans
716 pnorm Frobenius = pnorm PNorm2 . flatten
717
718instance Normed Matrix (Complex Float) where
719 pnorm PNorm1 = maximum . map (pnorm PNorm1) . toColumns
720 pnorm PNorm2 = realToFrac . (@>0) . singularValues . double
721 pnorm Infinity = pnorm PNorm1 . trans
722 pnorm Frobenius = pnorm PNorm2 . flatten
723
724-- | Approximate number of common digits in the maximum element.
725relativeError :: (Normed c t, Container c t) => c t -> c t -> Int
726relativeError x y = dig (norm (x `sub` y) / norm x)
727 where norm = pnorm Infinity
728 dig r = round $ -logBase 10 (realToFrac r :: Double)
729
730----------------------------------------------------------------------
731
732-- | Generalized symmetric positive definite eigensystem Av = lBv,
733-- for A and B symmetric, B positive definite (conditions not checked).
734geigSH' :: Field t
735 => Matrix t -- ^ A
736 -> Matrix t -- ^ B
737 -> (Vector Double, Matrix t)
738geigSH' a b = (l,v')
739 where
740 u = cholSH b
741 iu = inv u
742 c = ctrans iu <> a <> iu
743 (l,v) = eigSH' c
744 v' = iu <> v
745 (<>) = mXm
746
diff --git a/packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK.hs b/packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK.hs
new file mode 100644
index 0000000..11394a6
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK.hs
@@ -0,0 +1,555 @@
1-----------------------------------------------------------------------------
2-- |
3-- Module : Numeric.LinearAlgebra.LAPACK
4-- Copyright : (c) Alberto Ruiz 2006-7
5-- License : GPL-style
6--
7-- Maintainer : Alberto Ruiz (aruiz at um dot es)
8-- Stability : provisional
9-- Portability : portable (uses FFI)
10--
11-- Functional interface to selected LAPACK functions (<http://www.netlib.org/lapack>).
12--
13-----------------------------------------------------------------------------
14{-# OPTIONS_HADDOCK hide #-}
15
16module Numeric.LinearAlgebra.LAPACK (
17 -- * Matrix product
18 multiplyR, multiplyC, multiplyF, multiplyQ,
19 -- * Linear systems
20 linearSolveR, linearSolveC,
21 lusR, lusC,
22 cholSolveR, cholSolveC,
23 linearSolveLSR, linearSolveLSC,
24 linearSolveSVDR, linearSolveSVDC,
25 -- * SVD
26 svR, svRd, svC, svCd,
27 svdR, svdRd, svdC, svdCd,
28 thinSVDR, thinSVDRd, thinSVDC, thinSVDCd,
29 rightSVR, rightSVC, leftSVR, leftSVC,
30 -- * Eigensystems
31 eigR, eigC, eigS, eigS', eigH, eigH',
32 eigOnlyR, eigOnlyC, eigOnlyS, eigOnlyH,
33 -- * LU
34 luR, luC,
35 -- * Cholesky
36 cholS, cholH, mbCholS, mbCholH,
37 -- * QR
38 qrR, qrC, qrgrR, qrgrC,
39 -- * Hessenberg
40 hessR, hessC,
41 -- * Schur
42 schurR, schurC
43) where
44
45import Data.Packed.Internal
46import Data.Packed.Matrix
47import Numeric.Conversion
48import Numeric.GSL.Vector(vectorMapValR, FunCodeSV(Scale))
49
50import Foreign.Ptr(nullPtr)
51import Foreign.C.Types
52import Control.Monad(when)
53import System.IO.Unsafe(unsafePerformIO)
54
55-----------------------------------------------------------------------------------
56
57foreign import ccall unsafe "multiplyR" dgemmc :: CInt -> CInt -> TMMM
58foreign import ccall unsafe "multiplyC" zgemmc :: CInt -> CInt -> TCMCMCM
59foreign import ccall unsafe "multiplyF" sgemmc :: CInt -> CInt -> TFMFMFM
60foreign import ccall unsafe "multiplyQ" cgemmc :: CInt -> CInt -> TQMQMQM
61
62isT Matrix{order = ColumnMajor} = 0
63isT Matrix{order = RowMajor} = 1
64
65tt x@Matrix{order = ColumnMajor} = x
66tt x@Matrix{order = RowMajor} = trans x
67
68multiplyAux f st a b = unsafePerformIO $ do
69 when (cols a /= rows b) $ error $ "inconsistent dimensions in matrix product "++
70 show (rows a,cols a) ++ " x " ++ show (rows b, cols b)
71 s <- createMatrix ColumnMajor (rows a) (cols b)
72 app3 (f (isT a) (isT b)) mat (tt a) mat (tt b) mat s st
73 return s
74
75-- | Matrix product based on BLAS's /dgemm/.
76multiplyR :: Matrix Double -> Matrix Double -> Matrix Double
77multiplyR a b = {-# SCC "multiplyR" #-} multiplyAux dgemmc "dgemmc" a b
78
79-- | Matrix product based on BLAS's /zgemm/.
80multiplyC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
81multiplyC a b = multiplyAux zgemmc "zgemmc" a b
82
83-- | Matrix product based on BLAS's /sgemm/.
84multiplyF :: Matrix Float -> Matrix Float -> Matrix Float
85multiplyF a b = multiplyAux sgemmc "sgemmc" a b
86
87-- | Matrix product based on BLAS's /cgemm/.
88multiplyQ :: Matrix (Complex Float) -> Matrix (Complex Float) -> Matrix (Complex Float)
89multiplyQ a b = multiplyAux cgemmc "cgemmc" a b
90
91-----------------------------------------------------------------------------
92foreign import ccall unsafe "svd_l_R" dgesvd :: TMMVM
93foreign import ccall unsafe "svd_l_C" zgesvd :: TCMCMVCM
94foreign import ccall unsafe "svd_l_Rdd" dgesdd :: TMMVM
95foreign import ccall unsafe "svd_l_Cdd" zgesdd :: TCMCMVCM
96
97-- | Full SVD of a real matrix using LAPACK's /dgesvd/.
98svdR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
99svdR = svdAux dgesvd "svdR" . fmat
100
101-- | Full SVD of a real matrix using LAPACK's /dgesdd/.
102svdRd :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
103svdRd = svdAux dgesdd "svdRdd" . fmat
104
105-- | Full SVD of a complex matrix using LAPACK's /zgesvd/.
106svdC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double))
107svdC = svdAux zgesvd "svdC" . fmat
108
109-- | Full SVD of a complex matrix using LAPACK's /zgesdd/.
110svdCd :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double))
111svdCd = svdAux zgesdd "svdCdd" . fmat
112
113svdAux f st x = unsafePerformIO $ do
114 u <- createMatrix ColumnMajor r r
115 s <- createVector (min r c)
116 v <- createMatrix ColumnMajor c c
117 app4 f mat x mat u vec s mat v st
118 return (u,s,trans v)
119 where r = rows x
120 c = cols x
121
122
123-- | Thin SVD of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'S\'.
124thinSVDR :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
125thinSVDR = thinSVDAux dgesvd "thinSVDR" . fmat
126
127-- | Thin SVD of a complex matrix, using LAPACK's /zgesvd/ with jobu == jobvt == \'S\'.
128thinSVDC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double))
129thinSVDC = thinSVDAux zgesvd "thinSVDC" . fmat
130
131-- | Thin SVD of a real matrix, using LAPACK's /dgesdd/ with jobz == \'S\'.
132thinSVDRd :: Matrix Double -> (Matrix Double, Vector Double, Matrix Double)
133thinSVDRd = thinSVDAux dgesdd "thinSVDRdd" . fmat
134
135-- | Thin SVD of a complex matrix, using LAPACK's /zgesdd/ with jobz == \'S\'.
136thinSVDCd :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double, Matrix (Complex Double))
137thinSVDCd = thinSVDAux zgesdd "thinSVDCdd" . fmat
138
139thinSVDAux f st x = unsafePerformIO $ do
140 u <- createMatrix ColumnMajor r q
141 s <- createVector q
142 v <- createMatrix ColumnMajor q c
143 app4 f mat x mat u vec s mat v st
144 return (u,s,trans v)
145 where r = rows x
146 c = cols x
147 q = min r c
148
149
150-- | Singular values of a real matrix, using LAPACK's /dgesvd/ with jobu == jobvt == \'N\'.
151svR :: Matrix Double -> Vector Double
152svR = svAux dgesvd "svR" . fmat
153
154-- | Singular values of a complex matrix, using LAPACK's /zgesvd/ with jobu == jobvt == \'N\'.
155svC :: Matrix (Complex Double) -> Vector Double
156svC = svAux zgesvd "svC" . fmat
157
158-- | Singular values of a real matrix, using LAPACK's /dgesdd/ with jobz == \'N\'.
159svRd :: Matrix Double -> Vector Double
160svRd = svAux dgesdd "svRd" . fmat
161
162-- | Singular values of a complex matrix, using LAPACK's /zgesdd/ with jobz == \'N\'.
163svCd :: Matrix (Complex Double) -> Vector Double
164svCd = svAux zgesdd "svCd" . fmat
165
166svAux f st x = unsafePerformIO $ do
167 s <- createVector q
168 app2 g mat x vec s st
169 return s
170 where r = rows x
171 c = cols x
172 q = min r c
173 g ra ca pa nb pb = f ra ca pa 0 0 nullPtr nb pb 0 0 nullPtr
174
175
176-- | Singular values and all right singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'N\' and jobvt == \'A\'.
177rightSVR :: Matrix Double -> (Vector Double, Matrix Double)
178rightSVR = rightSVAux dgesvd "rightSVR" . fmat
179
180-- | Singular values and all right singular vectors of a complex matrix, using LAPACK's /zgesvd/ with jobu == \'N\' and jobvt == \'A\'.
181rightSVC :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double))
182rightSVC = rightSVAux zgesvd "rightSVC" . fmat
183
184rightSVAux f st x = unsafePerformIO $ do
185 s <- createVector q
186 v <- createMatrix ColumnMajor c c
187 app3 g mat x vec s mat v st
188 return (s,trans v)
189 where r = rows x
190 c = cols x
191 q = min r c
192 g ra ca pa = f ra ca pa 0 0 nullPtr
193
194
195-- | Singular values and all left singular vectors of a real matrix, using LAPACK's /dgesvd/ with jobu == \'A\' and jobvt == \'N\'.
196leftSVR :: Matrix Double -> (Matrix Double, Vector Double)
197leftSVR = leftSVAux dgesvd "leftSVR" . fmat
198
199-- | Singular values and all left singular vectors of a complex matrix, using LAPACK's /zgesvd/ with jobu == \'A\' and jobvt == \'N\'.
200leftSVC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector Double)
201leftSVC = leftSVAux zgesvd "leftSVC" . fmat
202
203leftSVAux f st x = unsafePerformIO $ do
204 u <- createMatrix ColumnMajor r r
205 s <- createVector q
206 app3 g mat x mat u vec s st
207 return (u,s)
208 where r = rows x
209 c = cols x
210 q = min r c
211 g ra ca pa ru cu pu nb pb = f ra ca pa ru cu pu nb pb 0 0 nullPtr
212
213-----------------------------------------------------------------------------
214
215foreign import ccall unsafe "eig_l_R" dgeev :: TMMCVM
216foreign import ccall unsafe "eig_l_C" zgeev :: TCMCMCVCM
217foreign import ccall unsafe "eig_l_S" dsyev :: CInt -> TMVM
218foreign import ccall unsafe "eig_l_H" zheev :: CInt -> TCMVCM
219
220eigAux f st m = unsafePerformIO $ do
221 l <- createVector r
222 v <- createMatrix ColumnMajor r r
223 app3 g mat m vec l mat v st
224 return (l,v)
225 where r = rows m
226 g ra ca pa = f ra ca pa 0 0 nullPtr
227
228
229-- | Eigenvalues and right eigenvectors of a general complex matrix, using LAPACK's /zgeev/.
230-- The eigenvectors are the columns of v. The eigenvalues are not sorted.
231eigC :: Matrix (Complex Double) -> (Vector (Complex Double), Matrix (Complex Double))
232eigC = eigAux zgeev "eigC" . fmat
233
234eigOnlyAux f st m = unsafePerformIO $ do
235 l <- createVector r
236 app2 g mat m vec l st
237 return l
238 where r = rows m
239 g ra ca pa nl pl = f ra ca pa 0 0 nullPtr nl pl 0 0 nullPtr
240
241-- | Eigenvalues of a general complex matrix, using LAPACK's /zgeev/ with jobz == \'N\'.
242-- The eigenvalues are not sorted.
243eigOnlyC :: Matrix (Complex Double) -> Vector (Complex Double)
244eigOnlyC = eigOnlyAux zgeev "eigOnlyC" . fmat
245
246-- | Eigenvalues and right eigenvectors of a general real matrix, using LAPACK's /dgeev/.
247-- The eigenvectors are the columns of v. The eigenvalues are not sorted.
248eigR :: Matrix Double -> (Vector (Complex Double), Matrix (Complex Double))
249eigR m = (s', v'')
250 where (s,v) = eigRaux (fmat m)
251 s' = fixeig1 s
252 v' = toRows $ trans v
253 v'' = fromColumns $ fixeig (toList s') v'
254
255eigRaux :: Matrix Double -> (Vector (Complex Double), Matrix Double)
256eigRaux m = unsafePerformIO $ do
257 l <- createVector r
258 v <- createMatrix ColumnMajor r r
259 app3 g mat m vec l mat v "eigR"
260 return (l,v)
261 where r = rows m
262 g ra ca pa = dgeev ra ca pa 0 0 nullPtr
263
264fixeig1 s = toComplex' (subVector 0 r (asReal s), subVector r r (asReal s))
265 where r = dim s
266
267fixeig [] _ = []
268fixeig [_] [v] = [comp' v]
269fixeig ((r1:+i1):(r2:+i2):r) (v1:v2:vs)
270 | r1 == r2 && i1 == (-i2) = toComplex' (v1,v2) : toComplex' (v1,scale (-1) v2) : fixeig r vs
271 | otherwise = comp' v1 : fixeig ((r2:+i2):r) (v2:vs)
272 where scale = vectorMapValR Scale
273fixeig _ _ = error "fixeig with impossible inputs"
274
275
276-- | Eigenvalues of a general real matrix, using LAPACK's /dgeev/ with jobz == \'N\'.
277-- The eigenvalues are not sorted.
278eigOnlyR :: Matrix Double -> Vector (Complex Double)
279eigOnlyR = fixeig1 . eigOnlyAux dgeev "eigOnlyR" . fmat
280
281
282-----------------------------------------------------------------------------
283
284eigSHAux f st m = unsafePerformIO $ do
285 l <- createVector r
286 v <- createMatrix ColumnMajor r r
287 app3 f mat m vec l mat v st
288 return (l,v)
289 where r = rows m
290
291-- | Eigenvalues and right eigenvectors of a symmetric real matrix, using LAPACK's /dsyev/.
292-- The eigenvectors are the columns of v.
293-- The eigenvalues are sorted in descending order (use 'eigS'' for ascending order).
294eigS :: Matrix Double -> (Vector Double, Matrix Double)
295eigS m = (s', fliprl v)
296 where (s,v) = eigS' (fmat m)
297 s' = fromList . reverse . toList $ s
298
299-- | 'eigS' in ascending order
300eigS' :: Matrix Double -> (Vector Double, Matrix Double)
301eigS' = eigSHAux (dsyev 1) "eigS'" . fmat
302
303-- | Eigenvalues and right eigenvectors of a hermitian complex matrix, using LAPACK's /zheev/.
304-- The eigenvectors are the columns of v.
305-- The eigenvalues are sorted in descending order (use 'eigH'' for ascending order).
306eigH :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double))
307eigH m = (s', fliprl v)
308 where (s,v) = eigH' (fmat m)
309 s' = fromList . reverse . toList $ s
310
311-- | 'eigH' in ascending order
312eigH' :: Matrix (Complex Double) -> (Vector Double, Matrix (Complex Double))
313eigH' = eigSHAux (zheev 1) "eigH'" . fmat
314
315
316-- | Eigenvalues of a symmetric real matrix, using LAPACK's /dsyev/ with jobz == \'N\'.
317-- The eigenvalues are sorted in descending order.
318eigOnlyS :: Matrix Double -> Vector Double
319eigOnlyS = vrev . fst. eigSHAux (dsyev 0) "eigS'" . fmat
320
321-- | Eigenvalues of a hermitian complex matrix, using LAPACK's /zheev/ with jobz == \'N\'.
322-- The eigenvalues are sorted in descending order.
323eigOnlyH :: Matrix (Complex Double) -> Vector Double
324eigOnlyH = vrev . fst. eigSHAux (zheev 1) "eigH'" . fmat
325
326vrev = flatten . flipud . reshape 1
327
328-----------------------------------------------------------------------------
329foreign import ccall unsafe "linearSolveR_l" dgesv :: TMMM
330foreign import ccall unsafe "linearSolveC_l" zgesv :: TCMCMCM
331foreign import ccall unsafe "cholSolveR_l" dpotrs :: TMMM
332foreign import ccall unsafe "cholSolveC_l" zpotrs :: TCMCMCM
333
334linearSolveSQAux f st a b
335 | n1==n2 && n1==r = unsafePerformIO $ do
336 s <- createMatrix ColumnMajor r c
337 app3 f mat a mat b mat s st
338 return s
339 | otherwise = error $ st ++ " of nonsquare matrix"
340 where n1 = rows a
341 n2 = cols a
342 r = rows b
343 c = cols b
344
345-- | Solve a real linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /dgesv/. For underconstrained or overconstrained systems use 'linearSolveLSR' or 'linearSolveSVDR'. See also 'lusR'.
346linearSolveR :: Matrix Double -> Matrix Double -> Matrix Double
347linearSolveR a b = linearSolveSQAux dgesv "linearSolveR" (fmat a) (fmat b)
348
349-- | Solve a complex linear system (for square coefficient matrix and several right-hand sides) using the LU decomposition, based on LAPACK's /zgesv/. For underconstrained or overconstrained systems use 'linearSolveLSC' or 'linearSolveSVDC'. See also 'lusC'.
350linearSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
351linearSolveC a b = linearSolveSQAux zgesv "linearSolveC" (fmat a) (fmat b)
352
353
354-- | Solves a symmetric positive definite system of linear equations using a precomputed Cholesky factorization obtained by 'cholS'.
355cholSolveR :: Matrix Double -> Matrix Double -> Matrix Double
356cholSolveR a b = linearSolveSQAux dpotrs "cholSolveR" (fmat a) (fmat b)
357
358-- | Solves a Hermitian positive definite system of linear equations using a precomputed Cholesky factorization obtained by 'cholH'.
359cholSolveC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
360cholSolveC a b = linearSolveSQAux zpotrs "cholSolveC" (fmat a) (fmat b)
361
362-----------------------------------------------------------------------------------
363foreign import ccall unsafe "linearSolveLSR_l" dgels :: TMMM
364foreign import ccall unsafe "linearSolveLSC_l" zgels :: TCMCMCM
365foreign import ccall unsafe "linearSolveSVDR_l" dgelss :: Double -> TMMM
366foreign import ccall unsafe "linearSolveSVDC_l" zgelss :: Double -> TCMCMCM
367
368linearSolveAux f st a b = unsafePerformIO $ do
369 r <- createMatrix ColumnMajor (max m n) nrhs
370 app3 f mat a mat b mat r st
371 return r
372 where m = rows a
373 n = cols a
374 nrhs = cols b
375
376-- | Least squared error solution of an overconstrained real linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /dgels/. For rank-deficient systems use 'linearSolveSVDR'.
377linearSolveLSR :: Matrix Double -> Matrix Double -> Matrix Double
378linearSolveLSR a b = subMatrix (0,0) (cols a, cols b) $
379 linearSolveAux dgels "linearSolverLSR" (fmat a) (fmat b)
380
381-- | Least squared error solution of an overconstrained complex linear system, or the minimum norm solution of an underconstrained system, using LAPACK's /zgels/. For rank-deficient systems use 'linearSolveSVDC'.
382linearSolveLSC :: Matrix (Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
383linearSolveLSC a b = subMatrix (0,0) (cols a, cols b) $
384 linearSolveAux zgels "linearSolveLSC" (fmat a) (fmat b)
385
386-- | Minimum norm solution of a general real linear least squares problem Ax=B using the SVD, based on LAPACK's /dgelss/. Admits rank-deficient systems but it is slower than 'linearSolveLSR'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used.
387linearSolveSVDR :: Maybe Double -- ^ rcond
388 -> Matrix Double -- ^ coefficient matrix
389 -> Matrix Double -- ^ right hand sides (as columns)
390 -> Matrix Double -- ^ solution vectors (as columns)
391linearSolveSVDR (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $
392 linearSolveAux (dgelss rcond) "linearSolveSVDR" (fmat a) (fmat b)
393linearSolveSVDR Nothing a b = linearSolveSVDR (Just (-1)) (fmat a) (fmat b)
394
395-- | Minimum norm solution of a general complex linear least squares problem Ax=B using the SVD, based on LAPACK's /zgelss/. Admits rank-deficient systems but it is slower than 'linearSolveLSC'. The effective rank of A is determined by treating as zero those singular valures which are less than rcond times the largest singular value. If rcond == Nothing machine precision is used.
396linearSolveSVDC :: Maybe Double -- ^ rcond
397 -> Matrix (Complex Double) -- ^ coefficient matrix
398 -> Matrix (Complex Double) -- ^ right hand sides (as columns)
399 -> Matrix (Complex Double) -- ^ solution vectors (as columns)
400linearSolveSVDC (Just rcond) a b = subMatrix (0,0) (cols a, cols b) $
401 linearSolveAux (zgelss rcond) "linearSolveSVDC" (fmat a) (fmat b)
402linearSolveSVDC Nothing a b = linearSolveSVDC (Just (-1)) (fmat a) (fmat b)
403
404-----------------------------------------------------------------------------------
405foreign import ccall unsafe "chol_l_H" zpotrf :: TCMCM
406foreign import ccall unsafe "chol_l_S" dpotrf :: TMM
407
408cholAux f st a = do
409 r <- createMatrix ColumnMajor n n
410 app2 f mat a mat r st
411 return r
412 where n = rows a
413
414-- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/.
415cholH :: Matrix (Complex Double) -> Matrix (Complex Double)
416cholH = unsafePerformIO . cholAux zpotrf "cholH" . fmat
417
418-- | Cholesky factorization of a real symmetric positive definite matrix, using LAPACK's /dpotrf/.
419cholS :: Matrix Double -> Matrix Double
420cholS = unsafePerformIO . cholAux dpotrf "cholS" . fmat
421
422-- | Cholesky factorization of a complex Hermitian positive definite matrix, using LAPACK's /zpotrf/ ('Maybe' version).
423mbCholH :: Matrix (Complex Double) -> Maybe (Matrix (Complex Double))
424mbCholH = unsafePerformIO . mbCatch . cholAux zpotrf "cholH" . fmat
425
426-- | Cholesky factorization of a real symmetric positive definite matrix, using LAPACK's /dpotrf/ ('Maybe' version).
427mbCholS :: Matrix Double -> Maybe (Matrix Double)
428mbCholS = unsafePerformIO . mbCatch . cholAux dpotrf "cholS" . fmat
429
430-----------------------------------------------------------------------------------
431foreign import ccall unsafe "qr_l_R" dgeqr2 :: TMVM
432foreign import ccall unsafe "qr_l_C" zgeqr2 :: TCMCVCM
433
434-- | QR factorization of a real matrix, using LAPACK's /dgeqr2/.
435qrR :: Matrix Double -> (Matrix Double, Vector Double)
436qrR = qrAux dgeqr2 "qrR" . fmat
437
438-- | QR factorization of a complex matrix, using LAPACK's /zgeqr2/.
439qrC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double))
440qrC = qrAux zgeqr2 "qrC" . fmat
441
442qrAux f st a = unsafePerformIO $ do
443 r <- createMatrix ColumnMajor m n
444 tau <- createVector mn
445 app3 f mat a vec tau mat r st
446 return (r,tau)
447 where
448 m = rows a
449 n = cols a
450 mn = min m n
451
452foreign import ccall unsafe "c_dorgqr" dorgqr :: TMVM
453foreign import ccall unsafe "c_zungqr" zungqr :: TCMCVCM
454
455-- | build rotation from reflectors
456qrgrR :: Int -> (Matrix Double, Vector Double) -> Matrix Double
457qrgrR = qrgrAux dorgqr "qrgrR"
458-- | build rotation from reflectors
459qrgrC :: Int -> (Matrix (Complex Double), Vector (Complex Double)) -> Matrix (Complex Double)
460qrgrC = qrgrAux zungqr "qrgrC"
461
462qrgrAux f st n (a, tau) = unsafePerformIO $ do
463 res <- createMatrix ColumnMajor (rows a) n
464 app3 f mat (fmat a) vec (subVector 0 n tau') mat res st
465 return res
466 where
467 tau' = vjoin [tau, constantD 0 n]
468
469-----------------------------------------------------------------------------------
470foreign import ccall unsafe "hess_l_R" dgehrd :: TMVM
471foreign import ccall unsafe "hess_l_C" zgehrd :: TCMCVCM
472
473-- | Hessenberg factorization of a square real matrix, using LAPACK's /dgehrd/.
474hessR :: Matrix Double -> (Matrix Double, Vector Double)
475hessR = hessAux dgehrd "hessR" . fmat
476
477-- | Hessenberg factorization of a square complex matrix, using LAPACK's /zgehrd/.
478hessC :: Matrix (Complex Double) -> (Matrix (Complex Double), Vector (Complex Double))
479hessC = hessAux zgehrd "hessC" . fmat
480
481hessAux f st a = unsafePerformIO $ do
482 r <- createMatrix ColumnMajor m n
483 tau <- createVector (mn-1)
484 app3 f mat a vec tau mat r st
485 return (r,tau)
486 where m = rows a
487 n = cols a
488 mn = min m n
489
490-----------------------------------------------------------------------------------
491foreign import ccall unsafe "schur_l_R" dgees :: TMMM
492foreign import ccall unsafe "schur_l_C" zgees :: TCMCMCM
493
494-- | Schur factorization of a square real matrix, using LAPACK's /dgees/.
495schurR :: Matrix Double -> (Matrix Double, Matrix Double)
496schurR = schurAux dgees "schurR" . fmat
497
498-- | Schur factorization of a square complex matrix, using LAPACK's /zgees/.
499schurC :: Matrix (Complex Double) -> (Matrix (Complex Double), Matrix (Complex Double))
500schurC = schurAux zgees "schurC" . fmat
501
502schurAux f st a = unsafePerformIO $ do
503 u <- createMatrix ColumnMajor n n
504 s <- createMatrix ColumnMajor n n
505 app3 f mat a mat u mat s st
506 return (u,s)
507 where n = rows a
508
509-----------------------------------------------------------------------------------
510foreign import ccall unsafe "lu_l_R" dgetrf :: TMVM
511foreign import ccall unsafe "lu_l_C" zgetrf :: TCMVCM
512
513-- | LU factorization of a general real matrix, using LAPACK's /dgetrf/.
514luR :: Matrix Double -> (Matrix Double, [Int])
515luR = luAux dgetrf "luR" . fmat
516
517-- | LU factorization of a general complex matrix, using LAPACK's /zgetrf/.
518luC :: Matrix (Complex Double) -> (Matrix (Complex Double), [Int])
519luC = luAux zgetrf "luC" . fmat
520
521luAux f st a = unsafePerformIO $ do
522 lu <- createMatrix ColumnMajor n m
523 piv <- createVector (min n m)
524 app3 f mat a vec piv mat lu st
525 return (lu, map (pred.round) (toList piv))
526 where n = rows a
527 m = cols a
528
529-----------------------------------------------------------------------------------
530type TW a = CInt -> PD -> a
531type TQ a = CInt -> CInt -> PC -> a
532
533foreign import ccall unsafe "luS_l_R" dgetrs :: TMVMM
534foreign import ccall unsafe "luS_l_C" zgetrs :: TQ (TW (TQ (TQ (IO CInt))))
535
536-- | Solve a real linear system from a precomputed LU decomposition ('luR'), using LAPACK's /dgetrs/.
537lusR :: Matrix Double -> [Int] -> Matrix Double -> Matrix Double
538lusR a piv b = lusAux dgetrs "lusR" (fmat a) piv (fmat b)
539
540-- | Solve a real linear system from a precomputed LU decomposition ('luC'), using LAPACK's /zgetrs/.
541lusC :: Matrix (Complex Double) -> [Int] -> Matrix (Complex Double) -> Matrix (Complex Double)
542lusC a piv b = lusAux zgetrs "lusC" (fmat a) piv (fmat b)
543
544lusAux f st a piv b
545 | n1==n2 && n2==n =unsafePerformIO $ do
546 x <- createMatrix ColumnMajor n m
547 app4 f mat a vec piv' mat b mat x st
548 return x
549 | otherwise = error $ st ++ " on LU factorization of nonsquare matrix"
550 where n1 = rows a
551 n2 = cols a
552 n = rows b
553 m = cols b
554 piv' = fromList (map (fromIntegral.succ) piv) :: Vector Double
555
diff --git a/packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK/lapack-aux.c b/packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
new file mode 100644
index 0000000..e5e45ef
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK/lapack-aux.c
@@ -0,0 +1,1489 @@
1#include <stdio.h>
2#include <stdlib.h>
3#include <string.h>
4#include <math.h>
5#include <time.h>
6#include "lapack-aux.h"
7
8#define MACRO(B) do {B} while (0)
9#define ERROR(CODE) MACRO(return CODE;)
10#define REQUIRES(COND, CODE) MACRO(if(!(COND)) {ERROR(CODE);})
11
12#define MIN(A,B) ((A)<(B)?(A):(B))
13#define MAX(A,B) ((A)>(B)?(A):(B))
14
15// #define DBGL
16
17#ifdef DBGL
18#define DEBUGMSG(M) printf("\nLAPACK "M"\n");
19#else
20#define DEBUGMSG(M)
21#endif
22
23#define OK return 0;
24
25// #ifdef DBGL
26// #define DEBUGMSG(M) printf("LAPACK Wrapper "M"\n: "); size_t t0 = time(NULL);
27// #define OK MACRO(printf("%ld s\n",time(0)-t0); return 0;);
28// #else
29// #define DEBUGMSG(M)
30// #define OK return 0;
31// #endif
32
33#define TRACEMAT(M) {int q; printf(" %d x %d: ",M##r,M##c); \
34 for(q=0;q<M##r*M##c;q++) printf("%.1f ",M##p[q]); printf("\n");}
35
36#define CHECK(RES,CODE) MACRO(if(RES) return CODE;)
37
38#define BAD_SIZE 2000
39#define BAD_CODE 2001
40#define MEM 2002
41#define BAD_FILE 2003
42#define SINGULAR 2004
43#define NOCONVER 2005
44#define NODEFPOS 2006
45#define NOSPRTD 2007
46
47//---------------------------------------
48void asm_finit() {
49#ifdef i386
50
51// asm("finit");
52
53 static unsigned char buf[108];
54 asm("FSAVE %0":"=m" (buf));
55
56 #if FPUDEBUG
57 if(buf[8]!=255 || buf[9]!=255) { // print warning in red
58 printf("%c[;31mWarning: FPU TAG = %x %x\%c[0m\n",0x1B,buf[8],buf[9],0x1B);
59 }
60 #endif
61
62 #if NANDEBUG
63 asm("FRSTOR %0":"=m" (buf));
64 #endif
65
66#endif
67}
68
69//---------------------------------------
70
71#if NANDEBUG
72
73#define CHECKNANR(M,msg) \
74{ int k; \
75for(k=0; k<(M##r * M##c); k++) { \
76 if(M##p[k] != M##p[k]) { \
77 printf(msg); \
78 TRACEMAT(M) \
79 /*exit(1);*/ \
80 } \
81} \
82}
83
84#define CHECKNANC(M,msg) \
85{ int k; \
86for(k=0; k<(M##r * M##c); k++) { \
87 if( M##p[k].r != M##p[k].r \
88 || M##p[k].i != M##p[k].i) { \
89 printf(msg); \
90 /*exit(1);*/ \
91 } \
92} \
93}
94
95#else
96#define CHECKNANC(M,msg)
97#define CHECKNANR(M,msg)
98#endif
99
100//---------------------------------------
101
102//////////////////// real svd ////////////////////////////////////
103
104/* Subroutine */ int dgesvd_(char *jobu, char *jobvt, integer *m, integer *n,
105 doublereal *a, integer *lda, doublereal *s, doublereal *u, integer *
106 ldu, doublereal *vt, integer *ldvt, doublereal *work, integer *lwork,
107 integer *info);
108
109int svd_l_R(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) {
110 integer m = ar;
111 integer n = ac;
112 integer q = MIN(m,n);
113 REQUIRES(sn==q,BAD_SIZE);
114 REQUIRES(up==NULL || (ur==m && (uc==m || uc==q)),BAD_SIZE);
115 char* jobu = "A";
116 if (up==NULL) {
117 jobu = "N";
118 } else {
119 if (uc==q) {
120 jobu = "S";
121 }
122 }
123 REQUIRES(vp==NULL || (vc==n && (vr==n || vr==q)),BAD_SIZE);
124 char* jobvt = "A";
125 integer ldvt = n;
126 if (vp==NULL) {
127 jobvt = "N";
128 } else {
129 if (vr==q) {
130 jobvt = "S";
131 ldvt = q;
132 }
133 }
134 DEBUGMSG("svd_l_R");
135 double *B = (double*)malloc(m*n*sizeof(double));
136 CHECK(!B,MEM);
137 memcpy(B,ap,m*n*sizeof(double));
138 integer lwork = -1;
139 integer res;
140 // ask for optimal lwork
141 double ans;
142 dgesvd_ (jobu,jobvt,
143 &m,&n,B,&m,
144 sp,
145 up,&m,
146 vp,&ldvt,
147 &ans, &lwork,
148 &res);
149 lwork = ceil(ans);
150 double * work = (double*)malloc(lwork*sizeof(double));
151 CHECK(!work,MEM);
152 dgesvd_ (jobu,jobvt,
153 &m,&n,B,&m,
154 sp,
155 up,&m,
156 vp,&ldvt,
157 work, &lwork,
158 &res);
159 CHECK(res,res);
160 free(work);
161 free(B);
162 OK
163}
164
165// (alternative version)
166
167/* Subroutine */ int dgesdd_(char *jobz, integer *m, integer *n, doublereal *
168 a, integer *lda, doublereal *s, doublereal *u, integer *ldu,
169 doublereal *vt, integer *ldvt, doublereal *work, integer *lwork,
170 integer *iwork, integer *info);
171
172int svd_l_Rdd(KDMAT(a),DMAT(u), DVEC(s),DMAT(v)) {
173 integer m = ar;
174 integer n = ac;
175 integer q = MIN(m,n);
176 REQUIRES(sn==q,BAD_SIZE);
177 REQUIRES((up == NULL && vp == NULL)
178 || (ur==m && vc==n
179 && ((uc == q && vr == q)
180 || (uc == m && vc==n))),BAD_SIZE);
181 char* jobz = "A";
182 integer ldvt = n;
183 if (up==NULL) {
184 jobz = "N";
185 } else {
186 if (uc==q && vr == q) {
187 jobz = "S";
188 ldvt = q;
189 }
190 }
191 DEBUGMSG("svd_l_Rdd");
192 double *B = (double*)malloc(m*n*sizeof(double));
193 CHECK(!B,MEM);
194 memcpy(B,ap,m*n*sizeof(double));
195 integer* iwk = (integer*) malloc(8*q*sizeof(integer));
196 CHECK(!iwk,MEM);
197 integer lwk = -1;
198 integer res;
199 // ask for optimal lwk
200 double ans;
201 dgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,iwk,&res);
202 lwk = ans;
203 double * workv = (double*)malloc(lwk*sizeof(double));
204 CHECK(!workv,MEM);
205 dgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,workv,&lwk,iwk,&res);
206 CHECK(res,res);
207 free(iwk);
208 free(workv);
209 free(B);
210 OK
211}
212
213//////////////////// complex svd ////////////////////////////////////
214
215// not in clapack.h
216
217int zgesvd_(char *jobu, char *jobvt, integer *m, integer *n,
218 doublecomplex *a, integer *lda, doublereal *s, doublecomplex *u,
219 integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work,
220 integer *lwork, doublereal *rwork, integer *info);
221
222int svd_l_C(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) {
223 integer m = ar;
224 integer n = ac;
225 integer q = MIN(m,n);
226 REQUIRES(sn==q,BAD_SIZE);
227 REQUIRES(up==NULL || (ur==m && (uc==m || uc==q)),BAD_SIZE);
228 char* jobu = "A";
229 if (up==NULL) {
230 jobu = "N";
231 } else {
232 if (uc==q) {
233 jobu = "S";
234 }
235 }
236 REQUIRES(vp==NULL || (vc==n && (vr==n || vr==q)),BAD_SIZE);
237 char* jobvt = "A";
238 integer ldvt = n;
239 if (vp==NULL) {
240 jobvt = "N";
241 } else {
242 if (vr==q) {
243 jobvt = "S";
244 ldvt = q;
245 }
246 }DEBUGMSG("svd_l_C");
247 doublecomplex *B = (doublecomplex*)malloc(m*n*sizeof(doublecomplex));
248 CHECK(!B,MEM);
249 memcpy(B,ap,m*n*sizeof(doublecomplex));
250
251 double *rwork = (double*) malloc(5*q*sizeof(double));
252 CHECK(!rwork,MEM);
253 integer lwork = -1;
254 integer res;
255 // ask for optimal lwork
256 doublecomplex ans;
257 zgesvd_ (jobu,jobvt,
258 &m,&n,B,&m,
259 sp,
260 up,&m,
261 vp,&ldvt,
262 &ans, &lwork,
263 rwork,
264 &res);
265 lwork = ceil(ans.r);
266 doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
267 CHECK(!work,MEM);
268 zgesvd_ (jobu,jobvt,
269 &m,&n,B,&m,
270 sp,
271 up,&m,
272 vp,&ldvt,
273 work, &lwork,
274 rwork,
275 &res);
276 CHECK(res,res);
277 free(work);
278 free(rwork);
279 free(B);
280 OK
281}
282
283int zgesdd_ (char *jobz, integer *m, integer *n,
284 doublecomplex *a, integer *lda, doublereal *s, doublecomplex *u,
285 integer *ldu, doublecomplex *vt, integer *ldvt, doublecomplex *work,
286 integer *lwork, doublereal *rwork, integer* iwork, integer *info);
287
288int svd_l_Cdd(KCMAT(a),CMAT(u), DVEC(s),CMAT(v)) {
289 //printf("entro\n");
290 integer m = ar;
291 integer n = ac;
292 integer q = MIN(m,n);
293 REQUIRES(sn==q,BAD_SIZE);
294 REQUIRES((up == NULL && vp == NULL)
295 || (ur==m && vc==n
296 && ((uc == q && vr == q)
297 || (uc == m && vc==n))),BAD_SIZE);
298 char* jobz = "A";
299 integer ldvt = n;
300 if (up==NULL) {
301 jobz = "N";
302 } else {
303 if (uc==q && vr == q) {
304 jobz = "S";
305 ldvt = q;
306 }
307 }
308 DEBUGMSG("svd_l_Cdd");
309 doublecomplex *B = (doublecomplex*)malloc(m*n*sizeof(doublecomplex));
310 CHECK(!B,MEM);
311 memcpy(B,ap,m*n*sizeof(doublecomplex));
312 integer* iwk = (integer*) malloc(8*q*sizeof(integer));
313 CHECK(!iwk,MEM);
314 int lrwk;
315 if (0 && *jobz == 'N') {
316 lrwk = 5*q; // does not work, crash at free below
317 } else {
318 lrwk = 5*q*q + 7*q;
319 }
320 double *rwk = (double*)malloc(lrwk*sizeof(double));;
321 CHECK(!rwk,MEM);
322 //printf("%s %ld %d\n",jobz,q,lrwk);
323 integer lwk = -1;
324 integer res;
325 // ask for optimal lwk
326 doublecomplex ans;
327 zgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,&ans,&lwk,rwk,iwk,&res);
328 lwk = ans.r;
329 //printf("lwk = %ld\n",lwk);
330 doublecomplex * workv = (doublecomplex*)malloc(lwk*sizeof(doublecomplex));
331 CHECK(!workv,MEM);
332 zgesdd_ (jobz,&m,&n,B,&m,sp,up,&m,vp,&ldvt,workv,&lwk,rwk,iwk,&res);
333 //printf("res = %ld\n",res);
334 CHECK(res,res);
335 free(workv); // printf("freed workv\n");
336 free(rwk); // printf("freed rwk\n");
337 free(iwk); // printf("freed iwk\n");
338 free(B); // printf("freed B, salgo\n");
339 OK
340}
341
342//////////////////// general complex eigensystem ////////////
343
344/* Subroutine */ int zgeev_(char *jobvl, char *jobvr, integer *n,
345 doublecomplex *a, integer *lda, doublecomplex *w, doublecomplex *vl,
346 integer *ldvl, doublecomplex *vr, integer *ldvr, doublecomplex *work,
347 integer *lwork, doublereal *rwork, integer *info);
348
349int eig_l_C(KCMAT(a), CMAT(u), CVEC(s),CMAT(v)) {
350 integer n = ar;
351 REQUIRES(ac==n && sn==n, BAD_SIZE);
352 REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE);
353 char jobvl = up==NULL?'N':'V';
354 REQUIRES(vp==NULL || (vr==n && vc==n), BAD_SIZE);
355 char jobvr = vp==NULL?'N':'V';
356 DEBUGMSG("eig_l_C");
357 doublecomplex *B = (doublecomplex*)malloc(n*n*sizeof(doublecomplex));
358 CHECK(!B,MEM);
359 memcpy(B,ap,n*n*sizeof(doublecomplex));
360 double *rwork = (double*) malloc(2*n*sizeof(double));
361 CHECK(!rwork,MEM);
362 integer lwork = -1;
363 integer res;
364 // ask for optimal lwork
365 doublecomplex ans;
366 //printf("ask zgeev\n");
367 zgeev_ (&jobvl,&jobvr,
368 &n,B,&n,
369 sp,
370 up,&n,
371 vp,&n,
372 &ans, &lwork,
373 rwork,
374 &res);
375 lwork = ceil(ans.r);
376 //printf("ans = %d\n",lwork);
377 doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
378 CHECK(!work,MEM);
379 //printf("zgeev\n");
380 zgeev_ (&jobvl,&jobvr,
381 &n,B,&n,
382 sp,
383 up,&n,
384 vp,&n,
385 work, &lwork,
386 rwork,
387 &res);
388 CHECK(res,res);
389 free(work);
390 free(rwork);
391 free(B);
392 OK
393}
394
395
396
397//////////////////// general real eigensystem ////////////
398
399/* Subroutine */ int dgeev_(char *jobvl, char *jobvr, integer *n, doublereal *
400 a, integer *lda, doublereal *wr, doublereal *wi, doublereal *vl,
401 integer *ldvl, doublereal *vr, integer *ldvr, doublereal *work,
402 integer *lwork, integer *info);
403
404int eig_l_R(KDMAT(a),DMAT(u), CVEC(s),DMAT(v)) {
405 integer n = ar;
406 REQUIRES(ac==n && sn==n, BAD_SIZE);
407 REQUIRES(up==NULL || (ur==n && uc==n), BAD_SIZE);
408 char jobvl = up==NULL?'N':'V';
409 REQUIRES(vp==NULL || (vr==n && vc==n), BAD_SIZE);
410 char jobvr = vp==NULL?'N':'V';
411 DEBUGMSG("eig_l_R");
412 double *B = (double*)malloc(n*n*sizeof(double));
413 CHECK(!B,MEM);
414 memcpy(B,ap,n*n*sizeof(double));
415 integer lwork = -1;
416 integer res;
417 // ask for optimal lwork
418 double ans;
419 //printf("ask dgeev\n");
420 dgeev_ (&jobvl,&jobvr,
421 &n,B,&n,
422 (double*)sp, (double*)sp+n,
423 up,&n,
424 vp,&n,
425 &ans, &lwork,
426 &res);
427 lwork = ceil(ans);
428 //printf("ans = %d\n",lwork);
429 double * work = (double*)malloc(lwork*sizeof(double));
430 CHECK(!work,MEM);
431 //printf("dgeev\n");
432 dgeev_ (&jobvl,&jobvr,
433 &n,B,&n,
434 (double*)sp, (double*)sp+n,
435 up,&n,
436 vp,&n,
437 work, &lwork,
438 &res);
439 CHECK(res,res);
440 free(work);
441 free(B);
442 OK
443}
444
445
446//////////////////// symmetric real eigensystem ////////////
447
448/* Subroutine */ int dsyev_(char *jobz, char *uplo, integer *n, doublereal *a,
449 integer *lda, doublereal *w, doublereal *work, integer *lwork,
450 integer *info);
451
452int eig_l_S(int wantV,KDMAT(a),DVEC(s),DMAT(v)) {
453 integer n = ar;
454 REQUIRES(ac==n && sn==n, BAD_SIZE);
455 REQUIRES(vr==n && vc==n, BAD_SIZE);
456 char jobz = wantV?'V':'N';
457 DEBUGMSG("eig_l_S");
458 memcpy(vp,ap,n*n*sizeof(double));
459 integer lwork = -1;
460 char uplo = 'U';
461 integer res;
462 // ask for optimal lwork
463 double ans;
464 //printf("ask dsyev\n");
465 dsyev_ (&jobz,&uplo,
466 &n,vp,&n,
467 sp,
468 &ans, &lwork,
469 &res);
470 lwork = ceil(ans);
471 //printf("ans = %d\n",lwork);
472 double * work = (double*)malloc(lwork*sizeof(double));
473 CHECK(!work,MEM);
474 dsyev_ (&jobz,&uplo,
475 &n,vp,&n,
476 sp,
477 work, &lwork,
478 &res);
479 CHECK(res,res);
480 free(work);
481 OK
482}
483
484//////////////////// hermitian complex eigensystem ////////////
485
486/* Subroutine */ int zheev_(char *jobz, char *uplo, integer *n, doublecomplex
487 *a, integer *lda, doublereal *w, doublecomplex *work, integer *lwork,
488 doublereal *rwork, integer *info);
489
490int eig_l_H(int wantV,KCMAT(a),DVEC(s),CMAT(v)) {
491 integer n = ar;
492 REQUIRES(ac==n && sn==n, BAD_SIZE);
493 REQUIRES(vr==n && vc==n, BAD_SIZE);
494 char jobz = wantV?'V':'N';
495 DEBUGMSG("eig_l_H");
496 memcpy(vp,ap,2*n*n*sizeof(double));
497 double *rwork = (double*) malloc((3*n-2)*sizeof(double));
498 CHECK(!rwork,MEM);
499 integer lwork = -1;
500 char uplo = 'U';
501 integer res;
502 // ask for optimal lwork
503 doublecomplex ans;
504 //printf("ask zheev\n");
505 zheev_ (&jobz,&uplo,
506 &n,vp,&n,
507 sp,
508 &ans, &lwork,
509 rwork,
510 &res);
511 lwork = ceil(ans.r);
512 //printf("ans = %d\n",lwork);
513 doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
514 CHECK(!work,MEM);
515 zheev_ (&jobz,&uplo,
516 &n,vp,&n,
517 sp,
518 work, &lwork,
519 rwork,
520 &res);
521 CHECK(res,res);
522 free(work);
523 free(rwork);
524 OK
525}
526
527//////////////////// general real linear system ////////////
528
529/* Subroutine */ int dgesv_(integer *n, integer *nrhs, doublereal *a, integer
530 *lda, integer *ipiv, doublereal *b, integer *ldb, integer *info);
531
532int linearSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) {
533 integer n = ar;
534 integer nhrs = bc;
535 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
536 DEBUGMSG("linearSolveR_l");
537 double*AC = (double*)malloc(n*n*sizeof(double));
538 memcpy(AC,ap,n*n*sizeof(double));
539 memcpy(xp,bp,n*nhrs*sizeof(double));
540 integer * ipiv = (integer*)malloc(n*sizeof(integer));
541 integer res;
542 dgesv_ (&n,&nhrs,
543 AC, &n,
544 ipiv,
545 xp, &n,
546 &res);
547 if(res>0) {
548 return SINGULAR;
549 }
550 CHECK(res,res);
551 free(ipiv);
552 free(AC);
553 OK
554}
555
556//////////////////// general complex linear system ////////////
557
558/* Subroutine */ int zgesv_(integer *n, integer *nrhs, doublecomplex *a,
559 integer *lda, integer *ipiv, doublecomplex *b, integer *ldb, integer *
560 info);
561
562int linearSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) {
563 integer n = ar;
564 integer nhrs = bc;
565 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
566 DEBUGMSG("linearSolveC_l");
567 doublecomplex*AC = (doublecomplex*)malloc(n*n*sizeof(doublecomplex));
568 memcpy(AC,ap,n*n*sizeof(doublecomplex));
569 memcpy(xp,bp,n*nhrs*sizeof(doublecomplex));
570 integer * ipiv = (integer*)malloc(n*sizeof(integer));
571 integer res;
572 zgesv_ (&n,&nhrs,
573 AC, &n,
574 ipiv,
575 xp, &n,
576 &res);
577 if(res>0) {
578 return SINGULAR;
579 }
580 CHECK(res,res);
581 free(ipiv);
582 free(AC);
583 OK
584}
585
586//////// symmetric positive definite real linear system using Cholesky ////////////
587
588/* Subroutine */ int dpotrs_(char *uplo, integer *n, integer *nrhs,
589 doublereal *a, integer *lda, doublereal *b, integer *ldb, integer *
590 info);
591
592int cholSolveR_l(KDMAT(a),KDMAT(b),DMAT(x)) {
593 integer n = ar;
594 integer nhrs = bc;
595 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
596 DEBUGMSG("cholSolveR_l");
597 memcpy(xp,bp,n*nhrs*sizeof(double));
598 integer res;
599 dpotrs_ ("U",
600 &n,&nhrs,
601 (double*)ap, &n,
602 xp, &n,
603 &res);
604 CHECK(res,res);
605 OK
606}
607
608//////// Hermitian positive definite real linear system using Cholesky ////////////
609
610/* Subroutine */ int zpotrs_(char *uplo, integer *n, integer *nrhs,
611 doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb,
612 integer *info);
613
614int cholSolveC_l(KCMAT(a),KCMAT(b),CMAT(x)) {
615 integer n = ar;
616 integer nhrs = bc;
617 REQUIRES(n>=1 && ar==ac && ar==br,BAD_SIZE);
618 DEBUGMSG("cholSolveC_l");
619 memcpy(xp,bp,n*nhrs*sizeof(doublecomplex));
620 integer res;
621 zpotrs_ ("U",
622 &n,&nhrs,
623 (doublecomplex*)ap, &n,
624 xp, &n,
625 &res);
626 CHECK(res,res);
627 OK
628}
629
630//////////////////// least squares real linear system ////////////
631
632/* Subroutine */ int dgels_(char *trans, integer *m, integer *n, integer *
633 nrhs, doublereal *a, integer *lda, doublereal *b, integer *ldb,
634 doublereal *work, integer *lwork, integer *info);
635
636int linearSolveLSR_l(KDMAT(a),KDMAT(b),DMAT(x)) {
637 integer m = ar;
638 integer n = ac;
639 integer nrhs = bc;
640 integer ldb = xr;
641 REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE);
642 DEBUGMSG("linearSolveLSR_l");
643 double*AC = (double*)malloc(m*n*sizeof(double));
644 memcpy(AC,ap,m*n*sizeof(double));
645 if (m>=n) {
646 memcpy(xp,bp,m*nrhs*sizeof(double));
647 } else {
648 int k;
649 for(k = 0; k<nrhs; k++) {
650 memcpy(xp+ldb*k,bp+m*k,m*sizeof(double));
651 }
652 }
653 integer res;
654 integer lwork = -1;
655 double ans;
656 dgels_ ("N",&m,&n,&nrhs,
657 AC,&m,
658 xp,&ldb,
659 &ans,&lwork,
660 &res);
661 lwork = ceil(ans);
662 //printf("ans = %d\n",lwork);
663 double * work = (double*)malloc(lwork*sizeof(double));
664 dgels_ ("N",&m,&n,&nrhs,
665 AC,&m,
666 xp,&ldb,
667 work,&lwork,
668 &res);
669 if(res>0) {
670 return SINGULAR;
671 }
672 CHECK(res,res);
673 free(work);
674 free(AC);
675 OK
676}
677
678//////////////////// least squares complex linear system ////////////
679
680/* Subroutine */ int zgels_(char *trans, integer *m, integer *n, integer *
681 nrhs, doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb,
682 doublecomplex *work, integer *lwork, integer *info);
683
684int linearSolveLSC_l(KCMAT(a),KCMAT(b),CMAT(x)) {
685 integer m = ar;
686 integer n = ac;
687 integer nrhs = bc;
688 integer ldb = xr;
689 REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE);
690 DEBUGMSG("linearSolveLSC_l");
691 doublecomplex*AC = (doublecomplex*)malloc(m*n*sizeof(doublecomplex));
692 memcpy(AC,ap,m*n*sizeof(doublecomplex));
693 if (m>=n) {
694 memcpy(xp,bp,m*nrhs*sizeof(doublecomplex));
695 } else {
696 int k;
697 for(k = 0; k<nrhs; k++) {
698 memcpy(xp+ldb*k,bp+m*k,m*sizeof(doublecomplex));
699 }
700 }
701 integer res;
702 integer lwork = -1;
703 doublecomplex ans;
704 zgels_ ("N",&m,&n,&nrhs,
705 AC,&m,
706 xp,&ldb,
707 &ans,&lwork,
708 &res);
709 lwork = ceil(ans.r);
710 //printf("ans = %d\n",lwork);
711 doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
712 zgels_ ("N",&m,&n,&nrhs,
713 AC,&m,
714 xp,&ldb,
715 work,&lwork,
716 &res);
717 if(res>0) {
718 return SINGULAR;
719 }
720 CHECK(res,res);
721 free(work);
722 free(AC);
723 OK
724}
725
726//////////////////// least squares real linear system using SVD ////////////
727
728/* Subroutine */ int dgelss_(integer *m, integer *n, integer *nrhs,
729 doublereal *a, integer *lda, doublereal *b, integer *ldb, doublereal *
730 s, doublereal *rcond, integer *rank, doublereal *work, integer *lwork,
731 integer *info);
732
733int linearSolveSVDR_l(double rcond,KDMAT(a),KDMAT(b),DMAT(x)) {
734 integer m = ar;
735 integer n = ac;
736 integer nrhs = bc;
737 integer ldb = xr;
738 REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE);
739 DEBUGMSG("linearSolveSVDR_l");
740 double*AC = (double*)malloc(m*n*sizeof(double));
741 double*S = (double*)malloc(MIN(m,n)*sizeof(double));
742 memcpy(AC,ap,m*n*sizeof(double));
743 if (m>=n) {
744 memcpy(xp,bp,m*nrhs*sizeof(double));
745 } else {
746 int k;
747 for(k = 0; k<nrhs; k++) {
748 memcpy(xp+ldb*k,bp+m*k,m*sizeof(double));
749 }
750 }
751 integer res;
752 integer lwork = -1;
753 integer rank;
754 double ans;
755 dgelss_ (&m,&n,&nrhs,
756 AC,&m,
757 xp,&ldb,
758 S,
759 &rcond,&rank,
760 &ans,&lwork,
761 &res);
762 lwork = ceil(ans);
763 //printf("ans = %d\n",lwork);
764 double * work = (double*)malloc(lwork*sizeof(double));
765 dgelss_ (&m,&n,&nrhs,
766 AC,&m,
767 xp,&ldb,
768 S,
769 &rcond,&rank,
770 work,&lwork,
771 &res);
772 if(res>0) {
773 return NOCONVER;
774 }
775 CHECK(res,res);
776 free(work);
777 free(S);
778 free(AC);
779 OK
780}
781
782//////////////////// least squares complex linear system using SVD ////////////
783
784// not in clapack.h
785
786int zgelss_(integer *m, integer *n, integer *nhrs,
787 doublecomplex *a, integer *lda, doublecomplex *b, integer *ldb, doublereal *s,
788 doublereal *rcond, integer* rank,
789 doublecomplex *work, integer* lwork, doublereal* rwork,
790 integer *info);
791
792int linearSolveSVDC_l(double rcond, KCMAT(a),KCMAT(b),CMAT(x)) {
793 integer m = ar;
794 integer n = ac;
795 integer nrhs = bc;
796 integer ldb = xr;
797 REQUIRES(m>=1 && n>=1 && ar==br && xr==MAX(m,n) && xc == bc, BAD_SIZE);
798 DEBUGMSG("linearSolveSVDC_l");
799 doublecomplex*AC = (doublecomplex*)malloc(m*n*sizeof(doublecomplex));
800 double*S = (double*)malloc(MIN(m,n)*sizeof(double));
801 double*RWORK = (double*)malloc(5*MIN(m,n)*sizeof(double));
802 memcpy(AC,ap,m*n*sizeof(doublecomplex));
803 if (m>=n) {
804 memcpy(xp,bp,m*nrhs*sizeof(doublecomplex));
805 } else {
806 int k;
807 for(k = 0; k<nrhs; k++) {
808 memcpy(xp+ldb*k,bp+m*k,m*sizeof(doublecomplex));
809 }
810 }
811 integer res;
812 integer lwork = -1;
813 integer rank;
814 doublecomplex ans;
815 zgelss_ (&m,&n,&nrhs,
816 AC,&m,
817 xp,&ldb,
818 S,
819 &rcond,&rank,
820 &ans,&lwork,
821 RWORK,
822 &res);
823 lwork = ceil(ans.r);
824 //printf("ans = %d\n",lwork);
825 doublecomplex * work = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
826 zgelss_ (&m,&n,&nrhs,
827 AC,&m,
828 xp,&ldb,
829 S,
830 &rcond,&rank,
831 work,&lwork,
832 RWORK,
833 &res);
834 if(res>0) {
835 return NOCONVER;
836 }
837 CHECK(res,res);
838 free(work);
839 free(RWORK);
840 free(S);
841 free(AC);
842 OK
843}
844
845//////////////////// Cholesky factorization /////////////////////////
846
847/* Subroutine */ int zpotrf_(char *uplo, integer *n, doublecomplex *a,
848 integer *lda, integer *info);
849
850int chol_l_H(KCMAT(a),CMAT(l)) {
851 integer n = ar;
852 REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE);
853 DEBUGMSG("chol_l_H");
854 memcpy(lp,ap,n*n*sizeof(doublecomplex));
855 char uplo = 'U';
856 integer res;
857 zpotrf_ (&uplo,&n,lp,&n,&res);
858 CHECK(res>0,NODEFPOS);
859 CHECK(res,res);
860 doublecomplex zero = {0.,0.};
861 int r,c;
862 for (r=0; r<lr-1; r++) {
863 for(c=r+1; c<lc; c++) {
864 lp[r*lc+c] = zero;
865 }
866 }
867 OK
868}
869
870
871/* Subroutine */ int dpotrf_(char *uplo, integer *n, doublereal *a, integer *
872 lda, integer *info);
873
874int chol_l_S(KDMAT(a),DMAT(l)) {
875 integer n = ar;
876 REQUIRES(n>=1 && ac == n && lr==n && lc==n,BAD_SIZE);
877 DEBUGMSG("chol_l_S");
878 memcpy(lp,ap,n*n*sizeof(double));
879 char uplo = 'U';
880 integer res;
881 dpotrf_ (&uplo,&n,lp,&n,&res);
882 CHECK(res>0,NODEFPOS);
883 CHECK(res,res);
884 int r,c;
885 for (r=0; r<lr-1; r++) {
886 for(c=r+1; c<lc; c++) {
887 lp[r*lc+c] = 0.;
888 }
889 }
890 OK
891}
892
893//////////////////// QR factorization /////////////////////////
894
895/* Subroutine */ int dgeqr2_(integer *m, integer *n, doublereal *a, integer *
896 lda, doublereal *tau, doublereal *work, integer *info);
897
898int qr_l_R(KDMAT(a), DVEC(tau), DMAT(r)) {
899 integer m = ar;
900 integer n = ac;
901 integer mn = MIN(m,n);
902 REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE);
903 DEBUGMSG("qr_l_R");
904 double *WORK = (double*)malloc(n*sizeof(double));
905 CHECK(!WORK,MEM);
906 memcpy(rp,ap,m*n*sizeof(double));
907 integer res;
908 dgeqr2_ (&m,&n,rp,&m,taup,WORK,&res);
909 CHECK(res,res);
910 free(WORK);
911 OK
912}
913
914/* Subroutine */ int zgeqr2_(integer *m, integer *n, doublecomplex *a,
915 integer *lda, doublecomplex *tau, doublecomplex *work, integer *info);
916
917int qr_l_C(KCMAT(a), CVEC(tau), CMAT(r)) {
918 integer m = ar;
919 integer n = ac;
920 integer mn = MIN(m,n);
921 REQUIRES(m>=1 && n >=1 && rr== m && rc == n && taun == mn, BAD_SIZE);
922 DEBUGMSG("qr_l_C");
923 doublecomplex *WORK = (doublecomplex*)malloc(n*sizeof(doublecomplex));
924 CHECK(!WORK,MEM);
925 memcpy(rp,ap,m*n*sizeof(doublecomplex));
926 integer res;
927 zgeqr2_ (&m,&n,rp,&m,taup,WORK,&res);
928 CHECK(res,res);
929 free(WORK);
930 OK
931}
932
933/* Subroutine */ int dorgqr_(integer *m, integer *n, integer *k, doublereal *
934 a, integer *lda, doublereal *tau, doublereal *work, integer *lwork,
935 integer *info);
936
937int c_dorgqr(KDMAT(a), KDVEC(tau), DMAT(r)) {
938 integer m = ar;
939 integer n = MIN(ac,ar);
940 integer k = taun;
941 DEBUGMSG("c_dorgqr");
942 integer lwork = 8*n; // FIXME
943 double *WORK = (double*)malloc(lwork*sizeof(double));
944 CHECK(!WORK,MEM);
945 memcpy(rp,ap,m*k*sizeof(double));
946 integer res;
947 dorgqr_ (&m,&n,&k,rp,&m,(double*)taup,WORK,&lwork,&res);
948 CHECK(res,res);
949 free(WORK);
950 OK
951}
952
953/* Subroutine */ int zungqr_(integer *m, integer *n, integer *k,
954 doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex *
955 work, integer *lwork, integer *info);
956
957int c_zungqr(KCMAT(a), KCVEC(tau), CMAT(r)) {
958 integer m = ar;
959 integer n = MIN(ac,ar);
960 integer k = taun;
961 DEBUGMSG("z_ungqr");
962 integer lwork = 8*n; // FIXME
963 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
964 CHECK(!WORK,MEM);
965 memcpy(rp,ap,m*k*sizeof(doublecomplex));
966 integer res;
967 zungqr_ (&m,&n,&k,rp,&m,(doublecomplex*)taup,WORK,&lwork,&res);
968 CHECK(res,res);
969 free(WORK);
970 OK
971}
972
973
974//////////////////// Hessenberg factorization /////////////////////////
975
976/* Subroutine */ int dgehrd_(integer *n, integer *ilo, integer *ihi,
977 doublereal *a, integer *lda, doublereal *tau, doublereal *work,
978 integer *lwork, integer *info);
979
980int hess_l_R(KDMAT(a), DVEC(tau), DMAT(r)) {
981 integer m = ar;
982 integer n = ac;
983 integer mn = MIN(m,n);
984 REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE);
985 DEBUGMSG("hess_l_R");
986 integer lwork = 5*n; // fixme
987 double *WORK = (double*)malloc(lwork*sizeof(double));
988 CHECK(!WORK,MEM);
989 memcpy(rp,ap,m*n*sizeof(double));
990 integer res;
991 integer one = 1;
992 dgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res);
993 CHECK(res,res);
994 free(WORK);
995 OK
996}
997
998
999/* Subroutine */ int zgehrd_(integer *n, integer *ilo, integer *ihi,
1000 doublecomplex *a, integer *lda, doublecomplex *tau, doublecomplex *
1001 work, integer *lwork, integer *info);
1002
1003int hess_l_C(KCMAT(a), CVEC(tau), CMAT(r)) {
1004 integer m = ar;
1005 integer n = ac;
1006 integer mn = MIN(m,n);
1007 REQUIRES(m>=1 && n == m && rr== m && rc == n && taun == mn-1, BAD_SIZE);
1008 DEBUGMSG("hess_l_C");
1009 integer lwork = 5*n; // fixme
1010 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
1011 CHECK(!WORK,MEM);
1012 memcpy(rp,ap,m*n*sizeof(doublecomplex));
1013 integer res;
1014 integer one = 1;
1015 zgehrd_ (&n,&one,&n,rp,&n,taup,WORK,&lwork,&res);
1016 CHECK(res,res);
1017 free(WORK);
1018 OK
1019}
1020
1021//////////////////// Schur factorization /////////////////////////
1022
1023/* Subroutine */ int dgees_(char *jobvs, char *sort, L_fp select, integer *n,
1024 doublereal *a, integer *lda, integer *sdim, doublereal *wr,
1025 doublereal *wi, doublereal *vs, integer *ldvs, doublereal *work,
1026 integer *lwork, logical *bwork, integer *info);
1027
1028int schur_l_R(KDMAT(a), DMAT(u), DMAT(s)) {
1029 integer m = ar;
1030 integer n = ac;
1031 REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE);
1032 DEBUGMSG("schur_l_R");
1033 //int k;
1034 //printf("---------------------------\n");
1035 //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n");
1036 //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n");
1037 //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n");
1038 memcpy(sp,ap,n*n*sizeof(double));
1039 integer lwork = 6*n; // fixme
1040 double *WORK = (double*)malloc(lwork*sizeof(double));
1041 double *WR = (double*)malloc(n*sizeof(double));
1042 double *WI = (double*)malloc(n*sizeof(double));
1043 // WR and WI not really required in this call
1044 logical *BWORK = (logical*)malloc(n*sizeof(logical));
1045 integer res;
1046 integer sdim;
1047 dgees_ ("V","N",NULL,&n,sp,&n,&sdim,WR,WI,up,&n,WORK,&lwork,BWORK,&res);
1048 //printf("%p: ",ap); for(k=0;k<n*n;k++) printf("%f ",ap[k]); printf("\n");
1049 //printf("%p: ",up); for(k=0;k<n*n;k++) printf("%f ",up[k]); printf("\n");
1050 //printf("%p: ",sp); for(k=0;k<n*n;k++) printf("%f ",sp[k]); printf("\n");
1051 if(res>0) {
1052 return NOCONVER;
1053 }
1054 CHECK(res,res);
1055 free(WR);
1056 free(WI);
1057 free(BWORK);
1058 free(WORK);
1059 OK
1060}
1061
1062
1063/* Subroutine */ int zgees_(char *jobvs, char *sort, L_fp select, integer *n,
1064 doublecomplex *a, integer *lda, integer *sdim, doublecomplex *w,
1065 doublecomplex *vs, integer *ldvs, doublecomplex *work, integer *lwork,
1066 doublereal *rwork, logical *bwork, integer *info);
1067
1068int schur_l_C(KCMAT(a), CMAT(u), CMAT(s)) {
1069 integer m = ar;
1070 integer n = ac;
1071 REQUIRES(m>=1 && n==m && ur==n && uc==n && sr==n && sc==n, BAD_SIZE);
1072 DEBUGMSG("schur_l_C");
1073 memcpy(sp,ap,n*n*sizeof(doublecomplex));
1074 integer lwork = 6*n; // fixme
1075 doublecomplex *WORK = (doublecomplex*)malloc(lwork*sizeof(doublecomplex));
1076 doublecomplex *W = (doublecomplex*)malloc(n*sizeof(doublecomplex));
1077 // W not really required in this call
1078 logical *BWORK = (logical*)malloc(n*sizeof(logical));
1079 double *RWORK = (double*)malloc(n*sizeof(double));
1080 integer res;
1081 integer sdim;
1082 zgees_ ("V","N",NULL,&n,sp,&n,&sdim,W,
1083 up,&n,
1084 WORK,&lwork,RWORK,BWORK,&res);
1085 if(res>0) {
1086 return NOCONVER;
1087 }
1088 CHECK(res,res);
1089 free(W);
1090 free(BWORK);
1091 free(WORK);
1092 OK
1093}
1094
1095//////////////////// LU factorization /////////////////////////
1096
1097/* Subroutine */ int dgetrf_(integer *m, integer *n, doublereal *a, integer *
1098 lda, integer *ipiv, integer *info);
1099
1100int lu_l_R(KDMAT(a), DVEC(ipiv), DMAT(r)) {
1101 integer m = ar;
1102 integer n = ac;
1103 integer mn = MIN(m,n);
1104 REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE);
1105 DEBUGMSG("lu_l_R");
1106 integer* auxipiv = (integer*)malloc(mn*sizeof(integer));
1107 memcpy(rp,ap,m*n*sizeof(double));
1108 integer res;
1109 dgetrf_ (&m,&n,rp,&m,auxipiv,&res);
1110 if(res>0) {
1111 res = 0; // fixme
1112 }
1113 CHECK(res,res);
1114 int k;
1115 for (k=0; k<mn; k++) {
1116 ipivp[k] = auxipiv[k];
1117 }
1118 free(auxipiv);
1119 OK
1120}
1121
1122
1123/* Subroutine */ int zgetrf_(integer *m, integer *n, doublecomplex *a,
1124 integer *lda, integer *ipiv, integer *info);
1125
1126int lu_l_C(KCMAT(a), DVEC(ipiv), CMAT(r)) {
1127 integer m = ar;
1128 integer n = ac;
1129 integer mn = MIN(m,n);
1130 REQUIRES(m>=1 && n >=1 && ipivn == mn, BAD_SIZE);
1131 DEBUGMSG("lu_l_C");
1132 integer* auxipiv = (integer*)malloc(mn*sizeof(integer));
1133 memcpy(rp,ap,m*n*sizeof(doublecomplex));
1134 integer res;
1135 zgetrf_ (&m,&n,rp,&m,auxipiv,&res);
1136 if(res>0) {
1137 res = 0; // fixme
1138 }
1139 CHECK(res,res);
1140 int k;
1141 for (k=0; k<mn; k++) {
1142 ipivp[k] = auxipiv[k];
1143 }
1144 free(auxipiv);
1145 OK
1146}
1147
1148
1149//////////////////// LU substitution /////////////////////////
1150
1151/* Subroutine */ int dgetrs_(char *trans, integer *n, integer *nrhs,
1152 doublereal *a, integer *lda, integer *ipiv, doublereal *b, integer *
1153 ldb, integer *info);
1154
1155int luS_l_R(KDMAT(a), KDVEC(ipiv), KDMAT(b), DMAT(x)) {
1156 integer m = ar;
1157 integer n = ac;
1158 integer mrhs = br;
1159 integer nrhs = bc;
1160
1161 REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE);
1162 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1163 int k;
1164 for (k=0; k<n; k++) {
1165 auxipiv[k] = (integer)ipivp[k];
1166 }
1167 integer res;
1168 memcpy(xp,bp,mrhs*nrhs*sizeof(double));
1169 dgetrs_ ("N",&n,&nrhs,(/*no const (!?)*/ double*)ap,&m,auxipiv,xp,&mrhs,&res);
1170 CHECK(res,res);
1171 free(auxipiv);
1172 OK
1173}
1174
1175
1176/* Subroutine */ int zgetrs_(char *trans, integer *n, integer *nrhs,
1177 doublecomplex *a, integer *lda, integer *ipiv, doublecomplex *b,
1178 integer *ldb, integer *info);
1179
1180int luS_l_C(KCMAT(a), KDVEC(ipiv), KCMAT(b), CMAT(x)) {
1181 integer m = ar;
1182 integer n = ac;
1183 integer mrhs = br;
1184 integer nrhs = bc;
1185
1186 REQUIRES(m==n && m==mrhs && m==ipivn,BAD_SIZE);
1187 integer* auxipiv = (integer*)malloc(n*sizeof(integer));
1188 int k;
1189 for (k=0; k<n; k++) {
1190 auxipiv[k] = (integer)ipivp[k];
1191 }
1192 integer res;
1193 memcpy(xp,bp,mrhs*nrhs*sizeof(doublecomplex));
1194 zgetrs_ ("N",&n,&nrhs,(doublecomplex*)ap,&m,auxipiv,xp,&mrhs,&res);
1195 CHECK(res,res);
1196 free(auxipiv);
1197 OK
1198}
1199
1200//////////////////// Matrix Product /////////////////////////
1201
1202void dgemm_(char *, char *, integer *, integer *, integer *,
1203 double *, const double *, integer *, const double *,
1204 integer *, double *, double *, integer *);
1205
1206int multiplyR(int ta, int tb, KDMAT(a),KDMAT(b),DMAT(r)) {
1207 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1208 DEBUGMSG("dgemm_");
1209 CHECKNANR(a,"NaN multR Input\n")
1210 CHECKNANR(b,"NaN multR Input\n")
1211 integer m = ta?ac:ar;
1212 integer n = tb?br:bc;
1213 integer k = ta?ar:ac;
1214 integer lda = ar;
1215 integer ldb = br;
1216 integer ldc = rr;
1217 double alpha = 1;
1218 double beta = 0;
1219 dgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
1220 CHECKNANR(r,"NaN multR Output\n")
1221 OK
1222}
1223
1224void zgemm_(char *, char *, integer *, integer *, integer *,
1225 doublecomplex *, const doublecomplex *, integer *, const doublecomplex *,
1226 integer *, doublecomplex *, doublecomplex *, integer *);
1227
1228int multiplyC(int ta, int tb, KCMAT(a),KCMAT(b),CMAT(r)) {
1229 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1230 DEBUGMSG("zgemm_");
1231 CHECKNANC(a,"NaN multC Input\n")
1232 CHECKNANC(b,"NaN multC Input\n")
1233 integer m = ta?ac:ar;
1234 integer n = tb?br:bc;
1235 integer k = ta?ar:ac;
1236 integer lda = ar;
1237 integer ldb = br;
1238 integer ldc = rr;
1239 doublecomplex alpha = {1,0};
1240 doublecomplex beta = {0,0};
1241 zgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
1242 ap,&lda,
1243 bp,&ldb,&beta,
1244 rp,&ldc);
1245 CHECKNANC(r,"NaN multC Output\n")
1246 OK
1247}
1248
1249void sgemm_(char *, char *, integer *, integer *, integer *,
1250 float *, const float *, integer *, const float *,
1251 integer *, float *, float *, integer *);
1252
1253int multiplyF(int ta, int tb, KFMAT(a),KFMAT(b),FMAT(r)) {
1254 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1255 DEBUGMSG("sgemm_");
1256 integer m = ta?ac:ar;
1257 integer n = tb?br:bc;
1258 integer k = ta?ar:ac;
1259 integer lda = ar;
1260 integer ldb = br;
1261 integer ldc = rr;
1262 float alpha = 1;
1263 float beta = 0;
1264 sgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,ap,&lda,bp,&ldb,&beta,rp,&ldc);
1265 OK
1266}
1267
1268void cgemm_(char *, char *, integer *, integer *, integer *,
1269 complex *, const complex *, integer *, const complex *,
1270 integer *, complex *, complex *, integer *);
1271
1272int multiplyQ(int ta, int tb, KQMAT(a),KQMAT(b),QMAT(r)) {
1273 //REQUIRES(ac==br && ar==rr && bc==rc,BAD_SIZE);
1274 DEBUGMSG("cgemm_");
1275 integer m = ta?ac:ar;
1276 integer n = tb?br:bc;
1277 integer k = ta?ar:ac;
1278 integer lda = ar;
1279 integer ldb = br;
1280 integer ldc = rr;
1281 complex alpha = {1,0};
1282 complex beta = {0,0};
1283 cgemm_(ta?"T":"N",tb?"T":"N",&m,&n,&k,&alpha,
1284 ap,&lda,
1285 bp,&ldb,&beta,
1286 rp,&ldc);
1287 OK
1288}
1289
1290//////////////////// transpose /////////////////////////
1291
1292int transF(KFMAT(x),FMAT(t)) {
1293 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
1294 DEBUGMSG("transF");
1295 int i,j;
1296 for (i=0; i<tr; i++) {
1297 for (j=0; j<tc; j++) {
1298 tp[i*tc+j] = xp[j*xc+i];
1299 }
1300 }
1301 OK
1302}
1303
1304int transR(KDMAT(x),DMAT(t)) {
1305 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
1306 DEBUGMSG("transR");
1307 int i,j;
1308 for (i=0; i<tr; i++) {
1309 for (j=0; j<tc; j++) {
1310 tp[i*tc+j] = xp[j*xc+i];
1311 }
1312 }
1313 OK
1314}
1315
1316int transQ(KQMAT(x),QMAT(t)) {
1317 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
1318 DEBUGMSG("transQ");
1319 int i,j;
1320 for (i=0; i<tr; i++) {
1321 for (j=0; j<tc; j++) {
1322 tp[i*tc+j] = xp[j*xc+i];
1323 }
1324 }
1325 OK
1326}
1327
1328int transC(KCMAT(x),CMAT(t)) {
1329 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
1330 DEBUGMSG("transC");
1331 int i,j;
1332 for (i=0; i<tr; i++) {
1333 for (j=0; j<tc; j++) {
1334 tp[i*tc+j] = xp[j*xc+i];
1335 }
1336 }
1337 OK
1338}
1339
1340int transP(KPMAT(x), PMAT(t)) {
1341 REQUIRES(xr==tc && xc==tr,BAD_SIZE);
1342 REQUIRES(xs==ts,NOCONVER);
1343 DEBUGMSG("transP");
1344 int i,j;
1345 for (i=0; i<tr; i++) {
1346 for (j=0; j<tc; j++) {
1347 memcpy(tp+(i*tc+j)*xs,xp +(j*xc+i)*xs,xs);
1348 }
1349 }
1350 OK
1351}
1352
1353//////////////////// constant /////////////////////////
1354
1355int constantF(float * pval, FVEC(r)) {
1356 DEBUGMSG("constantF")
1357 int k;
1358 double val = *pval;
1359 for(k=0;k<rn;k++) {
1360 rp[k]=val;
1361 }
1362 OK
1363}
1364
1365int constantR(double * pval, DVEC(r)) {
1366 DEBUGMSG("constantR")
1367 int k;
1368 double val = *pval;
1369 for(k=0;k<rn;k++) {
1370 rp[k]=val;
1371 }
1372 OK
1373}
1374
1375int constantQ(complex* pval, QVEC(r)) {
1376 DEBUGMSG("constantQ")
1377 int k;
1378 complex val = *pval;
1379 for(k=0;k<rn;k++) {
1380 rp[k]=val;
1381 }
1382 OK
1383}
1384
1385int constantC(doublecomplex* pval, CVEC(r)) {
1386 DEBUGMSG("constantC")
1387 int k;
1388 doublecomplex val = *pval;
1389 for(k=0;k<rn;k++) {
1390 rp[k]=val;
1391 }
1392 OK
1393}
1394
1395int constantP(void* pval, PVEC(r)) {
1396 DEBUGMSG("constantP")
1397 int k;
1398 for(k=0;k<rn;k++) {
1399 memcpy(rp+k*rs,pval,rs);
1400 }
1401 OK
1402}
1403
1404//////////////////// float-double conversion /////////////////////////
1405
1406int float2double(FVEC(x),DVEC(y)) {
1407 DEBUGMSG("float2double")
1408 int k;
1409 for(k=0;k<xn;k++) {
1410 yp[k]=xp[k];
1411 }
1412 OK
1413}
1414
1415int double2float(DVEC(x),FVEC(y)) {
1416 DEBUGMSG("double2float")
1417 int k;
1418 for(k=0;k<xn;k++) {
1419 yp[k]=xp[k];
1420 }
1421 OK
1422}
1423
1424//////////////////// conjugate /////////////////////////
1425
1426int conjugateQ(KQVEC(x),QVEC(t)) {
1427 REQUIRES(xn==tn,BAD_SIZE);
1428 DEBUGMSG("conjugateQ");
1429 int k;
1430 for(k=0;k<xn;k++) {
1431 tp[k].r = xp[k].r;
1432 tp[k].i = -xp[k].i;
1433 }
1434 OK
1435}
1436
1437int conjugateC(KCVEC(x),CVEC(t)) {
1438 REQUIRES(xn==tn,BAD_SIZE);
1439 DEBUGMSG("conjugateC");
1440 int k;
1441 for(k=0;k<xn;k++) {
1442 tp[k].r = xp[k].r;
1443 tp[k].i = -xp[k].i;
1444 }
1445 OK
1446}
1447
1448//////////////////// step /////////////////////////
1449
1450int stepF(FVEC(x),FVEC(y)) {
1451 DEBUGMSG("stepF")
1452 int k;
1453 for(k=0;k<xn;k++) {
1454 yp[k]=xp[k]>0;
1455 }
1456 OK
1457}
1458
1459int stepD(DVEC(x),DVEC(y)) {
1460 DEBUGMSG("stepD")
1461 int k;
1462 for(k=0;k<xn;k++) {
1463 yp[k]=xp[k]>0;
1464 }
1465 OK
1466}
1467
1468//////////////////// cond /////////////////////////
1469
1470int condF(FVEC(x),FVEC(y),FVEC(lt),FVEC(eq),FVEC(gt),FVEC(r)) {
1471 REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE);
1472 DEBUGMSG("condF")
1473 int k;
1474 for(k=0;k<xn;k++) {
1475 rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]);
1476 }
1477 OK
1478}
1479
1480int condD(DVEC(x),DVEC(y),DVEC(lt),DVEC(eq),DVEC(gt),DVEC(r)) {
1481 REQUIRES(xn==yn && xn==ltn && xn==eqn && xn==gtn && xn==rn ,BAD_SIZE);
1482 DEBUGMSG("condD")
1483 int k;
1484 for(k=0;k<xn;k++) {
1485 rp[k] = xp[k]<yp[k]?ltp[k]:(xp[k]>yp[k]?gtp[k]:eqp[k]);
1486 }
1487 OK
1488}
1489
diff --git a/packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK/lapack-aux.h b/packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
new file mode 100644
index 0000000..a3f1899
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/LinearAlgebra/LAPACK/lapack-aux.h
@@ -0,0 +1,60 @@
1/*
2 * We have copied the definitions in f2c.h required
3 * to compile clapack.h, modified to support both
4 * 32 and 64 bit
5
6 http://opengrok.creo.hu/dragonfly/xref/src/contrib/gcc-3.4/libf2c/readme.netlib
7 http://www.ibm.com/developerworks/library/l-port64.html
8 */
9
10#ifdef _LP64
11typedef int integer;
12typedef unsigned int uinteger;
13typedef int logical;
14typedef long longint; /* system-dependent */
15typedef unsigned long ulongint; /* system-dependent */
16#else
17typedef long int integer;
18typedef unsigned long int uinteger;
19typedef long int logical;
20typedef long long longint; /* system-dependent */
21typedef unsigned long long ulongint; /* system-dependent */
22#endif
23
24typedef char *address;
25typedef short int shortint;
26typedef float real;
27typedef double doublereal;
28typedef struct { real r, i; } complex;
29typedef struct { doublereal r, i; } doublecomplex;
30typedef short int shortlogical;
31typedef char logical1;
32typedef char integer1;
33
34typedef logical (*L_fp)();
35typedef short ftnlen;
36
37/********************************************************/
38
39#define FVEC(A) int A##n, float*A##p
40#define DVEC(A) int A##n, double*A##p
41#define QVEC(A) int A##n, complex*A##p
42#define CVEC(A) int A##n, doublecomplex*A##p
43#define PVEC(A) int A##n, void* A##p, int A##s
44#define FMAT(A) int A##r, int A##c, float* A##p
45#define DMAT(A) int A##r, int A##c, double* A##p
46#define QMAT(A) int A##r, int A##c, complex* A##p
47#define CMAT(A) int A##r, int A##c, doublecomplex* A##p
48#define PMAT(A) int A##r, int A##c, void* A##p, int A##s
49
50#define KFVEC(A) int A##n, const float*A##p
51#define KDVEC(A) int A##n, const double*A##p
52#define KQVEC(A) int A##n, const complex*A##p
53#define KCVEC(A) int A##n, const doublecomplex*A##p
54#define KPVEC(A) int A##n, const void* A##p, int A##s
55#define KFMAT(A) int A##r, int A##c, const float* A##p
56#define KDMAT(A) int A##r, int A##c, const double* A##p
57#define KQMAT(A) int A##r, int A##c, const complex* A##p
58#define KCMAT(A) int A##r, int A##c, const doublecomplex* A##p
59#define KPMAT(A) int A##r, int A##c, const void* A##p, int A##s
60
diff --git a/packages/hmatrix/src/Numeric/LinearAlgebra/Util.hs b/packages/hmatrix/src/Numeric/LinearAlgebra/Util.hs
new file mode 100644
index 0000000..7d134bf
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/LinearAlgebra/Util.hs
@@ -0,0 +1,295 @@
1{-# LANGUAGE FlexibleContexts #-}
2-----------------------------------------------------------------------------
3{- |
4Module : Numeric.LinearAlgebra.Util
5Copyright : (c) Alberto Ruiz 2013
6License : GPL
7
8Maintainer : Alberto Ruiz (aruiz at um dot es)
9Stability : provisional
10
11-}
12-----------------------------------------------------------------------------
13{-# OPTIONS_HADDOCK hide #-}
14
15module Numeric.LinearAlgebra.Util(
16
17 -- * Convenience functions
18 size, disp,
19 zeros, ones,
20 diagl,
21 row,
22 col,
23 (&), (¦), (——), (#),
24 (?), (¿),
25 rand, randn,
26 cross,
27 norm,
28 unitary,
29 mt,
30 pairwiseD2,
31 rowOuters,
32 null1,
33 null1sym,
34 -- * Convolution
35 -- ** 1D
36 corr, conv, corrMin,
37 -- ** 2D
38 corr2, conv2, separable,
39 -- * Tools for the Kronecker product
40 --
41 -- | (see A. Fusiello, A matter of notation: Several uses of the Kronecker product in
42 -- 3d computer vision, Pattern Recognition Letters 28 (15) (2007) 2127-2132)
43
44 --
45 -- | @`vec` (a \<> x \<> b) == ('trans' b ` 'kronecker' ` a) \<> 'vec' x@
46 vec,
47 vech,
48 dup,
49 vtrans,
50 -- * Plot
51 mplot,
52 plot, parametricPlot,
53 splot, mesh, meshdom,
54 matrixToPGM, imshow,
55 gnuplotX, gnuplotpdf, gnuplotWin
56) where
57
58import Numeric.Container
59import Numeric.LinearAlgebra.Algorithms hiding (i)
60import Numeric.Matrix()
61import Numeric.Vector()
62
63import System.Random(randomIO)
64import Numeric.LinearAlgebra.Util.Convolution
65import Graphics.Plot
66
67
68{- | print a real matrix with given number of digits after the decimal point
69
70>>> disp 5 $ ident 2 / 3
712x2
720.33333 0.00000
730.00000 0.33333
74
75-}
76disp :: Int -> Matrix Double -> IO ()
77
78disp n = putStrLn . dispf n
79
80-- | pseudorandom matrix with uniform elements between 0 and 1
81randm :: RandDist
82 -> Int -- ^ rows
83 -> Int -- ^ columns
84 -> IO (Matrix Double)
85randm d r c = do
86 seed <- randomIO
87 return (reshape c $ randomVector seed d (r*c))
88
89-- | pseudorandom matrix with uniform elements between 0 and 1
90rand :: Int -> Int -> IO (Matrix Double)
91rand = randm Uniform
92
93{- | pseudorandom matrix with normal elements
94
95>>> x <- randn 3 5
96>>> disp 3 x
973x5
980.386 -1.141 0.491 -0.510 1.512
990.069 -0.919 1.022 -0.181 0.745
1000.313 -0.670 -0.097 -1.575 -0.583
101
102-}
103randn :: Int -> Int -> IO (Matrix Double)
104randn = randm Gaussian
105
106{- | create a real diagonal matrix from a list
107
108>>> diagl [1,2,3]
109(3><3)
110 [ 1.0, 0.0, 0.0
111 , 0.0, 2.0, 0.0
112 , 0.0, 0.0, 3.0 ]
113
114-}
115diagl :: [Double] -> Matrix Double
116diagl = diag . fromList
117
118-- | a real matrix of zeros
119zeros :: Int -- ^ rows
120 -> Int -- ^ columns
121 -> Matrix Double
122zeros r c = konst 0 (r,c)
123
124-- | a real matrix of ones
125ones :: Int -- ^ rows
126 -> Int -- ^ columns
127 -> Matrix Double
128ones r c = konst 1 (r,c)
129
130-- | concatenation of real vectors
131infixl 3 &
132(&) :: Vector Double -> Vector Double -> Vector Double
133a & b = vjoin [a,b]
134
135{- | horizontal concatenation of real matrices
136
137 (unicode 0x00a6, broken bar)
138
139>>> ident 3 ¦ konst 7 (3,4)
140(3><7)
141 [ 1.0, 0.0, 0.0, 7.0, 7.0, 7.0, 7.0
142 , 0.0, 1.0, 0.0, 7.0, 7.0, 7.0, 7.0
143 , 0.0, 0.0, 1.0, 7.0, 7.0, 7.0, 7.0 ]
144
145-}
146infixl 3 ¦
147(¦) :: Matrix Double -> Matrix Double -> Matrix Double
148a ¦ b = fromBlocks [[a,b]]
149
150-- | vertical concatenation of real matrices
151--
152-- (unicode 0x2014, em dash)
153(——) :: Matrix Double -> Matrix Double -> Matrix Double
154infixl 2 ——
155a —— b = fromBlocks [[a],[b]]
156
157(#) :: Matrix Double -> Matrix Double -> Matrix Double
158infixl 2 #
159a # b = fromBlocks [[a],[b]]
160
161-- | create a single row real matrix from a list
162row :: [Double] -> Matrix Double
163row = asRow . fromList
164
165-- | create a single column real matrix from a list
166col :: [Double] -> Matrix Double
167col = asColumn . fromList
168
169{- | extract rows
170
171>>> (20><4) [1..] ? [2,1,1]
172(3><4)
173 [ 9.0, 10.0, 11.0, 12.0
174 , 5.0, 6.0, 7.0, 8.0
175 , 5.0, 6.0, 7.0, 8.0 ]
176
177-}
178infixl 9 ?
179(?) :: Element t => Matrix t -> [Int] -> Matrix t
180(?) = flip extractRows
181
182{- | extract columns
183
184(unicode 0x00bf, inverted question mark, Alt-Gr ?)
185
186>>> (3><4) [1..] ¿ [3,0]
187(3><2)
188 [ 4.0, 1.0
189 , 8.0, 5.0
190 , 12.0, 9.0 ]
191
192-}
193infixl 9 ¿
194(¿) :: Element t => Matrix t -> [Int] -> Matrix t
195(¿)= flip extractColumns
196
197
198cross :: Vector Double -> Vector Double -> Vector Double
199-- ^ cross product (for three-element real vectors)
200cross x y | dim x == 3 && dim y == 3 = fromList [z1,z2,z3]
201 | otherwise = error $ "cross ("++show x++") ("++show y++")"
202 where
203 [x1,x2,x3] = toList x
204 [y1,y2,y3] = toList y
205 z1 = x2*y3-x3*y2
206 z2 = x3*y1-x1*y3
207 z3 = x1*y2-x2*y1
208
209norm :: Vector Double -> Double
210-- ^ 2-norm of real vector
211norm = pnorm PNorm2
212
213
214-- | Obtains a vector in the same direction with 2-norm=1
215unitary :: Vector Double -> Vector Double
216unitary v = v / scalar (norm v)
217
218-- | ('rows' &&& 'cols')
219size :: Matrix t -> (Int, Int)
220size m = (rows m, cols m)
221
222-- | trans . inv
223mt :: Matrix Double -> Matrix Double
224mt = trans . inv
225
226----------------------------------------------------------------------
227
228-- | Matrix of pairwise squared distances of row vectors
229-- (using the matrix product trick in blog.smola.org)
230pairwiseD2 :: Matrix Double -> Matrix Double -> Matrix Double
231pairwiseD2 x y | ok = x2 `outer` oy + ox `outer` y2 - 2* x <> trans y
232 | otherwise = error $ "pairwiseD2 with different number of columns: "
233 ++ show (size x) ++ ", " ++ show (size y)
234 where
235 ox = one (rows x)
236 oy = one (rows y)
237 oc = one (cols x)
238 one k = constant 1 k
239 x2 = x * x <> oc
240 y2 = y * y <> oc
241 ok = cols x == cols y
242
243--------------------------------------------------------------------------------
244
245-- | outer products of rows
246rowOuters :: Matrix Double -> Matrix Double -> Matrix Double
247rowOuters a b = a' * b'
248 where
249 a' = kronecker a (ones 1 (cols b))
250 b' = kronecker (ones 1 (cols a)) b
251
252--------------------------------------------------------------------------------
253
254-- | solution of overconstrained homogeneous linear system
255null1 :: Matrix Double -> Vector Double
256null1 = last . toColumns . snd . rightSV
257
258-- | solution of overconstrained homogeneous symmetric linear system
259null1sym :: Matrix Double -> Vector Double
260null1sym = last . toColumns . snd . eigSH'
261
262--------------------------------------------------------------------------------
263
264vec :: Element t => Matrix t -> Vector t
265-- ^ stacking of columns
266vec = flatten . trans
267
268
269vech :: Element t => Matrix t -> Vector t
270-- ^ half-vectorization (of the lower triangular part)
271vech m = vjoin . zipWith f [0..] . toColumns $ m
272 where
273 f k v = subVector k (dim v - k) v
274
275
276dup :: (Num t, Num (Vector t), Element t) => Int -> Matrix t
277-- ^ duplication matrix (@'dup' k \<> 'vech' m == 'vec' m@, for symmetric m of 'dim' k)
278dup k = trans $ fromRows $ map f es
279 where
280 rs = zip [0..] (toRows (ident (k^(2::Int))))
281 es = [(i,j) | j <- [0..k-1], i <- [0..k-1], i>=j ]
282 f (i,j) | i == j = g (k*j + i)
283 | otherwise = g (k*j + i) + g (k*i + j)
284 g j = v
285 where
286 Just v = lookup j rs
287
288
289vtrans :: Element t => Int -> Matrix t -> Matrix t
290-- ^ generalized \"vector\" transposition: @'vtrans' 1 == 'trans'@, and @'vtrans' ('rows' m) m == 'asColumn' ('vec' m)@
291vtrans p m | r == 0 = fromBlocks . map (map asColumn . takesV (replicate q p)) . toColumns $ m
292 | otherwise = error $ "vtrans " ++ show p ++ " of matrix with " ++ show (rows m) ++ " rows"
293 where
294 (q,r) = divMod (rows m) p
295
diff --git a/packages/hmatrix/src/Numeric/LinearAlgebra/Util/Convolution.hs b/packages/hmatrix/src/Numeric/LinearAlgebra/Util/Convolution.hs
new file mode 100644
index 0000000..82de476
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/LinearAlgebra/Util/Convolution.hs
@@ -0,0 +1,114 @@
1{-# LANGUAGE FlexibleContexts #-}
2-----------------------------------------------------------------------------
3{- |
4Module : Numeric.LinearAlgebra.Util.Convolution
5Copyright : (c) Alberto Ruiz 2012
6License : GPL
7
8Maintainer : Alberto Ruiz (aruiz at um dot es)
9Stability : provisional
10
11-}
12-----------------------------------------------------------------------------
13
14module Numeric.LinearAlgebra.Util.Convolution(
15 corr, conv, corrMin,
16 corr2, conv2, separable
17) where
18
19import Numeric.LinearAlgebra
20
21
22vectSS :: Element t => Int -> Vector t -> Matrix t
23vectSS n v = fromRows [ subVector k n v | k <- [0 .. dim v - n] ]
24
25
26corr :: Product t => Vector t -- ^ kernel
27 -> Vector t -- ^ source
28 -> Vector t
29{- ^ correlation
30
31>>> corr (fromList[1,2,3]) (fromList [1..10])
32fromList [14.0,20.0,26.0,32.0,38.0,44.0,50.0,56.0]
33
34-}
35corr ker v | dim ker <= dim v = vectSS (dim ker) v <> ker
36 | otherwise = error $ "corr: dim kernel ("++show (dim ker)++") > dim vector ("++show (dim v)++")"
37
38
39conv :: (Product t, Num t) => Vector t -> Vector t -> Vector t
40{- ^ convolution ('corr' with reversed kernel and padded input, equivalent to polynomial product)
41
42>>> conv (fromList[1,1]) (fromList [-1,1])
43fromList [-1.0,0.0,1.0]
44
45-}
46conv ker v = corr ker' v'
47 where
48 ker' = (flatten.fliprl.asRow) ker
49 v' | dim ker > 1 = vjoin [z,v,z]
50 | otherwise = v
51 z = constant 0 (dim ker -1)
52
53corrMin :: (Container Vector t, RealElement t, Product t)
54 => Vector t
55 -> Vector t
56 -> Vector t
57-- ^ similar to 'corr', using 'min' instead of (*)
58corrMin ker v = minEvery ss (asRow ker) <> ones
59 where
60 minEvery a b = cond a b a a b
61 ss = vectSS (dim ker) v
62 ones = konst 1 (dim ker)
63
64
65
66matSS :: Element t => Int -> Matrix t -> [Matrix t]
67matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ]
68 where
69 v = flatten m
70 c = cols m
71 r = rows m
72 n = dr*c
73
74
75corr2 :: Product a => Matrix a -> Matrix a -> Matrix a
76-- ^ 2D correlation
77corr2 ker mat = dims
78 . concatMap (map (udot ker' . flatten) . matSS c . trans)
79 . matSS r $ mat
80 where
81 r = rows ker
82 c = cols ker
83 ker' = flatten (trans ker)
84 rr = rows mat - r + 1
85 rc = cols mat - c + 1
86 dims | rr > 0 && rc > 0 = (rr >< rc)
87 | otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")"
88 sz m = show (rows m)++"x"++show (cols m)
89
90conv2 :: (Num a, Product a, Container Vector a) => Matrix a -> Matrix a -> Matrix a
91-- ^ 2D convolution
92conv2 k m = corr2 (fliprl . flipud $ k) pm
93 where
94 pm | r == 0 && c == 0 = m
95 | r == 0 = fromBlocks [[z3,m,z3]]
96 | c == 0 = fromBlocks [[z2],[m],[z2]]
97 | otherwise = fromBlocks [[z1,z2,z1]
98 ,[z3, m,z3]
99 ,[z1,z2,z1]]
100 r = rows k - 1
101 c = cols k - 1
102 h = rows m
103 w = cols m
104 z1 = konst 0 (r,c)
105 z2 = konst 0 (r,w)
106 z3 = konst 0 (h,c)
107
108-- TODO: could be simplified using future empty arrays
109
110
111separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t
112-- ^ matrix computation implemented as separated vector operations by rows and columns.
113separable f = fromColumns . map f . toColumns . fromRows . map f . toRows
114
diff --git a/packages/hmatrix/src/Numeric/Matrix.hs b/packages/hmatrix/src/Numeric/Matrix.hs
new file mode 100644
index 0000000..e285ff2
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/Matrix.hs
@@ -0,0 +1,98 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6
7-----------------------------------------------------------------------------
8-- |
9-- Module : Numeric.Matrix
10-- Copyright : (c) Alberto Ruiz 2010
11-- License : GPL-style
12--
13-- Maintainer : Alberto Ruiz <aruiz@um.es>
14-- Stability : provisional
15-- Portability : portable
16--
17-- Provides instances of standard classes 'Show', 'Read', 'Eq',
18-- 'Num', 'Fractional', and 'Floating' for 'Matrix'.
19--
20-- In arithmetic operations one-component
21-- vectors and matrices automatically expand to match the dimensions of the other operand.
22
23-----------------------------------------------------------------------------
24
25module Numeric.Matrix (
26 ) where
27
28-------------------------------------------------------------------
29
30import Numeric.Container
31import qualified Data.Monoid as M
32import Data.List(partition)
33
34-------------------------------------------------------------------
35
36instance Container Matrix a => Eq (Matrix a) where
37 (==) = equal
38
39instance (Container Matrix a, Num (Vector a)) => Num (Matrix a) where
40 (+) = liftMatrix2Auto (+)
41 (-) = liftMatrix2Auto (-)
42 negate = liftMatrix negate
43 (*) = liftMatrix2Auto (*)
44 signum = liftMatrix signum
45 abs = liftMatrix abs
46 fromInteger = (1><1) . return . fromInteger
47
48---------------------------------------------------
49
50instance (Container Vector a, Fractional (Vector a), Num (Matrix a)) => Fractional (Matrix a) where
51 fromRational n = (1><1) [fromRational n]
52 (/) = liftMatrix2Auto (/)
53
54---------------------------------------------------------
55
56instance (Floating a, Container Vector a, Floating (Vector a), Fractional (Matrix a)) => Floating (Matrix a) where
57 sin = liftMatrix sin
58 cos = liftMatrix cos
59 tan = liftMatrix tan
60 asin = liftMatrix asin
61 acos = liftMatrix acos
62 atan = liftMatrix atan
63 sinh = liftMatrix sinh
64 cosh = liftMatrix cosh
65 tanh = liftMatrix tanh
66 asinh = liftMatrix asinh
67 acosh = liftMatrix acosh
68 atanh = liftMatrix atanh
69 exp = liftMatrix exp
70 log = liftMatrix log
71 (**) = liftMatrix2Auto (**)
72 sqrt = liftMatrix sqrt
73 pi = (1><1) [pi]
74
75--------------------------------------------------------------------------------
76
77isScalar m = rows m == 1 && cols m == 1
78
79adaptScalarM f1 f2 f3 x y
80 | isScalar x = f1 (x @@>(0,0) ) y
81 | isScalar y = f3 x (y @@>(0,0) )
82 | otherwise = f2 x y
83
84instance (Container Vector t, Eq t, Num (Vector t), Product t) => M.Monoid (Matrix t)
85 where
86 mempty = 1
87 mappend = adaptScalarM scale mXm (flip scale)
88
89 mconcat xs = work (partition isScalar xs)
90 where
91 work (ss,[]) = product ss
92 work (ss,ms) = scale' (product ss) (optimiseMult ms)
93 scale' x m
94 | isScalar x && x00 == 1 = m
95 | otherwise = scale x00 m
96 where
97 x00 = x @@> (0,0)
98
diff --git a/packages/hmatrix/src/Numeric/Vector.hs b/packages/hmatrix/src/Numeric/Vector.hs
new file mode 100644
index 0000000..3f480a0
--- /dev/null
+++ b/packages/hmatrix/src/Numeric/Vector.hs
@@ -0,0 +1,158 @@
1{-# LANGUAGE TypeFamilies #-}
2{-# LANGUAGE FlexibleContexts #-}
3{-# LANGUAGE FlexibleInstances #-}
4{-# LANGUAGE UndecidableInstances #-}
5{-# LANGUAGE MultiParamTypeClasses #-}
6-----------------------------------------------------------------------------
7-- |
8-- Module : Numeric.Vector
9-- Copyright : (c) Alberto Ruiz 2011
10-- License : GPL-style
11--
12-- Maintainer : Alberto Ruiz <aruiz@um.es>
13-- Stability : provisional
14-- Portability : portable
15--
16-- Provides instances of standard classes 'Show', 'Read', 'Eq',
17-- 'Num', 'Fractional', and 'Floating' for 'Vector'.
18--
19-----------------------------------------------------------------------------
20
21module Numeric.Vector () where
22
23import Numeric.GSL.Vector
24import Numeric.Container
25
26-------------------------------------------------------------------
27
28adaptScalar f1 f2 f3 x y
29 | dim x == 1 = f1 (x@>0) y
30 | dim y == 1 = f3 x (y@>0)
31 | otherwise = f2 x y
32
33------------------------------------------------------------------
34
35instance Num (Vector Float) where
36 (+) = adaptScalar addConstant add (flip addConstant)
37 negate = scale (-1)
38 (*) = adaptScalar scale mul (flip scale)
39 signum = vectorMapF Sign
40 abs = vectorMapF Abs
41 fromInteger = fromList . return . fromInteger
42
43instance Num (Vector Double) where
44 (+) = adaptScalar addConstant add (flip addConstant)
45 negate = scale (-1)
46 (*) = adaptScalar scale mul (flip scale)
47 signum = vectorMapR Sign
48 abs = vectorMapR Abs
49 fromInteger = fromList . return . fromInteger
50
51instance Num (Vector (Complex Double)) where
52 (+) = adaptScalar addConstant add (flip addConstant)
53 negate = scale (-1)
54 (*) = adaptScalar scale mul (flip scale)
55 signum = vectorMapC Sign
56 abs = vectorMapC Abs
57 fromInteger = fromList . return . fromInteger
58
59instance Num (Vector (Complex Float)) where
60 (+) = adaptScalar addConstant add (flip addConstant)
61 negate = scale (-1)
62 (*) = adaptScalar scale mul (flip scale)
63 signum = vectorMapQ Sign
64 abs = vectorMapQ Abs
65 fromInteger = fromList . return . fromInteger
66
67---------------------------------------------------
68
69instance (Container Vector a, Num (Vector a)) => Fractional (Vector a) where
70 fromRational n = fromList [fromRational n]
71 (/) = adaptScalar f divide g where
72 r `f` v = scaleRecip r v
73 v `g` r = scale (recip r) v
74
75-------------------------------------------------------
76
77instance Floating (Vector Float) where
78 sin = vectorMapF Sin
79 cos = vectorMapF Cos
80 tan = vectorMapF Tan
81 asin = vectorMapF ASin
82 acos = vectorMapF ACos
83 atan = vectorMapF ATan
84 sinh = vectorMapF Sinh
85 cosh = vectorMapF Cosh
86 tanh = vectorMapF Tanh
87 asinh = vectorMapF ASinh
88 acosh = vectorMapF ACosh
89 atanh = vectorMapF ATanh
90 exp = vectorMapF Exp
91 log = vectorMapF Log
92 sqrt = vectorMapF Sqrt
93 (**) = adaptScalar (vectorMapValF PowSV) (vectorZipF Pow) (flip (vectorMapValF PowVS))
94 pi = fromList [pi]
95
96-------------------------------------------------------------
97
98instance Floating (Vector Double) where
99 sin = vectorMapR Sin
100 cos = vectorMapR Cos
101 tan = vectorMapR Tan
102 asin = vectorMapR ASin
103 acos = vectorMapR ACos
104 atan = vectorMapR ATan
105 sinh = vectorMapR Sinh
106 cosh = vectorMapR Cosh
107 tanh = vectorMapR Tanh
108 asinh = vectorMapR ASinh
109 acosh = vectorMapR ACosh
110 atanh = vectorMapR ATanh
111 exp = vectorMapR Exp
112 log = vectorMapR Log
113 sqrt = vectorMapR Sqrt
114 (**) = adaptScalar (vectorMapValR PowSV) (vectorZipR Pow) (flip (vectorMapValR PowVS))
115 pi = fromList [pi]
116
117-------------------------------------------------------------
118
119instance Floating (Vector (Complex Double)) where
120 sin = vectorMapC Sin
121 cos = vectorMapC Cos
122 tan = vectorMapC Tan
123 asin = vectorMapC ASin
124 acos = vectorMapC ACos
125 atan = vectorMapC ATan
126 sinh = vectorMapC Sinh
127 cosh = vectorMapC Cosh
128 tanh = vectorMapC Tanh
129 asinh = vectorMapC ASinh
130 acosh = vectorMapC ACosh
131 atanh = vectorMapC ATanh
132 exp = vectorMapC Exp
133 log = vectorMapC Log
134 sqrt = vectorMapC Sqrt
135 (**) = adaptScalar (vectorMapValC PowSV) (vectorZipC Pow) (flip (vectorMapValC PowVS))
136 pi = fromList [pi]
137
138-----------------------------------------------------------
139
140instance Floating (Vector (Complex Float)) where
141 sin = vectorMapQ Sin
142 cos = vectorMapQ Cos
143 tan = vectorMapQ Tan
144 asin = vectorMapQ ASin
145 acos = vectorMapQ ACos
146 atan = vectorMapQ ATan
147 sinh = vectorMapQ Sinh
148 cosh = vectorMapQ Cosh
149 tanh = vectorMapQ Tanh
150 asinh = vectorMapQ ASinh
151 acosh = vectorMapQ ACosh
152 atanh = vectorMapQ ATanh
153 exp = vectorMapQ Exp
154 log = vectorMapQ Log
155 sqrt = vectorMapQ Sqrt
156 (**) = adaptScalar (vectorMapValQ PowSV) (vectorZipQ Pow) (flip (vectorMapValQ PowVS))
157 pi = fromList [pi]
158