diff options
Diffstat (limited to 'packages')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 41 | ||||
-rw-r--r-- | packages/base/src/Internal/Specialized.hs | 24 |
2 files changed, 34 insertions, 31 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index 225b039..05633de 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -7,6 +7,7 @@ | |||
7 | {-# LANGUAGE ViewPatterns #-} | 7 | {-# LANGUAGE ViewPatterns #-} |
8 | {-# LANGUAGE DeriveGeneric #-} | 8 | {-# LANGUAGE DeriveGeneric #-} |
9 | {-# LANGUAGE ConstrainedClassMethods #-} | 9 | {-# LANGUAGE ConstrainedClassMethods #-} |
10 | {-# LANGUAGE ConstraintKinds #-} | ||
10 | 11 | ||
11 | -- | | 12 | -- | |
12 | -- Module : Internal.Matrix | 13 | -- Module : Internal.Matrix |
@@ -29,6 +30,7 @@ import Foreign.Marshal.Array(newArray) | |||
29 | import Foreign.Ptr ( Ptr ) | 30 | import Foreign.Ptr ( Ptr ) |
30 | import Foreign.Storable ( Storable ) | 31 | import Foreign.Storable ( Storable ) |
31 | import Data.Complex ( Complex ) | 32 | import Data.Complex ( Complex ) |
33 | import Data.Typeable | ||
32 | import Foreign.C.Types ( CInt(..) ) | 34 | import Foreign.C.Types ( CInt(..) ) |
33 | import Foreign.C.String ( CString, newCString ) | 35 | import Foreign.C.String ( CString, newCString ) |
34 | import System.IO.Unsafe ( unsafePerformIO ) | 36 | import System.IO.Unsafe ( unsafePerformIO ) |
@@ -77,13 +79,13 @@ trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } = | |||
77 | m { irows = c, icols = r, xRow = xc, xCol = xr } | 79 | m { irows = c, icols = r, xRow = xc, xCol = xr } |
78 | 80 | ||
79 | 81 | ||
80 | cmat :: (Element t) => Matrix t -> Matrix t | 82 | cmat :: (Typeable t) => Matrix t -> Matrix t |
81 | cmat m | 83 | cmat m |
82 | | rowOrder m = m | 84 | | rowOrder m = m |
83 | | otherwise = extractAll RowMajor m | 85 | | otherwise = extractAll RowMajor m |
84 | 86 | ||
85 | 87 | ||
86 | fmat :: (Element t) => Matrix t -> Matrix t | 88 | fmat :: (Typeable t) => Matrix t -> Matrix t |
87 | fmat m | 89 | fmat m |
88 | | colOrder m = m | 90 | | colOrder m = m |
89 | | otherwise = extractAll ColumnMajor m | 91 | | otherwise = extractAll ColumnMajor m |
@@ -100,17 +102,19 @@ a #! b = a # b # id | |||
100 | 102 | ||
101 | -------------------------------------------------------------------------------- | 103 | -------------------------------------------------------------------------------- |
102 | 104 | ||
103 | copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) | 105 | copy :: Typeable t => MatrixOrder -> Matrix t -> IO (Matrix t) |
104 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) | 106 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) |
105 | 107 | ||
106 | extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t | 108 | extractAll :: Typeable t => MatrixOrder -> Matrix t -> Matrix t |
107 | extractAll ord m = unsafePerformIO (copy ord m) | 109 | extractAll ord m = unsafePerformIO (copy ord m) |
108 | 110 | ||
111 | type Element t = (Typeable t, Storable t) | ||
112 | |||
109 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 113 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
110 | 114 | ||
111 | >>> flatten (ident 3) | 115 | >>> flatten (ident 3) |
112 | [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] | 116 | [1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] |
113 | it :: (Num t, Element t) => Vector t | 117 | it :: (Num t, Typeable t) => Vector t |
114 | 118 | ||
115 | -} | 119 | -} |
116 | flatten :: Element t => Matrix t -> Vector t | 120 | flatten :: Element t => Matrix t -> Vector t |
@@ -120,7 +124,7 @@ flatten m | |||
120 | 124 | ||
121 | 125 | ||
122 | -- | the inverse of 'Data.Packed.Matrix.fromLists' | 126 | -- | the inverse of 'Data.Packed.Matrix.fromLists' |
123 | toLists :: (Element t) => Matrix t -> [[t]] | 127 | toLists :: Element t => Matrix t -> [[t]] |
124 | toLists = map toList . toRows | 128 | toLists = map toList . toRows |
125 | 129 | ||
126 | 130 | ||
@@ -285,34 +289,11 @@ instance (Storable t, NFData t) => NFData (Matrix t) | |||
285 | d = dim v | 289 | d = dim v |
286 | v = xdat m | 290 | v = xdat m |
287 | 291 | ||
288 | --------------------------------------------------------------- | ||
289 | |||
290 | type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) | ||
291 | |||
292 | foreign import ccall unsafe "extractD" c_extractD :: Extr Double | ||
293 | foreign import ccall unsafe "extractF" c_extractF :: Extr Float | ||
294 | foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double) | ||
295 | foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float) | ||
296 | foreign import ccall unsafe "extractI" c_extractI :: Extr CInt | ||
297 | foreign import ccall unsafe "extractL" c_extractL :: Extr Z | ||
298 | |||
299 | --------------------------------------------------------------- | ||
300 | |||
301 | -------------------------------------------------------------------------------- | ||
302 | |||
303 | -------------------------------------------------------------------------------- | ||
304 | |||
305 | -------------------------------------------------------------------------------- | ||
306 | |||
307 | --------------------------------------------------------------------------- | ||
308 | -------------------------------------------------------------------------------- | ||
309 | -------------------------------------------------------------------------------- | ||
310 | |||
311 | -------------------------------------------------------------------------------- | 292 | -------------------------------------------------------------------------------- |
312 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, | 293 | -- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, |
313 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ | 294 | -- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ |
314 | -- This function is intended to be used internally by tensor libraries. | 295 | -- This function is intended to be used internally by tensor libraries. |
315 | reorderVector :: Element a | 296 | reorderVector :: Typeable a |
316 | => Vector CInt -- ^ @strides@: array strides | 297 | => Vector CInt -- ^ @strides@: array strides |
317 | -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ | 298 | -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ |
318 | -> Vector a -- ^ @v@: flattened input array | 299 | -> Vector a -- ^ @v@: flattened input array |
diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs index c063369..46587d2 100644 --- a/packages/base/src/Internal/Specialized.hs +++ b/packages/base/src/Internal/Specialized.hs | |||
@@ -10,7 +10,29 @@ | |||
10 | {-# LANGUAGE KindSignatures #-} | 10 | {-# LANGUAGE KindSignatures #-} |
11 | {-# LANGUAGE ViewPatterns #-} | 11 | {-# LANGUAGE ViewPatterns #-} |
12 | {-# LANGUAGE LambdaCase #-} | 12 | {-# LANGUAGE LambdaCase #-} |
13 | module Internal.Specialized where | 13 | module Internal.Specialized |
14 | ( Mod(..) | ||
15 | , MatrixOrder(..) | ||
16 | , Matrix(..) | ||
17 | , createMatrix | ||
18 | , matrixFromVector | ||
19 | , cols | ||
20 | , rows | ||
21 | , size | ||
22 | , shSize | ||
23 | , shDim | ||
24 | , constantD | ||
25 | , extractR | ||
26 | , setRect | ||
27 | , sortI | ||
28 | , sortV | ||
29 | , compareV | ||
30 | , selectV | ||
31 | , remapM | ||
32 | , rowOp | ||
33 | , gemm | ||
34 | , reorderV | ||
35 | ) where | ||
14 | 36 | ||
15 | import Control.Monad | 37 | import Control.Monad |
16 | import Control.DeepSeq ( NFData(..) ) | 38 | import Control.DeepSeq ( NFData(..) ) |