from time import time
from datetime import datetime

import numpy as np
import netCDF4 as nc
from netCDF4 import MFDataset,Dataset
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from pyschism.mesh import Hgrid


#Read hgrid
hgrid=Hgrid.open('hgrid.gr3', crs='EPSG:4326')


t0=time()
#Read netcdf
ds=MFDataset('./Elsa/schout_*.nc')
temp=ds['temp'][2::3,:,-1]
print(f'temp size is {temp.shape}')
temp[temp.mask]=0.0
x=ds['SCHISM_hgrid_node_x'][:]
y=ds['SCHISM_hgrid_node_y'][:]
#times=ds['time']
#times2=nc.num2date(times,units=times.units,only_use_cftime_datetimes=False)
startdate=datetime(2021,7,6,3)
enddate=datetime(2021,7,10,1)
#print(times)
times = np.arange(startdate, enddate,np.timedelta64(3,'h'), dtype='datetime64[m]')
print(times)

xmin = np.min(x)
xmax = np.max(x)
ymin = np.min(y)
ymax = np.max(y)
vmin = 10.0
vmax = 30.0
levels=np.linspace(vmin,vmax,256)

#Get mask
elev=ds['elev'][2::3,:]
depth=ds['depth'][:]
NP=len(depth)

fig=plt.figure(figsize=(12,8))
ax=fig.add_subplot(111)
#ax.set_xlabel('Longitude ($^\circ$E)')
#ax.set_ylabel('Latitude ($^\circ$N)')

#get mask
idry=np.zeros(NP)
idxs=np.where(elev[0,:].flatten()+depth <=1.e-6)
idry[idxs]=1
triangulation = hgrid.triangulation
mask=np.any(np.where(idry[triangulation.triangles], True, False), axis=1)
triangulation.set_mask(mask)

im=ax.tricontourf(triangulation, temp[0,:], cmap='jet', levels=levels, vmin=vmin, vmax=vmax, extend='min')
cbar=plt.colorbar(im)
cbar.set_ticks([10, 14, 18, 22, 26, 30])
cbar.set_ticklabels(['10','14','18','22','26','30'])
cbar.set_label('Temperature ($^\circ$C)')

#Add time label
#add_arts=im.collections
#an = ax.annotate(f'{str(times[0])}', xy=(0.05, 0.8), xycoords='axes fraction')
ax.set_xlabel('Longitude ($^\circ$E)')
ax.set_ylabel('Latitude ($^\circ$N)')

def animate(i):
    idry=np.zeros(NP)
    idxs=np.where(elev[i,:].flatten()+depth <=1.e-6)
    idry[idxs]=1

    triangulation = hgrid.triangulation
    mask=np.any(np.where(idry[triangulation.triangles], True, False), axis=1)
    triangulation.set_mask(mask)

    val=temp[i,:].flatten()
    ax.clear()
    im=ax.tricontourf(triangulation, val, cmap='jet', levels=levels, vmin=vmin, vmax=vmax,extend='min')

    ax.set_xlabel('Longitude ($^\circ$E)')
    ax.set_ylabel('Latitude ($^\circ$N)')

    ax.set_title(f'Surface temperature at {str(times[i])}')

anim=FuncAnimation(fig, func=animate, frames=np.arange(32), blit=False)

anim.save('sst_20210706-09_4days.mp4', writer='ffmpeg',fps=3)
print(f'It took {time()-t0} seconds to plot surface temperature')
