from time import time
from datetime import datetime

import numpy as np
import netCDF4 as nc
from netCDF4 import Dataset, MFDataset
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from pyschism.mesh import Hgrid


#Read hgrid
hgrid=Hgrid.open('hgrid.gr3', crs='EPSG:4326')


t0=time()
#Read netcdf
ds=MFDataset('./Elsa/schout_*.nc')
#ds=Dataset('./Elsa/schout_20210705.nc')
salt=ds['salt'][2::3,:,-1]
print(f'salt size is {salt.shape}')
salt[salt.mask]=0.0
x=ds['SCHISM_hgrid_node_x'][:]
y=ds['SCHISM_hgrid_node_y'][:]
#times=ds['time']
#times2=nc.num2date(times,units=times.units,only_use_cftime_datetimes=False)
startdate=datetime(2021,7,6,3)
enddate=datetime(2021,7,10,1)
#print(times)
times = np.arange(startdate, enddate,np.timedelta64(3,'h'), dtype='datetime64[m]')
print(times)

xmin = np.min(x)
xmax = np.max(x)
ymin = np.min(y)
ymax = np.max(y)
vmin = 30.0
vmax = 37.0
levels=np.linspace(vmin,vmax,256)

elev=ds['elev'][2::3,:]
depth=ds['depth'][:]
NP=len(depth)

fig=plt.figure(figsize=(12,8))
ax=fig.add_subplot(111)

#get mask
idry=np.zeros(NP)
idxs=np.where(elev[0,:].flatten()+depth <=1.e-6)
idry[idxs]=1
triangulation = hgrid.triangulation
mask=np.any(np.where(idry[triangulation.triangles], True, False), axis=1)
triangulation.set_mask(mask)

im=ax.tricontourf(triangulation, salt[0,:], cmap='jet', levels=levels, vmin=vmin, vmax=vmax, extend='min')
cbar=plt.colorbar(im)
cbar.set_ticks([30.0, 31.5, 33.0, 34.5, 36.0])
cbar.set_ticklabels(['30.0','31.5','33.0','34.5','36.0'])
cbar.set_label('Salinity')

#Add time label
#add_arts=im.collections
#an = ax.annotate(f'{str(times[0])}', xy=(0.05, 0.8), xycoords='axes fraction')
ax.set_xlabel('Longitude ($^\circ$E)')
ax.set_ylabel('Latitude ($^\circ$N)')

def animate(i):
    print(f'Plotting time {i}')
    idry=np.zeros(NP)
    idxs=np.where(elev[i,:].flatten()+depth <=1.e-6)
    idry[idxs]=1

    triangulation = hgrid.triangulation
    mask=np.any(np.where(idry[triangulation.triangles], True, False), axis=1)
    triangulation.set_mask(mask)

    val=salt[i,:].flatten()
    ax.clear()
    im=ax.tricontourf(triangulation, val, cmap='jet', levels=levels, vmin=vmin, vmax=vmax, extend='min')

    ax.set_xlabel('Longitude ($^\circ$E)')
    ax.set_ylabel('Latitude ($^\circ$N)')

    ax.set_title(f'Surface salinity at {str(times[i])}')

anim=FuncAnimation(fig, func=animate, frames=np.arange(32), blit=False)

anim.save('sss_20210706-09_4days.mp4', writer='ffmpeg',fps=3)
print(f'It took {time()-t0} seconds to plot surface salinity')
