#%%
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib
import seaborn as sns
import pvlib
from matplotlib.colors import LinearSegmentedColormap
from scipy import stats
import pyproj
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.linear_model import LinearRegression

plt.rcParams.update({
    "savefig.facecolor": (0.0, 0.0, 1.0, 0.0), 
})

# Matplotlib font settings
SMALL_SIZE = 16
MEDIUM_SIZE = 20
BIGGER_SIZE = 22

matplotlib.rc('font', size=SMALL_SIZE)          # controls default text sizes
matplotlib.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of the axes title
matplotlib.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
matplotlib.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
matplotlib.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
matplotlib.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
matplotlib.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# Define constants and experiment settings
T0 = 273.16 # [K]
p0 = 101325 #[Pa], surface pressure
g = 9.8 # [m/s^2]
c_p = 1004 # [J/K/kg]
L_v = 2.5e6 # [J/K/kg]
L_i = 2.8e6 # [J/K/kg]
sigma = 5.670e-8 # [W/m^2/K^4], Stefan-Boltzmann constant
emis = 0.95 # Emissivity in GRASP
latitude = 52.727 # Warmenhuizen
longitude = 4.775 # Warmenhuizen
altitude = -2  # Example altitude in meters
location = pvlib.location.Location(latitude, longitude, altitude=altitude)
#%% Define custom functions
def find_files_recursive(rootdir, filename):
    file_paths, roots = [], []
    for root, dirs, files in os.walk(rootdir):
        for file in files:
            if filename in file:
                file_paths.append(os.path.join(root, file))
                roots.append(root)
    return file_paths, roots

def calc_CSI_CBH_CTH(dsold):
    ds = dsold.copy()
    ds['zenith'] = ('time',location.get_solarposition(ds.time.values)['zenith'])
    if 'Rswd' in ds.variables:
        if len(ds.Rswd.dims) > 1:
            ds['Rswd'] = (('time', 'yf', 'xf'), ds.Rswd.values)
            ds['Rswd_clear'] = (('time', 'yf', 'xf'), ds.Rswd_clear.values)
        ds['CSI'] = (ds.Rswd/ds.Rswd_clear)
        ds['CSI'] = xr.where((ds.zenith < 85)&(ds.CSI <= 10), ds.CSI, np.nan)
    if 'ql' in ds.variables:
        var = 'ql'
    elif 'qlf' in ds.variables:
        var = 'qlf'
    else:
        return ds
    cloudbase_index = (ds[var]>0).argmax(dim='zf')
    cloudtop_index = len(ds.zf) - (ds[var]>0).isel(zf=slice(None,None,-1)).argmax(dim='zf') - 1
    ds['CBH'] = xr.where((cloudbase_index==0) & (ds[var].isel(zf=0) == 0), np.nan, ds.zf.values[cloudbase_index])
    ds['CTH'] = xr.where((cloudtop_index==len(ds.zf)-1) & (ds[var].isel(zf=len(ds.zf)-1) == 0), np.nan, ds.zf.values[cloudtop_index])
    return ds

def interp_time(ds, new_time):
    # Set temporary timepoints in the middle of the intervals
    temp = ds.copy()
    freq = temp.time.diff(dim='time').min()
    temp['time'] = temp.time - freq/2

    goal = new_time.copy()
    goalfreq = goal.diff(dim='time').min()
    goal = goal - goalfreq/2

    # Interpolate to the new interval middle values
    temp = temp.interp(time=goal)

    # Reset labels at the end of the intervals
    temp['time'] = new_time

    return temp

