diff --git a/extra/machine-learning/rebalancing/rebalancing.factor b/extra/machine-learning/rebalancing/rebalancing.factor index c8c77f6f41..d95029f197 100644 --- a/extra/machine-learning/rebalancing/rebalancing.factor +++ b/extra/machine-learning/rebalancing/rebalancing.factor @@ -22,12 +22,14 @@ MEMO: probabilities-seq ( seq -- seq' ) : stratified-sample ( stratified-sequences probability-sequence -- elt ) probabilities-quot call swap nth random ; inline +: equal-stratified-sample ( stratified-sequences -- elt ) + random random ; inline + : balance-labels ( X y n -- X' y' ) [ dup [ ] collect-index-by - values dup length equal-probabilities - '[ - _ _ _ _ stratified-sample + values '[ + _ _ _ equal-stratified-sample '[ _ swap nth ] bi@ 2array ] ] dip swap replicate [ keys ] [ values ] bi ;