Catalina Jerez
catalina.jerez@colorado.edu


Fit a CART model to the leading 4 PCs of summer season rainfall over India, using 5 leading PCs of summer global tropical SSTs as covariates.

  1. Fit a CART model for each PC separately.

  2. Estimate the model precipitation at all the locations by multiplying the PCs estimates with Eigen Vectors.

1 Library & functions

1.1 Libraries

gc()
##           used (Mb) gc trigger (Mb) limit (Mb) max used (Mb)
## Ncells  543260 29.1    1202030 64.2         NA   700282 37.4
## Vcells 1005469  7.7    8388608 64.0      36864  1963597 15.0
rm(list = ls())

# Load necessary libraries
spatialAnalysis.Lib     = c("maps", "rnaturalearth", "sf")
statisAnalysis.Lib      = c("caret", "DescTools")
dataManipulation.Lib    = c("dplyr", "reshape2", "tidyr") 
dataVisualization.Lib   = c("ggplot2", "ggrepel", "ggpubr", "RColorBrewer",
                             "ggraph", "igraph")
modeling.Lib            = c("hydroGOF", "xgboost", "rpart", "randomForest", "rpart.plot",
                            "randomForestSRC", "data.tree", "DiagrammeR", 
                            "mclust", "tree")
list.packages           = unique(c(spatialAnalysis.Lib, statisAnalysis.Lib,
                                   modeling.Lib,
                                   dataManipulation.Lib, dataVisualization.Lib))
# Load all required packages
sapply(list.packages, require, character.only = TRUE)
##            maps   rnaturalearth              sf           caret       DescTools 
##            TRUE            TRUE            TRUE            TRUE            TRUE 
##        hydroGOF         xgboost           rpart    randomForest      rpart.plot 
##            TRUE            TRUE            TRUE            TRUE            TRUE 
## randomForestSRC       data.tree      DiagrammeR          mclust            tree 
##            TRUE            TRUE            TRUE            TRUE            TRUE 
##           dplyr        reshape2           tidyr         ggplot2         ggrepel 
##            TRUE            TRUE            TRUE            TRUE            TRUE 
##          ggpubr    RColorBrewer          ggraph          igraph 
##            TRUE            TRUE            TRUE            TRUE

1.2 Functions

# tune functions ---------------------------------------------------------------
# tune xgboost
tune.xgb = function(y, df, nsim, eta = 0.1, max_depth = 3, subsample = 0.8, 
                    colsample_bytree = 0.8, min_child_weight = 1,
                    nfold = 5, max.retries = 5) {
  success = FALSE
  attempt = 0
  while (!success && attempt < max.retries) {
    tryCatch({
      # Bootstrap Sampling
      df.train  = df[, !names(df) %in% y]  # excluding the target variable 
      df.y      = df[[y]]
      dtrain    = xgb.DMatrix(data = as.matrix(df.train), label = df.y)
      
      # XGBoost Parameters
      params             = list(
        booster          = "gbtree",
        eta              = eta,
        max_depth        = max_depth,
        # subsample        = subsample,
        # colsample_bytree = colsample_bytree,
        min_child_weight = min_child_weight,
        objective        = "reg:squarederror"
      )
      
      # Train with Cross-Validation
      xgb.cv.out = xgb.cv(params  = params, 
                          data    = dtrain, 
                          nrounds = nsim, 
                          nfold   = nfold, 
                          early_stopping_rounds = 10, 
                          verbose = 0)
      
      # Get Best Number of Trees
      best.tree    = xgb.cv.out$best_iteration
      
      # Final Model Training with Optimal Trees
      gbmXGB.train = xgboost(params  = params, 
                             data    = dtrain, 
                             nrounds = best.tree, 
                             verbose = 0)
      
      # Store Best Tree Count
      gbmXGB.train$best.trees = best.tree
      
      success = TRUE
    }, error = function(e) {
      attempt = attempt + 1
      message(paste("Error occurred:", e$message, "- Retrying... (Attempt:", attempt, "of", max.retries, ")"))
    })
  }
  
  if (!success) {
    stop("XGBoost training failed after ", max.retries, " attempts.")
  }
  return(gbmXGB.train)
}

# Tune CART (rpart + caret)
tune.cart   = function(y, df, nfold = 5) {
  ctrl.cv   = trainControl(method = "cv", number = nfold)
  cart.grid = expand.grid(cp = seq(0.001, 0.1, length.out = 10))
  model     = train(as.formula(paste(y, "~ .")),
                data      = df,
                method    = "rpart",
                trControl = ctrl.cv,
                tuneGrid  = cart.grid,
                metric    = "RMSE")

  return(model)
}

