util.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251
  1. import datetime
  2. import faulthandler
  3. import functools
  4. import gc
  5. import inspect
  6. import json
  7. import math
  8. import os
  9. import random
  10. import re
  11. import sqlite3
  12. import sys
  13. import threading
  14. import time
  15. import typing
  16. from bisect import bisect_left
  17. from enum import Enum
  18. from itertools import chain, combinations
  19. from math import log, isnan, nan, floor, log10, gcd
  20. from numbers import Number
  21. from shutil import copyfile
  22. from subprocess import check_output, CalledProcessError, PIPE
  23. from threading import RLock
  24. from types import FunctionType
  25. from typing import Union, Tuple, List, Optional, Dict, Any, Type, Callable, Literal
  26. # noinspection PyUnresolvedReferences
  27. from unittest import TestCase, mock
  28. import cachetools
  29. import hanging_threads
  30. import matplotlib.cm
  31. import matplotlib.pyplot as plt
  32. import numpy
  33. import numpy as np
  34. import pandas
  35. import scipy.optimize
  36. import scipy.stats
  37. import tabulate
  38. from scipy.ndimage import zoom
  39. from tool_lib import stack_tracer, print_exc_plus
  40. from tool_lib.my_logger import logging
  41. X = Y = Z = float
  42. class KnownIssue(Exception):
  43. """
  44. This means the code is not working and should not be used but still too valuable to be deleted
  45. """
  46. pass
  47. def powerset(iterable, largest_first=False):
  48. """powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"""
  49. s = list(iterable)
  50. sizes = range(len(s) + 1)
  51. if largest_first:
  52. sizes = reversed(sizes)
  53. return chain.from_iterable(combinations(s, r) for r in sizes)
  54. def plot_with_conf(x, y_mean, y_conf, alpha=0.5, **kwargs):
  55. ax = kwargs.pop('ax', plt.gca())
  56. base_line, = ax.plot(x, y_mean, **kwargs)
  57. y_mean = np.array(y_mean)
  58. y_conf = np.array(y_conf)
  59. lb = y_mean - y_conf
  60. ub = y_mean + y_conf
  61. ax.fill_between(x, lb, ub, facecolor=base_line.get_color(), alpha=alpha)
  62. def choice(sequence, probabilities):
  63. # if sum(probabilities) != 1:
  64. # raise AssertionError('Probabilities must sum to 1')
  65. r = random.random()
  66. for idx, c in enumerate(sequence):
  67. r -= probabilities[idx]
  68. if r < 0:
  69. return c
  70. raise AssertionError('Probabilities must sum to 1')
  71. def local_timezone():
  72. return datetime.datetime.now(datetime.timezone(datetime.timedelta(0))).astimezone().tzinfo
  73. def print_attributes(obj, include_methods=False, ignore=None):
  74. if ignore is None:
  75. ignore = []
  76. for attr in dir(obj):
  77. if attr in ignore:
  78. continue
  79. if attr.startswith('_'):
  80. continue
  81. if not include_methods and callable(obj.__getattr__(attr)):
  82. continue
  83. print(attr, ':', obj.__getattr__(attr).__class__.__name__, ':', obj.__getattr__(attr))
  84. def attr_dir(obj, include_methods=False, ignore=None):
  85. if ignore is None:
  86. ignore = []
  87. return {attr: obj.__getattr__(attr)
  88. for attr in dir(obj)
  89. if not attr.startswith('_') and (
  90. include_methods or not callable(obj.__getattr__(attr))) and attr not in ignore}
  91. def zoom_to_shape(a: np.ndarray, shape: Tuple, mode: str = 'smooth', verbose=1):
  92. from keras import backend
  93. a = np.array(a, dtype=backend.floatx()) # also does a copy
  94. shape_dim = len(a.shape)
  95. if len(a.shape) != len(shape):
  96. raise ValueError('The shapes must have the same dimension but were len({0}) = {1} (original) '
  97. 'and len({2}) = {3} desired.'.format(a.shape, len(a.shape), shape, len(shape)))
  98. if len(shape) == 0:
  99. return a
  100. zoom_factors = tuple(shape[idx] / a.shape[idx] for idx in range(shape_dim))
  101. def _current_index_in_old_array():
  102. return tuple(slice(0, length) if axis != current_axis else slice(current_pixel_index, current_pixel_index + 1)
  103. for axis, length in enumerate(a.shape))
  104. def _current_pixel_shape():
  105. return tuple(length if axis != current_axis else 1
  106. for axis, length in enumerate(a.shape))
  107. def _current_result_index():
  108. return tuple(
  109. slice(0, length) if axis != current_axis else slice(pixel_index_in_result, pixel_index_in_result + 1)
  110. for axis, length in enumerate(a.shape))
  111. def _current_result_shape():
  112. return tuple(orig_length if axis != current_axis else shape[axis]
  113. for axis, orig_length in enumerate(a.shape))
  114. if mode == 'constant':
  115. result = zoom(a, zoom_factors)
  116. assert result.shape == shape
  117. return result
  118. elif mode == 'smooth':
  119. result = a
  120. for current_axis, zoom_factor in sorted(enumerate(zoom_factors), key=lambda x: x[1]):
  121. result = np.zeros(_current_result_shape(), dtype=backend.floatx())
  122. # current_length = a.shape[current_axis]
  123. desired_length = shape[current_axis]
  124. current_pixel_index = 0
  125. current_pixel_part = 0 # how much of the current pixel is already read
  126. for pixel_index_in_result in range(desired_length):
  127. pixels_remaining = 1 / zoom_factor
  128. pixel_sum = np.zeros(_current_pixel_shape())
  129. while pixels_remaining + current_pixel_part > 1:
  130. pixel_sum += (1 - current_pixel_part) * a[_current_index_in_old_array()]
  131. current_pixel_index += 1
  132. pixels_remaining -= (1 - current_pixel_part)
  133. current_pixel_part = 0
  134. # the remaining pixel_part
  135. try:
  136. pixel_sum += pixels_remaining * a[_current_index_in_old_array()]
  137. except (IndexError, ValueError):
  138. if verbose:
  139. print('WARNING: Skipping {0} pixels because of numerical imprecision.'.format(pixels_remaining))
  140. else:
  141. current_pixel_part += pixels_remaining
  142. # insert to result
  143. pixel_sum *= zoom_factor
  144. result[_current_result_index()] = pixel_sum
  145. a = result
  146. assert result.shape == shape
  147. return result
  148. else:
  149. return NotImplementedError('Mode not available.')
  150. def dummy_computation(*_args, **_kwargs):
  151. pass
  152. def backup_file(filename):
  153. copyfile(filename, backup_file_path(filename))
  154. def backup_file_path(filename):
  155. return filename + time.strftime("%Y%m%d") + '.bak'
  156. # noinspection SpellCheckingInspection
  157. def my_tabulate(data, tablefmt='pipe', **params):
  158. if data == [] and 'headers' in params:
  159. data = [(None for _ in params['headers'])]
  160. tabulate.MIN_PADDING = 0
  161. return tabulate.tabulate(data, tablefmt=tablefmt, **params)
  162. def ce_loss(y_true, y_predicted):
  163. return -(y_true * log(y_predicted) + (1 - y_true) * log(1 - y_predicted))
  164. class DontSaveResultsError(Exception):
  165. pass
  166. def multinomial(n, bins):
  167. if bins == 0:
  168. if n > 0:
  169. raise ValueError('Cannot distribute to 0 bins.')
  170. return []
  171. remaining = n
  172. results = []
  173. for i in range(bins - 1):
  174. from numpy.random.mtrand import binomial
  175. x = binomial(remaining, 1 / (bins - i))
  176. results.append(x)
  177. remaining -= x
  178. results.append(remaining)
  179. return results
  180. class UnknownTypeError(Exception):
  181. pass
  182. # def shape_analysis(xs):
  183. # composed_dtypes = [list, tuple, np.ndarray, dict, set]
  184. # base_dtypes = [str, int, float, type, object] # TODO add class and superclass of xs first element
  185. # all_dtypes = composed_dtypes + base_dtypes
  186. # if isinstance(xs, np.ndarray):
  187. # outer_brackets = ('[', ']')
  188. # shape = xs.shape
  189. # dtype = xs.dtype
  190. # elif isinstance(xs, tuple):
  191. # outer_brackets = ('(', ')')
  192. # shape = len(xs)
  193. # dtype = [t for t in all_dtypes if all(isinstance(x, t) for x in xs)][0]
  194. # elif isinstance(xs, list):
  195. # outer_brackets = ('[', ']')
  196. # shape = len(xs)
  197. # dtype = [t for t in all_dtypes if all(isinstance(x, t) for x in xs)][0]
  198. # elif isinstance(xs, dict) or isinstance(xs, set):
  199. # outer_brackets = ('{', '}')
  200. # shape = len(xs)
  201. # dtype = [t for t in all_dtypes if all(isinstance(x, t) for x in xs)][0]
  202. # elif any(isinstance(xs, t) for t in base_dtypes):
  203. # for t in base_dtypes:
  204. # if isinstance(xs, t):
  205. # return str(t.__name__)
  206. # raise AssertionError('This should be unreachable.')
  207. # else:
  208. # raise UnknownTypeError('Unknown type:' + type(xs).__name__)
  209. #
  210. # if shape and shape != '?':
  211. # return outer_brackets[0] + str(xs.shape) + ' * ' + str(dtype) + outer_brackets[1]
  212. # else:
  213. # return outer_brackets[0] + outer_brackets[1]
  214. def beta_conf_interval_mle(data, conf=0.95):
  215. if len(data) <= 1:
  216. return 0, 1 # overestimates the interval
  217. if any(d < 0 or d > 1 or isnan(d) for d in data):
  218. return nan, nan
  219. if numpy.var(data) == 0:
  220. return numpy.mean(data), numpy.mean(data)
  221. epsilon = 1e-3
  222. # adjusted_data = data.copy()
  223. # for idx in range(len(adjusted_data)):
  224. # adjusted_data[idx] *= (1 - 2 * epsilon)
  225. # adjusted_data[idx] += epsilon
  226. alpha, beta, _, _ = scipy.stats.beta.fit(data, floc=-epsilon, fscale=1 + 2 * epsilon)
  227. lower, upper = scipy.stats.beta.interval(alpha=conf, a=alpha, b=beta)
  228. if lower < 0:
  229. lower = 0
  230. if upper < 0:
  231. upper = 0
  232. if lower > 1:
  233. lower = 1
  234. if upper > 1:
  235. upper = 1
  236. return lower, upper
  237. def gamma_conf_interval_mle(data, conf=0.95) -> Tuple[float, float]:
  238. if len(data) == 0:
  239. return nan, nan
  240. if len(data) == 1:
  241. return nan, nan
  242. if any(d < 0 or isnan(d) for d in data):
  243. return nan, nan
  244. if numpy.var(data) == 0:
  245. return numpy.mean(data).item(), 0
  246. alpha, _, scale = scipy.stats.gamma.fit(data, floc=0)
  247. lower, upper = scipy.stats.gamma.interval(alpha=conf, a=alpha, scale=scale)
  248. if lower < 0:
  249. lower = 0
  250. if upper < 0:
  251. upper = 0
  252. return lower, upper
  253. beta_quantile_cache = cachetools.LRUCache(maxsize=10)
  254. @cachetools.cached(cache=beta_quantile_cache, key=lambda x1, p1, x2, p2, guess: (x1, x2, p1, p2))
  255. def beta_parameters_quantiles(x1, p1, x2, p2, guess=(3, 3)):
  256. "Find parameters for a beta random variable X; so; that; P(X > x1) = p1 and P(X > x2) = p2.; "
  257. def square(x):
  258. return x * x
  259. def objective(v):
  260. (a, b) = v
  261. temp = square(scipy.stats.beta.cdf(x1, a, b) - p1)
  262. temp += square(scipy.stats.beta.cdf(x2, a, b) - p2)
  263. return temp
  264. xopt = scipy.optimize.fmin(objective, guess, disp=False)
  265. return (xopt[0], xopt[1])
  266. def beta_conf_interval_quantile(data, conf=0.95, quantiles=(0.25, 0.75)):
  267. if len(data) <= 1:
  268. return 0, 1 # overestimates the interval
  269. mu = numpy.mean(data)
  270. v = numpy.var(data)
  271. data = numpy.array(data)
  272. if v == 0:
  273. return mu, mu
  274. lower = numpy.quantile(data, quantiles[0])
  275. upper = numpy.quantile(data, quantiles[1])
  276. alpha_guess = mu ** 2 * ((1 - mu) / v - 1 / mu)
  277. beta_guess = alpha_guess * (1 / mu - 1)
  278. alpha, beta = beta_parameters_quantiles(lower, quantiles[0], upper, quantiles[1], (alpha_guess, beta_guess))
  279. return scipy.stats.beta.interval(alpha=conf, a=alpha, b=beta)
  280. def beta_stats_quantile(data, quantiles=(0.25, 0.75)):
  281. if len(data) <= 1:
  282. return 0, 1 # overestimates the interval
  283. data = numpy.array(data)
  284. mu = numpy.mean(data)
  285. v = numpy.var(data)
  286. if v == 0:
  287. return mu, mu
  288. lower = numpy.quantile(data, quantiles[0])
  289. upper = numpy.quantile(data, quantiles[1])
  290. alpha_guess = mu ** 2 * ((1 - mu) / v - 1 / mu)
  291. beta_guess = alpha_guess * (1 / mu - 1)
  292. alpha, beta = beta_parameters_quantiles(lower, quantiles[0], upper, quantiles[1], (alpha_guess, beta_guess))
  293. return scipy.stats.beta.stats(a=alpha, b=beta)
  294. def beta_stats_mle(data):
  295. if len(data) == 0:
  296. return nan, nan
  297. if len(data) == 1:
  298. return nan, nan
  299. if any(d < 0 or d > 1 or isnan(d) for d in data):
  300. return nan, nan
  301. if numpy.var(data) == 0:
  302. return numpy.mean(data), 0
  303. epsilon = 1e-4
  304. # adjusted_data = data.copy()
  305. # for idx in range(len(adjusted_data)):
  306. # adjusted_data[idx] *= (1 - 2 * epsilon)
  307. # adjusted_data[idx] += epsilon
  308. alpha, beta, _, _ = scipy.stats.beta.fit(data, floc=-epsilon, fscale=1 + 2 * epsilon)
  309. return scipy.stats.beta.stats(a=alpha, b=beta)
  310. def gamma_stats_mle(data):
  311. if len(data) == 0:
  312. return nan, nan
  313. if len(data) == 1:
  314. return nan, nan
  315. if any(d < 0 or isnan(d) for d in data):
  316. return nan, nan
  317. if numpy.var(data) == 0:
  318. return numpy.mean(data), 0
  319. alpha, _, scale = scipy.stats.gamma.fit(data, floc=0)
  320. return scipy.stats.gamma.stats(a=alpha, scale=scale)
  321. beta_stats = beta_stats_quantile
  322. beta_conf_interval = beta_conf_interval_quantile
  323. gamma_stats = gamma_stats_mle
  324. gamma_conf_interval = gamma_conf_interval_mle
  325. def split_df_list(df, target_column):
  326. """
  327. df = data frame to split,
  328. target_column = the column containing the values to split
  329. separator = the symbol used to perform the split
  330. returns: a data frame with each entry for the target column separated, with each element moved into a new row.
  331. The values in the other columns are duplicated across the newly divided rows.
  332. SOURCE: https://gist.github.com/jlln/338b4b0b55bd6984f883
  333. """
  334. def split_list_to_rows(row, row_accumulator):
  335. split_row = json.loads(row[target_column])
  336. for s in split_row:
  337. new_row = row.to_dict()
  338. new_row[target_column] = s
  339. row_accumulator.append(new_row)
  340. new_rows = []
  341. df.apply(split_list_to_rows, axis=1, args=(new_rows,))
  342. new_df = pandas.DataFrame(new_rows)
  343. return new_df
  344. try:
  345. import winsound as win_sound
  346. def beep(*args, **kwargs):
  347. win_sound.Beep(*args, **kwargs)
  348. except ImportError:
  349. win_sound = None
  350. def beep(*_args, **_kwargs):
  351. pass
  352. def round_to_digits(x, d):
  353. if x == 0:
  354. return 0
  355. if isnan(x):
  356. return nan
  357. try:
  358. return round(x, d - 1 - int(floor(log10(abs(x)))))
  359. except OverflowError:
  360. return x
  361. def gc_if_memory_error(f, *args, **kwargs):
  362. try:
  363. return f(*args, **kwargs)
  364. except MemoryError:
  365. print('Starting garbage collector')
  366. gc.collect()
  367. return f(*args, **kwargs)
  368. def assert_not_empty(x):
  369. assert len(x)
  370. return x
  371. def validation_steps(validation_dataset_size, maximum_batch_size):
  372. batch_size = gcd(validation_dataset_size, maximum_batch_size)
  373. steps = validation_dataset_size // batch_size
  374. assert batch_size * steps == validation_dataset_size
  375. return batch_size, steps
  376. def functional_dependency_trigger(connection: sqlite3.Connection,
  377. table_name: str,
  378. determining_columns: List[str],
  379. determined_columns: List[str],
  380. exist_ok: bool, ):
  381. cursor = connection.cursor()
  382. # possible_performance_improvements
  383. determined_columns = [c for c in determined_columns if c not in determining_columns]
  384. trigger_base_name = '_'.join([table_name] + determining_columns + ['determine'] + determined_columns)
  385. error_message = ','.join(determining_columns) + ' must uniquely identify ' + ','.join(determined_columns)
  386. # when inserting check if there is already an entry with these values
  387. cursor.execute(f'''
  388. CREATE TRIGGER {'IF NOT EXISTS' if exist_ok else ''} {trigger_base_name}_after_insert
  389. BEFORE INSERT ON {table_name}
  390. WHEN EXISTS(SELECT * FROM {table_name}
  391. WHERE ({' AND '.join(f'NEW.{c} IS NOT NULL AND {c} = NEW.{c}' for c in determining_columns)})
  392. AND ({' OR '.join(f'{c} != NEW.{c}' for c in determined_columns)}))
  393. BEGIN SELECT RAISE(ROLLBACK, '{error_message}'); END
  394. ''')
  395. # when updating check if there is already an entry with these values (only if changed)
  396. cursor.execute(f'''
  397. CREATE TRIGGER {'IF NOT EXISTS' if exist_ok else ''} {trigger_base_name}_after_update
  398. BEFORE UPDATE ON {table_name}
  399. WHEN EXISTS(SELECT * FROM {table_name}
  400. WHERE ({' AND '.join(f'NEW.{c} IS NOT NULL AND {c} = NEW.{c}' for c in determining_columns)})
  401. AND ({' OR '.join(f'{c} != NEW.{c}' for c in determined_columns)}))
  402. BEGIN SELECT RAISE(ROLLBACK, '{error_message}'); END
  403. ''')
  404. def heatmap_from_points(x, y,
  405. x_lim: Optional[Union[int, Tuple[int, int]]] = None,
  406. y_lim: Optional[Union[int, Tuple[int, int]]] = None,
  407. gridsize=30):
  408. if isinstance(x_lim, Number):
  409. x_lim = (x_lim, x_lim)
  410. if isinstance(y_lim, Number):
  411. y_lim = (y_lim, y_lim)
  412. plt.hexbin(x, y, gridsize=gridsize, cmap=matplotlib.cm.jet, bins=None)
  413. if x_lim is not None:
  414. plt.xlim(x_lim)
  415. if y_lim is not None:
  416. plt.ylim(y_lim)
  417. cb = plt.colorbar()
  418. cb.set_label('mean value')
  419. def strptime(date_string, fmt):
  420. return datetime.datetime(*(time.strptime(date_string, fmt)[0:6]))
  421. class PrintLineRLock(RLock().__class__):
  422. def __init__(self, *args, name='', **kwargs):
  423. # noinspection PyArgumentList
  424. super().__init__(*args, **kwargs)
  425. self.name = name
  426. def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
  427. print(f'Trying to acquire Lock {self.name}')
  428. result = RLock.acquire(self, blocking, timeout)
  429. print(f'Acquired Lock {self.name}')
  430. return result
  431. def release(self) -> None:
  432. print(f'Trying to release Lock {self.name}')
  433. # noinspection PyNoneFunctionAssignment
  434. result = RLock.release(self)
  435. print(f'Released Lock {self.name}')
  436. return result
  437. def __enter__(self, *args, **kwargs):
  438. print('Trying to enter Lock')
  439. # noinspection PyArgumentList
  440. super().__enter__(*args, **kwargs)
  441. print('Entered Lock')
  442. def __exit__(self, *args, **kwargs):
  443. print('Trying to exit Lock')
  444. super().__exit__(*args, **kwargs)
  445. print('Exited Lock')
  446. def fixed_get_current_frames():
  447. """Return current threads prepared for
  448. further processing.
  449. """
  450. threads = {thread.ident: thread for thread in threading.enumerate()}
  451. return {
  452. thread_id: {
  453. 'frame': hanging_threads.thread2list(frame),
  454. 'time': None,
  455. 'id': thread_id,
  456. 'name': threads[thread_id].name,
  457. 'object': threads[thread_id]
  458. } for thread_id, frame in sys._current_frames().items()
  459. if thread_id in threads # otherwise keyerrors might happen because of race conditions
  460. }
  461. hanging_threads.get_current_frames = fixed_get_current_frames
  462. class CallCounter():
  463. def __init__(self, f):
  464. self.f = f
  465. self.calls = 0
  466. self.__name__ = f.__name__
  467. def __call__(self, *args, **kwargs):
  468. self.calls += 1
  469. return self.f(*args, **kwargs)
  470. def __str__(self):
  471. return str(self.__dict__)
  472. def __repr__(self):
  473. return self.__class__.__name__ + repr(self.__dict__)
  474. def test_with_timeout(timeout=2):
  475. def wrapper(f):
  476. from lib.threading_timer_decorator import exit_after
  477. f = exit_after(timeout)(f)
  478. @functools.wraps(f)
  479. def wrapped(*args, **kwargs):
  480. try:
  481. print(f'Running this test with timeout: {timeout}')
  482. return f(*args, **kwargs)
  483. except KeyboardInterrupt:
  484. raise AssertionError(f'Test took longer than {timeout} seconds')
  485. return wrapped
  486. return wrapper
  487. def lru_cache_by_id(maxsize):
  488. return cachetools.cached(cachetools.LRUCache(maxsize=maxsize), key=id)
  489. def iff_patch(patch: mock._patch):
  490. def decorator(f):
  491. def wrapped(*args, **kwargs):
  492. with patch:
  493. f(*args, **kwargs)
  494. try:
  495. f(*args, **kwargs)
  496. except:
  497. pass
  498. else:
  499. raise AssertionError('Test did not fail without patch')
  500. return wrapped
  501. return decorator
  502. def iff_not_patch(patch: mock._patch):
  503. def decorator(f):
  504. def wrapped(*args, **kwargs):
  505. f(*args, **kwargs)
  506. try:
  507. with patch:
  508. f(*args, **kwargs)
  509. except Exception as e:
  510. pass
  511. else:
  512. raise AssertionError('Test did not fail with patch')
  513. return wrapped
  514. return decorator
  515. EMAIL_CRASHES_TO = []
  516. def list_logger(base_logging_function, store_in_list: list):
  517. def print_and_store(*args, **kwargs):
  518. base_logging_function(*args, **kwargs)
  519. store_in_list.extend(args)
  520. return print_and_store
  521. def main_wrapper(f):
  522. @functools.wraps(f)
  523. def wrapper(*args, **kwargs):
  524. start = time.perf_counter()
  525. # import lib.stack_tracer
  526. import __main__
  527. # does not help much
  528. # monitoring_thread = hanging_threads.start_monitoring(seconds_frozen=180, test_interval=1000)
  529. os.makedirs('logs', exist_ok=True)
  530. stack_tracer.trace_start('logs/' + os.path.split(__main__.__file__)[-1] + '.html', interval=5)
  531. faulthandler.enable()
  532. profile_wall_time_instead_if_profiling()
  533. # noinspection PyBroadException
  534. try:
  535. return f(*args, **kwargs)
  536. except KeyboardInterrupt:
  537. error_messages = []
  538. print_exc_plus.print_exc_plus(print=list_logger(logging.error, error_messages),
  539. serialize_to='logs/' + os.path.split(__main__.__file__)[-1] + '.dill')
  540. except:
  541. error_messages = []
  542. print_exc_plus.print_exc_plus(print=list_logger(logging.error, error_messages),
  543. serialize_to='logs/' + os.path.split(__main__.__file__)[-1] + '.dill')
  544. for recipient in EMAIL_CRASHES_TO:
  545. from jobs.sending_emails import send_mail
  546. send_mail.create_simple_mail_via_gmail(body='\n'.join(error_messages), filepath=None, excel_name=None, to_mail=recipient, subject='Crash report')
  547. finally:
  548. logging.info('Terminated.')
  549. total_time = time.perf_counter() - start
  550. faulthandler.disable()
  551. stack_tracer.trace_stop()
  552. frequency = 2000
  553. duration = 500
  554. beep(frequency, duration)
  555. print('Total time', total_time)
  556. try:
  557. from algorithm_development.metatrader import ZeroMQ_Connector
  558. ZeroMQ_Connector.DWX_ZeroMQ_Connector.deactivate_all()
  559. except ImportError:
  560. pass
  561. return wrapper
  562. def required_size_for_safe_rotation(base: Tuple[X, Y, Z], rotate_range_deg) -> Tuple[X, Y, Z]:
  563. if abs(rotate_range_deg) > 45:
  564. raise NotImplementedError
  565. if abs(rotate_range_deg) > 0:
  566. x_length = base[2] * math.sin(rotate_range_deg / 180 * math.pi) + base[1] * math.cos(
  567. rotate_range_deg / 180 * math.pi)
  568. y_length = base[2] * math.cos(rotate_range_deg / 180 * math.pi) + base[1] * math.sin(
  569. rotate_range_deg / 180 * math.pi)
  570. result = (base[0],
  571. x_length,
  572. y_length,)
  573. else:
  574. result = base
  575. return result
  576. def round_to_closest_value(x, values, assume_sorted=False):
  577. if not assume_sorted:
  578. values = sorted(values)
  579. next_largest = bisect_left(values, x) # binary search
  580. if next_largest == 0:
  581. return values[0]
  582. if next_largest == len(values):
  583. return values[-1]
  584. next_smallest = next_largest - 1
  585. smaller = values[next_smallest]
  586. larger = values[next_largest]
  587. if abs(smaller - x) < abs(larger - x):
  588. return smaller
  589. else:
  590. return larger
  591. def binary_search(a, x, lo=0, hi=None):
  592. hi = hi if hi is not None else len(a) # hi defaults to len(a)
  593. pos = bisect_left(a, x, lo, hi) # find insertion position
  594. return pos if pos != hi and a[pos] == x else -1 # don't walk off the end
  595. def ceil_to_closest_value(x, values):
  596. values = sorted(values)
  597. next_largest = bisect_left(values, x) # binary search
  598. if next_largest < len(values):
  599. return values[next_largest]
  600. else:
  601. return values[-1] # if there is no larger value use the largest one
  602. def print_progress_bar(iteration, total, prefix='Progress:', suffix='', decimals=1, length=50, fill='█',
  603. print_eta=True):
  604. """
  605. Call in a loop to create terminal progress bar
  606. @params:
  607. iteration - Required : current iteration (Int)
  608. total - Required : total iterations (Int)
  609. prefix - Optional : prefix string (Str)
  610. suffix - Optional : suffix string (Str)
  611. decimals - Optional : positive number of decimals in percent complete (Int)
  612. length - Optional : character length of bar (Int)
  613. fill - Optional : bar fill character (Str)
  614. """
  615. percent = ("{0:" + str(4 + decimals) + "." + str(decimals) + "f}").format(100 * (iteration / float(total)))
  616. filled_length = int(length * iteration // total)
  617. bar = fill * filled_length + '-' * (length - filled_length)
  618. if getattr(print_progress_bar, 'last_printed_value', None) == (prefix, bar, percent, suffix):
  619. return
  620. print_progress_bar.last_printed_value = (prefix, bar, percent, suffix)
  621. print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end='')
  622. # Print New Line on Complete
  623. if iteration == total:
  624. print()
  625. def get_all_subclasses(klass):
  626. all_subclasses = []
  627. for subclass in klass.__subclasses__():
  628. all_subclasses.append(subclass)
  629. all_subclasses.extend(get_all_subclasses(subclass))
  630. return all_subclasses
  631. def latin1_json(data):
  632. return json.dumps(data, ensure_ascii=False).encode('latin-1')
  633. def utf8_json(data):
  634. return json.dumps(data, ensure_ascii=False).encode('utf-8')
  635. def l2_norm(v1, v2):
  636. if len(v1) != len(v2):
  637. raise ValueError('Both vectors must be of the same size')
  638. return math.sqrt(sum([(x1 - x2) * (x1 - x2) for x1, x2 in zip(v1, v2)]))
  639. def allow_additional_unused_keyword_arguments(func):
  640. @functools.wraps(func)
  641. def wrapper(*args, **kwargs):
  642. allowed_kwargs = [param.name for param in inspect.signature(func).parameters.values()]
  643. allowed_kwargs = {a: kwargs[a] for a in kwargs if a in allowed_kwargs}
  644. return func(*args, **allowed_kwargs)
  645. return wrapper
  646. class BiDict(dict):
  647. """
  648. https://stackoverflow.com/a/21894086
  649. """
  650. def __init__(self, *args, **kwargs):
  651. super(BiDict, self).__init__(*args, **kwargs)
  652. self.inverse = {}
  653. for key, value in self.items():
  654. self.inverse.setdefault(value, []).append(key)
  655. def __setitem__(self, key, value):
  656. if key in self:
  657. self.inverse[self[key]].remove(key)
  658. super(BiDict, self).__setitem__(key, value)
  659. self.inverse.setdefault(value, []).append(key)
  660. def __delitem__(self, key):
  661. self.inverse.setdefault(self[key], []).remove(key)
  662. if self[key] in self.inverse and not self.inverse[self[key]]:
  663. del self.inverse[self[key]]
  664. super(BiDict, self).__delitem__(key)
  665. def copy_and_rename_method(func, new_name):
  666. funcdetails = [
  667. func.__code__,
  668. func.__globals__,
  669. func.__name__,
  670. func.__defaults__,
  671. func.__closure__
  672. ]
  673. old_name = func.__name__
  674. # copy
  675. # new_func = dill.loads(dill.dumps(func))
  676. new_func = FunctionType(*funcdetails)
  677. assert new_func is not funcdetails
  678. # rename
  679. new_func.__name__ = new_name
  680. assert func.__name__ is old_name
  681. return new_func
  682. def rename(new_name):
  683. def decorator(f):
  684. f.__name__ = new_name
  685. return f
  686. return decorator
  687. class LogicError(Exception):
  688. pass
  689. def chunks(lst, n):
  690. """Yield successive n-sized chunks from lst."""
  691. for i in range(0, len(lst), n):
  692. yield lst[i:i + n]
  693. def shorten_name(name):
  694. name = re.sub(r'\s+', r' ', str(name))
  695. name = name.replace(', ', ',')
  696. name = name.replace(', ', ',')
  697. name = name.replace(' ', '_')
  698. return re.sub(r'([A-Za-z])[a-z]*_?', r'\1', str(name))
  699. def array_analysis(a: numpy.ndarray):
  700. print(f' Shape: {a.shape}')
  701. mean = a.mean()
  702. print(f' Mean: {mean}')
  703. print(f' Std: {a.std()}')
  704. print(f' Min, Max: {a.min(), a.max()}')
  705. print(f' Mean absolute: {numpy.abs(a).mean()}')
  706. print(f' Mean square: {numpy.square(a).mean()}')
  707. print(f' Mean absolute difference from mean: {numpy.abs(a - mean).mean()}')
  708. print(f' Mean squared difference from mean: {numpy.square(a - mean).mean()}')
  709. nonzero = numpy.count_nonzero(a)
  710. print(f' Number of non-zeros: {nonzero}')
  711. print(f' Number of zeros: {numpy.prod(a.shape) - nonzero}')
  712. if a.shape[-1] > 1 and a.shape[-1] <= 1000:
  713. # last dim is probably the number of classes
  714. print(f' Class counts: {numpy.count_nonzero(a, axis=tuple(range(len(a.shape) - 1)))}')
  715. def current_year_begin():
  716. return datetime.datetime(datetime.datetime.today().year, 1, 1).timestamp()
  717. def current_day_begin():
  718. return datetime.datetime.today().timestamp() // (3600 * 24) * (3600 * 24)
  719. def current_second_begin():
  720. return floor(datetime.datetime.today().timestamp())
  721. def running_workers(executor):
  722. print(next(iter(executor._threads)).__dict__)
  723. return sum(1 for t in executor._threads
  724. if t == 1)
  725. class Bunch(dict):
  726. def __init__(self, **kwargs):
  727. dict.__init__(self, kwargs)
  728. self.__dict__.update(kwargs)
  729. def add_method(self, m):
  730. setattr(self, m.__name__, functools.partial(m, self))
  731. def queued_calls(executor):
  732. return len(executor._work_queue.queue)
  733. def single_use_function(f: Callable):
  734. @functools.wraps(f)
  735. def wrapper(*args, **kwargs):
  736. if not wrapper.already_called:
  737. wrapper.already_called = True
  738. return f(*args, **kwargs)
  739. else:
  740. raise RuntimeError(f'{f} is supposed to be called only once.')
  741. wrapper.already_called = False
  742. return wrapper
  743. def single_use_method(f: Callable):
  744. @functools.wraps(f)
  745. def wrapper(self, *args, **kwargs):
  746. if not hasattr(self, f'_already_called_{id(wrapper)}') or not getattr(self, f'_already_called_{id(wrapper)}'):
  747. setattr(self, f'_already_called_{id(wrapper)}', True)
  748. return f(self, *args, **kwargs)
  749. else:
  750. raise RuntimeError(f'{f} is supposed to be called only once per object.')
  751. return wrapper
  752. class JSONConvertible:
  753. SUBCLASSES_BY_NAME: Dict[str, Type[typing.Self]] = {}
  754. def __init_subclass__(cls, **kwargs):
  755. super().__init_subclass__(**kwargs)
  756. JSONConvertible.SUBCLASSES_BY_NAME[cls.__name__] = cls
  757. @classmethod
  758. def subclasses(cls, strict=False):
  759. return [t for t in cls.SUBCLASSES_BY_NAME.values() if issubclass(t, cls) and (not strict or t != cls)]
  760. def to_json(self) -> Dict[str, Any]:
  761. raise NotImplementedError('Abstract method')
  762. @staticmethod
  763. def from_json(data: Dict[str, Any]):
  764. cls = JSONConvertible.SUBCLASSES_BY_NAME[data['type']]
  765. if cls.from_json is JSONConvertible.from_json:
  766. raise NotImplementedError(f'Abstract method from_json of class {cls.__name__}')
  767. return cls.from_json(data)
  768. @classmethod
  769. def schema(cls) -> Dict[str, Any]:
  770. fields = cls.field_types()
  771. result: Dict[str, Any] = {
  772. "type": "object",
  773. "properties": {'type': {'type': 'string', 'enum': [cls.__name__]}},
  774. 'required': ['type'],
  775. }
  776. for field, field_type in fields.items():
  777. optional = hasattr(field_type, '__origin__') and field_type.__origin__ is typing.Union and type(None) in field_type.__args__
  778. if optional:
  779. field_type = Union[tuple(filter(lambda t: t is not type(None), field_type.__args__))]
  780. result['properties'][field] = schema_by_python_type(field_type)
  781. if not optional:
  782. result['required'].append(field)
  783. subclasses = cls.subclasses(strict=True)
  784. if len(subclasses) > 0:
  785. result = {'oneOf': [result] + [{'$ref': f'#/components/schemas/{s.__name__}'} for s in subclasses]}
  786. return result
  787. @classmethod
  788. def schema_for_referencing(cls):
  789. return {'$ref': f'#/components/schemas/{cls.__name__}'}
  790. @classmethod
  791. def schema_python_types(cls) -> Dict[str, Type]:
  792. fields = typing.get_type_hints(cls.__init__)
  793. result: Dict[str, Type] = {}
  794. for field, field_type in fields.items():
  795. result[field] = field_type
  796. return result
  797. @classmethod
  798. def field_types(cls) -> Dict[str, Type]:
  799. return typing.get_type_hints(cls.__init__)
  800. class EBC(JSONConvertible):
  801. def __eq__(self, other):
  802. return type(other) == type(self) and self.__dict__ == other.__dict__
  803. def __str__(self):
  804. return str(self.__dict__)
  805. def __repr__(self):
  806. return f'{type(self).__name__}(**' + str(self.__dict__) + ')'
  807. def to_json(self) -> Dict[str, Any]:
  808. def obj_to_json(o):
  809. if isinstance(o, JSONConvertible):
  810. return o.to_json()
  811. elif isinstance(o, numpy.ndarray):
  812. return array_to_json(o)
  813. elif isinstance(o, list):
  814. return list_to_json(o)
  815. elif isinstance(o, dict):
  816. return dict_to_json(o)
  817. else:
  818. return o # probably a primitive type
  819. def dict_to_json(d: Dict[str, Any]):
  820. return {k: obj_to_json(v) for k, v in d.items()}
  821. def list_to_json(l: list):
  822. return [obj_to_json(v) for v in l]
  823. def array_to_json(o: numpy.ndarray):
  824. return o.tolist()
  825. result: Dict[str, Any] = {
  826. 'type': type(self).__name__,
  827. **self.__dict__,
  828. }
  829. for k in result:
  830. result[k] = obj_to_json(result[k])
  831. return result
  832. @staticmethod
  833. def from_json(data: Dict[str, Any]):
  834. cls = EBC.SUBCLASSES_BY_NAME[data['type']]
  835. return class_from_json(cls, data)
  836. class _Container(EBC):
  837. def __init__(self, item):
  838. self.item = item
  839. def container_to_json(obj: Union[list, dict, numpy.ndarray]) -> Dict[str, Any]:
  840. # Useful when converting list or dicts of json-convertible objects to json (recursively)
  841. e = _Container(obj)
  842. return e.to_json()['item']
  843. class OpenAPISchema:
  844. def __init__(self, schema: dict):
  845. self.schema = schema
  846. def schema_by_python_type(t: Union[Type, dict, OpenAPISchema]):
  847. if t == 'TODO':
  848. return t
  849. if isinstance(t, dict):
  850. return {
  851. "type": "object",
  852. "properties": {n: schema_by_python_type(t2) for n, t2 in t.items()},
  853. 'required': [n for n in t],
  854. }
  855. elif isinstance(t, OpenAPISchema):
  856. return t.schema
  857. elif t is int:
  858. return {'type': 'integer'}
  859. elif t is Any:
  860. return {}
  861. elif t is float:
  862. return {'type': 'number'}
  863. elif t is str:
  864. return {'type': 'string'}
  865. elif t is bool:
  866. return {'type': 'boolean'}
  867. elif t is type(None) or t is None:
  868. return {'type': 'string', 'nullable': True}
  869. elif hasattr(t, '__origin__'): # probably a generic typing-type
  870. if t.__origin__ is list:
  871. return {'type': 'array', 'items': schema_by_python_type(t.__args__[0])}
  872. elif t.__origin__ is dict:
  873. return {'type': 'object', 'additionalProperties': schema_by_python_type(t.__args__[1])}
  874. elif t.__origin__ is tuple and len(set(t.__args__)) == 1:
  875. return {'type': 'array', 'items': schema_by_python_type(t.__args__[0]), 'minItems': len(t.__args__), 'maxItems': len(t.__args__)}
  876. elif t.__origin__ is Union:
  877. return {'oneOf': [schema_by_python_type(t2) for t2 in t.__args__]}
  878. elif t.__origin__ is Literal:
  879. return {'enum': list(t.__args__)}
  880. else:
  881. raise ValueError(f'Unknown type {t}')
  882. elif issubclass(t, list):
  883. return {'type': 'array'}
  884. elif issubclass(t, type) and len(set(t.__args__)) == 1:
  885. return {'type': 'array', 'items': schema_by_python_type(t.__args__[0]), 'minItems': len(t.__args__), 'maxItems': len(t.__args__)}
  886. elif t is datetime.datetime:
  887. from lib.date_time import SerializableDateTime
  888. return schema_by_python_type(SerializableDateTime)
  889. elif hasattr(t, 'schema_for_referencing'):
  890. return t.schema_for_referencing()
  891. else:
  892. raise ValueError(f'Unknown type {t}')
  893. class Synchronizable:
  894. def __init__(self):
  895. self._lock = threading.RLock()
  896. @staticmethod
  897. def synchronized(f):
  898. @functools.wraps(f)
  899. def wrapper(self, *args, **kwargs):
  900. with self._lock:
  901. return f(self, *args, **kwargs)
  902. return wrapper
  903. def class_from_json(cls, data: Dict[str, Any]):
  904. if isinstance(data, str):
  905. data = json.loads(data)
  906. # noinspection PyArgumentList
  907. try:
  908. return cls(**data)
  909. except TypeError as e:
  910. if "__init__() got an unexpected keyword argument 'type'" in str(e):
  911. # probably this was from a to_json method
  912. if data['type'] != cls.__name__:
  913. t = data['type']
  914. logging.warning(f'Reconstructing a {cls.__name__} from a dict with type={t}')
  915. data = data.copy()
  916. del data['type']
  917. for k, v in data.items():
  918. if probably_serialized_from_json_convertible(v):
  919. data[k] = JSONConvertible.SUBCLASSES_BY_NAME[v['type']].from_json(v)
  920. elif isinstance(v, list):
  921. data[k] = [JSONConvertible.SUBCLASSES_BY_NAME[x['type']].from_json(x)
  922. if probably_serialized_from_json_convertible(x)
  923. else x
  924. for x in v]
  925. elif isinstance(v, dict):
  926. data[k] = {
  927. k2: JSONConvertible.SUBCLASSES_BY_NAME[x['type']].from_json(x)
  928. if probably_serialized_from_json_convertible(x)
  929. else x
  930. for k2, x in v.items()
  931. }
  932. try:
  933. return cls(**data)
  934. except TypeError:
  935. return allow_additional_unused_keyword_arguments(cls)(**data)
  936. def probably_serialized_from_json_convertible(data):
  937. return isinstance(data, dict) and 'type' in data and data['type'] in JSONConvertible.SUBCLASSES_BY_NAME
  938. class EBE(JSONConvertible, Enum):
  939. def __int__(self):
  940. return self.value
  941. def __str__(self):
  942. return self.name
  943. def __repr__(self):
  944. return self.name
  945. def __lt__(self, other):
  946. return list(type(self)).index(self) < list(type(self)).index(other)
  947. def __ge__(self, other):
  948. return list(type(self)).index(self) >= list(type(self)).index(other)
  949. @classmethod
  950. def from_name(cls, variable_name) -> typing.Self:
  951. return cls.__dict__[variable_name]
  952. def to_json(self) -> Dict[str, Any]:
  953. return {
  954. 'type': type(self).__name__,
  955. 'name': self.name,
  956. }
  957. @classmethod
  958. def schema(cls) -> Dict[str, Any]:
  959. return {
  960. "type": "object",
  961. "properties": {'type': schema_by_python_type(str), 'name': schema_by_python_type(str)}
  962. }
  963. @staticmethod
  964. def from_json(data: Dict[str, Any]) -> 'JSONConvertible':
  965. cls = EBE.SUBCLASSES_BY_NAME[data['type']]
  966. return cls.from_name(data['name'])
  967. def floor_to_multiple_of(x, multiple_of):
  968. return math.floor(x / multiple_of) * multiple_of
  969. def ceil_to_multiple_of(x, multiple_of):
  970. return math.ceil(x / multiple_of) * multiple_of
  971. def call_tool(command, cwd=None, shell=False):
  972. try:
  973. print(f'Calling `{" ".join(command)}`...')
  974. sub_env = os.environ.copy()
  975. output: bytes = check_output(command, stderr=PIPE, env=sub_env, cwd=cwd, shell=shell)
  976. output: str = output.decode('utf-8', errors='ignore')
  977. return output
  978. except CalledProcessError as e:
  979. stdout = e.stdout.decode('utf-8', errors='ignore')
  980. stderr = e.stderr.decode('utf-8', errors='ignore')
  981. if len(stdout) == 0:
  982. print('stdout was empty.')
  983. else:
  984. print('stdout was: ')
  985. print(stdout)
  986. if len(stderr) == 0:
  987. print('stderr was empty.')
  988. else:
  989. print('stderr was: ')
  990. print(stderr)
  991. raise