from time import time
from datetime import datetime

import numpy as np
import netCDF4 as nc
from netCDF4 import MFDataset,Dataset
import matplotlib.pyplot as plt
from matplotlib.animation import ArtistAnimation
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.img_tiles as cimgt

from pyschism.mesh import Hgrid

#Read hgrid
hgrid=Hgrid.open('hgrid.gr3', crs='EPSG:4326')


t0=time()
#Read netcdf
ds=MFDataset('./Elsa/schout_*.nc')
temp=ds['temp'][2::3,:,-1]
print(f'temp size is {temp.shape}')
temp[temp.mask]=0.0
#x=ds['SCHISM_hgrid_node_x'][:]
#y=ds['SCHISM_hgrid_node_y'][:]
#times=ds['time']
#times2=nc.num2date(times,units=times.units,only_use_cftime_datetimes=False)
startdate=datetime(2021,7,6,3)
enddate=datetime(2021,7,10,1)
#print(times)
times = np.arange(startdate, enddate,np.timedelta64(3,'h'), dtype='datetime64[m]')
print(times)

xmin = -85.25
xmax = -77.78
ymin = 23.69
ymax = 28.49

vmin = 20.0
vmax = 30.0
levels=np.linspace(vmin,vmax,256)

#Get mask
elev=ds['elev'][2::3,:]
depth=ds['depth'][:]
NP=len(depth)

imagery=cimgt.GoogleTiles(style='satellite')
fig=plt.figure(figsize=(12,8))
ax=fig.add_subplot(111, projection=ccrs.PlateCarree())
ax.set_extent([xmin, xmax, ymin, ymax], ccrs.PlateCarree())
ax.add_image(imagery, 10)
#ax.set_xlabel('Longitude ($^\circ$E)')
#ax.set_ylabel('Latitude ($^\circ$N)')
#ax.coastlines()
ax.gridlines(draw_labels=True, alpha=0)

#get mask
ims=[]
for i in np.arange(32):
    idry=np.zeros(NP)
    idxs=np.where(elev[i,:].flatten()+depth <=1.e-6)
    idry[idxs]=1

    triangulation = hgrid.triangulation
    mask=np.any(np.where(idry[triangulation.triangles], True, False), axis=1)
    triangulation.set_mask(mask)

    print(f'plot temp at time {i}')
    val=temp[i,:].flatten()
    im=ax.tricontourf(triangulation, val, cmap='jet', levels=levels, vmin=vmin, vmax=vmax, extend='both')
    ax.set_xlim([xmin, xmax])
    ax.set_ylim([ymin, ymax])
    #ax.set_xlabel('Longitude ($^\circ$E)')
    #ax.set_ylabel('Latitude ($^\circ$N)')  
    #ax.set_title(f'Surface salinity at {str(times[i])}')  
    add_arts=im.collections  
    an = ax.annotate(f'Surface temperature at {str(times[i])}', xy=(0.35, 1.1), xycoords='axes fraction')  
    ims.append(add_arts+[an])  

#ax.gridlines(draw_labels=True, alpha=0)
cbar=plt.colorbar(im) #, orientation='horizontal')
cbar.set_ticks([20, 22, 24, 26, 28, 30])
cbar.set_ticklabels(['20', '22', '24', '26', '28','30'])
cbar.set_label('Temperature ($^\circ$C)')

anim=ArtistAnimation(fig, ims, interval=600, blit=False, repeat=False)
anim.save('sst_subdomain_20210706-09_4days.mp4', writer='ffmpeg',fps=3)
#anim.save('sst_subdomain_test.mp4', writer='ffmpeg',fps=3)
print(f'It took {time()-t0} seconds to plot surface temperature')
