Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 68 additions & 26 deletions mesa_geo/geospace.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def total_bounds(self) -> np.ndarray | None:
"""
Return the bounds of the GeoSpace in [min_x, min_y, max_x, max_y] format.
"""
if self._total_bounds is None:
if len(self.agents) > 0:
self._update_bounds(self._agent_layer.total_bounds)
if len(self.layers) > 0:
for layer in self.layers:
self._update_bounds(layer.total_bounds)
return self._total_bounds

def _update_bounds(self, new_bounds: np.ndarray) -> None:
Expand Down Expand Up @@ -128,7 +134,7 @@ def add_layer(self, layer: ImageLayer | RasterLayer | gpd.GeoDataFrame) -> None:
"to `False` to suppress this warning message."
)
layer.to_crs(self.crs, inplace=True)
self._update_bounds(layer.total_bounds)
self._total_bounds = None
self._static_layers.append(layer)

def _check_agent(self, agent):
Expand Down Expand Up @@ -161,7 +167,7 @@ def add_agents(self, agents):
for agent in agents:
self._check_agent(agent)
self._agent_layer.add_agents(agents)
self._update_bounds(new_bounds=self._agent_layer.total_bounds)
self._total_bounds = None

def _recreate_rtree(self, new_agents=None):
"""Create a new rtree index from agents geometries."""
Expand All @@ -170,6 +176,7 @@ def _recreate_rtree(self, new_agents=None):
def remove_agent(self, agent):
"""Remove an agent from the GeoSpace."""
self._agent_layer.remove_agent(agent)
self._total_bounds = None

def get_relation(self, agent, relation):
"""Return a list of related agents.
Expand Down Expand Up @@ -231,34 +238,50 @@ class _AgentLayer:
"""

def __init__(self):
# neighborhood graph for touching neighbors
self._neighborhood = None

# Set up rtree index
self.idx = index.Index()
self.idx.agents = {}
# rtree index for spatial indexing (e.g., neighbors within distance, agents at pos, etc.)
self._idx = None
self._id_to_agent = {}
# bounds of the layer in [min_x, min_y, max_x, max_y] format
# While it is possible to calculate the bounds from rtree index,
# total_bounds is almost always needed (e.g., for plotting), while rtree index is not.
# Hence we compute total_bounds separately from rtree index.
self._total_bounds = None

@property
def agents(self):
"""
Return a list of all agents in the layer.
"""

return list(self.idx.agents.values())
return list(self._id_to_agent.values())

@property
def total_bounds(self):
"""
Return the bounds of the layer in [min_x, min_y, max_x, max_y] format.
"""

return self.idx.get_bounds(coordinate_interleaved=True)
if self._total_bounds is None and len(self.agents) > 0:
bounds = np.array([agent.geometry.bounds for agent in self.agents])
min_x, min_y = np.min(bounds[:, :2], axis=0)
max_x, max_y = np.max(bounds[:, 2:], axis=0)
self._total_bounds = np.array([min_x, min_y, max_x, max_y])
return self._total_bounds

def _get_rtree_intersections(self, geometry):
"""
Calculate rtree intersections for candidate agents.
"""

return (self.idx.agents[i] for i in self.idx.intersection(geometry.bounds))
self._ensure_index()
if self._idx is None:
return []
else:
return [
self._id_to_agent[i] for i in self._idx.intersection(geometry.bounds)
]