# Tune Random Forest
tune.rf = function(y, df, nsim, stepFactor = 1.5, improve = 0.01, trace = TRUE) {
 #  y    = "rainfall"
 # df    = train.data[,-1]
 # nsim  = ntree
  preds = df[, !(names(df) %in% y)]
  yval  = df[[y]]

  # Tune mtry
  best.model = tuneRF(x          = preds, 
                      y          = yval,
                      ntreeTry   = nsim,
                      stepFactor = stepFactor,
                      improve    = improve,
                      trace      = trace,
                      plot       = FALSE)
  best.mtry   = best.model[which.min(best.model[, 2]), 1]
  # final RF model
  final.model = randomForest(x          = preds, 
                             y          = yval,
                             mtry       = best.mtry,
                             ntree      = nsim,
                             importance = TRUE)
  return(final.model)
}

# functions for graphs ----------------------------------------------------------
# constants 
text.size    = 15
customize.bw = theme_bw(base_family = "Times") +
  theme(
    axis.text.x        = element_text(size = text.size, vjust = 0.5),
    axis.text.y        = element_text(size = text.size),
    axis.title         = element_text(size = text.size + 2, face = "bold"),
    plot.title         = element_text(size = text.size + 2, face = "bold"),
    
    legend.position    = "bottom",
    legend.title       = element_text(size = text.size, face = "bold"),
    legend.text        = element_text(size = text.size),
    
    panel.border       = element_rect(color = "#000000", fill = NA, linewidth = 0.5),
    panel.grid.minor.x = element_line(color = "#d9d9d9", linetype = "dotted"),
    panel.grid.minor.y = element_line(color = "#d9d9d9", linetype = "dotted"),
    panel.grid.major.x = element_blank(),
    panel.grid.major.y = element_blank()
  )

# world map
world.sf     = ne_countries(scale = "medium", returnclass = "sf") %>%
    st_wrap_dateline(options = c("WRAPDATELINE=YES", "DATELINEOFFSET=180")) %>%
    st_transform("+proj=longlat +datum=WGS84")

# variable importance function
varImp.plot = function(df.imp, title, imp = T, text.size = 14){
  df.imp      = df.imp[order(-df.imp$importance), ]
  if(imp){ # for random forest
    imp.label = paste0("Relative importance (%)")
  }else{ # gbm
    imp.label = paste0("Relative Influence (%)")
  }
  
  df.imp$label = paste0(df.imp$covariate, " = ", round(df.imp$importance, 2), "%")
  
  ggplot(df.imp, aes(x = reorder(covariate, importance), y = importance)) +
    geom_bar(stat = "identity", color = "#40004b", fill = "#c2a5cf") +
    geom_text(aes(label = label, y = 0.05 * max(importance)),
              hjust = 0, color ="#40004b", size = 6, fontface = "bold") +  
    
    coord_flip() +  # Flip coordinates
    labs(title = title, x = NULL, y = imp.label) +
    theme_bw(base_family = "Times") +
    theme(axis.text.x  = element_text(size = text.size, vjust = 0.5),
          axis.text.y  = element_blank(),
          axis.ticks.y = element_blank(),
          axis.title         = element_text(size = text.size + 1, face = 'bold'),
          plot.title       = element_text(size = text.size + 1, face = 'bold', hjust = .5),
          panel.border       = element_rect(color = "#000000", fill = NA, linewidth = .5),
          panel.grid.minor.x = element_line(color = "#d9d9d9", linetype = "dotted"),
          panel.grid.minor.y = element_line(color = "#d9d9d9", linetype = "dotted"),
          panel.grid.major.x = element_blank(),
          panel.grid.major.y = element_blank() )
}

