Catalina Jerez
catalina.jerez@colorado.edu


For the all India Summer Rainfall with the tropical SST PCs as covariates, fit these models:

  1. Classification and Regression Trees (CART)

  2. Gradient Boosting Machine (GBM)

  3. Random Forest (RF)

1 Library & functions

1.1 Libraries

gc()
##           used (Mb) gc trigger (Mb) limit (Mb) max used (Mb)
## Ncells  543335 29.1    1202244 64.3         NA   700282 37.4
## Vcells 1005921  7.7    8388608 64.0      36864  1963597 15.0
rm(list = ls())

# Load necessary libraries
spatialAnalysis.Lib     = c()
statisAnalysis.Lib      = c("caret")
dataManipulation.Lib    = c("dplyr", "reshape2", "tidyr") 
dataVisualization.Lib   = c("ggplot2", "ggrepel", "ggpubr", "RColorBrewer",
                             "ggraph", "igraph")
modeling.Lib            = c("hydroGOF", "xgboost", "rpart", "randomForest", "rpart.plot",
                            "randomForestSRC", "data.tree", "DiagrammeR", 
                            "mclust", "tree")
list.packages           = unique(c(spatialAnalysis.Lib, statisAnalysis.Lib,
                                   modeling.Lib,
                                   dataManipulation.Lib, dataVisualization.Lib))
# Load all required packages
sapply(list.packages, require, character.only = TRUE)
##           caret        hydroGOF         xgboost           rpart    randomForest 
##            TRUE            TRUE            TRUE            TRUE            TRUE 
##      rpart.plot randomForestSRC       data.tree      DiagrammeR          mclust 
##            TRUE            TRUE            TRUE            TRUE            TRUE 
##            tree           dplyr        reshape2           tidyr         ggplot2 
##            TRUE            TRUE            TRUE            TRUE            TRUE 
##         ggrepel          ggpubr    RColorBrewer          ggraph          igraph 
##            TRUE            TRUE            TRUE            TRUE            TRUE

1.2 Functions

# tune functions ---------------------------------------------------------------
# tune xgboost
tune.xgb = function(y, df, nsim, eta = 0.1, max_depth = 3, subsample = 0.8, 
                    colsample_bytree = 0.8, min_child_weight = 1,
                    nfold = 5, max.retries = 5) {
  success = FALSE
  attempt = 0
  while (!success && attempt < max.retries) {
    tryCatch({
      # Bootstrap Sampling
      df.train  = df[, !names(df) %in% y]  # excluding the target variable 
      df.y      = df[[y]]
      dtrain    = xgb.DMatrix(data = as.matrix(df.train), label = df.y)
      
      # XGBoost Parameters
      params             = list(
        booster          = "gbtree",
        eta              = eta,
        max_depth        = max_depth,
        # subsample        = subsample,
        # colsample_bytree = colsample_bytree,
        min_child_weight = min_child_weight,
        objective        = "reg:squarederror"
      )
      
      # Train with Cross-Validation
      xgb.cv.out = xgb.cv(params  = params, 
                          data    = dtrain, 
                          nrounds = nsim, 
                          nfold   = nfold, 
                          early_stopping_rounds = 10, 
                          verbose = 0)
      
      # Get Best Number of Trees
      best.tree    = xgb.cv.out$best_iteration
      
      # Final Model Training with Optimal Trees
      gbmXGB.train = xgboost(params  = params, 
                             data    = dtrain, 
                             nrounds = best.tree, 
                             verbose = 0)
      
      # Store Best Tree Count
      gbmXGB.train$best.trees = best.tree
      
      success = TRUE
    }, error = function(e) {
      attempt = attempt + 1
      message(paste("Error occurred:", e$message, "- Retrying... (Attempt:", attempt, "of", max.retries, ")"))
    })
  }
  
  if (!success) {
    stop("XGBoost training failed after ", max.retries, " attempts.")
  }
  return(gbmXGB.train)
}

# Tune CART (rpart + caret)
tune.cart   = function(y, df, nfold = 5) {
  ctrl.cv   = trainControl(method = "cv", number = nfold)
  cart.grid = expand.grid(cp = seq(0.001, 0.1, length.out = 10))
  model     = train(as.formula(paste(y, "~ .")),
                data      = df,
                method    = "rpart",
                trControl = ctrl.cv,
                tuneGrid  = cart.grid,
                metric    = "RMSE")

  return(model)
}

