diff --git a/activitysim/abm/models/auto_ownership.py b/activitysim/abm/models/auto_ownership.py index 16b3141e96..c993445662 100644 --- a/activitysim/abm/models/auto_ownership.py +++ b/activitysim/abm/models/auto_ownership.py @@ -7,9 +7,17 @@ import pandas as pd from pydantic import validator -from activitysim.core import config, estimation, simulate, tracing, workflow +from activitysim.core import ( + config, + expressions, + estimation, + simulate, + tracing, + workflow, +) from activitysim.core.configuration.base import PreprocessorSettings, PydanticReadable from activitysim.core.configuration.logit import LogitComponentSettings +from .util import annotate logger = logging.getLogger(__name__) @@ -19,7 +27,8 @@ class AutoOwnershipSettings(LogitComponentSettings): Settings for the `auto_ownership` component. """ - # This model is relatively simple and has no unique settings + preprocessor: PreprocessorSettings | None = None + annotate_households: PreprocessorSettings | None = None @workflow.step @@ -57,6 +66,21 @@ def auto_ownership_simulate( logger.info("Running %s with %d households", trace_label, len(choosers)) + # - preprocessor + preprocessor_settings = model_settings.preprocessor + if preprocessor_settings: + + locals_d = {} + if constants is not None: + locals_d.update(constants) + + expressions.assign_columns( + df=choosers, + model_settings=preprocessor_settings, + locals_dict=locals_d, + trace_label=trace_label, + ) + if estimator: estimator.write_model_settings(model_settings, model_settings_file_name) estimator.write_spec(model_settings) @@ -92,5 +116,8 @@ def auto_ownership_simulate( "auto_ownership", households.auto_ownership, value_counts=True ) + if model_settings.annotate_households: + annotate.annotate_households(model_settings, trace_label) + if trace_hh_id: state.tracing.trace_df(households, label="auto_ownership", warn_if_empty=True) diff --git a/activitysim/abm/models/util/annotate.py b/activitysim/abm/models/util/annotate.py index a689637f95..69958b1abf 100644 --- a/activitysim/abm/models/util/annotate.py +++ b/activitysim/abm/models/util/annotate.py @@ -16,6 +16,62 @@ logger = logging.getLogger(__name__) +def annotate_households( + state: workflow.State, + model_settings: dict | PydanticBase, + trace_label: str, + locals_dict: dict | None = None, +): + """ + Add columns to the households table in the pipeline according to spec. + + Parameters + ---------- + model_settings : dict + trace_label : str + """ + if isinstance(model_settings, PydanticBase): + model_settings = model_settings.dict() + if locals_dict is None: + locals_dict = {} + households = state.get_dataframe("households") + expressions.assign_columns( + df=households, + model_settings=model_settings.get("annotate_households"), + locals_dict=locals_dict, + trace_label=tracing.extend_trace_label(trace_label, "annotate_households"), + ) + state.add_table("households", households) + + +def annotate_persons( + state: workflow.State, + model_settings: dict | PydanticBase, + trace_label: str, + locals_dict: dict | None = None, +): + """ + Add columns to the persons table in the pipeline according to spec. + + Parameters + ---------- + model_settings : dict + trace_label : str + """ + if isinstance(model_settings, PydanticBase): + model_settings = model_settings.dict() + if locals_dict is None: + locals_dict = {} + persons = state.get_dataframe("persons") + expressions.assign_columns( + df=persons, + model_settings=model_settings.get("annotate_persons"), + locals_dict=locals_dict, + trace_label=tracing.extend_trace_label(trace_label, "annotate_persons"), + ) + state.add_table("persons", persons) + + def annotate_tours( state: workflow.State, model_settings: dict | PydanticBase,