util.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143
  1. import datetime
  2. import faulthandler
  3. import functools
  4. import gc
  5. import inspect
  6. import json
  7. import math
  8. import os
  9. import random
  10. import re
  11. import sqlite3
  12. import sys
  13. import threading
  14. import time
  15. from bisect import bisect_left
  16. from enum import Enum
  17. from itertools import chain, combinations
  18. from math import log, isnan, nan, floor, log10, gcd
  19. from numbers import Number
  20. from shutil import copyfile
  21. from threading import RLock
  22. from types import FunctionType
  23. from typing import Union, Tuple, List, Optional, Dict, Any, Type
  24. # noinspection PyUnresolvedReferences
  25. from unittest import TestCase, mock
  26. import cachetools
  27. import hanging_threads
  28. import matplotlib.cm
  29. import matplotlib.pyplot as plt
  30. import numpy
  31. import numpy as np
  32. import pandas
  33. import scipy.optimize
  34. import scipy.stats
  35. import tabulate
  36. from scipy.ndimage import zoom
  37. X = Y = Z = float
  38. class KnownIssue(Exception):
  39. """
  40. This means the code is not working and should not be used but still too valuable to be deleted
  41. """
  42. pass
  43. def powerset(iterable):
  44. """powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"""
  45. s = list(iterable)
  46. return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
  47. def plot_with_conf(x, y_mean, y_conf, alpha=0.5, **kwargs):
  48. ax = kwargs.pop('ax', plt.gca())
  49. base_line, = ax.plot(x, y_mean, **kwargs)
  50. y_mean = np.array(y_mean)
  51. y_conf = np.array(y_conf)
  52. lb = y_mean - y_conf
  53. ub = y_mean + y_conf
  54. ax.fill_between(x, lb, ub, facecolor=base_line.get_color(), alpha=alpha)
  55. def choice(sequence, probabilities):
  56. # if sum(probabilities) != 1:
  57. # raise AssertionError('Probabilities must sum to 1')
  58. r = random.random()
  59. for idx, c in enumerate(sequence):
  60. r -= probabilities[idx]
  61. if r < 0:
  62. return c
  63. raise AssertionError('Probabilities must sum to 1')
  64. def print_attributes(obj, include_methods=False, ignore=None):
  65. if ignore is None:
  66. ignore = []
  67. for attr in dir(obj):
  68. if attr in ignore:
  69. continue
  70. if attr.startswith('_'):
  71. continue
  72. if not include_methods and callable(obj.__getattr__(attr)):
  73. continue
  74. print(attr, ':', obj.__getattr__(attr).__class__.__name__, ':', obj.__getattr__(attr))
  75. def attr_dir(obj, include_methods=False, ignore=None):
  76. if ignore is None:
  77. ignore = []
  78. return {attr: obj.__getattr__(attr)
  79. for attr in dir(obj)
  80. if not attr.startswith('_') and (
  81. include_methods or not callable(obj.__getattr__(attr))) and attr not in ignore}
  82. def zoom_to_shape(a: np.ndarray, shape: Tuple, mode: str = 'smooth', verbose=1):
  83. from keras import backend
  84. a = np.array(a, dtype=backend.floatx()) # also does a copy
  85. shape_dim = len(a.shape)
  86. if len(a.shape) != len(shape):
  87. raise ValueError('The shapes must have the same dimension but were len({0}) = {1} (original) '
  88. 'and len({2}) = {3} desired.'.format(a.shape, len(a.shape), shape, len(shape)))
  89. if len(shape) == 0:
  90. return a
  91. zoom_factors = tuple(shape[idx] / a.shape[idx] for idx in range(shape_dim))
  92. def _current_index_in_old_array():
  93. return tuple(slice(0, length) if axis != current_axis else slice(current_pixel_index, current_pixel_index + 1)
  94. for axis, length in enumerate(a.shape))
  95. def _current_pixel_shape():
  96. return tuple(length if axis != current_axis else 1
  97. for axis, length in enumerate(a.shape))
  98. def _current_result_index():
  99. return tuple(
  100. slice(0, length) if axis != current_axis else slice(pixel_index_in_result, pixel_index_in_result + 1)
  101. for axis, length in enumerate(a.shape))
  102. def _current_result_shape():
  103. return tuple(orig_length if axis != current_axis else shape[axis]
  104. for axis, orig_length in enumerate(a.shape))
  105. if mode == 'constant':
  106. result = zoom(a, zoom_factors)
  107. assert result.shape == shape
  108. return result
  109. elif mode == 'smooth':
  110. result = a
  111. for current_axis, zoom_factor in sorted(enumerate(zoom_factors), key=lambda x: x[1]):
  112. result = np.zeros(_current_result_shape(), dtype=backend.floatx())
  113. # current_length = a.shape[current_axis]
  114. desired_length = shape[current_axis]
  115. current_pixel_index = 0
  116. current_pixel_part = 0 # how much of the current pixel is already read
  117. for pixel_index_in_result in range(desired_length):
  118. pixels_remaining = 1 / zoom_factor
  119. pixel_sum = np.zeros(_current_pixel_shape())
  120. while pixels_remaining + current_pixel_part > 1:
  121. pixel_sum += (1 - current_pixel_part) * a[_current_index_in_old_array()]
  122. current_pixel_index += 1
  123. pixels_remaining -= (1 - current_pixel_part)
  124. current_pixel_part = 0
  125. # the remaining pixel_part
  126. try:
  127. pixel_sum += pixels_remaining * a[_current_index_in_old_array()]
  128. except (IndexError, ValueError):
  129. if verbose:
  130. print('WARNING: Skipping {0} pixels because of numerical imprecision.'.format(pixels_remaining))
  131. else:
  132. current_pixel_part += pixels_remaining
  133. # insert to result
  134. pixel_sum *= zoom_factor
  135. result[_current_result_index()] = pixel_sum
  136. a = result
  137. assert result.shape == shape
  138. return result
  139. else:
  140. return NotImplementedError('Mode not available.')
  141. def profile_wall_time_instead_if_profiling():
  142. try:
  143. import yappi
  144. except ModuleNotFoundError:
  145. return
  146. currently_profiling = len(yappi.get_func_stats())
  147. if currently_profiling and yappi.get_clock_type() != 'wall':
  148. yappi.stop()
  149. print('Profiling wall time instead of cpu time.')
  150. yappi.clear_stats()
  151. yappi.set_clock_type("wall")
  152. yappi.start()
  153. def dummy_computation(*_args, **_kwargs):
  154. pass
  155. def backup_file(filename):
  156. copyfile(filename, backup_file_path(filename))
  157. def backup_file_path(filename):
  158. return filename + time.strftime("%Y%m%d") + '.bak'
  159. # noinspection SpellCheckingInspection
  160. def my_tabulate(data, tablefmt='pipe', **params):
  161. if data == [] and 'headers' in params:
  162. data = [(None for _ in params['headers'])]
  163. tabulate.MIN_PADDING = 0
  164. return tabulate.tabulate(data, tablefmt=tablefmt, **params)
  165. def ce_loss(y_true, y_predicted):
  166. return -(y_true * log(y_predicted) + (1 - y_true) * log(1 - y_predicted))
  167. class DontSaveResultsError(Exception):
  168. pass
  169. def multinomial(n, bins):
  170. if bins == 0:
  171. if n > 0:
  172. raise ValueError('Cannot distribute to 0 bins.')
  173. return []
  174. remaining = n
  175. results = []
  176. for i in range(bins - 1):
  177. from numpy.random.mtrand import binomial
  178. x = binomial(remaining, 1 / (bins - i))
  179. results.append(x)
  180. remaining -= x
  181. results.append(remaining)
  182. return results
  183. class UnknownTypeError(Exception):
  184. pass
  185. # def shape_analysis(xs):
  186. # composed_dtypes = [list, tuple, np.ndarray, dict, set]
  187. # base_dtypes = [str, int, float, type, object] # TODO add class and superclass of xs first element
  188. # all_dtypes = composed_dtypes + base_dtypes
  189. # if isinstance(xs, np.ndarray):
  190. # outer_brackets = ('[', ']')
  191. # shape = xs.shape
  192. # dtype = xs.dtype
  193. # elif isinstance(xs, tuple):
  194. # outer_brackets = ('(', ')')
  195. # shape = len(xs)
  196. # dtype = [t for t in all_dtypes if all(isinstance(x, t) for x in xs)][0]
  197. # elif isinstance(xs, list):
  198. # outer_brackets = ('[', ']')
  199. # shape = len(xs)
  200. # dtype = [t for t in all_dtypes if all(isinstance(x, t) for x in xs)][0]
  201. # elif isinstance(xs, dict) or isinstance(xs, set):
  202. # outer_brackets = ('{', '}')
  203. # shape = len(xs)
  204. # dtype = [t for t in all_dtypes if all(isinstance(x, t) for x in xs)][0]
  205. # elif any(isinstance(xs, t) for t in base_dtypes):
  206. # for t in base_dtypes:
  207. # if isinstance(xs, t):
  208. # return str(t.__name__)
  209. # raise AssertionError('This should be unreachable.')
  210. # else:
  211. # raise UnknownTypeError('Unknown type:' + type(xs).__name__)
  212. #
  213. # if shape and shape != '?':
  214. # return outer_brackets[0] + str(xs.shape) + ' * ' + str(dtype) + outer_brackets[1]
  215. # else:
  216. # return outer_brackets[0] + outer_brackets[1]
  217. def beta_conf_interval_mle(data, conf=0.95):
  218. if len(data) <= 1:
  219. return 0, 1 # overestimates the interval
  220. if any(d < 0 or d > 1 or isnan(d) for d in data):
  221. return nan, nan
  222. if numpy.var(data) == 0:
  223. return numpy.mean(data), numpy.mean(data)
  224. epsilon = 1e-3
  225. # adjusted_data = data.copy()
  226. # for idx in range(len(adjusted_data)):
  227. # adjusted_data[idx] *= (1 - 2 * epsilon)
  228. # adjusted_data[idx] += epsilon
  229. alpha, beta, _, _ = scipy.stats.beta.fit(data, floc=-epsilon, fscale=1 + 2 * epsilon)
  230. lower, upper = scipy.stats.beta.interval(alpha=conf, a=alpha, b=beta)
  231. if lower < 0:
  232. lower = 0
  233. if upper < 0:
  234. upper = 0
  235. if lower > 1:
  236. lower = 1
  237. if upper > 1:
  238. upper = 1
  239. return lower, upper
  240. def gamma_conf_interval_mle(data, conf=0.95) -> Tuple[float, float]:
  241. if len(data) == 0:
  242. return nan, nan
  243. if len(data) == 1:
  244. return nan, nan
  245. if any(d < 0 or isnan(d) for d in data):
  246. return nan, nan
  247. if numpy.var(data) == 0:
  248. return numpy.mean(data).item(), 0
  249. alpha, _, scale = scipy.stats.gamma.fit(data, floc=0)
  250. lower, upper = scipy.stats.gamma.interval(alpha=conf, a=alpha, scale=scale)
  251. if lower < 0:
  252. lower = 0
  253. if upper < 0:
  254. upper = 0
  255. return lower, upper
  256. beta_quantile_cache = cachetools.LRUCache(maxsize=10)
  257. @cachetools.cached(cache=beta_quantile_cache, key=lambda x1, p1, x2, p2, guess: (x1, x2, p1, p2))
  258. def beta_parameters_quantiles(x1, p1, x2, p2, guess=(3, 3)):
  259. "Find parameters for a beta random variable X; so; that; P(X > x1) = p1 and P(X > x2) = p2.; "
  260. def square(x):
  261. return x * x
  262. def objective(v):
  263. (a, b) = v
  264. temp = square(scipy.stats.beta.cdf(x1, a, b) - p1)
  265. temp += square(scipy.stats.beta.cdf(x2, a, b) - p2)
  266. return temp
  267. xopt = scipy.optimize.fmin(objective, guess, disp=False)
  268. return (xopt[0], xopt[1])
  269. def beta_conf_interval_quantile(data, conf=0.95, quantiles=(0.25, 0.75)):
  270. if len(data) <= 1:
  271. return 0, 1 # overestimates the interval
  272. mu = numpy.mean(data)
  273. v = numpy.var(data)
  274. data = numpy.array(data)
  275. if v == 0:
  276. return mu, mu
  277. lower = numpy.quantile(data, quantiles[0])
  278. upper = numpy.quantile(data, quantiles[1])
  279. alpha_guess = mu ** 2 * ((1 - mu) / v - 1 / mu)
  280. beta_guess = alpha_guess * (1 / mu - 1)
  281. alpha, beta = beta_parameters_quantiles(lower, quantiles[0], upper, quantiles[1], (alpha_guess, beta_guess))
  282. return scipy.stats.beta.interval(alpha=conf, a=alpha, b=beta)
  283. def beta_stats_quantile(data, quantiles=(0.25, 0.75)):
  284. if len(data) <= 1:
  285. return 0, 1 # overestimates the interval
  286. data = numpy.array(data)
  287. mu = numpy.mean(data)
  288. v = numpy.var(data)
  289. if v == 0:
  290. return mu, mu
  291. lower = numpy.quantile(data, quantiles[0])
  292. upper = numpy.quantile(data, quantiles[1])
  293. alpha_guess = mu ** 2 * ((1 - mu) / v - 1 / mu)
  294. beta_guess = alpha_guess * (1 / mu - 1)
  295. alpha, beta = beta_parameters_quantiles(lower, quantiles[0], upper, quantiles[1], (alpha_guess, beta_guess))
  296. return scipy.stats.beta.stats(a=alpha, b=beta)
  297. def beta_stats_mle(data):
  298. if len(data) == 0:
  299. return nan, nan
  300. if len(data) == 1:
  301. return nan, nan
  302. if any(d < 0 or d > 1 or isnan(d) for d in data):
  303. return nan, nan
  304. if numpy.var(data) == 0:
  305. return numpy.mean(data), 0
  306. epsilon = 1e-4
  307. # adjusted_data = data.copy()
  308. # for idx in range(len(adjusted_data)):
  309. # adjusted_data[idx] *= (1 - 2 * epsilon)
  310. # adjusted_data[idx] += epsilon
  311. alpha, beta, _, _ = scipy.stats.beta.fit(data, floc=-epsilon, fscale=1 + 2 * epsilon)
  312. return scipy.stats.beta.stats(a=alpha, b=beta)
  313. def gamma_stats_mle(data):
  314. if len(data) == 0:
  315. return nan, nan
  316. if len(data) == 1:
  317. return nan, nan
  318. if any(d < 0 or isnan(d) for d in data):
  319. return nan, nan
  320. if numpy.var(data) == 0:
  321. return numpy.mean(data), 0
  322. alpha, _, scale = scipy.stats.gamma.fit(data, floc=0)
  323. return scipy.stats.gamma.stats(a=alpha, scale=scale)
  324. beta_stats = beta_stats_quantile
  325. beta_conf_interval = beta_conf_interval_quantile
  326. gamma_stats = gamma_stats_mle
  327. gamma_conf_interval = gamma_conf_interval_mle
  328. def split_df_list(df, target_column):
  329. """
  330. df = data frame to split,
  331. target_column = the column containing the values to split
  332. separator = the symbol used to perform the split
  333. returns: a data frame with each entry for the target column separated, with each element moved into a new row.
  334. The values in the other columns are duplicated across the newly divided rows.
  335. SOURCE: https://gist.github.com/jlln/338b4b0b55bd6984f883
  336. """
  337. def split_list_to_rows(row, row_accumulator):
  338. split_row = json.loads(row[target_column])
  339. for s in split_row:
  340. new_row = row.to_dict()
  341. new_row[target_column] = s
  342. row_accumulator.append(new_row)
  343. new_rows = []
  344. df.apply(split_list_to_rows, axis=1, args=(new_rows,))
  345. new_df = pandas.DataFrame(new_rows)
  346. return new_df
  347. try:
  348. import winsound as win_sound
  349. def beep(*args, **kwargs):
  350. win_sound.Beep(*args, **kwargs)
  351. except ImportError:
  352. win_sound = None
  353. def beep(*_args, **_kwargs):
  354. pass
  355. def round_to_digits(x, d):
  356. if x == 0:
  357. return 0
  358. if isnan(x):
  359. return nan
  360. try:
  361. return round(x, d - 1 - int(floor(log10(abs(x)))))
  362. except OverflowError:
  363. return x
  364. def gc_if_memory_error(f, *args, **kwargs):
  365. try:
  366. return f(*args, **kwargs)
  367. except MemoryError:
  368. print('Starting garbage collector')
  369. gc.collect()
  370. return f(*args, **kwargs)
  371. def assert_not_empty(x):
  372. assert len(x)
  373. return x
  374. def validation_steps(validation_dataset_size, maximum_batch_size):
  375. batch_size = gcd(validation_dataset_size, maximum_batch_size)
  376. steps = validation_dataset_size // batch_size
  377. assert batch_size * steps == validation_dataset_size
  378. return batch_size, steps
  379. def functional_dependency_trigger(connection: sqlite3.Connection,
  380. table_name: str,
  381. determining_columns: List[str],
  382. determined_columns: List[str],
  383. exist_ok: bool, ):
  384. cursor = connection.cursor()
  385. # possible_performance_improvements
  386. determined_columns = [c for c in determined_columns if c not in determining_columns]
  387. trigger_base_name = '_'.join([table_name] + determining_columns + ['determine'] + determined_columns)
  388. error_message = ','.join(determining_columns) + ' must uniquely identify ' + ','.join(determined_columns)
  389. # when inserting check if there is already an entry with these values
  390. cursor.execute(f'''
  391. CREATE TRIGGER {'IF NOT EXISTS' if exist_ok else ''} {trigger_base_name}_after_insert
  392. BEFORE INSERT ON {table_name}
  393. WHEN EXISTS(SELECT * FROM {table_name}
  394. WHERE ({' AND '.join(f'NEW.{c} IS NOT NULL AND {c} = NEW.{c}' for c in determining_columns)})
  395. AND ({' OR '.join(f'{c} != NEW.{c}' for c in determined_columns)}))
  396. BEGIN SELECT RAISE(ROLLBACK, '{error_message}'); END
  397. ''')
  398. # when updating check if there is already an entry with these values (only if changed)
  399. cursor.execute(f'''
  400. CREATE TRIGGER {'IF NOT EXISTS' if exist_ok else ''} {trigger_base_name}_after_update
  401. BEFORE UPDATE ON {table_name}
  402. WHEN EXISTS(SELECT * FROM {table_name}
  403. WHERE ({' AND '.join(f'NEW.{c} IS NOT NULL AND {c} = NEW.{c}' for c in determining_columns)})
  404. AND ({' OR '.join(f'{c} != NEW.{c}' for c in determined_columns)}))
  405. BEGIN SELECT RAISE(ROLLBACK, '{error_message}'); END
  406. ''')
  407. def heatmap_from_points(x, y,
  408. x_lim: Optional[Union[int, Tuple[int, int]]] = None,
  409. y_lim: Optional[Union[int, Tuple[int, int]]] = None,
  410. gridsize=30):
  411. if isinstance(x_lim, Number):
  412. x_lim = (x_lim, x_lim)
  413. if isinstance(y_lim, Number):
  414. y_lim = (y_lim, y_lim)
  415. plt.hexbin(x, y, gridsize=gridsize, cmap=matplotlib.cm.jet, bins=None)
  416. if x_lim is not None:
  417. plt.xlim(x_lim)
  418. if y_lim is not None:
  419. plt.ylim(y_lim)
  420. cb = plt.colorbar()
  421. cb.set_label('mean value')
  422. def strptime(date_string, fmt):
  423. return datetime.datetime(*(time.strptime(date_string, fmt)[0:6]))
  424. class PrintLineRLock(RLock().__class__):
  425. def __init__(self, *args, name='', **kwargs):
  426. # noinspection PyArgumentList
  427. super().__init__(*args, **kwargs)
  428. self.name = name
  429. def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
  430. print(f'Trying to acquire Lock {self.name}')
  431. result = RLock.acquire(self, blocking, timeout)
  432. print(f'Acquired Lock {self.name}')
  433. return result
  434. def release(self) -> None:
  435. print(f'Trying to release Lock {self.name}')
  436. # noinspection PyNoneFunctionAssignment
  437. result = RLock.release(self)
  438. print(f'Released Lock {self.name}')
  439. return result
  440. def __enter__(self, *args, **kwargs):
  441. print('Trying to enter Lock')
  442. # noinspection PyArgumentList
  443. super().__enter__(*args, **kwargs)
  444. print('Entered Lock')
  445. def __exit__(self, *args, **kwargs):
  446. print('Trying to exit Lock')
  447. super().__exit__(*args, **kwargs)
  448. print('Exited Lock')
  449. def fixed_get_current_frames():
  450. """Return current threads prepared for
  451. further processing.
  452. """
  453. threads = {thread.ident: thread for thread in threading.enumerate()}
  454. return {
  455. thread_id: {
  456. 'frame': hanging_threads.thread2list(frame),
  457. 'time': None,
  458. 'id': thread_id,
  459. 'name': threads[thread_id].name,
  460. 'object': threads[thread_id]
  461. } for thread_id, frame in sys._current_frames().items()
  462. if thread_id in threads # otherwise keyerrors might happen because of race conditions
  463. }
  464. hanging_threads.get_current_frames = fixed_get_current_frames
  465. class CallCounter():
  466. def __init__(self, f):
  467. self.f = f
  468. self.calls = 0
  469. self.__name__ = f.__name__
  470. def __call__(self, *args, **kwargs):
  471. self.calls += 1
  472. return self.f(*args, **kwargs)
  473. def __str__(self):
  474. return str(self.__dict__)
  475. def __repr__(self):
  476. return self.__class__.__name__ + repr(self.__dict__)
  477. def test_with_timeout(timeout=2):
  478. def wrapper(f):
  479. from lib.threading_timer_decorator import exit_after
  480. f = exit_after(timeout)(f)
  481. @functools.wraps(f)
  482. def wrapped(*args, **kwargs):
  483. try:
  484. print(f'Running this test with timeout: {timeout}')
  485. return f(*args, **kwargs)
  486. except KeyboardInterrupt:
  487. raise AssertionError(f'Test took longer than {timeout} seconds')
  488. return wrapped
  489. return wrapper
  490. def lru_cache_by_id(maxsize):
  491. return cachetools.cached(cachetools.LRUCache(maxsize=maxsize), key=id)
  492. class EquivalenceRelation:
  493. def equivalent(self, a, b) -> bool:
  494. raise NotImplementedError('Abstract method')
  495. def equivalence_classes(self, xs: list):
  496. classes = []
  497. for x in xs:
  498. for c in classes:
  499. if self.equivalent(x, c[0]):
  500. c.append(x)
  501. break
  502. else:
  503. classes.append([x])
  504. return classes
  505. def check_reflexivity_on_dataset(self, xs):
  506. for x in xs:
  507. if not self.equivalent(x, x):
  508. return False
  509. return True
  510. def check_symmetry_on_dataset(self, xs):
  511. for x in xs:
  512. for y in xs:
  513. if x is y:
  514. continue
  515. if self.equivalent(x, y) and not self.equivalent(y, x):
  516. return False
  517. return True
  518. def check_axioms_on_dataset(self, xs):
  519. return (
  520. self.check_reflexivity_on_dataset(xs)
  521. and self.check_symmetry_on_dataset(xs)
  522. and self.check_transitivity_on_dataset(xs, assume_symmetry=True, assume_reflexivity=True)
  523. )
  524. def check_transitivity_on_dataset(self, xs, assume_symmetry=False, assume_reflexivity=False):
  525. for x_idx, x in enumerate(xs):
  526. for y_idx, y in enumerate(xs):
  527. if x is y:
  528. continue
  529. if self.equivalent(x, y):
  530. for z_idx, z in enumerate(xs):
  531. if y is z:
  532. continue
  533. if assume_symmetry and x_idx > z_idx:
  534. continue
  535. if assume_reflexivity and x is z:
  536. continue
  537. if self.equivalent(y, z):
  538. if not self.equivalent(x, z):
  539. return False
  540. return True
  541. def match_lists(self, xs, ys, filter_minimum_size=0, filter_maximum_size=math.inf):
  542. xs = list(xs)
  543. ys = list(ys)
  544. if any(x is y for x in xs for y in ys):
  545. raise ValueError('Lists contain the same element. This is currently not supported.')
  546. classes = self.equivalence_classes([*xs, *ys])
  547. return [
  548. [
  549. (0 if any(x2 is x for x2 in xs) else 1, x)
  550. for x in c
  551. ]
  552. for c in classes[::-1]
  553. if filter_minimum_size <= len(c) <= filter_maximum_size
  554. ]
  555. def iff_patch(patch: mock._patch):
  556. def decorator(f):
  557. def wrapped(*args, **kwargs):
  558. with patch:
  559. f(*args, **kwargs)
  560. try:
  561. f(*args, **kwargs)
  562. except:
  563. pass
  564. else:
  565. raise AssertionError('Test did not fail without patch')
  566. return wrapped
  567. return decorator
  568. def iff_not_patch(patch: mock._patch):
  569. def decorator(f):
  570. def wrapped(*args, **kwargs):
  571. f(*args, **kwargs)
  572. try:
  573. with patch:
  574. f(*args, **kwargs)
  575. except Exception as e:
  576. pass
  577. else:
  578. raise AssertionError('Test did not fail with patch')
  579. return wrapped
  580. return decorator
  581. EMAIL_CRASHES_TO = []
  582. VOICE_CALL_ON_CRASH: List[Tuple[str, str]] = []
  583. def list_logger(base_logging_function, store_in_list: list):
  584. def print_and_store(*args, **kwargs):
  585. base_logging_function(*args, **kwargs)
  586. store_in_list.extend(args)
  587. return print_and_store
  588. def main_wrapper(f):
  589. @functools.wraps(f)
  590. def wrapper(*args, **kwargs):
  591. start = time.perf_counter()
  592. # import lib.stack_tracer
  593. import __main__
  594. # does not help much
  595. # monitoring_thread = hanging_threads.start_monitoring(seconds_frozen=180, test_interval=1000)
  596. os.makedirs('logs', exist_ok=True)
  597. stack_tracer.trace_start('logs/' + os.path.split(__main__.__file__)[-1] + '.html', interval=5)
  598. faulthandler.enable()
  599. profile_wall_time_instead_if_profiling()
  600. # noinspection PyBroadException
  601. try:
  602. return f(*args, **kwargs)
  603. except KeyboardInterrupt:
  604. error_messages = []
  605. print_exc_plus.print_exc_plus(print=list_logger(logging.error, error_messages),
  606. serialize_to='logs/' + os.path.split(__main__.__file__)[-1] + '.dill')
  607. except:
  608. error_messages = []
  609. print_exc_plus.print_exc_plus(print=list_logger(logging.error, error_messages),
  610. serialize_to='logs/' + os.path.split(__main__.__file__)[-1] + '.dill')
  611. for recipient in EMAIL_CRASHES_TO:
  612. from jobs.sending_emails import send_mail
  613. send_mail.create_simple_mail_via_gmail(body='\n'.join(error_messages), filepath=None, excel_name=None, to_mail=recipient, subject='[python] Crash report')
  614. for to_number, from_number in VOICE_CALL_ON_CRASH:
  615. logging.info(f'Calling {from_number} to notify about the crash.')
  616. voice_call('This is a notification message that one of your python scripts has crashed. If you are unsure about the origin of this call, please contact Eren Yilmaz.',
  617. to_number, from_number)
  618. finally:
  619. logging.info('Terminated.')
  620. total_time = time.perf_counter() - start
  621. faulthandler.disable()
  622. stack_tracer.trace_stop()
  623. frequency = 2000
  624. duration = 500
  625. beep(frequency, duration)
  626. print('Total time', total_time)
  627. try:
  628. from algorithm_development.metatrader import ZeroMQ_Connector
  629. ZeroMQ_Connector.DWX_ZeroMQ_Connector.deactivate_all()
  630. except ImportError:
  631. pass
  632. return wrapper
  633. def voice_call(msg, to_number, from_number):
  634. from twilio.rest import Client
  635. account_sid = 'AC63c459168c3e4fe34e462acb4f44f748'
  636. auth_token = 'b633bc0e945fe7cb737fdac395cc71d6'
  637. client = Client(account_sid, auth_token)
  638. call = client.calls.create(
  639. twiml=f'<Response><Say>{msg}</Say></Response>',
  640. from_=from_number,
  641. to=to_number,
  642. )
  643. print(call.sid)
  644. def required_size_for_safe_rotation(base: Tuple[X, Y, Z], rotate_range_deg) -> Tuple[X, Y, Z]:
  645. if abs(rotate_range_deg) > 45:
  646. raise NotImplementedError
  647. if abs(rotate_range_deg) > 0:
  648. x_length = base[2] * math.sin(rotate_range_deg / 180 * math.pi) + base[1] * math.cos(
  649. rotate_range_deg / 180 * math.pi)
  650. y_length = base[2] * math.cos(rotate_range_deg / 180 * math.pi) + base[1] * math.sin(
  651. rotate_range_deg / 180 * math.pi)
  652. result = (base[0],
  653. x_length,
  654. y_length,)
  655. else:
  656. result = base
  657. return result
  658. def round_to_closest_value(x, values, assume_sorted=False):
  659. if not assume_sorted:
  660. values = sorted(values)
  661. next_largest = bisect_left(values, x) # binary search
  662. if next_largest == 0:
  663. return values[0]
  664. if next_largest == len(values):
  665. return values[-1]
  666. next_smallest = next_largest - 1
  667. smaller = values[next_smallest]
  668. larger = values[next_largest]
  669. if abs(smaller - x) < abs(larger - x):
  670. return smaller
  671. else:
  672. return larger
  673. def binary_search(a, x, lo=0, hi=None):
  674. hi = hi if hi is not None else len(a) # hi defaults to len(a)
  675. pos = bisect_left(a, x, lo, hi) # find insertion position
  676. return pos if pos != hi and a[pos] == x else -1 # don't walk off the end
  677. def ceil_to_closest_value(x, values):
  678. values = sorted(values)
  679. next_largest = bisect_left(values, x) # binary search
  680. if next_largest < len(values):
  681. return values[next_largest]
  682. else:
  683. return values[-1] # if there is no larger value use the largest one
  684. def print_progress_bar(iteration, total, prefix='Progress:', suffix='', decimals=1, length=50, fill='█',
  685. print_eta=True):
  686. """
  687. Call in a loop to create terminal progress bar
  688. @params:
  689. iteration - Required : current iteration (Int)
  690. total - Required : total iterations (Int)
  691. prefix - Optional : prefix string (Str)
  692. suffix - Optional : suffix string (Str)
  693. decimals - Optional : positive number of decimals in percent complete (Int)
  694. length - Optional : character length of bar (Int)
  695. fill - Optional : bar fill character (Str)
  696. """
  697. percent = ("{0:" + str(4 + decimals) + "." + str(decimals) + "f}").format(100 * (iteration / float(total)))
  698. filled_length = int(length * iteration // total)
  699. bar = fill * filled_length + '-' * (length - filled_length)
  700. if getattr(print_progress_bar, 'last_printed_value', None) == (prefix, bar, percent, suffix):
  701. return
  702. print_progress_bar.last_printed_value = (prefix, bar, percent, suffix)
  703. print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end='')
  704. # Print New Line on Complete
  705. if iteration == total:
  706. print()
  707. def get_all_subclasses(klass):
  708. all_subclasses = []
  709. for subclass in klass.__subclasses__():
  710. all_subclasses.append(subclass)
  711. all_subclasses.extend(get_all_subclasses(subclass))
  712. return all_subclasses
  713. def my_mac_address():
  714. """
  715. https://stackoverflow.com/a/160821
  716. """
  717. import uuid
  718. mac = uuid.getnode()
  719. if (mac >> 40) % 2:
  720. return None
  721. mac = uuid.UUID(int=mac).hex[-12:]
  722. return mac
  723. def latin1_json(data):
  724. return json.dumps(data, ensure_ascii=False).encode('latin-1')
  725. def l2_norm(v1, v2):
  726. if len(v1) != len(v2):
  727. raise ValueError('Both vectors must be of the same size')
  728. return math.sqrt(sum([(x1 - x2) * (x1 - x2) for x1, x2 in zip(v1, v2)]))
  729. def allow_additional_unused_keyword_arguments(func):
  730. @functools.wraps(func)
  731. def wrapper(*args, **kwargs):
  732. allowed_kwargs = [param.name for param in inspect.signature(func).parameters.values()]
  733. allowed_kwargs = {a: kwargs[a] for a in kwargs if a in allowed_kwargs}
  734. return func(*args, **allowed_kwargs)
  735. return wrapper
  736. def copy_and_rename_method(func, new_name):
  737. funcdetails = [
  738. func.__code__,
  739. func.__globals__,
  740. func.__name__,
  741. func.__defaults__,
  742. func.__closure__
  743. ]
  744. old_name = func.__name__
  745. # copy
  746. # new_func = dill.loads(dill.dumps(func))
  747. new_func = FunctionType(*funcdetails)
  748. assert new_func is not funcdetails
  749. # rename
  750. new_func.__name__ = new_name
  751. assert func.__name__ is old_name
  752. return new_func
  753. def rename(new_name):
  754. def decorator(f):
  755. f.__name__ = new_name
  756. return f
  757. return decorator
  758. class LogicError(Exception):
  759. pass
  760. def round_time(dt=None, precision=60):
  761. """Round a datetime object to any time lapse in seconds
  762. dt : datetime.datetime object, default now.
  763. roundTo : Closest number of seconds to round to, default 1 minute.
  764. Author: Thierry Husson 2012 - Use it as you want but don't blame me.
  765. """
  766. if dt is None:
  767. dt = datetime.datetime.now()
  768. if isinstance(precision, datetime.timedelta):
  769. precision = precision.total_seconds()
  770. seconds = (dt.replace(tzinfo=None) - dt.min).seconds
  771. rounding = (seconds + precision / 2) // precision * precision
  772. return dt + datetime.timedelta(seconds=rounding - seconds,
  773. microseconds=dt.microsecond)
  774. def chunks(lst, n):
  775. """Yield successive n-sized chunks from lst."""
  776. for i in range(0, len(lst), n):
  777. yield lst[i:i + n]
  778. def shorten_name(name):
  779. name = re.sub(r'\s+', r' ', str(name))
  780. name = name.replace(', ', ',')
  781. name = name.replace(', ', ',')
  782. name = name.replace(' ', '_')
  783. return re.sub(r'([A-Za-z])[a-z]*_?', r'\1', str(name))
  784. def array_analysis(a: numpy.ndarray):
  785. print(f' Shape: {a.shape}')
  786. mean = a.mean()
  787. print(f' Mean: {mean}')
  788. print(f' Std: {a.std()}')
  789. print(f' Min, Max: {a.min(), a.max()}')
  790. print(f' Mean absolute: {numpy.abs(a).mean()}')
  791. print(f' Mean square: {numpy.square(a).mean()}')
  792. print(f' Mean absolute difference from mean: {numpy.abs(a - mean).mean()}')
  793. print(f' Mean squared difference from mean: {numpy.square(a - mean).mean()}')
  794. nonzero = numpy.count_nonzero(a)
  795. print(f' Number of non-zeros: {nonzero}')
  796. print(f' Number of zeros: {numpy.prod(a.shape) - nonzero}')
  797. if a.shape[-1] > 1 and a.shape[-1] <= 1000:
  798. # last dim is probably the number of classes
  799. print(f' Class counts: {numpy.count_nonzero(a, axis=tuple(range(len(a.shape) - 1)))}')
  800. def current_year_begin():
  801. return datetime.datetime(datetime.datetime.today().year, 1, 1).timestamp()
  802. def current_day_begin():
  803. return datetime.datetime.today().timestamp() // (3600 * 24) * (3600 * 24)
  804. def current_second_begin():
  805. return floor(datetime.datetime.today().timestamp())
  806. def running_workers(executor):
  807. print(next(iter(executor._threads)).__dict__)
  808. return sum(1 for t in executor._threads
  809. if t == 1)
  810. def queued_calls(executor):
  811. return len(executor._work_queue.queue)
  812. def retry_on_error(max_tries=3, delay=0.5, backoff=2, only_error_classes=Exception):
  813. def decorator(func):
  814. @functools.wraps(func)
  815. def wrapper(*args, **kwargs):
  816. for i in range(max_tries):
  817. try:
  818. return func(*args, **kwargs)
  819. except only_error_classes as e:
  820. if i == max_tries - 1:
  821. raise
  822. logging.error(f'Re-try after error in {func.__name__}: {type(e).__name__}, {e}')
  823. time.sleep(delay * (backoff ** i))
  824. return wrapper
  825. return decorator
  826. class EBC:
  827. SUBCLASSES_BY_NAME: Dict[str, Type['EBC']] = {}
  828. def __init_subclass__(cls, **kwargs):
  829. super().__init_subclass__(**kwargs)
  830. EBC.SUBCLASSES_BY_NAME[cls.__name__] = cls
  831. def __eq__(self, other):
  832. return isinstance(other, type(self)) and self.__dict__ == other.__dict__
  833. def __str__(self):
  834. return str(self.__dict__)
  835. def __repr__(self):
  836. return f'{type(self).__name__}(**' + str(self.__dict__) + ')'
  837. def to_json(self) -> Dict[str, Any]:
  838. result: Dict[str, Any] = {
  839. 'type': type(self).__name__,
  840. **self.__dict__,
  841. }
  842. for k in result:
  843. if isinstance(result[k], EBC):
  844. result[k] = result[k].to_json()
  845. elif isinstance(result[k], numpy.ndarray):
  846. result[k] = result[k].tolist()
  847. elif isinstance(result[k], list):
  848. result[k] = [r.to_json() if isinstance(r, EBC) else r
  849. for r in result[k]]
  850. return result
  851. @staticmethod
  852. def from_json(data: Dict[str, Any]):
  853. cls = EBC.SUBCLASSES_BY_NAME[data['type']]
  854. return class_from_json(cls, data)
  855. def class_from_json(cls, data: Dict[str, Any]):
  856. if isinstance(data, str):
  857. data = json.loads(data)
  858. # noinspection PyArgumentList
  859. try:
  860. return cls(**data)
  861. except TypeError as e:
  862. if "__init__() got an unexpected keyword argument 'type'" in str(e) or 'takes no arguments' in str(e):
  863. # probably this was from a to_json method
  864. if data['type'] != cls.__name__:
  865. t = data['type']
  866. logging.warning(f'Reconstructing a {cls.__name__} from a dict with type={t}')
  867. data = data.copy()
  868. del data['type']
  869. for k,v in data.items():
  870. if probably_serialized_from_ebc(v):
  871. data[k] = EBC.SUBCLASSES_BY_NAME[v['type']].from_json(v)
  872. elif isinstance(v, list):
  873. data[k] = [EBC.SUBCLASSES_BY_NAME[x['type']].from_json(x)
  874. if probably_serialized_from_ebc(x)
  875. else x
  876. for x in v]
  877. return allow_additional_unused_keyword_arguments(cls)(**data)
  878. else:
  879. raise
  880. def probably_serialized_from_ebc(data):
  881. return isinstance(data, dict) and 'type' in data and data['type'] in EBC.SUBCLASSES_BY_NAME
  882. class EBE(Enum):
  883. def __int__(self):
  884. return self.value
  885. def __str__(self):
  886. return self.name
  887. def __repr__(self):
  888. return self.name
  889. @classmethod
  890. def from_name(cls, variable_name):
  891. return cls.__dict__[variable_name]
  892. class Bunch(dict, EBC):
  893. def __init__(self, **kwargs):
  894. dict.__init__(self, kwargs)
  895. self.__dict__.update(kwargs)
  896. def add_method(self, m):
  897. setattr(self, m.__name__, functools.partial(m, self))
  898. def floor_to_multiple_of(x, multiple_of):
  899. return math.floor(x / multiple_of) * multiple_of
  900. def round_to_multiple_of(x, multiple_of):
  901. return round(x / multiple_of) * multiple_of
  902. def ceil_to_multiple_of(x, multiple_of):
  903. return math.ceil(x / multiple_of) * multiple_of