'''
Usage: python generate_adcirc_maxelev.py YYYY-MM-DD
'''

import sys
from datetime import datetime, timedelta
from time import time 
import argparse

import numpy as np
import numpy.ma as ma
from netCDF4 import MFDataset, Dataset

from hgrid import Hgrid

if __name__ == '__main__':

    t0=time()

    argparser = argparse.ArgumentParser()
    argparser.add_argument('date', type=datetime.fromisoformat, help='input file date')
    args=argparser.parse_args()
    date=args.date
    startdate=date-timedelta(days=1)

    fpath = "/sciclone/schism10/hyu05/NOAA_NWM/oper_3D/fcst/" 
    ds=MFDataset(f"{fpath}/{date.strftime('%Y%m%d')}/schout_*.nc")

    #get coordinates/bathymetry
    x=ds['SCHISM_hgrid_node_x'][:]
    y=ds['SCHISM_hgrid_node_y'][:]
    depth=ds['depth'][:]
    NP=depth.shape[0]

    #get elements and split quads into tris
    elements=ds['SCHISM_hgrid_face_nodes'][:,:]
    tris = []
    for ele in elements:
        ele=ele[~ele.mask]
        if len(ele) == 3:
            tris.append([ele[0], ele[1], ele[2]])
        elif len(ele) == 4:
            tris.append([ele[0], ele[1], ele[3]])
            tris.append([ele[1], ele[2], ele[3]])
    NE=len(tris)
    NV=3
    
    #get wetdry nodes
    #wd_nodes=ds['wetdry_node'][:,:]

    #get times
    times=ds['time'][:]
    #print(times)
    ntimes=len(times)

    #get elev 
    elev=ds['elev'][:,:]
    idxs=np.where(elev > 100000)
    elev[idxs]=-99999.0
    #mask dry node
    #wd_node=ds['wetdry_node'][:,:]
    #melev=ma.masked_array(elev, wd_node)

    #maxelevation
    maxelev=np.max(elev,axis=0)
    idxs=np.argmax(elev,axis=0)
    time_maxelev=times[idxs]

    #get mask
   
    idry=np.zeros(NP)
    idxs=np.where(maxelev+depth <= 1e-6)
    print(idxs)
    maxelev[idxs]=-99999.0

    #get boundaries
    ocean_bnd, land_bnd = Hgrid().read_hgrid('hgrid.gr3')
    nope=1
    max_nvdll=len(ocean_bnd)
    max_nvell=max_nvdll
    mesh=1

    ds.close()

    fpath2='/sciclone/pscr/lcui01/ICOGS3D_dev/outputs_adcirc'
    with Dataset(f"{fpath2}/schout_maxele_{date.strftime('%Y%m%d')}.nc", "w", format="NETCDF4") as fout:
        #dimensions
        fout.createDimension('time', None)
        fout.createDimension('node', NP)
        fout.createDimension('nele', NE)
        fout.createDimension('nvertex', NV)
        fout.createDimension('nope', nope)
        fout.createDimension('max_nvdll', max_nvdll)
        fout.createDimension('nbou', nope)
        fout.createDimension('max_nvell', max_nvell)
        fout.createDimension('mesh', mesh)

        #variables
        fout.createVariable('time', 'f8', ('time',))
        fout['time'].long_name="Time"
        fout['time'].units = f'seconds since {startdate.year}-{startdate.month}-{startdate.day} 00:00:00 UTC'
        fout['time'].base_date=f'{startdate.year}-{startdate.month}-{startdate.day} 00:00:00 UTC'
        fout['time'].standard_name="time"
        fout['time'][:] = times

        fout.createVariable('x', 'f8', ('node',))
        fout['x'].long_name="node x-coordinate"
        fout['x'].standard_name="longitude"
        fout['x'].units="degrees_east"
        fout['x'].positive="east"
        fout['x'][:]=x

        fout.createVariable('y', 'f8', ('node',))
        fout['y'].long_name="node y-coordinate"
        fout['y'].standard_name="latitude"
        fout['y'].units="degrees_north"
        fout['y'].positive="north"
        fout['y'][:]=y

        fout.createVariable('element', 'i', ('nele','nvertex',))
        fout['element'].long_name="element"
        fout['element'].standard_name="face_node_connectivity"
        fout['element'].start_index=1
        fout['element'].units="nondimensional"
        fout['element'][:]=np.array(tris)

        fout.createVariable('adcirc_mesh', 'i', ('mesh',))
        fout['adcirc_mesh'].long_name="mesh_topology"
        fout['adcirc_mesh'].cf_role="mesh_topology"
        fout['adcirc_mesh'].topology_dimension=2
        fout['adcirc_mesh'].node_coordinates="x y"
        fout['adcirc_mesh'].face_node_connectivity="element"
        fout['adcirc_mesh'][:]=1
       
        var=fout.createVariable('neta', 'i',)
        fout['neta'].long_name="total number of elevation specified boundary nodes"
        fout['neta'].units="nondimensional"
        var[0]=max_nvdll

        fout.createVariable('nvdll', 'i',('nope',))
        fout['nvdll'].long_name="number of nodes in each elevation specified boundary segment"
        fout['nvdll'].units="nondimensional"
        fout['nvdll'][:]=max_nvdll

        var=fout.createVariable('max_nvdll', 'i',)
        var[0]=max_nvdll

        fout.createVariable('ibtypee', 'i',('nope',))
        fout['ibtypee'].long_name="elevation boundary type"
        fout['ibtypee'].units="nondimensional"
        fout['ibtypee'][:]=5

        fout.createVariable('nbdv', 'i',('max_nvdll', 'nope',))
        fout['nbdv'].long_name="node numbers on each elevation specified boundary segment"
        fout['nbdv'].units="nondimensional"
        fout['nbdv'][:,0]=ocean_bnd+1

        var=fout.createVariable('nvel', 'i',)
        fout['nvel'].long_name="total number of normal flow specified boundary nodes"
        fout['nvel'].units="nondimensional"
        var[0]=max_nvdll

        fout.createVariable('nvell', 'i',('nbou',))
        fout['nvell'].long_name="number of nodes in each normal flow specified boundary segment"
        fout['nvell'].units="nondimensional"
        fout['nvell'][:]=max_nvdll

        var=fout.createVariable('max_nvell', 'i',)
        var[0]=max_nvdll

        fout.createVariable('ibtype', 'i',('nbou',))
        fout['ibtype'].long_name="type of normal flow (discharge) boundary type"
        fout['ibtype'].units="nondimensional"
        fout['ibtype'][:]=5

        fout.createVariable('nbvv', 'i',('max_nvell', 'nbou',))
        fout['nbvv'].long_name="node numbers on normal flow (discharge) boundary type"
        fout['nbvv'].units="nondimensional"
        fout['nbvv'][:,0]=ocean_bnd+1

        fout.createVariable('depth', 'f8', ('node',))
        fout['depth'].long_name="distance below NAVD88"
        fout['depth'].standard_name="depth below NAVD88"
        fout['depth'].coordinates="time y x"
        fout['depth'].location="node"
        fout['depth'].mesh="adcirc_mesh"
        fout['depth'].units="m"
        fout['depth'][:]=depth

        fout.createVariable('zeta_max','f8', ('node',), fill_value=-99999.)
        fout['zeta_max'].standard_name="maximum_sea_surface_height_above_navd88"
        fout['zeta_max'].coordinates="y x"
        fout['zeta_max'].location="node"
        fout['zeta_max'].mesh="adcirc_mesh"
        fout['zeta_max'].units="m"
        fout['zeta_max'][:]=maxelev

        fout.createVariable('time_of_zeta_max','f8', ('node',), fill_value=-99999.)
        fout['time_of_zeta_max'].standard_name="time_of_maximum_sea_surface_height_above_navd88"
        fout['time_of_zeta_max'].coordinates="y x"
        fout['time_of_zeta_max'].location="node"
        fout['time_of_zeta_max'].mesh="adcirc_mesh"
        fout['time_of_zeta_max'].units="sec"
        fout['time_of_zeta_max'][:]=time_maxelev

        fout.title = 'SCHISM Model output'
        fout.source = 'SCHISM model output version v10'
        fout.references = 'http://ccrm.vims.edu/schismweb/'

    print(f'It took {time()-t0} to interpolate')
