|
11 | 11 | from matplotlib.figure import Figure |
12 | 12 |
|
13 | 13 | import mesa |
14 | | -from mesa.experimental.cell_space import VoronoiGrid |
| 14 | +from mesa.experimental.cell_space import Grid, VoronoiGrid |
15 | 15 | from mesa.space import PropertyLayer |
16 | 16 | from mesa.visualization.utils import update_counter |
17 | 17 |
|
@@ -52,16 +52,20 @@ def SpaceMatplotlib( |
52 | 52 | if space is None: |
53 | 53 | space = getattr(model, "space", None) |
54 | 54 |
|
55 | | - if isinstance(space, mesa.space._Grid): |
56 | | - _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model) |
57 | | - elif isinstance(space, mesa.space.ContinuousSpace): |
58 | | - _draw_continuous_space(space, space_ax, agent_portrayal, model) |
59 | | - elif isinstance(space, mesa.space.NetworkGrid): |
60 | | - _draw_network_grid(space, space_ax, agent_portrayal) |
61 | | - elif isinstance(space, VoronoiGrid): |
62 | | - _draw_voronoi(space, space_ax, agent_portrayal) |
63 | | - elif space is None and propertylayer_portrayal: |
64 | | - draw_property_layers(space_ax, space, propertylayer_portrayal, model) |
| 55 | + # https://stackoverflow.com/questions/67524641/convert-multiple-isinstance-checks-to-structural-pattern-matching |
| 56 | + match space: |
| 57 | + case mesa.space._Grid(): |
| 58 | + _draw_continuous_space(space, space_ax, agent_portrayal, model) |
| 59 | + case mesa.space.NetworkGrid(): |
| 60 | + _draw_network_grid(space, space_ax, agent_portrayal) |
| 61 | + case VoronoiGrid(): |
| 62 | + _draw_voronoi(space, space_ax, agent_portrayal) |
| 63 | + case Grid(): # matches OrthogonalMooreGrid, OrthogonalVonNeumannGrid, and Hexgrid |
| 64 | + # fixme add a separate draw method for hexgrids in the future |
| 65 | + _draw_discrete_space_grid(space, space_ax, agent_portrayal) |
| 66 | + case None: |
| 67 | + if propertylayer_portrayal: |
| 68 | + draw_property_layers(space_ax, space, propertylayer_portrayal, model) |
65 | 69 |
|
66 | 70 | solara.FigureMatplotlib( |
67 | 71 | space_fig, format="png", bbox_inches="tight", dependencies=dependencies |
@@ -291,6 +295,44 @@ def portray(g): |
291 | 295 | space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in black |
292 | 296 |
|
293 | 297 |
|
| 298 | +def _draw_discrete_space_grid(space: Grid, space_ax, agent_portrayal): |
| 299 | + if space._ndims != 2: |
| 300 | + raise ValueError("Space must be 2D") |
| 301 | + |
| 302 | + def portray(g): |
| 303 | + x = [] |
| 304 | + y = [] |
| 305 | + s = [] # size |
| 306 | + c = [] # color |
| 307 | + |
| 308 | + for cell in g.all_cells: |
| 309 | + for agent in cell.agents: |
| 310 | + data = agent_portrayal(agent) |
| 311 | + x.append(cell.coordinate[0]) |
| 312 | + y.append(cell.coordinate[1]) |
| 313 | + if "size" in data: |
| 314 | + s.append(data["size"]) |
| 315 | + if "color" in data: |
| 316 | + c.append(data["color"]) |
| 317 | + out = {"x": x, "y": y} |
| 318 | + out["s"] = s |
| 319 | + if len(c) > 0: |
| 320 | + out["c"] = c |
| 321 | + |
| 322 | + return out |
| 323 | + |
| 324 | + space_ax.set_xlim(0, space.width) |
| 325 | + space_ax.set_ylim(0, space.height) |
| 326 | + |
| 327 | + # Draw grid lines |
| 328 | + for x in range(space.width + 1): |
| 329 | + space_ax.axvline(x, color="gray", linestyle=":") |
| 330 | + for y in range(space.height + 1): |
| 331 | + space_ax.axhline(y, color="gray", linestyle=":") |
| 332 | + |
| 333 | + space_ax.scatter(**portray(space)) |
| 334 | + |
| 335 | + |
294 | 336 | def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): |
295 | 337 | """Create a plotting function for a specified measure. |
296 | 338 |
|
|
0 commit comments