"""Drawing functions for visualizing the simulation."""
import math
from collections.abc import Callable
from enum import StrEnum
from functools import cache, reduce
from typing import TYPE_CHECKING, Any, Union
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.figure import Figure
from matplotlib.patches import Circle, RegularPolygon
from networkx import draw_networkx_edges as __d_netx_edges
from networkx import draw_networkx_labels as __d_netx_labels
from numpy import uint8, zeros
from pydistsim._exceptions import SimulationException
from pydistsim.algorithm.node_algorithm import NodeAlgorithm
from pydistsim.logging import logger
from pydistsim.simulation import Simulation
if TYPE_CHECKING:
from pydistsim.message import Message
from pydistsim.network.network import NetworkType
from pydistsim.network.node import Node
[docs]
class MessageType(StrEnum):
IN = "Inbox"
OUT = "Outbox"
TRANSIT = "Transit"
LOST = "Lost"
MESSAGE_COLOR = {
MessageType.IN: "tab:cyan",
MessageType.OUT: "w",
MessageType.TRANSIT: "y",
MessageType.LOST: "r",
}
EDGES_ALPHA = 0.6
NODE_COLORS = [
"tab:blue",
"tab:orange",
"tab:green",
"tab:red",
"tab:purple",
"tab:brown",
"tab:pink",
"tab:gray",
"tab:olive",
"tab:cyan",
] * 100
MESSAGE_SHAPE_ZORDER = 3
MESSAGE_ANNOTATION_ZORDER = 4
def __get_message_positions_and_orientation(xd, yd, xs, ys, direction: MessageType) -> tuple[float, float, float]:
angle_in_rads = -math.pi / 2 + math.atan2(yd - ys, xd - xs)
x = y = None
if direction == MessageType.OUT:
offset = 1 / 6
elif direction == MessageType.IN:
offset = 7 / 8
elif direction == MessageType.TRANSIT:
offset = 1 / 3
elif direction == MessageType.LOST:
offset_distance = 10 * (-1 if angle_in_rads < 0 else 1)
xm = (xs + xd) / 2.0
ym = (ys + yd) / 2.0
if xs == xd: # vertical line
x = xm + offset_distance
y = ym
elif yd == ys: # horizontal line
x = xm
y = ym + offset_distance
else: # diagonal line
slope = (yd - ys) / (xd - xs)
slope_perpendicular = -1 / slope
x = xm + offset_distance / (slope_perpendicular**2 + 1) ** 0.5
y = ym + offset_distance / (slope_perpendicular**2 + 1) ** 0.5 * slope_perpendicular
if x is None:
x = xs + (xd - xs) * offset
y = ys + (yd - ys) * offset
return x, y, angle_in_rads
def __draw_tree(tree_key: str, net: "NetworkType", axes: Axes):
"""
Show tree representation of network.
Attributes:
tree_key (str):
key in nodes memory (dictionary) where tree data is stored
storage format can be a list off tree neighbors or a dict:
{'parent': parent_node,
'children': [child_node1, child_node2 ...]}
"""
(nodes, edges) = net.get_tree_net(tree_key, return_subnetwork=False)
if nodes:
__d_netx_edges(
net,
net.pos,
edges,
edge_color="tab:brown",
width=2.5,
alpha=EDGES_ALPHA + 0.2,
ax=axes,
)
@cache
def __draw_nodes(node_pos_color: tuple[tuple[int, tuple[float, ...], str]], fig_id=0) -> PatchCollection:
nodeCircles = []
for radius, pos, color in node_pos_color:
c = Circle(
pos,
radius,
color=color,
ec="k",
lw=1.0,
ls="solid",
picker=3,
)
nodeCircles.append(c)
node_collection = PatchCollection(nodeCircles, match_original=True)
node_collection.set_picker(3)
return node_collection
@cache
def __create_figure_legend(color_map: tuple, algorithm_name: str, show_messages: bool, fig_id=0):
proxy_kwargs = {
"xy": (0, 0),
"radius": 8.0,
"ec": "k",
"lw": 1.0,
"ls": "solid",
}
# Node status legend
proxy = []
labels = []
for status, color in color_map:
proxy.append(
Circle(
color=color,
**proxy_kwargs,
)
)
labels.append(status)
legends = []
legends += [
plt.gcf().legend(
proxy,
labels,
loc="outside right upper",
fontsize=9,
ncol=1,
# bbox_to_anchor=(1.6, 0.75),
title="Statuses for\n%s:" % algorithm_name,
)
]
if show_messages:
# Message legend
legends += [
plt.gcf().legend(
[
Circle(
color=MESSAGE_COLOR[msg],
**proxy_kwargs,
)
for msg in (
MessageType.IN,
MessageType.OUT,
MessageType.TRANSIT,
MessageType.LOST,
)
],
["Inbox", "Outbox", "Transit", "Lost"],
loc="outside right lower",
ncol=1,
# bbox_to_anchor=(1.5, 0.2),
fontsize=9,
title="Messages:",
)
]
return legends
def __create_and_get_color_labels(net, algorithm=None, subclusters=None, figure: Figure = None, show_messages=True):
node_colors = {}
if algorithm:
color_map = {}
if isinstance(algorithm, NodeAlgorithm):
for ind, status in enumerate(algorithm.Status.__members__):
color_map.update({status: NODE_COLORS[ind]})
if figure:
figure.legends = __create_figure_legend(
tuple(color_map.items()), algorithm.name, show_messages, id(figure)
)
for n in net.nodes():
if n.status == "" or n.status not in list(color_map.keys()):
node_colors[n] = "k"
else:
node_colors[n] = color_map[n.status]
elif subclusters:
for i, sc in enumerate(subclusters):
for n in sc:
if n in node_colors:
node_colors[n] = "k"
else:
node_colors[n] = NODE_COLORS[i]
return node_colors
def __draw_edges(net, edges, axes) -> LineCollection:
return __d_netx_edges(net, net.pos, alpha=EDGES_ALPHA, edgelist=edges, ax=axes)
def __draw_messages(
net: "NetworkType", axes: Axes, message_colors: Callable[["Message", MessageType], Any], message_radius: float
):
MESSAGE_LINE_WIDTH = 1.0
patch_kwargs = {
"numVertices": 3,
"radius": message_radius,
"lw": MESSAGE_LINE_WIDTH,
"ls": "solid",
"picker": 3,
"zorder": MESSAGE_SHAPE_ZORDER,
"ec": "k",
}
msg_artists = []
message_collection = {
node: {
MessageType.OUT: [
([(msg, msg.destination)] if msg.destination is not None else list(net.adj[node].keys()))
for msg in node.outbox
],
MessageType.IN: [[(msg, msg.source)] for msg in node.inbox],
MessageType.TRANSIT: [
[(msg, other_node) for msg in net.get_transit_messages(node, other_node) if msg.source == node]
for other_node in net.out_neighbors(node)
if net.get_transit_messages(node, other_node)
],
MessageType.LOST: [
[(msg, other_node) for msg in net.get_lost_messages(node, other_node) if msg.source == node]
for other_node in net.out_neighbors(node)
if net.get_lost_messages(node, other_node)
],
}
for node in net.nodes()
}
for node in message_collection:
messages_type_dict = message_collection[node]
msg_dict = {}
for msg_type in messages_type_dict:
dest_lists = messages_type_dict[msg_type]
msg_dict[msg_type] = {}
for msg, dest in reduce(lambda x, y: x + y, dest_lists, []):
if dest is None:
continue
src = node if msg_type != MessageType.IN else dest
dst = dest if msg_type != MessageType.IN else node
count, msgs = msg_dict[msg_type].get((src, dst), (0, []))
msgs.append(msg)
msg_dict[msg_type][(src, dst)] = count + 1, msgs
for msg_type in msg_dict:
for (src, dst), (count, msgs) in msg_dict[msg_type].items():
if not src or not dst:
continue # Defensive check
x, y, rads_orientation = __get_message_positions_and_orientation(*net.pos[dst], *net.pos[src], msg_type)
triangle_artist = RegularPolygon(
(x, y),
orientation=rads_orientation,
**patch_kwargs,
fc=MESSAGE_COLOR[msg_type] if not message_colors else message_colors(msgs[0], msg_type),
label=msg_type,
)
if count > 1:
axes.annotate(
f"{count}",
(x + 5, y + 5),
color="k",
fontsize=8,
zorder=MESSAGE_ANNOTATION_ZORDER,
)
msg_artists.append(triangle_artist)
if msg_artists:
message_collection = PatchCollection(msg_artists, match_original=True)
message_collection.set_picker(3)
axes.add_collection(message_collection)
def __draw_labels(net: "NetworkType", node_size, dpi) -> dict["Node", plt.Text]:
label_pos = {}
from math import sqrt
if not callable(node_size):
label_delta = 1.5 * sqrt(node_size) * dpi / 100
for n in net.nodes():
if callable(node_size):
label_delta = 1.5 * sqrt(node_size(n)) * dpi / 100
label_pos[n] = net.pos[n].copy() + label_delta
return __d_netx_labels(
net,
label_pos,
labels=net.labels,
horizontalalignment="left",
verticalalignment="bottom",
)
[docs]
def draw_current_state(
sim: Union["Simulation", "NetworkType"],
axes: Axes = None,
clear: bool = True,
tree_key: str = None,
dpi: int = 100,
node_radius: int | Callable[["Node"], int] = 10,
node_positions: dict | Callable[[], dict["Node", tuple[float, ...]]] = None,
node_colors: dict | Callable[[], dict["Node", Any]] = None,
edge_filter: list = None,
show_messages: bool = True,
message_colors: Callable[["Message", MessageType], Any] = None,
message_size: int = None,
show_legends: bool = True,
space_for_legend: float = 0.15,
show_labels: bool = True,
node_labels: dict["Node", str] | Callable[[], dict["Node", str]] = None,
):
"""
Function to draw the current state of the simulation or network. This function is used to visualize the network
and the messages in the network.
Automatically determines the current algorithm and draws the network accordingly. This includes a mapping of
node colors to the status of the nodes, as well as the messages in the network.
:param sim: Simulation or NetworkType object
:param axes: matplotlib axes object
:param clear: boolean to clear the axes before drawing
:param tree_key: key in nodes memory (dictionary) where tree data is stored
:param dpi: dots per inch
:param node_radius: radius of nodes
:param node_positions: dictionary of node positions or function to get node positions
:param node_colors: dictionary of node colors or function to get node colors
:param edge_filter: list of edges to draw
:param show_messages: boolean to show messages in the network
:param show_legends: boolean to show legends
:param space_for_legend: space for legend, as a porcentage of the figure
:param show_labels: boolean to show labels of nodes
:param node_labels: dictionary of node labels or function to get node labels
:return: matplotlib figure object
"""
if isinstance(sim, Simulation):
net = sim.network
try:
currentAlgorithm = sim.get_current_algorithm()
except SimulationException:
currentAlgorithm = None
else:
net = sim
currentAlgorithm = None
if node_positions:
pos_aux = net.pos
net.pos = node_positions() if callable(node_positions) else node_positions
if axes is None:
with plt.ioff():
fig, axes = plt.subplots()
if clear:
axes.clear()
axes.pcolormesh(net.environment.image, vmin=0, cmap="binary_r")
if tree_key:
__draw_tree(tree_key, net, axes)
__draw_edges(net, edge_filter, axes)
__create_nodes(axes, node_radius, node_colors, show_messages, show_legends, net, currentAlgorithm)
if show_labels:
if node_labels:
aux_labels = net.labels
net.labels = node_labels() if callable(node_labels) else node_labels
__draw_labels(net, node_radius, dpi)
if show_messages:
if not callable(node_radius):
message_size = message_size or 3 * node_radius / 4
else:
assert message_size is not None, "Message size must be provided when node_radius is a function."
__draw_messages(net, axes, message_colors, message_radius=message_size)
step_text = " (step %d)" % sim.algorithmState["step"] if isinstance(currentAlgorithm, NodeAlgorithm) else ""
axes.set_title((currentAlgorithm.name if currentAlgorithm else "") + step_text)
# remove as much whitespace as possible
axes.axis("off")
plt.tight_layout()
if show_legends:
plt.gcf().subplots_adjust(left=0, bottom=0, right=1 - space_for_legend, top=0.95)
else:
plt.gcf().subplots_adjust(left=0, bottom=0, right=1, top=0.95)
if node_positions:
net.pos = pos_aux
if node_labels and show_labels:
net.labels = aux_labels
return axes.figure
def __create_nodes(axes, node_radius, node_colors, show_messages, show_legends, net, currentAlgorithm):
if not node_colors:
node_colors = __create_and_get_color_labels(
net,
algorithm=currentAlgorithm,
figure=axes.figure if show_legends else None,
show_messages=show_messages,
)
elif callable(node_colors):
node_colors = node_colors()
if isinstance(node_colors, dict):
node_pos_color = ((n, net.pos[n], node_colors.get(n, "r")) for n in net.nodes())
else:
node_pos_color = ((n, net.pos[n], node_colors) for n in net.nodes())
if callable(node_radius):
node_radius_pos_color = tuple((node_radius(n), tuple(pos), color) for n, pos, color in node_pos_color)
else:
node_radius_pos_color = tuple((node_radius, tuple(pos), color) for n, pos, color in node_pos_color)
patches = __draw_nodes(node_radius_pos_color, fig_id=id(plt.gcf()))
patches.axes = axes
patches.figure = axes.figure
axes.add_collection(patches)
return patches
[docs]
def create_animation(
sim: "Simulation",
figsize=None,
dpi: int = 100,
milliseconds_per_frame: int = 300,
frame_limit: int = 2000,
reset_on_start: bool = True,
**kwargs,
) -> animation.FuncAnimation:
"""
Create an animation of the simulation.
Example for visualizing in Jupyter Notebook:
.. code-block:: python
anim = create_animation(sim)
video = anim.to_html5_video()
from IPython.display import HTML
HTML(video)
Example for saving as a video file:
.. code-block:: python
from matplotlib.animation import FFMpegFileWriter
moviewriter = FFMpegFileWriter()
anim = draw.create_animation(sim)
anim.save("flood.mp4", writer=moviewriter)
:param sim: Simulation object
:param figsize: figure size
:param dpi: dots per inch
:param milliseconds_per_frame: milliseconds per frame
:param frame_limit: limit of frames, default is 2000
:param reset_on_start: control if the simulation will restart on animation start
:param kwargs: additional keyword arguments to pass to the :func:`draw_current_state` function
:return: animation object
"""
exception_occurred = False
with plt.ioff(): # Turn off interactive mode
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
def draw_frame(frame_index):
nonlocal exception_occurred
if exception_occurred:
return
if frame_index == 0 and reset_on_start:
sim.reset()
draw_current_state(sim, ax, dpi=dpi, **kwargs)
try:
sim.run(1)
except BaseException as e:
exception_occurred = True
logger.exception(e)
return ax.artists
def frame_count():
frame_index = 0
def should_continue():
if exception_occurred:
return False
if frame_limit and frame_index >= frame_limit:
logger.warning("Frame limit reached.")
return False
if not (sim.is_halted() and sim.get_current_algorithm() is None):
logger.debug(f"Frame {frame_index}, simulation still running.")
return True
if frame_index == 0:
logger.debug(f"Frame {frame_index}, simulation not started.")
return True
return False
while True:
if should_continue():
yield frame_index
frame_index += 1
else:
yield frame_index
break
return animation.FuncAnimation(
fig,
func=draw_frame,
frames=frame_count,
interval=milliseconds_per_frame,
cache_frame_data=True,
save_count=frame_limit,
)