diff --git a/src/config.rs b/src/config.rs index b5162f34908..a6c8277e051 100644 --- a/src/config.rs +++ b/src/config.rs @@ -18,6 +18,7 @@ pub struct Config { pub mirror: Replica, pub api_protocol: String, pub publish_rate_limit: PublishRateLimit, + pub blocked_traffic: Vec<(String, Vec)>, } impl Default for Config { @@ -42,6 +43,8 @@ impl Default for Config { /// - `GH_CLIENT_ID`: The client ID of the associated GitHub application. /// - `GH_CLIENT_SECRET`: The client secret of the associated GitHub application. /// - `DATABASE_URL`: The URL of the postgres database to use. + /// - `BLOCKED_TRAFFIC`: A list of headers and environment variables to use for blocking + ///. traffic. See the `block_traffic` module for more documentation. fn default() -> Config { let checkout = PathBuf::from(env("GIT_REPO_CHECKOUT")); let api_protocol = String::from("https"); @@ -135,6 +138,48 @@ impl Default for Config { mirror, api_protocol, publish_rate_limit: Default::default(), + blocked_traffic: blocked_traffic(), } } } + +fn blocked_traffic() -> Vec<(String, Vec)> { + let pattern_list = dotenv::var("BLOCKED_TRAFFIC").unwrap_or_default(); + parse_traffic_patterns(&pattern_list) + .map(|(header, value_env_var)| { + let value_list = dotenv::var(value_env_var).unwrap_or_default(); + let values = value_list.split(',').map(String::from).collect(); + (header.into(), values) + }) + .collect() +} + +fn parse_traffic_patterns(patterns: &str) -> impl Iterator { + patterns.split_terminator(',').map(|pattern| { + if let Some(idx) = pattern.find('=') { + (&pattern[..idx], &pattern[(idx + 1)..]) + } else { + panic!( + "BLOCKED_TRAFFIC must be in the form HEADER=VALUE_ENV_VAR, \ + got invalid pattern {}", + pattern + ) + } + }) +} + +#[test] +fn parse_traffic_patterns_splits_on_comma_and_looks_for_equal_sign() { + let pattern_string_1 = "Foo=BAR,Bar=BAZ"; + let pattern_string_2 = "Baz=QUX"; + let pattern_string_3 = ""; + + let patterns_1 = parse_traffic_patterns(pattern_string_1).collect::>(); + assert_eq!(vec![("Foo", "BAR"), ("Bar", "BAZ")], patterns_1); + + let patterns_2 = parse_traffic_patterns(pattern_string_2).collect::>(); + assert_eq!(vec![("Baz", "QUX")], patterns_2); + + let patterns_3 = parse_traffic_patterns(pattern_string_3).collect::>(); + assert!(patterns_3.is_empty()); +} diff --git a/src/middleware.rs b/src/middleware.rs index 2ff4de66eb4..db840654670 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -14,7 +14,7 @@ pub use self::security_headers::SecurityHeaders; pub use self::static_or_continue::StaticOrContinue; pub mod app; -mod block_ips; +mod block_traffic; pub mod current_user; mod debug; mod ember_index_rewrite; @@ -38,7 +38,8 @@ use crate::{App, Env}; pub fn build_middleware(app: Arc, endpoints: R404) -> MiddlewareBuilder { let mut m = MiddlewareBuilder::new(endpoints); - let env = app.config.env; + let config = app.config.clone(); + let env = config.env; if env != Env::Test { m.add(ensure_well_formed_500::EnsureWellFormed500); @@ -69,7 +70,7 @@ pub fn build_middleware(app: Arc, endpoints: R404) -> MiddlewareBuilder { )); if env == Env::Production { - m.add(SecurityHeaders::new(&app.config.uploader)); + m.add(SecurityHeaders::new(&config.uploader)); } m.add(AppMiddleware::new(app)); @@ -87,9 +88,8 @@ pub fn build_middleware(app: Arc, endpoints: R404) -> MiddlewareBuilder { m.around(Head::default()); - if let Ok(ip_list) = env::var("BLOCKED_IPS") { - let ips = ip_list.split(',').map(String::from).collect(); - m.around(block_ips::BlockIps::new(ips)); + for (header, blocked_values) in config.blocked_traffic { + m.around(block_traffic::BlockTraffic::new(header, blocked_values)); } m.around(require_user_agent::RequireUserAgent::default()); diff --git a/src/middleware/block_ips.rs b/src/middleware/block_traffic.rs similarity index 51% rename from src/middleware/block_ips.rs rename to src/middleware/block_traffic.rs index 91914c28187..c234f7ce7ac 100644 --- a/src/middleware/block_ips.rs +++ b/src/middleware/block_traffic.rs @@ -1,4 +1,12 @@ -//! Middleware that blocks requests from a list of given IPs +//! Middleware that blocks requests if a header matches the given list +//! +//! To use, set the `BLOCKED_TRAFFIC` environment variable to a comma-separated list of pairs +//! containing a header name, an equals sign, and the name of another environment variable that +//! contains the values of that header that should be blocked. For example, set `BLOCKED_TRAFFIC` +//! to `User-Agent=BLOCKED_UAS,X-Real-Ip=BLOCKED_IPS`, `BLOCKED_UAS` to `curl/7.54.0,cargo 1.36.0 +//! (c4fcfb725 2019-05-15)`, and `BLOCKED_IPS` to `192.168.0.1,127.0.0.1` to block requests from +//! the versions of curl or Cargo specified or from either of the IPs (values are nonsensical +//! examples). Values of the headers must match exactly. use super::prelude::*; @@ -8,32 +16,37 @@ use std::io::Cursor; // Can't derive debug because of Handler. #[allow(missing_debug_implementations)] #[derive(Default)] -pub struct BlockIps { - ips: Vec, +pub struct BlockTraffic { + header_name: String, + blocked_values: Vec, handler: Option>, } -impl BlockIps { - pub fn new(ips: Vec) -> Self { - Self { ips, handler: None } +impl BlockTraffic { + pub fn new(header_name: String, blocked_values: Vec) -> Self { + Self { + header_name, + blocked_values, + handler: None, + } } } -impl AroundMiddleware for BlockIps { +impl AroundMiddleware for BlockTraffic { fn with_handler(&mut self, handler: Box) { self.handler = Some(handler); } } -impl Handler for BlockIps { +impl Handler for BlockTraffic { fn call(&self, req: &mut dyn Request) -> Result> { - let has_blocked_ip = req + let has_blocked_value = req .headers() - .find("X-Real-Ip") - .unwrap() + .find(&self.header_name) + .unwrap_or_default() .iter() - .any(|ip| self.ips.iter().any(|v| v == ip)); - if has_blocked_ip { + .any(|value| self.blocked_values.iter().any(|v| v == value)); + if has_blocked_value { let body = format!( "We are unable to process your request at this time. \ This usually means that you are in violation of our crawler \ diff --git a/src/tests/all.rs b/src/tests/all.rs index 7b8ea6f3b46..310cdf5db9c 100644 --- a/src/tests/all.rs +++ b/src/tests/all.rs @@ -148,6 +148,7 @@ fn simple_config() -> Config { // sniff/record it, but everywhere else we use https api_protocol: String::from("http"), publish_rate_limit: Default::default(), + blocked_traffic: Default::default(), } } diff --git a/src/tests/server.rs b/src/tests/server.rs index 6ffd1e5de58..6a07041c078 100644 --- a/src/tests/server.rs +++ b/src/tests/server.rs @@ -26,3 +26,48 @@ fn user_agent_is_not_required_for_download() { let resp = anon.run::<()>(req); resp.assert_status(302); } + +#[test] +fn blocked_traffic_doesnt_panic_if_checked_header_is_not_present() { + let (app, anon, user) = TestApp::init() + .with_config(|config| { + config.blocked_traffic = vec![("Never-Given".into(), vec!["1".into()])]; + }) + .with_user(); + + app.db(|conn| { + CrateBuilder::new("dl_no_ua", user.as_model().id).expect_build(conn); + }); + + let mut req = anon.request_builder(Method::Get, "/api/v1/crates/dl_no_ua/0.99.0/download"); + req.header("User-Agent", ""); + let resp = anon.run::<()>(req); + resp.assert_status(302); +} + +#[test] +fn block_traffic_via_arbitrary_header_and_value() { + let (app, anon, user) = TestApp::init() + .with_config(|config| { + config.blocked_traffic = vec![("User-Agent".into(), vec!["1".into(), "2".into()])]; + }) + .with_user(); + + app.db(|conn| { + CrateBuilder::new("dl_no_ua", user.as_model().id).expect_build(conn); + }); + + let mut req = anon.request_builder(Method::Get, "/api/v1/crates/dl_no_ua/0.99.0/download"); + // A request with a header value we want to block isn't allowed + req.header("User-Agent", "1"); + req.header("X-Request-Id", "abcd"); // Needed for the error message we generate + let resp = anon.run::<()>(req); + resp.assert_status(403); + + let mut req = anon.request_builder(Method::Get, "/api/v1/crates/dl_no_ua/0.99.0/download"); + // A request with a header value we don't want to block is allowed, even though there might + // be a substring match + req.header("User-Agent", "1value-must-match-exactly-this-is-allowed"); + let resp = anon.run::<()>(req); + resp.assert_status(302); +} diff --git a/src/tests/util.rs b/src/tests/util.rs index cac849b0192..6a9d3bf30f3 100644 --- a/src/tests/util.rs +++ b/src/tests/util.rs @@ -268,12 +268,18 @@ impl TestAppBuilder { (app, anon, user, token) } - pub fn with_publish_rate_limit(mut self, rate: Duration, burst: i32) -> Self { - self.config.publish_rate_limit.rate = rate; - self.config.publish_rate_limit.burst = burst; + pub fn with_config(mut self, f: impl FnOnce(&mut Config)) -> Self { + f(&mut self.config); self } + pub fn with_publish_rate_limit(self, rate: Duration, burst: i32) -> Self { + self.with_config(|config| { + config.publish_rate_limit.rate = rate; + config.publish_rate_limit.burst = burst; + }) + } + pub fn with_git_index(mut self) -> Self { use crate::git;