parameter_search.py 68 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545
  1. """
  2. made by Eren Yilmaz
  3. """
  4. import functools
  5. import itertools
  6. from copy import deepcopy
  7. from goto import with_goto
  8. from lib import util
  9. from lib.tuned_cache import TunedMemory
  10. from lib.progress_bar import ProgressBar
  11. import pickle
  12. from matplotlib.axes import Axes
  13. import os
  14. import random
  15. import sqlite3
  16. from datetime import datetime
  17. from math import inf, nan, ceil, sqrt
  18. from timeit import default_timer
  19. from typing import Dict, Any, List, Callable, Optional, Tuple, Iterable, Union
  20. import numpy
  21. import pandas
  22. import scipy.stats
  23. import seaborn as sns
  24. import matplotlib.pyplot as plt
  25. import matplotlib.dates
  26. from lib.util import my_tabulate, round_to_digits, print_progress_bar, heatmap_from_points, LogicError, shorten_name
  27. score_cache = TunedMemory(location='.cache/scores', verbose=0)
  28. sns.set(style="whitegrid", font_scale=1.5)
  29. # set up db for results
  30. connection: sqlite3.Connection = sqlite3.connect('random_search_results.db')
  31. connection.cursor().execute('PRAGMA foreign_keys = 1')
  32. connection.cursor().execute('PRAGMA journal_mode = WAL')
  33. connection.cursor().execute('PRAGMA synchronous = NORMAL')
  34. Parameters = Dict[str, Any]
  35. MetricValue = float
  36. Metrics = Dict[str, MetricValue]
  37. History = List[MetricValue]
  38. suppress_intermediate_beeps = False
  39. label = goto = object()
  40. class Prediction:
  41. def __init__(self, dataset: str, y_true, y_pred, name: str):
  42. self.y_pred = y_pred
  43. self.y_true = y_true
  44. self.name = name
  45. if not isinstance(name, str):
  46. self.name = str(name)
  47. else:
  48. self.name = name
  49. if not isinstance(dataset, str):
  50. raise TypeError
  51. self.dataset = dataset
  52. def __str__(self):
  53. return str(self.__dict__)
  54. def __repr__(self):
  55. return self.__class__.__name__ + repr({k: v for k, v in self.__dict__.items() if k != 'predictions'})
  56. class EvaluationResult:
  57. def __init__(self,
  58. results: Dict[str, MetricValue],
  59. parameters: Parameters = None,
  60. predictions: Optional[List[Prediction]] = None):
  61. self.predictions = predictions
  62. if self.predictions is not None:
  63. self.predictions = self.predictions.copy()
  64. else:
  65. self.predictions = []
  66. if parameters is None:
  67. self.parameters = parameters
  68. else:
  69. self.parameters = parameters.copy()
  70. self.results = results.copy()
  71. def __iter__(self):
  72. yield self
  73. def __eq__(self, other):
  74. return (isinstance(other, EvaluationResult)
  75. and self.parameters == other.parameters
  76. and self.predictions == other.predictions
  77. and self.results == other.results)
  78. def __str__(self):
  79. return '{0}{1}'.format(self.__class__.__name__,
  80. {k: getattr(self, k) for k in ['results', 'parameters', 'predictions']})
  81. def __repr__(self):
  82. return self.__class__.__name__ + repr(self.__dict__)
  83. assert list(EvaluationResult({}, {})) == list([EvaluationResult({}, {})])
  84. EvaluationFunction = Callable[[Parameters], Union[List[EvaluationResult], EvaluationResult, List[float], float]]
  85. class Parameter:
  86. def __init__(self, name: str, initial_value, larger_value, smaller_value, first_try_increase=False):
  87. self.name = name
  88. self.initial_value = initial_value
  89. self.larger_value = larger_value
  90. self.smaller_value = smaller_value
  91. self.first_try_increase = first_try_increase
  92. def __repr__(self):
  93. return self.__class__.__name__ + repr(self.__dict__)
  94. def copy(self, new_name=None):
  95. result: Parameter = deepcopy(self)
  96. if new_name is not None:
  97. result.name = new_name
  98. return result
  99. class BoundedParameter(Parameter):
  100. def __init__(self, name, initial_value, larger_value, smaller_value, minimum=-inf, maximum=inf,
  101. first_try_increase=False):
  102. self.minimum = minimum
  103. self.maximum = maximum
  104. super().__init__(name,
  105. initial_value,
  106. lambda x: self._bounded(larger_value(x)),
  107. lambda x: self._bounded(smaller_value(x)),
  108. first_try_increase=first_try_increase)
  109. if self.initial_value < self.minimum:
  110. raise ValueError('Initial value is lower than minimum value.')
  111. if self.initial_value > self.maximum:
  112. raise ValueError('Initial value is larger than maximum value.')
  113. def _bounded(self, y):
  114. y = max(self.minimum, y)
  115. y = min(self.maximum, y)
  116. return y
  117. class ConstantParameter(Parameter):
  118. def __init__(self, name, value):
  119. super().__init__(name,
  120. value,
  121. lambda x: value,
  122. lambda x: value)
  123. class BinaryParameter(Parameter):
  124. def __init__(self, name, value1, value2):
  125. super().__init__(name,
  126. value1,
  127. lambda x: value2 if x == value1 else value1,
  128. lambda x: value2 if x == value1 else value1)
  129. class BooleanParameter(Parameter):
  130. def __init__(self, name, initial_value: bool):
  131. super().__init__(name,
  132. bool(initial_value),
  133. lambda x: not x,
  134. lambda x: not x)
  135. class TernaryParameter(Parameter):
  136. def __init__(self, name, value1, value2, value3):
  137. self.smaller = {value1: value3, value2: value1, value3: value2}
  138. self.larger = {value1: value2, value2: value3, value3: value1}
  139. super().__init__(name,
  140. value1,
  141. lambda x: self.smaller[x],
  142. lambda x: self.larger[x])
  143. class ListParameter(Parameter):
  144. def __init__(self, name, initial_value, possible_values: List, first_try_increase=False, circle=False):
  145. self.possible_values = possible_values.copy()
  146. if initial_value not in self.possible_values:
  147. raise ValueError()
  148. if len(set(self.possible_values)) != len(self.possible_values):
  149. print('WARNING: It seems that there are duplicates in the list of possible values for {0}'.format(name))
  150. length = len(self.possible_values)
  151. if circle:
  152. smaller = lambda x: self.possible_values[(self.possible_values.index(x) + 1) % length]
  153. larger = lambda x: self.possible_values[(self.possible_values.index(x) - 1) % length]
  154. else:
  155. smaller = lambda x: self.possible_values[min(self.possible_values.index(x) + 1, length - 1)]
  156. larger = lambda x: self.possible_values[max(self.possible_values.index(x) - 1, 0)]
  157. super().__init__(name,
  158. initial_value,
  159. smaller,
  160. larger,
  161. first_try_increase=first_try_increase)
  162. class ExponentialParameter(BoundedParameter):
  163. def __init__(self, name, initial_value, base, minimum=-inf, maximum=inf, first_try_increase=False):
  164. super().__init__(name,
  165. initial_value,
  166. lambda x: float(x * base),
  167. lambda x: float(x / base),
  168. minimum,
  169. maximum,
  170. first_try_increase=first_try_increase)
  171. self.plot_scale = 'log'
  172. class ExponentialIntegerParameter(BoundedParameter):
  173. def __init__(self, name, initial_value, base, minimum=-inf, maximum=inf, first_try_increase=False):
  174. if minimum != -inf:
  175. minimum = round(minimum)
  176. if maximum != inf:
  177. maximum = round(maximum)
  178. super().__init__(name,
  179. round(initial_value),
  180. lambda x: round(x * base),
  181. lambda x: round(x / base),
  182. minimum,
  183. maximum,
  184. first_try_increase=first_try_increase)
  185. self.plot_scale = 'log'
  186. class LinearParameter(BoundedParameter):
  187. def __init__(self, name, initial_value, summand, minimum=-inf, maximum=inf, first_try_increase=False):
  188. super().__init__(name,
  189. initial_value,
  190. lambda x: float(x + summand),
  191. lambda x: float(x - summand),
  192. minimum,
  193. maximum,
  194. first_try_increase=first_try_increase)
  195. class LinearIntegerParameter(BoundedParameter):
  196. def __init__(self, name, initial_value, summand, minimum=-inf, maximum=inf, first_try_increase=False):
  197. super().__init__(name,
  198. initial_value,
  199. lambda x: x + summand,
  200. lambda x: x - summand,
  201. minimum,
  202. maximum,
  203. first_try_increase=first_try_increase)
  204. class InvalidParametersError(Exception):
  205. def __init__(self, parameters=None):
  206. self.parameters = parameters
  207. class BadParametersError(InvalidParametersError):
  208. pass
  209. class InvalidReturnError(Exception):
  210. pass
  211. class EmptyTableError(Exception):
  212. pass
  213. EXAMPLE_PARAMS = [
  214. ExponentialParameter('learn_rate', 0.001, 10),
  215. ExponentialIntegerParameter('hidden_layer_size', 512, 2, minimum=1),
  216. LinearIntegerParameter('hidden_layer_count', 3, 1, minimum=0),
  217. ExponentialIntegerParameter('epochs', 100, 5, minimum=1),
  218. LinearParameter('dropout_rate', 0.5, 0.2, minimum=0, maximum=1),
  219. ]
  220. def mean_confidence_interval_size(data, confidence=0.95, force_v: Optional[int] = None,
  221. force_sem: Optional[float] = None):
  222. if len(data) == 0:
  223. return nan
  224. if force_sem is None:
  225. if len(data) == 1:
  226. return inf
  227. sem = scipy.stats.sem(data)
  228. else:
  229. sem = force_sem
  230. if sem == 0:
  231. return 0
  232. if force_v is None:
  233. v = len(data) - 1
  234. else:
  235. v = force_v
  236. return numpy.mean(data) - scipy.stats.t.interval(confidence,
  237. df=v,
  238. loc=numpy.mean(data),
  239. scale=sem)[0]
  240. def try_parameters(experiment_name: str,
  241. evaluate: EvaluationFunction,
  242. params: Dict[str, any],
  243. optimize: Optional[str] = None,
  244. larger_result_is_better: bool = None, ):
  245. print('running experiment...')
  246. params = params.copy()
  247. if larger_result_is_better is None and optimize is not None:
  248. raise NotImplementedError(
  249. 'Don\'t know how to optimize {0}. Did you specify `larger_result_is_better`?'.format(optimize))
  250. assert larger_result_is_better is not None or optimize is None
  251. worst_score = -inf if larger_result_is_better else inf
  252. cursor = connection.cursor()
  253. start = default_timer()
  254. try:
  255. result = evaluate(params)
  256. if not isinstance(result, Iterable):
  257. result = [result]
  258. evaluation_results: List[EvaluationResult] = list(result)
  259. except InvalidParametersError as e:
  260. if optimize is not None:
  261. bad_results: Dict[str, float] = {
  262. optimize: worst_score
  263. }
  264. else:
  265. bad_results = {}
  266. if e.parameters is None:
  267. evaluation_results = [EvaluationResult(
  268. parameters=params,
  269. results=bad_results
  270. )]
  271. else:
  272. evaluation_results = [EvaluationResult(
  273. parameters=e.parameters,
  274. results=bad_results
  275. )]
  276. finally:
  277. duration = default_timer() - start
  278. for idx in range(len(evaluation_results)):
  279. if isinstance(evaluation_results[idx], float):
  280. evaluation_results[idx] = EvaluationResult(parameters=params,
  281. results={optimize: evaluation_results[idx]})
  282. p_count = 0
  283. for evaluation_result in evaluation_results:
  284. if evaluation_result.parameters is None:
  285. evaluation_result.parameters = params
  286. metric_names = sorted(evaluation_result.results.keys())
  287. param_names = list(sorted(evaluation_result.parameters.keys()))
  288. for metric_name in metric_names:
  289. add_metric_column(experiment_name, metric_name, verbose=1)
  290. for param_name in param_names:
  291. add_parameter_column(experiment_name, param_name, evaluation_result.parameters[param_name], verbose=1)
  292. if not set(param_names).isdisjoint(metric_names):
  293. raise RuntimeError('Metrics and parameter names should be disjoint')
  294. if optimize is not None and numpy.isnan(evaluation_result.results[optimize]):
  295. evaluation_result.results[optimize] = worst_score
  296. metric_values = [evaluation_result.results[metric_name] for metric_name in metric_names]
  297. param_names_comma_separated = ','.join('"' + param_name + '"' for param_name in param_names)
  298. metric_names_comma_separated = ','.join('"' + metric_name + '"' for metric_name in metric_names)
  299. insert_question_marks = ','.join('?' for _ in range(len(param_names) + len(metric_names)))
  300. cursor.execute('''
  301. INSERT INTO {0} ({1}) VALUES ({2})
  302. '''.format(experiment_name,
  303. param_names_comma_separated + ',' + metric_names_comma_separated,
  304. insert_question_marks), (*[evaluation_result.parameters[name] for name in param_names],
  305. *metric_values))
  306. result_id = cursor.lastrowid
  307. assert cursor.execute(f'SELECT COUNT(*) FROM {experiment_name}_predictions WHERE result_id = ? LIMIT 1',
  308. (result_id,)).fetchone()[0] == 0
  309. p_count += len(evaluation_result.predictions)
  310. dataset_names = [(prediction.dataset, prediction.name) for prediction in evaluation_result.predictions]
  311. if len(set(dataset_names)) != len(dataset_names):
  312. print('\n'.join(sorted(dsn
  313. for idx, dsn in dataset_names
  314. if dsn in dataset_names[idx:])))
  315. raise InvalidReturnError(
  316. 'Every combination of name and dataset in a single evaluation result must be unique.'
  317. 'There should be a list of duplicates printed above where the number of occurrences'
  318. 'of an element in the list is the actual number of occurrences minus 1 '
  319. '(so only duplicates are listed).')
  320. # noinspection SqlResolve
  321. cursor.executemany('''
  322. INSERT INTO {0}_predictions (dataset, y_true, y_pred, result_id, name)
  323. VALUES (?, ?, ?, ?, ?)
  324. '''.format(experiment_name), [(prediction.dataset,
  325. pickle.dumps(prediction.y_true),
  326. pickle.dumps(prediction.y_pred),
  327. result_id,
  328. prediction.name) for prediction in evaluation_result.predictions])
  329. connection.commit()
  330. print('saved', len(evaluation_results), 'results and', p_count, 'predictions to db')
  331. if not suppress_intermediate_beeps:
  332. util.beep(1000, 500)
  333. if optimize is not None:
  334. scores = [r.results[optimize] for r in evaluation_results]
  335. if larger_result_is_better:
  336. best_score = max(scores)
  337. else:
  338. best_score = min(scores)
  339. print(' finished in', duration, 'seconds, best loss/score:', best_score)
  340. for r in evaluation_results:
  341. if list(sorted(r.results.keys())) != list(sorted(metric_names)):
  342. raise InvalidReturnError("""
  343. Wrong metric names were returned by `evaluate`:
  344. Expected metric_names={0}
  345. but was {1}.
  346. The result was saved to database anyways, possibly with missing values.
  347. """.format(list(sorted(metric_names)),
  348. list(sorted(r.results.keys()))))
  349. return evaluation_results
  350. def config_dict_from_param_list(params: List[Parameter]):
  351. return {
  352. p.name: p.initial_value
  353. for p in params
  354. }
  355. def evaluate_with_initial_params(experiment_name: str,
  356. params: List[Parameter],
  357. evaluate: EvaluationFunction,
  358. optimize: str,
  359. larger_result_is_better: bool,
  360. metric_names=None,
  361. num_experiments=1, ):
  362. random_parameter_search(experiment_name=experiment_name,
  363. params=params,
  364. evaluate=evaluate,
  365. optimize=optimize,
  366. larger_result_is_better=larger_result_is_better,
  367. mutation_probability=1.,
  368. no_mutations_probability=0.,
  369. max_num_experiments=num_experiments,
  370. metric_names=metric_names,
  371. initial_experiments=num_experiments,
  372. experiment_count='db_tries_initial', )
  373. def random_parameter_search(experiment_name: str,
  374. params: List[Parameter],
  375. evaluate: EvaluationFunction,
  376. optimize: str,
  377. larger_result_is_better: bool,
  378. mutation_probability: float = None,
  379. no_mutations_probability: float = None,
  380. allow_multiple_mutations=False,
  381. max_num_experiments=inf,
  382. metric_names=None,
  383. initial_experiments=1,
  384. runs_per_configuration=inf,
  385. initial_runs=1,
  386. ignore_configuration_condition='0',
  387. experiment_count='tries', ):
  388. print('experiment name:', experiment_name)
  389. if metric_names is None:
  390. metric_names = [optimize]
  391. if optimize not in metric_names:
  392. raise ValueError('trying to optimize {0} but only metrics available are {1}'.format(optimize, metric_names))
  393. params = sorted(params, key=lambda p: p.name)
  394. validate_parameter_set(params)
  395. param_names = [param.name for param in params]
  396. cursor = connection.cursor()
  397. create_experiment_tables_if_not_exists(experiment_name, params, metric_names)
  398. def max_tries_reached(ps):
  399. return len(result_ids_for_parameters(experiment_name, ps)) >= runs_per_configuration
  400. def min_tries_reached(ps):
  401. return len(result_ids_for_parameters(experiment_name, ps)) >= initial_runs
  402. def try_(ps) -> bool:
  403. tried = False
  404. if not max_tries_reached(ps):
  405. try_parameters(experiment_name=experiment_name,
  406. evaluate=evaluate,
  407. params=ps,
  408. optimize=optimize,
  409. larger_result_is_better=larger_result_is_better, )
  410. tried = True
  411. else:
  412. print('Skipping because maximum number of tries is already reached.')
  413. while not min_tries_reached(ps):
  414. print('Repeating because minimum number of tries is not reached.')
  415. try_parameters(experiment_name=experiment_name,
  416. evaluate=evaluate,
  417. params=ps,
  418. optimize=optimize,
  419. larger_result_is_better=larger_result_is_better, )
  420. tried = True
  421. return tried
  422. if mutation_probability is None:
  423. mutation_probability = 1 / (len(params) + 1)
  424. if no_mutations_probability is None:
  425. no_mutations_probability = (1 - 1 / len(params)) / 4
  426. initial_params = {param.name: param.initial_value for param in params}
  427. print('initial parameters:', initial_params)
  428. def skip():
  429. best_scores, best_mean, best_std, best_conf = get_results_for_params(optimize, experiment_name,
  430. best_params, 0.99)
  431. try_scores, try_mean, try_std, try_conf = get_results_for_params(optimize, experiment_name,
  432. try_params, 0.99)
  433. if larger_result_is_better:
  434. if best_mean - best_conf > try_mean + try_conf:
  435. return True
  436. else:
  437. if best_mean + best_conf < try_mean - try_conf:
  438. return True
  439. return False
  440. # get best params
  441. initial_params = {param.name: param.initial_value for param in params}
  442. any_results = cursor.execute('SELECT EXISTS (SELECT * FROM {0} WHERE NOT ({1}))'.format(experiment_name,
  443. ignore_configuration_condition)).fetchone()[
  444. 0]
  445. if any_results:
  446. best_params = get_best_params(experiment_name,
  447. larger_result_is_better,
  448. optimize,
  449. param_names,
  450. additional_condition=f'NOT ({ignore_configuration_condition})')
  451. else:
  452. best_params = initial_params
  453. try_params = best_params.copy()
  454. def results_for_params(ps):
  455. return get_results_for_params(
  456. metric=optimize,
  457. experiment_name=experiment_name,
  458. parameters=ps
  459. )
  460. if experiment_count == 'tries':
  461. num_experiments = 0
  462. elif experiment_count == 'results':
  463. num_experiments = 0
  464. elif experiment_count == 'db_total':
  465. num_experiments = cursor.execute('SELECT COUNT(*) FROM {0} WHERE NOT ({1})'.format(experiment_name,
  466. ignore_configuration_condition)).fetchone()[
  467. 0]
  468. elif experiment_count == 'db_tries_best':
  469. num_experiments = len(result_ids_for_parameters(experiment_name, best_params))
  470. elif experiment_count == 'db_tries_initial':
  471. num_experiments = len(result_ids_for_parameters(experiment_name, initial_params))
  472. else:
  473. raise ValueError('Invalid argument for experiment_count')
  474. last_best_score = results_for_params(best_params)[1]
  475. while num_experiments < max_num_experiments:
  476. if num_experiments < initial_experiments:
  477. try_params = initial_params.copy()
  478. else:
  479. any_results = \
  480. cursor.execute('SELECT EXISTS (SELECT * FROM {0} WHERE NOT ({1}))'.format(experiment_name,
  481. ignore_configuration_condition)).fetchone()[
  482. 0]
  483. if any_results:
  484. last_best_params = best_params
  485. best_params = get_best_params(experiment_name,
  486. larger_result_is_better,
  487. optimize,
  488. param_names,
  489. additional_condition=f'NOT ({ignore_configuration_condition})')
  490. best_scores, best_score, _, best_conf_size = results_for_params(best_params)
  491. if last_best_score is not None and best_score is not None:
  492. if last_best_params != best_params:
  493. if last_best_score < best_score and larger_result_is_better or last_best_score > best_score and not larger_result_is_better:
  494. print(' --> Parameters were improved by this change!')
  495. if last_best_score > best_score and larger_result_is_better or last_best_score < best_score and not larger_result_is_better:
  496. print(' --> Actually other parameters are better...')
  497. last_best_score = best_score
  498. # print('currently best parameters:', best_params)
  499. changed_params = {k: v for k, v in best_params.items() if best_params[k] != initial_params[k]}
  500. print('currently best parameters (excluding unchanged parameters):', changed_params)
  501. print('currently best score:', best_score, 'conf.', best_conf_size, 'num.', len(best_scores))
  502. else:
  503. best_params = {param.name: param.initial_value for param in params}
  504. best_conf_size = inf
  505. try_params = best_params.copy()
  506. verbose = 1
  507. if best_conf_size != inf:
  508. if random.random() > no_mutations_probability:
  509. modify_params_randomly(mutation_probability, params, try_params, verbose,
  510. allow_multiple_mutations=allow_multiple_mutations)
  511. if num_experiments < initial_experiments:
  512. try_params = initial_params.copy()
  513. else:
  514. # check if this already has a bad score
  515. if skip():
  516. print('skipping because this set of parameters is known to be worse with high probability.')
  517. print()
  518. continue
  519. # print('trying parameters', {k: v for k, v in try_params.items() if try_params[k] != initial_params[k]})
  520. results = try_(try_params)
  521. if experiment_count == 'tries':
  522. num_experiments += 1
  523. elif experiment_count == 'results':
  524. num_experiments += len(results)
  525. elif experiment_count == 'db_total':
  526. num_experiments = cursor.execute('SELECT COUNT(*) FROM {0}'.format(experiment_name)).fetchone()[0]
  527. elif experiment_count == 'db_tries_best':
  528. num_experiments = len(result_ids_for_parameters(experiment_name, best_params))
  529. elif experiment_count == 'db_tries_initial':
  530. num_experiments = len(result_ids_for_parameters(experiment_name, initial_params))
  531. else:
  532. raise LogicError('It is not possible that this is reached.')
  533. @with_goto
  534. def diamond_parameter_search(experiment_name: str,
  535. diamond_size: int,
  536. params: List[Parameter],
  537. evaluate: EvaluationFunction,
  538. optimize: str,
  539. larger_result_is_better: bool,
  540. runs_per_configuration=inf,
  541. initial_runs=1,
  542. metric_names=None,
  543. filter_results_condition='1'):
  544. print('experiment name:', experiment_name)
  545. if metric_names is None:
  546. metric_names = [optimize]
  547. if optimize not in metric_names:
  548. raise ValueError('trying to optimize {0} but only metrics available are {1}'.format(optimize, metric_names))
  549. print('Optimizing metric', optimize)
  550. if runs_per_configuration > initial_runs:
  551. print(
  552. f'WARNING: You are using initial_runs={initial_runs} and runs_per_configuration={runs_per_configuration}. '
  553. f'This may lead to unexpected results if you dont know what you are doing.')
  554. params_in_original_order = params
  555. params = sorted(params, key=lambda p: p.name)
  556. validate_parameter_set(params)
  557. create_experiment_tables_if_not_exists(experiment_name, params, metric_names)
  558. initial_params = {param.name: param.initial_value for param in params}
  559. print('initial parameters:', initial_params)
  560. # get best params
  561. initial_params = {param.name: param.initial_value for param in params}
  562. try:
  563. best_params = get_best_params_and_compare_with_initial(experiment_name, initial_params, larger_result_is_better,
  564. optimize,
  565. additional_condition=filter_results_condition)
  566. except EmptyTableError:
  567. best_params = initial_params
  568. def max_tries_reached(ps):
  569. return len(result_ids_for_parameters(experiment_name, ps)) >= runs_per_configuration
  570. def min_tries_reached(ps):
  571. return len(result_ids_for_parameters(experiment_name, ps)) >= initial_runs
  572. def try_(ps) -> bool:
  573. tried = False
  574. if not max_tries_reached(ps):
  575. try_parameters(experiment_name=experiment_name,
  576. evaluate=evaluate,
  577. params=ps,
  578. optimize=optimize,
  579. larger_result_is_better=larger_result_is_better, )
  580. tried = True
  581. else:
  582. print('Skipping because maximum number of tries is already reached.')
  583. while not min_tries_reached(ps):
  584. print('Repeating because minimum number of tries is not reached.')
  585. try_parameters(experiment_name=experiment_name,
  586. evaluate=evaluate,
  587. params=ps,
  588. optimize=optimize,
  589. larger_result_is_better=larger_result_is_better, )
  590. tried = True
  591. return tried
  592. last_best_score = results_for_params(optimize, experiment_name, best_params)[1]
  593. modifications_steps = [
  594. {'param_name': param.name, 'direction': direction}
  595. for param in params_in_original_order
  596. for direction in ([param.larger_value, param.smaller_value] if param.first_try_increase
  597. else [param.smaller_value, param.larger_value])
  598. ]
  599. label.restart
  600. restart_scheduled = False
  601. while True: # repeatedly iterate parameters
  602. any_tries_done_this_iteration = False
  603. for num_modifications in range(diamond_size + 1): # first try small changes, later larger changes
  604. modification_sets = itertools.product(*(modifications_steps for _ in range(num_modifications)))
  605. for modifications in modification_sets: # which modifications to try this time
  606. while True: # repeatedly modify parameters in this direction
  607. improvement_found_in_this_iteration = False
  608. try_params = best_params.copy()
  609. for modification in modifications:
  610. try_params[modification['param_name']] = modification['direction'](
  611. try_params[modification['param_name']])
  612. for param_name, param_value in try_params.items():
  613. if best_params[param_name] != param_value:
  614. print(f'Setting {param_name} = {param_value} for the next run.')
  615. if try_params == best_params:
  616. print('Repeating experiment with best found parameters.')
  617. if try_(try_params): # if the experiment was actually conducted
  618. any_tries_done_this_iteration = True
  619. best_params = get_best_params_and_compare_with_initial(experiment_name, initial_params,
  620. larger_result_is_better, optimize,
  621. filter_results_condition)
  622. last_best_params = best_params
  623. best_scores, best_score, _, best_conf_size = results_for_params(optimize, experiment_name,
  624. best_params)
  625. changed_params = {k: v for k, v in best_params.items() if best_params[k] != initial_params[k]}
  626. print('currently best parameters (excluding unchanged parameters):', changed_params)
  627. print('currently best score:', best_score, 'conf.', best_conf_size, 'num.', len(best_scores))
  628. else:
  629. last_best_params = best_params
  630. _, best_score, _, best_conf_size = results_for_params(optimize, experiment_name, best_params)
  631. if last_best_score is not None and best_score is not None:
  632. if last_best_params != best_params:
  633. if last_best_score < best_score and larger_result_is_better or last_best_score > best_score and not larger_result_is_better:
  634. print(' --> Parameters were improved by this change!')
  635. improvement_found_in_this_iteration = True
  636. if num_modifications > 1:
  637. # two or more parameters were modified and this improved the results -> first try to modify them again in the same direction,
  638. # then restart the search from the best found configuration
  639. restart_scheduled = True
  640. elif last_best_score > best_score and larger_result_is_better or last_best_score < best_score and not larger_result_is_better:
  641. print(' --> Actually other parameters are better...')
  642. if not improvement_found_in_this_iteration:
  643. break # stop if no improvement was found in this direction
  644. if restart_scheduled:
  645. break
  646. if restart_scheduled:
  647. break
  648. if restart_scheduled:
  649. goto.restart
  650. if not any_tries_done_this_iteration:
  651. break # parameter search finished (converged in some sense)
  652. cross_parameter_search = functools.partial(diamond_parameter_search, diamond_size=1)
  653. cross_parameter_search.__name__ = 'cross_parameter_search'
  654. def get_best_params_and_compare_with_initial(experiment_name, initial_params, larger_result_is_better, optimize,
  655. additional_condition='1'):
  656. best_params = get_best_params(experiment_name, larger_result_is_better, optimize, list(initial_params),
  657. additional_condition=additional_condition)
  658. changed_params = {k: v for k, v in best_params.items() if best_params[k] != initial_params[k]}
  659. best_scores, best_score, _, best_conf_size = results_for_params(optimize, experiment_name, best_params)
  660. print('currently best parameters (excluding unchanged parameters):', changed_params)
  661. print('currently best score:', best_score, 'conf.', best_conf_size, 'num.', len(best_scores))
  662. return best_params
  663. def results_for_params(optimize, experiment_name, ps):
  664. return get_results_for_params(
  665. metric=optimize,
  666. experiment_name=experiment_name,
  667. parameters=ps
  668. )
  669. def modify_params_randomly(mutation_probability, params, try_params, verbose, allow_multiple_mutations=False):
  670. for param in params:
  671. while random.random() < mutation_probability:
  672. next_value = random.choice([param.smaller_value, param.larger_value])
  673. old_value = try_params[param.name]
  674. try:
  675. try_params[param.name] = round_to_digits(next_value(try_params[param.name]), 4)
  676. except TypeError: # when the parameter is not a number
  677. try_params[param.name] = next_value(try_params[param.name])
  678. if verbose and try_params[param.name] != old_value:
  679. print('setting', param.name, '=', try_params[param.name], 'for this run')
  680. if not allow_multiple_mutations:
  681. break
  682. def finish_experiments(experiment_name: str,
  683. params: List[Parameter],
  684. optimize: str,
  685. larger_result_is_better: bool,
  686. metric_names=None,
  687. filter_results_table='1',
  688. max_display_results=None,
  689. print_results_table=False,
  690. max_table_row_count=inf,
  691. plot_metrics_by_metrics=False,
  692. plot_metric_over_time=False,
  693. plot_metrics_by_parameters=False, ):
  694. if max_display_results is inf:
  695. max_display_results = None
  696. if metric_names is None:
  697. metric_names = [optimize]
  698. # get the best parameters
  699. cursor = connection.cursor()
  700. params = sorted(params, key=lambda param: param.name)
  701. param_names = sorted(set(param.name for param in params))
  702. param_names_comma_separated = ','.join('"' + param_name + '"' for param_name in param_names)
  703. best_params = get_best_params(experiment_name, larger_result_is_better, optimize, param_names,
  704. additional_condition=filter_results_table, )
  705. best_score = get_results_for_params(
  706. metric=optimize,
  707. experiment_name=experiment_name,
  708. parameters=best_params
  709. )
  710. initial_params = {param.name: param.initial_value for param in params}
  711. # get a list of all results with mean std and conf
  712. if print_results_table or plot_metrics_by_parameters or plot_metrics_by_metrics:
  713. concatenated_metric_names = ','.join('GROUP_CONCAT("' + metric_name + '", \'@\') AS ' + metric_name
  714. for metric_name in metric_names)
  715. worst_score = '-1e999999' if larger_result_is_better else '1e999999'
  716. limit_string = f'LIMIT {max_table_row_count}' if max_table_row_count is not None and max_table_row_count < inf else ''
  717. # noinspection SqlAggregates
  718. cursor.execute('''
  719. SELECT {1}, {4}
  720. FROM {0} AS params
  721. WHERE ({5})
  722. GROUP BY {1}
  723. ORDER BY AVG(CASE WHEN params.{3} IS NULL THEN {6} ELSE params.{3} END) {2}
  724. {7}
  725. '''.format(experiment_name,
  726. param_names_comma_separated,
  727. 'DESC' if larger_result_is_better else 'ASC',
  728. optimize,
  729. concatenated_metric_names,
  730. filter_results_table,
  731. worst_score,
  732. limit_string))
  733. all_results = cursor.fetchall()
  734. column_description = list(cursor.description)
  735. for idx, row in enumerate(all_results):
  736. all_results[idx] = list(row)
  737. # prepare results table
  738. if print_results_table or plot_metrics_by_metrics or plot_metrics_by_parameters:
  739. iterations = 0
  740. print('Generating table of parameters')
  741. for column_index, column in list(enumerate(column_description))[::-1]: # reverse
  742. print_progress_bar(iterations, len(metric_names))
  743. column_name = column[0]
  744. column_description[column_index] = column
  745. if column_name in metric_names:
  746. if max_display_results > 0:
  747. column_description[column_index] = column_name + ' values'
  748. column_description.insert(column_index + 1, column_name + ' mean')
  749. column_description.insert(column_index + 2, column_name + ' std')
  750. column_description.insert(column_index + 3, column_name + ' conf')
  751. # noinspection PyUnusedLocal
  752. list_row: List
  753. for list_row in all_results:
  754. string_values: str = list_row[column_index]
  755. if string_values is None:
  756. metric_values: List[float] = [nan]
  757. else:
  758. metric_values = list(map(float, string_values.split('@')))
  759. list_row[column_index] = [round_to_digits(x, 3) for x in metric_values[:max_display_results]]
  760. list_row.insert(column_index + 1, numpy.mean(metric_values))
  761. list_row.insert(column_index + 2, numpy.std(metric_values))
  762. list_row.insert(column_index + 3, mean_confidence_interval_size(metric_values))
  763. if all(len(list_row[column_index]) == 0 for list_row in all_results):
  764. del column_description[column_index]
  765. for list_row in all_results:
  766. del list_row[column_index]
  767. iterations += 1
  768. else:
  769. column_description[column_index] = column_name
  770. print_progress_bar(iterations, len(metric_names))
  771. if print_results_table: # actually print the table
  772. table = my_tabulate(all_results,
  773. headers=column_description,
  774. tablefmt='pipe')
  775. print(table)
  776. cursor.execute('''
  777. SELECT COUNT(*)
  778. FROM {0}
  779. '''.format(experiment_name))
  780. print('Total number of rows, experiments, cells in this table:',
  781. (len(all_results), cursor.fetchone()[0], len(all_results) * len(all_results[0])))
  782. print('Best parameters:', best_params)
  783. changed_params = {k: v for k, v in best_params.items() if best_params[k] != initial_params[k]}
  784. print('Best parameters (excluding unchanged parameters):', changed_params)
  785. print('loss/score for best parameters (mean, std, conf):', best_score[1:])
  786. if plot_metrics_by_parameters or plot_metrics_by_metrics:
  787. print('Loading data...')
  788. df = pandas.DataFrame.from_records(all_results, columns=param_names + [x
  789. for name in metric_names
  790. for x in
  791. [
  792. name + '_values',
  793. name + '_mean',
  794. name + '_std',
  795. name + '_conf'
  796. ]])
  797. if plot_metrics_by_parameters:
  798. print('Plotting metrics by parameter...')
  799. plots = [
  800. (param.name,
  801. getattr(param, 'plot_scale', None),
  802. param.smaller_value if isinstance(param, BoundedParameter) else None,
  803. param.larger_value if isinstance(param, BoundedParameter) else None)
  804. for param in params
  805. ]
  806. iterations = 0
  807. for metric_name in metric_names:
  808. dirname = 'img/results/{0}/{1}/'.format(experiment_name, metric_name)
  809. os.makedirs(dirname, exist_ok=True)
  810. for plot, x_scale, min_mod, max_mod in plots:
  811. print_progress_bar(iterations, len(metric_names) * len(plots))
  812. if min_mod is None:
  813. min_mod = lambda x: x
  814. if max_mod is None:
  815. max_mod = lambda x: x
  816. if df[plot].nunique() <= 1:
  817. iterations += 1
  818. continue
  819. grid = sns.relplot(x=plot, y=metric_name + '_mean', data=df)
  820. if x_scale is not None:
  821. if x_scale == 'log' and min_mod(df.min(axis=0)[plot]) <= 0:
  822. x_min = None
  823. else:
  824. x_min = min_mod(df.min(axis=0)[plot])
  825. grid.set(xscale=x_scale,
  826. xlim=(x_min,
  827. max_mod(df.max(axis=0)[plot]),))
  828. plt.savefig(dirname + '{0}.png'.format(plot))
  829. plt.clf()
  830. plt.close()
  831. iterations += 1
  832. print_progress_bar(iterations, len(metric_names) * len(plots))
  833. if plot_metrics_by_metrics:
  834. print('Plotting metrics by metrics...')
  835. dirname = 'img/results/{0}/'.format(experiment_name)
  836. os.makedirs(dirname, exist_ok=True)
  837. # Generate some plots, metric by metric
  838. iterations = 0
  839. print('Plotting metric by metric, grouped')
  840. for metric_name in metric_names:
  841. for metric_2 in metric_names:
  842. if metric_name == metric_2:
  843. iterations += 1
  844. print_progress_bar(iterations, len(metric_names) ** 2)
  845. continue
  846. print_progress_bar(iterations, len(metric_names) ** 2)
  847. sns.relplot(x=metric_name + '_mean', y=metric_2 + '_mean', data=df)
  848. plt.savefig(dirname + '{0}_{1}_grouped.png'.format(metric_name, metric_2))
  849. plt.clf()
  850. plt.close()
  851. heatmap_from_points(x=df[metric_name + '_mean'], y=df[metric_2 + '_mean'])
  852. plt.xlabel(f'mean {metric_name}')
  853. plt.ylabel(f'mean {metric_2}')
  854. plt.savefig(dirname + '{0}_{1}_heatmap.png'.format(metric_name, metric_2))
  855. plt.clf()
  856. plt.close()
  857. iterations += 1
  858. print_progress_bar(iterations, len(metric_names) ** 2)
  859. df = pandas.read_sql_query('SELECT * FROM {0}'.format(experiment_name),
  860. connection)
  861. df['dt_created'] = pandas.to_datetime(df['dt_created'])
  862. if plot_metric_over_time:
  863. # Generate some plots, metric over time
  864. dirname = 'img/results/{0}/'.format(experiment_name)
  865. os.makedirs(dirname, exist_ok=True)
  866. print('Plotting metric over time')
  867. iterations = 0
  868. for metric_name in metric_names:
  869. if not df[metric_name].any():
  870. continue
  871. print_progress_bar(iterations, len(metric_names))
  872. ax = df.plot(x='dt_created', y=metric_name, style='.')
  873. ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%m-%d %H:00'))
  874. plt.savefig(dirname + 'dt_created_{0}.png'.format(metric_name))
  875. plt.clf()
  876. plt.close()
  877. iterations += 1
  878. print_progress_bar(iterations, len(metric_names))
  879. # plot optimize grouped over time
  880. assert df['dt_created'].is_monotonic # sorting should not be a problem but we are lazy
  881. y_means = []
  882. df = df.drop_duplicates(subset='dt_created')
  883. timestamps = pandas.datetimeIndex(df.dt_created).asi8 // 10 ** 9
  884. iterations = 0
  885. print('Preparing plot {0} over time'.format(optimize))
  886. for x in timestamps:
  887. print_progress_bar(iterations, len(timestamps))
  888. not_after_x = 'CAST(strftime(\'%s\', dt_created) AS INT) <= {0}'.format(x)
  889. param = get_best_params(additional_condition=not_after_x,
  890. param_names=param_names,
  891. experiment_name=experiment_name,
  892. larger_result_is_better=larger_result_is_better,
  893. optimize=optimize)
  894. scores, mean, std, conf = get_results_for_params(optimize, experiment_name, param,
  895. additional_condition=not_after_x)
  896. y_means.append(mean)
  897. iterations += 1
  898. print_progress_bar(iterations, len(timestamps))
  899. df['score'] = y_means
  900. ax = df.plot(x='dt_created', y='score')
  901. ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%m-%d %H:00'))
  902. plt.savefig(dirname + '{0}_over_time.png'.format(optimize))
  903. plt.clf()
  904. plt.close()
  905. return best_params, best_score
  906. def predictions_for_parameters(experiment_name: str, parameters, show_progress=False):
  907. result_ids = result_ids_for_parameters(experiment_name, parameters)
  908. if not show_progress:
  909. return [
  910. predictions_for_result_id(experiment_name, result_id)
  911. for result_id in result_ids
  912. ]
  913. else:
  914. return [
  915. predictions_for_result_id(experiment_name, result_id)
  916. for result_id in ProgressBar(result_ids)
  917. ]
  918. def result_ids_for_parameters(experiment_name, parameters: Dict[str, Any]):
  919. condition, parameters = only_specific_parameters_condition(parameters)
  920. cursor = connection.cursor()
  921. cursor.execute('''
  922. SELECT rowid FROM {0}
  923. WHERE {1}
  924. ORDER BY rowid
  925. '''.format(experiment_name, condition), parameters)
  926. result_ids = [row[0] for row in cursor.fetchall()]
  927. return result_ids
  928. def creation_times_for_parameters(experiment_name, parameters):
  929. condition, parameters = only_specific_parameters_condition(parameters)
  930. cursor = connection.cursor()
  931. cursor.execute('''
  932. SELECT dt_created FROM {0}
  933. WHERE {1}
  934. ORDER BY rowid
  935. '''.format(experiment_name, condition), parameters)
  936. creation_times = [row[0] for row in cursor.fetchall()]
  937. return creation_times
  938. def predictions_for_result_id(experiment_name: str, result_id):
  939. cursor = connection.cursor()
  940. cursor.execute('''
  941. SELECT name, dataset, y_pred, y_true FROM {0}_predictions
  942. WHERE result_id = ?
  943. '''.format(experiment_name, ), (result_id,))
  944. predictions = [{
  945. 'name': row[0],
  946. 'dataset': row[1],
  947. 'y_pred': row[2],
  948. 'y_true': row[3],
  949. } for row in cursor.fetchall()]
  950. return predictions
  951. def list_difficult_samples(experiment_name,
  952. loss_functions,
  953. dataset,
  954. max_losses_to_average=20,
  955. additional_condition='1',
  956. additional_parameters=(),
  957. also_print=False):
  958. names = all_sample_names(dataset, experiment_name)
  959. cursor = connection.cursor()
  960. if 'epochs' in additional_condition:
  961. try:
  962. print('Creating index to fetch results faster (if not exists)...')
  963. cursor.execute('''
  964. CREATE INDEX IF NOT EXISTS {0}_by_name_epochs_dataset
  965. ON {0} (name, epochs, dataset)'''.format(experiment_name))
  966. except Exception as e: # TODO check error type
  967. print(e)
  968. pass
  969. cursor = connection.cursor()
  970. table = []
  971. print('Fetching results for names...')
  972. for name in ProgressBar(names):
  973. if additional_condition == '1':
  974. additional_join = ''
  975. else:
  976. additional_join = 'JOIN {0} ON {0}.rowid = result_id'.format(experiment_name)
  977. if isinstance(max_losses_to_average, int) is not None and max_losses_to_average != inf:
  978. limit_string = f'LIMIT ?'
  979. limit_args = [max_losses_to_average]
  980. elif max_losses_to_average is None or max_losses_to_average == inf:
  981. limit_string = ''
  982. limit_args = []
  983. else:
  984. raise ValueError
  985. cursor.execute('''
  986. SELECT y_pred, y_true
  987. FROM {0}
  988. CROSS JOIN {0}_predictions ON {0}.rowid = result_id
  989. WHERE name = ? AND dataset = ? AND ({1})
  990. {3}'''.format(experiment_name, additional_condition, ..., limit_string), (name,
  991. dataset,
  992. *additional_parameters,
  993. *limit_args,))
  994. data = cursor.fetchall()
  995. if len(data) > 0:
  996. def aggregate(xs):
  997. if len(set(xs)) == 1:
  998. return xs[0]
  999. else:
  1000. return numpy.mean(xs)
  1001. table.append((*[aggregate([loss_function(y_pred=y_pred, y_true=y_true, name=name)
  1002. for y_pred, y_true in data])
  1003. for loss_function in loss_functions],
  1004. name, len(data)))
  1005. print('sorting table...')
  1006. table.sort(reverse=True)
  1007. if also_print:
  1008. print('stringifying table...')
  1009. print(my_tabulate(table,
  1010. headers=[loss_function.__name__ for loss_function in loss_functions] + ['name', '#results'],
  1011. tablefmt='pipe'))
  1012. return table
  1013. def all_sample_names(dataset, experiment_name):
  1014. cursor = connection.cursor()
  1015. print('Creating index to have faster queries by name (if not exists)...')
  1016. cursor.execute('''
  1017. CREATE INDEX IF NOT EXISTS {0}_predictions_by_name_and_dataset
  1018. ON {0}_predictions (dataset, name)'''.format(experiment_name))
  1019. print('Fetching all names...')
  1020. names = []
  1021. last_found = '' # smaller than all other strings
  1022. while True:
  1023. cursor.execute('SELECT name '
  1024. 'FROM {0}_predictions '
  1025. 'WHERE dataset = ? AND name > ?'
  1026. 'LIMIT 1'.format(experiment_name), (dataset, last_found))
  1027. row = cursor.fetchone()
  1028. if row is None:
  1029. break
  1030. names.append(row[0])
  1031. last_found = row[0]
  1032. return names
  1033. def only_specific_parameters_condition(parameters: Dict[str, Any]) -> Tuple[str, Tuple]:
  1034. items = list(parameters.items()) # to have the correct ordering
  1035. return '(' + ' AND '.join(f'"{name}" IS ?' for name, _ in items) + ')', \
  1036. tuple(value for name, value in items)
  1037. def only_best_parameters_condition(experiment_name: str,
  1038. larger_result_is_better: bool,
  1039. optimize: str,
  1040. param_names: List[str],
  1041. additional_condition: str = '1') -> Tuple[str, Tuple]:
  1042. parameters = get_best_params(experiment_name=experiment_name,
  1043. larger_result_is_better=larger_result_is_better,
  1044. optimize=optimize,
  1045. param_names=param_names,
  1046. additional_condition=additional_condition)
  1047. return only_specific_parameters_condition(parameters)
  1048. def get_results_for_params(metric, experiment_name, parameters, confidence=0.95,
  1049. additional_condition='1'):
  1050. param_names = list(parameters.keys())
  1051. cursor = connection.cursor()
  1052. params_equal = '\nAND '.join('"' + param_name + '" IS ?' for param_name in param_names)
  1053. cursor.execute(
  1054. '''
  1055. SELECT {0}
  1056. FROM {1}
  1057. WHERE {2} AND ({3})
  1058. '''.format(metric,
  1059. experiment_name,
  1060. params_equal,
  1061. additional_condition),
  1062. tuple(parameters[name] for name in param_names)
  1063. )
  1064. # noinspection PyShadowingNames
  1065. scores = [row[0] if row[0] is not None else nan for row in cursor.fetchall()]
  1066. if len(scores) == 0:
  1067. return scores, nan, nan, nan
  1068. return scores, numpy.mean(scores), numpy.std(scores), mean_confidence_interval_size(scores, confidence)
  1069. def num_results_for_params(param_names, experiment_name, parameters,
  1070. additional_condition='1'):
  1071. cursor = connection.cursor()
  1072. params_equal = '\nAND '.join('"' + param_name + '" IS ?' for param_name in param_names)
  1073. cursor.execute(
  1074. '''
  1075. SELECT COUNT(*)
  1076. FROM {0}
  1077. WHERE {1} AND ({2})
  1078. '''.format(experiment_name,
  1079. params_equal,
  1080. additional_condition),
  1081. tuple(parameters[name] for name in param_names)
  1082. )
  1083. return cursor.fetchone()[0]
  1084. def get_best_params(experiment_name: str,
  1085. larger_result_is_better: bool,
  1086. optimize: str,
  1087. param_names: List[str],
  1088. additional_condition='1') -> Optional[Parameters]:
  1089. cursor = connection.cursor()
  1090. param_names_comma_separated = ','.join('"' + param_name + '"' for param_name in param_names)
  1091. worst_score = '-1e999999' if larger_result_is_better else '1e999999'
  1092. # noinspection SqlAggregates
  1093. cursor.execute('''
  1094. SELECT * FROM {0} AS params
  1095. WHERE ({4})
  1096. GROUP BY {1}
  1097. ORDER BY AVG(CASE WHEN params.{3} IS NULL THEN {5} ELSE params.{3} END) {2}, MIN(rowid) ASC
  1098. LIMIT 1
  1099. '''.format(experiment_name,
  1100. param_names_comma_separated,
  1101. 'DESC' if larger_result_is_better else 'ASC',
  1102. optimize,
  1103. additional_condition,
  1104. worst_score, ))
  1105. row = cursor.fetchone()
  1106. if row is None:
  1107. raise EmptyTableError()
  1108. else:
  1109. return params_from_row(cursor.description, row, param_names=param_names)
  1110. def params_from_row(description, row, param_names=None) -> Parameters:
  1111. best_params = {}
  1112. for idx, column_description in enumerate(description):
  1113. column_name = column_description[0]
  1114. if param_names is None or column_name in param_names:
  1115. best_params[column_name] = row[idx]
  1116. return best_params
  1117. def create_experiment_tables_if_not_exists(experiment_name, params, metric_names):
  1118. cursor = connection.cursor()
  1119. param_names = set(param.name for param in params)
  1120. initial_params = {param.name: param.initial_value for param in params}
  1121. cursor.execute('''
  1122. CREATE TABLE IF NOT EXISTS {0}(
  1123. rowid INTEGER PRIMARY KEY,
  1124. dt_created datetime DEFAULT CURRENT_TIMESTAMP
  1125. )
  1126. '''.format(experiment_name))
  1127. cursor.execute('''
  1128. CREATE TABLE IF NOT EXISTS {0}_predictions(
  1129. rowid INTEGER PRIMARY KEY,
  1130. dataset TEXT NOT NULL,
  1131. y_true BLOB,
  1132. y_pred BLOB,
  1133. name TEXT NOT NULL, -- used to identify the samples
  1134. result_id INTEGER NOT NULL REFERENCES {0}(rowid),
  1135. UNIQUE(result_id, dataset, name) -- gives additional indices
  1136. )
  1137. '''.format(experiment_name))
  1138. connection.commit()
  1139. for param_name in param_names:
  1140. default_value = initial_params[param_name]
  1141. add_parameter_column(experiment_name, param_name, default_value)
  1142. for metric_name in metric_names:
  1143. add_metric_column(experiment_name, metric_name)
  1144. def add_metric_column(experiment_name, metric_name, verbose=0):
  1145. cursor = connection.cursor()
  1146. try:
  1147. cursor.execute('ALTER TABLE {0} ADD COLUMN "{1}" NUMERIC DEFAULT NULL'.format(experiment_name,
  1148. metric_name))
  1149. except sqlite3.OperationalError as e:
  1150. if 'duplicate column name' not in e.args[0]:
  1151. raise
  1152. else:
  1153. if verbose:
  1154. print(f'WARNING: created additional column {metric_name}. This may or may not be intentional')
  1155. connection.commit()
  1156. def add_parameter_column(experiment_name, param_name, default_value, verbose=0):
  1157. cursor = connection.cursor()
  1158. try:
  1159. if isinstance(default_value, str):
  1160. default_value.replace("'", "\\'")
  1161. default_value = "'" + default_value + "'"
  1162. if default_value is None:
  1163. default_value = 'NULL'
  1164. cursor.execute('ALTER TABLE {0} ADD COLUMN "{1}" BLOB DEFAULT {2}'.format(experiment_name,
  1165. param_name,
  1166. default_value))
  1167. except sqlite3.OperationalError as e:
  1168. if 'duplicate column name' not in e.args[0]:
  1169. raise
  1170. else:
  1171. if verbose:
  1172. print(
  1173. f'WARNING: created additional column {param_name} with default value {default_value}. This may or may not be intentional')
  1174. connection.commit()
  1175. def markdown_table(all_results, sort_by):
  1176. rows = [list(result['params'].values()) + [result['mean'], result['std'], result['conf'], result['all']] for result
  1177. in all_results]
  1178. rows.sort(key=sort_by)
  1179. table = my_tabulate(rows, headers=list(all_results[0]['params'].keys()) + ['mean', 'std', 'conf', 'results'],
  1180. tablefmt='pipe')
  1181. return table
  1182. def validate_parameter_set(params):
  1183. if len(params) == 0:
  1184. raise ValueError('Parameter set empty')
  1185. for i, param in enumerate(params):
  1186. # noinspection PyUnusedLocal
  1187. other_param: Parameter
  1188. for other_param in params[i + 1:]:
  1189. if param.name == other_param.name and param.initial_value != other_param.initial_value:
  1190. msg = '''
  1191. A single parameter cant have multiple initial values.
  1192. Parameter "{0}" has initial values "{1}" and "{2}"
  1193. '''.format(param.name, param.initial_value, other_param.initial_value)
  1194. raise ValueError(msg)
  1195. def run_name(parameters=None) -> str:
  1196. if parameters is None:
  1197. parameters = {}
  1198. shorter_parameters = {
  1199. shorten_name(k): shorten_name(v)
  1200. for k, v in parameters.items()
  1201. }
  1202. return ((str(datetime.now()) + str(shorter_parameters).replace(' ', ''))
  1203. .replace("'", '')
  1204. .replace('"', '')
  1205. .replace(":", '⦂')
  1206. .replace(",", '')
  1207. .replace("_", '')
  1208. .replace("<", '')
  1209. .replace(">", '')
  1210. .replace("{", '')
  1211. .replace("}", ''))
  1212. def plot_experiment(metric_names,
  1213. experiment_name: str,
  1214. plot_name: str,
  1215. param_names: List[str],
  1216. params_list: List[Parameters],
  1217. evaluate: EvaluationFunction,
  1218. ignore: List[str] = None,
  1219. plot_shape=None,
  1220. metric_limits: Dict = None,
  1221. titles=None,
  1222. natural_metric_names: Dict[str, str] = None,
  1223. min_runs_per_params=0,
  1224. single_plot_width=6.4,
  1225. single_plot_height=4.8, ):
  1226. if natural_metric_names is None:
  1227. natural_metric_names = {}
  1228. for parameters in params_list:
  1229. if 'epochs' not in parameters:
  1230. raise ValueError('`plot_experiment` needs the number of epochs to plot (`epochs`)')
  1231. if metric_limits is None:
  1232. metric_limits = {}
  1233. if ignore is None:
  1234. ignore = []
  1235. if titles is None:
  1236. titles = [None for _ in params_list]
  1237. if plot_shape is None:
  1238. width = ceil(sqrt(len(params_list)))
  1239. plot_shape = (ceil(len(params_list) / width), width,)
  1240. else:
  1241. width = plot_shape[1]
  1242. plot_shape_offset = 100 * plot_shape[0] + 10 * plot_shape[1]
  1243. axes: Dict[int, Axes] = {}
  1244. legend: List[str] = []
  1245. results_dir = 'img/results/{0}/over_time/'.format(experiment_name)
  1246. os.makedirs(results_dir, exist_ok=True)
  1247. metric_names = sorted(metric_names, key=lambda m: (metric_limits.get(m, ()), metric_names.index(m)))
  1248. print(metric_names)
  1249. plotted_metric_names = []
  1250. iterations = 0
  1251. for plot_std in [False, True]:
  1252. plt.figure(figsize=(single_plot_width * plot_shape[1], single_plot_height * plot_shape[0]))
  1253. for idx, metric in enumerate(metric_names):
  1254. print_progress_bar(iterations, 2 * (len(metric_names) - len(ignore)))
  1255. limits = metric_limits.get(metric, None)
  1256. try:
  1257. next_limits = metric_limits.get(metric_names[idx + 1], None)
  1258. except IndexError:
  1259. next_limits = None
  1260. if metric in ignore:
  1261. continue
  1262. sqlite_infinity = '1e999999'
  1263. metric_is_finite = '{0} IS NOT NULL AND {0} != {1} AND {0} != -{1}'.format(metric, sqlite_infinity)
  1264. for plot_idx, parameters in enumerate(params_list):
  1265. while num_results_for_params(param_names=param_names,
  1266. experiment_name=experiment_name,
  1267. parameters=parameters, ) < min_runs_per_params:
  1268. print('Doing one of the missing experiments for the plot:')
  1269. print(parameters)
  1270. results = try_parameters(experiment_name=experiment_name,
  1271. evaluate=evaluate,
  1272. params=parameters, )
  1273. assert any(result.parameters == parameters for result in results)
  1274. contains_avg_over = 'average_over_last_epochs' in parameters
  1275. total_epochs = parameters['epochs']
  1276. history = []
  1277. lower_conf_limits = []
  1278. upper_conf_limits = []
  1279. for epoch_end in range(total_epochs):
  1280. current_parameters = parameters.copy()
  1281. if contains_avg_over:
  1282. current_parameters['average_over_last_epochs'] = None
  1283. current_parameters['epochs'] = epoch_end + 1
  1284. scores, mean, std, conf = get_results_for_params(
  1285. metric=metric,
  1286. experiment_name=experiment_name,
  1287. parameters=current_parameters,
  1288. additional_condition=metric_is_finite
  1289. )
  1290. history.append(mean)
  1291. if plot_std:
  1292. lower_conf_limits.append(mean - 1.959964 * std)
  1293. upper_conf_limits.append(mean + 1.959964 * std)
  1294. else:
  1295. lower_conf_limits.append(mean - conf)
  1296. upper_conf_limits.append(mean + conf)
  1297. x = list(range(len(history)))
  1298. if plot_shape_offset + plot_idx + 1 not in axes:
  1299. # noinspection PyTypeChecker
  1300. ax: Axes = plt.subplot(plot_shape_offset + plot_idx + 1)
  1301. assert isinstance(ax, Axes)
  1302. axes[plot_shape_offset + plot_idx + 1] = ax
  1303. ax = axes[plot_shape_offset + plot_idx + 1]
  1304. ax.plot(x, history)
  1305. ax.fill_between(x, lower_conf_limits, upper_conf_limits, alpha=0.4)
  1306. if titles[plot_idx] is not None:
  1307. ax.set_title(titles[plot_idx])
  1308. if limits is not None:
  1309. ax.set_ylim(limits)
  1310. ax.set_xlim(0, max(total_epochs, ax.get_xlim()[1]))
  1311. current_row = plot_idx // width
  1312. if current_row == plot_shape[0] - 1:
  1313. ax.set_xlabel('Epoch')
  1314. natural_name = natural_metric_names.get(metric, metric)
  1315. if plot_std:
  1316. legend += ['mean ' + natural_name, '1.96σ of {0}'.format(natural_name)]
  1317. else:
  1318. legend += ['mean ' + natural_name, '95% conf. of mean {0}'.format(natural_name)]
  1319. plotted_metric_names.append(metric)
  1320. if limits is None or next_limits is None or limits != next_limits:
  1321. legend = legend[0::2] + legend[1::2]
  1322. for ax in axes.values():
  1323. ax.legend(legend)
  1324. if plot_std:
  1325. plt.savefig(results_dir + plot_name + '_' + ','.join(plotted_metric_names) + '_std' + '.png')
  1326. else:
  1327. plt.savefig(results_dir + plot_name + '_' + ','.join(plotted_metric_names) + '.png')
  1328. plt.clf()
  1329. plt.close()
  1330. plt.figure(figsize=(single_plot_width * plot_shape[1], single_plot_height * plot_shape[0]))
  1331. axes = {}
  1332. plotted_metric_names = []
  1333. legend = []
  1334. iterations += 1
  1335. print_progress_bar(iterations, 2 * (len(metric_names) - len(ignore)))
  1336. plt.clf()
  1337. plt.close()
  1338. if __name__ == '__main__':
  1339. def evaluate(params):
  1340. return (params['A'] - 30) ** 2 + 10 * ((params['B'] / (params['A'] + 1)) - 1) ** 2 + params['C']
  1341. diamond_parameter_search('test',
  1342. diamond_size=2,
  1343. params=[LinearParameter('A', 10, 10),
  1344. ExponentialIntegerParameter('B', 8, 2),
  1345. ConstantParameter('C', 5)],
  1346. runs_per_configuration=1,
  1347. initial_runs=1,
  1348. evaluate=evaluate,
  1349. optimize='loss',
  1350. larger_result_is_better=False)