import math
from math import ceil

import scipy.stats

from db_setup.create_triggers import create_triggers
from db_setup.indices import create_indices
from db_setup.seeds import seed
from db_setup.tables import tables
from game import DB_NAME


class StDevFunc:
    """
    source: https://stackoverflow.com/a/24423341
    """

    def __init__(self):
        self.M = 0.0
        self.S = 0.0
        self.k = 1

    def step(self, value):
        if value is None:
            return
        value = float(value)
        tM = self.M
        self.M += (value - tM) / self.k
        self.S += (value - tM) * (value - self.M)
        self.k += 1

    def finalize(self):
        if self.k == 1:
            return None
        if self.k == 2:
            return 0.0
        return math.sqrt(self.S / (self.k - 1))

    def __call__(self, items):
        for i in items:
            self.step(i)
        return self.finalize()


class MeanFunc:
    """
    source: https://stackoverflow.com/a/24423341
    """

    def __init__(self):
        self.sum = 0
        self.count = 0

    def step(self, value):
        if value is None:
            return
        value = float(value)
        self.sum += value
        self.count += 1

    def finalize(self):
        return self.sum / self.count

    def __call__(self, items):
        for i in items:
            self.step(i)
        return self.finalize()


class MeanConfidenceIntervalSizeFunc:
    def __init__(self):
        self.std = StDevFunc()
        self.mean = MeanFunc()
        self.count = 0

    def step(self, value):
        self.std.step(value)
        self.mean.step(value)
        self.count += 1

    def finalize(self):
        if self.count == 0:
            return None  # same as nan for sqlite3
        if self.count == 1:
            return math.inf
        std = self.std.finalize()
        if std == 0:
            return 0
        return self.mean.finalize() - scipy.stats.t.interval(0.95, self.count - 1,
                                                             loc=self.mean.finalize(),
                                                             scale=std / math.sqrt(self.count))[0]

    def __call__(self, items):
        for i in items:
            self.step(i)
        return self.finalize()

def str_to_float(s):
    s = s.replace('.','').replace(',','.')
    assert not '+' in s or not '-' in s
    v = float(s.replace('%', ''))
    if '%' in s:
        v /= 100
    return v


def create_functions(connection):
    connection.create_function('CEIL', 1, ceil)
    connection.create_function('POWER', 2, lambda b, e: b ** e)
    connection.create_aggregate('CONF', 1, MeanConfidenceIntervalSizeFunc)
    connection.create_aggregate('TOFLOAT', 1, str_to_float)


def set_pragmas(cursor):
    cursor.execute('PRAGMA foreign_keys=1')
    cursor.execute('PRAGMA journal_mode = WAL')
    cursor.execute('PRAGMA synchronous = NORMAL')


def setup(cursor):
    print('Database setup...')

    tables(cursor)

    create_triggers(cursor)

    create_indices(cursor)

    seed(cursor)


if __name__ == '__main__':
    import model
    model.connect(DB_NAME)
    setup(model.current_cursor)
    model.current_connection.commit()