def _create_neighborhood(self):
"""
Expand All @@ -273,21 +296,27 @@ def _create_neighborhood(self):
for agent, key in zip(agents, self._neighborhood.neighbors.keys()):
self._neighborhood.idx[agent] = key

def _ensure_index(self):
"""
Ensure that the rtree index is created.
"""

if self._idx is None:
self._recreate_rtree()

def _recreate_rtree(self, new_agents=None):
"""
Create a new rtree index from agents geometries.
"""

if new_agents is None:
new_agents = []
old_agents = list(self.agents)
agents = old_agents + new_agents

# Bulk insert agents
index_data = ((id(agent), agent.geometry.bounds, None) for agent in agents)
agents = list(self.agents) + new_agents

self.idx = index.Index(index_data)
self.idx.agents = {id(agent): agent for agent in agents}
if len(agents) > 0:
# Bulk insert agents
index_data = ((id(agent), agent.geometry.bounds, None) for agent in agents)
self._idx = index.Index(index_data)

def add_agents(self, agents):
"""
Expand All @@ -303,18 +332,25 @@ def add_agents(self, agents):

if isinstance(agents, GeoAgent):
agent = agents
self.idx.insert(id(agent), agent.geometry.bounds, None)
self.idx.agents[id(agent)] = agent
self._id_to_agent[id(agent)] = agent
if self._idx:
self._idx.insert(id(agent), agent.geometry.bounds, None)
else:
self._recreate_rtree(agents)
for agent in agents:
self._id_to_agent[id(agent)] = agent
if self._idx:
self._recreate_rtree(agents)
self._total_bounds = None

def remove_agent(self, agent):
"""
Remove an agent from the layer.
"""

self.idx.delete(id(agent), agent.geometry.bounds)
del self.idx.agents[id(agent)]
del self._id_to_agent[id(agent)]
if self._idx:
self._idx.delete(id(agent), agent.geometry.bounds)
self._total_bounds = None

def get_relation(self, agent, relation):
"""Return a list of related agents.
Expand All @@ -327,6 +363,7 @@ def get_relation(self, agent, relation):
Omit to compare against all other agents of the layer.
"""

self._ensure_index()
possible_agents = self._get_rtree_intersections(agent.geometry)
for other_agent in possible_agents:
if (
Expand All @@ -336,6 +373,7 @@ def get_relation(self, agent, relation):
yield other_agent

def get_intersecting_agents(self, agent):
self._ensure_index()
intersecting_agents = self.get_relation(agent, "intersects")
return intersecting_agents

Expand All @@ -347,7 +385,7 @@ def get_neighbors_within_distance(
Distance is measured as a buffer around the agent's geometry,
set center=True to calculate distance from center.
"""

self._ensure_index()
if center:
geometry = agent.geometry.centroid.buffer(distance)
else:
Expand All @@ -363,6 +401,7 @@ def agents_at(self, pos):
Return a generator of agents at given pos.
"""

self._ensure_index()
if not isinstance(pos, Point):
pos = Point(pos)

Expand All @@ -386,10 +425,13 @@ def get_neighbors(self, agent):
if not self._neighborhood or self._neighborhood.agents != self.agents:
self._create_neighborhood()

idx = self._neighborhood.idx[agent]
neighbors_idx = self._neighborhood.neighbors[idx]
neighbors = [self.agents[i] for i in neighbors_idx]
return neighbors
if self._neighborhood is None:
return []
else:
idx = self._neighborhood.idx[agent]
neighbors_idx = self._neighborhood.neighbors[idx]
neighbors = [self.agents[i] for i in neighbors_idx]
return neighbors

def get_agents_as_GeoDataFrame(self, agent_cls=GeoAgent) -> gpd.GeoDataFrame:
"""
Expand Down
10 changes: 10 additions & 0 deletions tests/test_GeoSpace.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,13 @@ def test_agents_at(self):
agent.unique_id for agent in self.geo_space.agents_at((1, 1))
}
self.assertEqual(agents_id_found, agents_id)

def test_get_neighbors(self):
self.geo_space.add_agents(self.polygon_agent)
self.assertEqual(len(self.geo_space.get_neighbors(self.polygon_agent)), 0)
self.geo_space.add_agents(self.touching_agent)
self.assertEqual(len(self.geo_space.get_neighbors(self.polygon_agent)), 1)
self.assertEqual(
self.geo_space.get_neighbors(self.polygon_agent)[0].unique_id,
self.touching_agent.unique_id,
)