#%%
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
#%%

def remove_spinup(ds):
    spinup_str = ds.attrs['spin_up']
    spinup = pd.Timedelta(spinup_str)
    ds = ds.sel(time=slice(ds.time[0] + spinup,None))
    return ds

# Change to your path
path = '/home/marleen/Jonas_runs/2024/'

# Open meso datasets and remove spinup
meso_simdata = xr.open_mfdataset(path+'*/*/*/graspOutSimdata.les.nc', preprocess=remove_spinup).drop_vars('dofdif')
meso_metmast = xr.open_mfdataset(path+'*/*/*/graspOutTFMetmast.les.nc', preprocess=remove_spinup).drop_vars(['dofdif'])

# Open les datasets?

# Calculate liquid water path, first remove heights where density is nan
meso_simdata['lwp'] = (meso_simdata.ql * meso_simdata.rhohydrof).dropna('zf').integrate('zf')

meso_simdata.lwp.plot()

# Calculate ice water path?


# Calculate shortwave up radiation and net radiation
meso_simdata['Rswu'] = meso_simdata.Rswd - meso_simdata.Rswa
meso_simdata['Rnet'] = meso_simdata.Rswd + meso_simdata.Rlwd - meso_simdata.Rswu - meso_simdata.Rlwu

#%% Plot surface fluxes of a random day
day = meso_simdata.time.dt.date[30]
meso_temp = meso_simdata.where(meso_simdata.time.dt.date==day, drop=True)

fig, axs = plt.subplots(2,1, figsize=(10,10), sharex=True)

axs[0].plot(meso_temp.time, meso_temp.Rswd, label='Rswd')
axs[0].plot(meso_temp.time, meso_temp.Rlwd, label='Rlwd')
axs[0].plot(meso_temp.time, meso_temp.Rswu, label='Rswu')
axs[0].plot(meso_temp.time, meso_temp.Rlwu, label='Rlwu')
axs[0].plot(meso_temp.time, meso_temp.Rnet, c='k', label='Rnet')
axs[0].legend()

axs[1].plot(meso_temp.time, meso_temp.Rnet, c='k', label='Rnet')
axs[1].plot(meso_temp.time, meso_temp.G, label='Ground heat flux')
axs[1].plot(meso_temp.time, meso_temp.H, label='Sensible heat flux')
axs[1].plot(meso_temp.time, meso_temp.Q, label='Latent heat flux')
axs[1].legend()

# %%
