Skip to content

Commit ac6f732

Browse files
committed
feat: add dashboard with fan plots and scores
1 parent 5171d91 commit ac6f732

File tree

7 files changed

+481
-5
lines changed

7 files changed

+481
-5
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ data/
2323
.nhsn_flu_cache.parquet
2424
meta/
2525
**/unnamed-chunk*
26-
decreasing_forecasters_cache/
26+
decreasing_forecasters_cache/

R/forecasters/formatters.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,16 @@ format_covidhub <- function(pred, true_forecast_date, target_end_date, quantile_
5858
format_flusight <- function(pred, disease = c("flu", "covid")) {
5959
disease <- arg_match(disease)
6060
pred %>%
61+
add_state_info(geo_value_col = "geo_value", old_geo_code = "state_id", new_geo_code = "state_code") %>%
6162
mutate(
6263
reference_date = get_forecast_reference_date(forecast_date),
6364
target = glue::glue("wk inc {disease} hosp"),
6465
horizon = as.integer(floor((target_end_date - reference_date) / 7)),
6566
output_type = "quantile",
6667
output_type_id = quantile,
67-
value = value
68+
value = value,
69+
location = state_code
6870
) %>%
69-
left_join(get_population_data() %>% select(state_id, state_code), by = c("geo_value" = "state_id")) %>%
70-
mutate(location = state_code) %>%
7171
select(reference_date, target, horizon, target_end_date, location, output_type, output_type_id, value)
7272
}
7373

R/utils.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,19 @@ write_submission_file <- function(pred, forecast_reference_date, submission_dire
316316
readr::write_csv(pred, file_path)
317317
}
318318

319+
#' The quantile levels used by the covidhub repository
320+
#'
321+
#' @param type either standard or inc_case, with inc_case being a small subset of the standard
322+
#'
323+
#' @export
324+
covidhub_probs <- function(type = c("standard", "inc_case")) {
325+
type <- match.arg(type)
326+
switch(type,
327+
standard = c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99),
328+
inc_case = c(0.025, 0.100, 0.250, 0.500, 0.750, 0.900, 0.975)
329+
) |> round(digits = 3)
330+
}
331+
319332
#' Utility to get the reference date for a given date. This is the last day of
320333
#' the epiweek that the date falls in.
321334
get_forecast_reference_date <- function(date) {

_targets.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,7 @@ covid_hosp_prod:
1818
store: covid_hosp_prod
1919
use_crew: yes
2020
reporter_make: timestamp
21-
21+
dashboard-proj:
22+
script: scripts/dashboard-proj.R
23+
store: dashboard-proj
24+
use_crew: yes

scripts/dashboard-proj.R

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
library(tidyverse)
2+
library(httr)
3+
library(lubridate)
4+
library(progress)
5+
library(targets)
6+
source(here::here("R", "load_all.R"))
7+
8+
options(readr.show_progress = FALSE)
9+
options(readr.show_col_types = FALSE)
10+
11+
insufficient_data_geos <- c("as", "mp", "vi", "gu")
12+
13+
# Configuration
14+
config <- list(
15+
base_url = "https://raw.githubusercontent.com/cdcepi/FluSight-forecast-hub/main/model-output",
16+
forecasters = c("CMU-TimeSeries", "FluSight-baseline", "FluSight-ensemble", "FluSight-base_seasonal", "UMass-flusion"),
17+
local_storage = "data/forecasts",
18+
tracking_file = "data/download_tracking.csv"
19+
)
20+
21+
# Function to ensure directory structure exists
22+
setup_directories <- function(base_dir) {
23+
dir.create(file.path(base_dir), recursive = TRUE, showWarnings = FALSE)
24+
for (forecaster in config$forecasters) {
25+
dir.create(file.path(base_dir, forecaster), recursive = TRUE, showWarnings = FALSE)
26+
}
27+
}
28+
29+
# Function to load tracking data
30+
load_tracking_data <- function() {
31+
if (file.exists(config$tracking_file)) {
32+
read_csv(config$tracking_file)
33+
} else {
34+
tibble(
35+
forecaster = character(),
36+
filename = character(),
37+
download_date = character(),
38+
status = character()
39+
)
40+
}
41+
}
42+
43+
# Function to generate possible filenames for a date range
44+
generate_filenames <- function(start_date, end_date, forecaster) {
45+
dates <- seq(as_date(start_date), as_date(end_date), by = "week")
46+
filenames <- paste0(
47+
format(dates, "%Y-%m-%d"),
48+
"-",
49+
forecaster,
50+
".csv"
51+
)
52+
return(filenames)
53+
}
54+
55+
# Function to check if file exists on GitHub
56+
check_github_file <- function(forecaster, filename) {
57+
url <- paste0(config$base_url, "/", forecaster, "/", filename)
58+
response <- GET(url)
59+
return(status_code(response) == 200)
60+
}
61+
62+
# Function to download a single file
63+
download_forecast_file <- function(forecaster, filename) {
64+
url <- paste0(config$base_url, "/", forecaster, "/", filename)
65+
local_path <- file.path(config$local_storage, forecaster, filename)
66+
67+
tryCatch(
68+
{
69+
download.file(url, local_path, mode = "wb", quiet = TRUE)
70+
return("success")
71+
},
72+
error = function(e) {
73+
return("failed")
74+
}
75+
)
76+
}
77+
78+
# Main function to update forecast files
79+
update_forecast_files <- function(days_back = 30) {
80+
# Setup
81+
setup_directories(config$local_storage)
82+
tracking_data <- load_tracking_data()
83+
84+
# Generate date range
85+
end_date <- Sys.Date()
86+
start_date <- get_forecast_reference_date(end_date - days_back)
87+
88+
# Process each forecaster
89+
new_tracking_records <- list()
90+
91+
pb_forecasters <- progress_bar$new(
92+
format = "Downloading forecasts from :forecaster [:bar] :percent :eta",
93+
total = length(config$forecasters),
94+
clear = FALSE,
95+
width = 60
96+
)
97+
98+
for (forecaster in config$forecasters) {
99+
pb_forecasters$tick(tokens = list(forecaster = forecaster))
100+
101+
# Get potential filenames
102+
filenames <- generate_filenames(start_date, end_date, forecaster)
103+
104+
# Filter out already downloaded files
105+
existing_files <- tracking_data %>%
106+
filter(forecaster == !!forecaster, status == "success") %>%
107+
pull(filename)
108+
109+
new_files <- setdiff(filenames, existing_files)
110+
111+
if (length(new_files) > 0) {
112+
# Create nested progress bar for files
113+
pb_files <- progress_bar$new(
114+
format = " Downloading files [:bar] :current/:total :filename",
115+
total = length(new_files)
116+
)
117+
118+
for (filename in new_files) {
119+
pb_files$tick(tokens = list(filename = filename))
120+
121+
if (check_github_file(forecaster, filename)) {
122+
status <- download_forecast_file(forecaster, filename)
123+
124+
new_tracking_records[[length(new_tracking_records) + 1]] <- tibble(
125+
forecaster = forecaster,
126+
filename = filename,
127+
download_date = as.character(Sys.time()),
128+
status = status
129+
)
130+
}
131+
}
132+
}
133+
}
134+
135+
# Update tracking data
136+
if (length(new_tracking_records) > 0) {
137+
new_tracking_data <- bind_rows(new_tracking_records)
138+
tracking_data <- bind_rows(tracking_data, new_tracking_data)
139+
write_csv(tracking_data, config$tracking_file)
140+
}
141+
142+
return(tracking_data)
143+
}
144+
145+
# Function to read all forecast data
146+
read_all_forecasts <- function() {
147+
tracking_data <- read_csv(config$tracking_file)
148+
149+
successful_downloads <- tracking_data %>%
150+
filter(status == "success")
151+
152+
forecast_data <- map(1:nrow(successful_downloads), function(i) {
153+
row <- successful_downloads[i, ]
154+
path <- file.path(config$local_storage, row$forecaster, row$filename)
155+
if (file.exists(path)) {
156+
read_csv(path, col_types = list(
157+
reference_date = col_date(format = "%Y-%m-%d"),
158+
target_end_date = col_date(format = "%Y-%m-%d"),
159+
target = col_character(),
160+
location = col_character(),
161+
horizon = col_integer(),
162+
output_type = col_character(),
163+
output_type_id = col_character(),
164+
value = col_double(),
165+
forecaster = col_character(),
166+
forecast_date = col_date(format = "%Y-%m-%d")
167+
)) %>%
168+
mutate(
169+
forecaster = row$forecaster,
170+
forecast_date = as.Date(str_extract(row$filename, "\\d{4}-\\d{2}-\\d{2}")),
171+
)
172+
}
173+
})
174+
175+
bind_rows(forecast_data) %>%
176+
add_state_info(geo_value_col = "location", old_geo_code = "state_code", new_geo_code = "state_id") %>%
177+
rename(geo_value = state_id) %>%
178+
select(-location) %>%
179+
filter(
180+
target == "wk inc flu hosp",
181+
output_type == "quantile",
182+
)
183+
}
184+
185+
score_forecasts <- function(all_forecasts, nhsn_latest_data) {
186+
predictions_cards <- all_forecasts %>%
187+
rename(model = forecaster) %>%
188+
mutate(
189+
quantile = as.numeric(output_type_id),
190+
prediction = value
191+
) %>%
192+
select(model, geo_value, forecast_date, target_end_date, quantile, prediction)
193+
194+
truth_data <- nhsn_latest_data %>%
195+
mutate(
196+
target_end_date = as.Date(time_value),
197+
true_value = value
198+
) %>%
199+
select(geo_value, target_end_date, true_value)
200+
201+
evaluate_predictions(predictions_cards = predictions_cards, truth_data = truth_data) %>%
202+
rename(forecaster = model)
203+
}
204+
205+
get_latest_data <- function() {
206+
update_forecast_files(days_back = 120)
207+
read_all_forecasts()
208+
}
209+
210+
rlang::list2(
211+
tar_target(
212+
nhsn_latest_data,
213+
command = {
214+
if (wday(Sys.Date()) < 6 & wday(Sys.Date()) > 3) {
215+
# download from the preliminary data source from Wednesday to Friday
216+
most_recent_result <- readr::read_csv("https://data.cdc.gov/resource/mpgq-jmmr.csv?$limit=20000&$select=weekendingdate,jurisdiction,totalconfc19newadm,totalconfflunewadm")
217+
} else {
218+
most_recent_result <- readr::read_csv("https://data.cdc.gov/resource/ua7e-t2fy.csv?$limit=20000&$select=weekendingdate,jurisdiction,totalconfc19newadm,totalconfflunewadm")
219+
}
220+
most_recent_result %>%
221+
process_nhsn_data() %>%
222+
filter(disease == "nhsn_flu") %>%
223+
select(-disease) %>%
224+
filter(geo_value %nin% insufficient_data_geos) %>%
225+
mutate(
226+
source = "nhsn",
227+
geo_value = ifelse(geo_value == "usa", "us", geo_value),
228+
time_value = time_value
229+
) %>%
230+
filter(version == max(version)) %>%
231+
select(-version) %>%
232+
data_substitutions(disease = "flu") %>%
233+
as_epi_df(other_keys = "source", as_of = Sys.Date())
234+
}
235+
),
236+
tar_target(
237+
name = nhsn_archive_data,
238+
command = {
239+
create_nhsn_data_archive(disease = "nhsn_flu")
240+
}
241+
),
242+
tar_target(download_forecasts, update_forecast_files(days_back = 120)),
243+
tar_target(all_forecasts, read_all_forecasts()),
244+
tar_target(all_scores, score_forecasts(all_forecasts, nhsn_latest_data))
245+
)

0 commit comments

Comments
 (0)