diff --git a/extra/math/blas/matrices/matrices-docs.factor b/extra/math/blas/matrices/matrices-docs.factor index dc6a86017a..01e0997405 100644 --- a/extra/math/blas/matrices/matrices-docs.factor +++ b/extra/math/blas/matrices/matrices-docs.factor @@ -88,7 +88,7 @@ HELP: blas-matrix-base } "All of these subclasses share the same tuple layout:" { $list - { { $snippet "data" } " contains an alien pointer referencing or byte-array containing a packed, column-major array of float, double, float complex, or double complex values;" } + { { $snippet "underlying" } " contains an alien pointer referencing or byte-array containing a packed, column-major array of float, double, float complex, or double complex values;" } { { $snippet "ld" } " indicates the distance, in elements, between matrix columns;" } { { $snippet "rows" } " and " { $snippet "cols" } " indicate the number of significant rows and columns in the matrix;" } { "and " { $snippet "transpose" } ", if set to a true value, indicates that the matrix should be treated as transposed relative to its in-memory representation." } diff --git a/extra/math/blas/matrices/matrices.factor b/extra/math/blas/matrices/matrices.factor index 0899e2d079..c8a4ee6292 100755 --- a/extra/math/blas/matrices/matrices.factor +++ b/extra/math/blas/matrices/matrices.factor @@ -1,31 +1,13 @@ USING: accessors alien alien.c-types arrays byte-arrays combinators -combinators.lib combinators.short-circuit fry kernel locals macros +combinators.short-circuit fry kernel locals macros math math.blas.cblas math.blas.vectors math.blas.vectors.private -math.complex math.functions math.order multi-methods qualified -sequences sequences.merged sequences.private generalizations -shuffle symbols speicalized-arrays.float specialized-arrays.double ; -QUALIFIED: syntax +math.complex math.functions math.order functors words +sequences sequences.merged sequences.private shuffle symbols +specialized-arrays.direct.float specialized-arrays.direct.double +specialized-arrays.float specialized-arrays.double ; IN: math.blas.matrices -TUPLE: blas-matrix-base data ld rows cols transpose ; -TUPLE: float-blas-matrix < blas-matrix-base ; -TUPLE: double-blas-matrix < blas-matrix-base ; -TUPLE: float-complex-blas-matrix < blas-matrix-base ; -TUPLE: double-complex-blas-matrix < blas-matrix-base ; - -C: float-blas-matrix -C: double-blas-matrix -C: float-complex-blas-matrix -C: double-complex-blas-matrix - -METHOD: element-type { float-blas-matrix } - drop "float" ; -METHOD: element-type { double-blas-matrix } - drop "double" ; -METHOD: element-type { float-complex-blas-matrix } - drop "CBLAS_C" ; -METHOD: element-type { double-complex-blas-matrix } - drop "CBLAS_Z" ; +TUPLE: blas-matrix-base underlying ld rows cols transpose ; : Mtransposed? ( matrix -- ? ) transpose>> ; inline @@ -34,6 +16,11 @@ METHOD: element-type { double-complex-blas-matrix } : Mheight ( matrix -- height ) dup Mtransposed? [ cols>> ] [ rows>> ] if ; inline +GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y ) +GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A ) +GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A ) +GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C ) + ; -METHOD: (blas-matrix-like) { object object object object object double-blas-matrix } - drop ; -METHOD: (blas-matrix-like) { object object object object object float-complex-blas-matrix } - drop ; -METHOD: (blas-matrix-like) { object object object object object double-complex-blas-matrix } - drop ; - -METHOD: (blas-matrix-like) { object object object object object float-blas-vector } - drop ; -METHOD: (blas-matrix-like) { object object object object object double-blas-vector } - drop ; -METHOD: (blas-matrix-like) { object object object object object float-complex-blas-vector } - drop ; -METHOD: (blas-matrix-like) { object object object object object double-complex-blas-vector } - drop ; - -METHOD: (blas-vector-like) { object object object float-blas-matrix } - drop ; -METHOD: (blas-vector-like) { object object object double-blas-matrix } - drop ; -METHOD: (blas-vector-like) { object object object float-complex-blas-matrix } - drop ; -METHOD: (blas-vector-like) { object object object double-complex-blas-matrix } - drop ; - : (validate-gemv) ( A x y -- ) { [ drop [ Mwidth ] [ length>> ] bi* = ] [ nip [ Mheight ] [ length>> ] bi* = ] } 3&& - [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ] unless ; + [ "Mismatched matrix and vectors in matrix-vector multiplication" throw ] + unless ; -:: (prepare-gemv) ( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc y ) +:: (prepare-gemv) + ( alpha A x beta y >c-arg -- order A-trans m n alpha A-data A-ld x-data x-inc beta y-data y-inc + y ) A x y (validate-gemv) CblasColMajor A (blas-transpose) A rows>> A cols>> alpha >c-arg call - A data>> + A underlying>> A ld>> - x data>> + x underlying>> x inc>> beta >c-arg call - y data>> + y underlying>> y inc>> y ; inline @@ -96,19 +59,22 @@ METHOD: (blas-vector-like) { object object object double-complex-blas-matrix } [ nip [ length>> ] [ Mheight ] bi* = ] [ nipd [ length>> ] [ Mwidth ] bi* = ] } 3&& - [ "Mismatched vertices and matrix in vector outer product" throw ] unless ; + [ "Mismatched vertices and matrix in vector outer product" throw ] + unless ; -:: (prepare-ger) ( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld A ) +:: (prepare-ger) + ( alpha x y A >c-arg -- order m n alpha x-data x-inc y-data y-inc A-data A-ld + A ) x y A (validate-ger) CblasColMajor A rows>> A cols>> alpha >c-arg call - x data>> + x underlying>> x inc>> - y data>> + y underlying>> y inc>> - A data>> + A underlying>> A ld>> A f >>transpose ; inline @@ -117,9 +83,13 @@ METHOD: (blas-vector-like) { object object object double-complex-blas-matrix } [ drop [ Mwidth ] [ Mheight ] bi* = ] [ nip [ Mheight ] bi@ = ] [ nipd [ Mwidth ] bi@ = ] - } 3&& [ "Mismatched matrices in matrix multiplication" throw ] unless ; + } 3&& + [ "Mismatched matrices in matrix multiplication" throw ] + unless ; -:: (prepare-gemm) ( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld C ) +:: (prepare-gemm) + ( alpha A B beta C >c-arg -- order A-trans B-trans m n k alpha A-data A-ld B-data B-ld beta C-data C-ld + C ) A B C (validate-gemm) CblasColMajor A (blas-transpose) @@ -128,12 +98,12 @@ METHOD: (blas-vector-like) { object object object double-complex-blas-matrix } C cols>> A Mwidth alpha >c-arg call - A data>> + A underlying>> A ld>> - B data>> + B underlying>> B ld>> beta >c-arg call - C data>> + C underlying>> C ld>> C f >>transpose ; inline @@ -142,65 +112,22 @@ METHOD: (blas-vector-like) { object object object double-complex-blas-matrix } PRIVATE> -: >float-blas-matrix ( arrays -- matrix ) - [ >float-array underlying>> ] (>matrix) ; -: >double-blas-matrix ( arrays -- matrix ) - [ >double-array underlying>> ] (>matrix) ; -: >float-complex-blas-matrix ( arrays -- matrix ) - [ (flatten-complex-sequence) >float-array underlying>> ] (>matrix) - ; -: >double-complex-blas-matrix ( arrays -- matrix ) - [ (flatten-complex-sequence) >double-array underlying>> ] (>matrix) - ; - -GENERIC: n*M.V+n*V! ( alpha A x beta y -- y=alpha*A.x+b*y ) -GENERIC: n*V(*)V+M! ( alpha x y A -- A=alpha*x(*)y+A ) -GENERIC: n*V(*)Vconj+M! ( alpha x y A -- A=alpha*x(*)yconj+A ) -GENERIC: n*M.M+n*M! ( alpha A B beta C -- C=alpha*A.B+beta*C ) - -METHOD: n*M.V+n*V! { real float-blas-matrix float-blas-vector real float-blas-vector } - [ ] (prepare-gemv) [ cblas_sgemv ] dip ; -METHOD: n*M.V+n*V! { real double-blas-matrix double-blas-vector real double-blas-vector } - [ ] (prepare-gemv) [ cblas_dgemv ] dip ; -METHOD: n*M.V+n*V! { number float-complex-blas-matrix float-complex-blas-vector number float-complex-blas-vector } - [ (>c-complex) ] (prepare-gemv) [ cblas_cgemv ] dip ; -METHOD: n*M.V+n*V! { number double-complex-blas-matrix double-complex-blas-vector number double-complex-blas-vector } - [ (>z-complex) ] (prepare-gemv) [ cblas_zgemv ] dip ; - -METHOD: n*V(*)V+M! { real float-blas-vector float-blas-vector float-blas-matrix } - [ ] (prepare-ger) [ cblas_sger ] dip ; -METHOD: n*V(*)V+M! { real double-blas-vector double-blas-vector double-blas-matrix } - [ ] (prepare-ger) [ cblas_dger ] dip ; -METHOD: n*V(*)V+M! { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix } - [ (>c-complex) ] (prepare-ger) [ cblas_cgeru ] dip ; -METHOD: n*V(*)V+M! { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix } - [ (>z-complex) ] (prepare-ger) [ cblas_zgeru ] dip ; - -METHOD: n*V(*)Vconj+M! { real float-blas-vector float-blas-vector float-blas-matrix } - [ ] (prepare-ger) [ cblas_sger ] dip ; -METHOD: n*V(*)Vconj+M! { real double-blas-vector double-blas-vector double-blas-matrix } - [ ] (prepare-ger) [ cblas_dger ] dip ; -METHOD: n*V(*)Vconj+M! { number float-complex-blas-vector float-complex-blas-vector float-complex-blas-matrix } - [ (>c-complex) ] (prepare-ger) [ cblas_cgerc ] dip ; -METHOD: n*V(*)Vconj+M! { number double-complex-blas-vector double-complex-blas-vector double-complex-blas-matrix } - [ (>z-complex) ] (prepare-ger) [ cblas_zgerc ] dip ; - -METHOD: n*M.M+n*M! { real float-blas-matrix float-blas-matrix real float-blas-matrix } - [ ] (prepare-gemm) [ cblas_sgemm ] dip ; -METHOD: n*M.M+n*M! { real double-blas-matrix double-blas-matrix real double-blas-matrix } - [ ] (prepare-gemm) [ cblas_dgemm ] dip ; -METHOD: n*M.M+n*M! { number float-complex-blas-matrix float-complex-blas-matrix number float-complex-blas-matrix } - [ (>c-complex) ] (prepare-gemm) [ cblas_cgemm ] dip ; -METHOD: n*M.M+n*M! { number double-complex-blas-matrix double-complex-blas-matrix number double-complex-blas-matrix } - [ (>z-complex) ] (prepare-gemm) [ cblas_zgemm ] dip ; - ! XXX should do a dense clone -syntax:M: blas-matrix-base clone +M: blas-matrix-base clone [ - [ - { [ data>> ] [ ld>> ] [ cols>> ] [ element-type heap-size ] } cleave - * * memory>byte-array - ] [ { [ ld>> ] [ rows>> ] [ cols>> ] [ transpose>> ] } cleave ] bi + [ { + [ underlying>> ] + [ ld>> ] + [ cols>> ] + [ element-type heap-size ] + } cleave * * memory>byte-array ] + [ { + [ ld>> ] + [ rows>> ] + [ cols>> ] + [ transpose>> ] + } cleave ] + bi ] keep (blas-matrix-like) ; ! XXX try rounding stride to next 128 bit bound for better vectorizin' @@ -246,29 +173,31 @@ syntax:M: blas-matrix-base clone :: (Msub) ( matrix row col height width -- data ld rows cols ) matrix ld>> col * row + matrix element-type heap-size * - matrix data>> + matrix underlying>> matrix ld>> height width ; -: Msub ( matrix row col height width -- sub ) - 5 npick dup transpose>> - [ nip [ [ swap ] 2dip swap ] when (Msub) ] 2keep - swap (blas-matrix-like) ; +:: Msub ( matrix row col height width -- sub ) + matrix dup transpose>> + [ col row width height ] + [ row col height width ] if (Msub) + matrix transpose>> matrix (blas-matrix-like) ; -TUPLE: blas-matrix-rowcol-sequence parent inc rowcol-length rowcol-jump length ; +TUPLE: blas-matrix-rowcol-sequence + parent inc rowcol-length rowcol-jump length ; C: blas-matrix-rowcol-sequence INSTANCE: blas-matrix-rowcol-sequence sequence -syntax:M: blas-matrix-rowcol-sequence length +M: blas-matrix-rowcol-sequence length length>> ; -syntax:M: blas-matrix-rowcol-sequence nth-unsafe +M: blas-matrix-rowcol-sequence nth-unsafe { [ [ rowcol-jump>> ] [ parent>> element-type heap-size ] - [ parent>> data>> ] tri + [ parent>> underlying>> ] tri [ * * ] dip ] [ rowcol-length>> ] @@ -277,11 +206,11 @@ syntax:M: blas-matrix-rowcol-sequence nth-unsafe } cleave (blas-vector-like) ; : (Mcols) ( A -- columns ) - { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] } cleave - ; + { [ ] [ drop 1 ] [ rows>> ] [ ld>> ] [ cols>> ] } + cleave ; : (Mrows) ( A -- rows ) - { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] } cleave - ; + { [ ] [ ld>> ] [ cols>> ] [ drop 1 ] [ rows>> ] } + cleave ; : Mrows ( A -- rows ) dup transpose>> [ (Mcols) ] [ (Mrows) ] if ; @@ -300,11 +229,79 @@ syntax:M: blas-matrix-rowcol-sequence nth-unsafe recip swap n*M ; inline : Mtranspose ( matrix -- matrix^T ) - [ { [ data>> ] [ ld>> ] [ rows>> ] [ cols>> ] [ transpose>> not ] } cleave ] keep (blas-matrix-like) ; + [ { + [ underlying>> ] + [ ld>> ] [ rows>> ] + [ cols>> ] + [ transpose>> not ] + } cleave ] keep (blas-matrix-like) ; -syntax:M: blas-matrix-base equal? +M: blas-matrix-base equal? { [ [ Mwidth ] bi@ = ] [ [ Mcols ] bi@ [ = ] 2all? ] } 2&& ; +<< + +FUNCTOR: (define-blas-matrix) ( TYPE T U C -- ) + +VECTOR IS ${TYPE}-blas-vector + IS <${TYPE}-blas-vector> +>ARRAY IS >${TYPE}-array +TYPE>ARG IS ${TYPE}>arg +XGEMV IS cblas_${T}gemv +XGEMM IS cblas_${T}gemm +XGERU IS cblas_${T}ger${U} +XGERC IS cblas_${T}ger${C} + +MATRIX DEFINES ${TYPE}-blas-matrix + DEFINES <${TYPE}-blas-matrix> +>MATRIX DEFINES >${TYPE}-blas-matrix + +WHERE + +TUPLE: MATRIX < blas-matrix-base ; +: ( underlying ld rows cols transpose -- matrix ) + MATRIX boa ; inline + +M: MATRIX element-type + drop TYPE ; +M: MATRIX (blas-matrix-like) + drop execute ; +M: VECTOR (blas-matrix-like) + drop execute ; +M: MATRIX (blas-vector-like) + drop execute ; + +: >MATRIX ( arrays -- matrix ) + [ >ARRAY execute underlying>> ] (>matrix) + execute ; + +M: VECTOR n*M.V+n*V! + [ TYPE>ARG execute ] (prepare-gemv) + [ XGEMV execute ] dip ; +M: MATRIX n*M.M+n*M! + [ TYPE>ARG execute ] (prepare-gemm) + [ XGEMM execute ] dip ; +M: MATRIX n*V(*)V+M! + [ TYPE>ARG execute ] (prepare-ger) + [ XGERU execute ] dip ; +M: MATRIX n*V(*)Vconj+M! + [ TYPE>ARG execute ] (prepare-ger) + [ XGERC execute ] dip ; + +;FUNCTOR + + +: define-real-blas-matrix ( TYPE T -- ) + "" "" (define-blas-matrix) ; +: define-complex-blas-matrix ( TYPE T -- ) + "u" "c" (define-blas-matrix) ; + +"float" "s" define-real-blas-matrix +"double" "d" define-real-blas-matrix +"float-complex" "c" define-complex-blas-matrix +"double-complex" "z" define-complex-blas-matrix + +>> diff --git a/extra/math/blas/syntax/syntax.factor b/extra/math/blas/syntax/syntax.factor index 6b40910687..95f9f7bd08 100644 --- a/extra/math/blas/syntax/syntax.factor +++ b/extra/math/blas/syntax/syntax.factor @@ -1,4 +1,4 @@ -USING: kernel math.blas.matrices math.blas.vectors parser +USING: kernel math.blas.vectors math.blas.matrices parser arrays prettyprint.backend sequences ; IN: math.blas.syntax @@ -20,15 +20,23 @@ IN: math.blas.syntax : zmatrix{ \ } [ >double-complex-blas-matrix ] parse-literal ; parsing -M: float-blas-vector pprint-delims drop \ svector{ \ } ; -M: double-blas-vector pprint-delims drop \ dvector{ \ } ; -M: float-complex-blas-vector pprint-delims drop \ cvector{ \ } ; -M: double-complex-blas-vector pprint-delims drop \ zvector{ \ } ; +M: float-blas-vector pprint-delims + drop \ svector{ \ } ; +M: double-blas-vector pprint-delims + drop \ dvector{ \ } ; +M: float-complex-blas-vector pprint-delims + drop \ cvector{ \ } ; +M: double-complex-blas-vector pprint-delims + drop \ zvector{ \ } ; -M: float-blas-matrix pprint-delims drop \ smatrix{ \ } ; -M: double-blas-matrix pprint-delims drop \ dmatrix{ \ } ; -M: float-complex-blas-matrix pprint-delims drop \ cmatrix{ \ } ; -M: double-complex-blas-matrix pprint-delims drop \ zmatrix{ \ } ; +M: float-blas-matrix pprint-delims + drop \ smatrix{ \ } ; +M: double-blas-matrix pprint-delims + drop \ dmatrix{ \ } ; +M: float-complex-blas-matrix pprint-delims + drop \ cmatrix{ \ } ; +M: double-complex-blas-matrix pprint-delims + drop \ zmatrix{ \ } ; M: blas-vector-base >pprint-sequence ; M: blas-vector-base pprint* pprint-object ; diff --git a/extra/math/blas/vectors/vectors.factor b/extra/math/blas/vectors/vectors.factor index 56ec773c6a..41fe2b4740 100755 --- a/extra/math/blas/vectors/vectors.factor +++ b/extra/math/blas/vectors/vectors.factor @@ -119,6 +119,10 @@ M: blas-vector-base virtual-seq M: blas-vector-base virtual@ [ inc>> * ] [ nip (blas-direct-array) ] 2bi ; +: float>arg ( f -- f ) ; inline +: double>arg ( f -- f ) ; inline +: arg>float ( f -- f ) ; inline +: arg>double ( f -- f ) ; inline << @@ -195,8 +199,8 @@ FUNCTOR: (define-complex-helpers) ( TYPE -- ) DEFINES >COMPLEX-ARRAY DEFINES >${TYPE}-complex-array -ALIEN>COMPLEX DEFINES alien>${TYPE}-complex -COMPLEX>ALIEN DEFINES ${TYPE}-complex>alien +ARG>COMPLEX DEFINES arg>${TYPE}-complex +COMPLEX>ARG DEFINES ${TYPE}-complex>arg IS >ARRAY IS >${TYPE}-array @@ -206,9 +210,9 @@ WHERE 1 shift execute ; : >COMPLEX-ARRAY ( sequence -- sequence ) >ARRAY execute ; -: COMPLEX>ALIEN ( complex -- alien ) +: COMPLEX>ARG ( complex -- alien ) >rect 2array >ARRAY execute underlying>> ; -: ALIEN>COMPLEX ( alien -- complex ) +: ARG>COMPLEX ( alien -- complex ) 2 execute first2 rect> ; ;FUNCTOR @@ -223,28 +227,28 @@ XXNRM2 IS cblas_${S}${C}nrm2 XXASUM IS cblas_${S}${C}asum XAXPY IS cblas_${C}axpy XSCAL IS cblas_${C}scal -TYPE>ALIEN IS ${TYPE}>alien -ALIEN>TYPE IS alien>${TYPE} +TYPE>ARG IS ${TYPE}>arg +ARG>TYPE IS arg>${TYPE} WHERE M: VECTOR V. (prepare-dot) TYPE [ XDOTU_SUB execute ] keep - ALIEN>TYPE execute ; + ARG>TYPE execute ; M: VECTOR V.conj (prepare-dot) TYPE [ XDOTC_SUB execute ] keep - ALIEN>TYPE execute ; + ARG>TYPE execute ; M: VECTOR Vnorm (prepare-nrm2) XXNRM2 execute ; M: VECTOR Vasum (prepare-nrm2) XXASUM execute ; M: VECTOR n*V+V! - [ TYPE>ALIEN execute ] 2dip + [ TYPE>ARG execute ] 2dip (prepare-axpy) [ XAXPY execute ] dip ; M: VECTOR n*V! - [ TYPE>ALIEN execute ] dip + [ TYPE>ARG execute ] dip (prepare-scal) [ XSCAL execute ] dip ; ;FUNCTOR