Catalina Jerez
catalina.jerez@colorado.edu
Fit a CART model to the leading 4 PCs of summer season rainfall over India, using 5 leading PCs of summer global tropical SSTs as covariates.
Fit a CART model for each PC separately.
Estimate the model precipitation at all the locations by multiplying the PCs estimates with Eigen Vectors.
gc()
## used (Mb) gc trigger (Mb) limit (Mb) max used (Mb)
## Ncells 543260 29.1 1202030 64.2 NA 700282 37.4
## Vcells 1005469 7.7 8388608 64.0 36864 1963597 15.0
rm(list = ls())
# Load necessary libraries
spatialAnalysis.Lib = c("maps", "rnaturalearth", "sf")
statisAnalysis.Lib = c("caret", "DescTools")
dataManipulation.Lib = c("dplyr", "reshape2", "tidyr")
dataVisualization.Lib = c("ggplot2", "ggrepel", "ggpubr", "RColorBrewer",
"ggraph", "igraph")
modeling.Lib = c("hydroGOF", "xgboost", "rpart", "randomForest", "rpart.plot",
"randomForestSRC", "data.tree", "DiagrammeR",
"mclust", "tree")
list.packages = unique(c(spatialAnalysis.Lib, statisAnalysis.Lib,
modeling.Lib,
dataManipulation.Lib, dataVisualization.Lib))
# Load all required packages
sapply(list.packages, require, character.only = TRUE)
## maps rnaturalearth sf caret DescTools
## TRUE TRUE TRUE TRUE TRUE
## hydroGOF xgboost rpart randomForest rpart.plot
## TRUE TRUE TRUE TRUE TRUE
## randomForestSRC data.tree DiagrammeR mclust tree
## TRUE TRUE TRUE TRUE TRUE
## dplyr reshape2 tidyr ggplot2 ggrepel
## TRUE TRUE TRUE TRUE TRUE
## ggpubr RColorBrewer ggraph igraph
## TRUE TRUE TRUE TRUE
# tune functions ---------------------------------------------------------------
# tune xgboost
tune.xgb = function(y, df, nsim, eta = 0.1, max_depth = 3, subsample = 0.8,
colsample_bytree = 0.8, min_child_weight = 1,
nfold = 5, max.retries = 5) {
success = FALSE
attempt = 0
while (!success && attempt < max.retries) {
tryCatch({
# Bootstrap Sampling
df.train = df[, !names(df) %in% y] # excluding the target variable
df.y = df[[y]]
dtrain = xgb.DMatrix(data = as.matrix(df.train), label = df.y)
# XGBoost Parameters
params = list(
booster = "gbtree",
eta = eta,
max_depth = max_depth,
# subsample = subsample,
# colsample_bytree = colsample_bytree,
min_child_weight = min_child_weight,
objective = "reg:squarederror"
)
# Train with Cross-Validation
xgb.cv.out = xgb.cv(params = params,
data = dtrain,
nrounds = nsim,
nfold = nfold,
early_stopping_rounds = 10,
verbose = 0)
# Get Best Number of Trees
best.tree = xgb.cv.out$best_iteration
# Final Model Training with Optimal Trees
gbmXGB.train = xgboost(params = params,
data = dtrain,
nrounds = best.tree,
verbose = 0)
# Store Best Tree Count
gbmXGB.train$best.trees = best.tree
success = TRUE
}, error = function(e) {
attempt = attempt + 1
message(paste("Error occurred:", e$message, "- Retrying... (Attempt:", attempt, "of", max.retries, ")"))
})
}
if (!success) {
stop("XGBoost training failed after ", max.retries, " attempts.")
}
return(gbmXGB.train)
}
# Tune CART (rpart + caret)
tune.cart = function(y, df, nfold = 5) {
ctrl.cv = trainControl(method = "cv", number = nfold)
cart.grid = expand.grid(cp = seq(0.001, 0.1, length.out = 10))
model = train(as.formula(paste(y, "~ .")),
data = df,
method = "rpart",
trControl = ctrl.cv,
tuneGrid = cart.grid,
metric = "RMSE")
return(model)
}
# Tune Random Forest
tune.rf = function(y, df, nsim, stepFactor = 1.5, improve = 0.01, trace = TRUE) {
# y = "rainfall"
# df = train.data[,-1]
# nsim = ntree
preds = df[, !(names(df) %in% y)]
yval = df[[y]]
# Tune mtry
best.model = tuneRF(x = preds,
y = yval,
ntreeTry = nsim,
stepFactor = stepFactor,
improve = improve,
trace = trace,
plot = FALSE)
best.mtry = best.model[which.min(best.model[, 2]), 1]
# final RF model
final.model = randomForest(x = preds,
y = yval,
mtry = best.mtry,
ntree = nsim,
importance = TRUE)
return(final.model)
}
# functions for graphs ----------------------------------------------------------
# constants
text.size = 15
customize.bw = theme_bw(base_family = "Times") +
theme(
axis.text.x = element_text(size = text.size, vjust = 0.5),
axis.text.y = element_text(size = text.size),
axis.title = element_text(size = text.size + 2, face = "bold"),
plot.title = element_text(size = text.size + 2, face = "bold"),
legend.position = "bottom",
legend.title = element_text(size = text.size, face = "bold"),
legend.text = element_text(size = text.size),
panel.border = element_rect(color = "#000000", fill = NA, linewidth = 0.5),
panel.grid.minor.x = element_line(color = "#d9d9d9", linetype = "dotted"),
panel.grid.minor.y = element_line(color = "#d9d9d9", linetype = "dotted"),
panel.grid.major.x = element_blank(),
panel.grid.major.y = element_blank()
)
# world map
world.sf = ne_countries(scale = "medium", returnclass = "sf") %>%
st_wrap_dateline(options = c("WRAPDATELINE=YES", "DATELINEOFFSET=180")) %>%
st_transform("+proj=longlat +datum=WGS84")
# variable importance function
varImp.plot = function(df.imp, title, imp = T, text.size = 14){
df.imp = df.imp[order(-df.imp$importance), ]
if(imp){ # for random forest
imp.label = paste0("Relative importance (%)")
}else{ # gbm
imp.label = paste0("Relative Influence (%)")
}
df.imp$label = paste0(df.imp$covariate, " = ", round(df.imp$importance, 2), "%")
ggplot(df.imp, aes(x = reorder(covariate, importance), y = importance)) +
geom_bar(stat = "identity", color = "#40004b", fill = "#c2a5cf") +
geom_text(aes(label = label, y = 0.05 * max(importance)),
hjust = 0, color ="#40004b", size = 6, fontface = "bold") +
coord_flip() + # Flip coordinates
labs(title = title, x = NULL, y = imp.label) +
theme_bw(base_family = "Times") +
theme(axis.text.x = element_text(size = text.size, vjust = 0.5),
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
axis.title = element_text(size = text.size + 1, face = 'bold'),
plot.title = element_text(size = text.size + 1, face = 'bold', hjust = .5),
panel.border = element_rect(color = "#000000", fill = NA, linewidth = .5),
panel.grid.minor.x = element_line(color = "#d9d9d9", linetype = "dotted"),
panel.grid.minor.y = element_line(color = "#d9d9d9", linetype = "dotted"),
panel.grid.major.x = element_blank(),
panel.grid.major.y = element_blank() )
}
# tree function for rf
# adapted from \url{https://shiring.github.io/machine_learning/2017/03/16/rf_plot_ggraph}
treeRF.plot = function(final_model, tree_num) {
# get tree by index
tree = randomForest::getTree(final_model, k = tree_num, labelVar = TRUE) %>%
tibble::rownames_to_column() %>%
# make leaf split points to NA, so the 0s won't get plotted
mutate(`split point` = ifelse(is.na(prediction), `split point`, NA))
# prepare data frame for graph
graph.frame = data.frame(from = rep(tree$rowname, 2),
to = c(tree$`left daughter`, tree$`right daughter`))
# convert to graph and delete the last node that we don't want to plot
graph = graph_from_data_frame(graph.frame) %>%
delete_vertices("0")
# set node labels
V(graph)$node_label = gsub("_", " ", as.character(tree$`split var`))
V(graph)$leaf_label = as.character(round(tree$prediction, 2))
V(graph)$split = as.character(round(tree$`split point`, digits = 2))
# plot
ggraph(graph, 'dendrogram') +
geom_edge_link() +
geom_node_point() +
geom_node_text(aes(label = node_label), na.rm = TRUE, repel = TRUE) +
geom_node_label(aes(label = split), vjust = 2.5, na.rm = TRUE, fill = "white") +
geom_node_label(aes(label = leaf_label, fill = leaf_label), na.rm = TRUE,
repel = TRUE, colour = "white", fontface = "bold", show.legend = FALSE) +
customize.bw +
theme(panel.grid.minor = element_blank(),
panel.grid.major = element_blank(),
panel.background = element_blank(),
plot.background = element_rect(fill = "white"),
panel.border = element_blank(),
axis.line = element_blank(),
axis.text.x = element_blank(),
axis.text.y = element_blank(),
axis.ticks = element_blank(),
axis.title.x = element_blank(),
axis.title.y = element_blank(),
plot.title = element_text(size = 18))
}
# plot spatial metrics
plot.spatialMetric = function(data, skill){
# data = data %>%
# dplyr::filter(metric == skill)
data = subset(data, metric %in% skill)
ggplot(data) +
geom_tile(aes(x = lon, y = lat, fill = value.winsor)) +
# regions
geom_rect(aes(xmin = 66, xmax = 100, ymin = 6, ymax = 39),
inherit.aes = FALSE, color = "#41ab5d", fill = NA, linewidth = 1) +
geom_sf(data = world.sf, fill = NA, color = "black",
linewidth = 0.2, inherit.aes = FALSE) +
# labs
labs(title = paste0(skill, " across grid cells"),
x = "Longitude", y = "Latitude") +
# scales
facet_wrap(~ model, ncol = 2) +
coord_sf(xlim = c(65, 101), ylim = c(5, 40), expand = FALSE, crs = "EPSG:4326") +
scale_fill_gradientn(skill,
colors = colorRampPalette(rev(brewer.pal(9, "RdBu")))(200),
# limits = fill_limits,
guide = guide_colorbar(barwidth = 15, barheight = 2) ) +
scale_y_continuous(breaks = seq(10, 40 , 10))+
scale_x_continuous(breaks = seq(60, 100, 10))+
# theme
customize.bw +
theme(axis.title.x = element_blank(),
plot.subtitle = element_text(size = text.size-2, face = "italic"),
strip.background = element_blank(),
strip.text = element_text(size = text.size, face = "bold"))
}
Rainfall data for JJAS from 1950 to 2016 is read and reshaped into a matrix where each row represents one season (year) and each column a spatial location. Only valid (non-missing) data are retained.
path.data = "https://civil.colorado.edu/~balajir/CVEN6833/HWs/HW-2/Spring-2025-data-commands/"
# Global grid
# deg grid
nrows = 68
ncols = 65
ntime = 67 # Jun-Sep 1950 - 2016
# Lat - Long grid
ygrid = seq(6.5 , 38.5, by=0.5)
xgrid = seq(66.5, 100 , by=0.5)
ny = length(ygrid)
nx = length(xgrid)
xygrid= matrix(0, nrow = nx*ny, ncol=2)
i = 0
for(iy in 1:ny){
for(ix in 1:nx){
i = i+1
xygrid[i,1] = ygrid[iy]
xygrid[i,2] = xgrid[ix]
}
}
# read India location valid points
rain.locs = read.table(paste0(path.data, "India-rain-locs.txt"))
# read data binary
data.raw = readBin(paste0(path.data, "India-Rain-JJAS-05deg-1950-2016.r4"),
what = "numeric",
n = (nx * ny * ntime),
size = 4,
endian = "swap")
# reshape to (ny, nx, ntime)
data.array = array(data.raw, dim = c(ny, nx, ntime))
rain.indices = rain.locs[, 3]
nsites = length(rain.indices)
# construct rain.avg (ntime rows × valid rain sites columns)
rain.avg = matrix(NA, nrow = ntime, ncol = nsites)
for (t in 1:ntime) {
rain.slice = data.array[, , t]
rain.avg[t, ] = rain.slice[rain.indices]
}
rain.mean = rowMeans(rain.avg)
SST data for the tropical Pacific is loaded and reshaped. Only valid SST points are retained based on the provided locations.
path.data = "https://civil.colorado.edu/~balajir/CVEN6833/HWs/HW-2/Spring-2025-data-commands/"
# grid SST (NOAA tropical)
ntime = 67 # Jun-Sep 1950 - 2016
ygrid.sst = seq(-16, 16 , by = 2)
xgrid.sst = seq(0 , 358, by = 2)
ny.sst = length(ygrid.sst)
nx.sst = length(xgrid.sst)
# read SST valid points
sst.locs = read.table(paste0(path.data, "NOAA-trop-sst-locs.txt"))
# read SST binary
data.raw = readBin(paste0(path.data, "NOAA-Trop-JJAS-SST-1950-2016.r4"),
what = "numeric",
n = (nx.sst * ny.sst * ntime),
size = 4,
endian = "swap")
# reshape to (ny, nx, ntime)
sst.array = array(data.raw, dim = c(ny.sst, nx.sst, ntime))
sst.indices = sst.locs[, 3]
nsites = length(sst.indices)
# construct sst.avg (ntime rows × valid SST sites columns)
sst.avg = matrix(NA, nrow = ntime, ncol = nsites)
for (t in 1:ntime) {
sst.slice = sst.array[, , t]
sst.avg[t, ] = sst.slice[sst.indices]
}
rain.pca = prcomp(rain.avg, scale. = TRUE, center = TRUE)
rain.pcs = rain.pca$x[, 1:4] # use first 4 PCs
rain.evecs = rain.pca$rotation[, 1:4] # eigenvectors for reconstruction
sst.pca = prcomp(sst.avg, scale. = TRUE, center = TRUE)
sst.pcs = sst.pca$x[, 1:5]
data = data.frame(year = 1950:2016, rainfall = rain.mean, sst.pcs)
train.id = createDataPartition(data$rainfall, p = 0.6, list = FALSE)
train.data = data[train.id, ]
test.data = data[-train.id, ]
set.seed(250405)
sst.pcs.df = as.data.frame(sst.pcs[, 1:5]) # use 5 SST PCs
models = list(CART = list(), RF = list())
preds = list(CART = matrix(NA, nrow = nrow(sst.pcs), ncol = 4),
RF = matrix(NA, nrow = nrow(sst.pcs), ncol = 4))
for (i in 1:4) {
df = cbind(PC = rain.pcs[, i], sst.pcs.df)
train.pcs = df[train.id, ]
# CART
cart = tune.cart("PC", df = train.pcs)
models$CART[[i]] = cart$finalModel
preds$CART[, i] = predict(cart$finalModel, sst.pcs.df)
# RF
rf = tune.rf("PC", df = train.pcs, nsim = 500)
models$RF[[i]] = rf
preds$RF[, i] = predict(rf, sst.pcs.df)
}
## mtry = 1 OOB error = 149.7205
## Searching left ...
## Searching right ...
## mtry = 1 OOB error = 63.94505
## Searching left ...
## Searching right ...
## mtry = 1 OOB error = 62.07333
## Searching left ...
## Searching right ...
## mtry = 1 OOB error = 41.59057
## Searching left ...
## Searching right ...
# Reconstruct model rainfall fields
dim(rain.evecs) # nsites x 4
## [1] 1250 4
dim(preds$CART) # ntime x 4
## [1] 67 4
rain.recon = list()
rain.recon$CART = preds$CART %*% t(rain.evecs)
rain.recon$RF = preds$RF %*% t(rain.evecs)
# Inverse transformation
rain.recon$CART = sweep(preds$CART %*% t(rain.pca$rotation[, 1:4]), 2, rain.pca$scale, "*")
rain.recon$CART = sweep(rain.recon$CART, 2, rain.pca$center, "+")
rain.recon$RF = sweep(preds$RF %*% t(rain.pca$rotation[, 1:4]), 2, rain.pca$scale, "*")
rain.recon$RF = sweep(rain.recon$RF, 2, rain.pca$center, "+")
# # multiply by the original scaling and add the center (inverse standardization)
# rain.recon$CART = sweep(rain.recon$CART, 2, rain.pca$scale, FUN = "*")
# rain.recon$CART = sweep(rain.recon$CART, 2, rain.pca$center, FUN = "+")
# rain.recon$RF = sweep(rain.recon$RF, 2, rain.pca$scale, FUN = "*")
# rain.recon$RF = sweep(rain.recon$RF, 2, rain.pca$center, FUN = "+")
eval.metrics = function(predicted, true) {
sapply(seq_len(ncol(predicted)), function(i) {
cor.val = cor(predicted[, i], true[, i])
rmse.val = rmse(predicted[, i], true[, i])
c(Corr = cor.val, RMSE = rmse.val)
}) %>% t() %>% as.data.frame()
}
true.rain.test = rain.avg[-train.id, ]
rain.cart.test = rain.recon$CART[-train.id, ]
rain.rf.test = rain.recon$RF[-train.id, ]
metrics = list()
metrics$CART = eval.metrics(rain.cart.test, true.rain.test)
metrics$RF = eval.metrics(rain.rf.test, true.rain.test)
# Combine metrics
metrics$CART$model = "CART"
metrics$RF$model = "RF"
metrics.df = rbind(metrics$CART, metrics$RF)
metrics.long = tidyr::pivot_longer(metrics.df,
cols = c("Corr", "RMSE"),
names_to = "metric",
values_to = "value")
metrics.long = metrics.long %>%
dplyr::group_by(model, metric) %>%
dplyr::mutate(value.winsor = Winsorize(value,
val = quantile(value, probs = c(0.1, 0.9), na.rm = T) )) %>%
dplyr::ungroup()
ggplot(metrics.long, aes(x = model, y = value.winsor, fill = model)) +
geom_boxplot(alpha = 0.8, show.legend = F) +
facet_wrap(~ metric, scales = "free_y") +
labs(title = "Evaluation metrics",
subtitle = "Summer season rainfall over India",
y = "Value") +
# theme
customize.bw +
theme(axis.title.x = element_blank(),
plot.subtitle = element_text(size = text.size-2, face = "italic"),
strip.background = element_blank(),
strip.text = element_text(size = text.size, face = "bold"))
metrics.long %>%
group_by(model, metric) %>%
summarise(mean = round(mean(value), 2),
median = round(median(value), 2),
sd = round(sd(value), 2),
.groups = 'drop')
xy.coords = xygrid[rain.indices, ]
colnames(xy.coords) = c("lat", "lon")
metrics.spatial = rbind(
cbind(model = "CART", xy.coords, metrics$CART[, c("Corr", "RMSE")]),
cbind(model = "RF", xy.coords, metrics$RF[, c("Corr", "RMSE")]) )
spatial.long = tidyr::pivot_longer(metrics.spatial,
cols = c("Corr", "RMSE"),
names_to = "metric",
values_to = "value")
spatial.long = spatial.long %>%
mutate(lon = ifelse(lon > 180, lon - 360, lon)) %>%
dplyr::group_by(model, metric) %>%
dplyr::mutate(value.winsor = Winsorize(value,
val = quantile(value, probs = c(0.1, 0.9), na.rm = T) )) %>%
dplyr::ungroup()
plot.spatialMetric(data = spatial.long, skill = "Corr")
plot.spatialMetric(data = spatial.long, skill = "RMSE")
Correlation Heatmap:
CART: sparse pockets with Corr > 0.2, most values hover near 0.
RF: several regions with Corr > 0.3-0.4, showing significant local pattern learning.
RMSE Heatmap:
CART: widespread high RMSE (>300 mm), poor model consistency across locations.
RF: lower RMSE over central and western India, indicating more reliable spatial prediction.
For both models, the rainfall were reconstructed via inverse PCA transformation. RF’s predicted PCs, when multiplied by eigenvectors, produced fields closer to real rainfall than CART’s (validated by lower RMSE and higher correlation at almost every grid point.)
Thus, we recommend to use Random Forest for accurate spatio-temporal forecasting of Indian summer rainfall using SST PCs. While CART offers interpretability, it lacks accuracy and generalization compared to RF.