diff options
author | Alberto Ruiz <aruiz@um.es> | 2015-06-17 13:02:40 +0200 |
---|---|---|
committer | Alberto Ruiz <aruiz@um.es> | 2015-06-17 13:02:40 +0200 |
commit | 61d90ff66af8bfe53ef8cdda8dfe1e70463c213c (patch) | |
tree | 79ca6e2024731708a4378ff691b5317d8da11d11 /packages/base/src/Internal/Matrix.hs | |
parent | e7d2916f78b5c140738fc4f4f95c9b13c1768293 (diff) |
gemmm
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 75e92a5..8f8c219 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -274,6 +274,7 @@ class (Storable a) => Element a where | |||
274 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a | 274 | selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a |
275 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a | 275 | remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a |
276 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () | 276 | rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO () |
277 | gemm :: Vector a -> Vector I -> Matrix a -> Matrix a -> Matrix a -> IO () | ||
277 | 278 | ||
278 | 279 | ||
279 | instance Element Float where | 280 | instance Element Float where |
@@ -287,6 +288,7 @@ instance Element Float where | |||
287 | selectV = selectF | 288 | selectV = selectF |
288 | remapM = remapF | 289 | remapM = remapF |
289 | rowOp = rowOpAux c_rowOpF | 290 | rowOp = rowOpAux c_rowOpF |
291 | gemm = gemmg c_gemmF | ||
290 | 292 | ||
291 | instance Element Double where | 293 | instance Element Double where |
292 | transdata = transdataAux ctransR | 294 | transdata = transdataAux ctransR |
@@ -299,7 +301,7 @@ instance Element Double where | |||
299 | selectV = selectD | 301 | selectV = selectD |
300 | remapM = remapD | 302 | remapM = remapD |
301 | rowOp = rowOpAux c_rowOpD | 303 | rowOp = rowOpAux c_rowOpD |
302 | 304 | gemm = gemmg c_gemmD | |
303 | 305 | ||
304 | instance Element (Complex Float) where | 306 | instance Element (Complex Float) where |
305 | transdata = transdataAux ctransQ | 307 | transdata = transdataAux ctransQ |
@@ -312,7 +314,7 @@ instance Element (Complex Float) where | |||
312 | selectV = selectQ | 314 | selectV = selectQ |
313 | remapM = remapQ | 315 | remapM = remapQ |
314 | rowOp = rowOpAux c_rowOpQ | 316 | rowOp = rowOpAux c_rowOpQ |
315 | 317 | gemm = gemmg c_gemmQ | |
316 | 318 | ||
317 | instance Element (Complex Double) where | 319 | instance Element (Complex Double) where |
318 | transdata = transdataAux ctransC | 320 | transdata = transdataAux ctransC |
@@ -325,6 +327,7 @@ instance Element (Complex Double) where | |||
325 | selectV = selectC | 327 | selectV = selectC |
326 | remapM = remapC | 328 | remapM = remapC |
327 | rowOp = rowOpAux c_rowOpC | 329 | rowOp = rowOpAux c_rowOpC |
330 | gemm = gemmg c_gemmC | ||
328 | 331 | ||
329 | instance Element (CInt) where | 332 | instance Element (CInt) where |
330 | transdata = transdataAux ctransI | 333 | transdata = transdataAux ctransI |
@@ -337,6 +340,7 @@ instance Element (CInt) where | |||
337 | selectV = selectI | 340 | selectV = selectI |
338 | remapM = remapI | 341 | remapM = remapI |
339 | rowOp = rowOpAux c_rowOpI | 342 | rowOp = rowOpAux c_rowOpI |
343 | gemm = gemmg c_gemmI | ||
340 | 344 | ||
341 | instance Element Z where | 345 | instance Element Z where |
342 | transdata = transdataAux ctransL | 346 | transdata = transdataAux ctransL |
@@ -349,6 +353,7 @@ instance Element Z where | |||
349 | selectV = selectL | 353 | selectV = selectL |
350 | remapM = remapL | 354 | remapM = remapL |
351 | rowOp = rowOpAux c_rowOpL | 355 | rowOp = rowOpAux c_rowOpL |
356 | gemm = gemmg c_gemmL | ||
352 | 357 | ||
353 | ------------------------------------------------------------------- | 358 | ------------------------------------------------------------------- |
354 | 359 | ||
@@ -575,6 +580,21 @@ foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z | |||
575 | 580 | ||
576 | -------------------------------------------------------------------------------- | 581 | -------------------------------------------------------------------------------- |
577 | 582 | ||
583 | gemmg f u v m1 m2 m3 = app5 f vec u vec v omat m1 omat m2 omat m3 "gemmg" | ||
584 | |||
585 | type Tgemm x = x :> I :> x ::> x ::> x ::> Ok | ||
586 | |||
587 | foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R | ||
588 | foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float | ||
589 | foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C | ||
590 | foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float) | ||
591 | foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I | ||
592 | foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z | ||
593 | foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I | ||
594 | foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z | ||
595 | |||
596 | -------------------------------------------------------------------------------- | ||
597 | |||
578 | foreign import ccall unsafe "saveMatrix" c_saveMatrix | 598 | foreign import ccall unsafe "saveMatrix" c_saveMatrix |
579 | :: CString -> CString -> Double ..> Ok | 599 | :: CString -> CString -> Double ..> Ok |
580 | 600 | ||