Skip to content

Commit cf7fbba

Browse files
committed
nwss: adding tests for pull functions
1 parent 1a347c6 commit cf7fbba

File tree

6 files changed

+743
-54
lines changed

6 files changed

+743
-54
lines changed

nwss_wastewater/delphi_nwss/pull.py

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""Functions for pulling NCHS mortality data API."""
33

4-
from typing import Optional
5-
64
import numpy as np
75
import pandas as pd
86
from sodapy import Socrata
@@ -32,7 +30,48 @@ def sig_digit_round(value, n_digits):
3230
return result
3331

3432

35-
def pull_nwss_data(token: str, test_file: Optional[str] = None):
33+
def construct_typedicts():
34+
"""Create the type conversion dictionary for both dataframes."""
35+
# basic type conversion
36+
signals_dict = {key: float for key in SIGNALS}
37+
type_dict = {**signals_dict}
38+
type_dict["timestamp"] = "datetime64[ns]"
39+
# metric type conversion
40+
signals_dict_metric = {key: float for key in METRIC_SIGNALS}
41+
metric_dates_dict = {key: "datetime64[ns]" for key in METRIC_DATES}
42+
type_dict_metric = {**metric_dates_dict, **signals_dict_metric, **SAMPLE_SITE_NAMES}
43+
return type_dict, type_dict_metric
44+
45+
46+
def warn_string(df, type_dict):
47+
"""Format the warning string."""
48+
return f"""
49+
Expected column(s) missed, The dataset schema may
50+
have changed. Please investigate and amend the code.
51+
52+
Columns needed:
53+
{NEWLINE.join(type_dict.keys())}
54+
55+
Columns available:
56+
{NEWLINE.join(df.columns)}
57+
"""
58+
59+
60+
def reformat(df, df_metric):
61+
"""Add the population column from df_metric to df, and rename some columns."""
62+
# drop unused columns from df_metric
63+
df_population = df_metric.loc[:, ["key_plot_id", "date_start", "population_served"]]
64+
# get matching keys
65+
df_population = df_population.rename(columns={"date_start": "timestamp"})
66+
df_population = df_population.set_index(["key_plot_id", "timestamp"])
67+
df = df.set_index(["key_plot_id", "timestamp"])
68+
69+
df = df.join(df_population)
70+
df = df.reset_index()
71+
return df
72+
73+
74+
def pull_nwss_data(token: str):
3675
"""Pull the latest NWSS Wastewater data, and conforms it into a dataset.
3776
3877
The output dataset has:
@@ -55,54 +94,26 @@ def pull_nwss_data(token: str, test_file: Optional[str] = None):
5594
"""
5695
# Constants
5796
keep_columns = SIGNALS.copy()
58-
signals_dict_metric = {key: float for key in METRIC_SIGNALS}
59-
metric_dates_dict = {key: "datetime64[ns]" for key in METRIC_DATES}
60-
type_dict_metric = {**metric_dates_dict, **signals_dict_metric, **SAMPLE_SITE_NAMES}
6197
# concentration key types
62-
signals_dict = {key: float for key in SIGNALS}
63-
type_dict = {**signals_dict}
64-
type_dict["timestamp"] = "datetime64[ns]"
98+
type_dict, type_dict_metric = construct_typedicts()
99+
100+
# Pull data from Socrata API
101+
client = Socrata("data.cdc.gov", token)
102+
results_concentration = client.get("g653-rqe2", limit=10**10)
103+
results_metric = client.get("2ew6-ywp6", limit=10**10)
104+
df_metric = pd.DataFrame.from_records(results_metric)
105+
df = pd.DataFrame.from_records(results_concentration)
106+
df = df.rename(columns={"date": "timestamp"})
65107

66-
if test_file:
67-
df = pd.read_csv(f"./test_data/{test_file}")
68-
else:
69-
# Pull data from Socrata API
70-
client = Socrata("data.cdc.gov", token)
71-
results_concentration = client.get("g653-rqe2", limit=10**10)
72-
results_metric = client.get("2ew6-ywp6", limit=10**10)
73-
df_metric = pd.DataFrame.from_records(results_metric)
74-
df = pd.DataFrame.from_records(results_concentration)
75-
df = df.rename(columns={"date": "timestamp"})
76108
try:
77109
df = df.astype(type_dict)
78110
except KeyError as exc:
79-
raise ValueError(
80-
f"""
81-
Expected column(s) missed, The dataset schema may
82-
have changed. Please investigate and amend the code.
83-
84-
Columns needed:
85-
{NEWLINE.join(type_dict.keys())}
111+
raise ValueError(warn_string(df, type_dict)) from exc
86112

87-
Columns available:
88-
{NEWLINE.join(df.columns)}
89-
"""
90-
) from exc
91113
try:
92114
df_metric = df_metric.astype(type_dict_metric)
93115
except KeyError as exc:
94-
raise ValueError(
95-
f"""
96-
Expected column(s) missed, The metric dataset schema may
97-
have changed. Please investigate and amend the code.
98-
99-
Columns needed:
100-
{NEWLINE.join(type_dict_metric.keys())}
101-
102-
Columns available:
103-
{NEWLINE.join(df_metric.columns)}
104-
"""
105-
) from exc
116+
raise ValueError(warn_string(df_metric, type_dict_metric)) from exc
106117

107118
# pull 2 letter state labels out of the key_plot_id labels
108119
df["state"] = df.key_plot_id.str.extract(r"_(\w\w)_")
@@ -111,15 +122,7 @@ def pull_nwss_data(token: str, test_file: Optional[str] = None):
111122
for signal in SIGNALS:
112123
df[signal] = sig_digit_round(df[signal], SIG_DIGITS)
113124

114-
# drop unused columns from df_metric
115-
df_population = df_metric.loc[:, ["key_plot_id", "date_start", "population_served"]]
116-
# get matching keys
117-
df_population = df_population.rename(columns={"date_start": "timestamp"})
118-
df_population = df_population.set_index(["key_plot_id", "timestamp"])
119-
df = df.set_index(["key_plot_id", "timestamp"])
120-
121-
df = df.join(df_population)
122-
df = df.reset_index()
125+
df = reformat(df, df_metric)
123126
# if there are population NA's, assume the previous value is accurate (most
124127
# likely introduced by dates only present in one and not the other; even
125128
# otherwise, best to assume some value rather than break the data)

nwss_wastewater/delphi_nwss/run.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def run_module(params):
126126
)
127127
export_dir = params["common"]["export_dir"]
128128
token = params["indicator"]["token"]
129-
test_file = params["indicator"].get("test_file", None)
130129
if "archive" in params:
131130
daily_arch_diff = S3ArchiveDiffer(
132131
params["archive"]["cache_dir"],
@@ -140,7 +139,7 @@ def run_module(params):
140139
run_stats = []
141140
## build the base version of the signal at the most detailed geo level you can get.
142141
## compute stuff here or farm out to another function or file
143-
df_pull = pull_nwss_data(token, test_file)
142+
df_pull = pull_nwss_data(token)
144143
## aggregate
145144
for sensor in SIGNALS:
146145
df = df_pull.copy()

0 commit comments

Comments
 (0)