# Tune Random Forest
tune.rf = function(y, df, nsim, stepFactor = 1.5, improve = 0.01, trace = TRUE) {
 #  y    = "rainfall"
 # df    = train.data[,-1]
 # nsim  = ntree
  preds = df[, !(names(df) %in% y)]
  yval  = df[[y]]

  # Tune mtry
  best.model = tuneRF(x          = preds, 
                      y          = yval,
                      ntreeTry   = nsim,
                      stepFactor = stepFactor,
                      improve    = improve,
                      trace      = trace,
                      plot       = FALSE)
  best.mtry   = best.model[which.min(best.model[, 2]), 1]
  # final RF model
  final.model = randomForest(x          = preds, 
                             y          = yval,
                             mtry       = best.mtry,
                             ntree      = nsim,
                             importance = TRUE)
  return(final.model)
}

# functions for graphs ----------------------------------------------------------
# constants 
text.size    = 15
customize.bw = theme_bw(base_family = "Times") +
  theme(
    axis.text.x        = element_text(size = text.size, vjust = 0.5),
    axis.text.y        = element_text(size = text.size),
    axis.title         = element_text(size = text.size + 2, face = "bold"),
    plot.title         = element_text(size = text.size + 2, face = "bold"),
    
    legend.position    = "bottom",
    legend.title       = element_text(size = text.size, face = "bold"),
    legend.text        = element_text(size = text.size),
    
    panel.border       = element_rect(color = "#000000", fill = NA, linewidth = 0.5),
    panel.grid.minor.x = element_line(color = "#d9d9d9", linetype = "dotted"),
    panel.grid.minor.y = element_line(color = "#d9d9d9", linetype = "dotted"),
    panel.grid.major.x = element_blank(),
    panel.grid.major.y = element_blank()
  )
# variable importance function
varImp.plot = function(df.imp, title, imp = T, text.size = 14){
  df.imp      = df.imp[order(-df.imp$importance), ]
  if(imp){ # for random forest
    imp.label = paste0("Relative importance (%)")
  }else{ # gbm
    imp.label = paste0("Relative Influence (%)")
  }
  
  df.imp$label = paste0(df.imp$covariate, " = ", round(df.imp$importance, 2), "%")
  
  ggplot(df.imp, aes(x = reorder(covariate, importance), y = importance)) +
    geom_bar(stat = "identity", color = "#40004b", fill = "#c2a5cf") +
    geom_text(aes(label = label, y = 0.05 * max(importance)),
              hjust = 0, color ="#40004b", size = 6, fontface = "bold") +  
    
    coord_flip() +  # Flip coordinates
    labs(title = title, x = NULL, y = imp.label) +
    theme_bw(base_family = "Times") +
    theme(axis.text.x  = element_text(size = text.size, vjust = 0.5),
          axis.text.y  = element_blank(),
          axis.ticks.y = element_blank(),
          axis.title         = element_text(size = text.size + 1, face = 'bold'),
          plot.title       = element_text(size = text.size + 1, face = 'bold', hjust = .5),
          panel.border       = element_rect(color = "#000000", fill = NA, linewidth = .5),
          panel.grid.minor.x = element_line(color = "#d9d9d9", linetype = "dotted"),
          panel.grid.minor.y = element_line(color = "#d9d9d9", linetype = "dotted"),
          panel.grid.major.x = element_blank(),
          panel.grid.major.y = element_blank() )
}

