Skip to content

Commit 04c9c3b

Browse files
authored
Merge pull request #245 from cmu-delphi/ndefries/scoring-performance
Optimize local error measures and pin container version
2 parents 01a0bb5 + ab4d480 commit 04c9c3b

File tree

1 file changed

+148
-21
lines changed

1 file changed

+148
-21
lines changed

Report/error_measures.R

Lines changed: 148 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
library(assertthat)
2-
31
overprediction <- function(quantile, value, actual_value) {
42
score_func_param_checker(quantile, value, actual_value, "overprediction")
53
if (!is_symmetric(quantile)) {
@@ -12,7 +10,9 @@ overprediction <- function(quantile, value, actual_value) {
1210
if (all(is.na(actual_value))) {
1311
return(NA)
1412
}
15-
actual_value <- unique(actual_value)
13+
14+
# Already checking that actual_value is unique in score_func_param_checker
15+
actual_value <- actual_value[1]
1616

1717
lower <- value[!is.na(quantile) & quantile < .5]
1818
med <- value[find_quantile_match(quantile, 0.5)]
@@ -45,7 +45,9 @@ underprediction <- function(quantile, value, actual_value) {
4545
if (all(is.na(actual_value))) {
4646
return(NA)
4747
}
48-
actual_value <- unique(actual_value)
48+
49+
# Already checking that actual_value is unique in score_func_param_checker
50+
actual_value <- actual_value[1]
4951

5052
upper <- value[!is.na(quantile) & quantile > .5]
5153
med <- value[find_quantile_match(quantile, 0.5)]
@@ -65,6 +67,122 @@ underprediction <- function(quantile, value, actual_value) {
6567
return(ans)
6668
}
6769

70+
#' Compute weighted interval score
71+
#'
72+
#' Computes weighted interval score (WIS), a well-known quantile-based
73+
#' approximation of the commonly-used continuous ranked probability score
74+
#' (CRPS). WIS is a proper score, and can be thought of as a distributional
75+
#' generalization of absolute error. For example, see [Bracher et
76+
#' al. (2020)](https://arxiv.org/abs/2005.12881) for discussion in the context
77+
#' of COVID-19 forecasting.
78+
#'
79+
#' @param quantile vector of forecasted quantiles
80+
#' @param value vector of forecasted values
81+
#' @param actual_value Actual value.
82+
#'
83+
#' @export
84+
weighted_interval_score <- function(quantile, value, actual_value) {
85+
score_func_param_checker(quantile, value, actual_value, "weighted_interval_score")
86+
if (all(is.na(actual_value))) {
87+
return(NA)
88+
}
89+
90+
# Already checking that actual_value is unique in score_func_param_checker
91+
actual_value <- actual_value[1]
92+
93+
value <- value[!is.na(quantile)]
94+
quantile <- quantile[!is.na(quantile)]
95+
96+
# per Ryan: WIS is equivalent to quantile loss modulo an extra 0.5 AE term
97+
# for the median forecast (counted twice).
98+
#
99+
# update: WIS is now being redefined to match exactly, still some question
100+
# about the correct denominator but the formula seems to be 1 / (K + 0.5)
101+
#
102+
# Finally, the multiplication by 2 is because alpha_k = 2*quantile_k
103+
#
104+
med <- value[find_quantile_match(quantile, 0.5)]
105+
106+
if (length(med) > 1L) {
107+
return(NA)
108+
}
109+
110+
wis <- 2 * mean(pmax(
111+
quantile * (actual_value - value),
112+
(1 - quantile) * (value - actual_value),
113+
na.rm = TRUE
114+
))
115+
116+
return(wis)
117+
}
118+
119+
#' Compute absolute error
120+
#'
121+
#' Absolute error of a forecaster
122+
#'
123+
#'
124+
#' Intended to be used with `evaluate_predictions()`, it expects three arguments
125+
#' of the same length, finds the location of the point forecast, and returns
126+
#' the absolute error.
127+
#'
128+
#' @param quantile vector of forecasted quantiles
129+
#' @param value vector of forecasted values
130+
#' @param actual_value vector of actual values of the same length as
131+
#' `quantile`/`value` or a scalar
132+
#'
133+
#' @export
134+
absolute_error <- function(quantile, value, actual_value) {
135+
score_func_param_checker(quantile, value, actual_value, "absolute_error")
136+
point_fcast <- which(is.na(quantile))
137+
ae <- abs(actual_value - value)
138+
if (length(point_fcast) == 1L) {
139+
return(ae[point_fcast])
140+
}
141+
point_fcast <- which(find_quantile_match(quantile, 0.5))
142+
if (length(point_fcast) == 1L) {
143+
return(ae[point_fcast])
144+
}
145+
warning(paste(
146+
"Absolute error: Forecaster must return either a point forecast",
147+
"with quantile == NA or a median with quantile == 0.5",
148+
"Returning NA."
149+
))
150+
return(NA)
151+
}
152+
153+
#' Generate interval coverage error measure function
154+
#'
155+
#' Returns an error measure function indicating whether a central interval
156+
#' covers the actual value. The interval is defined as the (alpha/2)-quantile
157+
#' to the (1 - alpha/2)-quantile, where alpha = 1 - coverage.
158+
#'
159+
#' @param coverage Nominal interval coverage (from 0 to 1).
160+
#'
161+
#' @export
162+
interval_coverage <- function(coverage) {
163+
function(quantiles, value, actual_value) {
164+
score_func_param_checker(quantiles, value, actual_value, "interval_coverage")
165+
value <- value[!is.na(quantiles)]
166+
quantiles <- quantiles[!is.na(quantiles)]
167+
alpha <- 1 - coverage
168+
lower_interval <- alpha / 2
169+
upper_interval <- 1 - (alpha / 2)
170+
if (!any(find_quantile_match(quantiles, lower_interval)) ||
171+
!any(find_quantile_match(quantiles, upper_interval))) {
172+
warning(paste(
173+
"Interval Coverage:",
174+
"Quantiles must cover an interval of specified width",
175+
"centered at 0.5. Returning NA."
176+
))
177+
return(NA)
178+
}
179+
180+
lower <- value[which(find_quantile_match(quantiles, lower_interval))]
181+
upper <- value[which(find_quantile_match(quantiles, upper_interval))]
182+
return(actual_value[1] >= lower & actual_value[1] <= upper)
183+
}
184+
}
185+
68186
sharpness <- function(quantile, value, actual_value) {
69187
weighted_interval_score(quantile, value, actual_value) -
70188
overprediction(quantile, value, actual_value) -
@@ -74,7 +192,14 @@ sharpness <- function(quantile, value, actual_value) {
74192
# Utility functions required from evalcast that are not exported
75193

76194
is_symmetric <- function(x, tol = 1e-8) {
77-
x <- sort(x)
195+
# Checking if `x` is sorted is much faster than trying to sort it again
196+
if (is.unsorted(x, na.rm = TRUE)) {
197+
# Implicitly drops NA values
198+
x <- sort(x)
199+
} else {
200+
# Match `sort` behavior
201+
x <- x[!is.na(x)]
202+
}
78203
all(abs(x + rev(x) - 1) < tol)
79204
}
80205

@@ -106,31 +231,33 @@ get_quantile_prediction_factory <- function(val_to_match, tol = 1e-8) {
106231
score_func_param_checker <- function(quantiles, values, actual_value, id = "") {
107232
id_str <- paste0(id, ": ")
108233
if (length(actual_value) > 1) {
109-
assert_that(length(actual_value) == length(values),
110-
msg = paste0(
234+
if (length(actual_value) != length(values)) {
235+
stop(paste0(
111236
id_str,
112237
"actual_value must be a scalar or the same length",
113238
" as values"
114-
)
115-
)
239+
))
240+
}
116241
actual_value <- unique(actual_value)
117242
}
118-
assert_that(length(actual_value) == 1,
119-
msg = paste0(
243+
244+
if (length(actual_value) != 1) {
245+
stop(paste0(
120246
id_str,
121247
"actual_value must have exactly 1 unique value"
122-
)
123-
)
124-
assert_that(length(quantiles) == length(values),
125-
msg = paste0(
248+
))
249+
}
250+
if (length(quantiles) != length(values)) {
251+
stop(paste0(
126252
id_str,
127253
"quantiles and values must be of the same length"
128-
)
129-
)
130-
assert_that(!any(duplicated(quantiles)),
131-
msg = paste0(
254+
))
255+
}
256+
257+
if (anyDuplicated(quantiles)) {
258+
stop(paste0(
132259
id_str,
133260
"quantiles must be unique."
134-
)
135-
)
261+
))
262+
}
136263
}

0 commit comments

Comments
 (0)