def transform_coord(ds, source_crs='epsg:32631', target_crs='epsg:4326', xdim='x', ydim='y'):
    """
    Transforms an xarray dataset from source_crs to target_crs and reindexes it on a regular lat/lon grid.
    
    Parameters:
    - ds: xarray.Dataset
        The input dataset with 'x' and 'y' coordinates in the source_crs.
    - source_crs: str, default 'epsg:3043'
        The source coordinate reference system.
    - target_crs: str, default 'epsg:4326'
        The target coordinate reference system.
    
    Returns:
    - xarray.Dataset
        The transformed and reindexed dataset.
    """
    
    # Define the source and destination coordinate systems
    transformer = pyproj.Transformer.from_crs(source_crs, target_crs)
    
    # Extract x and y coordinates from your dataset
    x = ds[xdim].values
    y = ds[ydim].values
    
    # Ensure x and y are in meshgrid form if they are 1D arrays
    if x.ndim == 1 and y.ndim == 1:
        x, y = np.meshgrid(x, y)
    
    # Transform x, y to latitude and longitude
    lat, lon = transformer.transform(x, y)
    
    # Add new latitude and longitude coordinates to your dataset
    ds_new = ds.assign_coords(lon=((ydim, xdim), lon), lat=((ydim, xdim), lat))

    return ds_new

def skill_vars(x, y):
    """
    Computes various skill metrics between two arrays x and y.

    Args:
    x (numpy.ndarray or array-like): First array of data points.
    y (numpy.ndarray or array-like): Second array of data points.

    Returns:
    tuple: Tuple containing the following skill metrics:
        - RMSE (float): Root Mean Squared Error between x and y.
        - bias (float): Bias (mean difference) between y and x.
        - biaserror (float): Error in bias estimation based on standard deviations of x and y.
        - MAE (float): Mean Absolute Error between x and y.
        - corcoef (float): Pearson correlation coefficient between x and y.
        - corcoeferror (float): Error in correlation coefficient estimation.
        - sigma (float): Standard deviation of the residuals (difference between x and y after removing bias).
        - ssim (float): Structural Similarity Index (SSIM) between x and y.
    """
    # Root Mean Squared Error
    RMSE = np.sqrt(np.square(np.subtract(x, y)).mean())

    # Bias (mean difference)
    bias = np.subtract(y, x).mean()

    # Error in bias estimation
    biaserror = np.std(x) / np.sqrt(len(x)) + np.std(y) / np.sqrt(len(y))

    # Mean Absolute Error
    MAE = np.absolute(x - y).mean()

    # Pearson correlation coefficient
    corcoef = np.corrcoef(x, y)[0, 1]

    # Error in correlation coefficient estimation
    corcoeferror = 0.6745 * (1 - corcoef ** 2) / np.sqrt(len(x))

    # Standard deviation of residuals (sigma)
    sigma = np.sqrt(np.square(np.subtract(x, y) - bias).mean())

    # Structural similarity index (SSIM)
    try:
        ssim = ssim(x, y, data_range=np.max([np.max(x), np.max(y)]) - np.min([np.min(x), np.min(y)]))
    except:
        ssim = np.nan

    return RMSE, bias, biaserror, MAE, corcoef, corcoeferror, sigma, ssim

def seaborn_scatter(x, y, cvar, ylabel, xlabel="Observed clear-sky index [-]", axmin=None, axmax=None, cmap = 'viridis', hist=True, kde=True, statistics=True, binwidth=0.2):
    RMSE, bias, biaserror, MAE, corcoef, corcoeferror, sigma = skill_vars(x.values,y.values)
    Skills = f"$\\rho$ = {str(round(corcoef,3))} \nRMSE: {str(round(RMSE,3))}\nMBE: {str(round(bias,3))}"

    s = sns.JointGrid(x = x, 
                y = y,
                height = 10,
                xlim=[axmin,axmax], ylim=[axmin,axmax],
                space=0
                )
    # s.plot_joint(sns.scatterplot, marker='o', alpha=0.5, c = cvar, cmap = cmap)
    scatter = s.ax_joint.scatter(x, y, c=cvar, cmap=cmap, alpha=0.5)
    if kde == True:
        s.plot_joint(sns.kdeplot, fill=False, c=cmap(0.01), cut=0)
    if hist == True:
        s.plot_marginals(sns.histplot, color=cmap(0.01), binwidth=binwidth)
    s.set_axis_labels(xlabel, ylabel)
    s.ax_joint.grid()
    if statistics == True:
        s.ax_joint.plot(np.array([axmin,axmax]), np.array([axmin,axmax]), 'grey')
        s.ax_joint.text(axmin+(axmax-axmin)*0.05,axmax-(axmax-axmin)*0.15, Skills, bbox = dict(facecolor = 'white', alpha = 0.5))
    return s, scatter

