from time import time
from datetime import datetime

import numpy as np
import netCDF4 as nc
from netCDF4 import Dataset, MFDataset
import matplotlib.pyplot as plt
from matplotlib.animation import ArtistAnimation
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt

from pyschism.mesh import Hgrid

#Read hgrid
hgrid=Hgrid.open('hgrid.gr3', crs='EPSG:4326')


t0=time()
#Read netcdf
ds=MFDataset('./Elsa/schout_*.nc')
#ds=Dataset('./Elsa/schout_20210705.nc')
salt=ds['salt'][2::3,:,-1]
print(f'salt size is {salt.shape}')
salt[salt.mask]=0.0
#x=ds['SCHISM_hgrid_node_x'][:]
#y=ds['SCHISM_hgrid_node_y'][:]
#times=ds['time']
#times2=nc.num2date(times,units=times.units,only_use_cftime_datetimes=False)
startdate=datetime(2021,7,6,3)
enddate=datetime(2021,7,10,1)
#print(times)
times = np.arange(startdate, enddate,np.timedelta64(3,'h'), dtype='datetime64[m]')
print(times)

xmin = -85.25
xmax = -77.78
ymin = 23.69
ymax = 28.49

vmin = 0.0
vmax = 34.0
levels=np.linspace(vmin,vmax,256)

#Get mask
elev=ds['elev'][2::3,:]
depth=ds['depth'][:]
NP=len(depth)

imagery=cimgt.GoogleTiles(style='satellite')
fig=plt.figure(figsize=(12,8))
ax=fig.add_subplot(111, projection=ccrs.PlateCarree()) #projection=imagery.crs)
ax.set_extent([xmin, xmax, ymin, ymax]) #, ccrs.PlateCarree())
ax.add_image(imagery, 10)
ax.set_xlabel('Longitude ($^\circ$E)')
ax.set_ylabel('Latitude ($^\circ$N)')
ax.gridlines(draw_labels=True, alpha=0)

ims=[]
for i in np.arange(32):
    idry=np.zeros(NP)
    idxs=np.where(elev[i,:].flatten()+depth <=1.e-6)
    idry[idxs]=1

    triangulation = hgrid.triangulation
    mask=np.any(np.where(idry[triangulation.triangles], True, False), axis=1)
    triangulation.set_mask(mask)

    print(f'plot salt at time {i}')
    val=salt[i,:].flatten()
    im=ax.tricontourf(triangulation, val, cmap='jet', levels=levels, vmin=vmin, vmax=vmax, extend='both')
    ax.set_xlim([xmin, xmax])
    ax.set_ylim([ymin, ymax])
    #ax.set_xlabel('Longitude ($^\circ$E)')
    #ax.set_ylabel('Latitude ($^\circ$N)')  
    #ax.set_title(f'Surface salinity at {str(times[i])}')  
    add_arts=im.collections
    an = ax.annotate(f'Surface salinity at {str(times[i])}', xy=(0.35, 1.1), xycoords='axes fraction')
    ims.append(add_arts+[an])

cbar=plt.colorbar(im)
cbar.set_ticks([0, 4, 8, 12, 16, 20, 24, 28, 32])
cbar.set_ticklabels(['0','4','8','12','16', '20', '24', '28', '32'])
cbar.set_label('Salinity')

anim=ArtistAnimation(fig, ims, interval=600, blit=False, repeat=False)
anim.save('sss_subdomain_20210706-09_4days.mp4', writer='ffmpeg',fps=3)
print(f'It took {time()-t0} seconds to plot surface salinity')
