Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@
import numpy as np
import numpy.typing as npt

from scipy.spatial import KDTree

# For Mypy
from .agent import Agent


# for better performance, we calculate the tuple to use in the is_integer function
_types_integer = (int, np.integer)

Expand Down Expand Up @@ -867,6 +870,7 @@ def __init__(
self.torus = torus

self._agent_points: npt.NDArray[FloatCoordinate] | None = None
self._kdtree = None
self._index_to_agent: dict[int, Agent] = {}
self._agent_to_index: dict[Agent, int | None] = {}

Expand All @@ -878,6 +882,7 @@ def _build_agent_cache(self):
self._index_to_agent[idx] = agent
# Since dicts are ordered by insertion, we can iterate through agents keys
self._agent_points = np.array([agent.pos for agent in self._agent_to_index])
self._kdtree = KDTree(self._agent_points, boxsize=self.size if self.torus else None)

def _invalidate_agent_cache(self):
"""Clear cached data of agents and positions in the space."""
Expand Down Expand Up @@ -941,15 +946,14 @@ def get_neighbors(
if self._agent_points is None:
self._build_agent_cache()

deltas = np.abs(self._agent_points - np.array(pos))
if self.torus:
deltas = np.minimum(deltas, self.size - deltas)
dists = deltas[:, 0] ** 2 + deltas[:, 1] ** 2
pos_arr = np.array(pos)

(idxs,) = np.where(dists <= radius**2)
neighbors = [
self._index_to_agent[x] for x in idxs if include_center or dists[x] > 0
]
idxs = self._kdtree.query_ball_point(pos_arr, radius)

if not include_center:
idxs = [idx for idx in idxs if not np.array_equal(self._agent_points[idx], pos_arr)]

neighbors = [self._index_to_agent[idx] for idx in idxs]
return neighbors

def get_heading(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from setuptools import find_packages, setup

requires = ["click", "cookiecutter", "networkx", "numpy", "pandas", "tornado", "tqdm"]
requires = ["click", "cookiecutter", "networkx", "numpy", "pandas", "scipy", "tornado", "tqdm"]

extras_require = {
"dev": [
Expand Down