summaryrefslogtreecommitdiff
path: root/packages
diff options
context:
space:
mode:
Diffstat (limited to 'packages')
-rw-r--r--packages/base/src/Internal/Matrix.hs17
-rw-r--r--packages/base/src/Internal/ST.hs11
2 files changed, 8 insertions, 20 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index bdf2785..8597dcb 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -142,10 +142,9 @@ a # b = apply a b
142 142
143-------------------------------------------------------------------------------- 143--------------------------------------------------------------------------------
144 144
145extractAll ord m = unsafePerformIO $ 145copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1])
146 extractR ord m 146
147 0 (idxs[0,rows m-1]) 147extractAll ord m = unsafePerformIO (copy ord m)
148 0 (idxs[0,cols m-1])
149 148
150{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. 149{- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose.
151 150
@@ -204,16 +203,6 @@ toRows m
204 sub k = subVector (k*xRow m) (cols m) (xdat m) 203 sub k = subVector (k*xRow m) (cols m) (xdat m)
205 ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) 204 ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1])
206 205
207{-
208 | c == 0 = replicate r (fromList[])
209 | otherwise = toRows' 0
210 where
211 v = flatten m
212 r = rows m
213 c = cols m
214 toRows' k | k == r*c = []
215 | otherwise = subVector k c v : toRows' (k+c)
216-}
217 206
218-- | Creates a matrix from a list of vectors, as columns 207-- | Creates a matrix from a list of vectors, as columns
219fromColumns :: Element t => [Vector t] -> Matrix t 208fromColumns :: Element t => [Vector t] -> Matrix t
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs
index 23fda99..91c2a11 100644
--- a/packages/base/src/Internal/ST.hs
+++ b/packages/base/src/Internal/ST.hs
@@ -119,7 +119,7 @@ ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val
119 119
120newtype STMatrix s t = STMatrix (Matrix t) 120newtype STMatrix s t = STMatrix (Matrix t)
121 121
122thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) 122thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t)
123thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix 123thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix
124 124
125unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) 125unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t)
@@ -140,18 +140,17 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c
140modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () 140modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
141modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c 141modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c
142 142
143liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a 143liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a
144liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x 144liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x
145 145
146unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) 146unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t)
147unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x 147unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x
148 148
149 149
150freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) 150freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t)
151freezeMatrix m = liftSTMatrix id m 151freezeMatrix m = liftSTMatrix id m
152 152
153-- FIXME 153cloneMatrix m = copy (orderOf m) m
154cloneMatrix m = cloneVector (xdat m) >>= return . (\d' -> m{xdat = d'})
155 154
156{-# INLINE safeIndexM #-} 155{-# INLINE safeIndexM #-}
157safeIndexM f (STMatrix m) r c 156safeIndexM f (STMatrix m) r c
@@ -242,7 +241,7 @@ gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res
242 v = vjoin[pa,pb,pr] 241 v = vjoin[pa,pb,pr]
243 242
244 243
245mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) 244mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u)
246mutable f a = runST $ do 245mutable f a = runST $ do
247 x <- thawMatrix a 246 x <- thawMatrix a
248 info <- f (rows a, cols a) x 247 info <- f (rows a, cols a) x