diff options
Diffstat (limited to 'packages')
-rw-r--r-- | packages/base/src/Internal/Matrix.hs | 17 | ||||
-rw-r--r-- | packages/base/src/Internal/ST.hs | 11 |
2 files changed, 8 insertions, 20 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs index bdf2785..8597dcb 100644 --- a/packages/base/src/Internal/Matrix.hs +++ b/packages/base/src/Internal/Matrix.hs | |||
@@ -142,10 +142,9 @@ a # b = apply a b | |||
142 | 142 | ||
143 | -------------------------------------------------------------------------------- | 143 | -------------------------------------------------------------------------------- |
144 | 144 | ||
145 | extractAll ord m = unsafePerformIO $ | 145 | copy ord m = extractR ord m 0 (idxs[0,rows m-1]) 0 (idxs[0,cols m-1]) |
146 | extractR ord m | 146 | |
147 | 0 (idxs[0,rows m-1]) | 147 | extractAll ord m = unsafePerformIO (copy ord m) |
148 | 0 (idxs[0,cols m-1]) | ||
149 | 148 | ||
150 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. | 149 | {- | Creates a vector by concatenation of rows. If the matrix is ColumnMajor, this operation requires a transpose. |
151 | 150 | ||
@@ -204,16 +203,6 @@ toRows m | |||
204 | sub k = subVector (k*xRow m) (cols m) (xdat m) | 203 | sub k = subVector (k*xRow m) (cols m) (xdat m) |
205 | ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) | 204 | ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m-1]) |
206 | 205 | ||
207 | {- | ||
208 | | c == 0 = replicate r (fromList[]) | ||
209 | | otherwise = toRows' 0 | ||
210 | where | ||
211 | v = flatten m | ||
212 | r = rows m | ||
213 | c = cols m | ||
214 | toRows' k | k == r*c = [] | ||
215 | | otherwise = subVector k c v : toRows' (k+c) | ||
216 | -} | ||
217 | 206 | ||
218 | -- | Creates a matrix from a list of vectors, as columns | 207 | -- | Creates a matrix from a list of vectors, as columns |
219 | fromColumns :: Element t => [Vector t] -> Matrix t | 208 | fromColumns :: Element t => [Vector t] -> Matrix t |
diff --git a/packages/base/src/Internal/ST.hs b/packages/base/src/Internal/ST.hs index 23fda99..91c2a11 100644 --- a/packages/base/src/Internal/ST.hs +++ b/packages/base/src/Internal/ST.hs | |||
@@ -119,7 +119,7 @@ ioWriteM m r c val = ioWriteV (xdat m) (r * xRow m + c * xCol m) val | |||
119 | 119 | ||
120 | newtype STMatrix s t = STMatrix (Matrix t) | 120 | newtype STMatrix s t = STMatrix (Matrix t) |
121 | 121 | ||
122 | thawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | 122 | thawMatrix :: Element t => Matrix t -> ST s (STMatrix s t) |
123 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix | 123 | thawMatrix = unsafeIOToST . fmap STMatrix . cloneMatrix |
124 | 124 | ||
125 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) | 125 | unsafeThawMatrix :: Storable t => Matrix t -> ST s (STMatrix s t) |
@@ -140,18 +140,17 @@ unsafeWriteMatrix (STMatrix x) r c = unsafeIOToST . ioWriteM x r c | |||
140 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () | 140 | modifyMatrix :: (Storable t) => STMatrix s t -> Int -> Int -> (t -> t) -> ST s () |
141 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c | 141 | modifyMatrix x r c f = readMatrix x r c >>= return . f >>= unsafeWriteMatrix x r c |
142 | 142 | ||
143 | liftSTMatrix :: (Storable t) => (Matrix t -> a) -> STMatrix s t -> ST s a | 143 | liftSTMatrix :: (Element t) => (Matrix t -> a) -> STMatrix s t -> ST s a |
144 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x | 144 | liftSTMatrix f (STMatrix x) = unsafeIOToST . fmap f . cloneMatrix $ x |
145 | 145 | ||
146 | unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) | 146 | unsafeFreezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) |
147 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x | 147 | unsafeFreezeMatrix (STMatrix x) = unsafeIOToST . return $ x |
148 | 148 | ||
149 | 149 | ||
150 | freezeMatrix :: (Storable t) => STMatrix s t -> ST s (Matrix t) | 150 | freezeMatrix :: (Element t) => STMatrix s t -> ST s (Matrix t) |
151 | freezeMatrix m = liftSTMatrix id m | 151 | freezeMatrix m = liftSTMatrix id m |
152 | 152 | ||
153 | -- FIXME | 153 | cloneMatrix m = copy (orderOf m) m |
154 | cloneMatrix m = cloneVector (xdat m) >>= return . (\d' -> m{xdat = d'}) | ||
155 | 154 | ||
156 | {-# INLINE safeIndexM #-} | 155 | {-# INLINE safeIndexM #-} |
157 | safeIndexM f (STMatrix m) r c | 156 | safeIndexM f (STMatrix m) r c |
@@ -242,7 +241,7 @@ gemmm beta (slice->(r,pr)) alpha (slice->(a,pa)) (slice->(b,pb)) = res | |||
242 | v = vjoin[pa,pb,pr] | 241 | v = vjoin[pa,pb,pr] |
243 | 242 | ||
244 | 243 | ||
245 | mutable :: Storable t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) | 244 | mutable :: Element t => (forall s . (Int, Int) -> STMatrix s t -> ST s u) -> Matrix t -> (Matrix t,u) |
246 | mutable f a = runST $ do | 245 | mutable f a = runST $ do |
247 | x <- thawMatrix a | 246 | x <- thawMatrix a |
248 | info <- f (rows a, cols a) x | 247 | info <- f (rows a, cols a) x |