#%% Open observations
dfs1, dfs2, dfs3, dfs4 = [], [], [], []
file_paths, roots = find_files_recursive('/home/marleen/Reform/Data/Dijken/Data-REFORM', 'merged_data_10min.csv')
for root in roots:
    df1 = pd.read_csv(root+'/merged_data_10min.csv', parse_dates=['TIMESTAMP'], index_col='TIMESTAMP')
    df2 = pd.read_csv(root+'/iwv_lwp_10min_avg.csv', parse_dates=['TIMESTAMP'], index_col='TIMESTAMP')
    dfs1.append(df1)
    dfs2.append(df2)

ds1, ds2 = pd.concat(dfs1).to_xarray().drop_duplicates(dim='TIMESTAMP'), pd.concat(dfs2).to_xarray().drop_duplicates(dim='TIMESTAMP')
ds_obs = xr.merge([ds1, ds2]).sortby('TIMESTAMP').rename({'TIMESTAMP':'time', 'SR15D1Dn_Irr':'Rswd'}).isel(time=slice(1,-1))
times = pd.DatetimeIndex(ds_obs.time.values)+pd.to_timedelta(-5, 'min')
clearsky = location.get_clearsky(times, model='ineichen')
ds_obs['Rswd_clear'] = ('time', clearsky.ghi.values)
ds_obs = calc_CSI_CBH_CTH(ds_obs)

#%%
# Open solar park data
ds_park = pd.read_excel('/home/marleen/Reform/Data/Dijken/ForecastDeDijken.xlsx', index_col=0).to_xarray().rename({'#':'date'})
dt_strings = [f"{date} {time}" for date, time in zip(ds_park.date.values, ds_park.CODE.values)]
dt = pd.to_datetime(dt_strings, format='%d-%m-%Y %H:%M:%S').tz_localize('Europe/Berlin').tz_convert('UTC')
ds_park['date'] = dt.values + pd.to_timedelta(15, 'min')
ds_park = ds_park.rename({'date': 'time', '871685920003799408TM':'realisation', '871685920003799408_ZON_DA':'forecast'}).drop('CODE').sel(time=slice(ds_obs.time[0], ds_obs.time[-1]))  # Replace 'time' with 'datetime'

# Interpolate observations to park data times and merge
ds = xr.merge([ds_park, interp_time(ds_obs, ds_park.time)])
location = pvlib.location.Location(latitude, longitude, altitude=altitude)

ds['zenith'] = ('time',location.get_solarposition(ds.time+pd.to_timedelta(-7.5, 'min'))['zenith'])
ds['azimuth'] = ('time',location.get_solarposition(ds.time+pd.to_timedelta(-7.5, 'min'))['azimuth'])
ds['declination'] = ('time', pvlib.solarposition.declination_spencer71(ds['time.dayofyear'].values))

nightmask = ds.zenith > 90
zeromask = ds.realisation < 100

ds = ds.where(~nightmask*~zeromask, drop = True).dropna(dim='time')

#%%
#%% Scatter SWD obs vs park production
y = ds.Rswd.dropna(dim='time')[1:]
y = y.where(y>10, drop=True)
x = ds.realisation.sel(time=y.time)
# obsplot = ds_obs.sel(time=y.time)
RMSE, bias, biaserror, MAE, corcoef, corcoeferror, sigma, ssim = skill_vars(x.values,y.values)
Skills = f"$\\rho$ = {str(round(corcoef,3))}"#\nRMSE: {str(round(RMSE,3))}\nMBE: {str(round(bias,3))}"

