plots.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import sqlite3
  2. import time
  3. from typing import Union
  4. import networkx as nx
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import sys
  8. import matplotlib.animation
  9. class Plots():
  10. def __init__(self, x_axe, y_axe, z_axe=[], x_label='x', y_label='y', z_label='z', x_axe2=[], y_axe2=[], z_axe2=[], grid=True,
  11. figsize=(13, 10)):
  12. self.x_axe = x_axe
  13. self.y_axe = y_axe
  14. self.z_axe = z_axe
  15. self.x_label = x_label
  16. self.y_label = y_label
  17. self.z_label = z_label
  18. self.x_axe2 = x_axe2
  19. self.y_axe2 = y_axe2
  20. self.z_axe2 = z_axe2
  21. self.grid = grid
  22. self.figure = plt.figure(figsize=figsize)
  23. def get_a_exponential_func_as_plot(self, linrange=[-5, 5], expo=2):
  24. x = np.linspace(linrange[0], linrange[1])
  25. y = x ** expo
  26. plt.plot(x, y, 'g')
  27. def plot_2D_compare_bar_chart(self, legend=[], width=0.35, title=str, path=''):
  28. x = np.arange(len(self.x_axe))
  29. if legend:
  30. plt.bar(x, self.y_axe, width=width, color='blue', label=legend[0])
  31. plt.bar(x + width, self.y_axe2, width=width, color='red', label=legend[1])
  32. else:
  33. plt.bar(x, self.y_axe, width=width, color='blue')
  34. plt.bar(x + width, self.y_axe2, width=width, color='red')
  35. if self.grid:
  36. plt.grid()
  37. plt.xlabel(self.x_label)
  38. plt.ylabel(self.y_label)
  39. if title:
  40. plt.title(title)
  41. plt.xticks(x + width / 2, self.x_axe)
  42. plt.legend(loc='best')
  43. if path:
  44. self.save_fig(name='plot_2D_compare_bar_chart', path=path)
  45. else:
  46. self.save_fig(name='plot_2D_compare_bar_chart')
  47. plt.show()
  48. def save_fig(self, name, path: Union[bool, str] = False, ):
  49. if path:
  50. plt.savefig(path + '{}_{}.png'.format(name, time.strftime("%Y-%m-%d_H%H-M%M")))
  51. else:
  52. plt.savefig('{}_{}.png'.format(name, time.strftime("%Y-%m-%d_H%H-M%M")))
  53. # Constant Layout
  54. NODE_SIZE = 200
  55. NODE_EDGE_COLOR = 'black'
  56. EDGE_WIDTH = 0.5
  57. FONT_SIZE = 8
  58. FONT_SIZE_EDGES = 1
  59. FONT_FAMILY = 'sans-serif'
  60. SAVE_FORMAT = 'svg'
  61. DPI = 1200
  62. CONNECTION_STYLE = 'arc3, rad=0.2'
  63. ARROW_SIZE = 12
  64. LABLE_POS = 0.35
  65. class NetworkxPlots():
  66. def __init__(self, node_dict, pathing=[], color_map=[], legend=[], edges=[], edge_weightings=[], directed_graph=False):
  67. if directed_graph:
  68. self.graph = nx.DiGraph()
  69. else:
  70. self.graph = nx.Graph()
  71. # Nodes
  72. self.node_type_dict = {}
  73. self.node_color_dict = {}
  74. self.node_pos_dict = {}
  75. for key, value in node_dict.items():
  76. for ckey, cvalue in node_dict.items():
  77. if node_dict[key]['node_type'] == node_dict[ckey]['node_type'] and node_dict[key][
  78. 'node_type'] not in self.node_type_dict.keys():
  79. self.node_type_dict[node_dict[key]['node_type']] = []
  80. self.node_type_dict[node_dict[key]['node_type']].append(key)
  81. self.node_color_dict[node_dict[key]['node_type']] = node_dict[key]['node_color']
  82. self.node_pos_dict[key] = node_dict[key]['node_name']
  83. # Edges can be a nxn Matrix
  84. self.edges = edges
  85. self.pathing = pathing
  86. self.color_map = color_map
  87. self.edge_weightings = edge_weightings
  88. self.legend = legend
  89. def edges_for_complete_graph(self):
  90. # without self_loops
  91. for row in range(0, len(self.edges[:, 0])):
  92. for column in range(0, len(self.edges[0, :])):
  93. if round(self.edges[row, column], 1) == 0:
  94. pass
  95. else:
  96. self.graph.add_edge(row + 1, column + 1, weight=1) # round(self.edges[row, column], 1)
  97. def directions_for_directed_graph(self, pathing=[]):
  98. if pathing:
  99. for order, path in enumerate(pathing):
  100. self.graph.add_edge(path[0], path[1], weight=order + 1)
  101. else:
  102. for order, path in enumerate(self.pathing):
  103. self.graph.add_edge(path[0], path[1], weight=order + 1)
  104. def add_nodes_to_graph(self):
  105. node_numberation = [node_number for node_number in range(1, len(self.node_pos_dict.keys()) + 1)]
  106. for count in node_numberation:
  107. self.graph.add_node(count)
  108. def add_edges_to_graph(self):
  109. self.graph.add_edges_from(self.edges, weight=self.edge_weightings)
  110. def undirected_graph_plt_(self,
  111. name_path_tupl=''): # TODO SIMPLYFY THESE FUNCTIONS BY EXTRA FUNCTIONS WHICH SEPERATES THEM INTO SMALLER ONES
  112. '''
  113. :param name_path_tupl: (name, path) to save the pic
  114. :return: a showing of the undirected graph/ saves the picture
  115. '''
  116. # Setting
  117. plt.axis("on")
  118. ax = plt.gca()
  119. self.create_edgeless_graph_plot(ax)
  120. # Undirected labels
  121. elarge = [(u, v) for (u, v, d) in self.graph.edges(data=True) if d["weight"]]
  122. nx.draw_networkx_edges(self.graph, self.node_pos_dict, edgelist=elarge, width=EDGE_WIDTH, ax=ax)
  123. # Create Figure
  124. self.create_graph_settings_with_cartesic_coord(plt, ax)
  125. # Save
  126. if name_path_tupl:
  127. self.save_fig(name=name_path_tupl[0], path=name_path_tupl[1], format=SAVE_FORMAT, dpi=DPI)
  128. # plt.gca().set_aspect('equal', adjustable='box')
  129. plt.show()
  130. def directed_graph_with_path_plt_(self, name_path_tupl=''):
  131. '''
  132. :param name_path_tupl: (name, path) to save the pic
  133. :return: a showing of the undirected graph/ saves the picture
  134. '''
  135. # Setting
  136. plt.axis("on")
  137. ax = plt.gca()
  138. self.create_edgeless_graph_plot(ax)
  139. # Directed labels
  140. elarge = [(u, v) for (u, v, d) in self.graph.edges(data=True) if d["weight"]]
  141. nx.draw_networkx_edges(self.graph, self.node_pos_dict, edgelist=elarge, width=EDGE_WIDTH, ax=ax,
  142. connectionstyle=CONNECTION_STYLE, arrowsize=ARROW_SIZE)
  143. edge_labels = nx.get_edge_attributes(self.graph, 'weight')
  144. nx.draw_networkx_edge_labels(self.graph, self.node_pos_dict, edge_labels=edge_labels, ax=ax, label_pos=LABLE_POS,
  145. font_size=FONT_SIZE_EDGES)
  146. # Create figure:
  147. self.create_graph_settings_with_cartesic_coord(plt, ax)
  148. # Save
  149. if name_path_tupl:
  150. self.save_fig(name=name_path_tupl[0], path=name_path_tupl[1], format=SAVE_FORMAT, dpi=DPI)
  151. plt.gca().set_aspect('equal', adjustable='box')
  152. plt.show()
  153. def save_fig(self, name, path: Union[bool, str] = False, format='png', dpi=1200):
  154. if path:
  155. plt.savefig(path + r'\{}_{}.'.format(name, time.strftime("%Y-%m-%d_H%H-M%M")) + format, format=format, dpi=dpi)
  156. else:
  157. plt.savefig('{}_{}'.format(name, time.strftime("%Y-%m-%d_H%H-M%M")) + format, format=format, dpi=dpi)
  158. def create_edgeless_graph_plot(self, ax):
  159. for node_type in self.node_type_dict.keys():
  160. nlist = self.node_type_dict[node_type]
  161. ncolor = self.node_color_dict[node_type]
  162. # draw the graph
  163. nx.draw_networkx_nodes(self.graph,
  164. pos=self.node_pos_dict,
  165. nodelist=nlist,
  166. ax=ax,
  167. node_color=ncolor,
  168. label=node_type,
  169. edgecolors=NODE_EDGE_COLOR,
  170. node_size=NODE_SIZE)
  171. nx.draw_networkx_labels(self.graph, self.node_pos_dict, font_size=FONT_SIZE, font_family=FONT_FAMILY)
  172. def create_graph_settings_with_cartesic_coord(self, plt, ax):
  173. # getting the position of the legend so that the graph is not disrupted:
  174. x_lim_max = max([x[0] for x in self.node_pos_dict.values()])
  175. x_lim_max += x_lim_max * 0.01
  176. y_lim_max = max([y[1] for y in self.node_pos_dict.values()])
  177. y_lim_max += y_lim_max * 0.05
  178. x_lim_min = min([x[0] for x in self.node_pos_dict.values()])
  179. x_lim_min -= x_lim_min * 0.01
  180. y_lim_min = min([y[1] for y in self.node_pos_dict.values()])
  181. y_lim_min -= y_lim_min * 0.05
  182. ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
  183. legend_size = 16
  184. plt.legend(scatterpoints=1, prop={'size': legend_size})
  185. plt.xlim(x_lim_min - 0.5, x_lim_max + legend_size / 6)
  186. plt.ylim(y_lim_min - 0.5, y_lim_max + 0.5)
  187. plt.xlabel('x')
  188. plt.ylabel('y')
  189. plt.tight_layout()