summaryrefslogtreecommitdiff
path: root/examples/devel/wrappers.hs
blob: f9e258a0a8892fc7c17634636babc9e300d58ba9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
{-# LANGUAGE ForeignFunctionInterface #-}

-- $ ghc -O2 --make wrappers.hs functions.c

import Numeric.LinearAlgebra
import Data.Packed.Development
import Foreign(Ptr,unsafePerformIO)
import Foreign.C.Types(CInt)

-----------------------------------------------------

main = do
    print $ myScale 3.0 (fromList [1..10])
    print $ myDiag $ (3><5) [1..]

-----------------------------------------------------

foreign import ccall "c_scale_vector"
    cScaleVector :: Double                -- scale
                 -> CInt -> Ptr Double    -- argument
                 -> CInt -> Ptr Double    -- result
                 -> IO CInt               -- exit code

myScale s x = unsafePerformIO $ do
    y <- createVector (dim x)
    app2 (cScaleVector s) vec x vec y "cScaleVector"
    return y

-----------------------------------------------------

foreign import ccall "c_diag"
    cDiag :: CInt                        -- matrix order
          -> CInt -> CInt -> Ptr Double  -- argument
          -> CInt -> Ptr Double          -- result1
          -> CInt -> CInt -> Ptr Double  -- result2
          -> IO CInt                     -- exit code

myDiag m = unsafePerformIO $ do
    y <- createVector (min r c)
    z <- createMatrix (orderOf m) r c
    app3 (cDiag o) mat m vec y mat z "cDiag"
    return (y,z)
  where r = rows m
        c = cols m
        o = if orderOf m == RowMajor then 1 else 0