"""
Author: 
Email: 
Last Modified: Oct, 2021 

Description: This script is for parameter tuning only. Goal is to observe ~ 50 infections in one month period in the UIHC graph.
Temporal graph.

Usage

To run it on UIHC original graph,
$ python parameter_tuning.py -name Karate_temporal
$ python parameter_tuning.py -name UIHC_HCP_patient_room -year 2011 -dose_response exponential
$ python parameter_tuning.py -name UIHC_HCP_patient_room -year 2011 -sampled True -dose_response exponential

"""
from utils.load_network import *
from simulator_load_sharing_temporal_v2 import *

import pandas as pd
import argparse
import math
import random as random
import copy
import timeit
import numpy as np
from tqdm import tqdm

# Get people in day 0 w/ at least 1 neighbor
def get_people_array_in_day0(G, node_name_to_idx_mapping):
    day0_people_list = []
    for node_name, degree in G.degree:
        if degree > 0:
            node_idx = node_name_to_idx_mapping[node_name]
            if node_idx in people_nodes_idx:
                day0_people_list.append(node_idx)
    day0_people_array = np.array(day0_people_list)
    return day0_people_array 


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='greedy source detection, missing infection')
    parser.add_argument('-name', '--name', type=str, default="Karate_temporal",
                        help= 'network to use. Karate_temporal | UIHC_Jan2010_patient_room_temporal | UIHC_HCP_patient_room | UVA_temporal')
    parser.add_argument('-year', '--year', type=int, default=2011,
                        help= '2007 | 2011')
    parser.add_argument('-sampled', '--sampled', type=bool, default=False,
                        help= 'set it True to use sampled data.')
    parser.add_argument('-dose_response', '--dose_response', type=str, default="exponential",
                        help= 'dose-response function')
    args = parser.parse_args()

    name = args.name
    year = args.year
    sampled = args.sampled
    dose_response = args.dose_response

    np.set_printoptions(suppress=True)
    n_timesteps = 31
    rho = 0.4
    d = 0.1
    q = 8
    pi = 1.0
    # contact_area = 150
    contact_area = 150
    area_people = 2000 # area of patient. 2000cm^2
    area_location = 40000 # area of room. 40000cm^2

    flag_increase_area = True # If this is set to True, then increase area of each node based on their max degree over grpahs
    n_replicates = 10 # number of simulations on the same seed
    n_exp = 10 # select 10 different starting seed sets
    k=1 # number of seeds

    ####################################################################
    print("Load network")
    if name == "Karate_temporal":
        # q = 2
        # pi = 1.0 # pi is the infectivity. f(x) = 1 - e ^ {- pi * load}
        G_over_time, people_nodes, people_nodes_idx, location_nodes_idx, area_array = load_karate_temporal_network(area_people, area_location, flag_increase_area)
    elif name == "UVA_temporal":
        # q = 2
        # pi = 1.0 # pi is the infectivity. f(x) = 1 - e ^ {- pi * load}
        G_over_time, people_nodes, people_nodes_idx, location_nodes_idx, area_array = load_UVA_temporal_network(area_people, area_location, flag_increase_area)
    elif name == "UIHC_Jan2010_patient_room_temporal":
        # q = 2
        # pi = 1.0 # pi is the infectivity. f(x) = 1 - e ^ {- pi * load}
        G_over_time, people_nodes, people_nodes_idx, location_nodes_idx, area_array = load_UIHC_Jan2010_patient_room_temporal_network(area_people, area_location, flag_increase_area)
    elif name == "UIHC_HCP_patient_room":
        # q = 10
        # pi = 1.0 # pi is the infectivity. f(x) = 1 - e ^ {- pi * load}

        if sampled:
            # contact_area = 10
            name = "{}_{}_sampled".format(name, year)
        else:
            # contact_area = 10
            name = "{}_{}".format(name, year)
        # if year = 2011 # Use non-overlap data.
        # if sampled = True # Use the subgraph. Sampled based on the unit with the most number of CDI cases.
        G_over_time, people_nodes, people_nodes_idx, location_nodes_idx, area_array = load_UIHC_HCP_patient_room_temporal_network(year, sampled, area_people, area_location, flag_increase_area)

    node_name_to_idx_mapping = dict([(node_name, node_idx) for node_idx, node_name in enumerate(G_over_time[0].nodes())])
    node_idx_to_name_mapping = dict([(node_idx, node_name) for node_idx, node_name in enumerate(G_over_time[0].nodes())])

    day0_people_idx_array = get_people_array_in_day0(G_over_time[0], node_name_to_idx_mapping)

    ####################################################################
    # 0. Create simulation instance with empty seeds list
    simul = Simulation(G_over_time, [], people_nodes, area_array, contact_area, n_timesteps, rho, d, q, pi, dose_response, n_replicates)

    rho_list = []
    d_list = []
    q_list = []
    pi_list = []
    contact_area_list = []
    min_avg_inf_cnt_list = []
    avg_avg_inf_cnt_list = []
    max_avg_inf_cnt_list = []
    avg_recover_event_cnt_list = []

    # parameter_list_contact_area = [150, 500, 1000, 1500, 2000]
    # parameter_list_rho = [0.1, 0.3, 0.5]
    # parameter_list_d = [0.1, 0.3, 0.5]
    # parameter_list_q = [0.01, 0.03, 0.1, 0.3, 1, 3, 9]
    # parameter_list_pi = [0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0]

    parameter_list_contact_area = [2000*pow(math.e, -i) for i in range(5)] # [2000.0, 735.7588823428847, 270.6705664732254]
    parameter_list_rho = [pow(math.e, -i) for i in range(5)] # [1.0, 0.36787944117144233, 0.1353352832366127]
    parameter_list_d = [pow(math.e, -i) for i in range(5)] # [1.0, 0.36787944117144233, 0.1353352832366127]
    parameter_list_q = [pow(math.e, i) for i in range(3)] + [pow(math.e, -i) for i in range(1, 3)] # [1.0, 2.718281828459045, 7.3890560989306495] + ..
    parameter_list_pi = [pow(math.e, -i) for i in range(5)] # [1.0, 0.36787944117144233, 0.1353352832366127]

    total_combinations_of_parameters = len(parameter_list_contact_area) * len(parameter_list_rho) * len(parameter_list_d) * \
                                        len(parameter_list_q) * len(parameter_list_pi)
    # NOTE: Need to reconstruct BplusD whenever 'contact_area', 'rho', 'd' change
    loop_idx=0
    for contact_area in parameter_list_contact_area:
        simul.contact_area = contact_area
        for rho in parameter_list_rho:
            simul.rho = rho
            for d in parameter_list_d:
                simul.d = d

                # with these three paremeters, the set of parameters may lead to negative load. Then, skip.
                if not 0 <= simul.d + simul.rho * (simul.contact_area / np.min(simul.area_array)) <= 1:
                    continue

                # NOTE: Need to reconstruct BplusD whenever 'contact_area', 'rho', 'd' change
                simul.BplusD_over_time = simul.constructBplusD_over_time()
                for q in parameter_list_q:
                    simul.q = q
                    for pi in parameter_list_pi:
                        simul.pi = pi

                        print("{}/{}...".format(loop_idx, total_combinations_of_parameters), end='\r', flush=True)
                        loop_idx+=1

                        avg_inf_cnt_array = np.zeros((n_exp))
                        # NOTE: recover_event_cnt_array the counts here are summation over such events at any time an over 10 replicates
                        recover_event_cnt_array = np.zeros((n_exp))
                        # print("Start simulation")
                        for i in range(n_exp):

                            # S_original = np.random.choice(a=people_nodes_idx, size=k, replace=False)
                            S_original = np.random.choice(a=day0_people_idx_array, size=k, replace=False)
                            S_name = node_idx_to_name_mapping[S_original[0]]

                            seeds_array = np.zeros((simul.n_timesteps, simul.number_of_nodes)).astype(bool)
                            seeds_array[0, S_original] = True
                            simul.set_seeds(seeds_array)
                            simul.simulate()

                            infection_array = simul.infection_array
                            #---------------------------------------------------
                            # 1. Get avg infection count
                            # Get number of infected patients per simulation
                            # len_P_array = infection_array[:,-1,:].sum(axis=1)
                            len_P_array = np.sum(np.sum(infection_array, axis=1).astype(bool), axis=1)
                            avg_inf_cnt = np.mean(len_P_array)
                            avg_inf_cnt_array[i] = avg_inf_cnt
                            #---------------------------------------------------
                            # 2. Get counts of 'RECOVERED' events
                            recover_event_cnt = 0
                            for t in range(n_timesteps-1):
                                recover_event_cnt += ((infection_array[:,t,:].astype(int) - infection_array[:,t+1,:].astype(int)) == 1).sum()
                            recover_event_cnt_array[i] = recover_event_cnt

                            # print("Seeds: {}, avg: {:.2f}, len(P) array: {}".format(S_original, avg_inf_cnt, len_P_array))

                        # print()
                        # print("network: {}".format(name))
                        # print("T: {}, rho: {:.2f}, d: {:.2f}, contact_area: {}, q: {:.2f}, pi: {:.2f}".format(n_timesteps, rho, d, contact_area, q, pi))
                        # print("Avg(avg): {:.2f}, Max(avg): {:.2f}".format(np.mean(avg_inf_cnt_array), np.max(avg_inf_cnt_array)))

                        min_avg_inf_cnt = np.min(avg_inf_cnt_array)
                        avg_avg_inf_cnt = np.mean(avg_inf_cnt_array)
                        max_avg_inf_cnt = np.max(avg_inf_cnt_array)
                        avg_recover_event_cnt = np.mean(recover_event_cnt/n_replicates)

                        # append current parameters, results to lists
                        rho_list.append(rho)
                        d_list.append(d)
                        q_list.append(q)
                        pi_list.append(pi)
                        contact_area_list.append(contact_area)
                        min_avg_inf_cnt_list.append(min_avg_inf_cnt)
                        avg_avg_inf_cnt_list.append(avg_avg_inf_cnt)
                        max_avg_inf_cnt_list.append(max_avg_inf_cnt)
                        avg_recover_event_cnt_list.append(avg_recover_event_cnt)

                # NOTE: Save whenever 'contact_area', 'rho', 'd' change
                # Save dataframes every 
                df_results = pd.DataFrame(data={
                    "trans-eff": rho_list,
                    "die-off": d_list,
                    "shedding": q_list,
                    "infectivity": pi_list,
                    "A(contact)": contact_area_list,
                    "Inf(min)": min_avg_inf_cnt_list,
                    "Inf(avg)": avg_avg_inf_cnt_list,
                    "Inf(max)": max_avg_inf_cnt_list,
                    "Recover_cnt(avg)": avg_recover_event_cnt_list
                    })
                print(df_results)

                # Save intermediary results
                # df_results.to_csv("../tables/parameter_tuning/{}/tuning1.csv".format(name), index=False)
                df_results.to_csv("../tables/parameter_tuning/{}/tuning2.csv".format(name), index=False)

    # Save final results
    # df_results.to_csv("../tables/parameter_tuning/{}/tuning1.csv".format(name), index=False)
    df_results.to_csv("../tables/parameter_tuning/{}/tuning2.csv".format(name), index=False)