# tree function for rf
# adapted from \url{https://shiring.github.io/machine_learning/2017/03/16/rf_plot_ggraph}
treeRF.plot = function(final_model, tree_num) {
  
  # get tree by index
  tree = randomForest::getTree(final_model, k = tree_num, labelVar = TRUE) %>%
    tibble::rownames_to_column() %>%
    # make leaf split points to NA, so the 0s won't get plotted
    mutate(`split point` = ifelse(is.na(prediction), `split point`, NA))
  # prepare data frame for graph
  graph.frame = data.frame(from = rep(tree$rowname, 2),
                           to   = c(tree$`left daughter`, tree$`right daughter`))
  # convert to graph and delete the last node that we don't want to plot
  graph = graph_from_data_frame(graph.frame) %>%
    delete_vertices("0")
  
  # set node labels
  V(graph)$node_label = gsub("_", " ", as.character(tree$`split var`))
  V(graph)$leaf_label = as.character(round(tree$prediction, 2))
  V(graph)$split      = as.character(round(tree$`split point`, digits = 2))
  
  # plot
  ggraph(graph, 'dendrogram') + 
    geom_edge_link() +
    geom_node_point() +
    geom_node_text(aes(label = node_label), na.rm = TRUE, repel = TRUE) +
    geom_node_label(aes(label = split), vjust = 2.5, na.rm = TRUE, fill = "white") +
    geom_node_label(aes(label = leaf_label, fill = leaf_label), na.rm = TRUE, 
                    repel = TRUE, colour = "white", fontface = "bold", show.legend = FALSE) +
    customize.bw + 
    theme(panel.grid.minor = element_blank(),
          panel.grid.major = element_blank(),
          panel.background = element_blank(),
          plot.background = element_rect(fill = "white"),
          panel.border = element_blank(),
          axis.line = element_blank(),
          axis.text.x = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks = element_blank(),
          axis.title.x = element_blank(),
          axis.title.y = element_blank(),
          plot.title = element_text(size = 18))
}

2 Load data and pre-process

2.1 Rajeevan Gridded rainfall

Rainfall data for JJAS from 1950 to 2016 is read and reshaped into a matrix where each row represents one season (year) and each column a spatial location. Only valid (non-missing) data are retained.

path.data = "https://civil.colorado.edu/~balajir/CVEN6833/HWs/HW-2/Spring-2025-data-commands/"
# Global grid
# deg grid
nrows = 68
ncols = 65
ntime = 67    # Jun-Sep 1950 - 2016

# Lat - Long grid
ygrid = seq(6.5 , 38.5, by=0.5)
xgrid = seq(66.5, 100 , by=0.5)
ny    = length(ygrid)
nx    = length(xgrid)
xygrid= matrix(0, nrow = nx*ny, ncol=2)
i = 0
for(iy in 1:ny){
  for(ix in 1:nx){
  i           = i+1
  xygrid[i,1] = ygrid[iy]
  xygrid[i,2] = xgrid[ix]
  }
}

# read India location valid points
rain.locs = read.table(paste0(path.data, "India-rain-locs.txt"))
# read data binary
data.raw = readBin(paste0(path.data, "India-Rain-JJAS-05deg-1950-2016.r4"),
                       what   = "numeric",
                       n      = (nx * ny * ntime),
                       size   = 4,
                       endian = "swap")
# reshape to (ny, nx, ntime)
data.array   = array(data.raw, dim = c(ny, nx, ntime))
rain.indices = rain.locs[, 3]
nsites       = length(rain.indices)

# construct rain.avg (ntime rows × valid rain sites columns)
rain.avg     = matrix(NA, nrow = ntime, ncol = nsites)
for (t in 1:ntime) {
  rain.slice    = data.array[, , t]
  rain.avg[t, ] = rain.slice[rain.indices]
}

rain.mean = rowMeans(rain.avg)

2.2 Sea Surface Temperature fields

SST data for the tropical Pacific is loaded and reshaped. Only valid SST points are retained based on the provided locations.

path.data = "https://civil.colorado.edu/~balajir/CVEN6833/HWs/HW-2/Spring-2025-data-commands/"
# grid SST (NOAA tropical)
ntime     = 67    # Jun-Sep 1950 - 2016
ygrid.sst = seq(-16, 16 , by = 2)
xgrid.sst = seq(0  , 358, by = 2)
ny.sst    = length(ygrid.sst)
nx.sst    = length(xgrid.sst)

# read SST valid points
sst.locs = read.table(paste0(path.data, "NOAA-trop-sst-locs.txt"))
# read SST binary
data.raw = readBin(paste0(path.data, "NOAA-Trop-JJAS-SST-1950-2016.r4"),
                       what   = "numeric",
                       n      = (nx.sst * ny.sst * ntime),
                       size   = 4,
                       endian = "swap")
