- Notifications
You must be signed in to change notification settings - Fork 17
Open
Labels
help wantedExtra attention is neededExtra attention is neededquestionFurther information is requestedFurther information is requested
Description
In some simulation studies, stochtree (R package) is dramatically slower in Windows than in MacOS or Linux [all are running on performant hardware, but of course not completely apples-to-apples]. Features that appear to trigger this massive performance differential:
- Data generating processes that encourage deep trees (because of deep interactions between features)
- Large sample sizes that support growing deep trees
- A large cutpoint grid in the grow-from-root algorithm (the
cutpoint_grid_sizeparameter in thestochtree::bartandstochtree::bcffunction signatures)
To view this performance gap, run the following code on both Windows and MacOS / Linux
# Load libraries library(stochtree) library(rnn) # Random seed random_seed <- 1234 set.seed(random_seed) # Fixed parameters sample_size <- 500000 alpha <- 1.0 beta <- 0.1 ntree <- 50 num_iter <- 10 num_gfr <- 10 num_burnin <- 0 num_mcmc <- 10 cutpoint_grid_size <- 10 min_samples_leaf <- 1 nu <- 3 lambda <- NULL q <- 0.9 sigma2_init <- NULL sample_tau <- F sample_sigma <- T # Initial DGP setup n0 <- 500 p <- 10 n <- n0*(2^p) k <- 2 p1 <- 20 noise <- 0.1 # Full factorial covariate reference frame xtemp <- as.data.frame(as.factor(rep(0:(2^p-1),n0))) xtemp1 <- rep(0:(2^p-1),n0) x <- t(sapply(xtemp1,function(j) as.numeric(int2bin(j,p)))) X_superset <- x*abs(rnorm(length(x))) - (1-x)*abs(rnorm(length(x))) # Generate outcome M <- model.matrix(~.-1,data = xtemp) M <- cbind(rep(1,n),M) beta.true <- -10*abs(rnorm(ncol(M))) beta.true[1] <- 0.5 non_zero_betas <- c(1,sample(1:ncol(M), p1-1)) beta.true[-non_zero_betas] <- 0 Y <- M %*% beta.true + rnorm(n, 0, noise) y_superset <- as.numeric(Y>0) # Downsample to desired n subset_inds <- order(sample(1:nrow(X_superset), sample_size, replace = F)) X <- X_superset[subset_inds,] y <- y_superset[subset_inds] system.time({ bart_obj <- stochtree::bart( X_train = X, y_train = y, alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, nu = nu, lambda = lambda, q = q, sigma2_init = sigma2_init, num_trees = ntree, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, cutpoint_grid_size = cutpoint_grid_size, sample_tau = sample_tau, sample_sigma = sample_sigma, random_seed = random_seed ) })Metadata
Metadata
Assignees
Labels
help wantedExtra attention is neededExtra attention is neededquestionFurther information is requestedFurther information is requested