progress_bar.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import functools
  2. import math
  3. import time
  4. from math import floor
  5. from typing import Iterable, Sized, Iterator
  6. class ProgressBar(Sized, Iterable):
  7. def __iter__(self) -> Iterator:
  8. self.check_if_num_steps_defined()
  9. self.current_iteration = -1 # start counting at the end of the first epoch
  10. self.current_iterator = iter(self._backend)
  11. self.start_time = time.perf_counter()
  12. return self
  13. def __init__(self,
  14. num_steps=None,
  15. prefix='',
  16. suffix='',
  17. line_length=75,
  18. empty_char='-',
  19. fill_char='#',
  20. print_eta=True,
  21. decimals=1):
  22. self.decimals = decimals
  23. self.line_length = line_length
  24. self.suffix = suffix
  25. self.empty_char = empty_char
  26. self.prefix = prefix
  27. self.fill_char = fill_char
  28. self.print_eta = print_eta
  29. self.current_iteration = 0
  30. self.last_printed_value = None
  31. self.current_iterator = None
  32. self.start_time = time.perf_counter()
  33. try:
  34. self._backend = range(num_steps)
  35. except TypeError:
  36. if isinstance(num_steps, Sized):
  37. if isinstance(num_steps, Iterable):
  38. self._backend = num_steps
  39. else:
  40. self._backend = range(len(num_steps))
  41. elif num_steps is None:
  42. self._backend = None
  43. else:
  44. raise
  45. assert num_steps is None or isinstance(self._backend, (Iterable, Sized))
  46. def set_num_steps(self, num_steps):
  47. try:
  48. self._backend = range(num_steps)
  49. except TypeError:
  50. if isinstance(num_steps, Sized):
  51. if isinstance(num_steps, Iterable):
  52. self._backend = num_steps
  53. else:
  54. self._backend = range(len(num_steps))
  55. elif num_steps is None:
  56. self._backend = None
  57. else:
  58. raise
  59. assert num_steps is None or isinstance(self._backend, (Iterable, Sized))
  60. def __len__(self):
  61. return len(self._backend)
  62. def __next__(self):
  63. self.print_progress()
  64. try:
  65. result = next(self.current_iterator)
  66. self.increment_iteration()
  67. self.print_progress()
  68. return result
  69. except StopIteration:
  70. self.increment_iteration()
  71. self.print_progress()
  72. raise
  73. def step(self, num_iterations=1):
  74. self.current_iteration += num_iterations
  75. self.print_progress()
  76. def print_progress(self, iteration=None):
  77. """
  78. Call in a loop to create terminal progress bar
  79. @params:
  80. iteration - Optional : current iteration (Int)
  81. """
  82. if iteration is not None:
  83. self.current_iteration = iteration
  84. try:
  85. progress = self.current_iteration / len(self)
  86. except ZeroDivisionError:
  87. progress = 1
  88. if self.current_iteration == 0:
  89. self.start_time = time.perf_counter()
  90. if self.print_eta and progress > 0:
  91. time_spent = (time.perf_counter() - self.start_time)
  92. eta = time_spent / progress * (1 - progress)
  93. if progress == 1:
  94. eta = f' T = {int(time_spent / 60):02d}:{round(time_spent % 60):02d}'
  95. else:
  96. eta = f' ETA {int(eta / 60):02d}:{round(eta % 60):02d}'
  97. else:
  98. eta = ''
  99. percent = ("{0:" + str(4 + self.decimals) + "." + str(self.decimals) + "f}").format(100 * progress)
  100. bar_length = self.line_length - len(self.prefix) - len(self.suffix) - len(eta) - 4 - 6
  101. try:
  102. filled_length = int(bar_length * self.current_iteration // len(self))
  103. except ZeroDivisionError:
  104. filled_length = bar_length
  105. if math.isclose(bar_length * progress, filled_length):
  106. overflow = 0
  107. else:
  108. overflow = bar_length * progress - filled_length
  109. overflow *= 10
  110. overflow = floor(overflow)
  111. assert overflow in range(10), overflow
  112. if overflow > 0:
  113. bar = self.fill_char * filled_length + str(overflow) + self.empty_char * (bar_length - filled_length - 1)
  114. else:
  115. bar = self.fill_char * filled_length + self.empty_char * (bar_length - filled_length)
  116. print_value = '\r{0} |{1}| {2}% {4}{3}'.format(self.prefix, bar, percent, eta, self.suffix)
  117. if self.current_iteration == len(self):
  118. print_value += '\n' # Print New Line on Complete
  119. if self.last_printed_value == print_value:
  120. return
  121. self.last_printed_value = print_value
  122. print(print_value, end='')
  123. def increment_iteration(self):
  124. self.current_iteration += 1
  125. if self.current_iteration > len(self): # catches the special case at the end of the bar
  126. self.current_iteration %= len(self)
  127. def monitor(self, func=None):
  128. """ Decorates the given function func to print a progress bar before and after each call. """
  129. if func is None:
  130. # Partial application, to be able to specify extra keyword
  131. # arguments in decorators
  132. return functools.partial(self.monitor)
  133. @functools.wraps(func)
  134. def wrapper(*args, **kwargs):
  135. self.check_if_num_steps_defined()
  136. self.print_progress()
  137. result = func(*args, **kwargs)
  138. self.increment_iteration()
  139. self.print_progress()
  140. return result
  141. return wrapper
  142. def check_if_num_steps_defined(self):
  143. if self._backend is None:
  144. raise RuntimeError('You need to specify the number of iterations before starting to iterate. '
  145. 'You can either pass it to the constructor or use the method `set_num_steps`.')
  146. if __name__ == '__main__':
  147. # Einfach beim iterieren verwenden
  148. for x in ProgressBar([0.5, 2, 0.5]):
  149. time.sleep(x)
  150. # manuell aufrufen
  151. data = [1, 5, 5, 6, 12, 3, 4, 5]
  152. y = 0
  153. p = ProgressBar(len(data))
  154. for x in data:
  155. p.print_progress()
  156. time.sleep(0.2)
  157. y += x
  158. p.current_iteration += 1
  159. p.print_progress()
  160. print(y)
  161. # oder einfach bei jedem funktionsaufruf den balken printen
  162. p = ProgressBar()
  163. @p.monitor
  164. def heavy_computation(t=0.25):
  165. time.sleep(t)
  166. p.set_num_steps(10) # 10 steps pro balken
  167. for _ in range(20): # zeichnet 2 balken
  168. heavy_computation(0.25)