# reshape to (ny, nx, ntime)
sst.array   = array(data.raw, dim = c(ny.sst, nx.sst, ntime))
sst.indices = sst.locs[, 3]
nsites      = length(sst.indices)

# construct sst.avg (ntime rows × valid SST sites columns)
sst.avg     = matrix(NA, nrow = ntime, ncol = nsites)
for (t in 1:ntime) {
  sst.slice    = sst.array[, , t]
  sst.avg[t, ] = sst.slice[sst.indices]
}

3 Fitting models

3.1 PCA on SSTs

sst.pca = prcomp(sst.avg, scale. = TRUE, center = TRUE)
sst.pcs = sst.pca$x[, 1:5]

3.2 Combine predictors and target

data       = data.frame(year = 1950:2016, rainfall = rain.mean, sst.pcs)
train.id   = createDataPartition(data$rainfall, p = 0.6, list = FALSE)
train.data = data[train.id, ]
test.data  = data[-train.id, ]

set.seed(250504)

3.3 Train models

We’ll fit the models using the following packages:

  1. : Recursive Partitioning and Regression Trees (Breiman et al., 1984) for CART.

  2. : eXtreme Gradient Boosting Training (Chen and Guestrin, 2016) for GBM.

  3. : Classification and Regression with Random Forest and : Fast Unified Random Forests for Survival, Regression, and Classification for RF.

ntree = 500
# CART with rpart pkg
cart.model = tune.cart(y   = "rainfall", df    = train.data[,-1])
# GBM with xgboost pkg
gbm.model  = tune.xgb(y    = "rainfall",
                     df    = train.data[,-1], 
                     nsim  = ntree)

# RF with randomForest pkg
rf.model   = tune.rf(y    = "rainfall",
                     df    = train.data[,-1], 
                     nsim  = ntree)
## mtry = 1  OOB error = 5277.358 
## Searching left ...
## Searching right ...
# RF-SRC with andomForestSRC pkg
rfsrc.model = rfsrc(rainfall ~ ., 
                    data       = train.data[,-1], 
                    ntree      = ntree,
                    importance = TRUE,
                    bootstrap  = "by.root",
                    samptype   = "swr")

3.4 Variable importance

# CART .........................................................................
imp.cart     = varImp(cart.model)$importance
shifted      = imp.cart - min(imp.cart)
imp.cart     = shifted / sum(shifted)
varImp.cart  = data.frame(
  covariate  = rownames(imp.cart),
  importance = imp.cart[,1] )
# varImp.plot(df.imp = varImp.cart, title = "CART", imp = T)

# GBM ..........................................................................
# relative influence (Friedman 2001)
imp.gbm    = xgb.importance(gbm.model$feature_names, model = gbm.model)
varImp.gbm = data.frame(covariate  = imp.gbm$Feature, 
                        importance = imp.gbm$Gain)
# varImp.plot(df.imp = varImp.gbm, title  = "GBM", imp = F)

# RF ...........................................................................
# We preferred %IncMSE (Percentage Increase in Mean Squared Error) 
# because it directly measures how much removing a variable affects predictions.
imp.rf     = varImp(rf.model, type = 1) 
shifted    = imp.rf[,1] - min(imp.rf[,1]) # all values non-negative
imp.rf[,1] = shifted / sum(shifted) # normalize to sum to 1
varImp.rf  = data.frame(covariate = rownames(imp.rf),
                       importance = imp.rf[,1] ) 
# varImp.plot(df.imp = varImp.rf, title  = "RF", imp = T)

# RF-SRC .......................................................................
imp.rfsrc    = rfsrc.model$importance
imp.rfsrc    = imp.rfsrc / sum(imp.rfsrc) # normalize so that the sum equals 100
varImp.rfsrc = data.frame(covariate  = names(imp.rfsrc),
                          importance = imp.rfsrc ) 
# varImp.plot(df.imp = varImp.rfsrc, title  = "RF-src", imp = T)

# ggarrange
ggarrange(varImp.plot(df.imp = varImp.cart , title = "CART"   , imp = T),
          varImp.plot(df.imp = varImp.gbm  , title  = "GBM"   , imp = F),
          varImp.plot(df.imp = varImp.rf   , title  = "RF"    , imp = T),
          varImp.plot(df.imp = varImp.rfsrc, title  = "RF-src", imp = T)
          )