# tree function for rf
# adapted from \url{https://shiring.github.io/machine_learning/2017/03/16/rf_plot_ggraph}
treeRF.plot = function(final_model, tree_num) {
  
  # get tree by index
  tree = randomForest::getTree(final_model, k = tree_num, labelVar = TRUE) %>%
    tibble::rownames_to_column() %>%
    # make leaf split points to NA, so the 0s won't get plotted
    mutate(`split point` = ifelse(is.na(prediction), `split point`, NA))
  # prepare data frame for graph
  graph.frame = data.frame(from = rep(tree$rowname, 2),
                           to   = c(tree$`left daughter`, tree$`right daughter`))
  # convert to graph and delete the last node that we don't want to plot
  graph = graph_from_data_frame(graph.frame) %>%
    delete_vertices("0")
  
  # set node labels
  V(graph)$node_label = gsub("_", " ", as.character(tree$`split var`))
  V(graph)$leaf_label = as.character(round(tree$prediction, 2))
  V(graph)$split      = as.character(round(tree$`split point`, digits = 2))
  
  # plot
  ggraph(graph, 'dendrogram') + 
    geom_edge_link() +
    geom_node_point() +
    geom_node_text(aes(label = node_label), na.rm = TRUE, repel = TRUE) +
    geom_node_label(aes(label = split), vjust = 2.5, na.rm = TRUE, fill = "white") +
    geom_node_label(aes(label = leaf_label, fill = leaf_label), na.rm = TRUE, 
                    repel = TRUE, colour = "white", fontface = "bold", show.legend = FALSE) +
    customize.bw + 
    theme(panel.grid.minor = element_blank(),
          panel.grid.major = element_blank(),
          panel.background = element_blank(),
          plot.background = element_rect(fill = "white"),
          panel.border = element_blank(),
          axis.line = element_blank(),
          axis.text.x = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks = element_blank(),
          axis.title.x = element_blank(),
          axis.title.y = element_blank(),
          plot.title = element_text(size = 18))
}

# plot spatial metrics
plot.spatialMetric = function(data, skill){
  
  # data = data %>%
  #   dplyr::filter(metric == skill)
  
  data = subset(data, metric %in% skill)
  ggplot(data) +
    geom_tile(aes(x = lon, y = lat, fill = value.winsor)) +
    # regions
    geom_rect(aes(xmin = 66, xmax = 100, ymin = 6, ymax = 39),
              inherit.aes = FALSE, color = "#41ab5d", fill = NA, linewidth = 1) +
    geom_sf(data = world.sf, fill = NA, color = "black", 
            linewidth = 0.2, inherit.aes = FALSE) +
    # labs
    labs(title = paste0(skill, " across grid cells"), 
         x = "Longitude", y = "Latitude") +
    # scales
    facet_wrap(~ model, ncol = 2) +
    coord_sf(xlim = c(65, 101), ylim = c(5, 40), expand = FALSE, crs = "EPSG:4326") +
    scale_fill_gradientn(skill, 
                         colors = colorRampPalette(rev(brewer.pal(9, "RdBu")))(200), 
                         # limits = fill_limits,
                         guide  = guide_colorbar(barwidth = 15, barheight = 2) ) +
    scale_y_continuous(breaks = seq(10, 40 , 10))+
    scale_x_continuous(breaks = seq(60, 100, 10))+
    # theme 
    customize.bw +
    theme(axis.title.x = element_blank(),
        plot.subtitle = element_text(size = text.size-2, face = "italic"),
        strip.background = element_blank(),
        strip.text = element_text(size = text.size, face = "bold"))
}

2 Load data and pre-process

2.1 Rajeevan Gridded rainfall

Rainfall data for JJAS from 1950 to 2016 is read and reshaped into a matrix where each row represents one season (year) and each column a spatial location. Only valid (non-missing) data are retained.

path.data = "https://civil.colorado.edu/~balajir/CVEN6833/HWs/HW-2/Spring-2025-data-commands/"
# Global grid
# deg grid
nrows = 68
ncols = 65
ntime = 67    # Jun-Sep 1950 - 2016

# Lat - Long grid
ygrid = seq(6.5 , 38.5, by=0.5)
xgrid = seq(66.5, 100 , by=0.5)
ny    = length(ygrid)
nx    = length(xgrid)
xygrid= matrix(0, nrow = nx*ny, ncol=2)
i = 0
for(iy in 1:ny){
  for(ix in 1:nx){
  i           = i+1
  xygrid[i,1] = ygrid[iy]
  xygrid[i,2] = xgrid[ix]
  }
}

# read India location valid points
rain.locs = read.table(paste0(path.data, "India-rain-locs.txt"))
# read data binary
data.raw = readBin(paste0(path.data, "India-Rain-JJAS-05deg-1950-2016.r4"),
                       what   = "numeric",
                       n      = (nx * ny * ntime),
                       size   = 4,
                       endian = "swap")
# reshape to (ny, nx, ntime)
data.array   = array(data.raw, dim = c(ny, nx, ntime))
rain.indices = rain.locs[, 3]
nsites       = length(rain.indices)

