import faulthandler
import functools
import inspect
import json
import os
import random
import time
from bisect import bisect_left
from datetime import datetime, timedelta
from math import floor, log10, sqrt, nan, inf

import scipy.stats
import tabulate
from numpy.random.mtrand import binomial

from lib import stack_tracer
from lib.print_exc_plus import print_exc_plus

chars = [str(d) for d in range(1, 10)]
digits = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
ps = [1. for _ in chars]
# the first number is absolute and the second relative
letter_dist = [("E", 21912, 12.02), ("T", 16587, 9.1), ("A", 14810, 8.12), ("O", 14003, 7.68), ("I", 13318, 7.31),
               ("N", 12666, 6.95), ("S", 11450, 6.28), ("R", 10977, 6.02), ("H", 10795, 5.92), ("D", 7874, 4.32),
               ("L", 7253, 3.98), ("U", 5246, 2.88), ("C", 4943, 2.71), ("M", 4761, 2.61), ("F", 4200, 2.3),
               ("Y", 3853, 2.11), ("W", 3819, 2.09), ("G", 3693, 2.03), ("P", 3316, 1.82), ("B", 2715, 1.49),
               ("V", 2019, 1.11), ("K", 1257, 0.69), ("X", 315, 0.17), ("Q", 205, 0.11), ("J", 188, 0.1),
               ("Z", 128, 0.07), ]
sp = sum(ps)
for row in letter_dist:
    chars.append(row[0])
    ps.append(float(row[2]))
ps = [p / sum(ps) for p in ps]


def choice(sequence, probabilities):
    # if sum(probabilities) != 1:
    #     raise AssertionError('Probabilities must sum to 1')
    r = random.random()
    for idx, c in enumerate(sequence):
        r -= probabilities[idx]
        if r < 0:
            return c
    raise AssertionError('Probabilities must sum to 1')


def multiple_choice(sequence, count):
    results = []
    num_remaining = len(sequence)
    for _ in range(count):
        idx = random.randrange(num_remaining)
        results.append(sequence[idx])
        del sequence[idx]
        num_remaining -= 1
    return results


try:
    import winsound as win_sound


    def beep(*args, **kwargs):
        win_sound.Beep(*args, **kwargs)
except ImportError:
    win_sound = None


    def beep(*_args, **_kwargs):
        pass


def main_wrapper(f):
    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        # import lib.stack_tracer
        import __main__
        # does not help much
        # monitoring_thread = hanging_threads.start_monitoring(seconds_frozen=180, test_interval=1000)
        os.makedirs('logs', exist_ok=True)
        stack_tracer.trace_start('logs/' + os.path.split(__main__.__file__)[-1] + '.html', interval=5)
        faulthandler.enable()
        profile_wall_time_instead_if_profiling()

        # noinspection PyBroadException
        try:
            f(*args, **kwargs)
        except Exception:
            print_exc_plus()
            exit(-1)
        finally:
            total_time = time.perf_counter() - start
            frequency = 2000
            duration = 500
            beep(frequency, duration)
            print('Total time', total_time)

    return wrapper


def random_chars(count):
    return ''.join(choice(chars, probabilities=ps) for _ in range(count))


def str2bool(v):
    v = str(v).strip().lower()
    if v in ["yes", 'y' "true", "t", "1"]:
        return True
    if v in ["no", 'n' "false", "f", "0", '', 'null', 'none']:
        return False
    raise ValueError('Can not convert `' + v + '` to bool')


def my_tabulate(data, **params):
    if data == [] and 'headers' in params:
        data = [(None for _ in params['headers'])]
    tabulate.MIN_PADDING = 0
    return tabulate.tabulate(data, **params)


def yn_dialog(msg):
    while True:
        result = input(msg + ' [y/n]: ')
        if result == 'y':
            return True
        if result == 'n':
            return False
        print('Type in \'y\' or \'n\'!')