Each model uses the 5 PCs as predictors, where all models consistently emphasize PC2 (follows by PC1) as the most critical variable for predicting rainfall.

3.5 Decision (best) Tree

# CART tree
rpart.plot(cart.model$finalModel, main = "CART Decision Tree", box.palette = "BuGn")

CART Decision Tree:

  • Shallow, uses \(PC5 < -2\) and \(PC1 \geq 29\) as first splits.

  • Very basic, explains its underperformance.

# GBM with best tree
xgb.plot.tree(model = gbm.model, 
              trees = gbm.model$best.trees-1, 
              show_node_id = T,
              render = T)
xgbTree = xgb.plot.tree(model = gbm.model, trees = gbm.model$best.trees-1, 
              show_node_id = T, render = F)
export_graph(xgbTree)

GBM Tree:

  • Uses multiple PCs with deep branching and high gain values.

  • More nuanced and optimized.

# RF tree 
itree = 10
treeRF.plot(rf.model, itree)

itree = 10
tree  = randomForestSRC::get.tree(rfsrc.model, tree.id = itree, show.plots = T)
plot(tree)

RF Tree:

  • Splits on PCs like PC3, PC4, PC1 with leaf values showing different predicted rainfall.

  • Ensemble behavior not fully captured in one tree.

3.6 Predictions and metrics

pred.cart  = predict(cart.model$finalModel, test.data[,-1])
dtest      = xgb.DMatrix(data  = as.matrix(dplyr::select(test.data[,-1], -rainfall)), 
                        label = test.data$rainfall)
pred.gbm   = predict(gbm.model, newdata = dtest)
pred.rf    = predict(rf.model, test.data[,-1])
pred.rfsrc = predict(rfsrc.model, test.data[,-1])

df.pred    = data.frame(
  "year"   = test.data$year,
  "rain"   = test.data$rainfall,
  "CART"   = pred.cart,
  "GBM"    = pred.gbm,
  "RF"     = pred.rf,
  "RF.src" = pred.rfsrc$predicted)

# compute RMSE and R2
models      = c("CART", "GBM", "RF", "RF.src")
rmse.values = sapply(df.pred[, models], function(obs) rmse(obs, df.pred$rain))
r2.values   = sapply(df.pred[, models], function(obs) R2(obs, df.pred$rain))
skill.df    = data.frame(
  Model = models,
  RMSE  = round(rmse.values, 2),
  R2    = round(as.numeric(r2.values), 2)
)
print(skill.df)
##         Model  RMSE    R2
## CART     CART 84.94 -0.02
## GBM       GBM 68.16  0.34
## RF         RF 70.98  0.29
## RF.src RF.src 72.05  0.26

3.7 Observed vs Predicted

df.long        = tidyr::pivot_longer(df.pred, cols = models, names_to = "Model", values_to = "Predicted")
skill.df$label = paste0("RMSE = ", skill.df$RMSE, "\nR² = ", skill.df$R2)

label.df = skill.df %>% dplyr::select(Model, label)
df.long  = left_join(df.long, label.df, by = "Model")

ggplot(df.long, aes(x = rain, y = Predicted, fill = Model)) +
  geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
  
  # point data
  geom_point(size = 3, alpha = 0.7, shape = 21) +
  labs(title = "Observed vs Predicted Rainfall", 
       x = "Observed", y = "Predicted") +
  
  # skill
  geom_text(data = label.df, aes(x = Inf, y = -Inf, label = label), 
            hjust = 1.1, vjust = -0.5, inherit.aes = FALSE) +
  
  facet_wrap(~ Model, scales = "free", ncol = 2) +
  # scales
  
  # theme
  customize.bw + 
  theme(strip.text = element_text(size = text.size),
        strip.background = element_blank()) 

  • CART: large deviations; many predictions are off-mark, justifying its poor RMSE and \(R^2\).

  • GBM: points closely align to the 45° line, indicating accurate predictions.

  • RF and RF.src: slightly more scatter around the line, still reasonable.

4 Final remarks

Criterion Best Model
RMSE, \(R^2\) GBM
Model Complexity GBM > RF > RF.src > CART
Best Visual Fit GBM
Simplicity CART