Skip to content

Commit 42043d3

Browse files
authored
Add support for drawing discrete grids (#2386)
1 parent 9f13c30 commit 42043d3

File tree

1 file changed

+53
-11
lines changed

1 file changed

+53
-11
lines changed

mesa/visualization/components/matplotlib.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from matplotlib.figure import Figure
1212

1313
import mesa
14-
from mesa.experimental.cell_space import VoronoiGrid
14+
from mesa.experimental.cell_space import Grid, VoronoiGrid
1515
from mesa.space import PropertyLayer
1616
from mesa.visualization.utils import update_counter
1717

@@ -52,16 +52,20 @@ def SpaceMatplotlib(
5252
if space is None:
5353
space = getattr(model, "space", None)
5454

55-
if isinstance(space, mesa.space._Grid):
56-
_draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)
57-
elif isinstance(space, mesa.space.ContinuousSpace):
58-
_draw_continuous_space(space, space_ax, agent_portrayal, model)
59-
elif isinstance(space, mesa.space.NetworkGrid):
60-
_draw_network_grid(space, space_ax, agent_portrayal)
61-
elif isinstance(space, VoronoiGrid):
62-
_draw_voronoi(space, space_ax, agent_portrayal)
63-
elif space is None and propertylayer_portrayal:
64-
draw_property_layers(space_ax, space, propertylayer_portrayal, model)
55+
# https://stackoverflow.com/questions/67524641/convert-multiple-isinstance-checks-to-structural-pattern-matching
56+
match space:
57+
case mesa.space._Grid():
58+
_draw_continuous_space(space, space_ax, agent_portrayal, model)
59+
case mesa.space.NetworkGrid():
60+
_draw_network_grid(space, space_ax, agent_portrayal)
61+
case VoronoiGrid():
62+
_draw_voronoi(space, space_ax, agent_portrayal)
63+
case Grid(): # matches OrthogonalMooreGrid, OrthogonalVonNeumannGrid, and Hexgrid
64+
# fixme add a separate draw method for hexgrids in the future
65+
_draw_discrete_space_grid(space, space_ax, agent_portrayal)
66+
case None:
67+
if propertylayer_portrayal:
68+
draw_property_layers(space_ax, space, propertylayer_portrayal, model)
6569

6670
solara.FigureMatplotlib(
6771
space_fig, format="png", bbox_inches="tight", dependencies=dependencies
@@ -291,6 +295,44 @@ def portray(g):
291295
space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black
292296

293297

298+
def _draw_discrete_space_grid(space: Grid, space_ax, agent_portrayal):
299+
if space._ndims != 2:
300+
raise ValueError("Space must be 2D")
301+
302+
def portray(g):
303+
x = []
304+
y = []
305+
s = [] # size
306+
c = [] # color
307+
308+
for cell in g.all_cells:
309+
for agent in cell.agents:
310+
data = agent_portrayal(agent)
311+
x.append(cell.coordinate[0])
312+
y.append(cell.coordinate[1])
313+
if "size" in data:
314+
s.append(data["size"])
315+
if "color" in data:
316+
c.append(data["color"])
317+
out = {"x": x, "y": y}
318+
out["s"] = s
319+
if len(c) > 0:
320+
out["c"] = c
321+
322+
return out
323+
324+
space_ax.set_xlim(0, space.width)
325+
space_ax.set_ylim(0, space.height)
326+
327+
# Draw grid lines
328+
for x in range(space.width + 1):
329+
space_ax.axvline(x, color="gray", linestyle=":")
330+
for y in range(space.height + 1):
331+
space_ax.axhline(y, color="gray", linestyle=":")
332+
333+
space_ax.scatter(**portray(space))
334+
335+
294336
def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]):
295337
"""Create a plotting function for a specified measure.
296338

0 commit comments

Comments
 (0)