util.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146
  1. import datetime
  2. import faulthandler
  3. import functools
  4. import gc
  5. import inspect
  6. import json
  7. import math
  8. import os
  9. import random
  10. import re
  11. import sqlite3
  12. import sys
  13. import threading
  14. import time
  15. from bisect import bisect_left
  16. from enum import Enum
  17. from itertools import chain, combinations
  18. from math import log, isnan, nan, floor, log10, gcd
  19. from numbers import Number
  20. from shutil import copyfile
  21. from threading import RLock
  22. from types import FunctionType
  23. from typing import Union, Tuple, List, Optional, Dict, Any, Type
  24. # noinspection PyUnresolvedReferences
  25. from unittest import TestCase, mock
  26. import cachetools
  27. import hanging_threads
  28. import matplotlib.cm
  29. import matplotlib.pyplot as plt
  30. import numpy
  31. import numpy as np
  32. import pandas
  33. import scipy.optimize
  34. import scipy.stats
  35. import tabulate
  36. from scipy.ndimage import zoom
  37. from lib import stack_tracer, print_exc_plus
  38. from lib.my_logger import logging
  39. X = Y = Z = float
  40. class KnownIssue(Exception):
  41. """
  42. This means the code is not working and should not be used but still too valuable to be deleted
  43. """
  44. pass
  45. def powerset(iterable):
  46. """powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"""
  47. s = list(iterable)
  48. return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
  49. def plot_with_conf(x, y_mean, y_conf, alpha=0.5, **kwargs):
  50. ax = kwargs.pop('ax', plt.gca())
  51. base_line, = ax.plot(x, y_mean, **kwargs)
  52. y_mean = np.array(y_mean)
  53. y_conf = np.array(y_conf)
  54. lb = y_mean - y_conf
  55. ub = y_mean + y_conf
  56. ax.fill_between(x, lb, ub, facecolor=base_line.get_color(), alpha=alpha)
  57. def choice(sequence, probabilities):
  58. # if sum(probabilities) != 1:
  59. # raise AssertionError('Probabilities must sum to 1')
  60. r = random.random()
  61. for idx, c in enumerate(sequence):
  62. r -= probabilities[idx]
  63. if r < 0:
  64. return c
  65. raise AssertionError('Probabilities must sum to 1')
  66. def print_attributes(obj, include_methods=False, ignore=None):
  67. if ignore is None:
  68. ignore = []
  69. for attr in dir(obj):
  70. if attr in ignore:
  71. continue
  72. if attr.startswith('_'):
  73. continue
  74. if not include_methods and callable(obj.__getattr__(attr)):
  75. continue
  76. print(attr, ':', obj.__getattr__(attr).__class__.__name__, ':', obj.__getattr__(attr))
  77. def attr_dir(obj, include_methods=False, ignore=None):
  78. if ignore is None:
  79. ignore = []
  80. return {attr: obj.__getattr__(attr)
  81. for attr in dir(obj)
  82. if not attr.startswith('_') and (
  83. include_methods or not callable(obj.__getattr__(attr))) and attr not in ignore}
  84. def zoom_to_shape(a: np.ndarray, shape: Tuple, mode: str = 'smooth', verbose=1):
  85. from keras import backend
  86. a = np.array(a, dtype=backend.floatx()) # also does a copy
  87. shape_dim = len(a.shape)
  88. if len(a.shape) != len(shape):
  89. raise ValueError('The shapes must have the same dimension but were len({0}) = {1} (original) '
  90. 'and len({2}) = {3} desired.'.format(a.shape, len(a.shape), shape, len(shape)))
  91. if len(shape) == 0:
  92. return a
  93. zoom_factors = tuple(shape[idx] / a.shape[idx] for idx in range(shape_dim))
  94. def _current_index_in_old_array():
  95. return tuple(slice(0, length) if axis != current_axis else slice(current_pixel_index, current_pixel_index + 1)
  96. for axis, length in enumerate(a.shape))
  97. def _current_pixel_shape():
  98. return tuple(length if axis != current_axis else 1
  99. for axis, length in enumerate(a.shape))
  100. def _current_result_index():
  101. return tuple(
  102. slice(0, length) if axis != current_axis else slice(pixel_index_in_result, pixel_index_in_result + 1)
  103. for axis, length in enumerate(a.shape))
  104. def _current_result_shape():
  105. return tuple(orig_length if axis != current_axis else shape[axis]
  106. for axis, orig_length in enumerate(a.shape))
  107. if mode == 'constant':
  108. result = zoom(a, zoom_factors)
  109. assert result.shape == shape
  110. return result
  111. elif mode == 'smooth':
  112. result = a
  113. for current_axis, zoom_factor in sorted(enumerate(zoom_factors), key=lambda x: x[1]):
  114. result = np.zeros(_current_result_shape(), dtype=backend.floatx())
  115. # current_length = a.shape[current_axis]
  116. desired_length = shape[current_axis]
  117. current_pixel_index = 0
  118. current_pixel_part = 0 # how much of the current pixel is already read
  119. for pixel_index_in_result in range(desired_length):
  120. pixels_remaining = 1 / zoom_factor
  121. pixel_sum = np.zeros(_current_pixel_shape())
  122. while pixels_remaining + current_pixel_part > 1:
  123. pixel_sum += (1 - current_pixel_part) * a[_current_index_in_old_array()]
  124. current_pixel_index += 1
  125. pixels_remaining -= (1 - current_pixel_part)
  126. current_pixel_part = 0
  127. # the remaining pixel_part
  128. try:
  129. pixel_sum += pixels_remaining * a[_current_index_in_old_array()]
  130. except (IndexError, ValueError):
  131. if verbose:
  132. print('WARNING: Skipping {0} pixels because of numerical imprecision.'.format(pixels_remaining))
  133. else:
  134. current_pixel_part += pixels_remaining
  135. # insert to result
  136. pixel_sum *= zoom_factor
  137. result[_current_result_index()] = pixel_sum
  138. a = result
  139. assert result.shape == shape
  140. return result
  141. else:
  142. return NotImplementedError('Mode not available.')
  143. def profile_wall_time_instead_if_profiling():
  144. try:
  145. import yappi
  146. except ModuleNotFoundError:
  147. return
  148. currently_profiling = len(yappi.get_func_stats())
  149. if currently_profiling and yappi.get_clock_type() != 'wall':
  150. yappi.stop()
  151. print('Profiling wall time instead of cpu time.')
  152. yappi.clear_stats()
  153. yappi.set_clock_type("wall")
  154. yappi.start()
  155. def dummy_computation(*_args, **_kwargs):
  156. pass
  157. def backup_file(filename):
  158. copyfile(filename, backup_file_path(filename))
  159. def backup_file_path(filename):
  160. return filename + time.strftime("%Y%m%d") + '.bak'
  161. # noinspection SpellCheckingInspection
  162. def my_tabulate(data, tablefmt='pipe', **params):
  163. if data == [] and 'headers' in params:
  164. data = [(None for _ in params['headers'])]
  165. tabulate.MIN_PADDING = 0
  166. return tabulate.tabulate(data, tablefmt=tablefmt, **params)
  167. def ce_loss(y_true, y_predicted):
  168. return -(y_true * log(y_predicted) + (1 - y_true) * log(1 - y_predicted))
  169. class DontSaveResultsError(Exception):
  170. pass
  171. def multinomial(n, bins):
  172. if bins == 0:
  173. if n > 0:
  174. raise ValueError('Cannot distribute to 0 bins.')
  175. return []
  176. remaining = n
  177. results = []
  178. for i in range(bins - 1):
  179. from numpy.random.mtrand import binomial
  180. x = binomial(remaining, 1 / (bins - i))
  181. results.append(x)
  182. remaining -= x
  183. results.append(remaining)
  184. return results
  185. class UnknownTypeError(Exception):
  186. pass
  187. # def shape_analysis(xs):
  188. # composed_dtypes = [list, tuple, np.ndarray, dict, set]
  189. # base_dtypes = [str, int, float, type, object] # TODO add class and superclass of xs first element
  190. # all_dtypes = composed_dtypes + base_dtypes
  191. # if isinstance(xs, np.ndarray):
  192. # outer_brackets = ('[', ']')
  193. # shape = xs.shape
  194. # dtype = xs.dtype
  195. # elif isinstance(xs, tuple):
  196. # outer_brackets = ('(', ')')
  197. # shape = len(xs)
  198. # dtype = [t for t in all_dtypes if all(isinstance(x, t) for x in xs)][0]
  199. # elif isinstance(xs, list):
  200. # outer_brackets = ('[', ']')
  201. # shape = len(xs)
  202. # dtype = [t for t in all_dtypes if all(isinstance(x, t) for x in xs)][0]
  203. # elif isinstance(xs, dict) or isinstance(xs, set):
  204. # outer_brackets = ('{', '}')
  205. # shape = len(xs)
  206. # dtype = [t for t in all_dtypes if all(isinstance(x, t) for x in xs)][0]
  207. # elif any(isinstance(xs, t) for t in base_dtypes):
  208. # for t in base_dtypes:
  209. # if isinstance(xs, t):
  210. # return str(t.__name__)
  211. # raise AssertionError('This should be unreachable.')
  212. # else:
  213. # raise UnknownTypeError('Unknown type:' + type(xs).__name__)
  214. #
  215. # if shape and shape != '?':
  216. # return outer_brackets[0] + str(xs.shape) + ' * ' + str(dtype) + outer_brackets[1]
  217. # else:
  218. # return outer_brackets[0] + outer_brackets[1]
  219. def beta_conf_interval_mle(data, conf=0.95):
  220. if len(data) <= 1:
  221. return 0, 1 # overestimates the interval
  222. if any(d < 0 or d > 1 or isnan(d) for d in data):
  223. return nan, nan
  224. if numpy.var(data) == 0:
  225. return numpy.mean(data), numpy.mean(data)
  226. epsilon = 1e-3
  227. # adjusted_data = data.copy()
  228. # for idx in range(len(adjusted_data)):
  229. # adjusted_data[idx] *= (1 - 2 * epsilon)
  230. # adjusted_data[idx] += epsilon
  231. alpha, beta, _, _ = scipy.stats.beta.fit(data, floc=-epsilon, fscale=1 + 2 * epsilon)
  232. lower, upper = scipy.stats.beta.interval(alpha=conf, a=alpha, b=beta)
  233. if lower < 0:
  234. lower = 0
  235. if upper < 0:
  236. upper = 0
  237. if lower > 1:
  238. lower = 1
  239. if upper > 1:
  240. upper = 1
  241. return lower, upper
  242. def gamma_conf_interval_mle(data, conf=0.95) -> Tuple[float, float]:
  243. if len(data) == 0:
  244. return nan, nan
  245. if len(data) == 1:
  246. return nan, nan
  247. if any(d < 0 or isnan(d) for d in data):
  248. return nan, nan
  249. if numpy.var(data) == 0:
  250. return numpy.mean(data).item(), 0
  251. alpha, _, scale = scipy.stats.gamma.fit(data, floc=0)
  252. lower, upper = scipy.stats.gamma.interval(alpha=conf, a=alpha, scale=scale)
  253. if lower < 0:
  254. lower = 0
  255. if upper < 0:
  256. upper = 0
  257. return lower, upper
  258. beta_quantile_cache = cachetools.LRUCache(maxsize=10)
  259. @cachetools.cached(cache=beta_quantile_cache, key=lambda x1, p1, x2, p2, guess: (x1, x2, p1, p2))
  260. def beta_parameters_quantiles(x1, p1, x2, p2, guess=(3, 3)):
  261. "Find parameters for a beta random variable X; so; that; P(X > x1) = p1 and P(X > x2) = p2.; "
  262. def square(x):
  263. return x * x
  264. def objective(v):
  265. (a, b) = v
  266. temp = square(scipy.stats.beta.cdf(x1, a, b) - p1)
  267. temp += square(scipy.stats.beta.cdf(x2, a, b) - p2)
  268. return temp
  269. xopt = scipy.optimize.fmin(objective, guess, disp=False)
  270. return (xopt[0], xopt[1])
  271. def beta_conf_interval_quantile(data, conf=0.95, quantiles=(0.25, 0.75)):
  272. if len(data) <= 1:
  273. return 0, 1 # overestimates the interval
  274. mu = numpy.mean(data)
  275. v = numpy.var(data)
  276. data = numpy.array(data)
  277. if v == 0:
  278. return mu, mu
  279. lower = numpy.quantile(data, quantiles[0])
  280. upper = numpy.quantile(data, quantiles[1])
  281. alpha_guess = mu ** 2 * ((1 - mu) / v - 1 / mu)
  282. beta_guess = alpha_guess * (1 / mu - 1)
  283. alpha, beta = beta_parameters_quantiles(lower, quantiles[0], upper, quantiles[1], (alpha_guess, beta_guess))
  284. return scipy.stats.beta.interval(alpha=conf, a=alpha, b=beta)
  285. def beta_stats_quantile(data, quantiles=(0.25, 0.75)):
  286. if len(data) <= 1:
  287. return 0, 1 # overestimates the interval
  288. data = numpy.array(data)
  289. mu = numpy.mean(data)
  290. v = numpy.var(data)
  291. if v == 0:
  292. return mu, mu
  293. lower = numpy.quantile(data, quantiles[0])
  294. upper = numpy.quantile(data, quantiles[1])
  295. alpha_guess = mu ** 2 * ((1 - mu) / v - 1 / mu)
  296. beta_guess = alpha_guess * (1 / mu - 1)
  297. alpha, beta = beta_parameters_quantiles(lower, quantiles[0], upper, quantiles[1], (alpha_guess, beta_guess))
  298. return scipy.stats.beta.stats(a=alpha, b=beta)
  299. def beta_stats_mle(data):
  300. if len(data) == 0:
  301. return nan, nan
  302. if len(data) == 1:
  303. return nan, nan
  304. if any(d < 0 or d > 1 or isnan(d) for d in data):
  305. return nan, nan
  306. if numpy.var(data) == 0:
  307. return numpy.mean(data), 0
  308. epsilon = 1e-4
  309. # adjusted_data = data.copy()
  310. # for idx in range(len(adjusted_data)):
  311. # adjusted_data[idx] *= (1 - 2 * epsilon)
  312. # adjusted_data[idx] += epsilon
  313. alpha, beta, _, _ = scipy.stats.beta.fit(data, floc=-epsilon, fscale=1 + 2 * epsilon)
  314. return scipy.stats.beta.stats(a=alpha, b=beta)
  315. def gamma_stats_mle(data):
  316. if len(data) == 0:
  317. return nan, nan
  318. if len(data) == 1:
  319. return nan, nan
  320. if any(d < 0 or isnan(d) for d in data):
  321. return nan, nan
  322. if numpy.var(data) == 0:
  323. return numpy.mean(data), 0
  324. alpha, _, scale = scipy.stats.gamma.fit(data, floc=0)
  325. return scipy.stats.gamma.stats(a=alpha, scale=scale)
  326. beta_stats = beta_stats_quantile
  327. beta_conf_interval = beta_conf_interval_quantile
  328. gamma_stats = gamma_stats_mle
  329. gamma_conf_interval = gamma_conf_interval_mle
  330. def split_df_list(df, target_column):
  331. """
  332. df = data frame to split,
  333. target_column = the column containing the values to split
  334. separator = the symbol used to perform the split
  335. returns: a data frame with each entry for the target column separated, with each element moved into a new row.
  336. The values in the other columns are duplicated across the newly divided rows.
  337. SOURCE: https://gist.github.com/jlln/338b4b0b55bd6984f883
  338. """
  339. def split_list_to_rows(row, row_accumulator):
  340. split_row = json.loads(row[target_column])
  341. for s in split_row:
  342. new_row = row.to_dict()
  343. new_row[target_column] = s
  344. row_accumulator.append(new_row)
  345. new_rows = []
  346. df.apply(split_list_to_rows, axis=1, args=(new_rows,))
  347. new_df = pandas.DataFrame(new_rows)
  348. return new_df
  349. try:
  350. import winsound as win_sound
  351. def beep(*args, **kwargs):
  352. win_sound.Beep(*args, **kwargs)
  353. except ImportError:
  354. win_sound = None
  355. def beep(*_args, **_kwargs):
  356. pass
  357. def round_to_digits(x, d):
  358. if x == 0:
  359. return 0
  360. if isnan(x):
  361. return nan
  362. try:
  363. return round(x, d - 1 - int(floor(log10(abs(x)))))
  364. except OverflowError:
  365. return x
  366. def gc_if_memory_error(f, *args, **kwargs):
  367. try:
  368. return f(*args, **kwargs)
  369. except MemoryError:
  370. print('Starting garbage collector')
  371. gc.collect()
  372. return f(*args, **kwargs)
  373. def assert_not_empty(x):
  374. assert len(x)
  375. return x
  376. def validation_steps(validation_dataset_size, maximum_batch_size):
  377. batch_size = gcd(validation_dataset_size, maximum_batch_size)
  378. steps = validation_dataset_size // batch_size
  379. assert batch_size * steps == validation_dataset_size
  380. return batch_size, steps
  381. def functional_dependency_trigger(connection: sqlite3.Connection,
  382. table_name: str,
  383. determining_columns: List[str],
  384. determined_columns: List[str],
  385. exist_ok: bool, ):
  386. cursor = connection.cursor()
  387. # possible_performance_improvements
  388. determined_columns = [c for c in determined_columns if c not in determining_columns]
  389. trigger_base_name = '_'.join([table_name] + determining_columns + ['determine'] + determined_columns)
  390. error_message = ','.join(determining_columns) + ' must uniquely identify ' + ','.join(determined_columns)
  391. # when inserting check if there is already an entry with these values
  392. cursor.execute(f'''
  393. CREATE TRIGGER {'IF NOT EXISTS' if exist_ok else ''} {trigger_base_name}_after_insert
  394. BEFORE INSERT ON {table_name}
  395. WHEN EXISTS(SELECT * FROM {table_name}
  396. WHERE ({' AND '.join(f'NEW.{c} IS NOT NULL AND {c} = NEW.{c}' for c in determining_columns)})
  397. AND ({' OR '.join(f'{c} != NEW.{c}' for c in determined_columns)}))
  398. BEGIN SELECT RAISE(ROLLBACK, '{error_message}'); END
  399. ''')
  400. # when updating check if there is already an entry with these values (only if changed)
  401. cursor.execute(f'''
  402. CREATE TRIGGER {'IF NOT EXISTS' if exist_ok else ''} {trigger_base_name}_after_update
  403. BEFORE UPDATE ON {table_name}
  404. WHEN EXISTS(SELECT * FROM {table_name}
  405. WHERE ({' AND '.join(f'NEW.{c} IS NOT NULL AND {c} = NEW.{c}' for c in determining_columns)})
  406. AND ({' OR '.join(f'{c} != NEW.{c}' for c in determined_columns)}))
  407. BEGIN SELECT RAISE(ROLLBACK, '{error_message}'); END
  408. ''')
  409. def heatmap_from_points(x, y,
  410. x_lim: Optional[Union[int, Tuple[int, int]]] = None,
  411. y_lim: Optional[Union[int, Tuple[int, int]]] = None,
  412. gridsize=30):
  413. if isinstance(x_lim, Number):
  414. x_lim = (x_lim, x_lim)
  415. if isinstance(y_lim, Number):
  416. y_lim = (y_lim, y_lim)
  417. plt.hexbin(x, y, gridsize=gridsize, cmap=matplotlib.cm.jet, bins=None)
  418. if x_lim is not None:
  419. plt.xlim(x_lim)
  420. if y_lim is not None:
  421. plt.ylim(y_lim)
  422. cb = plt.colorbar()
  423. cb.set_label('mean value')
  424. def strptime(date_string, fmt):
  425. return datetime.datetime(*(time.strptime(date_string, fmt)[0:6]))
  426. class PrintLineRLock(RLock().__class__):
  427. def __init__(self, *args, name='', **kwargs):
  428. # noinspection PyArgumentList
  429. super().__init__(*args, **kwargs)
  430. self.name = name
  431. def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
  432. print(f'Trying to acquire Lock {self.name}')
  433. result = RLock.acquire(self, blocking, timeout)
  434. print(f'Acquired Lock {self.name}')
  435. return result
  436. def release(self) -> None:
  437. print(f'Trying to release Lock {self.name}')
  438. # noinspection PyNoneFunctionAssignment
  439. result = RLock.release(self)
  440. print(f'Released Lock {self.name}')
  441. return result
  442. def __enter__(self, *args, **kwargs):
  443. print('Trying to enter Lock')
  444. # noinspection PyArgumentList
  445. super().__enter__(*args, **kwargs)
  446. print('Entered Lock')
  447. def __exit__(self, *args, **kwargs):
  448. print('Trying to exit Lock')
  449. super().__exit__(*args, **kwargs)
  450. print('Exited Lock')
  451. def fixed_get_current_frames():
  452. """Return current threads prepared for
  453. further processing.
  454. """
  455. threads = {thread.ident: thread for thread in threading.enumerate()}
  456. return {
  457. thread_id: {
  458. 'frame': hanging_threads.thread2list(frame),
  459. 'time': None,
  460. 'id': thread_id,
  461. 'name': threads[thread_id].name,
  462. 'object': threads[thread_id]
  463. } for thread_id, frame in sys._current_frames().items()
  464. if thread_id in threads # otherwise keyerrors might happen because of race conditions
  465. }
  466. hanging_threads.get_current_frames = fixed_get_current_frames
  467. class CallCounter():
  468. def __init__(self, f):
  469. self.f = f
  470. self.calls = 0
  471. self.__name__ = f.__name__
  472. def __call__(self, *args, **kwargs):
  473. self.calls += 1
  474. return self.f(*args, **kwargs)
  475. def __str__(self):
  476. return str(self.__dict__)
  477. def __repr__(self):
  478. return self.__class__.__name__ + repr(self.__dict__)
  479. def test_with_timeout(timeout=2):
  480. def wrapper(f):
  481. from lib.threading_timer_decorator import exit_after
  482. f = exit_after(timeout)(f)
  483. @functools.wraps(f)
  484. def wrapped(*args, **kwargs):
  485. try:
  486. print(f'Running this test with timeout: {timeout}')
  487. return f(*args, **kwargs)
  488. except KeyboardInterrupt:
  489. raise AssertionError(f'Test took longer than {timeout} seconds')
  490. return wrapped
  491. return wrapper
  492. def lru_cache_by_id(maxsize):
  493. return cachetools.cached(cachetools.LRUCache(maxsize=maxsize), key=id)
  494. class EquivalenceRelation:
  495. def equivalent(self, a, b) -> bool:
  496. raise NotImplementedError('Abstract method')
  497. def equivalence_classes(self, xs: list):
  498. classes = []
  499. for x in xs:
  500. for c in classes:
  501. if self.equivalent(x, c[0]):
  502. c.append(x)
  503. break
  504. else:
  505. classes.append([x])
  506. return classes
  507. def check_reflexivity_on_dataset(self, xs):
  508. for x in xs:
  509. if not self.equivalent(x, x):
  510. return False
  511. return True
  512. def check_symmetry_on_dataset(self, xs):
  513. for x in xs:
  514. for y in xs:
  515. if x is y:
  516. continue
  517. if self.equivalent(x, y) and not self.equivalent(y, x):
  518. return False
  519. return True
  520. def check_axioms_on_dataset(self, xs):
  521. return (
  522. self.check_reflexivity_on_dataset(xs)
  523. and self.check_symmetry_on_dataset(xs)
  524. and self.check_transitivity_on_dataset(xs, assume_symmetry=True, assume_reflexivity=True)
  525. )
  526. def check_transitivity_on_dataset(self, xs, assume_symmetry=False, assume_reflexivity=False):
  527. for x_idx, x in enumerate(xs):
  528. for y_idx, y in enumerate(xs):
  529. if x is y:
  530. continue
  531. if self.equivalent(x, y):
  532. for z_idx, z in enumerate(xs):
  533. if y is z:
  534. continue
  535. if assume_symmetry and x_idx > z_idx:
  536. continue
  537. if assume_reflexivity and x is z:
  538. continue
  539. if self.equivalent(y, z):
  540. if not self.equivalent(x, z):
  541. return False
  542. return True
  543. def match_lists(self, xs, ys, filter_minimum_size=0, filter_maximum_size=math.inf):
  544. xs = list(xs)
  545. ys = list(ys)
  546. if any(x is y for x in xs for y in ys):
  547. raise ValueError('Lists contain the same element. This is currently not supported.')
  548. classes = self.equivalence_classes([*xs, *ys])
  549. return [
  550. [
  551. (0 if any(x2 is x for x2 in xs) else 1, x)
  552. for x in c
  553. ]
  554. for c in classes[::-1]
  555. if filter_minimum_size <= len(c) <= filter_maximum_size
  556. ]
  557. def iff_patch(patch: mock._patch):
  558. def decorator(f):
  559. def wrapped(*args, **kwargs):
  560. with patch:
  561. f(*args, **kwargs)
  562. try:
  563. f(*args, **kwargs)
  564. except:
  565. pass
  566. else:
  567. raise AssertionError('Test did not fail without patch')
  568. return wrapped
  569. return decorator
  570. def iff_not_patch(patch: mock._patch):
  571. def decorator(f):
  572. def wrapped(*args, **kwargs):
  573. f(*args, **kwargs)
  574. try:
  575. with patch:
  576. f(*args, **kwargs)
  577. except Exception as e:
  578. pass
  579. else:
  580. raise AssertionError('Test did not fail with patch')
  581. return wrapped
  582. return decorator
  583. EMAIL_CRASHES_TO = []
  584. VOICE_CALL_ON_CRASH: List[Tuple[str, str]] = []
  585. def list_logger(base_logging_function, store_in_list: list):
  586. def print_and_store(*args, **kwargs):
  587. base_logging_function(*args, **kwargs)
  588. store_in_list.extend(args)
  589. return print_and_store
  590. def main_wrapper(f):
  591. @functools.wraps(f)
  592. def wrapper(*args, **kwargs):
  593. start = time.perf_counter()
  594. # import lib.stack_tracer
  595. import __main__
  596. # does not help much
  597. # monitoring_thread = hanging_threads.start_monitoring(seconds_frozen=180, test_interval=1000)
  598. os.makedirs('logs', exist_ok=True)
  599. stack_tracer.trace_start('logs/' + os.path.split(__main__.__file__)[-1] + '.html', interval=5)
  600. faulthandler.enable()
  601. profile_wall_time_instead_if_profiling()
  602. # noinspection PyBroadException
  603. try:
  604. return f(*args, **kwargs)
  605. except KeyboardInterrupt:
  606. error_messages = []
  607. print_exc_plus.print_exc_plus(print=list_logger(logging.error, error_messages),
  608. serialize_to='logs/' + os.path.split(__main__.__file__)[-1] + '.dill')
  609. except:
  610. error_messages = []
  611. print_exc_plus.print_exc_plus(print=list_logger(logging.error, error_messages),
  612. serialize_to='logs/' + os.path.split(__main__.__file__)[-1] + '.dill')
  613. for recipient in EMAIL_CRASHES_TO:
  614. from jobs.sending_emails import send_mail
  615. send_mail.create_simple_mail_via_gmail(body='\n'.join(error_messages), filepath=None, excel_name=None, to_mail=recipient, subject='[python] Crash report')
  616. for to_number, from_number in VOICE_CALL_ON_CRASH:
  617. logging.info(f'Calling {from_number} to notify about the crash.')
  618. voice_call('This is a notification message that one of your python scripts has crashed. If you are unsure about the origin of this call, please contact Eren Yilmaz.',
  619. to_number, from_number)
  620. finally:
  621. logging.info('Terminated.')
  622. total_time = time.perf_counter() - start
  623. faulthandler.disable()
  624. stack_tracer.trace_stop()
  625. frequency = 2000
  626. duration = 500
  627. beep(frequency, duration)
  628. print('Total time', total_time)
  629. try:
  630. from algorithm_development.metatrader import ZeroMQ_Connector
  631. ZeroMQ_Connector.DWX_ZeroMQ_Connector.deactivate_all()
  632. except ImportError:
  633. pass
  634. return wrapper
  635. def voice_call(msg, to_number, from_number):
  636. from twilio.rest import Client
  637. account_sid = 'AC63c459168c3e4fe34e462acb4f44f748'
  638. auth_token = 'b633bc0e945fe7cb737fdac395cc71d6'
  639. client = Client(account_sid, auth_token)
  640. call = client.calls.create(
  641. twiml=f'<Response><Say>{msg}</Say></Response>',
  642. from_=from_number,
  643. to=to_number,
  644. )
  645. print(call.sid)
  646. def required_size_for_safe_rotation(base: Tuple[X, Y, Z], rotate_range_deg) -> Tuple[X, Y, Z]:
  647. if abs(rotate_range_deg) > 45:
  648. raise NotImplementedError
  649. if abs(rotate_range_deg) > 0:
  650. x_length = base[2] * math.sin(rotate_range_deg / 180 * math.pi) + base[1] * math.cos(
  651. rotate_range_deg / 180 * math.pi)
  652. y_length = base[2] * math.cos(rotate_range_deg / 180 * math.pi) + base[1] * math.sin(
  653. rotate_range_deg / 180 * math.pi)
  654. result = (base[0],
  655. x_length,
  656. y_length,)
  657. else:
  658. result = base
  659. return result
  660. def round_to_closest_value(x, values, assume_sorted=False):
  661. if not assume_sorted:
  662. values = sorted(values)
  663. next_largest = bisect_left(values, x) # binary search
  664. if next_largest == 0:
  665. return values[0]
  666. if next_largest == len(values):
  667. return values[-1]
  668. next_smallest = next_largest - 1
  669. smaller = values[next_smallest]
  670. larger = values[next_largest]
  671. if abs(smaller - x) < abs(larger - x):
  672. return smaller
  673. else:
  674. return larger
  675. def binary_search(a, x, lo=0, hi=None):
  676. hi = hi if hi is not None else len(a) # hi defaults to len(a)
  677. pos = bisect_left(a, x, lo, hi) # find insertion position
  678. return pos if pos != hi and a[pos] == x else -1 # don't walk off the end
  679. def ceil_to_closest_value(x, values):
  680. values = sorted(values)
  681. next_largest = bisect_left(values, x) # binary search
  682. if next_largest < len(values):
  683. return values[next_largest]
  684. else:
  685. return values[-1] # if there is no larger value use the largest one
  686. def print_progress_bar(iteration, total, prefix='Progress:', suffix='', decimals=1, length=50, fill='█',
  687. print_eta=True):
  688. """
  689. Call in a loop to create terminal progress bar
  690. @params:
  691. iteration - Required : current iteration (Int)
  692. total - Required : total iterations (Int)
  693. prefix - Optional : prefix string (Str)
  694. suffix - Optional : suffix string (Str)
  695. decimals - Optional : positive number of decimals in percent complete (Int)
  696. length - Optional : character length of bar (Int)
  697. fill - Optional : bar fill character (Str)
  698. """
  699. percent = ("{0:" + str(4 + decimals) + "." + str(decimals) + "f}").format(100 * (iteration / float(total)))
  700. filled_length = int(length * iteration // total)
  701. bar = fill * filled_length + '-' * (length - filled_length)
  702. if getattr(print_progress_bar, 'last_printed_value', None) == (prefix, bar, percent, suffix):
  703. return
  704. print_progress_bar.last_printed_value = (prefix, bar, percent, suffix)
  705. print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end='')
  706. # Print New Line on Complete
  707. if iteration == total:
  708. print()
  709. def get_all_subclasses(klass):
  710. all_subclasses = []
  711. for subclass in klass.__subclasses__():
  712. all_subclasses.append(subclass)
  713. all_subclasses.extend(get_all_subclasses(subclass))
  714. return all_subclasses
  715. def my_mac_address():
  716. """
  717. https://stackoverflow.com/a/160821
  718. """
  719. import uuid
  720. mac = uuid.getnode()
  721. if (mac >> 40) % 2:
  722. return None
  723. mac = uuid.UUID(int=mac).hex[-12:]
  724. return mac
  725. def latin1_json(data):
  726. return json.dumps(data, ensure_ascii=False).encode('latin-1')
  727. def l2_norm(v1, v2):
  728. if len(v1) != len(v2):
  729. raise ValueError('Both vectors must be of the same size')
  730. return math.sqrt(sum([(x1 - x2) * (x1 - x2) for x1, x2 in zip(v1, v2)]))
  731. def allow_additional_unused_keyword_arguments(func):
  732. @functools.wraps(func)
  733. def wrapper(*args, **kwargs):
  734. allowed_kwargs = [param.name for param in inspect.signature(func).parameters.values()]
  735. allowed_kwargs = {a: kwargs[a] for a in kwargs if a in allowed_kwargs}
  736. return func(*args, **allowed_kwargs)
  737. return wrapper
  738. def copy_and_rename_method(func, new_name):
  739. funcdetails = [
  740. func.__code__,
  741. func.__globals__,
  742. func.__name__,
  743. func.__defaults__,
  744. func.__closure__
  745. ]
  746. old_name = func.__name__
  747. # copy
  748. # new_func = dill.loads(dill.dumps(func))
  749. new_func = FunctionType(*funcdetails)
  750. assert new_func is not funcdetails
  751. # rename
  752. new_func.__name__ = new_name
  753. assert func.__name__ is old_name
  754. return new_func
  755. def rename(new_name):
  756. def decorator(f):
  757. f.__name__ = new_name
  758. return f
  759. return decorator
  760. class LogicError(Exception):
  761. pass
  762. def round_time(dt=None, precision=60):
  763. """Round a datetime object to any time lapse in seconds
  764. dt : datetime.datetime object, default now.
  765. roundTo : Closest number of seconds to round to, default 1 minute.
  766. Author: Thierry Husson 2012 - Use it as you want but don't blame me.
  767. """
  768. if dt is None:
  769. dt = datetime.datetime.now()
  770. if isinstance(precision, datetime.timedelta):
  771. precision = precision.total_seconds()
  772. seconds = (dt.replace(tzinfo=None) - dt.min).seconds
  773. rounding = (seconds + precision / 2) // precision * precision
  774. return dt + datetime.timedelta(seconds=rounding - seconds,
  775. microseconds=dt.microsecond)
  776. def chunks(lst, n):
  777. """Yield successive n-sized chunks from lst."""
  778. for i in range(0, len(lst), n):
  779. yield lst[i:i + n]
  780. def shorten_name(name):
  781. name = re.sub(r'\s+', r' ', str(name))
  782. name = name.replace(', ', ',')
  783. name = name.replace(', ', ',')
  784. name = name.replace(' ', '_')
  785. return re.sub(r'([A-Za-z])[a-z]*_?', r'\1', str(name))
  786. def array_analysis(a: numpy.ndarray):
  787. print(f' Shape: {a.shape}')
  788. mean = a.mean()
  789. print(f' Mean: {mean}')
  790. print(f' Std: {a.std()}')
  791. print(f' Min, Max: {a.min(), a.max()}')
  792. print(f' Mean absolute: {numpy.abs(a).mean()}')
  793. print(f' Mean square: {numpy.square(a).mean()}')
  794. print(f' Mean absolute difference from mean: {numpy.abs(a - mean).mean()}')
  795. print(f' Mean squared difference from mean: {numpy.square(a - mean).mean()}')
  796. nonzero = numpy.count_nonzero(a)
  797. print(f' Number of non-zeros: {nonzero}')
  798. print(f' Number of zeros: {numpy.prod(a.shape) - nonzero}')
  799. if a.shape[-1] > 1 and a.shape[-1] <= 1000:
  800. # last dim is probably the number of classes
  801. print(f' Class counts: {numpy.count_nonzero(a, axis=tuple(range(len(a.shape) - 1)))}')
  802. def current_year_begin():
  803. return datetime.datetime(datetime.datetime.today().year, 1, 1).timestamp()
  804. def current_day_begin():
  805. return datetime.datetime.today().timestamp() // (3600 * 24) * (3600 * 24)
  806. def current_second_begin():
  807. return floor(datetime.datetime.today().timestamp())
  808. def running_workers(executor):
  809. print(next(iter(executor._threads)).__dict__)
  810. return sum(1 for t in executor._threads
  811. if t == 1)
  812. def queued_calls(executor):
  813. return len(executor._work_queue.queue)
  814. def retry_on_error(max_tries=3, delay=0.5, backoff=2, only_error_classes=Exception):
  815. def decorator(func):
  816. @functools.wraps(func)
  817. def wrapper(*args, **kwargs):
  818. for i in range(max_tries):
  819. try:
  820. return func(*args, **kwargs)
  821. except only_error_classes as e:
  822. if i == max_tries - 1:
  823. raise
  824. logging.error(f'Re-try after error in {func.__name__}: {type(e).__name__}, {e}')
  825. time.sleep(delay * (backoff ** i))
  826. return wrapper
  827. return decorator
  828. class EBC:
  829. SUBCLASSES_BY_NAME: Dict[str, Type['EBC']] = {}
  830. def __init_subclass__(cls, **kwargs):
  831. super().__init_subclass__(**kwargs)
  832. EBC.SUBCLASSES_BY_NAME[cls.__name__] = cls
  833. def __eq__(self, other):
  834. return isinstance(other, type(self)) and self.__dict__ == other.__dict__
  835. def __str__(self):
  836. return str(self.__dict__)
  837. def __repr__(self):
  838. return f'{type(self).__name__}(**' + str(self.__dict__) + ')'
  839. def to_json(self) -> Dict[str, Any]:
  840. result: Dict[str, Any] = {
  841. 'type': type(self).__name__,
  842. **self.__dict__,
  843. }
  844. for k in result:
  845. if isinstance(result[k], EBC):
  846. result[k] = result[k].to_json()
  847. elif isinstance(result[k], numpy.ndarray):
  848. result[k] = result[k].tolist()
  849. elif isinstance(result[k], list):
  850. result[k] = [r.to_json() if isinstance(r, EBC) else r
  851. for r in result[k]]
  852. return result
  853. @staticmethod
  854. def from_json(data: Dict[str, Any]):
  855. cls = EBC.SUBCLASSES_BY_NAME[data['type']]
  856. return class_from_json(cls, data)
  857. def class_from_json(cls, data: Dict[str, Any]):
  858. if isinstance(data, str):
  859. data = json.loads(data)
  860. # noinspection PyArgumentList
  861. try:
  862. return cls(**data)
  863. except TypeError as e:
  864. if "__init__() got an unexpected keyword argument 'type'" in str(e) or 'takes no arguments' in str(e):
  865. # probably this was from a to_json method
  866. if data['type'] != cls.__name__:
  867. t = data['type']
  868. logging.warning(f'Reconstructing a {cls.__name__} from a dict with type={t}')
  869. data = data.copy()
  870. del data['type']
  871. for k,v in data.items():
  872. if probably_serialized_from_ebc(v):
  873. data[k] = EBC.SUBCLASSES_BY_NAME[v['type']].from_json(v)
  874. elif isinstance(v, list):
  875. data[k] = [EBC.SUBCLASSES_BY_NAME[x['type']].from_json(x)
  876. if probably_serialized_from_ebc(x)
  877. else x
  878. for x in v]
  879. return allow_additional_unused_keyword_arguments(cls)(**data)
  880. else:
  881. raise
  882. def probably_serialized_from_ebc(data):
  883. return isinstance(data, dict) and 'type' in data and data['type'] in EBC.SUBCLASSES_BY_NAME
  884. class EBE(Enum):
  885. def __int__(self):
  886. return self.value
  887. def __str__(self):
  888. return self.name
  889. def __repr__(self):
  890. return self.name
  891. @classmethod
  892. def from_name(cls, variable_name):
  893. return cls.__dict__[variable_name]
  894. class Bunch(dict, EBC):
  895. def __init__(self, **kwargs):
  896. dict.__init__(self, kwargs)
  897. self.__dict__.update(kwargs)
  898. def add_method(self, m):
  899. setattr(self, m.__name__, functools.partial(m, self))
  900. def floor_to_multiple_of(x, multiple_of):
  901. return math.floor(x / multiple_of) * multiple_of
  902. def round_to_multiple_of(x, multiple_of):
  903. return round(x / multiple_of) * multiple_of
  904. def ceil_to_multiple_of(x, multiple_of):
  905. return math.ceil(x / multiple_of) * multiple_of