GithubHelp home page GithubHelp logo

ebprado / csp-bart Goto Github PK

View Code? Open in Web Editor NEW
4.0 2.0 2.0 8.9 MB

R scripts and data sets that can be used to reproduce the results presented in the paper "Accounting for shared covariates in semi-parametric Bayesian additive regression trees".

R 98.46% TeX 1.54%
semi-parametric-regression simulation real-dataset bayesian-additive-regression-trees generalised-linear-models

csp-bart's Introduction

Accounting for shared covariates in semi-parametric Bayesian additive regression trees

This repository houses R scripts and data sets that can be used to reproduce the results presented in the paper Accounting for shared covariates in semi-parametric Bayesian additive regression trees. arXiv (2022).

In addition, it provides an implementation of CSP-BART in the format of an R package named spbart.

Installation

library(devtools)
install_github("ebprado/CSP-BART/cspbart",ref='main')

Example

library(cspbart)
rm(list = ls())

# ---------------------------------
# CSP-BART for a continuous response
# ---------------------------------

# Simulate from Friedman equation
friedman_data = function(n, num_cov, sd_error){
  x = matrix(runif(n*num_cov),n,num_cov)
  y = 10*sin(pi*x[,1]*x[,2]) + 20*(x[,3]-.5)^2+10*x[,4]+5*x[,5] + rnorm(n, sd=sd_error)
  return(list(y = y,
              x = as.data.frame(x)))
}

n = 200
ncov = 5
var = 1
data = friedman_data(n, ncov, sqrt(var))

y = data$y # response variable
X1 = as.data.frame(cbind(y,data$x)) # all covariates + the response
X2 = data$x # all covariates

# Run the semi-parametric BART (WITHOUT intercept)
cspbart.fit = cspbart(formula = y ~ 0 + V4 + V5, x1 = X1, x2 = X2, ntrees = 10, nburn = 2000, npost = 1000)

# Run the semi-parametric BART (WITH intercept)
# cspbart.fit = cspbart(formula = y ~ V4 + V5, x1 = X1, x2 = X2, ntrees = 10, nburn = 2000, npost = 1000)

# Calculate the predicted values (yhat) and parameter estimates (betahat)
yhat = colMeans(spbart.fit$y_hat)
betahat = colMeans(spbart.fit$beta_hat)

# Predict on a new dataset
yhat_pred = predict(cspbart.fit, newdata_x1 = X1, newdata_x2 = X2, type = 'mean')
cor(yhat,yhat_pred) == 1

# Plot 
plot(y, yhat);abline(0,1)
plot(1:2, c(10,5), main = 'True versus estimates', ylim=c(3,12))
points(1:2, betahat, col=2, pch=2)

# -----------------------------
# CSP-BART for a binary response
# -----------------------------
n = 200
ncov = 5
var = 1
data = friedman_data(n, ncov, sqrt(var))

aux = data$y
y = ifelse(aux > median(aux), 1, 0)
X1 = as.data.frame(cbind(y,data$x)) # all covariates + the response
X2 = data$x # all covariates

# Run the semi-parametric BART (WITH intercept)
cspbart.fit = cspbart::cl_cspbart(formula = y ~ V4 + V5, x1 = X1, x2 = X2, ntrees = 1, nburn = 2000, npost = 1000)

# Calculate the predicted values (yhat) and parameter estimates (betahat)
yhat = colMeans(pnorm(cspbart.fit$y_hat))
betahat = colMeans(cspbart.fit$beta_hat)

# Predict on a new dataset
yhat_pred = predict(spbart.fit, newdata_x1 = X1, newdata_x2 = X2, type = 'mean')
cor(yhat,yhat_pred) == 1

csp-bart's People

Contributors

ebprado avatar keefe-murphy avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.