summaryrefslogtreecommitdiff
path: root/packages/base/src/Numeric/LinearAlgebra/Static.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static.hs')
-rw-r--r--packages/base/src/Numeric/LinearAlgebra/Static.hs193
1 files changed, 193 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs
new file mode 100644
index 0000000..f9e935d
--- /dev/null
+++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs
@@ -0,0 +1,193 @@
1{-# LANGUAGE DataKinds #-}
2{-# LANGUAGE KindSignatures #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE MultiParamTypeClasses #-}
5{-# LANGUAGE FunctionalDependencies #-}
6{-# LANGUAGE FlexibleContexts #-}
7{-# LANGUAGE ScopedTypeVariables #-}
8{-# LANGUAGE EmptyDataDecls #-}
9{-# LANGUAGE Rank2Types #-}
10{-# LANGUAGE FlexibleInstances #-}
11{-# LANGUAGE TypeOperators #-}
12{-# LANGUAGE ViewPatterns #-}
13{-# LANGUAGE GADTs #-}
14
15
16{- |
17Module : Numeric.LinearAlgebra.Static
18Copyright : (c) Alberto Ruiz 2006-14
19License : BSD3
20Stability : provisional
21
22-}
23
24module Numeric.LinearAlgebra.Static(
25 Dim(..),
26 R(..), C(..),
27 lift1F, lift2F,
28 vconcat, gvec2, gvec3, gvec4, gvect, gmat,
29 Sized(..),
30 singleV, singleM
31) where
32
33
34import GHC.TypeLits
35import Numeric.HMatrix as LA
36import Data.Packed as D
37import Data.Packed.ST
38import Data.Proxy(Proxy)
39import Foreign.Storable(Storable)
40
41
42
43newtype R n = R (Dim n (Vector ℝ))
44 deriving (Num,Fractional)
45
46
47newtype C n = C (Dim n (Vector ℂ))
48 deriving (Num,Fractional)
49
50
51
52newtype Dim (n :: Nat) t = Dim t
53 deriving Show
54
55lift1F
56 :: (c t -> c t)
57 -> Dim n (c t) -> Dim n (c t)
58lift1F f (Dim v) = Dim (f v)
59
60lift2F
61 :: (c t -> c t -> c t)
62 -> Dim n (c t) -> Dim n (c t) -> Dim n (c t)
63lift2F f (Dim u) (Dim v) = Dim (f u v)
64
65--------------------------------------------------------------------------------
66
67instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t))
68 where
69 (+) = lift2F (+)
70 (*) = lift2F (*)
71 (-) = lift2F (-)
72 abs = lift1F abs
73 signum = lift1F signum
74 negate = lift1F negate
75 fromInteger x = Dim (fromInteger x)
76
77instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t)))
78 where
79 (+) = (lift2F . lift2F) (+)
80 (*) = (lift2F . lift2F) (*)
81 (-) = (lift2F . lift2F) (-)
82 abs = (lift1F . lift1F) abs
83 signum = (lift1F . lift1F) signum
84 negate = (lift1F . lift1F) negate
85 fromInteger x = Dim (Dim (fromInteger x))
86
87instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t))
88 where
89 fromRational x = Dim (fromRational x)
90 (/) = lift2F (/)
91
92instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t)))
93 where
94 fromRational x = Dim (Dim (fromRational x))
95 (/) = (lift2F.lift2F) (/)
96
97--------------------------------------------------------------------------------
98
99type V n t = Dim n (Vector t)
100
101ud :: Dim n (Vector t) -> Vector t
102ud (Dim v) = v
103
104mkV :: forall (n :: Nat) t . t -> Dim n t
105mkV = Dim
106
107type M m n t = Dim m (Dim n (Matrix t))
108
109ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t
110ud2 (Dim (Dim m)) = m
111
112mkM :: forall (m :: Nat) (n :: Nat) t . t -> Dim m (Dim n t)
113mkM = Dim . Dim
114
115
116vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t)
117 => V n t -> V m t -> V (n+m) t
118(ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v'])
119 where
120 du = fromIntegral . natVal $ (undefined :: Proxy n)
121 dv = fromIntegral . natVal $ (undefined :: Proxy m)
122 u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du
123 | otherwise = u
124 v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv
125 | otherwise = v
126
127
128gvec2 :: Storable t => t -> t -> V 2 t
129gvec2 a b = mkV $ runSTVector $ do
130 v <- newUndefinedVector 2
131 writeVector v 0 a
132 writeVector v 1 b
133 return v
134
135gvec3 :: Storable t => t -> t -> t -> V 3 t
136gvec3 a b c = mkV $ runSTVector $ do
137 v <- newUndefinedVector 3
138 writeVector v 0 a
139 writeVector v 1 b
140 writeVector v 2 c
141 return v
142
143
144gvec4 :: Storable t => t -> t -> t -> t -> V 4 t
145gvec4 a b c d = mkV $ runSTVector $ do
146 v <- newUndefinedVector 4
147 writeVector v 0 a
148 writeVector v 1 b
149 writeVector v 2 c
150 writeVector v 3 d
151 return v
152
153
154gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t
155gvect st xs'
156 | ok = mkV v
157 | not (null rest) && null (tail rest) = abort (show xs')
158 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
159 | otherwise = abort (show xs)
160 where
161 (xs,rest) = splitAt d xs'
162 ok = size v == d && null rest
163 v = LA.fromList xs
164 d = fromIntegral . natVal $ (undefined :: Proxy n)
165 abort info = error $ st++" "++show d++" can't be created from elements "++info
166
167
168gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> M m n t
169gmat st xs'
170 | ok = mkM x
171 | not (null rest) && null (tail rest) = abort (show xs')
172 | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]")
173 | otherwise = abort (show xs)
174 where
175 (xs,rest) = splitAt (m'*n') xs'
176 v = LA.fromList xs
177 x = reshape n' v
178 ok = rem (size v) n' == 0 && size x == (m',n') && null rest
179 m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int
180 n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int
181 abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info
182
183
184class Num t => Sized t s d | s -> t, s -> d
185 where
186 konst :: t -> s
187 extract :: s -> d
188 fromList :: [t] -> s
189 expand :: s -> d
190
191singleV v = size v == 1
192singleM m = rows m == 1 && cols m == 1
193