def round_to_closest_value(x, values):
    values = sorted(values)
    next_largest = bisect_left(values, x)  # binary search
    if next_largest == 0:
        return values[0]
    if next_largest == len(values):
        return values[-1]
    next_smallest = next_largest - 1
    smaller = values[next_smallest]
    larger = values[next_largest]
    if abs(smaller - x) < abs(larger - x):
        return smaller
    else:
        return larger


def binary_search(a, x, lo=0, hi=None):
    hi = hi if hi is not None else len(a)  # hi defaults to len(a)

    pos = bisect_left(a, x, lo, hi)  # find insertion position

    return pos if pos != hi and a[pos] == x else -1  # don't walk off the end


def ceil_to_closest_value(x, values):
    values = sorted(values)
    next_largest = bisect_left(values, x)  # binary search
    if next_largest < len(values):
        return values[next_largest]
    else:
        return values[-1]  # if there is no larger value use the largest one


def upset_binomial(mu, p, factor):
    if factor > 1:
        raise NotImplementedError()
    return (binomial(mu / p, p) - mu) * factor + mu


def multinomial(n, bins):
    if bins == 0:
        if n > 0:
            raise ValueError('Cannot distribute to 0 bins.')
        return []
    remaining = n
    results = []
    for i in range(bins - 1):
        x = binomial(remaining, 1 / (bins - i))
        results.append(x)
        remaining -= x

    results.append(remaining)
    return results


def round_to_n(x, n):
    return round(x, -int(floor(log10(x))) + (n - 1))


def get_all_subclasses(klass):
    all_subclasses = []

    for subclass in klass.__subclasses__():
        all_subclasses.append(subclass)
        all_subclasses.extend(get_all_subclasses(subclass))

    return all_subclasses


def latin1_json(data):
    return json.dumps(data, ensure_ascii=False).encode('latin-1')


def l2_norm(v1, v2):
    if len(v1) != len(v2):
        raise ValueError('Both vectors must be of the same size')
    return sqrt(sum([(x1 - x2) * (x1 - x2) for x1, x2 in zip(v1, v2)]))


def allow_additional_unused_keyword_arguments(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        allowed_kwargs = [param.name for param in inspect.signature(func).parameters.values()]
        allowed_kwargs = {a: kwargs[a] for a in kwargs if a in allowed_kwargs}
        return func(*args, **allowed_kwargs)

    return wrapper


def rename(new_name):
    def decorator(f):
        f.__name__ = new_name
        return f

    return decorator


def mean_confidence_interval_size(data, confidence=0.95):
    if len(data) == 0:
        return nan
    if len(data) == 1:
        return inf
    if scipy.stats.sem(data) == 0:
        return 0
    return len(data) / sum(data) - scipy.stats.t.interval(confidence, len(data) - 1,
                                                          loc=len(data) / sum(data),
                                                          scale=scipy.stats.sem(data))[0]


class LogicError(Exception):
    pass


def round_time(dt=None, precision=60):
    """Round a datetime object to any time lapse in seconds
    dt : datetime.datetime object, default now.
    roundTo : Closest number of seconds to round to, default 1 minute.
    Author: Thierry Husson 2012 - Use it as you want but don't blame me.
    """
    if dt is None:
        dt = datetime.now()
    if isinstance(precision, timedelta):
        precision = precision.total_seconds()
    seconds = (dt.replace(tzinfo=None) - dt.min).seconds
    rounding = (seconds + precision / 2) // precision * precision
    return dt + timedelta(seconds=rounding - seconds,
                          microseconds=dt.microsecond)


def profile_wall_time_instead_if_profiling():
    try:
        import yappi
    except ModuleNotFoundError:
        return
    currently_profiling = len(yappi.get_func_stats())
    if currently_profiling and yappi.get_clock_type() != 'wall':
        print('Changing yappi clock type to wall and restarting yappi.')
        yappi.stop()
        yappi.clear_stats()
        yappi.set_clock_type("wall")
        yappi.start()


def dummy_computation(_data):
    return


def current_year_begin():
    return datetime(datetime.today().year, 1, 1).timestamp()


def current_day_begin():
    return datetime.today().timestamp() // (3600 * 24) * (3600 * 24)


def current_second_begin():
    return floor(datetime.today().timestamp())