from time import time
from datetime import datetime

import numpy as np
import netCDF4 as nc
from netCDF4 import MFDataset,Dataset
import matplotlib.pyplot as plt

from matplotlib.animation import ArtistAnimation
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt

from pyschism.mesh import Hgrid


#Read hgrid
hgrid=Hgrid.open('hgrid.gr3', crs='EPSG:4326')


t0=time()
#Read netcdf
ds=MFDataset('./Elsa/schout_*.nc')
#ds=Dataset('./Elsa/schout_20210705.nc')
#x=ds['SCHISM_hgrid_node_x'][:]
#y=ds['SCHISM_hgrid_node_y'][:]
#times=ds['time']
#times2=nc.num2date(times,units=times.units,only_use_cftime_datetimes=False)
startdate=datetime(2021,7,6,3)
enddate=datetime(2021,7,10,1)
#print(times)
times = np.arange(startdate, enddate,np.timedelta64(3,'h'), dtype='datetime64[m]')
print(times)

#U/V
u=ds['hvel'][2::3,:,-1,0]
u[u.mask]=0
v=ds['hvel'][2::3,:,-1,1]
v[v.mask]=0
M=np.hypot(u,v)
u[np.where(abs(u)>10000)]=np.nan
v[np.where(abs(v)>10000)]=np.nan

xmin = -85.25
xmax = -77.78
ymin = 23.69
ymax = 28.49

vmin = 0.0
vmax = 3.0
levels=np.linspace(vmin,vmax,256)

#Get mask
elev=ds['elev'][2::3,:]
depth=ds['depth'][:]
NP=len(depth)

imagery=cimgt.GoogleTiles(style='satellite')
fig=plt.figure(figsize=(12,8))
ax=fig.add_subplot(111, projection=ccrs.PlateCarree())
ax.set_extent([xmin, xmax, ymin, ymax]) #, ccrs.PlateCarree())
ax.add_image(imagery, 10)
ax.set_xlabel('Longitude ($^\circ$E)')
ax.set_ylabel('Latitude ($^\circ$N)')
ax.gridlines(draw_labels=True, alpha=0)

ims=[]
for i in np.arange(32):
    idry=np.zeros(NP)
    idxs=np.where(elev[i,:].flatten()+depth <=1.e-6)
    idry[idxs]=1

    triangulation = hgrid.triangulation
    mask=np.any(np.where(idry[triangulation.triangles], True, False), axis=1)
    triangulation.set_mask(mask)

    print(f'plot uv at time {i}')
    val=M[i,:].flatten()
    im=ax.tricontourf(triangulation, val, cmap='jet', levels=levels, vmin=vmin, vmax=vmax, extend='both')
    ax.set_xlim([xmin, xmax])
    ax.set_ylim([ymin, ymax])
    add_arts=im.collections
    an = ax.annotate(f'Surface velocity at {str(times[i])}', xy=(0.35, 1.1), xycoords='axes fraction')
    ims.append(add_arts+[an])

cbar=plt.colorbar(im)
cbar.set_ticks([0.4, 0.8, 1.2, 1.6, 2.0, 2.4, 2.8])
cbar.set_ticklabels(['0.4','0.8','1.2','1.6','2.0','2.4','2.8'])
cbar.set_label('Velocity (m/s)')

anim=ArtistAnimation(fig, ims, interval=600, blit=False, repeat=False)
anim.save('current_subdomain_20210706-09_4days.mp4', writer='ffmpeg',fps=3)
#anim.save('current_subdomain_test.mp4', writer='ffmpeg',fps=3)
print(f'It took {time()-t0} seconds to plot surface current')
