summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Crayne <joe@jerkface.net>2019-08-10 03:41:07 -0400
committerJoe Crayne <joe@jerkface.net>2019-08-10 03:41:07 -0400
commit7f23aabba933c8e7ef44dbe21e35fa8fa0300f49 (patch)
tree0357fec21a7c0e4a88a96d64ba7335625601e8e7
parent145a61cc82ab66853daed8b352cb283fdcc790c5 (diff)
Internal.Matrix builds.
-rw-r--r--packages/base/src/Internal/Matrix.hs41
-rw-r--r--packages/base/src/Internal/Specialized.hs24
2 files changed, 34 insertions, 31 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 225b039..05633de 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -7,6 +7,7 @@
7{-# LANGUAGE ViewPatterns #-} 7{-# LANGUAGE ViewPatterns #-}
8{-# LANGUAGE DeriveGeneric #-} 8{-# LANGUAGE DeriveGeneric #-}
9{-# LANGUAGE ConstrainedClassMethods #-} 9{-# LANGUAGE ConstrainedClassMethods #-}
10{-# LANGUAGE ConstraintKinds #-}
10 11
11-- | 12-- |
12-- Module : Internal.Matrix 13-- Module : Internal.Matrix
@@ -29,6 +30,7 @@ import Foreign.Marshal.Array(newArray)
29import Foreign.Ptr ( Ptr ) 30import Foreign.Ptr ( Ptr )
30import Foreign.Storable ( Storable ) 31import Foreign.Storable ( Storable )
31import Data.Complex ( Complex ) 32import Data.Complex ( Complex )
33import Data.Typeable
32import Foreign.C.Types ( CInt(..) ) 34import Foreign.C.Types ( CInt(..) )
33import Foreign.C.String ( CString, newCString ) 35import Foreign.C.String ( CString, newCString )
34import System.IO.Unsafe ( unsafePerformIO ) 36import System.IO.Unsafe ( unsafePerformIO )
@@ -77,13 +79,13 @@ trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } =
77 m { irows = c, icols = r, xRow = xc, xCol = xr } 79 m { irows = c, icols = r, xRow = xc, xCol = xr }
78 80
79 81
80cmat :: (Element t) => Matrix t -> Matrix t 82cmat :: (Typeable t) => Matrix t -> Matrix t
81cmat m 83cmat m
82 | rowOrder m = m 84 | rowOrder m = m
83 | otherwise = extractAll RowMajor m 85 | otherwise = extractAll RowMajor m
84 86
85 87
86fmat :: (Element t) => Matrix t -> Matrix t 88fmat :: (Typeable t) => Matrix t -> Matrix t
87fmat m 89fmat m
88 | colOrder m = m 90 | colOrder m = m
89 | otherwise = extractAll ColumnMajor m 91 | otherwise = extractAll ColumnMajor m
@@ -100,17 +102,19 @@ a #! b = a # b # id
100 102
101-------------------------------------------------------------------------------- 103--------------------------------------------------------------------------------
102 104
103copy :: Element t => MatrixOrder -> Matrix t -> IO (Matrix t) 105copy :: Typeable t => MatrixOrder -> Matrix t -> IO (Matrix t)
104copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) 106copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1])
105 107
106extractAll :: Element t => MatrixOrder -> Matrix t -> Matrix t 108extractAll :: Typeable t => MatrixOrder -> Matrix t -> Matrix t
107extractAll ord m = unsafePerformIO (copy ord m) 109extractAll ord m = unsafePerformIO (copy ord m)
108 110
111type Element t = (Typeable t, Storable t)
112
109{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 113{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
110 114
111>>> flatten (ident 3) 115>>> flatten (ident 3)
112[1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0] 116[1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0]
113it :: (Num t, Element t) => Vector t 117it :: (Num t, Typeable t) => Vector t
114 118
115-} 119-}
116flatten :: Element t => Matrix t -> Vector t 120flatten :: Element t => Matrix t -> Vector t
@@ -120,7 +124,7 @@ flatten m
120 124
121 125
122-- | the inverse of 'Data.Packed.Matrix.fromLists' 126-- | the inverse of 'Data.Packed.Matrix.fromLists'
123toLists :: (Element t) => Matrix t -> [[t]] 127toLists :: Element t => Matrix t -> [[t]]
124toLists = map toList . toRows 128toLists = map toList . toRows
125 129
126 130
@@ -285,34 +289,11 @@ instance (Storable t, NFData t) => NFData (Matrix t)
285 d = dim v 289 d = dim v
286 v = xdat m 290 v = xdat m
287 291
288---------------------------------------------------------------
289
290type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
291
292foreign import ccall unsafe "extractD" c_extractD :: Extr Double
293foreign import ccall unsafe "extractF" c_extractF :: Extr Float
294foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double)
295foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float)
296foreign import ccall unsafe "extractI" c_extractI :: Extr CInt
297foreign import ccall unsafe "extractL" c_extractL :: Extr Z
298
299---------------------------------------------------------------
300
301--------------------------------------------------------------------------------
302
303--------------------------------------------------------------------------------
304
305--------------------------------------------------------------------------------
306
307---------------------------------------------------------------------------
308--------------------------------------------------------------------------------
309--------------------------------------------------------------------------------
310
311-------------------------------------------------------------------------------- 292--------------------------------------------------------------------------------
312-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, 293-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices,
313-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ 294-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@
314-- This function is intended to be used internally by tensor libraries. 295-- This function is intended to be used internally by tensor libraries.
315reorderVector :: Element a 296reorderVector :: Typeable a
316 => Vector CInt -- ^ @strides@: array strides 297 => Vector CInt -- ^ @strides@: array strides
317 -> Vector CInt -- ^ @dims@: array dimensions of new array @v@ 298 -> Vector CInt -- ^ @dims@: array dimensions of new array @v@
318 -> Vector a -- ^ @v@: flattened input array 299 -> Vector a -- ^ @v@: flattened input array
diff --git a/packages/base/src/Internal/Specialized.hs b/packages/base/src/Internal/Specialized.hs
index c063369..46587d2 100644
--- a/packages/base/src/Internal/Specialized.hs
+++ b/packages/base/src/Internal/Specialized.hs
@@ -10,7 +10,29 @@
10{-# LANGUAGE KindSignatures #-} 10{-# LANGUAGE KindSignatures #-}
11{-# LANGUAGE ViewPatterns #-} 11{-# LANGUAGE ViewPatterns #-}
12{-# LANGUAGE LambdaCase #-} 12{-# LANGUAGE LambdaCase #-}
13module Internal.Specialized where 13module Internal.Specialized
14 ( Mod(..)
15 , MatrixOrder(..)
16 , Matrix(..)
17 , createMatrix
18 , matrixFromVector
19 , cols
20 , rows
21 , size
22 , shSize
23 , shDim
24 , constantD
25 , extractR
26 , setRect
27 , sortI
28 , sortV
29 , compareV
30 , selectV
31 , remapM
32 , rowOp
33 , gemm
34 , reorderV
35 ) where
14 36
15import Control.Monad 37import Control.Monad
16import Control.DeepSeq ( NFData(..) ) 38import Control.DeepSeq ( NFData(..) )