values = np.vstack([x,y])
cvar = stats.gaussian_kde(values)(values)
# cvar = obsplot.LWP_Corrected
original_cmap = plt.get_cmap('autumn')
# Create a new colormap using only the first half of the original colormap
colors = original_cmap(np.linspace(0.7, 1, 30))  # Use the first half of the colormap
new_cmap = LinearSegmentedColormap.from_list('half_cmap', colors)

s = sns.scatterplot(x=x,y=y, 
                    hue=cvar, 
                    # hue=(ds_park.forecast - ds_park.realisation).interp(time=x.time),
                    palette=original_cmap, alpha=0.5)
s.set_xlabel('Energy production [kW]')
s.set_ylabel('Observed $R_{\mathrm{SWD}}$ [W/m$^2$]')
# s.legend(loc="upper left", bbox_to_anchor=(1, 1))
s.get_legend().set_visible(False)
s.set_title(Skills)

#%% Define variables to use in production model
variables = ['realisation', 'Rswd', 'zenith','azimuth', 'declination', 'Wind_Speed', 'Temperature_K_2', 'LWP_Corrected']#, 'IWV']
names = ['production', 'Rswd', 'zenith','azimuth', 'declination', 'Wind_Speed', 'Temperature', 'LWP']#, 'IWV']

ds_rf = ds[variables]
corr_dict = {}

# Compute pairwise Pearson correlation for each variable pair
for i, var1 in enumerate(variables):
    for j, var2 in enumerate(variables):
        if i <= j:  # To avoid duplicate correlations (i.e., var1 vs var2 and var2 vs var1)
            corr_name = f"{names[i]}_vs_{names[j]}"
            corr_dict[corr_name] = xr.corr(ds_rf[var1], ds_rf[var2], dim='time')

# Convert the dictionary to an xarray Dataset
corr_dataset = xr.Dataset(corr_dict)

#%%  Filter out curtailment 
df = ds_rf.to_dataframe()

X = df.iloc[:, 1:]  # All columns except the first one
X_swd = X.iloc[:, 0].values.reshape(-1, 1)
Y = df.realisation   # The last column

# Train Linear Regression Model on all data
linear_model = LinearRegression(fit_intercept=False)
linear_model.fit(X_swd, Y)

# Predict and calculate residuals on all data
y_pred_full = linear_model.predict(X_swd)
residuals = Y - y_pred_full  # Residuals

# Discard rows
mask = (residuals/Y) >= -(residuals/Y).max()  # Keep only rows with residuals larger than max on other side
X_filtered = X[mask]
Y_filtered = Y[mask]

# Optional: You can check how much data is removed
print(f"Original data size: {X.shape[0]}")
print(f"Filtered data size: {X_filtered.shape[0]}")

# Predicted vs Actual plot
plt.figure(figsize=(8, 6))
plt.scatter(Y, y_pred_full, color='red', edgecolor='k', alpha=0.7)
plt.scatter(Y_filtered, y_pred_full[mask], color='green', edgecolor='k', alpha=0.7)
plt.plot([min(Y), max(Y)], [min(Y), max(Y)], color='red', lw=2)  # Line for perfect predictions
plt.title("Filter out curtailment")
plt.xlabel("Actual Values")
plt.ylabel("Predicted Values")
plt.show()

#%% 
# Split filtered data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X_filtered, Y_filtered, test_size=0.2, random_state=42)

# Model 1: Only random forest on all variables
# Create and train the Random Forest Regressor
rf_model = RandomForestRegressor(n_estimators=100, random_state=42)
rf_model.fit(X_train, y_train)

# Make predictions
y_pred_rf = rf_model.predict(X_test)

