from time import time
from datetime import datetime

import numpy as np
import netCDF4 as nc
from netCDF4 import MFDataset,Dataset
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from pyschism.mesh import Hgrid


#Read hgrid
hgrid=Hgrid.open('hgrid.gr3', crs='EPSG:4326')


t0=time()
#Read netcdf
ds=MFDataset('./Elsa/schout_*.nc')
#ds=Dataset('./Elsa/schout_20210705.nc')
x=ds['SCHISM_hgrid_node_x'][:]
y=ds['SCHISM_hgrid_node_y'][:]
#times=ds['time']
#times2=nc.num2date(times,units=times.units,only_use_cftime_datetimes=False)
startdate=datetime(2021,7,6,3)
enddate=datetime(2021,7,10,1)
#print(times)
times = np.arange(startdate, enddate,np.timedelta64(3,'h'), dtype='datetime64[m]')
print(times)

#U/V
u=ds['hvel'][2::3,:,-1,0]
u[u.mask]=0
v=ds['hvel'][2::3,:,-1,1]
v[v.mask]=0
M=np.hypot(u,v)
u[np.where(abs(u)>10000)]=np.nan
v[np.where(abs(v)>10000)]=np.nan

xmin = np.min(x)
xmax = np.max(x)
ymin = np.min(y)
ymax = np.max(y)
vmin = 0.0
vmax = 3.0
levels=np.linspace(vmin,vmax,256)

#Get mask
elev=ds['elev'][2::3,:]
depth=ds['depth'][:]
NP=len(depth)

fig=plt.figure(figsize=(12,8))
ax=fig.add_subplot(111)
#ax.set_xlabel('Longitude ($^\circ$E)')
#ax.set_ylabel('Latitude ($^\circ$N)')

#get mask
idry=np.zeros(NP)
idxs=np.where(elev[0,:].flatten()+depth <=1.e-6)
idry[idxs]=1
u[0,idxs]=np.nan
v[0,idxs]=np.nan
#UV
#q = ax.quiver(x[::100], y[::100],u[0,::100], v[0,::100],alpha=0.3, pivot='mid', color='k', angles='xy', scale_units='xy', scale=0.5) #, width=0.005) 
#M=np.hypot(u[0,:],v[0,:])
triangulation = hgrid.triangulation
mask=np.any(np.where(idry[triangulation.triangles], True, False), axis=1)
triangulation.set_mask(mask)

im=ax.tricontourf(triangulation, M[0,:], cmap='jet', levels=levels, vmin=vmin, vmax=vmax, extend='neither')
cbar=plt.colorbar(im)
cbar.set_ticks([0.4, 0.8, 1.2, 1.6, 2.0, 2.4, 2.8])
cbar.set_ticklabels(['0.4','0.8','1.2','1.6','2.0','2.4','2.8'])
cbar.set_label('Velocity (m/s)')

#q = ax.quiver(x[::100], y[::100],u[0,::100], v[0,::100],alpha=0.3, pivot='mid', color='k', scale_units='width',headwidth=2,headlength=3,headaxislength=2.5) #, width=0.005) 
#ax.quiverkey(q, 0.2, 0.7, 0.5, 'current 0.5 m/s',color='k', labelpos='E', coordinates='figure')

#Add time label
#add_arts=im.collections
#an = ax.annotate(f'{str(times[0])}', xy=(0.05, 0.8), xycoords='axes fraction')
ax.set_xlabel('Longitude ($^\circ$E)')
ax.set_ylabel('Latitude ($^\circ$N)')

def animate(i):
    idry=np.zeros(NP)
    idxs=np.where(elev[i,:].flatten()+depth <=1.e-6)
    u[i,idxs]=np.nan
    v[i,idxs]=np.nan

    ax.clear()
    #q = ax.quiver(x[::100], y[::100],u[i,::100], v[i,::100],alpha=0.3, pivot='mid', color='k', angles='xy', scale_units='xy', scale=1.0) #, width=0.005) 
    #M=np.hypot(u[i,:],v[i,:])
    triangulation = hgrid.triangulation
    mask=np.any(np.where(idry[triangulation.triangles], True, False), axis=1)
    triangulation.set_mask(mask)

    im=ax.tricontourf(triangulation, M[i,:], cmap='jet', levels=levels, vmin=vmin, vmax=vmax, extend='neither')

    #q = ax.quiver(x[::100], y[::100],u[i,::100], v[i,::100],alpha=0.3, pivot='mid', color='k', scale_units='width',headwidth=2,headlength=3,headaxislength=2.5) #, width=0.005) 
    #ax.quiverkey(q, 0.2, 0.7, 1.0, 'current 1.0 m/s',color='k', labelpos='E', coordinates='figure')
    ax.set_xlabel('Longitude ($^\circ$E)')
    ax.set_ylabel('Latitude ($^\circ$N)')

    ax.set_title(f'Surface current at {str(times[i])}')

anim=FuncAnimation(fig, func=animate, frames=np.arange(32), blit=False)

anim.save('current_20210706-09_4days.mp4', writer='ffmpeg',fps=3)
print(f'It took {time()-t0} seconds to plot surface current')
