library(CCA)


# Clear environment
rm(list = ls())
options(warn = -1)

# Set working directory
script_dir <- dirname(rstudioapi::getActiveDocumentContext()$path)
setwd(script_dir)

# Load helper functions
source("HW2_Library.R")

# Color palette
myPalette <- colorRampPalette(rev(brewer.pal(9, "RdBu")), space = "Lab")
Read data
# India Rainfall Data
nrows = 68; ncols = 65
ntime = 67   # 1950–2016
ygrid = seq(6.5, 38.5, by = 0.5)
xgrid = seq(66.5, 100, by = 0.5)
ny = length(ygrid); nx = length(xgrid)

xygrid = matrix(0, nrow = nx*ny, ncol = 2)
i = 0
for (iy in 1:ny) {
  for (ix in 1:nx) {
    i = i + 1
    xygrid[i, ] = c(ygrid[iy], xgrid[ix])
  }
}

# Load rainfall binary data
data = readBin("data/India-Rain-JJAS-05deg-1950-2016.r4", what = "numeric", 
               n = nrows * ncols * ntime, size = 4, endian = "swap")
data = array(data, dim = c(nrows, ncols, ntime))

# Filter non-missing
index = 1:(nx * ny)
data1 = data[,,1]
index1 = index[data1 != "NaN"]
xygrid1 = xygrid[index1,]
nsites = length(index1)

# Time series rainfall matrix
raindata = matrix(NA, nrow = ntime, ncol = nsites)
for (i in 1:ntime) {
  data1 = data[,,i]
  index1 = index[data1 != "NaN"]
  raindata[i, ] = data1[index1]
}

# Filter zero-mean locations
index = 1:dim(raindata)[2]
xx = apply(raindata, 2, mean)
index2 = index1[xx > 0]
index3 = index[xx > 0]

xygrid1 = xygrid[index2,]
rainavg = raindata[, index3]
indexgrid = index2
rm("data")

#SST Data
nrows_sst = 180; ncols_sst = 17
ygrid_sst = seq(-16, 16, by = 2)
xgrid_sst = seq(0, 358, by = 2)
ny_sst = length(ygrid_sst); nx_sst = length(xgrid_sst)

xygrid_sst = matrix(0, nrow = nx_sst * ny_sst, ncol = 2)
i = 0
for (iy in 1:ny_sst) {
  for (ix in 1:nx_sst) {
    i = i + 1
    xygrid_sst[i, ] = c(ygrid_sst[iy], xgrid_sst[ix])
  }
}

data_sst = readBin("data/NOAA-Trop-JJAS-SST-1950-2016.r4", what = "numeric", 
                   n = nrows_sst * ncols_sst * ntime, size = 4, endian = "swap")
data_sst = array(data_sst, dim = c(nrows_sst, ncols_sst, ntime))

data1_sst = data_sst[,,1]
index_sst = 1:(nx_sst * ny_sst)
index1_sst = index_sst[data1_sst < 20 & data1_sst != "NaN"]
xygrid1_sst = xygrid_sst[index1_sst,]

nsites_sst = length(index1_sst)
sstdata = matrix(NA, nrow = ntime, ncol = nsites_sst)

for (i in 1:ntime) {
  data1_sst = data_sst[,,i]
  index1_sst = index_sst[data1_sst < 20 & data1_sst != "NaN"]
  sstdata[i, ] = data1_sst[index1_sst]
}

sstannavg = sstdata
indexgrid_sst = index1_sst
rm("data_sst")
Perform PCA
# Rainfall PCA
rainscale = scale(rainavg)
zs = var(rainscale)
zsvd = svd(zs)
rainpcs = t(t(zsvd$u) %*% t(rainscale))  # PCs
prec_eof = zsvd$u                        # EOFs

# SST PCA
sstscale = scale(sstannavg)
zs_sst = var(sstscale)
zsvd_sst = svd(zs_sst)
sstpcs = t(t(zsvd_sst$u) %*% t(sstscale))
Canonical Correlation Analysis (CCA)
npc = 4
rain_cca = rainpcs[, 1:npc]
sst_cca = sstpcs[, 1:npc]

cca_result = cc(sst_cca, rain_cca)
U = sst_cca %*% cca_result$xcoef  # Canonical variates of SST
V = rain_cca %*% cca_result$ycoef  # Canonical variates of Rain

# Regress V on U
V_hat = matrix(0, nrow = nrow(U), ncol = npc)
for (i in 1:npc) {
  fit = lm(V[,i] ~ U)
  V_hat[,i] = predict(fit)
}

# Reconstruct rainfall PCs
rain_pcs_hat = V_hat %*% t(cca_result$ycoef)
N = nrow(rainpcs)
N1 = ncol(rainpcs) - npc
rain_pcs_full = cbind(rain_pcs_hat, matrix(0, ncol = N1, nrow = N))

# Back-transform to rainfall
E = matrix(0, nrow = dim(rainscale)[2], ncol = dim(rainscale)[2])
E[,1:npc] = prec_eof[,1:npc]
prec_pred_cca = rain_pcs_full %*% t(E)
precMean = apply(rainavg, 2, mean)
precSd = apply(rainavg, 2, sd)
prec_pred_cca = t(t(prec_pred_cca)*precSd + precMean)
Evaluate Performance (R² and RMSE Maps)
# Compute R² and RMSE at each grid
xcor_cca = diag(cor(prec_pred_cca, rainavg))
xrmse_cca = sqrt(colMeans((prec_pred_cca - rainavg)^2))

# Load rainfall locations
rainlocs = read.table("data/India-rain-locs.txt")

# Spatial plot
zfull = rep(NaN, nx * ny)
zfull[rainlocs[,3]] = xcor_cca
zmat = matrix(zfull, nrow = nx, ncol = ny)

image.plot(xgrid, ygrid, zmat, ylim = range(5, 40), col = myPalette(200), zlim = c(0, 0.6))
contour(xgrid, ygrid, zmat, add = TRUE, nlev = 6, lwd = 2)
title(main = bquote(R^2 ~ "between Observed and CCA Predicted Precipitation"))
maps::map('world', wrap = c(0, 360), add = TRUE, resolution = 0, lwd = 2)
grid()

# Boxplots
par(mfrow = c(1, 2))
boxplot(xcor_cca,
        main = "CCA Correlation (R2)",
        col = "skyblue", border = "blue")

boxplot(xrmse_cca,
        main = "CCA RMSE",
        col = "salmon", border = "red")