diff options
Diffstat (limited to 'packages/base/src/Numeric/LinearAlgebra/Static.hs')
-rw-r--r-- | packages/base/src/Numeric/LinearAlgebra/Static.hs | 193 |
1 files changed, 193 insertions, 0 deletions
diff --git a/packages/base/src/Numeric/LinearAlgebra/Static.hs b/packages/base/src/Numeric/LinearAlgebra/Static.hs new file mode 100644 index 0000000..f9e935d --- /dev/null +++ b/packages/base/src/Numeric/LinearAlgebra/Static.hs | |||
@@ -0,0 +1,193 @@ | |||
1 | {-# LANGUAGE DataKinds #-} | ||
2 | {-# LANGUAGE KindSignatures #-} | ||
3 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
4 | {-# LANGUAGE MultiParamTypeClasses #-} | ||
5 | {-# LANGUAGE FunctionalDependencies #-} | ||
6 | {-# LANGUAGE FlexibleContexts #-} | ||
7 | {-# LANGUAGE ScopedTypeVariables #-} | ||
8 | {-# LANGUAGE EmptyDataDecls #-} | ||
9 | {-# LANGUAGE Rank2Types #-} | ||
10 | {-# LANGUAGE FlexibleInstances #-} | ||
11 | {-# LANGUAGE TypeOperators #-} | ||
12 | {-# LANGUAGE ViewPatterns #-} | ||
13 | {-# LANGUAGE GADTs #-} | ||
14 | |||
15 | |||
16 | {- | | ||
17 | Module : Numeric.LinearAlgebra.Static | ||
18 | Copyright : (c) Alberto Ruiz 2006-14 | ||
19 | License : BSD3 | ||
20 | Stability : provisional | ||
21 | |||
22 | -} | ||
23 | |||
24 | module Numeric.LinearAlgebra.Static( | ||
25 | Dim(..), | ||
26 | R(..), C(..), | ||
27 | lift1F, lift2F, | ||
28 | vconcat, gvec2, gvec3, gvec4, gvect, gmat, | ||
29 | Sized(..), | ||
30 | singleV, singleM | ||
31 | ) where | ||
32 | |||
33 | |||
34 | import GHC.TypeLits | ||
35 | import Numeric.HMatrix as LA | ||
36 | import Data.Packed as D | ||
37 | import Data.Packed.ST | ||
38 | import Data.Proxy(Proxy) | ||
39 | import Foreign.Storable(Storable) | ||
40 | |||
41 | |||
42 | |||
43 | newtype R n = R (Dim n (Vector ℝ)) | ||
44 | deriving (Num,Fractional) | ||
45 | |||
46 | |||
47 | newtype C n = C (Dim n (Vector ℂ)) | ||
48 | deriving (Num,Fractional) | ||
49 | |||
50 | |||
51 | |||
52 | newtype Dim (n :: Nat) t = Dim t | ||
53 | deriving Show | ||
54 | |||
55 | lift1F | ||
56 | :: (c t -> c t) | ||
57 | -> Dim n (c t) -> Dim n (c t) | ||
58 | lift1F f (Dim v) = Dim (f v) | ||
59 | |||
60 | lift2F | ||
61 | :: (c t -> c t -> c t) | ||
62 | -> Dim n (c t) -> Dim n (c t) -> Dim n (c t) | ||
63 | lift2F f (Dim u) (Dim v) = Dim (f u v) | ||
64 | |||
65 | -------------------------------------------------------------------------------- | ||
66 | |||
67 | instance forall n t . (Num (Vector t), Numeric t )=> Num (Dim n (Vector t)) | ||
68 | where | ||
69 | (+) = lift2F (+) | ||
70 | (*) = lift2F (*) | ||
71 | (-) = lift2F (-) | ||
72 | abs = lift1F abs | ||
73 | signum = lift1F signum | ||
74 | negate = lift1F negate | ||
75 | fromInteger x = Dim (fromInteger x) | ||
76 | |||
77 | instance (Num (Matrix t), Numeric t) => Num (Dim m (Dim n (Matrix t))) | ||
78 | where | ||
79 | (+) = (lift2F . lift2F) (+) | ||
80 | (*) = (lift2F . lift2F) (*) | ||
81 | (-) = (lift2F . lift2F) (-) | ||
82 | abs = (lift1F . lift1F) abs | ||
83 | signum = (lift1F . lift1F) signum | ||
84 | negate = (lift1F . lift1F) negate | ||
85 | fromInteger x = Dim (Dim (fromInteger x)) | ||
86 | |||
87 | instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim n (Vector t)) | ||
88 | where | ||
89 | fromRational x = Dim (fromRational x) | ||
90 | (/) = lift2F (/) | ||
91 | |||
92 | instance (Num (Vector t), Num (Matrix t), Numeric t) => Fractional (Dim m (Dim n (Matrix t))) | ||
93 | where | ||
94 | fromRational x = Dim (Dim (fromRational x)) | ||
95 | (/) = (lift2F.lift2F) (/) | ||
96 | |||
97 | -------------------------------------------------------------------------------- | ||
98 | |||
99 | type V n t = Dim n (Vector t) | ||
100 | |||
101 | ud :: Dim n (Vector t) -> Vector t | ||
102 | ud (Dim v) = v | ||
103 | |||
104 | mkV :: forall (n :: Nat) t . t -> Dim n t | ||
105 | mkV = Dim | ||
106 | |||
107 | type M m n t = Dim m (Dim n (Matrix t)) | ||
108 | |||
109 | ud2 :: Dim m (Dim n (Matrix t)) -> Matrix t | ||
110 | ud2 (Dim (Dim m)) = m | ||
111 | |||
112 | mkM :: forall (m :: Nat) (n :: Nat) t . t -> Dim m (Dim n t) | ||
113 | mkM = Dim . Dim | ||
114 | |||
115 | |||
116 | vconcat :: forall n m t . (KnownNat n, KnownNat m, Numeric t) | ||
117 | => V n t -> V m t -> V (n+m) t | ||
118 | (ud -> u) `vconcat` (ud -> v) = mkV (vjoin [u', v']) | ||
119 | where | ||
120 | du = fromIntegral . natVal $ (undefined :: Proxy n) | ||
121 | dv = fromIntegral . natVal $ (undefined :: Proxy m) | ||
122 | u' | du > 1 && size u == 1 = LA.konst (u D.@> 0) du | ||
123 | | otherwise = u | ||
124 | v' | dv > 1 && size v == 1 = LA.konst (v D.@> 0) dv | ||
125 | | otherwise = v | ||
126 | |||
127 | |||
128 | gvec2 :: Storable t => t -> t -> V 2 t | ||
129 | gvec2 a b = mkV $ runSTVector $ do | ||
130 | v <- newUndefinedVector 2 | ||
131 | writeVector v 0 a | ||
132 | writeVector v 1 b | ||
133 | return v | ||
134 | |||
135 | gvec3 :: Storable t => t -> t -> t -> V 3 t | ||
136 | gvec3 a b c = mkV $ runSTVector $ do | ||
137 | v <- newUndefinedVector 3 | ||
138 | writeVector v 0 a | ||
139 | writeVector v 1 b | ||
140 | writeVector v 2 c | ||
141 | return v | ||
142 | |||
143 | |||
144 | gvec4 :: Storable t => t -> t -> t -> t -> V 4 t | ||
145 | gvec4 a b c d = mkV $ runSTVector $ do | ||
146 | v <- newUndefinedVector 4 | ||
147 | writeVector v 0 a | ||
148 | writeVector v 1 b | ||
149 | writeVector v 2 c | ||
150 | writeVector v 3 d | ||
151 | return v | ||
152 | |||
153 | |||
154 | gvect :: forall n t . (Show t, KnownNat n, Numeric t) => String -> [t] -> V n t | ||
155 | gvect st xs' | ||
156 | | ok = mkV v | ||
157 | | not (null rest) && null (tail rest) = abort (show xs') | ||
158 | | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") | ||
159 | | otherwise = abort (show xs) | ||
160 | where | ||
161 | (xs,rest) = splitAt d xs' | ||
162 | ok = size v == d && null rest | ||
163 | v = LA.fromList xs | ||
164 | d = fromIntegral . natVal $ (undefined :: Proxy n) | ||
165 | abort info = error $ st++" "++show d++" can't be created from elements "++info | ||
166 | |||
167 | |||
168 | gmat :: forall m n t . (Show t, KnownNat m, KnownNat n, Numeric t) => String -> [t] -> M m n t | ||
169 | gmat st xs' | ||
170 | | ok = mkM x | ||
171 | | not (null rest) && null (tail rest) = abort (show xs') | ||
172 | | not (null rest) = abort (init (show (xs++take 1 rest))++", ... ]") | ||
173 | | otherwise = abort (show xs) | ||
174 | where | ||
175 | (xs,rest) = splitAt (m'*n') xs' | ||
176 | v = LA.fromList xs | ||
177 | x = reshape n' v | ||
178 | ok = rem (size v) n' == 0 && size x == (m',n') && null rest | ||
179 | m' = fromIntegral . natVal $ (undefined :: Proxy m) :: Int | ||
180 | n' = fromIntegral . natVal $ (undefined :: Proxy n) :: Int | ||
181 | abort info = error $ st ++" "++show m' ++ " " ++ show n'++" can't be created from elements " ++ info | ||
182 | |||
183 | |||
184 | class Num t => Sized t s d | s -> t, s -> d | ||
185 | where | ||
186 | konst :: t -> s | ||
187 | extract :: s -> d | ||
188 | fromList :: [t] -> s | ||
189 | expand :: s -> d | ||
190 | |||
191 | singleV v = size v == 1 | ||
192 | singleM m = rows m == 1 && cols m == 1 | ||
193 | |||