Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions activitysim/abm/models/auto_ownership.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@
import pandas as pd
from pydantic import validator

from activitysim.core import config, estimation, simulate, tracing, workflow
from activitysim.core import (
config,
expressions,
estimation,
simulate,
tracing,
workflow,
)
from activitysim.core.configuration.base import PreprocessorSettings, PydanticReadable
from activitysim.core.configuration.logit import LogitComponentSettings
from .util import annotate

logger = logging.getLogger(__name__)

Expand All @@ -19,7 +27,8 @@ class AutoOwnershipSettings(LogitComponentSettings):
Settings for the `auto_ownership` component.
"""

# This model is relatively simple and has no unique settings
preprocessor: PreprocessorSettings | None = None
annotate_households: PreprocessorSettings | None = None


@workflow.step
Expand Down Expand Up @@ -57,6 +66,21 @@ def auto_ownership_simulate(

logger.info("Running %s with %d households", trace_label, len(choosers))

# - preprocessor
preprocessor_settings = model_settings.preprocessor
if preprocessor_settings:

locals_d = {}
if constants is not None:
locals_d.update(constants)

expressions.assign_columns(
df=choosers,
model_settings=preprocessor_settings,
locals_dict=locals_d,
trace_label=trace_label,
)

if estimator:
estimator.write_model_settings(model_settings, model_settings_file_name)
estimator.write_spec(model_settings)
Expand Down Expand Up @@ -92,5 +116,8 @@ def auto_ownership_simulate(
"auto_ownership", households.auto_ownership, value_counts=True
)

if model_settings.annotate_households:
annotate.annotate_households(model_settings, trace_label)

if trace_hh_id:
state.tracing.trace_df(households, label="auto_ownership", warn_if_empty=True)
56 changes: 56 additions & 0 deletions activitysim/abm/models/util/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,62 @@
logger = logging.getLogger(__name__)


def annotate_households(
state: workflow.State,
model_settings: dict | PydanticBase,
trace_label: str,
locals_dict: dict | None = None,
):
"""
Add columns to the households table in the pipeline according to spec.

Parameters
----------
model_settings : dict
trace_label : str
"""
if isinstance(model_settings, PydanticBase):
model_settings = model_settings.dict()
if locals_dict is None:
locals_dict = {}
households = state.get_dataframe("households")
expressions.assign_columns(
df=households,
model_settings=model_settings.get("annotate_households"),
locals_dict=locals_dict,
trace_label=tracing.extend_trace_label(trace_label, "annotate_households"),
)
state.add_table("households", households)


def annotate_persons(
state: workflow.State,
model_settings: dict | PydanticBase,
trace_label: str,
locals_dict: dict | None = None,
):
"""
Add columns to the persons table in the pipeline according to spec.

Parameters
----------
model_settings : dict
trace_label : str
"""
if isinstance(model_settings, PydanticBase):
model_settings = model_settings.dict()
if locals_dict is None:
locals_dict = {}
persons = state.get_dataframe("persons")
expressions.assign_columns(
df=persons,
model_settings=model_settings.get("annotate_persons"),
locals_dict=locals_dict,
trace_label=tracing.extend_trace_label(trace_label, "annotate_persons"),
)
state.add_table("persons", persons)


def annotate_tours(
state: workflow.State,
model_settings: dict | PydanticBase,
Expand Down