from datetime import datetime
import glob
import argparse
import json

import numpy as np
import pandas as pd
from netCDF4 import Dataset

def read_featureID_file(filename):

    with open(filename) as f:
        lines = f.readlines()
        feature_ids = []
        for line in lines:
            feature_ids.append(line.split('\n')[0])
    return feature_ids

def write_th_file(dataset, timeinterval, fname, issource=True):

    data = []
    for values, interval in zip(dataset, timeinterval):
        if issource:
            data.append(" ".join([f"{interval:G}", *[f'{x: .4f}' for x in values], '\n']))
        else:
            data.append(" ".join([f"{interval:G}", *[f'{-x: .4f}' for x in values], '\n']))

    with open(fname, 'w+') as fid:
        fid.writelines(data)

def write_mth_file(temp, salinity, timeinterval, fname):

    data = []
    for interval in timeinterval:
        data.append(" ".join([f"{interval:G}", *[f'{x: .4f}' for x in temp], '\n']))
    for interval in timeinterval:
        data.append(" ".join([f"{interval:G}", *[f'{x: .4f}' for x in salinity], '\n']))

    with open(fname, 'w+') as fid:
        fid.writelines(data)

def get_aggregated_features(nc_feature_id, features):

    aggregated_features = []
    for source_feats in features:
        aggregated_features.extend(list(source_feats))

    in_file=[]
    for feature in aggregated_features:
        idx=np.where(nc_feature_id == int(feature))[0]
        in_file.append(idx.item())

    in_file_2 = []
    sidx = 0
    for source_feats in features:
        eidx = sidx + len(source_feats)
        #in_file_2.append(in_file[sidx:eidx].tolist())
        in_file_2.append(in_file[sidx:eidx])
        sidx = eidx
    return in_file_2

def streamflow_lookup(file, indexes, threshold=-1e-5):
    nc = Dataset(file)
    streamflow = nc["streamflow"][:]
    streamflow[np.where(streamflow < threshold)] = 0.0
    #change masked value to zero
    streamflow[np.where(streamflow.mask)] = 0.0
    data = []
    for indxs in indexes:
        # Note: Dataset already consideres scale factor and offset.
        data.append(np.sum(streamflow[indxs]))
    nc.close()
    return data

if __name__ == '__main__':
    '''
    Usage: python extract2asci.py "yyyy-mm-dd" or "yyyy-mm-dd hh:mm:ss"
    Run this script in oper_3D/NWM/Combined/ directory. Inputs are in the same directory:
        1 sources_{conus, alaska, hawaii}_global.json
           sinks_{conus, alaska, hawaii}_global.json
        2 ./cached/nwm*.{conus, alaska, hawaii}.nc

    '''

    #input paramters 
    argparser = argparse.ArgumentParser()
    argparser.add_argument('date', type=datetime.fromisoformat, help='input file date')
    args=argparser.parse_args()
    startdate=args.date
    #startdate = datetime(2022, 3, 29, 0)

    #1. region name
    layers = ['conus', 'alaska', 'hawaii']


    sources_all = {}
    sinks_all = {}
    eid_sources = []
    eid_sinks = []
    times = []
    dates = []

    for layer in layers:
        print(f'layer is {layer}')
        fname_source = f'./sources_{layer}_v3.json'
        fname_sink = f'./sinks_{layer}_v3.json'
        sources_fid = json.load(open(fname_source))
        sinks_fid = json.load(open(fname_sink))

        #add to the final list
        eid_sources.extend(list(sources_fid.keys()))
        eid_sinks.extend(list(sinks_fid.keys()))


        files = glob.glob(f'./cached/nwm*.{layer}.nc')
        files.sort()
        print(f'file 0 is {files[0]}')
        nc_fid0 = Dataset(files[0])["feature_id"][:]
        src_idxs = get_aggregated_features(nc_fid0, sources_fid.values())
        snk_idxs = get_aggregated_features(nc_fid0, sinks_fid.values())

        sources = []
        sinks = []
        for fname in files:
            ds = Dataset(fname)
            ncfeatureid=ds['feature_id'][:]
            if not np.all(ncfeatureid == nc_fid0):
                print(f'Indexes of feature_id are changed in  {fname}')
                src_idxs=get_aggregated_features(ncfeatureid, sources_fid.values())
                snk_idxs=get_aggregated_features(ncfeatureid, sinks_fid.values())
                nc_fid0 = ncfeatureid

            sources.append(streamflow_lookup(fname, src_idxs))
            sinks.append(streamflow_lookup(fname, snk_idxs))

            model_time = datetime.strptime(ds.model_output_valid_time, "%Y-%m-%d_%H:%M:%S")
            if layer == 'conus':
                dates.append(str(model_time))
                times.append((model_time - startdate).total_seconds())
            ds.close()
        sources_all[layer] = np.array(sources)
        sinks_all[layer] = np.array(sinks)

    sources = np.concatenate((sources_all['conus'], sources_all['alaska'], sources_all['hawaii']), axis=1) 
    sinks = np.concatenate((sinks_all['conus'], sinks_all['alaska'], sinks_all['hawaii']), axis=1) 
    print(sources.shape)
    print(sinks.shape)

    df_nwm = pd.DataFrame(data=sources, columns=np.array(eid_sources), index=np.array(dates))

    #write to file
    write_th_file(sources, times, 'vsource.th', issource=True)
    write_th_file(sinks, times, 'vsink.th', issource=False)

    #write msource.th
    nsource = sources.shape[1]
    print(f'nsource is {nsource}')
    temp = np.full(nsource, -99999.0)
    salt = np.full(nsource, 30.0)
    write_mth_file(temp, salt, times, 'msource.th')

    nsink = np.array(eid_sinks).shape[0]
    #write source_sink.in
    with open('source_sink.in', 'w+') as f:
        f.write('{:<d} \n'.format(nsource))
        for eid in eid_sources:
            f.write('{:<d} \n'.format(int(eid)))
        f.write('\n')

        f.write('{:<d} \n'.format(nsink))
        for eid in eid_sinks:
            f.write('{:<d} \n'.format(int(eid)))
