diff --git a/basis/math/matrices/matrices-tests.factor b/basis/math/matrices/matrices-tests.factor index 7547e78652..2870385093 100644 --- a/basis/math/matrices/matrices-tests.factor +++ b/basis/math/matrices/matrices-tests.factor @@ -314,3 +314,66 @@ CONSTANT: test-points { test-points cov-matrix ] unit-test +{ + { + { 5 5 } + { 5 5 } + } +} [ + 2 2 5 +] unit-test + +{ + { + { 5 5 } + { 5 5 } + } +} [ + 2 2 [ 5 ] make-matrix +] unit-test + +{ + { + { 0 1 2 } + { 1 2 3 } + } +} [ + 2 3 [ + ] make-matrix-with-indices +] unit-test + +{ + { + { 0 1 } + { 0 1 } + } +} [ + 2 square-rows +] unit-test + +{ + { + { 0 0 } + { 1 1 } + } +} [ + 2 square-cols +] unit-test + +{ + { + { 5 6 } + { 5 6 } + } +} [ + { 5 6 } square-rows +] unit-test + +{ + { + { 5 5 } + { 6 6 } + } +} [ + { 5 6 } square-cols +] unit-test + diff --git a/basis/math/matrices/matrices.factor b/basis/math/matrices/matrices.factor index e5f3dcfe1b..bf9aba23a0 100644 --- a/basis/math/matrices/matrices.factor +++ b/basis/math/matrices/matrices.factor @@ -2,12 +2,19 @@ ! See http://factorcode.org/license.txt for BSD license. USING: accessors arrays columns kernel locals math math.bits math.functions math.order math.vectors sequences -sequences.private fry math.statistics ; +sequences.private fry math.statistics grouping +combinators.short-circuit math.ranges combinators.smart ; IN: math.matrices ! Matrices +: make-matrix ( m n quot -- matrix ) + '[ _ _ replicate ] replicate ; inline + +: ( m n element -- matrix ) + '[ _ _ ] replicate ; inline + : zero-matrix ( m n -- matrix ) - '[ _ 0 ] replicate ; + 0 ; inline : diagonal-matrix ( diagonal-seq -- matrix ) dup length dup zero-matrix @@ -169,38 +176,91 @@ IN: math.matrices : outer ( u v -- m ) [ n*v ] curry map ; -: row ( n m -- col ) +: row ( n matrix -- col ) nth ; inline -: rows ( seq m -- cols ) +: rows ( seq matrix -- cols ) '[ _ row ] map ; inline -: col ( n m -- col ) +: col ( n matrix -- col ) swap '[ _ swap nth ] map ; inline -: cols ( seq m -- cols ) +: cols ( seq matrix -- cols ) '[ _ col ] map ; inline -: matrix-map ( m quot -- ) +: set-index ( object pair matrix -- ) + [ first2 swap ] dip nth set-nth ; inline + +: set-indices ( object sequence matrix -- ) + '[ _ set-index ] with each ; inline + + +: matrix-map ( matrix quot -- ) '[ _ map ] map ; inline -: column-map ( m quot -- seq ) +: column-map ( matrix quot -- seq ) [ [ first length iota ] keep ] dip '[ _ col @ ] map ; inline -: cartesian-indices ( n -- matrix ) +: cartesian-square-indices ( n -- matrix ) iota dup cartesian-product ; inline -: cartesian-matrix-map ( m quot -- m' ) - [ [ first length cartesian-indices ] keep ] dip +: cartesian-matrix-map ( matrix quot -- matrix' ) + [ [ first length cartesian-square-indices ] keep ] dip '[ _ @ ] matrix-map ; inline -: cartesian-matrix-column-map ( m quot -- m' ) +: cartesian-matrix-column-map ( matrix quot -- matrix' ) [ cols first2 ] prepose cartesian-matrix-map ; inline -: cov-matrix-ddof ( m ddof -- cov ) +: cov-matrix-ddof ( matrix ddof -- cov ) '[ _ cov-ddof ] cartesian-matrix-column-map ; inline -: cov-matrix ( m -- cov ) 0 cov-matrix-ddof ; inline +: cov-matrix ( matrix -- cov ) 0 cov-matrix-ddof ; inline -: sample-cov-matrix ( m -- cov ) 1 cov-matrix-ddof ; inline +: sample-cov-matrix ( matrix -- cov ) 1 cov-matrix-ddof ; inline +GENERIC: square-rows ( object -- matrix ) +M: integer square-rows iota square-rows ; +M: sequence square-rows dup [ nip ] cartesian-map ; + +GENERIC: square-cols ( object -- matrix ) +M: integer square-cols iota square-cols ; +M: sequence square-cols dup [ drop ] cartesian-map ; + +: make-matrix-with-indices ( m n quot -- matrix ) + [ [ iota ] bi@ ] dip '[ @ ] cartesian-map ; inline + +: null-matrix? ( matrix -- ? ) empty? ; + +: well-formed-matrix? ( matrix -- ? ) + dup null-matrix? [ + drop t + ] [ + [ ] [ first length ] bi + '[ length _ = ] all? + ] if ; + +: dim ( matrix -- pair/f ) + [ 2 0 ] + [ [ length ] [ first length ] bi 2array ] if-empty ; + +: square-matrix? ( matrix -- ? ) + { [ well-formed-matrix? ] [ dim all-eq? ] } 1&& ; + +: matrix-coordinates ( dim -- coordinates ) + first2 [ iota ] bi@ cartesian-product ; inline + +: dimension-range ( matrix -- dim range ) + dim [ matrix-coordinates ] [ first [1,b] ] bi ; + +: upper-matrix-indices ( matrix -- matrix' ) + dimension-range [ tail-slice* >array ] 2map concat ; + +: lower-matrix-indices ( matrix -- matrix' ) + dimension-range [ head-slice >array ] 2map concat ; + + +: make-lower-matrix ( object m n -- matrix ) + zero-matrix [ lower-matrix-indices ] [ set-indices ] [ ] tri ; + +: make-upper-matrix ( object m n -- matrix ) + zero-matrix [ upper-matrix-indices ] [ set-indices ] [ ] tri ;