# Evaluate the model (mean squared error for regression)
RMSE, bias, biaserror, MAE, corcoef, corcoeferror, sigma, ssim = skill_vars(y_test,y_pred_rf)
Skills_rf = f"RF: \nrho = {str(round(corcoef,3))} \nRMSE: {str(round(RMSE,1))}\nMBE: {str(round(bias,1))}"
print(Skills_rf)

# Model 2: Linear regression on Rswd
# For linear regression, only use the first predictor
X_train_linear = X_train.iloc[:, 0].values.reshape(-1, 1)  # First predictor for linear regression
X_test_linear = X_test.iloc[:, 0].values.reshape(-1, 1)
X_train_rf = X_train.iloc[:, 1:]
X_test_rf = X_test.iloc[:, 1:]

# Train Linear Regression Model on the first predictor for the training data
linear_model.fit(X_train_linear, y_train)

# Make predictions on the test set
y_pred_linear_test = linear_model.predict(X_test_linear)  # Linear regression on test data

# Evaluate the model (mean squared error for regression)
RMSE, bias, biaserror, MAE, corcoef, corcoeferror, sigma, ssim = skill_vars(y_test,y_pred_linear_test)
Skills_lr = f"LR: \nrho = {str(round(corcoef,3))} \nRMSE: {str(round(RMSE,1))}\nMBE: {str(round(bias,1))}"
print(Skills_lr)

# Model 3: Linear regression on Rswd + Random Forest on residuals of training data

y_pred_linear = linear_model.predict(X_train_linear)
residuals_train = y_train - y_pred_linear  # Residuals

# Train Random Forest Regressor on the residuals using all predictors
rfres_model = RandomForestRegressor(n_estimators=100, random_state=42)
rfres_model.fit(X_train_rf, residuals_train)

# Predict the residuals using the random forest on the test set
residuals_pred_rf = rfres_model.predict(X_test_rf)

# The final prediction is the sum of linear regression predictions and predicted residuals
y_pred_final = y_pred_linear_test + residuals_pred_rf

# Evaluate the model (mean squared error for regression)
RMSE, bias, biaserror, MAE, corcoef, corcoeferror, sigma, ssim = skill_vars(y_test,y_pred_final)
Skills_final = f"LR + RF: \nrho = {str(round(corcoef,3))} \nRMSE: {str(round(RMSE,1))}\nMBE: {str(round(bias,1))}"
print(Skills_final)

#%% Plot feature importances of RF and 
# Get feature importances
importances = rf_model.feature_importances_[1:]

# Sort the features by importance
indices = np.argsort(importances)[::-1]
features = X_train.columns[1:]

# Feature importances excl Rswd for only RF model 
fig, ax = plt.subplots(figsize=(10, 6))
ax.set_title("Feature importances excl Rswd for only RF model")
ax.bar(range(features.shape[0]), importances[indices], align="center")
ax.set_xticks(range(features.shape[0]))
ax.set_xticklabels(features[indices], rotation=90)
ax.set_xlabel("Feature")
ax.set_ylabel("Importance Score")
fig.tight_layout()


# --- Feature Importances for residual RF model ---
importances = rfres_model.feature_importances_

# Sort the features by importance
indices = np.argsort(importances)[::-1]
features = X_train_rf.columns

# Plot feature importances
fig, ax = plt.subplots(figsize=(10, 6))
ax.set_title("Feature importances for residual RF model")
ax.bar(range(features.shape[0]), importances[indices], align="center")
ax.set_xticks(range(features.shape[0]))
ax.set_xticklabels(features[indices], rotation=90)
ax.set_xlabel("Feature")
ax.set_ylabel("Importance Score")
fig.tight_layout()

# Predicted vs Actual plot
fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(y_test, y_pred_linear_test, color='green', edgecolor='k', alpha=0.5, label='LR')
ax.scatter(y_test, y_pred_rf, color='orange', edgecolor='k', alpha=0.5, label='RF')
ax.scatter(y_test, y_pred_final, color='blue', edgecolor='k', alpha=0.5, label='LR + RF')
ax.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red', lw=2)

