Renovate BLAS matrices

db4
Joe Groff 2008-12-04 16:08:01 -08:00
parent bd59b86ad6
commit ec76a0bfff
4 changed files with 167 additions and 158 deletions

View File

@ -88,7 +88,7 @@ HELP: blas-matrix-base
}
"All of these subclasses share the same tuple layout:"
{ $list
{ { $snippet "data" } " contains an alien pointer referencing or byte-array containing a packed, column-major array of float, double, float complex, or double complex values;" }
{ { $snippet "underlying" } " contains an alien pointer referencing or byte-array containing a packed, column-major array of float, double, float complex, or double complex values;" }
{ { $snippet "ld" } " indicates the distance, in elements, between matrix columns;" }
{ { $snippet "rows" } " and " { $snippet "cols" } " indicate the number of significant rows and columns in the matrix;" }
{ "and " { $snippet "transpose" } ", if set to a true value, indicates that the matrix should be treated as transposed relative to its in-memory representation." }

View File

@ -1,31 +1,13 @@
USING: accessors alien alien.c-types arrays byte-arrays combinators
combinators.lib combinators.short-circuit fry kernel locals macros
combinators.short-circuit fry kernel locals macros
math math.blas.cblas math.blas.vectors math.blas.vectors.private
math.complex math.functions math.order multi-methods qualified
sequences sequences.merged sequences.private generalizations
shuffle symbols speicalized-arrays.float specialized-arrays.double ;
QUALIFIED: syntax
math.complex math.functions math.order functors words
sequences sequences.merged sequences.private shuffle symbols
specialized-arrays.direct.float specialized-arrays.direct.double
specialized-arrays.float specialized-arrays.double ;
IN: math.blas.matrices
TUPLE: blas-matrix-base data ld rows cols transpose ;
TUPLE: float-blas-matrix < blas-matrix-base ;
TUPLE: double-blas-matrix < blas-matrix-base ;
TUPLE: float-complex-blas-matrix < blas-matrix-base ;
TUPLE: double-complex-blas-matrix < blas-matrix-base ;
C: <float-blas-matrix> float-blas-matrix
C: <double-blas-matrix> double-blas-matrix
C: <float-complex-blas-matrix> float-complex-blas-matrix
C: <double-complex-blas-matrix> double-complex-blas-matrix
METHOD: element-type { float-blas-matrix }
drop "float" ;
METHOD: element-type { double-blas-matrix }
drop "double" ;
METHOD: element-type { float-complex-blas-matrix }
drop "CBLAS_C" ;
METHOD: element-type { double-complex-blas-matrix }
drop "CBLAS_Z" ;
TUPLE: blas-matrix-base underlying ld rows cols transpose ;
: Mtransposed? ( matrix -- ? )
transpose>> ; inline
@ -34,6 +16,11 @@ METHOD: element-type { double-complex-blas-matrix }
: Mheight ( matrix -- height )
dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline
GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
<PRIVATE
: (blas-transpose) ( matrix -- integer )
@ -41,53 +28,29 @@ METHOD: element-type { double-complex-blas-matrix }
GENERIC: (blas-matrix-like) ( data ld rows cols transpose exemplar -- matrix )
METHOD: (blas-matrix-like) { object object object object object float-blas-matrix }
drop <float-blas-matrix> ;
METHOD: (blas-matrix-like) { object object object object object double-blas-matrix }
drop <double-blas-matrix> ;
METHOD: (blas-matrix-like) { object object object object object float-complex-blas-matrix }
drop <float-complex-blas-matrix> ;
METHOD: (blas-matrix-like) { object object object object object double-complex-blas-matrix }
drop <double-complex-blas-matrix> ;
METHOD: (blas-matrix-like) { object object object object object float-blas-vector }
drop <float-blas-matrix> ;
METHOD: (blas-matrix-like) { object object object object object double-blas-vector }
drop <double-blas-matrix> ;
METHOD: (blas-matrix-like) { object object object object object float-complex-blas-vector }
drop <float-complex-blas-matrix> ;
METHOD: (blas-matrix-like) { object object object object object double-complex-blas-vector }
drop <double-complex-blas-matrix> ;
METHOD: (blas-vector-like) { object object object float-blas-matrix }
drop <float-blas-vector> ;
METHOD: (blas-vector-like) { object object object double-blas-matrix }
drop <double-blas-vector> ;
METHOD: (blas-vector-like) { object object object float-complex-blas-matrix }
drop <float-complex-blas-vector> ;
METHOD: (blas-vector-like) { object object object double-complex-blas-matrix }
drop <double-complex-blas-vector> ;
: (validate-gemv) ( A x y -- )
{
[ drop [ Mwidth ] [ length>> ] bi* = ]
[ nip [ Mheight ] [ length>> ] bi* = ]
} 3&&
[ "Mismatched matrix and vectors in matrix-vector multiplication" throw ] unless ;
[ "Mismatched matrix and vectors in matrix-vector multiplication" throw ]
unless ;
:: (prepare-gemv) ( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc y )
:: (prepare-gemv)
( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc
y )
A x y (validate-gemv)
CblasColMajor
A (blas-transpose)
A rows>>
A cols>>
alpha >c-arg call
A data>>
A underlying>>
A ld>>
x data>>
x underlying>>
x inc>>
beta >c-arg call
y data>>
y underlying>>
y inc>>
y ; inline
@ -96,19 +59,22 @@ METHOD: (blas-vector-like) { object object object double-complex-blas-matrix }
[ nip [ length>> ] [ Mheight ] bi* = ]
[ nipd [ length>> ] [ Mwidth ] bi* = ]
} 3&&
[ "Mismatched vertices and matrix in vector outer product" throw ] unless ;
[ "Mismatched vertices and matrix in vector outer product" throw ]
unless ;
:: (prepare-ger) ( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld A )
:: (prepare-ger)
( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld
A )
x y A (validate-ger)
CblasColMajor
A rows>>
A cols>>
alpha >c-arg call
x data>>
x underlying>>
x inc>>
y data>>
y underlying>>
y inc>>
A data>>
A underlying>>
A ld>>
A f >>transpose ; inline
@ -117,9 +83,13 @@ METHOD: (blas-vector-like) { object object object double-complex-blas-matrix }
[ drop [ Mwidth ] [ Mheight ] bi* = ]
[ nip [ Mheight ] bi@ = ]
[ nipd [ Mwidth ] bi@ = ]
} 3&& [ "Mismatched matrices in matrix multiplication" throw ] unless ;
} 3&&
[ "Mismatched matrices in matrix multiplication" throw ]
unless ;
:: (prepare-gemm) ( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld C )
:: (prepare-gemm)
( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld
C )
A B C (validate-gemm)
CblasColMajor
A (blas-transpose)
@ -128,12 +98,12 @@ METHOD: (blas-vector-like) { object object object double-complex-blas-matrix }
C cols>>
A Mwidth
alpha >c-arg call
A data>>
A underlying>>
A ld>>
B data>>
B underlying>>
B ld>>
beta >c-arg call
C data>>
C underlying>>
C ld>>
C f >>transpose ; inline
@ -142,65 +112,22 @@ METHOD: (blas-vector-like) { object object object double-complex-blas-matrix }
PRIVATE>
: >float-blas-matrix ( arrays -- matrix )
[ >float-array underlying>> ] (>matrix) <float-blas-matrix> ;
: >double-blas-matrix ( arrays -- matrix )
[ >double-array underlying>> ] (>matrix) <double-blas-matrix> ;
: >float-complex-blas-matrix ( arrays -- matrix )
[ (flatten-complex-sequence) >float-array underlying>> ] (>matrix)
<float-complex-blas-matrix> ;
: >double-complex-blas-matrix ( arrays -- matrix )
[ (flatten-complex-sequence) >double-array underlying>> ] (>matrix)
<double-complex-blas-matrix> ;
GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y )
GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A )
GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A )
GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C )
METHOD: n*M.V+n*V! { real float-blas-matrix float-blas-vector real float-blas-vector }
[ ] (prepare-gemv) [ cblas_sgemv ] dip ;
METHOD: n*M.V+n*V! { real double-blas-matrix double-blas-vector real double-blas-vector }
[ ] (prepare-gemv) [ cblas_dgemv ] dip ;
METHOD: n*M.V+n*V! { number float-complex-blas-matrix float-complex-blas-vector number float-complex-blas-vector }
[ (>c-complex) ] (prepare-gemv) [ cblas_cgemv ] dip ;
METHOD: n*M.V+n*V! { number double-complex-blas-matrix double-complex-blas-vector number double-complex-blas-vector }
[ (>z-complex) ] (prepare-gemv) [ cblas_zgemv ] dip ;
METHOD: n*V(*)V+M! { real float-blas-vector float-blas-vector float-blas-matrix }
[ ] (prepare-ger) [ cblas_sger ] dip ;
METHOD: n*V(*)V+M! { real double-blas-vector double-blas-vector double-blas-matrix }
[ ] (prepare-ger) [ cblas_dger ] dip ;
METHOD: n*V(*)V+M! { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
[ (>c-complex) ] (prepare-ger) [ cblas_cgeru ] dip ;
METHOD: n*V(*)V+M! { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
[ (>z-complex) ] (prepare-ger) [ cblas_zgeru ] dip ;
METHOD: n*V(*)Vconj+M! { real float-blas-vector float-blas-vector float-blas-matrix }
[ ] (prepare-ger) [ cblas_sger ] dip ;
METHOD: n*V(*)Vconj+M! { real double-blas-vector double-blas-vector double-blas-matrix }
[ ] (prepare-ger) [ cblas_dger ] dip ;
METHOD: n*V(*)Vconj+M! { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix }
[ (>c-complex) ] (prepare-ger) [ cblas_cgerc ] dip ;
METHOD: n*V(*)Vconj+M! { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix }
[ (>z-complex) ] (prepare-ger) [ cblas_zgerc ] dip ;
METHOD: n*M.M+n*M! { real float-blas-matrix float-blas-matrix real float-blas-matrix }
[ ] (prepare-gemm) [ cblas_sgemm ] dip ;
METHOD: n*M.M+n*M! { real double-blas-matrix double-blas-matrix real double-blas-matrix }
[ ] (prepare-gemm) [ cblas_dgemm ] dip ;
METHOD: n*M.M+n*M! { number float-complex-blas-matrix float-complex-blas-matrix number float-complex-blas-matrix }
[ (>c-complex) ] (prepare-gemm) [ cblas_cgemm ] dip ;
METHOD: n*M.M+n*M! { number double-complex-blas-matrix double-complex-blas-matrix number double-complex-blas-matrix }
[ (>z-complex) ] (prepare-gemm) [ cblas_zgemm ] dip ;
! XXX should do a dense clone
syntax:M: blas-matrix-base clone
M: blas-matrix-base clone
[
[
{ [ data>> ] [ ld>> ] [ cols>> ] [ element-type heap-size ] } cleave
* * memory>byte-array
] [ { [ ld>> ] [ rows>> ] [ cols>> ] [ transpose>> ] } cleave ] bi
[ {
[ underlying>> ]
[ ld>> ]
[ cols>> ]
[ element-type heap-size ]
} cleave * * memory>byte-array ]
[ {
[ ld>> ]
[ rows>> ]
[ cols>> ]
[ transpose>> ]
} cleave ]
bi
] keep (blas-matrix-like) ;
! XXX try rounding stride to next 128 bit bound for better vectorizin'
@ -246,29 +173,31 @@ syntax:M: blas-matrix-base clone
:: (Msub) ( matrix row col height width -- data ld rows cols )
matrix ld>> col * row + matrix element-type heap-size *
matrix data>> <displaced-alien>
matrix underlying>> <displaced-alien>
matrix ld>>
height
width ;
: Msub ( matrix row col height width -- sub )
5 npick dup transpose>>
[ nip [ [ swap ] 2dip swap ] when (Msub) ] 2keep
swap (blas-matrix-like) ;
:: Msub ( matrix row col height width -- sub )
matrix dup transpose>>
[ col row width height ]
[ row col height width ] if (Msub)
matrix transpose>> matrix (blas-matrix-like) ;
TUPLE: blas-matrix-rowcol-sequence parent inc rowcol-length rowcol-jump length ;
TUPLE: blas-matrix-rowcol-sequence
parent inc rowcol-length rowcol-jump length ;
C: <blas-matrix-rowcol-sequence> blas-matrix-rowcol-sequence
INSTANCE: blas-matrix-rowcol-sequence sequence
syntax:M: blas-matrix-rowcol-sequence length
M: blas-matrix-rowcol-sequence length
length>> ;
syntax:M: blas-matrix-rowcol-sequence nth-unsafe
M: blas-matrix-rowcol-sequence nth-unsafe
{
[
[ rowcol-jump>> ]
[ parent>> element-type heap-size ]
[ parent>> data>> ] tri
[ parent>> underlying>> ] tri
[ * * ] dip <displaced-alien>
]
[ rowcol-length>> ]
@ -277,11 +206,11 @@ syntax:M: blas-matrix-rowcol-sequence nth-unsafe
} cleave (blas-vector-like) ;
: (Mcols) ( A -- columns )
{ [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] } cleave
<blas-matrix-rowcol-sequence> ;
{ [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] }
cleave <blas-matrix-rowcol-sequence> ;
: (Mrows) ( A -- rows )
{ [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] } cleave
<blas-matrix-rowcol-sequence> ;
{ [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] }
cleave <blas-matrix-rowcol-sequence> ;
: Mrows ( A -- rows )
dup transpose>> [ (Mcols) ] [ (Mrows) ] if ;
@ -300,11 +229,79 @@ syntax:M: blas-matrix-rowcol-sequence nth-unsafe
recip swap n*M ; inline
: Mtranspose ( matrix -- matrix^T )
[ { [ data>> ] [ ld>> ] [ rows>> ] [ cols>> ] [ transpose>> not ] } cleave ] keep (blas-matrix-like) ;
[ {
[ underlying>> ]
[ ld>> ] [ rows>> ]
[ cols>> ]
[ transpose>> not ]
} cleave ] keep (blas-matrix-like) ;
syntax:M: blas-matrix-base equal?
M: blas-matrix-base equal?
{
[ [ Mwidth ] bi@ = ]
[ [ Mcols ] bi@ [ = ] 2all? ]
} 2&& ;
<<
FUNCTOR: (define-blas-matrix) ( TYPE T U C -- )
VECTOR IS ${TYPE}-blas-vector
<VECTOR> IS <${TYPE}-blas-vector>
>ARRAY IS >${TYPE}-array
TYPE>ARG IS ${TYPE}>arg
XGEMV IS cblas_${T}gemv
XGEMM IS cblas_${T}gemm
XGERU IS cblas_${T}ger${U}
XGERC IS cblas_${T}ger${C}
MATRIX DEFINES ${TYPE}-blas-matrix
<MATRIX> DEFINES <${TYPE}-blas-matrix>
>MATRIX DEFINES >${TYPE}-blas-matrix
WHERE
TUPLE: MATRIX < blas-matrix-base ;
: <MATRIX> ( underlying ld rows cols transpose -- matrix )
MATRIX boa ; inline
M: MATRIX element-type
drop TYPE ;
M: MATRIX (blas-matrix-like)
drop <MATRIX> execute ;
M: VECTOR (blas-matrix-like)
drop <MATRIX> execute ;
M: MATRIX (blas-vector-like)
drop <VECTOR> execute ;
: >MATRIX ( arrays -- matrix )
[ >ARRAY execute underlying>> ] (>matrix)
<MATRIX> execute ;
M: VECTOR n*M.V+n*V!
[ TYPE>ARG execute ] (prepare-gemv)
[ XGEMV execute ] dip ;
M: MATRIX n*M.M+n*M!
[ TYPE>ARG execute ] (prepare-gemm)
[ XGEMM execute ] dip ;
M: MATRIX n*V(*)V+M!
[ TYPE>ARG execute ] (prepare-ger)
[ XGERU execute ] dip ;
M: MATRIX n*V(*)Vconj+M!
[ TYPE>ARG execute ] (prepare-ger)
[ XGERC execute ] dip ;
;FUNCTOR
: define-real-blas-matrix ( TYPE T -- )
"" "" (define-blas-matrix) ;
: define-complex-blas-matrix ( TYPE T -- )
"u" "c" (define-blas-matrix) ;
"float" "s" define-real-blas-matrix
"double" "d" define-real-blas-matrix
"float-complex" "c" define-complex-blas-matrix
"double-complex" "z" define-complex-blas-matrix
>>

View File

@ -1,4 +1,4 @@
USING: kernel math.blas.matrices math.blas.vectors parser
USING: kernel math.blas.vectors math.blas.matrices parser
arrays prettyprint.backend sequences ;
IN: math.blas.syntax
@ -20,15 +20,23 @@ IN: math.blas.syntax
: zmatrix{
\ } [ >double-complex-blas-matrix ] parse-literal ; parsing
M: float-blas-vector pprint-delims drop \ svector{ \ } ;
M: double-blas-vector pprint-delims drop \ dvector{ \ } ;
M: float-complex-blas-vector pprint-delims drop \ cvector{ \ } ;
M: double-complex-blas-vector pprint-delims drop \ zvector{ \ } ;
M: float-blas-vector pprint-delims
drop \ svector{ \ } ;
M: double-blas-vector pprint-delims
drop \ dvector{ \ } ;
M: float-complex-blas-vector pprint-delims
drop \ cvector{ \ } ;
M: double-complex-blas-vector pprint-delims
drop \ zvector{ \ } ;
M: float-blas-matrix pprint-delims drop \ smatrix{ \ } ;
M: double-blas-matrix pprint-delims drop \ dmatrix{ \ } ;
M: float-complex-blas-matrix pprint-delims drop \ cmatrix{ \ } ;
M: double-complex-blas-matrix pprint-delims drop \ zmatrix{ \ } ;
M: float-blas-matrix pprint-delims
drop \ smatrix{ \ } ;
M: double-blas-matrix pprint-delims
drop \ dmatrix{ \ } ;
M: float-complex-blas-matrix pprint-delims
drop \ cmatrix{ \ } ;
M: double-complex-blas-matrix pprint-delims
drop \ zmatrix{ \ } ;
M: blas-vector-base >pprint-sequence ;
M: blas-vector-base pprint* pprint-object ;

View File

@ -119,6 +119,10 @@ M: blas-vector-base virtual-seq
M: blas-vector-base virtual@
[ inc>> * ] [ nip (blas-direct-array) ] 2bi ;
: float>arg ( f -- f ) ; inline
: double>arg ( f -- f ) ; inline
: arg>float ( f -- f ) ; inline
: arg>double ( f -- f ) ; inline
<<
@ -195,8 +199,8 @@ FUNCTOR: (define-complex-helpers) ( TYPE -- )
<DIRECT-COMPLEX-ARRAY> DEFINES <direct-${TYPE}-complex-array>
>COMPLEX-ARRAY DEFINES >${TYPE}-complex-array
ALIEN>COMPLEX DEFINES alien>${TYPE}-complex
COMPLEX>ALIEN DEFINES ${TYPE}-complex>alien
ARG>COMPLEX DEFINES arg>${TYPE}-complex
COMPLEX>ARG DEFINES ${TYPE}-complex>arg
<DIRECT-ARRAY> IS <direct-${TYPE}-array>
>ARRAY IS >${TYPE}-array
@ -206,9 +210,9 @@ WHERE
1 shift <DIRECT-ARRAY> execute <complex-sequence> ;
: >COMPLEX-ARRAY ( sequence -- sequence )
<complex-components> >ARRAY execute ;
: COMPLEX>ALIEN ( complex -- alien )
: COMPLEX>ARG ( complex -- alien )
>rect 2array >ARRAY execute underlying>> ;
: ALIEN>COMPLEX ( alien -- complex )
: ARG>COMPLEX ( alien -- complex )
2 <DIRECT-ARRAY> execute first2 rect> ;
;FUNCTOR
@ -223,28 +227,28 @@ XXNRM2 IS cblas_${S}${C}nrm2
XXASUM IS cblas_${S}${C}asum
XAXPY IS cblas_${C}axpy
XSCAL IS cblas_${C}scal
TYPE>ALIEN IS ${TYPE}>alien
ALIEN>TYPE IS alien>${TYPE}
TYPE>ARG IS ${TYPE}>arg
ARG>TYPE IS arg>${TYPE}
WHERE
M: VECTOR V.
(prepare-dot) TYPE <c-object>
[ XDOTU_SUB execute ] keep
ALIEN>TYPE execute ;
ARG>TYPE execute ;
M: VECTOR V.conj
(prepare-dot) TYPE <c-object>
[ XDOTC_SUB execute ] keep
ALIEN>TYPE execute ;
ARG>TYPE execute ;
M: VECTOR Vnorm
(prepare-nrm2) XXNRM2 execute ;
M: VECTOR Vasum
(prepare-nrm2) XXASUM execute ;
M: VECTOR n*V+V!
[ TYPE>ALIEN execute ] 2dip
[ TYPE>ARG execute ] 2dip
(prepare-axpy) [ XAXPY execute ] dip ;
M: VECTOR n*V!
[ TYPE>ALIEN execute ] dip
[ TYPE>ARG execute ] dip
(prepare-scal) [ XSCAL execute ] dip ;
;FUNCTOR