# construct rain.avg (ntime rows × valid rain sites columns)
rain.avg     = matrix(NA, nrow = ntime, ncol = nsites)
for (t in 1:ntime) {
  rain.slice    = data.array[, , t]
  rain.avg[t, ] = rain.slice[rain.indices]
}

rain.mean = rowMeans(rain.avg)

2.2 Sea Surface Temperature fields

SST data for the tropical Pacific is loaded and reshaped. Only valid SST points are retained based on the provided locations.

path.data = "https://civil.colorado.edu/~balajir/CVEN6833/HWs/HW-2/Spring-2025-data-commands/"
# grid SST (NOAA tropical)
ntime     = 67    # Jun-Sep 1950 - 2016
ygrid.sst = seq(-16, 16 , by = 2)
xgrid.sst = seq(0  , 358, by = 2)
ny.sst    = length(ygrid.sst)
nx.sst    = length(xgrid.sst)

# read SST valid points
sst.locs = read.table(paste0(path.data, "NOAA-trop-sst-locs.txt"))
# read SST binary
data.raw = readBin(paste0(path.data, "NOAA-Trop-JJAS-SST-1950-2016.r4"),
                       what   = "numeric",
                       n      = (nx.sst * ny.sst * ntime),
                       size   = 4,
                       endian = "swap")
# reshape to (ny, nx, ntime)
sst.array   = array(data.raw, dim = c(ny.sst, nx.sst, ntime))
sst.indices = sst.locs[, 3]
nsites      = length(sst.indices)

# construct sst.avg (ntime rows × valid SST sites columns)
sst.avg     = matrix(NA, nrow = ntime, ncol = nsites)
for (t in 1:ntime) {
  sst.slice    = sst.array[, , t]
  sst.avg[t, ] = sst.slice[sst.indices]
}

3 Multivariate forecasting

3.1 PCA on rainfall data and SST

rain.pca   = prcomp(rain.avg, scale. = TRUE, center = TRUE)
rain.pcs   = rain.pca$x[, 1:4]  # use first 4 PCs
rain.evecs = rain.pca$rotation[, 1:4]  # eigenvectors for reconstruction

sst.pca    = prcomp(sst.avg, scale. = TRUE, center = TRUE)
sst.pcs    = sst.pca$x[, 1:5]

3.2 Combine predictors and target

data       = data.frame(year = 1950:2016, rainfall = rain.mean, sst.pcs)
train.id   = createDataPartition(data$rainfall, p = 0.6, list = FALSE)
train.data = data[train.id, ]
test.data  = data[-train.id, ]

set.seed(250405)

3.3 Modeling: Predict each rain PC from SST PCs

sst.pcs.df = as.data.frame(sst.pcs[, 1:5])  # use 5 SST PCs
models = list(CART = list(), RF = list())
preds  = list(CART = matrix(NA, nrow = nrow(sst.pcs), ncol = 4),
              RF   = matrix(NA, nrow = nrow(sst.pcs), ncol = 4))

for (i in 1:4) {
  df               = cbind(PC = rain.pcs[, i], sst.pcs.df)
  train.pcs        = df[train.id, ]
  # CART
  cart             = tune.cart("PC", df = train.pcs)
  models$CART[[i]] = cart$finalModel
  preds$CART[, i]  = predict(cart$finalModel, sst.pcs.df)

  # RF
  rf             = tune.rf("PC", df = train.pcs, nsim = 500)
  models$RF[[i]] = rf
  preds$RF[, i]  = predict(rf, sst.pcs.df)
}
## mtry = 1  OOB error = 149.7205 
## Searching left ...
## Searching right ...
## mtry = 1  OOB error = 63.94505 
## Searching left ...
## Searching right ...
## mtry = 1  OOB error = 62.07333 
## Searching left ...
## Searching right ...
## mtry = 1  OOB error = 41.59057 
## Searching left ...
## Searching right ...

3.4 Reconstruct model rainfall fields

# Reconstruct model rainfall fields
dim(rain.evecs)  # nsites x 4
## [1] 1250    4
dim(preds$CART)  # ntime x 4
## [1] 67  4
rain.recon      = list()
rain.recon$CART = preds$CART %*% t(rain.evecs)
rain.recon$RF   = preds$RF   %*% t(rain.evecs)