ax.set_title("Performance on test data")
ax.set_xlabel("Observed production [kW]")
ax.set_ylabel("Predicted production [kW]")

# Add skill metrics (using axes coordinates)
ax.text(1.1, 0.6, Skills_lr, bbox=dict(facecolor='green', alpha=0.2), transform=ax.transAxes)
ax.text(1.1, 0.3, Skills_rf, bbox=dict(facecolor='orange', alpha=0.2), transform=ax.transAxes)
ax.text(1.1, 0.0, Skills_final, bbox=dict(facecolor='blue', alpha=0.2), transform=ax.transAxes)

ax.legend()

#%% Just a start to making a physical model, but do not have the right information yet

times = pd.DatetimeIndex(ds.time.values)+pd.to_timedelta(-7.5, 'min') #pd.DatetimeIndex(les.time.values)+pd.to_timedelta(-5, 'min')
zenith = ds.zenith #location.get_solarposition(times)['zenith']

# Calculate clear sky radiation using the Ineichen model
clearsky = location.get_clearsky(times, model='ineichen')
ds['SWD_clear'] = ('time', clearsky.ghi.values)

# Define PV system specifications
temperature_model_parameters =pvlib.temperature.TEMPERATURE_MODEL_PARAMETERS['sapm']['open_rack_glass_glass']

sandia_modules = pvlib.pvsystem.retrieve_sam('SandiaMod')

cec_inverters = pvlib.pvsystem.retrieve_sam('cecinverter')

sandia_module = sandia_modules['Canadian_Solar_CS5P_220M___2009_']

cec_inverter = cec_inverters['ABB__MICRO_0_25_I_OUTD_US_208__208V_']

pv_system = pvlib.pvsystem.PVSystem(surface_tilt=10, surface_azimuth=188,
                  module_parameters=sandia_module,
                  inverter_parameters=cec_inverter,
                  temperature_model_parameters=temperature_model_parameters)

# Create ModelChain object
model_chain = pvlib.modelchain.ModelChain(pv_system, location, aoi_model='physical')

# Prepare weather data
weather = pd.DataFrame({
    'ghi': ds.Rswd,#les.Rswd.isel(lbc_xf=32, lbc_yf=32),
    'dni': pvlib.irradiance.erbs(ds.Rswd, zenith, times)['dni'], #les.Rswd_direct/np.cos(np.radians(zenith)),
    'dhi': pvlib.irradiance.erbs(ds.Rswd, zenith, times)['dhi'], #les.Rswd_diffuse,
    'temp_air': ds.Average_Temperature - T0,#metmast.TC.isel(height_above_ground_level=0, index=0),
    'wind_speed': ds.Wind_Speed, #metmast.M.isel(height_above_ground_level=0, index=0),
    }, index=times)
weather_clear = pd.DataFrame({
    'ghi': clearsky.ghi,
    'dni': clearsky.dni,
    'dhi': clearsky.dhi,
    # 'temp_air': metmast.TC.isel(height_above_ground_level=0, index=0),
    # 'wind_speed': metmast.M.isel(height_above_ground_level=0, index=0),
    }, index=times)

#%%
# Run ModelChain
model_chain.run_model(weather)
results = model_chain.results
ds['model_production'] = ('time',results.ac.values)

# model_chain.run_model(weather_clear)
# results_clear = model_chain.results

# results_clear.ac.plot(label='Clear sky')
# ds.realisation.sel(time=slice(times[0], times[-1])).plot(ax = plt.twinx(), label='Observed', linestyle='--')
ds.realisation.plot(ax = plt.twinx(), label='Observed', linestyle='--')

# Plot results
plt.xlabel('Time')
plt.ylabel('Energy Production (kWh)')
plt.title('Estimated Solar Energy Production')
plt.show()

# %%
