summaryrefslogtreecommitdiff
path: root/packages/base/src/Internal/Matrix.hs
diff options
context:
space:
mode:
Diffstat (limited to 'packages/base/src/Internal/Matrix.hs')
-rw-r--r--packages/base/src/Internal/Matrix.hs382
1 files changed, 0 insertions, 382 deletions
diff --git a/packages/base/src/Internal/Matrix.hs b/packages/base/src/Internal/Matrix.hs
index 7c774ef..225b039 100644
--- a/packages/base/src/Internal/Matrix.hs
+++ b/packages/base/src/Internal/Matrix.hs
@@ -37,31 +37,6 @@ import Text.Printf
37 37
38----------------------------------------------------------------- 38-----------------------------------------------------------------
39 39
40data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
41
42-- | Matrix representation suitable for BLAS\/LAPACK computations.
43
44data Matrix t = Matrix
45 { irows :: {-# UNPACK #-} !Int
46 , icols :: {-# UNPACK #-} !Int
47 , xRow :: {-# UNPACK #-} !Int
48 , xCol :: {-# UNPACK #-} !Int
49 , xdat :: {-# UNPACK #-} !(Vector t)
50 }
51
52
53rows :: Matrix t -> Int
54rows = irows
55{-# INLINE rows #-}
56
57cols :: Matrix t -> Int
58cols = icols
59{-# INLINE cols #-}
60
61size :: Matrix t -> (Int, Int)
62size m = (irows m, icols m)
63{-# INLINE size #-}
64
65rowOrder :: Matrix t -> Bool 40rowOrder :: Matrix t -> Bool
66rowOrder m = xCol m == 1 || cols m == 1 41rowOrder m = xCol m == 1 || cols m == 1
67{-# INLINE rowOrder #-} 42{-# INLINE rowOrder #-}
@@ -114,33 +89,6 @@ fmat m
114 | otherwise = extractAll ColumnMajor m 89 | otherwise = extractAll ColumnMajor m
115 90
116 91
117-- C-Haskell matrix adapters
118{-# INLINE amatr #-}
119amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r
120amatr x f g = unsafeWith (xdat x) (f . g r c)
121 where
122 r = fi (rows x)
123 c = fi (cols x)
124
125{-# INLINE amat #-}
126amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r
127amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
128 where
129 r = fi (rows x)
130 c = fi (cols x)
131 sr = fi (xRow x)
132 sc = fi (xCol x)
133
134
135instance Storable t => TransArray (Matrix t)
136 where
137 type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b
138 type Trans (Matrix t) b = CInt -> CInt -> CInt -> CInt -> Ptr t -> b
139 apply = amat
140 {-# INLINE apply #-}
141 applyRaw = amatr
142 {-# INLINE applyRaw #-}
143
144infixr 1 # 92infixr 1 #
145(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r 93(#) :: TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
146a # b = apply a b 94a # b = apply a b
@@ -240,22 +188,6 @@ atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
240 188
241------------------------------------------------------------------ 189------------------------------------------------------------------
242 190
243matrixFromVector :: Storable t => MatrixOrder -> Int -> Int -> Vector t -> Matrix t
244matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
245matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
246matrixFromVector o r c v
247 | r * c == dim v = m
248 | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
249 where
250 m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 }
251 | otherwise = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r }
252
253-- allocates memory for a new matrix
254createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
255createMatrix ord r c = do
256 p <- createVector (r*c)
257 return (matrixFromVector ord r c p)
258
259{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = tr' . reshape r@ 191{- | Creates a matrix from a vector by grouping the elements in rows with the desired number of columns. (GNU-Octave groups by columns. To do it you can define @reshapeF r = tr' . reshape r@
260where r is the desired number of rows.) 192where r is the desired number of rows.)
261 193
@@ -286,101 +218,6 @@ liftMatrix2 f m1@(size->(r,c)) m2
286 218
287------------------------------------------------------------------ 219------------------------------------------------------------------
288 220
289{-
290-- | Supported matrix elements.
291class (Storable a) => Element a where
292 constantD :: a -> Int -> Vector a
293 extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
294 setRect :: Int -> Int -> Matrix a -> Matrix a -> IO ()
295 sortI :: Ord a => Vector a -> Vector CInt
296 sortV :: Ord a => Vector a -> Vector a
297 compareV :: Ord a => Vector a -> Vector a -> Vector CInt
298 selectV :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
299 remapM :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
300 rowOp :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
301 gemm :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
302 reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a -- see reorderVector for documentation
303
304
305instance Element Float where
306 constantD = constantAux cconstantF
307 extractR = extractAux c_extractF
308 setRect = setRectAux c_setRectF
309 sortI = sortIdxF
310 sortV = sortValF
311 compareV = compareF
312 selectV = selectF
313 remapM = remapF
314 rowOp = rowOpAux c_rowOpF
315 gemm = gemmg c_gemmF
316 reorderV = reorderAux c_reorderF
317
318instance Element Double where
319 constantD = constantAux cconstantR
320 extractR = extractAux c_extractD
321 setRect = setRectAux c_setRectD
322 sortI = sortIdxD
323 sortV = sortValD
324 compareV = compareD
325 selectV = selectD
326 remapM = remapD
327 rowOp = rowOpAux c_rowOpD
328 gemm = gemmg c_gemmD
329 reorderV = reorderAux c_reorderD
330
331instance Element (Complex Float) where
332 constantD = constantAux cconstantQ
333 extractR = extractAux c_extractQ
334 setRect = setRectAux c_setRectQ
335 sortI = undefined
336 sortV = undefined
337 compareV = undefined
338 selectV = selectQ
339 remapM = remapQ
340 rowOp = rowOpAux c_rowOpQ
341 gemm = gemmg c_gemmQ
342 reorderV = reorderAux c_reorderQ
343
344instance Element (Complex Double) where
345 constantD = constantAux cconstantC
346 extractR = extractAux c_extractC
347 setRect = setRectAux c_setRectC
348 sortI = undefined
349 sortV = undefined
350 compareV = undefined
351 selectV = selectC
352 remapM = remapC
353 rowOp = rowOpAux c_rowOpC
354 gemm = gemmg c_gemmC
355 reorderV = reorderAux c_reorderC
356
357instance Element (CInt) where
358 constantD = constantAux cconstantI
359 extractR = extractAux c_extractI
360 setRect = setRectAux c_setRectI
361 sortI = sortIdxI
362 sortV = sortValI
363 compareV = compareI
364 selectV = selectI
365 remapM = remapI
366 rowOp = rowOpAux c_rowOpI
367 gemm = gemmg c_gemmI
368 reorderV = reorderAux c_reorderI
369
370instance Element Z where
371 constantD = constantAux cconstantL
372 extractR = extractAux c_extractL
373 setRect = setRectAux c_setRectL
374 sortI = sortIdxL
375 sortV = sortValL
376 compareV = compareL
377 selectV = selectL
378 remapM = remapL
379 rowOp = rowOpAux c_rowOpL
380 gemm = gemmg c_gemmL
381 reorderV = reorderAux c_reorderL
382-}
383
384------------------------------------------------------------------- 221-------------------------------------------------------------------
385 222
386-- | reference to a rectangular slice of a matrix (no data copy) 223-- | reference to a rectangular slice of a matrix (no data copy)
@@ -435,12 +272,6 @@ repRows n x = fromRows (replicate n (flatten x))
435repCols :: Element t => Int -> Matrix t -> Matrix t 272repCols :: Element t => Int -> Matrix t -> Matrix t
436repCols n x = fromColumns (replicate n (flatten x)) 273repCols n x = fromColumns (replicate n (flatten x))
437 274
438shSize :: Matrix t -> [Char]
439shSize = shDim . size
440
441shDim :: (Show a, Show a1) => (a1, a) -> [Char]
442shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
443
444emptyM :: Storable t => Int -> Int -> Matrix t 275emptyM :: Storable t => Int -> Int -> Matrix t
445emptyM r c = matrixFromVector RowMajor r c (fromList[]) 276emptyM r c = matrixFromVector RowMajor r c (fromList[])
446 277
@@ -456,19 +287,6 @@ instance (Storable t, NFData t) => NFData (Matrix t)
456 287
457--------------------------------------------------------------- 288---------------------------------------------------------------
458 289
459extractAux :: (Eq t3, Eq t2, TransArray c, Storable a, Storable t1,
460 Storable t, Num t3, Num t2, Integral t1, Integral t)
461 => (t3 -> t2 -> CInt -> Ptr t1 -> CInt -> Ptr t
462 -> Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt))
463 -> MatrixOrder -> c -> t3 -> Vector t1 -> t2 -> Vector t -> IO (Matrix a)
464extractAux f ord m moder vr modec vc = do
465 let nr = if moder == 0 then fromIntegral $ vr@>1 - vr@>0 + 1 else dim vr
466 nc = if modec == 0 then fromIntegral $ vc@>1 - vc@>0 + 1 else dim vc
467 r <- createMatrix ord nr nc
468 (vr # vc # m #! r) (f moder modec) #|"extract"
469
470 return r
471
472type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt)))) 290type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
473 291
474foreign import ccall unsafe "extractD" c_extractD :: Extr Double 292foreign import ccall unsafe "extractD" c_extractD :: Extr Double
@@ -480,217 +298,17 @@ foreign import ccall unsafe "extractL" c_extractL :: Extr Z
480 298
481--------------------------------------------------------------- 299---------------------------------------------------------------
482 300
483setRectAux :: (TransArray c1, TransArray c)
484 => (CInt -> CInt -> Trans c1 (Trans c (IO CInt)))
485 -> Int -> Int -> c1 -> c -> IO ()
486setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
487
488type SetRect x = I -> I -> x ::> x::> Ok
489
490foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double
491foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float
492foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double)
493foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float)
494foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I
495foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
496
497-------------------------------------------------------------------------------- 301--------------------------------------------------------------------------------
498 302
499sortG :: (Storable t, Storable a)
500 => (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt) -> Vector t -> Vector a
501sortG f v = unsafePerformIO $ do
502 r <- createVector (dim v)
503 (v #! r) f #|"sortG"
504 return r
505
506sortIdxD :: Vector Double -> Vector CInt
507sortIdxD = sortG c_sort_indexD
508sortIdxF :: Vector Float -> Vector CInt
509sortIdxF = sortG c_sort_indexF
510sortIdxI :: Vector CInt -> Vector CInt
511sortIdxI = sortG c_sort_indexI
512sortIdxL :: Vector Z -> Vector I
513sortIdxL = sortG c_sort_indexL
514
515sortValD :: Vector Double -> Vector Double
516sortValD = sortG c_sort_valD
517sortValF :: Vector Float -> Vector Float
518sortValF = sortG c_sort_valF
519sortValI :: Vector CInt -> Vector CInt
520sortValI = sortG c_sort_valI
521sortValL :: Vector Z -> Vector Z
522sortValL = sortG c_sort_valL
523
524foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
525foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float (CV CInt (IO CInt))
526foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt (CV CInt (IO CInt))
527foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok
528
529foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt))
530foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float (CV Float (IO CInt))
531foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt (CV CInt (IO CInt))
532foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
533
534-------------------------------------------------------------------------------- 303--------------------------------------------------------------------------------
535 304
536compareG :: (TransArray c, Storable t, Storable a)
537 => Trans c (CInt -> Ptr t -> CInt -> Ptr a -> IO CInt)
538 -> c -> Vector t -> Vector a
539compareG f u v = unsafePerformIO $ do
540 r <- createVector (dim v)
541 (u # v #! r) f #|"compareG"
542 return r
543
544compareD :: Vector Double -> Vector Double -> Vector CInt
545compareD = compareG c_compareD
546compareF :: Vector Float -> Vector Float -> Vector CInt
547compareF = compareG c_compareF
548compareI :: Vector CInt -> Vector CInt -> Vector CInt
549compareI = compareG c_compareI
550compareL :: Vector Z -> Vector Z -> Vector CInt
551compareL = compareG c_compareL
552
553foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
554foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float (CV CInt (IO CInt)))
555foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt (CV CInt (IO CInt)))
556foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
557
558-------------------------------------------------------------------------------- 305--------------------------------------------------------------------------------
559 306
560selectG :: (TransArray c, TransArray c1, TransArray c2, Storable t, Storable a)
561 => Trans c2 (Trans c1 (CInt -> Ptr t -> Trans c (CInt -> Ptr a -> IO CInt)))
562 -> c2 -> c1 -> Vector t -> c -> Vector a
563selectG f c u v w = unsafePerformIO $ do
564 r <- createVector (dim v)
565 (c # u # v # w #! r) f #|"selectG"
566 return r
567
568selectD :: Vector CInt -> Vector Double -> Vector Double -> Vector Double -> Vector Double
569selectD = selectG c_selectD
570selectF :: Vector CInt -> Vector Float -> Vector Float -> Vector Float -> Vector Float
571selectF = selectG c_selectF
572selectI :: Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt -> Vector CInt
573selectI = selectG c_selectI
574selectL :: Vector CInt -> Vector Z -> Vector Z -> Vector Z -> Vector Z
575selectL = selectG c_selectL
576selectC :: Vector CInt
577 -> Vector (Complex Double)
578 -> Vector (Complex Double)
579 -> Vector (Complex Double)
580 -> Vector (Complex Double)
581selectC = selectG c_selectC
582selectQ :: Vector CInt
583 -> Vector (Complex Float)
584 -> Vector (Complex Float)
585 -> Vector (Complex Float)
586 -> Vector (Complex Float)
587selectQ = selectG c_selectQ
588
589type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
590
591foreign import ccall unsafe "chooseD" c_selectD :: Sel Double
592foreign import ccall unsafe "chooseF" c_selectF :: Sel Float
593foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt
594foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
595foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
596foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
597
598--------------------------------------------------------------------------- 307---------------------------------------------------------------------------
599
600remapG :: (TransArray c, TransArray c1, Storable t, Storable a)
601 => (CInt -> CInt -> CInt -> CInt -> Ptr t
602 -> Trans c1 (Trans c (CInt -> CInt -> CInt -> CInt -> Ptr a -> IO CInt)))
603 -> Matrix t -> c1 -> c -> Matrix a
604remapG f i j m = unsafePerformIO $ do
605 r <- createMatrix RowMajor (rows i) (cols i)
606 (i # j # m #! r) f #|"remapG"
607 return r
608
609remapD :: Matrix CInt -> Matrix CInt -> Matrix Double -> Matrix Double
610remapD = remapG c_remapD
611remapF :: Matrix CInt -> Matrix CInt -> Matrix Float -> Matrix Float
612remapF = remapG c_remapF
613remapI :: Matrix CInt -> Matrix CInt -> Matrix CInt -> Matrix CInt
614remapI = remapG c_remapI
615remapL :: Matrix CInt -> Matrix CInt -> Matrix Z -> Matrix Z
616remapL = remapG c_remapL
617remapC :: Matrix CInt
618 -> Matrix CInt
619 -> Matrix (Complex Double)
620 -> Matrix (Complex Double)
621remapC = remapG c_remapC
622remapQ :: Matrix CInt -> Matrix CInt -> Matrix (Complex Float) -> Matrix (Complex Float)
623remapQ = remapG c_remapQ
624
625type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
626
627foreign import ccall unsafe "remapD" c_remapD :: Rem Double
628foreign import ccall unsafe "remapF" c_remapF :: Rem Float
629foreign import ccall unsafe "remapI" c_remapI :: Rem CInt
630foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double)
631foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float)
632foreign import ccall unsafe "remapL" c_remapL :: Rem Z
633
634-------------------------------------------------------------------------------- 308--------------------------------------------------------------------------------
635
636rowOpAux :: (TransArray c, Storable a) =>
637 (CInt -> Ptr a -> CInt -> CInt -> CInt -> CInt -> Trans c (IO CInt))
638 -> Int -> a -> Int -> Int -> Int -> Int -> c -> IO ()
639rowOpAux f c x i1 i2 j1 j2 m = do
640 px <- newArray [x]
641 (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
642 free px
643
644type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
645
646foreign import ccall unsafe "rowop_double" c_rowOpD :: RowOp R
647foreign import ccall unsafe "rowop_float" c_rowOpF :: RowOp Float
648foreign import ccall unsafe "rowop_TCD" c_rowOpC :: RowOp C
649foreign import ccall unsafe "rowop_TCF" c_rowOpQ :: RowOp (Complex Float)
650foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I
651foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z
652foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I
653foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
654
655-------------------------------------------------------------------------------- 309--------------------------------------------------------------------------------
656 310
657gemmg :: (TransArray c1, TransArray c, TransArray c2, TransArray c3)
658 => Trans c3 (Trans c2 (Trans c1 (Trans c (IO CInt))))
659 -> c3 -> c2 -> c1 -> c -> IO ()
660gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
661
662type Tgemm x = x :> x ::> x ::> x ::> Ok
663
664foreign import ccall unsafe "gemm_double" c_gemmD :: Tgemm R
665foreign import ccall unsafe "gemm_float" c_gemmF :: Tgemm Float
666foreign import ccall unsafe "gemm_TCD" c_gemmC :: Tgemm C
667foreign import ccall unsafe "gemm_TCF" c_gemmQ :: Tgemm (Complex Float)
668foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I
669foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z
670foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I
671foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
672
673-------------------------------------------------------------------------------- 311--------------------------------------------------------------------------------
674
675reorderAux :: (TransArray c, Storable t, Storable a1, Storable t1, Storable a) =>
676 (CInt -> Ptr a -> CInt -> Ptr t1
677 -> Trans c (CInt -> Ptr t -> CInt -> Ptr a1 -> IO CInt))
678 -> Vector t1 -> c -> Vector t -> Vector a1
679reorderAux f s d v = unsafePerformIO $ do
680 k <- createVector (dim s)
681 r <- createVector (dim v)
682 (k # s # d # v #! r) f #| "reorderV"
683 return r
684
685type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt)))))
686
687foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double
688foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float
689foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt
690foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double)
691foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float)
692foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
693
694-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices, 312-- | Transpose an array with dimensions @dims@ by making a copy using @strides@. For example, for an array with 3 indices,
695-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@ 313-- @(reorderVector strides dims v) ! ((i * dims ! 1 + j) * dims ! 2 + k) == v ! (i * strides ! 0 + j * strides ! 1 + k * strides ! 2)@
696-- This function is intended to be used internally by tensor libraries. 314-- This function is intended to be used internally by tensor libraries.