# Inverse transformation
rain.recon$CART = sweep(preds$CART %*% t(rain.pca$rotation[, 1:4]), 2, rain.pca$scale, "*")
rain.recon$CART = sweep(rain.recon$CART, 2, rain.pca$center, "+")
rain.recon$RF = sweep(preds$RF %*% t(rain.pca$rotation[, 1:4]), 2, rain.pca$scale, "*")
rain.recon$RF = sweep(rain.recon$RF, 2, rain.pca$center, "+")
# # multiply by the original scaling and add the center (inverse standardization)
# rain.recon$CART = sweep(rain.recon$CART, 2, rain.pca$scale, FUN = "*")
# rain.recon$CART = sweep(rain.recon$CART, 2, rain.pca$center, FUN = "+")
# rain.recon$RF   = sweep(rain.recon$RF, 2, rain.pca$scale, FUN = "*")
# rain.recon$RF   = sweep(rain.recon$RF, 2, rain.pca$center, FUN = "+")

3.5 Evaluation per location (grid cell)

eval.metrics = function(predicted, true) {
  sapply(seq_len(ncol(predicted)), function(i) {
    cor.val  = cor(predicted[, i], true[, i])
    rmse.val = rmse(predicted[, i], true[, i])
    c(Corr = cor.val, RMSE = rmse.val)
  }) %>% t() %>% as.data.frame()
}

true.rain.test = rain.avg[-train.id, ]
rain.cart.test = rain.recon$CART[-train.id, ]
rain.rf.test   = rain.recon$RF[-train.id, ]

metrics      = list()
metrics$CART = eval.metrics(rain.cart.test, true.rain.test)
metrics$RF   = eval.metrics(rain.rf.test, true.rain.test)

3.6 Boxplots of metrics

# Combine metrics
metrics$CART$model = "CART"
metrics$RF$model   = "RF"
metrics.df         = rbind(metrics$CART, metrics$RF)
metrics.long       = tidyr::pivot_longer(metrics.df, 
                                    cols      = c("Corr", "RMSE"), 
                                    names_to  = "metric", 
                                    values_to = "value")
metrics.long       = metrics.long %>%
  dplyr::group_by(model, metric)  %>%
  dplyr::mutate(value.winsor = Winsorize(value, 
           val = quantile(value, probs = c(0.1, 0.9), na.rm = T) )) %>%
  dplyr::ungroup()

ggplot(metrics.long, aes(x = model, y = value.winsor, fill = model)) +
  geom_boxplot(alpha = 0.8, show.legend = F) +
  facet_wrap(~ metric, scales = "free_y") +
  labs(title = "Evaluation metrics", 
       subtitle = "Summer season rainfall over India",
       y = "Value") +
  # theme
  customize.bw + 
  theme(axis.title.x = element_blank(),
        plot.subtitle = element_text(size = text.size-2, face = "italic"),
        strip.background = element_blank(),
        strip.text = element_text(size = text.size, face = "bold"))

metrics.long %>%
  group_by(model, metric) %>%
  summarise(mean    = round(mean(value), 2),
            median  = round(median(value), 2),
            sd      = round(sd(value), 2),
            .groups = 'drop')
  • Random Forest outperforms CART across both metrics, especially in correlation (double the median correlation, indicating better rainfall pattern capture).

3.7 Spatial heatmaps

xy.coords           = xygrid[rain.indices, ]
colnames(xy.coords) = c("lat", "lon")
metrics.spatial     = rbind(
  cbind(model = "CART", xy.coords, metrics$CART[, c("Corr", "RMSE")]),
  cbind(model = "RF",   xy.coords, metrics$RF[, c("Corr", "RMSE")]) )
spatial.long       = tidyr::pivot_longer(metrics.spatial, 
                                    cols      = c("Corr", "RMSE"), 
                                    names_to  = "metric", 
                                    values_to = "value")
spatial.long     = spatial.long %>% 
  mutate(lon = ifelse(lon > 180, lon - 360, lon)) %>%
  dplyr::group_by(model, metric)  %>%
  dplyr::mutate(value.winsor = Winsorize(value, 
           val = quantile(value, probs = c(0.1, 0.9), na.rm = T) )) %>%
  dplyr::ungroup()

plot.spatialMetric(data = spatial.long, skill = "Corr")

plot.spatialMetric(data = spatial.long, skill = "RMSE")

Correlation Heatmap:

  • CART: sparse pockets with Corr > 0.2, most values hover near 0.

  • RF: several regions with Corr > 0.3-0.4, showing significant local pattern learning.

RMSE Heatmap:

  • CART: widespread high RMSE (>300 mm), poor model consistency across locations.

  • RF: lower RMSE over central and western India, indicating more reliable spatial prediction.

4 Final remarks