diff --git a/Cargo.lock b/Cargo.lock index e68af936c..d045af828 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -180,6 +180,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bindgen" version = "0.69.5" @@ -669,7 +675,7 @@ version = "7.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" dependencies = [ - "base64", + "base64 0.21.7", "byteorder", "crossbeam-channel", "flate2", @@ -1885,6 +1891,7 @@ name = "sozu-lib" version = "1.0.6" dependencies = [ "anyhow", + "base64 0.22.1", "cookie-factory", "hdrhistogram", "hex", diff --git a/bin/rewrites.toml b/bin/rewrites.toml new file mode 100644 index 000000000..3732b4d9d --- /dev/null +++ b/bin/rewrites.toml @@ -0,0 +1,37 @@ +log_target = "stdout" +log_colored = true +worker_count = 1 + +[[listeners]] +protocol = "http" +address = "0.0.0.0:8080" +#answers = { "404" = "default_404.html", "503" = "default_503.html", "custom_200" = "default_200.html" } + +[clusters.MyCluster] +protocol = "http" +answers = { "custom_200" = "../lib/assets/mycluster_200.html" } +https_redirect = true +https_redirect_port = 8443 + +backends = [] + +[[clusters.MyCluster.frontends]] +address = "0.0.0.0:8080" +hostname = "/cdn([0-9]*)/.foo./(.*)/.com" +path = "\\A/client/id_([0-9]*)/(.*)" +path_type = "REGEX" +redirect_scheme = "USE_HTTPS" +redirect_template = "custom_200" +rewrite_host = "client_$PATH[1].bar.$HOST[2].com" +rewrite_path = "/$PATH[2]?cdn=$HOST[1]" +rewrite_port = 8442 + +[[clusters.MyCluster.frontends]] +address = "0.0.0.0:8080" +hostname = "localhost" +path = "\\A/download(/(.*))?\\z" +path_type = "REGEX" +redirect_scheme = "USE_HTTPS" +redirect_template = "custom_200" +rewrite_path = "/rewritten/$PATH[2]" + diff --git a/bin/src/ctl/request_builder.rs b/bin/src/ctl/request_builder.rs index eb07929d7..d34d26c87 100644 --- a/bin/src/ctl/request_builder.rs +++ b/bin/src/ctl/request_builder.rs @@ -76,11 +76,7 @@ impl CommandManager { pub fn reload_configuration(&mut self, path: Option) -> Result<(), CtlError> { debug!("Reloading configuration…"); - let path = match path { - Some(p) => p, - None => String::new(), - }; - self.send_request(RequestType::ReloadConfiguration(path).into()) + self.send_request(RequestType::ReloadConfiguration(path.unwrap_or_default()).into()) } pub fn list_frontends( @@ -246,10 +242,15 @@ impl CommandManager { path: PathRule::from_cli_options(path_prefix, path_regex, path_equals), method: method.map(String::from), position: RulePosition::Tree.into(), - tags: match tags { - Some(tags) => tags, - None => BTreeMap::new(), - }, + tags: tags.unwrap_or_default(), + required_auth: todo!(), + redirect: todo!(), + redirect_scheme: todo!(), + redirect_template: todo!(), + rewrite_host: todo!(), + rewrite_path: todo!(), + rewrite_port: todo!(), + headers: todo!(), }) .into(), ), @@ -294,10 +295,15 @@ impl CommandManager { path: PathRule::from_cli_options(path_prefix, path_regex, path_equals), method: method.map(String::from), position: RulePosition::Tree.into(), - tags: match tags { - Some(tags) => tags, - None => BTreeMap::new(), - }, + tags: tags.unwrap_or_default(), + required_auth: todo!(), + redirect: todo!(), + redirect_scheme: todo!(), + redirect_template: todo!(), + rewrite_host: todo!(), + rewrite_path: todo!(), + rewrite_port: todo!(), + headers: todo!(), }) .into(), ), @@ -341,8 +347,8 @@ impl CommandManager { } => { let https_listener = ListenerBuilder::new_https(address.into()) .with_public_address(public_address) - .with_answer_404_path(answer_404) - .with_answer_503_path(answer_503) + .with_answer("404", answer_404) + .with_answer("503", answer_503) .with_tls_versions(tls_versions) .with_cipher_list(cipher_list) .with_expect_proxy(expect_proxy) @@ -384,8 +390,8 @@ impl CommandManager { } => { let http_listener = ListenerBuilder::new_http(address.into()) .with_public_address(public_address) - .with_answer_404_path(answer_404) - .with_answer_503_path(answer_503) + .with_answer("404", answer_404) + .with_answer("503", answer_503) .with_expect_proxy(expect_proxy) .with_sticky_name(sticky_name) .with_front_timeout(front_timeout) diff --git a/command/assets/config.toml b/command/assets/config.toml index 3d48b60cd..8675c58a1 100644 --- a/command/assets/config.toml +++ b/command/assets/config.toml @@ -17,7 +17,7 @@ protocol = "http" [[listeners]] address = "0.0.0.0:443" protocol = "https" -answer_404 = "./assets/custom_404.html" +answers = { "404" = "./assets/custom_404.html" } tls_versions = ["TLS_V12"] [[listeners]] @@ -28,7 +28,7 @@ expect_proxy = true [clusters] [clusters.MyCluster] protocol = "http" -answer_503 = "./assets/custom_503.html" +answers = { "503" = "./assets/custom_503.html" } #sticky_session = false #https_redirect = false frontends = [ diff --git a/command/assets/custom_200.html b/command/assets/custom_200.html new file mode 100644 index 000000000..12f2ca492 --- /dev/null +++ b/command/assets/custom_200.html @@ -0,0 +1,6 @@ +HTTP/1.1 200 OK +Sozu-Id: %REQUEST_ID + +

%CLUSTER_ID Custom 200

+

original url: %ROUTE

+

rewritten url: %REDIRECT_LOCATION

\ No newline at end of file diff --git a/command/assets/custom_404.html b/command/assets/custom_404.html index 34f6c80d0..e331b285e 100644 --- a/command/assets/custom_404.html +++ b/command/assets/custom_404.html @@ -1,7 +1,6 @@ HTTP/1.1 404 Not Found Cache-Control: no-cache -Connection: close -Sozu-Id: %SOZU_ID +Sozu-Id: %REQUEST_ID

My own 404 error page

-

Your request %SOZU_ID found no frontend and cannot be redirected.

\ No newline at end of file +

Your request %REQUEST_ID found no frontend and cannot be redirected.

\ No newline at end of file diff --git a/command/assets/custom_503.html b/command/assets/custom_503.html index 8f174262b..4484f8a4e 100644 --- a/command/assets/custom_503.html +++ b/command/assets/custom_503.html @@ -1,11 +1,10 @@ HTTP/1.1 503 Service Unavailable Cache-Control: no-cache Connection: close -%Content-Length: %CONTENT_LENGTH -Sozu-Id: %SOZU_ID +Sozu-Id: %REQUEST_ID

MyCluster: 503 Service Unavailable

-

No server seems to be alive, could not redirect request %SOZU_ID.

+

No server seems to be alive, could not redirect request %REQUEST_ID.

-%DETAILS
+%MESSAGE
 
\ No newline at end of file
diff --git a/command/src/command.proto b/command/src/command.proto
index ad93b9e34..b604ffcb8 100644
--- a/command/src/command.proto
+++ b/command/src/command.proto
@@ -129,7 +129,7 @@ message HttpListenerConfig {
     required uint32 request_timeout = 10 [default = 10];
     // wether the listener is actively listening on its socket
     required bool active = 11 [default = false];
-    optional CustomHttpAnswers http_answers = 12;
+    map answers = 13;
 }
 
 // details of an HTTPS listener
@@ -161,7 +161,7 @@ message HttpsListenerConfig {
     // The tickets allow the client to resume a session. This protects the client
     // agains session tracking. Defaults to 4.
     required uint64 send_tls13_tickets = 20;
-    optional CustomHttpAnswers http_answers = 21;
+    map answers = 22;
 }
 
 // details of an TCP listener
@@ -179,31 +179,6 @@ message TcpListenerConfig {
     required bool active = 7 [default = false];
 }
 
-// custom HTTP answers, useful for 404, 503 pages
-message CustomHttpAnswers {
-    // MovedPermanently
-    optional string answer_301 = 1;
-    // BadRequest
-    optional string answer_400 = 2;
-    // Unauthorized
-    optional string answer_401 = 3;
-    // NotFound
-    optional string answer_404 = 4;
-    // RequestTimeout
-    optional string answer_408 = 5;
-    // PayloadTooLarge
-    optional string answer_413 = 6;
-    // BadGateway
-    optional string answer_502 = 7;
-    // ServiceUnavailable
-    optional string answer_503 = 8;
-    // GatewayTimeout
-    optional string answer_504 = 9;
-    // InsufficientStorage
-    optional string answer_507 = 10;
-
-}
-
 message ActivateListener {
     required SocketAddress address = 1;
     required ListenerType proxy = 2;
@@ -237,6 +212,18 @@ message ListenersList {
     map tcp_listeners = 3;
 }
 
+enum RedirectPolicy {
+    FORWARD = 0;
+    PERMANENT = 1;
+    UNAUTHORIZED = 2;
+}
+
+enum RedirectScheme {
+    USE_SAME = 0;
+    USE_HTTP = 1;
+    USE_HTTPS = 2;
+}
+
 // An HTTP or HTTPS frontend, as order to, or received from, Sōzu
 message RequestHttpFrontend {
     optional string cluster_id = 1;
@@ -247,6 +234,26 @@ message RequestHttpFrontend {
     required RulePosition position = 6 [default = TREE];
     // custom tags to identify the frontend in the access logs
     map tags = 7;
+    optional RedirectPolicy redirect = 8;
+    optional bool required_auth = 9;
+    optional RedirectScheme redirect_scheme = 10;
+    optional string redirect_template = 11;
+    optional string rewrite_host = 12;
+    optional string rewrite_path = 13;
+    optional uint32 rewrite_port = 14;
+    repeated Header headers = 15;
+}
+
+enum HeaderPosition {
+    REQUEST = 1;
+    RESPONSE = 2;
+    BOTH = 3;
+}
+
+message Header {
+    required HeaderPosition position = 1;
+    required string key = 2;
+    required string val = 3;
 }
 
 message RequestTcpFrontend {
@@ -374,8 +381,11 @@ message Cluster {
     required bool https_redirect = 3;
     optional ProxyProtocolConfig proxy_protocol = 4;
     required LoadBalancingAlgorithms load_balancing = 5 [default = ROUND_ROBIN];
-    optional string answer_503 = 6;
     optional LoadMetric load_metric = 7;
+    optional uint32 https_redirect_port = 8;
+    map answers = 9;
+    repeated string authorized_hashes = 10;
+    optional string www_authenticate = 11;
 }
 
 enum LoadBalancingAlgorithms {
diff --git a/command/src/config.rs b/command/src/config.rs
index e148cc1d0..47b5e41fb 100644
--- a/command/src/config.rs
+++ b/command/src/config.rs
@@ -51,7 +51,7 @@ use std::{
     collections::{BTreeMap, HashMap, HashSet},
     env, fmt,
     fs::{create_dir_all, metadata, File},
-    io::{ErrorKind, Read},
+    io::ErrorKind,
     net::SocketAddr,
     ops::Range,
     path::PathBuf,
@@ -62,11 +62,11 @@ use crate::{
     logging::AccessLogFormat,
     proto::command::{
         request::RequestType, ActivateListener, AddBackend, AddCertificate, CertificateAndKey,
-        Cluster, CustomHttpAnswers, HttpListenerConfig, HttpsListenerConfig, ListenerType,
+        Cluster, Header, HeaderPosition, HttpListenerConfig, HttpsListenerConfig, ListenerType,
         LoadBalancingAlgorithms, LoadBalancingParams, LoadMetric, MetricsConfiguration, PathRule,
-        ProtobufAccessLogFormat, ProxyProtocolConfig, Request, RequestHttpFrontend,
-        RequestTcpFrontend, RulePosition, ServerConfig, ServerMetricsConfig, SocketAddress,
-        TcpListenerConfig, TlsVersion, WorkerRequest,
+        ProtobufAccessLogFormat, ProxyProtocolConfig, RedirectPolicy, RedirectScheme, Request,
+        RequestHttpFrontend, RequestTcpFrontend, RulePosition, ServerConfig, ServerMetricsConfig,
+        SocketAddress, TcpListenerConfig, TlsVersion, WorkerRequest,
     },
     ObjectKind,
 };
@@ -210,13 +210,15 @@ pub enum ConfigError {
     },
     #[error("Invalid '{0}' field for a TCP frontend")]
     InvalidFrontendConfig(String),
-    #[error("invalid path {0:?}")]
+    #[error("Invalid path {0:?}")]
     InvalidPath(PathBuf),
-    #[error("listening address {0:?} is already used in the configuration")]
+    #[error("Invalid Sha256 hash '{0}'")]
+    InvalidHash(String),
+    #[error("Listening address {0:?} is already used in the configuration")]
     ListenerAddressAlreadyInUse(SocketAddr),
-    #[error("missing {0:?}")]
+    #[error("Missing {0:?}")]
     Missing(MissingKind),
-    #[error("could not get parent directory for file {0}")]
+    #[error("Could not get parent directory for file {0}")]
     NoFileParent(String),
     #[error("Could not get the path of the saved state")]
     SaveStatePath(String),
@@ -240,16 +242,7 @@ pub struct ListenerBuilder {
     pub address: SocketAddr,
     pub protocol: Option,
     pub public_address: Option,
-    pub answer_301: Option,
-    pub answer_400: Option,
-    pub answer_401: Option,
-    pub answer_404: Option,
-    pub answer_408: Option,
-    pub answer_413: Option,
-    pub answer_502: Option,
-    pub answer_503: Option,
-    pub answer_504: Option,
-    pub answer_507: Option,
+    pub answers: Option>,
     pub tls_versions: Option>,
     pub cipher_list: Option>,
     pub cipher_suites: Option>,
@@ -279,6 +272,26 @@ pub fn default_sticky_name() -> String {
     DEFAULT_STICKY_NAME.to_string()
 }
 
+pub fn load_answers(
+    answers: Option<&BTreeMap>,
+) -> Result, ConfigError> {
+    if let Some(answers) = answers {
+        answers
+            .iter()
+            .map(|(name, path)| match Config::load_file(path) {
+                Ok(content) => Ok((name.to_owned(), content)),
+                Err(e) => Err((name.to_owned(), path, e)),
+            })
+            .collect::, _>>()
+            .map_err(|(name, path, e)| {
+                error!("cannot load answer {:?} at path {:?}: {:?}", name, path, e);
+                e
+            })
+    } else {
+        Ok(BTreeMap::new())
+    }
+}
+
 impl ListenerBuilder {
     /// starts building an HTTP Listener with config values for timeouts,
     /// or defaults if no config is provided
@@ -302,16 +315,7 @@ impl ListenerBuilder {
     fn new(address: SocketAddress, protocol: ListenerProtocol) -> ListenerBuilder {
         ListenerBuilder {
             address: address.into(),
-            answer_301: None,
-            answer_401: None,
-            answer_400: None,
-            answer_404: None,
-            answer_408: None,
-            answer_413: None,
-            answer_502: None,
-            answer_503: None,
-            answer_504: None,
-            answer_507: None,
+            answers: None,
             back_timeout: None,
             certificate_chain: None,
             certificate: None,
@@ -338,23 +342,22 @@ impl ListenerBuilder {
         self
     }
 
-    pub fn with_answer_404_path(&mut self, answer_404_path: Option) -> &mut Self
+    pub fn with_answer(&mut self, name: S, path: Option) -> &mut Self
     where
         S: ToString,
     {
-        if let Some(path) = answer_404_path {
-            self.answer_404 = Some(path.to_string());
+        if let Some(path) = path {
+            self.answers
+                .get_or_insert_with(BTreeMap::new)
+                .insert(name.to_string(), path);
         }
         self
     }
 
-    pub fn with_answer_503_path(&mut self, answer_503_path: Option) -> &mut Self
-    where
-        S: ToString,
-    {
-        if let Some(path) = answer_503_path {
-            self.answer_503 = Some(path.to_string());
-        }
+    pub fn with_answers(&mut self, mut answers: BTreeMap) -> &mut Self {
+        self.answers
+            .get_or_insert_with(BTreeMap::new)
+            .append(&mut answers);
         self
     }
 
@@ -429,23 +432,6 @@ impl ListenerBuilder {
         self
     }
 
-    /// Get the custom HTTP answers from the file system using the provided paths
-    fn get_http_answers(&self) -> Result, ConfigError> {
-        let http_answers = CustomHttpAnswers {
-            answer_301: read_http_answer_file(&self.answer_301)?,
-            answer_400: read_http_answer_file(&self.answer_400)?,
-            answer_401: read_http_answer_file(&self.answer_401)?,
-            answer_404: read_http_answer_file(&self.answer_404)?,
-            answer_408: read_http_answer_file(&self.answer_408)?,
-            answer_413: read_http_answer_file(&self.answer_413)?,
-            answer_502: read_http_answer_file(&self.answer_502)?,
-            answer_503: read_http_answer_file(&self.answer_503)?,
-            answer_504: read_http_answer_file(&self.answer_504)?,
-            answer_507: read_http_answer_file(&self.answer_507)?,
-        };
-        Ok(Some(http_answers))
-    }
-
     /// Assign the timeouts of the config to this listener, only if timeouts did not exist
     fn assign_config_timeouts(&mut self, config: &Config) {
         self.front_timeout = Some(self.front_timeout.unwrap_or(config.front_timeout));
@@ -467,8 +453,6 @@ impl ListenerBuilder {
             self.assign_config_timeouts(config);
         }
 
-        let http_answers = self.get_http_answers()?;
-
         let configuration = HttpListenerConfig {
             address: self.address.into(),
             public_address: self.public_address.map(|a| a.into()),
@@ -478,7 +462,7 @@ impl ListenerBuilder {
             back_timeout: self.back_timeout.unwrap_or(DEFAULT_BACK_TIMEOUT),
             connect_timeout: self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT),
             request_timeout: self.request_timeout.unwrap_or(DEFAULT_REQUEST_TIMEOUT),
-            http_answers,
+            answers: load_answers(self.answers.as_ref())?,
             ..Default::default()
         };
 
@@ -550,8 +534,6 @@ impl ListenerBuilder {
             .map(split_certificate_chain)
             .unwrap_or_default();
 
-        let http_answers = self.get_http_answers()?;
-
         if let Some(config) = config {
             self.assign_config_timeouts(config);
         }
@@ -577,7 +559,7 @@ impl ListenerBuilder {
             send_tls13_tickets: self
                 .send_tls13_tickets
                 .unwrap_or(DEFAULT_SEND_TLS_13_TICKETS),
-            http_answers,
+            answers: load_answers(self.answers.as_ref())?,
         };
 
         Ok(https_listener_config)
@@ -608,28 +590,6 @@ impl ListenerBuilder {
     }
 }
 
-/// read a custom HTTP answer from a file
-fn read_http_answer_file(path: &Option) -> Result, ConfigError> {
-    match path {
-        Some(path) => {
-            let mut content = String::new();
-            let mut file = File::open(path).map_err(|io_error| ConfigError::FileOpen {
-                path_to_open: path.to_owned(),
-                io_error,
-            })?;
-
-            file.read_to_string(&mut content)
-                .map_err(|io_error| ConfigError::FileRead {
-                    path_to_read: path.to_owned(),
-                    io_error,
-                })?;
-
-            Ok(Some(content))
-        }
-        None => Ok(None),
-    }
-}
-
 #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
 #[serde(deny_unknown_fields)]
 pub struct MetricsConfig {
@@ -649,6 +609,14 @@ pub enum PathRuleType {
     Equals,
 }
 
+#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
+#[serde(deny_unknown_fields)]
+pub struct HeaderConfig {
+    position: HeaderPosition,
+    key: String,
+    val: String,
+}
+
 #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
 #[serde(deny_unknown_fields)]
 pub struct FileClusterFrontendConfig {
@@ -664,9 +632,18 @@ pub struct FileClusterFrontendConfig {
     pub certificate_chain: Option,
     #[serde(default)]
     pub tls_versions: Vec,
-    #[serde(default)]
+    #[serde(default = "default_rule_position")]
     pub position: RulePosition,
     pub tags: Option>,
+    pub required_auth: Option,
+    pub redirect: Option,
+    pub redirect_scheme: Option,
+    pub redirect_template: Option,
+    pub rewrite_host: Option,
+    pub rewrite_path: Option,
+    pub rewrite_port: Option,
+    #[serde(default)]
+    pub headers: Vec,
 }
 
 impl FileClusterFrontendConfig {
@@ -752,6 +729,14 @@ impl FileClusterFrontendConfig {
             path,
             method: self.method.clone(),
             tags: self.tags.clone(),
+            required_auth: self.required_auth.unwrap_or(false),
+            redirect: self.redirect,
+            redirect_scheme: self.redirect_scheme,
+            redirect_template: self.redirect_template.clone(),
+            rewrite_host: self.rewrite_host.clone(),
+            rewrite_path: self.rewrite_path.clone(),
+            rewrite_port: self.rewrite_port,
+            headers: self.headers.clone(),
         })
     }
 }
@@ -771,7 +756,7 @@ pub enum FileClusterProtocolConfig {
     Tcp,
 }
 
-#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 #[serde(deny_unknown_fields)]
 pub struct FileClusterConfig {
     pub frontends: Vec,
@@ -779,13 +764,19 @@ pub struct FileClusterConfig {
     pub protocol: FileClusterProtocolConfig,
     pub sticky_session: Option,
     pub https_redirect: Option,
+    pub https_redirect_port: Option,
     #[serde(default)]
     pub send_proxy: Option,
     #[serde(default)]
     pub load_balancing: LoadBalancingAlgorithms,
-    pub answer_503: Option,
     #[serde(default)]
     pub load_metric: Option,
+    #[serde(default)]
+    pub answers: Option>,
+    #[serde(default)]
+    pub authorized_hashes: Vec,
+    #[serde(default)]
+    pub www_authenticate: Option,
 }
 
 #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
@@ -862,15 +853,17 @@ impl FileClusterConfig {
                     let http_frontend = frontend.to_http_front(cluster_id)?;
                     frontends.push(http_frontend);
                 }
-
-                let answer_503 = self.answer_503.as_ref().and_then(|path| {
-                    Config::load_file(path)
-                        .map_err(|e| {
-                            error!("cannot load 503 error page at path '{}': {:?}", path, e);
-                            e
-                        })
-                        .ok()
-                });
+                // self.authorized_hashes
+                //     .iter()
+                //     .map(|hash| {
+                //         hex::decode(hash)
+                //             .map_err(|_| ConfigError::InvalidHash(hash.clone()))
+                //             .and_then(|v| {
+                //                 v.try_into()
+                //                     .map_err(|_| ConfigError::InvalidHash(hash.clone()))
+                //             })
+                //     })
+                //     .collect::, ConfigError>>()?;
 
                 Ok(ClusterConfig::Http(HttpClusterConfig {
                     cluster_id: cluster_id.to_string(),
@@ -878,9 +871,12 @@ impl FileClusterConfig {
                     backends: self.backends,
                     sticky_session: self.sticky_session.unwrap_or(false),
                     https_redirect: self.https_redirect.unwrap_or(false),
+                    https_redirect_port: self.https_redirect_port,
                     load_balancing: self.load_balancing,
                     load_metric: self.load_metric,
-                    answer_503,
+                    answers: load_answers(self.answers.as_ref())?,
+                    authorized_hashes: self.authorized_hashes,
+                    www_authenticate: self.www_authenticate,
                 }))
             }
         }
@@ -899,9 +895,21 @@ pub struct HttpFrontendConfig {
     pub certificate_chain: Option>,
     #[serde(default)]
     pub tls_versions: Vec,
-    #[serde(default)]
+    #[serde(default = "default_rule_position")]
     pub position: RulePosition,
     pub tags: Option>,
+    pub required_auth: bool,
+    pub redirect: Option,
+    pub redirect_scheme: Option,
+    pub redirect_template: Option,
+    pub rewrite_host: Option,
+    pub rewrite_path: Option,
+    pub rewrite_port: Option,
+    pub headers: Vec,
+}
+
+fn default_rule_position() -> RulePosition {
+    RulePosition::Tree
 }
 
 impl HttpFrontendConfig {
@@ -909,6 +917,16 @@ impl HttpFrontendConfig {
         let mut v = Vec::new();
 
         let tags = self.tags.clone().unwrap_or_default();
+        let headers = self
+            .headers
+            .iter()
+            .cloned()
+            .map(|h| Header {
+                position: h.position.into(),
+                key: h.key,
+                val: h.val,
+            })
+            .collect();
 
         if self.key.is_some() && self.certificate.is_some() {
             v.push(
@@ -939,6 +957,14 @@ impl HttpFrontendConfig {
                     method: self.method.clone(),
                     position: self.position.into(),
                     tags,
+                    required_auth: Some(self.required_auth),
+                    redirect: self.redirect.map(Into::into),
+                    redirect_scheme: self.redirect_scheme.map(Into::into),
+                    redirect_template: self.redirect_template.clone(),
+                    rewrite_host: self.rewrite_host.clone(),
+                    rewrite_path: self.rewrite_path.clone(),
+                    rewrite_port: self.rewrite_port.map(|x| x as u32),
+                    headers,
                 })
                 .into(),
             );
@@ -953,6 +979,14 @@ impl HttpFrontendConfig {
                     method: self.method.clone(),
                     position: self.position.into(),
                     tags,
+                    required_auth: Some(self.required_auth),
+                    redirect: self.redirect.map(Into::into),
+                    redirect_scheme: self.redirect_scheme.map(Into::into),
+                    redirect_template: self.redirect_template.clone(),
+                    rewrite_host: self.rewrite_host.clone(),
+                    rewrite_path: self.rewrite_path.clone(),
+                    rewrite_port: self.rewrite_port.map(|x| x as u32),
+                    headers,
                 })
                 .into(),
             );
@@ -962,7 +996,7 @@ impl HttpFrontendConfig {
     }
 }
 
-#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 #[serde(deny_unknown_fields)]
 pub struct HttpClusterConfig {
     pub cluster_id: String,
@@ -970,9 +1004,12 @@ pub struct HttpClusterConfig {
     pub backends: Vec,
     pub sticky_session: bool,
     pub https_redirect: bool,
+    pub https_redirect_port: Option,
     pub load_balancing: LoadBalancingAlgorithms,
     pub load_metric: Option,
-    pub answer_503: Option,
+    pub answers: BTreeMap,
+    pub authorized_hashes: Vec,
+    pub www_authenticate: Option,
 }
 
 impl HttpClusterConfig {
@@ -981,10 +1018,13 @@ impl HttpClusterConfig {
             cluster_id: self.cluster_id.clone(),
             sticky_session: self.sticky_session,
             https_redirect: self.https_redirect,
+            https_redirect_port: self.https_redirect_port.map(|s| s as u32),
             proxy_protocol: None,
             load_balancing: self.load_balancing as i32,
-            answer_503: self.answer_503.clone(),
             load_metric: self.load_metric.map(|s| s as i32),
+            answers: self.answers.clone(),
+            authorized_hashes: self.authorized_hashes.clone(),
+            www_authenticate: self.www_authenticate.clone(),
         })
         .into()];
 
@@ -1040,10 +1080,13 @@ impl TcpClusterConfig {
             cluster_id: self.cluster_id.clone(),
             sticky_session: false,
             https_redirect: false,
+            https_redirect_port: None,
             proxy_protocol: self.proxy_protocol.map(|s| s as i32),
             load_balancing: self.load_balancing as i32,
             load_metric: self.load_metric.map(|s| s as i32),
-            answer_503: None,
+            answers: Default::default(),
+            authorized_hashes: Default::default(),
+            www_authenticate: None,
         })
         .into()];
 
@@ -1082,7 +1125,7 @@ impl TcpClusterConfig {
     }
 }
 
-#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 pub enum ClusterConfig {
     Http(HttpClusterConfig),
     Tcp(TcpClusterConfig),
@@ -1868,7 +1911,7 @@ mod tests {
             SocketAddress::new_v4(127, 0, 0, 1, 8080),
             ListenerProtocol::Http,
         )
-        .with_answer_404_path(Some("404.html"))
+        .with_answer(404, Some("404.html".to_string()))
         .to_owned();
         println!("http: {:?}", to_string(&http));
 
@@ -1876,7 +1919,7 @@ mod tests {
             SocketAddress::new_v4(127, 0, 0, 1, 8443),
             ListenerProtocol::Https,
         )
-        .with_answer_404_path(Some("404.html"))
+        .with_answer(404, Some("404.html".to_string()))
         .to_owned();
         println!("https: {:?}", to_string(&https));
 
diff --git a/command/src/proto/display.rs b/command/src/proto/display.rs
index 9331de3a7..b2701c2fb 100644
--- a/command/src/proto/display.rs
+++ b/command/src/proto/display.rs
@@ -13,12 +13,11 @@ use crate::{
         command::{
             filtered_metrics, protobuf_endpoint, request::RequestType,
             response_content::ContentType, AggregatedMetrics, AvailableMetrics, CertificateAndKey,
-            CertificateSummary, CertificatesWithFingerprints, ClusterMetrics, CustomHttpAnswers,
-            Event, EventKind, FilteredMetrics, HttpEndpoint, HttpListenerConfig,
-            HttpsListenerConfig, ListOfCertificatesByAddress, ListedFrontends, ListenersList,
-            ProtobufEndpoint, QueryCertificatesFilters, RequestCounts, Response, ResponseContent,
-            ResponseStatus, RunState, SocketAddress, TlsVersion, WorkerInfos, WorkerMetrics,
-            WorkerResponses,
+            CertificateSummary, CertificatesWithFingerprints, ClusterMetrics, Event, EventKind,
+            FilteredMetrics, HttpEndpoint, HttpListenerConfig, HttpsListenerConfig,
+            ListOfCertificatesByAddress, ListedFrontends, ListenersList, ProtobufEndpoint,
+            QueryCertificatesFilters, RequestCounts, Response, ResponseContent, ResponseStatus,
+            RunState, SocketAddress, TlsVersion, WorkerInfos, WorkerMetrics, WorkerResponses,
         },
         DisplayError,
     },
@@ -1011,8 +1010,8 @@ impl Display for HttpListenerConfig {
         table.set_format(*prettytable::format::consts::FORMAT_BOX_CHARS);
         table.add_row(row!["socket address", format!("{:?}", self.address)]);
         table.add_row(row!["public address", format!("{:?}", self.public_address),]);
-        for http_answer_row in CustomHttpAnswers::to_rows(&self.http_answers) {
-            table.add_row(http_answer_row);
+        for (name, content) in &self.answers {
+            table.add_row(row![format!("answer({name})"), content]);
         }
         table.add_row(row!["expect proxy", self.expect_proxy]);
         table.add_row(row!["sticky name", self.sticky_name]);
@@ -1036,8 +1035,8 @@ impl Display for HttpsListenerConfig {
 
         table.add_row(row!["socket address", format!("{:?}", self.address)]);
         table.add_row(row!["public address", format!("{:?}", self.public_address)]);
-        for http_answer_row in CustomHttpAnswers::to_rows(&self.http_answers) {
-            table.add_row(http_answer_row);
+        for (name, content) in &self.answers {
+            table.add_row(row![format!("answer({name})"), content]);
         }
         table.add_row(row!["versions", tls_versions]);
         table.add_row(row!["cipher list", list_string_vec(&self.cipher_list),]);
@@ -1059,42 +1058,6 @@ impl Display for HttpsListenerConfig {
     }
 }
 
-impl CustomHttpAnswers {
-    fn to_rows(option: &Option) -> Vec {
-        let mut rows = Vec::new();
-        if let Some(answers) = option {
-            if let Some(a) = &answers.answer_301 {
-                rows.push(row!("301", a));
-            }
-            if let Some(a) = &answers.answer_400 {
-                rows.push(row!("400", a));
-            }
-            if let Some(a) = &answers.answer_404 {
-                rows.push(row!("404", a));
-            }
-            if let Some(a) = &answers.answer_408 {
-                rows.push(row!("408", a));
-            }
-            if let Some(a) = &answers.answer_413 {
-                rows.push(row!("413", a));
-            }
-            if let Some(a) = &answers.answer_502 {
-                rows.push(row!("502", a));
-            }
-            if let Some(a) = &answers.answer_503 {
-                rows.push(row!("503", a));
-            }
-            if let Some(a) = &answers.answer_504 {
-                rows.push(row!("504", a));
-            }
-            if let Some(a) = &answers.answer_507 {
-                rows.push(row!("507", a));
-            }
-        }
-        rows
-    }
-}
-
 impl Display for Event {
     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
         let kind = match self.kind() {
diff --git a/command/src/request.rs b/command/src/request.rs
index f43e7c621..beb08501a 100644
--- a/command/src/request.rs
+++ b/command/src/request.rs
@@ -14,8 +14,7 @@ use crate::{
     proto::{
         command::{
             ip_address, request::RequestType, InitialState, IpAddress, LoadBalancingAlgorithms,
-            PathRuleKind, Request, RequestHttpFrontend, RulePosition, SocketAddress, Uint128,
-            WorkerRequest,
+            PathRuleKind, Request, RequestHttpFrontend, SocketAddress, Uint128, WorkerRequest,
         },
         display::format_request_type,
     },
@@ -161,18 +160,21 @@ impl RequestHttpFrontend {
     /// convert a requested frontend to a usable one by parsing its address
     pub fn to_frontend(self) -> Result {
         Ok(HttpFrontend {
+            position: self.position(),
+            required_auth: self.required_auth.unwrap_or(false),
+            redirect: self.redirect(),
+            redirect_scheme: self.redirect_scheme(),
+            redirect_template: self.redirect_template,
+            rewrite_host: self.rewrite_host,
+            rewrite_path: self.rewrite_path,
+            rewrite_port: self.rewrite_port.map(|x| x as u16),
             address: self.address.into(),
             cluster_id: self.cluster_id,
             hostname: self.hostname,
             path: self.path,
             method: self.method,
-            position: RulePosition::try_from(self.position).map_err(|_| {
-                RequestError::InvalidValue {
-                    name: "position".to_string(),
-                    value: self.position,
-                }
-            })?,
             tags: Some(self.tags),
+            headers: self.headers,
         })
     }
 }
diff --git a/command/src/response.rs b/command/src/response.rs
index c399b657a..6bd56a739 100644
--- a/command/src/response.rs
+++ b/command/src/response.rs
@@ -1,10 +1,15 @@
-use std::{cmp::Ordering, collections::BTreeMap, fmt, net::SocketAddr};
+use std::{
+    cmp::Ordering,
+    collections::BTreeMap,
+    fmt::{self, Debug},
+    net::SocketAddr,
+};
 
 use crate::{
     proto::command::{
-        AddBackend, FilteredTimeSerie, LoadBalancingParams, PathRule, PathRuleKind,
-        RequestHttpFrontend, RequestTcpFrontend, Response, ResponseContent, ResponseStatus,
-        RulePosition, RunState, WorkerResponse,
+        AddBackend, FilteredTimeSerie, Header, LoadBalancingParams, PathRule, PathRuleKind,
+        RedirectPolicy, RedirectScheme, RequestHttpFrontend, RequestTcpFrontend, Response,
+        ResponseContent, ResponseStatus, RulePosition, RunState, WorkerResponse,
     },
     state::ClusterId,
 };
@@ -39,6 +44,18 @@ pub struct HttpFrontend {
     #[serde(default)]
     pub position: RulePosition,
     pub tags: Option>,
+    pub required_auth: bool,
+    pub redirect: RedirectPolicy,
+    pub redirect_scheme: RedirectScheme,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub redirect_template: Option,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub rewrite_host: Option,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub rewrite_path: Option,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub rewrite_port: Option,
+    pub headers: Vec
, } impl From for RequestHttpFrontend { @@ -51,6 +68,14 @@ impl From for RequestHttpFrontend { method: val.method, position: val.position.into(), tags: val.tags.unwrap_or_default(), + required_auth: Some(val.required_auth), + redirect: Some(val.redirect.into()), + redirect_scheme: Some(val.redirect_scheme.into()), + redirect_template: val.redirect_template, + rewrite_host: val.rewrite_host, + rewrite_path: val.rewrite_path, + rewrite_port: val.rewrite_port.map(|x| x as u32), + headers: val.headers, } } } diff --git a/command/src/state.rs b/command/src/state.rs index e8371a8bd..0d6524b23 100644 --- a/command/src/state.rs +++ b/command/src/state.rs @@ -1482,7 +1482,7 @@ mod tests { use super::*; use crate::proto::command::{ - CustomHttpAnswers, LoadBalancingParams, RequestHttpFrontend, RulePosition, + LoadBalancingParams, RedirectPolicy, RedirectScheme, RequestHttpFrontend, RulePosition, }; #[test] @@ -1724,6 +1724,9 @@ mod tests { hostname: String::from("test.local"), path: PathRule::prefix(String::from("/abc")), address: SocketAddress::new_v4(0, 0, 0, 0, 8080), + required_auth: Some(false), + redirect: Some(RedirectPolicy::Forward.into()), + redirect_scheme: Some(RedirectScheme::UseSame.into()), ..Default::default() }) .into(), @@ -1985,10 +1988,7 @@ mod tests { #[test] fn listener_diff() { let mut state: ConfigState = Default::default(); - let custom_http_answers = Some(CustomHttpAnswers { - answer_404: Some("test".to_string()), - ..Default::default() - }); + let answers = BTreeMap::from([("404".to_string(), "test".to_string())]); state .dispatch( &RequestType::AddTcpListener(TcpListenerConfig { @@ -2052,7 +2052,7 @@ mod tests { .dispatch( &RequestType::AddHttpListener(HttpListenerConfig { address: SocketAddress::new_v4(0, 0, 0, 0, 8080), - http_answers: custom_http_answers.clone(), + answers: answers.clone(), ..Default::default() }) .into(), @@ -2072,7 +2072,7 @@ mod tests { .dispatch( &RequestType::AddHttpsListener(HttpsListenerConfig { address: SocketAddress::new_v4(0, 0, 0, 0, 8443), - http_answers: custom_http_answers.clone(), + answers: answers.clone(), ..Default::default() }) .into(), @@ -2114,7 +2114,7 @@ mod tests { .into(), RequestType::AddHttpListener(HttpListenerConfig { address: SocketAddress::new_v4(0, 0, 0, 0, 8080), - http_answers: custom_http_answers.clone(), + answers: answers.clone(), ..Default::default() }) .into(), @@ -2131,7 +2131,7 @@ mod tests { .into(), RequestType::AddHttpsListener(HttpsListenerConfig { address: SocketAddress::new_v4(0, 0, 0, 0, 8443), - http_answers: custom_http_answers.clone(), + answers: answers.clone(), ..Default::default() }) .into(), diff --git a/e2e/src/http_utils/mod.rs b/e2e/src/http_utils/mod.rs index 9e6248df2..345ebc1a3 100644 --- a/e2e/src/http_utils/mod.rs +++ b/e2e/src/http_utils/mod.rs @@ -24,12 +24,17 @@ pub fn http_request, S2: Into, S3: Into, S4: In ) } -pub fn immutable_answer(status: u16) -> String { +pub fn immutable_answer(status: u16, content_length: bool) -> String { + let content_length = if content_length { + "\r\nContent-Length: 0" + } else { + "" + }; match status { - 400 => String::from("HTTP/1.1 400 Bad Request\r\nCache-Control: no-cache\r\nConnection: close\r\n\r\n"), - 404 => String::from("HTTP/1.1 404 Not Found\r\nCache-Control: no-cache\r\nConnection: close\r\n\r\n"), - 502 => String::from("HTTP/1.1 502 Bad Gateway\r\nCache-Control: no-cache\r\nConnection: close\r\n\r\n"), - 503 => String::from("HTTP/1.1 503 Service Unavailable\r\nCache-Control: no-cache\r\nConnection: close\r\n\r\n"), + 400 => format!("HTTP/1.1 400 Bad Request\r\nCache-Control: no-cache\r\nConnection: close{content_length}\r\n\r\n"), + 404 => format!("HTTP/1.1 404 Not Found\r\nCache-Control: no-cache\r\nConnection: close{content_length}\r\n\r\n"), + 502 => format!("HTTP/1.1 502 Bad Gateway\r\nCache-Control: no-cache\r\nConnection: close{content_length}\r\n\r\n"), + 503 => format!("HTTP/1.1 503 Service Unavailable\r\nCache-Control: no-cache\r\nConnection: close{content_length}\r\n\r\n"), _ => unimplemented!() } } diff --git a/e2e/src/tests/tests.rs b/e2e/src/tests/tests.rs index 6bcedfc06..e2cd874c1 100644 --- a/e2e/src/tests/tests.rs +++ b/e2e/src/tests/tests.rs @@ -1,4 +1,5 @@ use std::{ + collections::BTreeMap, net::SocketAddr, thread, time::{Duration, Instant}, @@ -10,7 +11,7 @@ use sozu_command_lib::{ logging::setup_default_logging, proto::command::{ request::RequestType, ActivateListener, AddCertificate, CertificateAndKey, Cluster, - CustomHttpAnswers, ListenerType, RemoveBackend, RequestHttpFrontend, SocketAddress, + ListenerType, RemoveBackend, RequestHttpFrontend, SocketAddress, }, scm_socket::Listeners, state::ConfigState, @@ -643,14 +644,12 @@ fn try_http_behaviors() -> State { let mut http_config = ListenerBuilder::new_http(front_address.into()) .to_http(None) .unwrap(); - let http_answers = CustomHttpAnswers { - answer_400: Some(immutable_answer(400)), - answer_404: Some(immutable_answer(404)), - answer_502: Some(immutable_answer(502)), - answer_503: Some(immutable_answer(503)), - ..Default::default() - }; - http_config.http_answers = Some(http_answers); + http_config.answers = BTreeMap::from([ + ("400".to_string(), immutable_answer(400, false)), + ("404".to_string(), immutable_answer(404, false)), + ("502".to_string(), immutable_answer(502, false)), + ("503".to_string(), immutable_answer(503, false)), + ]); worker.send_proxy_request_type(RequestType::AddHttpListener(http_config)); worker.send_proxy_request_type(RequestType::ActivateListener(ActivateListener { @@ -672,7 +671,7 @@ fn try_http_behaviors() -> State { let response = client.receive(); println!("response: {response:?}"); - assert_eq!(response, Some(immutable_answer(404))); + assert_eq!(response, Some(immutable_answer(404, true))); assert_eq!(client.receive(), None); worker.send_proxy_request_type(RequestType::AddHttpFrontend(RequestHttpFrontend { @@ -687,7 +686,7 @@ fn try_http_behaviors() -> State { let response = client.receive(); println!("response: {response:?}"); - assert_eq!(response, Some(immutable_answer(503))); + assert_eq!(response, Some(immutable_answer(503, true))); assert_eq!(client.receive(), None); let back_address = create_local_address(); @@ -707,7 +706,7 @@ fn try_http_behaviors() -> State { let response = client.receive(); println!("response: {response:?}"); - assert_eq!(response, Some(immutable_answer(400))); + assert_eq!(response, Some(immutable_answer(400, true))); assert_eq!(client.receive(), None); let mut backend = SyncBackend::new("backend", back_address, "TEST\r\n\r\n"); @@ -724,7 +723,7 @@ fn try_http_behaviors() -> State { let response = client.receive(); println!("request: {request:?}"); println!("response: {response:?}"); - assert_eq!(response, Some(immutable_answer(502))); + assert_eq!(response, Some(immutable_answer(502, true))); assert_eq!(client.receive(), None); info!("expecting 200"); @@ -787,7 +786,7 @@ fn try_http_behaviors() -> State { let response = client.receive(); println!("request: {request:?}"); println!("response: {response:?}"); - assert_eq!(response, Some(immutable_answer(503))); + assert_eq!(response, Some(immutable_answer(503, true))); assert_eq!(client.receive(), None); worker.send_proxy_request_type(RequestType::RemoveBackend(RemoveBackend { @@ -951,12 +950,10 @@ fn try_https_redirect() -> State { .to_http(None) .unwrap(); let answer_301_prefix = "HTTP/1.1 301 Moved Permanently\r\nLocation: "; - - let http_answers = CustomHttpAnswers { - answer_301: Some(format!("{answer_301_prefix}%REDIRECT_LOCATION\r\n\r\n")), - ..Default::default() - }; - http_config.http_answers = Some(http_answers); + http_config.answers = BTreeMap::from([( + "301".to_string(), + format!("{answer_301_prefix}%REDIRECT_LOCATION\r\n\r\n"), + )]); worker.send_proxy_request_type(RequestType::AddHttpListener(http_config)); worker.send_proxy_request_type(RequestType::ActivateListener(ActivateListener { @@ -987,7 +984,9 @@ fn try_https_redirect() -> State { client.connect(); client.send(); let answer = client.receive(); - let expected_answer = format!("{answer_301_prefix}https://example.com/redirected?true\r\n\r\n"); + let expected_answer = format!( + "{answer_301_prefix}https://example.com/redirected?true\r\nContent-Length: 0\r\n\r\n" + ); assert_eq!(answer, Some(expected_answer)); State::Success @@ -1242,7 +1241,7 @@ pub fn try_stick() -> State { backend1.send(0); let response = client.receive(); println!("response: {response:?}"); - assert!(request.unwrap().starts_with("GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\nX-Forwarded-For:")); + assert!(request.unwrap().starts_with("GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\nContent-Length: 0\r\nX-Forwarded-For:")); assert!(response.unwrap().starts_with("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nSet-Cookie: SOZUBALANCEID=sticky_cluster_0-0; Path=/\r\nSozu-Id:")); // invalid sticky_session @@ -1255,7 +1254,7 @@ pub fn try_stick() -> State { backend2.send(0); let response = client.receive(); println!("response: {response:?}"); - assert!(request.unwrap().starts_with("GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\nX-Forwarded-For:")); + assert!(request.unwrap().starts_with("GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\nContent-Length: 0\r\nX-Forwarded-For:")); assert!(response.unwrap().starts_with("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nSet-Cookie: SOZUBALANCEID=sticky_cluster_0-1; Path=/\r\nSozu-Id:")); // good sticky_session (force use backend2, round-robin would have chosen backend1) @@ -1268,7 +1267,7 @@ pub fn try_stick() -> State { backend2.send(0); let response = client.receive(); println!("response: {response:?}"); - assert!(request.unwrap().starts_with("GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\nX-Forwarded-For:")); + assert!(request.unwrap().starts_with("GET /api HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: foo=bar\r\nContent-Length: 0\r\nX-Forwarded-For:")); assert!(response .unwrap() .starts_with("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nSozu-Id:")); diff --git a/lib/Cargo.toml b/lib/Cargo.toml index d38e2a3d5..9c37cfbce 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -29,12 +29,13 @@ include = [ [dependencies] anyhow = "^1.0.89" +base64 = "0.22.1" cookie-factory = "^0.3.3" hdrhistogram = "^7.5.4" hex = "^0.4.3" hpack = "^0.3.0" idna = "^1.0.2" -kawa = { version = "^0.6.7", default-features = false } +kawa = { version = "^0.6.7", default-features = false, features = ["rc-alloc"]} libc = "^0.2.159" memchr = "^2.7.4" mio = { version = "^1.0.2", features = ["os-poll", "os-ext", "net"] } diff --git a/lib/assets/mycluster_200.html b/lib/assets/mycluster_200.html new file mode 100644 index 000000000..df7d10e88 --- /dev/null +++ b/lib/assets/mycluster_200.html @@ -0,0 +1,7 @@ +HTTP/1.1 200 OK +%Content-Length: %CONTENT_LENGTH +Sozu-Id: %REQUEST_ID + +

%CLUSTER_ID Custom 200

+

original url: %ROUTE

+

rewritten url: %REDIRECT_LOCATION

diff --git a/lib/examples/http.rs b/lib/examples/http.rs index 68ec35a03..7ecbc295e 100644 --- a/lib/examples/http.rs +++ b/lib/examples/http.rs @@ -45,7 +45,6 @@ fn main() -> anyhow::Result<()> { sticky_session: false, https_redirect: false, load_balancing: LoadBalancingAlgorithms::RoundRobin as i32, - answer_503: Some("A custom forbidden message".to_string()), ..Default::default() }; diff --git a/lib/src/http.rs b/lib/src/http.rs index 2f6e69e39..359ba01d9 100644 --- a/lib/src/http.rs +++ b/lib/src/http.rs @@ -1,11 +1,10 @@ use std::{ cell::RefCell, - collections::{hash_map::Entry, BTreeMap, HashMap}, + collections::{hash_map::Entry, HashMap}, io::ErrorKind, net::{Shutdown, SocketAddr}, os::unix::io::AsRawFd, rc::{Rc, Weak}, - str::from_utf8_unchecked, time::{Duration, Instant}, }; @@ -17,7 +16,6 @@ use mio::{ use rusty_ulid::Ulid; use sozu_command::{ - logging::CachedTags, proto::command::{ request::RequestType, Cluster, HttpListenerConfig, ListenerType, RemoveListener, RequestHttpFrontend, WorkerRequest, WorkerResponse, @@ -31,21 +29,17 @@ use crate::{ backends::BackendMap, pool::Pool, protocol::{ - http::{ - answers::HttpAnswers, - parser::{hostname_and_port, Method}, - ResponseStream, - }, + http::{answers::HttpAnswers, parser::Method, ResponseStream}, proxy_protocol::expect::ExpectProxyProtocol, Http, Pipe, SessionState, }, - router::{Route, Router}, + router::{RouteResult, Router}, server::{ListenToken, SessionManager}, socket::server_bind, timer::TimeoutContainer, - AcceptError, FrontendFromRequestError, L7ListenerHandler, L7Proxy, ListenerError, - ListenerHandler, Protocol, ProxyConfiguration, ProxyError, ProxySession, SessionIsToBeClosed, - SessionMetrics, SessionResult, StateMachineBuilder, StateResult, + AcceptError, FrontendFromRequestError, L7ListenerHandler, L7Proxy, ListenerError, Protocol, + ProxyConfiguration, ProxyError, ProxySession, SessionIsToBeClosed, SessionMetrics, + SessionResult, StateMachineBuilder, StateResult, }; #[derive(PartialEq, Eq)] @@ -63,7 +57,7 @@ StateMachineBuilder! { enum HttpStateMachine impl SessionState { Expect(ExpectProxyProtocol), Http(Http), - WebSocket(Pipe), + WebSocket(Pipe), } } @@ -214,11 +208,14 @@ impl HttpSession { } } - fn upgrade_http(&mut self, http: Http) -> Option { + fn upgrade_http( + &mut self, + mut http: Http, + ) -> Option { debug!("http switching to ws"); - let front_token = self.frontend_token; - let back_token = match http.backend_token { - Some(back_token) => back_token, + let frontend_token = self.frontend_token; + let origin = match http.origin.take() { + Some(origin) => origin, None => { warn!( "Could not upgrade http request on cluster '{:?}' ({:?}) using backend '{:?}' into websocket for request '{}'", @@ -228,7 +225,7 @@ impl HttpSession { } }; - let ws_context = http.websocket_context(); + let websocket_context = http.websocket_context(); let mut container_frontend_timeout = http.container_frontend_timeout; let mut container_backend_timeout = http.container_backend_timeout; container_frontend_timeout.reset(); @@ -242,25 +239,25 @@ impl HttpSession { let mut pipe = Pipe::new( backend_buffer, - http.context.backend_id, - http.backend_socket, - http.backend, + Some(origin.backend_id), + Some(origin.socket), + Some(origin.backend), Some(container_backend_timeout), Some(container_frontend_timeout), http.context.cluster_id, http.request_stream.storage.buffer, - front_token, + frontend_token, http.frontend_socket, - self.listener.clone(), Protocol::HTTP, http.context.id, http.context.session_address, - ws_context, + websocket_context, + http.context.tags, ); pipe.frontend_readiness.event = http.frontend_readiness.event; pipe.backend_readiness.event = http.backend_readiness.event; - pipe.set_back_token(back_token); + pipe.set_back_token(origin.token); gauge_add!("protocol.http", -1); gauge_add!("protocol.ws", 1); @@ -269,7 +266,7 @@ impl HttpSession { Some(HttpStateMachine::WebSocket(pipe)) } - fn upgrade_websocket(&self, ws: Pipe) -> Option { + fn upgrade_websocket(&self, ws: Pipe) -> Option { // what do we do here? error!("Upgrade called on WS, this should not happen"); Some(HttpStateMachine::WebSocket(ws)) @@ -400,27 +397,9 @@ pub struct HttpListener { config: HttpListenerConfig, fronts: Router, listener: Option, - tags: BTreeMap, token: Token, } -impl ListenerHandler for HttpListener { - fn get_addr(&self) -> &SocketAddr { - &self.address - } - - fn get_tags(&self, key: &str) -> Option<&CachedTags> { - self.tags.get(key) - } - - fn set_tags(&mut self, key: String, tags: Option>) { - match tags { - Some(tags) => self.tags.insert(key, CachedTags::new(tags)), - None => self.tags.remove(&key), - }; - } -} - impl L7ListenerHandler for HttpListener { fn get_sticky_name(&self) -> &str { &self.config.sticky_name @@ -430,53 +409,16 @@ impl L7ListenerHandler for HttpListener { self.config.connect_timeout } - // redundant, already called once in extract_route fn frontend_from_request( &self, host: &str, - uri: &str, + path: &str, method: &Method, - ) -> Result { - let start = Instant::now(); - let (remaining_input, (hostname, _)) = match hostname_and_port(host.as_bytes()) { - Ok(tuple) => tuple, - Err(parse_error) => { - // parse_error contains a slice of given_host, which should NOT escape this scope - return Err(FrontendFromRequestError::HostParse { - host: host.to_owned(), - error: parse_error.to_string(), - }); - } - }; - if remaining_input != &b""[..] { - return Err(FrontendFromRequestError::InvalidCharsAfterHost( - host.to_owned(), - )); - } - - /*if port == Some(&b"80"[..]) { - // it is alright to call from_utf8_unchecked, - // we already verified that there are only ascii - // chars in there - unsafe { from_utf8_unchecked(hostname) } - } else { - host - } - */ - let host = unsafe { from_utf8_unchecked(hostname) }; - - let route = self.fronts.lookup(host, uri, method).map_err(|e| { + ) -> Result { + self.fronts.lookup(host, path, method).map_err(|e| { incr!("http.failed_backend_matching"); FrontendFromRequestError::NoClusterFound(e) - })?; - - let now = Instant::now(); - - if let Route::ClusterId(cluster) = &route { - time!("frontend_matching_time", cluster, (now - start).as_millis()); - } - - Ok(route) + }) } } @@ -593,18 +535,19 @@ impl HttpProxy { } pub fn add_cluster(&mut self, mut cluster: Cluster) -> Result<(), ProxyError> { - if let Some(answer_503) = cluster.answer_503.take() { + if !cluster.answers.is_empty() { for listener in self.listeners.values() { listener .borrow() .answers .borrow_mut() - .add_custom_answer(&cluster.cluster_id, answer_503.clone()) - .map_err(|(status, error)| { - ProxyError::AddCluster(ListenerError::TemplateParse(status, error)) + .add_cluster_answers(&cluster.cluster_id, &cluster.answers) + .map_err(|(name, error)| { + ProxyError::AddCluster(ListenerError::TemplateParse(name, error)) })?; } } + cluster.answers.clear(); self.clusters.insert(cluster.cluster_id.clone(), cluster); Ok(()) } @@ -617,7 +560,7 @@ impl HttpProxy { .borrow() .answers .borrow_mut() - .remove_custom_answer(cluster_id); + .remove_cluster_answers(cluster_id); } Ok(()) } @@ -637,13 +580,9 @@ impl HttpProxy { .ok_or(ProxyError::NoListenerFound(front.address))? .borrow_mut(); - let hostname = front.hostname.to_owned(); - let tags = front.tags.to_owned(); - listener .add_http_front(front) .map_err(ProxyError::AddFrontend)?; - listener.set_tags(hostname, tags); Ok(()) } @@ -662,13 +601,10 @@ impl HttpProxy { .ok_or(ProxyError::NoListenerFound(front.address))? .borrow_mut(); - let hostname = front.hostname.to_owned(); - listener .remove_http_front(front) .map_err(ProxyError::RemoveFrontend)?; - listener.set_tags(hostname, None); Ok(()) } @@ -725,13 +661,12 @@ impl HttpListener { active: false, address: config.address.into(), answers: Rc::new(RefCell::new( - HttpAnswers::new(&config.http_answers) - .map_err(|(status, error)| ListenerError::TemplateParse(status, error))?, + HttpAnswers::new(&config.answers) + .map_err(|(name, error)| ListenerError::TemplateParse(name, error))?, )), config, fronts: Router::new(), listener: None, - tags: BTreeMap::new(), token, }) } @@ -1050,9 +985,19 @@ pub mod testing { mod tests { extern crate tiny_http; + use std::{ + collections::BTreeMap, + io::{Read, Write}, + net::TcpStream, + str, + sync::{Arc, Barrier}, + thread, + time::Duration, + }; + use super::testing::start_http_worker; use super::*; - use sozu_command::proto::command::{CustomHttpAnswers, SocketAddress}; + use sozu_command::proto::command::{RedirectPolicy, RedirectScheme, SocketAddress}; use crate::sozu_command::{ channel::Channel, @@ -1061,15 +1006,6 @@ mod tests { response::{Backend, HttpFrontend}, }; - use std::{ - io::{Read, Write}, - net::TcpStream, - str, - sync::{Arc, Barrier}, - thread, - time::Duration, - }; - /* #[test] #[cfg(target_pointer_width = "64")] @@ -1325,6 +1261,14 @@ mod tests { path: PathRule::prefix(uri1), position: RulePosition::Tree, cluster_id: Some(cluster_id1), + required_auth: false, + redirect: RedirectPolicy::Forward, + redirect_scheme: RedirectScheme::UseSame, + redirect_template: None, + rewrite_host: None, + rewrite_path: None, + rewrite_port: None, + headers: vec![], tags: None, }) .expect("Could not add http frontend"); @@ -1336,6 +1280,14 @@ mod tests { path: PathRule::prefix(uri2), position: RulePosition::Tree, cluster_id: Some(cluster_id2), + required_auth: false, + redirect: RedirectPolicy::Forward, + redirect_scheme: RedirectScheme::UseSame, + redirect_template: None, + rewrite_host: None, + rewrite_path: None, + rewrite_port: None, + headers: vec![], tags: None, }) .expect("Could not add http frontend"); @@ -1347,6 +1299,14 @@ mod tests { path: PathRule::prefix(uri3), position: RulePosition::Tree, cluster_id: Some(cluster_id3), + required_auth: false, + redirect: RedirectPolicy::Forward, + redirect_scheme: RedirectScheme::UseSame, + redirect_template: None, + rewrite_host: None, + rewrite_path: None, + rewrite_port: None, + headers: vec![], tags: None, }) .expect("Could not add http frontend"); @@ -1358,6 +1318,14 @@ mod tests { path: PathRule::prefix("/test".to_owned()), position: RulePosition::Tree, cluster_id: Some("cluster_1".to_owned()), + required_auth: false, + redirect: RedirectPolicy::Forward, + redirect_scheme: RedirectScheme::UseSame, + redirect_template: None, + rewrite_host: None, + rewrite_path: None, + rewrite_port: None, + headers: vec![], tags: None, }) .expect("Could not add http frontend"); @@ -1372,13 +1340,10 @@ mod tests { listener: None, address: address.into(), fronts, - answers: Rc::new(RefCell::new( - HttpAnswers::new(&Some(CustomHttpAnswers::default())).unwrap(), - )), + answers: Rc::new(RefCell::new(HttpAnswers::new(&BTreeMap::new()).unwrap())), config: default_config, token: Token(0), active: true, - tags: BTreeMap::new(), }; let frontend1 = listener.frontend_from_request("lolcatho.st", "/", &Method::Get); @@ -1387,20 +1352,20 @@ mod tests { let frontend4 = listener.frontend_from_request("lolcatho.st", "/yolo/swag", &Method::Get); let frontend5 = listener.frontend_from_request("domain", "/", &Method::Get); assert_eq!( - frontend1.expect("should find frontend"), - Route::ClusterId("cluster_1".to_string()) + frontend1.expect("should find frontend").cluster_id, + Some("cluster_1".to_string()) ); assert_eq!( - frontend2.expect("should find frontend"), - Route::ClusterId("cluster_1".to_string()) + frontend2.expect("should find frontend").cluster_id, + Some("cluster_1".to_string()) ); assert_eq!( - frontend3.expect("should find frontend"), - Route::ClusterId("cluster_2".to_string()) + frontend3.expect("should find frontend").cluster_id, + Some("cluster_2".to_string()) ); assert_eq!( - frontend4.expect("should find frontend"), - Route::ClusterId("cluster_3".to_string()) + frontend4.expect("should find frontend").cluster_id, + Some("cluster_3".to_string()) ); assert!(frontend5.is_err()); } diff --git a/lib/src/https.rs b/lib/src/https.rs index ce03d33bd..c699df0df 100644 --- a/lib/src/https.rs +++ b/lib/src/https.rs @@ -1,11 +1,11 @@ use std::{ cell::RefCell, - collections::{hash_map::Entry, BTreeMap, HashMap}, + collections::{hash_map::Entry, HashMap}, io::ErrorKind, net::{Shutdown, SocketAddr as StdSocketAddr}, os::unix::io::AsRawFd, rc::{Rc, Weak}, - str::{from_utf8, from_utf8_unchecked}, + str::from_utf8, sync::Arc, time::{Duration, Instant}, }; @@ -53,24 +53,20 @@ use crate::{ pool::Pool, protocol::{ h2::Http2, - http::{ - answers::HttpAnswers, - parser::{hostname_and_port, Method}, - ResponseStream, - }, + http::{answers::HttpAnswers, parser::Method, ResponseStream}, proxy_protocol::expect::ExpectProxyProtocol, rustls::TlsHandshake, Http, Pipe, SessionState, }, - router::{Route, Router}, + router::{RouteResult, Router}, server::{ListenToken, SessionManager}, socket::{server_bind, FrontRustls}, timer::TimeoutContainer, tls::MutexCertificateResolver, util::UnwrapLog, - AcceptError, CachedTags, FrontendFromRequestError, L7ListenerHandler, L7Proxy, ListenerError, - ListenerHandler, Protocol, ProxyConfiguration, ProxyError, ProxySession, SessionIsToBeClosed, - SessionMetrics, SessionResult, StateMachineBuilder, StateResult, + AcceptError, FrontendFromRequestError, L7ListenerHandler, L7Proxy, ListenerError, Protocol, + ProxyConfiguration, ProxyError, ProxySession, SessionIsToBeClosed, SessionMetrics, + SessionResult, StateMachineBuilder, StateResult, }; // const SERVER_PROTOS: &[&str] = &["http/1.1", "h2"]; @@ -87,7 +83,7 @@ StateMachineBuilder! { Expect(ExpectProxyProtocol, ServerConnection), Handshake(TlsHandshake), Http(Http), - WebSocket(Pipe), + WebSocket(Pipe), Http2(Http2) -> todo!("H2"), } } @@ -333,11 +329,14 @@ impl HttpsSession { } } - fn upgrade_http(&self, http: Http) -> Option { + fn upgrade_http( + &self, + mut http: Http, + ) -> Option { debug!("https switching to wss"); let front_token = self.frontend_token; - let back_token = match http.backend_token { - Some(back_token) => back_token, + let origin = match http.origin.take() { + Some(origin) => origin, None => { warn!( "Could not upgrade https request on cluster '{:?}' ({:?}) using backend '{:?}' into secure websocket for request '{}'", @@ -347,7 +346,7 @@ impl HttpsSession { } }; - let ws_context = http.websocket_context(); + let websocket_context = http.websocket_context(); let mut container_frontend_timeout = http.container_frontend_timeout; let mut container_backend_timeout = http.container_backend_timeout; container_frontend_timeout.reset(); @@ -361,25 +360,25 @@ impl HttpsSession { let mut pipe = Pipe::new( backend_buffer, - http.context.backend_id, - http.backend_socket, - http.backend, + Some(origin.backend_id), + Some(origin.socket), + Some(origin.backend), Some(container_backend_timeout), Some(container_frontend_timeout), http.context.cluster_id, http.request_stream.storage.buffer, front_token, http.frontend_socket, - self.listener.clone(), - Protocol::HTTP, + Protocol::HTTPS, http.context.id, http.context.session_address, - ws_context, + websocket_context, + http.context.tags, ); pipe.frontend_readiness.event = http.frontend_readiness.event; pipe.backend_readiness.event = http.backend_readiness.event; - pipe.set_back_token(back_token); + pipe.set_back_token(origin.token); gauge_add!("protocol.https", -1); gauge_add!("protocol.wss", 1); @@ -392,10 +391,7 @@ impl HttpsSession { todo!() } - fn upgrade_websocket( - &self, - wss: Pipe, - ) -> Option { + fn upgrade_websocket(&self, wss: Pipe) -> Option { // what do we do here? error!("Upgrade called on WSS, this should not happen"); Some(HttpsStateMachine::WebSocket(wss)) @@ -534,27 +530,9 @@ pub struct HttpsListener { listener: Option, resolver: Arc, rustls_details: Arc, - tags: BTreeMap, token: Token, } -impl ListenerHandler for HttpsListener { - fn get_addr(&self) -> &StdSocketAddr { - &self.address - } - - fn get_tags(&self, key: &str) -> Option<&CachedTags> { - self.tags.get(key) - } - - fn set_tags(&mut self, key: String, tags: Option>) { - match tags { - Some(tags) => self.tags.insert(key, CachedTags::new(tags)), - None => self.tags.remove(&key), - }; - } -} - impl L7ListenerHandler for HttpsListener { fn get_sticky_name(&self) -> &str { &self.config.sticky_name @@ -567,44 +545,13 @@ impl L7ListenerHandler for HttpsListener { fn frontend_from_request( &self, host: &str, - uri: &str, + path: &str, method: &Method, - ) -> Result { - let start = Instant::now(); - let (remaining_input, (hostname, _)) = match hostname_and_port(host.as_bytes()) { - Ok(tuple) => tuple, - Err(parse_error) => { - // parse_error contains a slice of given_host, which should NOT escape this scope - return Err(FrontendFromRequestError::HostParse { - host: host.to_owned(), - error: parse_error.to_string(), - }); - } - }; - - if remaining_input != &b""[..] { - return Err(FrontendFromRequestError::InvalidCharsAfterHost( - host.to_owned(), - )); - } - - // it is alright to call from_utf8_unchecked, - // we already verified that there are only ascii - // chars in there - let host = unsafe { from_utf8_unchecked(hostname) }; - - let route = self.fronts.lookup(host, uri, method).map_err(|e| { + ) -> Result { + self.fronts.lookup(host, path, method).map_err(|e| { incr!("http.failed_backend_matching"); FrontendFromRequestError::NoClusterFound(e) - })?; - - let now = Instant::now(); - - if let Route::ClusterId(cluster) = &route { - time!("frontend_matching_time", cluster, (now - start).as_millis()); - } - - Ok(route) + }) } } @@ -624,13 +571,11 @@ impl HttpsListener { rustls_details: server_config, active: false, fronts: Router::new(), - answers: Rc::new(RefCell::new( - HttpAnswers::new(&config.http_answers) - .map_err(|(status, error)| ListenerError::TemplateParse(status, error))?, - )), + answers: Rc::new(RefCell::new(HttpAnswers::new(&config.answers).map_err( + |(status, error)| ListenerError::TemplateParse(status, error), + )?)), config, token, - tags: BTreeMap::new(), }) } @@ -997,18 +942,19 @@ impl HttpsProxy { &mut self, mut cluster: Cluster, ) -> Result, ProxyError> { - if let Some(answer_503) = cluster.answer_503.take() { + if !cluster.answers.is_empty() { for listener in self.listeners.values() { listener .borrow() .answers .borrow_mut() - .add_custom_answer(&cluster.cluster_id, answer_503.clone()) - .map_err(|(status, error)| { - ProxyError::AddCluster(ListenerError::TemplateParse(status, error)) + .add_cluster_answers(&cluster.cluster_id, &cluster.answers) + .map_err(|(name, error)| { + ProxyError::AddCluster(ListenerError::TemplateParse(name, error)) })?; } } + cluster.answers.clear(); self.clusters.insert(cluster.cluster_id.clone(), cluster); Ok(None) } @@ -1023,7 +969,7 @@ impl HttpsProxy { .borrow() .answers .borrow_mut() - .remove_custom_answer(cluster_id); + .remove_cluster_answers(cluster_id); } Ok(None) @@ -1047,7 +993,6 @@ impl HttpsProxy { .ok_or(ProxyError::NoListenerFound(front.address))? .borrow_mut(); - listener.set_tags(front.hostname.to_owned(), front.tags.to_owned()); listener .add_https_front(front) .map_err(ProxyError::AddFrontend)?; @@ -1072,7 +1017,6 @@ impl HttpsProxy { .ok_or(ProxyError::NoListenerFound(front.address))? .borrow_mut(); - listener.set_tags(front.hostname.to_owned(), None); listener .remove_https_front(front) .map_err(ProxyError::RemoveFrontend)?; @@ -1502,14 +1446,11 @@ pub mod testing { mod tests { use super::*; - use std::sync::Arc; + use std::{collections::BTreeMap, sync::Arc}; - use sozu_command::{ - config::ListenerBuilder, - proto::command::{CustomHttpAnswers, SocketAddress}, - }; + use sozu_command::{config::ListenerBuilder, proto::command::SocketAddress}; - use crate::router::{pattern_trie::TrieNode, MethodRule, PathRule, Route, Router}; + use crate::router::{pattern_trie::TrieNode, Frontend, MethodRule, PathRule, Router}; /* #[test] @@ -1542,25 +1483,25 @@ mod tests { "lolcatho.st".as_bytes(), &PathRule::Prefix(uri1), &MethodRule::new(None), - &Route::ClusterId(cluster_id1.clone()) + &Frontend::forward(cluster_id1.clone()) )); assert!(fronts.add_tree_rule( "lolcatho.st".as_bytes(), &PathRule::Prefix(uri2), &MethodRule::new(None), - &Route::ClusterId(cluster_id2) + &Frontend::forward(cluster_id2) )); assert!(fronts.add_tree_rule( "lolcatho.st".as_bytes(), &PathRule::Prefix(uri3), &MethodRule::new(None), - &Route::ClusterId(cluster_id3) + &Frontend::forward(cluster_id3) )); assert!(fronts.add_tree_rule( "other.domain".as_bytes(), &PathRule::Prefix("test".to_string()), &MethodRule::new(None), - &Route::ClusterId(cluster_id1) + &Frontend::forward(cluster_id1) )); let address = SocketAddress::new_v4(127, 0, 0, 1, 1032); @@ -1588,38 +1529,35 @@ mod tests { fronts, rustls_details, resolver, - answers: Rc::new(RefCell::new( - HttpAnswers::new(&Some(CustomHttpAnswers::default())).unwrap(), - )), + answers: Rc::new(RefCell::new(HttpAnswers::new(&BTreeMap::new()).unwrap())), config: default_config, token: Token(0), active: true, - tags: BTreeMap::new(), }; println!("TEST {}", line!()); let frontend1 = listener.frontend_from_request("lolcatho.st", "/", &Method::Get); assert_eq!( - frontend1.expect("should find a frontend"), - Route::ClusterId("cluster_1".to_string()) + frontend1.expect("should find a frontend").cluster_id, + Some("cluster_1".to_string()) ); println!("TEST {}", line!()); let frontend2 = listener.frontend_from_request("lolcatho.st", "/test", &Method::Get); assert_eq!( - frontend2.expect("should find a frontend"), - Route::ClusterId("cluster_1".to_string()) + frontend2.expect("should find a frontend").cluster_id, + Some("cluster_1".to_string()) ); println!("TEST {}", line!()); let frontend3 = listener.frontend_from_request("lolcatho.st", "/yolo/test", &Method::Get); assert_eq!( - frontend3.expect("should find a frontend"), - Route::ClusterId("cluster_2".to_string()) + frontend3.expect("should find a frontend").cluster_id, + Some("cluster_2".to_string()) ); println!("TEST {}", line!()); let frontend4 = listener.frontend_from_request("lolcatho.st", "/yolo/swag", &Method::Get); assert_eq!( - frontend4.expect("should find a frontend"), - Route::ClusterId("cluster_3".to_string()) + frontend4.expect("should find a frontend").cluster_id, + Some("cluster_3".to_string()) ); println!("TEST {}", line!()); let frontend5 = listener.frontend_from_request("domain", "/", &Method::Get); diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 0c2bca139..bd90b325d 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -106,7 +106,6 @@ //! sticky_session: false, //! https_redirect: false, //! load_balancing: LoadBalancingAlgorithms::RoundRobin as i32, -//! answer_503: Some("A custom forbidden message".to_string()), //! ..Default::default() //! }; //! ``` @@ -249,7 +248,6 @@ //! sticky_session: false, //! https_redirect: false, //! load_balancing: LoadBalancingAlgorithms::RoundRobin as i32, -//! answer_503: Some("A custom forbidden message".to_string()), //! ..Default::default() //! }; //! @@ -350,7 +348,7 @@ use backends::BackendError; use hex::FromHexError; use mio::{net::TcpStream, Interest, Token}; use protocol::http::{answers::TemplateError, parser::Method}; -use router::RouterError; +use router::{RouteResult, RouterError}; use socket::ServerBindError; use tls::CertificateResolverError; @@ -362,7 +360,7 @@ use sozu_command::{ AsStr, ObjectKind, }; -use crate::{backends::BackendMap, router::Route}; +use crate::backends::BackendMap; /// Anything that can be registered in mio (subscribe to kernel events) #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -521,16 +519,10 @@ macro_rules! StateMachineBuilder { } } -pub trait ListenerHandler { - fn get_addr(&self) -> &SocketAddr; +pub trait L4ListenerHandler { + fn get_tags(&self) -> Option<&CachedTags>; - fn get_tags(&self, key: &str) -> Option<&CachedTags>; - - fn get_concatenated_tags(&self, key: &str) -> Option<&str> { - self.get_tags(key).map(|tags| tags.concatenated.as_str()) - } - - fn set_tags(&mut self, key: String, tags: Option>); + fn set_tags(&mut self, tags: Option>); } #[derive(thiserror::Error, Debug)] @@ -554,7 +546,7 @@ pub trait L7ListenerHandler { host: &str, uri: &str, method: &Method, - ) -> Result; + ) -> Result; } #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -602,6 +594,8 @@ pub enum RetrieveClusterError { NoPath, #[error("unauthorized route")] UnauthorizedRoute, + #[error("redirected")] + Redirected, #[error("{0}")] RetrieveFrontend(FrontendFromRequestError), } @@ -624,8 +618,8 @@ pub enum ListenerError { Resolver(CertificateResolverError), #[error("failed to parse pem, {0}")] PemParse(String), - #[error("failed to parse template {0}: {1}")] - TemplateParse(u16, TemplateError), + #[error("failed to parse template {0:?}: {1}")] + TemplateParse(String, TemplateError), #[error("failed to build rustls context, {0}")] BuildRustls(String), #[error("could not activate listener with address {address:?}: {error}")] @@ -953,7 +947,6 @@ pub struct SessionMetrics { pub service_start: Option, pub wait_start: Instant, - pub backend_id: Option, pub backend_start: Option, pub backend_connected: Option, pub backend_stop: Option, @@ -971,7 +964,6 @@ impl SessionMetrics { bout: 0, service_start: None, wait_start: Instant::now(), - backend_id: None, backend_start: None, backend_connected: None, backend_stop: None, @@ -1072,7 +1064,7 @@ impl SessionMetrics { time!("request_time", request_time.as_millis()); time!("service_time", service_time.as_millis()); - if let Some(backend_id) = self.backend_id.as_ref() { + if let Some(backend_id) = context.backend_id { if let Some(backend_response_time) = self.backend_response_time() { record_backend_metrics!( context.cluster_id.as_str_or("-"), diff --git a/lib/src/protocol/kawa_h1/answers.rs b/lib/src/protocol/kawa_h1/answers.rs index 06be73b56..b0ea8b49c 100644 --- a/lib/src/protocol/kawa_h1/answers.rs +++ b/lib/src/protocol/kawa_h1/answers.rs @@ -1,15 +1,18 @@ use crate::{protocol::http::DefaultAnswer, sozu_command::state::ClusterId}; use kawa::{ - h1::NoCallbacks, AsBuffer, Block, BodySize, Buffer, Chunk, Kawa, Kind, Pair, ParsingPhase, - ParsingPhaseMarker, StatusLine, Store, + h1::NoCallbacks, AsBuffer, Block, BodySize, Buffer, Chunk, Flags, Kawa, Kind, Pair, + ParsingPhase, ParsingPhaseMarker, StatusLine, Store, }; -use sozu_command::proto::command::CustomHttpAnswers; +use nom::AsBytes; use std::{ - collections::{HashMap, VecDeque}, + collections::{BTreeMap, HashMap, VecDeque}, fmt, rc::Rc, + str::from_utf8_unchecked, }; +use super::parser::compare_no_case; + #[derive(Clone)] pub struct SharedBuffer(Rc<[u8]>); @@ -33,6 +36,8 @@ pub enum TemplateError { InvalidTemplate(ParsingPhase), #[error("unexpected status code: {0}")] InvalidStatusCode(u16), + #[error("unexpected size info: {0:?}")] + InvalidSizeInfo(BodySize), #[error("streaming is not supported in templates")] UnsupportedStreaming, #[error("template variable {0} is not allowed in headers")] @@ -48,6 +53,7 @@ pub struct TemplateVariable { name: &'static str, valid_in_body: bool, valid_in_header: bool, + or_elide_header: bool, typ: ReplacementType, } @@ -61,11 +67,14 @@ pub enum ReplacementType { #[derive(Clone, Copy, Debug)] pub struct Replacement { block_index: usize, + or_elide_header: bool, typ: ReplacementType, } // TODO: rename for clarity, for instance HttpAnswerTemplate pub struct Template { + status: u16, + keep_alive: bool, kawa: DefaultAnswerStream, body_replacements: Vec, header_replacements: Vec, @@ -86,8 +95,8 @@ impl fmt::Debug for Template { impl Template { /// sanitize the template: transform newlines \r (CR) to \r\n (CRLF) fn new( - status: u16, - answer: String, + status: Option, + answer: &str, variables: &[TemplateVariable], ) -> Result { let mut i = 0; @@ -124,28 +133,54 @@ impl Template { if !kawa.is_main_phase() { return Err(TemplateError::InvalidTemplate(kawa.parsing_phase)); } - if let StatusLine::Response { code, .. } = kawa.detached.status_line { - if code != status { - return Err(TemplateError::InvalidStatusCode(code)); + if kawa.body_size != BodySize::Empty { + return Err(TemplateError::InvalidSizeInfo(kawa.body_size)); + } + let status = if let StatusLine::Response { code, .. } = &kawa.detached.status_line { + if let Some(expected_code) = status { + if expected_code != *code { + return Err(TemplateError::InvalidStatusCode(*code)); + } } + *code } else { return Err(TemplateError::InvalidType); - } + }; let buf = kawa.storage.buffer(); let mut blocks = VecDeque::new(); let mut header_replacements = Vec::new(); let mut body_replacements = Vec::new(); let mut body_size = 0; + let mut keep_alive = true; let mut used_once = Vec::new(); for mut block in kawa.blocks.into_iter() { match &mut block { Block::ChunkHeader(_) => return Err(TemplateError::UnsupportedStreaming), + Block::Flags(Flags { + end_header: true, .. + }) => { + header_replacements.push(Replacement { + block_index: blocks.len(), + or_elide_header: false, + typ: ReplacementType::ContentLength, + }); + blocks.push_back(Block::Header(Pair { + key: Store::Static(b"Content-Length"), + val: Store::Static(b"PLACEHOLDER"), + })); + blocks.push_back(block); + } Block::StatusLine | Block::Cookies | Block::Flags(_) => { blocks.push_back(block); } Block::Header(Pair { key, val }) => { let val_data = val.data(buf); let key_data = key.data(buf); + if compare_no_case(key_data, b"connection") + && compare_no_case(val_data, b"close") + { + keep_alive = false; + } if let Some(b'%') = val_data.first() { for variable in &variables { if &val_data[1..] == variable.name.as_bytes() { @@ -163,14 +198,11 @@ impl Template { } used_once.push(var_index); } - ReplacementType::ContentLength => { - if let Some(b'%') = key_data.first() { - *key = Store::new_slice(buf, &key_data[1..]); - } - } + ReplacementType::ContentLength => {} } header_replacements.push(Replacement { block_index: blocks.len(), + or_elide_header: variable.or_elide_header, typ: variable.typ, }); break; @@ -213,6 +245,7 @@ impl Template { } body_replacements.push(Replacement { block_index: blocks.len(), + or_elide_header: false, typ: variable.typ, }); blocks.push_back(Block::Chunk(Chunk { @@ -234,6 +267,8 @@ impl Template { } kawa.blocks = blocks; Ok(Self { + status, + keep_alive, kawa, body_replacements, header_replacements, @@ -277,6 +312,10 @@ impl Template { pair.val = Store::from_string(body_size.to_string()) } } + if pair.val.len() == 0 && replacement.or_elide_header { + pair.elide(); + continue; + } } } Kawa { @@ -293,44 +332,13 @@ impl Template { } } -/// a set of templates for HTTP answers, meant for one listener to use -pub struct ListenerAnswers { - /// MovedPermanently - pub answer_301: Template, - /// BadRequest - pub answer_400: Template, - /// Unauthorized - pub answer_401: Template, - /// NotFound - pub answer_404: Template, - /// RequestTimeout - pub answer_408: Template, - /// PayloadTooLarge - pub answer_413: Template, - /// BadGateway - pub answer_502: Template, - /// ServiceUnavailable - pub answer_503: Template, - /// GatewayTimeout - pub answer_504: Template, - /// InsufficientStorage - pub answer_507: Template, -} - -/// templates for HTTP answers, set for one cluster -#[allow(non_snake_case)] -pub struct ClusterAnswers { - /// ServiceUnavailable - pub answer_503: Template, -} - pub struct HttpAnswers { - pub listener_answers: ListenerAnswers, // configurated answers - pub cluster_custom_answers: HashMap, + pub cluster_answers: HashMap>, + pub listener_answers: BTreeMap, + pub fallback: Template, } // const HEADERS: &str = "Connection: close\r -// Content-Length: 0\r // Sozu-Id: %REQUEST_ID\r // \r"; // const STYLE: &str = ""; // const FOOTER: &str = "
This is an automatic answer by Sōzu.
"; +fn fallback() -> String { + String::from( + "\ +HTTP/1.1 404 Not Found\r +Cache-Control: no-cache\r +Connection: close\r +Sozu-Id: %REQUEST_ID\r +\r + + +

404 Not Found

+
+{
+    \"status_code\": 404,
+    \"route\": \"%ROUTE\",
+    \"rewritten_url\": \"%REDIRECT_LOCATION\",
+    \"request_id\": \"%REQUEST_ID\"
+    \"cluster_id\": \"%CLUSTER_ID\",
+}
+
+

A frontend requested template \"%TEMPLATE_NAME\" that couldn't be found

+
This is an automatic answer by Sōzu.
", + ) +} + fn default_301() -> String { String::from( "\ HTTP/1.1 301 Moved Permanently\r Location: %REDIRECT_LOCATION\r Connection: close\r -Content-Length: 0\r Sozu-Id: %REQUEST_ID\r \r\n", ) @@ -360,7 +392,6 @@ fn default_400() -> String { HTTP/1.1 400 Bad Request\r Cache-Control: no-cache\r Connection: close\r -%Content-Length: %CONTENT_LENGTH\r Sozu-Id: %REQUEST_ID\r \r @@ -407,6 +438,7 @@ fn default_401() -> String { String::from( "\ HTTP/1.1 401 Unauthorized\r +WWW-Authenticate: %WWW_AUTHENTICATE\r Cache-Control: no-cache\r Connection: close\r Sozu-Id: %REQUEST_ID\r @@ -430,7 +462,6 @@ fn default_404() -> String { "\ HTTP/1.1 404 Not Found\r Cache-Control: no-cache\r -Connection: close\r Sozu-Id: %REQUEST_ID\r \r @@ -476,7 +507,6 @@ fn default_413() -> String { HTTP/1.1 413 Payload Too Large\r Cache-Control: no-cache\r Connection: close\r -%Content-Length: %CONTENT_LENGTH\r Sozu-Id: %REQUEST_ID\r \r @@ -500,7 +530,6 @@ fn default_502() -> String { HTTP/1.1 502 Bad Gateway\r Cache-Control: no-cache\r Connection: close\r -%Content-Length: %CONTENT_LENGTH\r Sozu-Id: %REQUEST_ID\r \r @@ -514,9 +543,9 @@ Sozu-Id: %REQUEST_ID\r \"cluster_id\": \"%CLUSTER_ID\", \"backend_id\": \"%BACKEND_ID\", \"parsing_phase\": \"%PHASE\", - \"successfully_parsed\": \"%SUCCESSFULLY_PARSED\", - \"partially_parsed\": \"%PARTIALLY_PARSED\", - \"invalid\": \"%INVALID\" + \"successfully_parsed\": %SUCCESSFULLY_PARSED, + \"partially_parsed\": %PARTIALLY_PARSED, + \"invalid\": %INVALID }

Response could not be parsed. %MESSAGE

@@ -551,7 +580,6 @@ fn default_503() -> String { HTTP/1.1 503 Service Unavailable\r Cache-Control: no-cache\r Connection: close\r -%Content-Length: %CONTENT_LENGTH\r Sozu-Id: %REQUEST_ID\r \r @@ -602,7 +630,6 @@ fn default_507() -> String { HTTP/1.1 507 Insufficient Storage\r Cache-Control: no-cache\r Connection: close\r -%Content-Length: %CONTENT_LENGTH\r Sozu-Id: %REQUEST_ID\r \r @@ -638,225 +665,218 @@ fn phase_to_vec(phase: ParsingPhaseMarker) -> Vec { impl HttpAnswers { #[rustfmt::skip] - pub fn template(status: u16, answer: String) -> Result { - let length = TemplateVariable { - name: "CONTENT_LENGTH", - valid_in_body: false, - valid_in_header: true, - typ: ReplacementType::ContentLength, - }; - + pub fn template(name: &str, answer: &str) -> Result { let route = TemplateVariable { name: "ROUTE", valid_in_body: true, valid_in_header: true, + or_elide_header: false, typ: ReplacementType::Variable(0), }; let request_id = TemplateVariable { name: "REQUEST_ID", valid_in_body: true, valid_in_header: true, + or_elide_header: false, typ: ReplacementType::Variable(0), }; let cluster_id = TemplateVariable { name: "CLUSTER_ID", valid_in_body: true, valid_in_header: true, + or_elide_header: false, typ: ReplacementType::Variable(0), }; let backend_id = TemplateVariable { name: "BACKEND_ID", valid_in_body: true, valid_in_header: true, + or_elide_header: false, typ: ReplacementType::Variable(0), }; let duration = TemplateVariable { name: "DURATION", valid_in_body: true, valid_in_header: true, + or_elide_header: false, typ: ReplacementType::Variable(0), }; let capacity = TemplateVariable { name: "CAPACITY", valid_in_body: true, valid_in_header: true, + or_elide_header: false, typ: ReplacementType::Variable(0), }; let phase = TemplateVariable { name: "PHASE", valid_in_body: true, valid_in_header: true, + or_elide_header: false, typ: ReplacementType::Variable(0), }; let location = TemplateVariable { name: "REDIRECT_LOCATION", + valid_in_body: true, + valid_in_header: true, + or_elide_header: false, + typ: ReplacementType::VariableOnce(0), + }; + let www_authenticate = TemplateVariable { + name: "WWW_AUTHENTICATE", valid_in_body: false, valid_in_header: true, + or_elide_header: true, typ: ReplacementType::VariableOnce(0), }; let message = TemplateVariable { name: "MESSAGE", valid_in_body: true, valid_in_header: false, + or_elide_header: false, typ: ReplacementType::VariableOnce(0), }; let successfully_parsed = TemplateVariable { name: "SUCCESSFULLY_PARSED", valid_in_body: true, valid_in_header: false, + or_elide_header: false, typ: ReplacementType::Variable(0), }; let partially_parsed = TemplateVariable { name: "PARTIALLY_PARSED", valid_in_body: true, valid_in_header: false, + or_elide_header: false, typ: ReplacementType::Variable(0), }; let invalid = TemplateVariable { name: "INVALID", valid_in_body: true, valid_in_header: false, + or_elide_header: false, + typ: ReplacementType::Variable(0), + }; + let template_name = TemplateVariable { + name: "TEMPLATE_NAME", + valid_in_body: true, + valid_in_header: true, + or_elide_header: false, typ: ReplacementType::Variable(0), }; - match status { - 301 => Template::new( - 301, + match name { + "301" => Template::new( + Some(301), answer, - &[length, route, request_id, location] + &[route, request_id, location] ), - 400 => Template::new( - 400, + "400" => Template::new( + Some(400), answer, - &[length, route, request_id, message, phase, successfully_parsed, partially_parsed, invalid], + &[route, request_id, message, phase, successfully_parsed, partially_parsed, invalid], ), - 401 => Template::new( - 401, + "401" => Template::new( + Some(401), answer, - &[length, route, request_id] + &[route, request_id, www_authenticate] ), - 404 => Template::new( - 404, + "404" => Template::new( + Some(404), answer, - &[length, route, request_id] + &[route, request_id] ), - 408 => Template::new( - 408, + "408" => Template::new( + Some(408), answer, - &[length, route, request_id, duration] + &[route, request_id, duration] ), - 413 => Template::new( - 413, + "413" => Template::new( + Some(413), answer, - &[length, route, request_id, capacity, message, phase], + &[route, request_id, capacity, message, phase], ), - 502 => Template::new( - 502, + "502" => Template::new( + Some(502), answer, - &[length, route, request_id, cluster_id, backend_id, message, phase, successfully_parsed, partially_parsed, invalid], + &[route, request_id, cluster_id, backend_id, message, phase, successfully_parsed, partially_parsed, invalid], ), - 503 => Template::new( - 503, + "503" => Template::new( + Some(503), answer, - &[length, route, request_id, cluster_id, backend_id, message], + &[route, request_id, cluster_id, backend_id, message], ), - 504 => Template::new( - 504, + "504" => Template::new( + Some(504), answer, - &[length, route, request_id, cluster_id, backend_id, duration], + &[route, request_id, cluster_id, backend_id, duration], ), - 507 => Template::new( - 507, + "507" => Template::new( + Some(507), answer, - &[length, route, request_id, cluster_id, backend_id, capacity, message, phase], + &[route, request_id, cluster_id, backend_id, capacity, message, phase], ), - _ => Err(TemplateError::InvalidStatusCode(status)), + _ => Template::new( + None, + answer, + &[route, request_id, cluster_id, location, template_name] + ) } - .map_err(|e| (status, e)) + .map_err(|e| (name.to_owned(), e)) + } + + pub fn templates( + answers: &BTreeMap, + ) -> Result, (String, TemplateError)> { + answers + .iter() + .map(|(name, answer)| { + Self::template(name, answer).map(|template| (name.clone(), template)) + }) + .collect::>() } - pub fn new(conf: &Option) -> Result { + pub fn new(answers: &BTreeMap) -> Result { + let mut listener_answers = Self::templates(answers)?; + let expected_defaults: &[(&str, fn() -> String)] = &[ + ("301", default_301), + ("400", default_400), + ("401", default_401), + ("404", default_404), + ("408", default_408), + ("413", default_413), + ("502", default_502), + ("503", default_503), + ("504", default_504), + ("507", default_507), + ]; + for (name, default) in expected_defaults { + listener_answers + .entry(name.to_string()) + .or_insert_with(|| Self::template(name, &default()).unwrap()); + } Ok(HttpAnswers { - listener_answers: ListenerAnswers { - answer_301: Self::template( - 301, - conf.as_ref() - .and_then(|c| c.answer_301.clone()) - .unwrap_or(default_301()), - )?, - answer_400: Self::template( - 400, - conf.as_ref() - .and_then(|c| c.answer_400.clone()) - .unwrap_or(default_400()), - )?, - answer_401: Self::template( - 401, - conf.as_ref() - .and_then(|c| c.answer_401.clone()) - .unwrap_or(default_401()), - )?, - answer_404: Self::template( - 404, - conf.as_ref() - .and_then(|c| c.answer_404.clone()) - .unwrap_or(default_404()), - )?, - answer_408: Self::template( - 408, - conf.as_ref() - .and_then(|c| c.answer_408.clone()) - .unwrap_or(default_408()), - )?, - answer_413: Self::template( - 413, - conf.as_ref() - .and_then(|c| c.answer_413.clone()) - .unwrap_or(default_413()), - )?, - answer_502: Self::template( - 502, - conf.as_ref() - .and_then(|c| c.answer_502.clone()) - .unwrap_or(default_502()), - )?, - answer_503: Self::template( - 503, - conf.as_ref() - .and_then(|c| c.answer_503.clone()) - .unwrap_or(default_503()), - )?, - answer_504: Self::template( - 504, - conf.as_ref() - .and_then(|c| c.answer_504.clone()) - .unwrap_or(default_504()), - )?, - answer_507: Self::template( - 507, - conf.as_ref() - .and_then(|c| c.answer_507.clone()) - .unwrap_or(default_507()), - )?, - }, - cluster_custom_answers: HashMap::new(), + fallback: Self::template("", &fallback()).unwrap(), + listener_answers, + cluster_answers: HashMap::new(), }) } - pub fn add_custom_answer( + pub fn add_cluster_answers( &mut self, cluster_id: &str, - answer_503: String, - ) -> Result<(), (u16, TemplateError)> { - let answer_503 = Self::template(503, answer_503)?; - self.cluster_custom_answers - .insert(cluster_id.to_string(), ClusterAnswers { answer_503 }); + answers: &BTreeMap, + ) -> Result<(), (String, TemplateError)> { + self.cluster_answers + .entry(cluster_id.to_owned()) + .or_default() + .append(&mut Self::templates(answers)?); Ok(()) } - pub fn remove_custom_answer(&mut self, cluster_id: &str) { - self.cluster_custom_answers.remove(cluster_id); + pub fn remove_cluster_answers(&mut self, cluster_id: &str) { + self.cluster_answers.remove(cluster_id); } pub fn get( @@ -866,14 +886,14 @@ impl HttpAnswers { cluster_id: Option<&str>, backend_id: Option<&str>, route: String, - ) -> DefaultAnswerStream { + ) -> (u16, bool, DefaultAnswerStream) { let variables: Vec>; let mut variables_once: Vec>; - let template = match answer { + let name = match answer { DefaultAnswer::Answer301 { location } => { variables = vec![route.into(), request_id.into()]; variables_once = vec![location.into()]; - &self.listener_answers.answer_301 + "301" } DefaultAnswer::Answer400 { message, @@ -891,22 +911,22 @@ impl HttpAnswers { invalid.into(), ]; variables_once = vec![message.into()]; - &self.listener_answers.answer_400 + "400" } - DefaultAnswer::Answer401 {} => { + DefaultAnswer::Answer401 { www_authenticate } => { variables = vec![route.into(), request_id.into()]; - variables_once = vec![]; - &self.listener_answers.answer_401 + variables_once = vec![www_authenticate.map(Into::into).unwrap_or_default()]; + "401" } DefaultAnswer::Answer404 {} => { variables = vec![route.into(), request_id.into()]; variables_once = vec![]; - &self.listener_answers.answer_404 + "404" } DefaultAnswer::Answer408 { duration } => { variables = vec![route.into(), request_id.into(), duration.to_string().into()]; variables_once = vec![]; - &self.listener_answers.answer_408 + "408" } DefaultAnswer::Answer413 { message, @@ -920,7 +940,7 @@ impl HttpAnswers { phase_to_vec(phase), ]; variables_once = vec![message.into()]; - &self.listener_answers.answer_413 + "413" } DefaultAnswer::Answer502 { message, @@ -940,7 +960,7 @@ impl HttpAnswers { invalid.into(), ]; variables_once = vec![message.into()]; - &self.listener_answers.answer_502 + "502" } DefaultAnswer::Answer503 { message } => { variables = vec![ @@ -950,10 +970,7 @@ impl HttpAnswers { backend_id.unwrap_or_default().into(), ]; variables_once = vec![message.into()]; - cluster_id - .and_then(|id: &str| self.cluster_custom_answers.get(id)) - .map(|c| &c.answer_503) - .unwrap_or_else(|| &self.listener_answers.answer_503) + "503" } DefaultAnswer::Answer504 { duration } => { variables = vec![ @@ -964,7 +981,7 @@ impl HttpAnswers { duration.to_string().into(), ]; variables_once = vec![]; - &self.listener_answers.answer_504 + "504" } DefaultAnswer::Answer507 { phase, @@ -980,11 +997,30 @@ impl HttpAnswers { phase_to_vec(phase), ]; variables_once = vec![message.into()]; - &self.listener_answers.answer_507 + "507" + } + DefaultAnswer::AnswerCustom { name, location, .. } => { + variables = vec![ + route.into(), + request_id.into(), + cluster_id.unwrap_or_default().into(), + name.into(), + ]; + variables_once = vec![location.into()]; + unsafe { from_utf8_unchecked(variables[3].as_bytes()) } } }; // kawa::debug_kawa(&template.kawa); // println!("{template:#?}"); - template.fill(&variables, &mut variables_once) + let template = cluster_id + .and_then(|id| self.cluster_answers.get(id)) + .and_then(|answers| answers.get(name)) + .or_else(|| self.listener_answers.get(name)) + .unwrap_or(&self.fallback); + ( + template.status, + template.keep_alive, + template.fill(&variables, &mut variables_once), + ) } } diff --git a/lib/src/protocol/kawa_h1/editor.rs b/lib/src/protocol/kawa_h1/editor.rs index 26a8d3658..c48951ff0 100644 --- a/lib/src/protocol/kawa_h1/editor.rs +++ b/lib/src/protocol/kawa_h1/editor.rs @@ -1,18 +1,30 @@ use std::{ net::{IpAddr, SocketAddr}, + rc::Rc, str::{from_utf8, from_utf8_unchecked}, }; +use base64::Engine; use rusty_ulid::Ulid; +use sha2::{Digest, Sha256}; +use sozu_command::logging::CachedTags; use crate::{ pool::Checkout, protocol::http::{parser::compare_no_case, GenericHttpStream, Method}, - Protocol, + router::HeaderEdit, + Protocol, RetrieveClusterError, }; use sozu_command_lib::logging::LogContext; +#[derive(Debug)] +pub struct HttpRoute { + pub method: Option, + pub authority: Option, + pub path: Option, +} + /// This is the container used to store and use information about the session from within a Kawa parser callback #[derive(Debug)] pub struct HttpContext { @@ -23,16 +35,13 @@ pub struct HttpContext { pub keep_alive_frontend: bool, /// the value of the sticky session cookie in the request pub sticky_session_found: Option, - // ---------- Status Line - /// the value of the method in the request line - pub method: Option, - /// the value of the authority of the request (in the request line of "Host" header) - pub authority: Option, - /// the value of the path in the request line - pub path: Option, - /// the value of the status code in the response line + /// hashed value of the last authorization header + pub authorization_found: Option, + /// position of the last header of the request (the "Sozu-Id"), only valid until prepare is called + pub last_header: Option, + // ---------- Route + pub route: HttpRoute, pub status: Option, - /// the value of the reason in the response line pub reason: Option, // ---------- Additional optional data pub user_agent: Option, @@ -55,6 +64,10 @@ pub struct HttpContext { /// the sticky session that should be used /// used to create a "Set-Cookie" header in the response in case it differs from sticky_session_found pub sticky_session: Option, + /// Headers to add to response + pub headers_response: Rc<[HeaderEdit]>, + /// tags of the contacted frontend + pub tags: Option>, } impl kawa::h1::ParserCallbacks for HttpContext { @@ -78,6 +91,14 @@ impl HttpContext { /// - sticky cookie /// - user-agent fn on_request_headers(&mut self, request: &mut GenericHttpStream) { + if request.body_size == kawa::BodySize::Empty { + request.parsing_phase = kawa::ParsingPhase::Terminated; + request.push_block(kawa::Block::Header(kawa::Pair { + key: kawa::Store::Static(b"Content-Length"), + val: kawa::Store::Static(b"0"), + })); + }; + let buf = &mut request.storage.mut_buffer(); // Captures the request line @@ -88,21 +109,17 @@ impl HttpContext { .. } = &request.detached.status_line { - self.method = method.data_opt(buf).map(Method::new); - self.authority = authority + self.route.method = method.data_opt(buf).map(Method::new); + self.route.authority = authority .data_opt(buf) .and_then(|data| from_utf8(data).ok()) .map(ToOwned::to_owned); - self.path = path + self.route.path = path .data_opt(buf) .and_then(|data| from_utf8(data).ok()) .map(ToOwned::to_owned); } - // if self.method == Some(Method::Get) && request.body_size == kawa::BodySize::Empty { - // request.parsing_phase = kawa::ParsingPhase::Terminated; - // } - let public_ip = self.public_address.ip(); let public_port = self.public_address.port(); let proto = match self.protocol { @@ -117,7 +134,7 @@ impl HttpContext { let key = cookie.key.data(buf); if key == self.sticky_name.as_bytes() { let val = cookie.val.data(buf); - self.sticky_session_found = from_utf8(val).ok().map(|val| val.to_string()); + self.sticky_session_found = from_utf8(val).ok().map(ToOwned::to_owned); cookie.elide(); } } @@ -130,11 +147,14 @@ impl HttpContext { // - store X-Forwarded-For // - store Forwarded // - store User-Agent + // - compute sha256 of Authorization let mut x_for = None; let mut forwarded = None; let mut has_x_port = false; let mut has_x_proto = false; let mut has_connection = false; + + let mut auth = None; for block in &mut request.blocks { match block { kawa::Block::Header(header) if !header.is_elided() => { @@ -156,7 +176,7 @@ impl HttpContext { incr!("http.trusting.x_proto.diff"); debug!( "Trusting X-Forwarded-Proto for {:?} even though {:?} != {}", - self.authority, val, proto + self.route.authority, val, proto ); } } else if compare_no_case(key, b"X-Forwarded-Port") { @@ -169,7 +189,7 @@ impl HttpContext { incr!("http.trusting.x_port.diff"); debug!( "Trusting X-Forwarded-Port for {:?} even though {:?} != {}", - self.authority, val, expected + self.route.authority, val, expected ); } } else if compare_no_case(key, b"X-Forwarded-For") { @@ -182,12 +202,29 @@ impl HttpContext { .data_opt(buf) .and_then(|data| from_utf8(data).ok()) .map(ToOwned::to_owned); + } else if compare_no_case(key, b"Authorization") { + auth = Some(header); } } _ => {} } } + self.authorization_found = + auth.and_then(|header| header.val.data_opt(buf)) + .and_then(|auth| { + let (kind, token) = auth.trim_ascii_start().split_at("Basic ".len()); + compare_no_case(kind, b"Basic ").then_some(())?; + let token = base64::prelude::BASE64_STANDARD.decode(token).ok()?; + let (name, pwd) = token + .iter() + .position(|c| *c == b':') + .map(|i| token.split_at(i+1))?; + let mut auth = String::from_utf8(name.to_vec()).ok()?; + auth.push_str(&hex::encode(Sha256::digest(pwd))); + Some(auth) + }); + // If session_address is set: // - append its ip address to the list of "X-Forwarded-For" if it was found, creates it if not // - append "proto=[PROTO];for=[PEER];by=[PUBLIC]" to the list of "Forwarded" if it was found, creates it if not @@ -227,6 +264,8 @@ impl HttpContext { header.val = kawa::Store::from_string(new_value); } + request.blocks.reserve(8); + if !has_x_for { request.push_block(kawa::Block::Header(kawa::Pair { key: kawa::Store::Static(b"X-Forwarded-For"), @@ -275,6 +314,7 @@ impl HttpContext { })); } + self.last_header = Some(request.blocks.len()); // Create a custom "Sozu-Id" header request.push_block(kawa::Block::Header(kawa::Pair { key: kawa::Store::Static(b"Sozu-Id"), @@ -301,7 +341,7 @@ impl HttpContext { .map(ToOwned::to_owned); } - if self.method == Some(Method::Head) { + if self.route.method == Some(Method::Head) { response.parsing_phase = kawa::ParsingPhase::Terminated; } @@ -325,6 +365,8 @@ impl HttpContext { } } + response.blocks.reserve(2 + self.headers_response.len()); + // If the sticky_session is set and differs from the one found in the request // create a "Set-Cookie" header to update the sticky_name value if let Some(sticky_session) = &self.sticky_session { @@ -339,6 +381,8 @@ impl HttpContext { } } + apply_header_edits(response, self.headers_response.clone()); + // Create a custom "Sozu-Id" header response.push_block(kawa::Block::Header(kawa::Pair { key: kawa::Store::Static(b"Sozu-Id"), @@ -350,12 +394,15 @@ impl HttpContext { self.keep_alive_backend = true; self.keep_alive_frontend = true; self.sticky_session_found = None; - self.method = None; - self.authority = None; - self.path = None; + self.route.method = None; + self.route.authority = None; + self.route.path = None; self.status = None; self.reason = None; self.user_agent = None; + self.cluster_id = None; + self.backend_id = None; + self.headers_response = Rc::new([]); } pub fn log_context(&self) -> LogContext { @@ -366,3 +413,26 @@ impl HttpContext { } } } + +impl HttpRoute { + // -> host, path, method + pub fn extract(&self) -> Result<(&str, &str, &Method), RetrieveClusterError> { + let given_method = self.method.as_ref().ok_or(RetrieveClusterError::NoMethod)?; + let given_authority = self + .authority + .as_deref() + .ok_or(RetrieveClusterError::NoHost)?; + let given_path = self.path.as_deref().ok_or(RetrieveClusterError::NoPath)?; + + Ok((given_authority, given_path, given_method)) + } +} + +pub fn apply_header_edits(kawa: &mut GenericHttpStream, headers: Rc<[HeaderEdit]>) { + for header in &*headers { + kawa.push_block(kawa::Block::Header(kawa::Pair { + key: kawa::Store::Shared(header.key.clone(), 0), + val: kawa::Store::Shared(header.val.clone(), 0), + })); + } +} diff --git a/lib/src/protocol/kawa_h1/mod.rs b/lib/src/protocol/kawa_h1/mod.rs index 270ca7e7e..8758f4a9d 100644 --- a/lib/src/protocol/kawa_h1/mod.rs +++ b/lib/src/protocol/kawa_h1/mod.rs @@ -6,19 +6,22 @@ pub mod parser; use std::{ cell::RefCell, io::ErrorKind, + mem, net::{Shutdown, SocketAddr}, rc::{Rc, Weak}, + str::from_utf8_unchecked, time::{Duration, Instant}, }; +use editor::{apply_header_edits, HttpRoute}; use mio::{net::TcpStream, Interest, Token}; +use parser::hostname_and_port; use rusty_ulid::Ulid; use sozu_command::{ config::MAX_LOOP_ITERATIONS, logging::EndpointRecord, - proto::command::{Event, EventKind, ListenerType}, + proto::command::{Event, EventKind, ListenerType, RedirectPolicy, RedirectScheme}, }; -// use time::{Duration, Instant}; use crate::{ backends::{Backend, BackendError}, @@ -34,14 +37,14 @@ use crate::{ SessionState, }, retry::RetryPolicy, - router::Route, + router::RouteResult, server::{push_event, CONN_RETRIES}, socket::{stats::socket_rtt, SocketHandler, SocketResult, TransportProtocol}, sozu_command::{logging::LogContext, ready::Ready}, timer::TimeoutContainer, - AcceptError, BackendConnectAction, BackendConnectionError, BackendConnectionStatus, - L7ListenerHandler, L7Proxy, ListenerHandler, Protocol, ProxySession, Readiness, - RetrieveClusterError, SessionIsToBeClosed, SessionMetrics, SessionResult, StateResult, + AcceptError, BackendConnectAction, BackendConnectionError, FrontendFromRequestError, + L7ListenerHandler, L7Proxy, Protocol, ProxySession, Readiness, RetrieveClusterError, + SessionIsToBeClosed, SessionMetrics, SessionResult, StateResult, }; /// This macro is defined uniquely in this module to help the tracking of kawa h1 @@ -55,7 +58,7 @@ macro_rules! log_context { $self.context.session_address.map(|addr| addr.to_string()).unwrap_or_else(|| "".to_string()), $self.frontend_token.0, $self.frontend_readiness, - $self.backend_token.map(|token| token.0.to_string()).unwrap_or_else(|| "".to_string()), + $self.origin.as_ref().map(|origin| origin.token.0.to_string()).unwrap_or_else(|| "".to_string()), $self.backend_readiness, ) }; @@ -75,6 +78,10 @@ impl kawa::AsBuffer for Checkout { #[derive(Debug, Clone, PartialEq, Eq)] pub enum DefaultAnswer { + AnswerCustom { + name: String, + location: String, + }, Answer301 { location: String, }, @@ -85,7 +92,9 @@ pub enum DefaultAnswer { partially_parsed: String, invalid: String, }, - Answer401 {}, + Answer401 { + www_authenticate: Option, + }, Answer404 {}, Answer408 { duration: String, @@ -109,29 +118,12 @@ pub enum DefaultAnswer { duration: String, }, Answer507 { - phase: kawa::ParsingPhaseMarker, message: String, + phase: kawa::ParsingPhaseMarker, capacity: usize, }, } -impl From<&DefaultAnswer> for u16 { - fn from(answer: &DefaultAnswer) -> u16 { - match answer { - DefaultAnswer::Answer301 { .. } => 301, - DefaultAnswer::Answer400 { .. } => 400, - DefaultAnswer::Answer401 { .. } => 401, - DefaultAnswer::Answer404 { .. } => 404, - DefaultAnswer::Answer408 { .. } => 408, - DefaultAnswer::Answer413 { .. } => 413, - DefaultAnswer::Answer502 { .. } => 502, - DefaultAnswer::Answer503 { .. } => 503, - DefaultAnswer::Answer504 { .. } => 504, - DefaultAnswer::Answer507 { .. } => 507, - } - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TimeoutStatus { Request, @@ -143,17 +135,35 @@ pub enum TimeoutStatus { pub enum ResponseStream { BackendAnswer(GenericHttpStream), DefaultAnswer(u16, DefaultAnswerStream), + DefaultAnswerKA(u16, DefaultAnswerStream, GenericHttpStream), + Swaping, +} + +#[derive(Debug)] +pub struct Origin { + pub cluster_id: String, + pub backend_id: String, + pub backend: Rc>, + pub token: Token, + connected: bool, + pub socket: TcpStream, +} + +impl Origin { + fn is_connected_to(&self, cluster_id: &str) -> bool { + self.cluster_id == cluster_id && self.connected + } } /// Http will be contained in State which itself is contained by Session -pub struct Http { +pub struct Http { answers: Rc>, - pub backend: Option>>, - backend_connection_status: BackendConnectionStatus, - pub backend_readiness: Readiness, - pub backend_socket: Option, + /// The last origin server we tried to communicate with. + /// It may be connected or connecting. It may be the server serving the current response, + /// or a server kept alive while a default answer is sent. Its cluster_id might differ from context.cluster_id + pub origin: Option, backend_stop: Option, - pub backend_token: Option, + pub backend_readiness: Readiness, pub container_backend_timeout: TimeoutContainer, pub container_frontend_timeout: TimeoutContainer, configured_backend_timeout: Duration, @@ -175,7 +185,7 @@ pub struct Http { pub context: HttpContext, } -impl Http { +impl Http { /// Instantiate a new HTTP SessionState with: /// /// - frontend_interest: READABLE | HUP | ERROR @@ -213,12 +223,9 @@ impl Http Http Http Http Http Http response_stream, - ResponseStream::DefaultAnswer(..) => { + _ => { error!( "{} Sending default answer, should not read from frontend socket", log_context!(self) @@ -617,7 +631,8 @@ impl Http response_stream, + ResponseStream::DefaultAnswer(_, response_stream) + | ResponseStream::DefaultAnswerKA(_, response_stream, _) => response_stream, _ => return StateResult::CloseSession, }; let bufs = response_stream.as_io_slice(); @@ -632,11 +647,16 @@ impl Http { + self.response_stream = ResponseStream::BackendAnswer(kawa_back); + metrics.reset(); + self.reset(); + return StateResult::Continue; + } + _ => StateResult::CloseSession, + }; } if socket_state == SocketResult::Error { @@ -663,10 +683,10 @@ impl Http Http Http { - backend_socket.read_error(); + // backend_socket.read_error(); + incr!("tcp.read.error"); self.log_request_error( metrics, &format!( - "back socket {socket_state:?}, closing session. Readiness: {:?} -> {:?}, read {size} bytes", + "back socket {socket_state:?}, closing session. Readiness: {:?} -> {:?}, read {size} bytes", self.frontend_readiness, self.backend_readiness, ), @@ -872,12 +893,12 @@ impl Http Http { +impl Http { fn log_endpoint(&self) -> EndpointRecord { EndpointRecord::Http { - method: self.context.method.as_deref(), - authority: self.context.authority.as_deref(), - path: self.context.path.as_deref(), + method: self.context.route.method.as_deref(), + authority: self.context.route.authority.as_deref(), + path: self.context.route.path.as_deref(), reason: self.context.reason.as_deref(), status: self.context.status, } @@ -890,14 +911,9 @@ impl Http Option { - self.backend + self.origin .as_ref() - .map(|backend| backend.borrow().address) - .or_else(|| { - self.backend_socket - .as_ref() - .and_then(|backend| backend.peer_addr().ok()) - }) + .map(|origin| origin.backend.borrow().address) } // The protocol name used in the access logs @@ -920,24 +936,15 @@ impl Http WebSocketContext { WebSocketContext::Http { - method: self.context.method.clone(), - authority: self.context.authority.clone(), - path: self.context.path.clone(), + method: self.context.route.method.clone(), + authority: self.context.route.authority.clone(), + path: self.context.route.path.clone(), reason: self.context.reason.clone(), status: self.context.status, } } pub fn log_request(&self, metrics: &SessionMetrics, error: bool, message: Option<&str>) { - let listener = self.listener.borrow(); - let tags = self.context.authority.as_ref().and_then(|host| { - let hostname = match host.split_once(':') { - None => host, - Some((hostname, _)) => hostname, - }; - listener.get_tags(hostname) - }); - let context = self.context.log_context(); metrics.register_end_of_session(&context); @@ -950,9 +957,9 @@ impl Http Http Http incr!( - "http.301.redirection", - self.context.cluster_id.as_deref(), - self.context.backend_id.as_deref() - ), - DefaultAnswer::Answer400 { .. } => incr!("http.400.errors"), - DefaultAnswer::Answer401 { .. } => incr!( - "http.401.errors", - self.context.cluster_id.as_deref(), - self.context.backend_id.as_deref() - ), - DefaultAnswer::Answer404 { .. } => incr!("http.404.errors"), - DefaultAnswer::Answer408 { .. } => incr!( - "http.408.errors", - self.context.cluster_id.as_deref(), - self.context.backend_id.as_deref() - ), - DefaultAnswer::Answer413 { .. } => incr!( - "http.413.errors", - self.context.cluster_id.as_deref(), - self.context.backend_id.as_deref() - ), - DefaultAnswer::Answer502 { .. } => incr!( - "http.502.errors", - self.context.cluster_id.as_deref(), - self.context.backend_id.as_deref() - ), - DefaultAnswer::Answer503 { .. } => incr!( - "http.503.errors", - self.context.cluster_id.as_deref(), - self.context.backend_id.as_deref() - ), - DefaultAnswer::Answer504 { .. } => incr!( - "http.504.errors", - self.context.cluster_id.as_deref(), - self.context.backend_id.as_deref() - ), - DefaultAnswer::Answer507 { .. } => incr!( - "http.507.errors", - self.context.cluster_id.as_deref(), - self.context.backend_id.as_deref() - ), - }; + match answer { + DefaultAnswer::AnswerCustom { .. } => incr!( + "http.custom_asnwers", + self.context.cluster_id.as_deref(), + self.context.backend_id.as_deref() + ), + DefaultAnswer::Answer301 { .. } => incr!( + "http.301.redirection", + self.context.cluster_id.as_deref(), + self.context.backend_id.as_deref() + ), + DefaultAnswer::Answer400 { .. } => incr!("http.400.errors"), + DefaultAnswer::Answer401 { .. } => incr!( + "http.401.errors", + self.context.cluster_id.as_deref(), + self.context.backend_id.as_deref() + ), + DefaultAnswer::Answer404 { .. } => incr!("http.404.errors"), + DefaultAnswer::Answer408 { .. } => incr!( + "http.408.errors", + self.context.cluster_id.as_deref(), + self.context.backend_id.as_deref() + ), + DefaultAnswer::Answer413 { .. } => incr!( + "http.413.errors", + self.context.cluster_id.as_deref(), + self.context.backend_id.as_deref() + ), + DefaultAnswer::Answer502 { .. } => incr!( + "http.502.errors", + self.context.cluster_id.as_deref(), + self.context.backend_id.as_deref() + ), + DefaultAnswer::Answer503 { .. } => incr!( + "http.503.errors", + self.context.cluster_id.as_deref(), + self.context.backend_id.as_deref() + ), + DefaultAnswer::Answer504 { .. } => incr!( + "http.504.errors", + self.context.cluster_id.as_deref(), + self.context.backend_id.as_deref() + ), + DefaultAnswer::Answer507 { .. } => incr!( + "http.507.errors", + self.context.cluster_id.as_deref(), + self.context.backend_id.as_deref() + ), } - - let mut kawa = self.answers.borrow().get( + let (status, keep_alive, mut kawa) = self.answers.borrow().get( answer, self.context.id.to_string(), self.context.cluster_id.as_deref(), self.context.backend_id.as_deref(), self.get_route(), ); + if let ResponseStream::DefaultAnswer(old_status, ..) = self.response_stream { + error!( + "already set the default answer to {}, trying to set to {}", + old_status, status + ); + }; kawa.prepare(&mut kawa::h1::BlockConverter); self.context.status = Some(status); self.context.reason = None; - self.context.keep_alive_frontend = false; - self.response_stream = ResponseStream::DefaultAnswer(status, kawa); + if keep_alive { + match mem::replace(&mut self.response_stream, ResponseStream::Swaping) { + ResponseStream::BackendAnswer(back_kawa) => { + self.response_stream = ResponseStream::DefaultAnswerKA(status, kawa, back_kawa); + } + _ => unreachable!(), + } + } else { + self.response_stream = ResponseStream::DefaultAnswer(status, kawa); + } self.frontend_readiness.interest = Ready::WRITABLE | Ready::HUP | Ready::ERROR; self.backend_readiness.interest = Ready::HUP | Ready::ERROR; } pub fn test_backend_socket(&self) -> bool { - match self.backend_socket { - Some(ref s) => { + match &self.origin { + Some(origin) => { let mut tmp = [0u8; 1]; - let res = s.peek(&mut tmp[..]); + let res = origin.socket.peek(&mut tmp[..]); match res { // if the socket is half open, it will report 0 bytes read (EOF) Ok(0) => false, Ok(_) => true, - Err(e) => matches!(e.kind(), std::io::ErrorKind::WouldBlock), + Err(e) => matches!(e.kind(), ErrorKind::WouldBlock), } } None => false, @@ -1070,110 +1084,76 @@ impl Http bool { // if socket was last used in the last second, test it - match self.backend_stop.as_ref() { + match &self.backend_stop { Some(stop_instant) => { - let now = Instant::now(); - let dur = now - *stop_instant; - if dur > Duration::from_secs(1) { - return self.test_backend_socket(); - } + stop_instant.elapsed() < Duration::from_secs(1) || self.test_backend_socket() } - None => return self.test_backend_socket(), + None => self.test_backend_socket(), } - - true - } - - pub fn set_backend_socket(&mut self, socket: TcpStream, backend: Option>>) { - self.backend_socket = Some(socket); - self.backend = backend; - } - - pub fn set_cluster_id(&mut self, cluster_id: String) { - self.context.cluster_id = Some(cluster_id); } - pub fn set_backend_id(&mut self, backend_id: String) { - self.context.backend_id = Some(backend_id); - } - - pub fn set_backend_token(&mut self, token: Token) { - self.backend_token = Some(token); - } - - pub fn clear_backend_token(&mut self) { - self.backend_token = None; - } - - pub fn set_backend_timeout(&mut self, dur: Duration) { - if let Some(token) = self.backend_token.as_ref() { - self.container_backend_timeout.set_duration(dur); - self.container_backend_timeout.set(*token); - } + pub fn set_backend_timeout(&mut self, token: Token, dur: Duration) { + self.container_backend_timeout.set_duration(dur); + self.container_backend_timeout.set(token); } pub fn front_socket(&self) -> &TcpStream { self.frontend_socket.socket_ref() } - /// WARNING: this function removes the backend entry in the session manager - /// IF the backend_token is set, so that entry can be reused for new backend. - /// I don't think this is a good idea, but it is a quick fix - fn close_backend(&mut self, proxy: Rc>, metrics: &mut SessionMetrics) { + fn close_backend( + &mut self, + proxy: Rc>, + reuse_entry: bool, + ) -> Option { self.container_backend_timeout.cancel(); - debug!( - "{}\tPROXY [{}->{}] CLOSED BACKEND", - log_context!(self), - self.frontend_token.0, - self.backend_token - .map(|t| format!("{}", t.0)) - .unwrap_or_else(|| "-".to_string()) - ); let proxy = proxy.borrow(); - if let Some(socket) = &mut self.backend_socket.take() { - if let Err(e) = proxy.deregister_socket(socket) { + if let Some(mut origin) = self.origin.take() { + debug!( + "{}\tPROXY [{}->{}] CLOSE BACKEND {:?}", + log_context!(self), + self.frontend_token.0, + format!("{}", origin.token.0), + origin + ); + + if let Err(e) = proxy.deregister_socket(&mut origin.socket) { error!( "{} Error deregistering back socket({:?}): {:?}", log_context!(self), - socket, + origin.socket, e ); } - if let Err(e) = socket.shutdown(Shutdown::Both) { + if let Err(e) = origin.socket.shutdown(Shutdown::Both) { if e.kind() != ErrorKind::NotConnected { error!( "{} Error shutting down back socket({:?}): {:?}", log_context!(self), - socket, + origin.socket, e ); } } - } - - if let Some(token) = self.backend_token.take() { - proxy.remove_session(token); - - if self.backend_connection_status != BackendConnectionStatus::NotConnected { - self.backend_readiness.event = Ready::EMPTY; - } - - if self.backend_connection_status == BackendConnectionStatus::Connected { + self.backend_readiness.event = Ready::EMPTY; + if origin.connected { gauge_add!("backend.connections", -1); gauge_add!( "connections_per_backend", -1, - self.context.cluster_id.as_deref(), - metrics.backend_id.as_deref() + Some(&origin.cluster_id), + Some(&origin.backend_id) ); } + origin.backend.borrow_mut().dec_connections(); - self.set_backend_connected(BackendConnectionStatus::NotConnected, metrics); - - if let Some(backend) = self.backend.take() { - backend.borrow_mut().dec_connections(); + if !reuse_entry { + proxy.remove_session(origin.token); } + Some(origin.token) + } else { + None } } @@ -1197,48 +1177,24 @@ impl Http bool { + fn check_backend_connection(&self, metrics: &mut SessionMetrics) -> bool { let is_valid_backend_socket = self.is_valid_backend_socket(); if !is_valid_backend_socket { return false; } - //matched on keepalive - metrics.backend_id = self.backend.as_ref().map(|i| i.borrow().backend_id.clone()); - metrics.backend_start(); - if let Some(b) = self.backend.as_mut() { - b.borrow_mut().active_requests += 1; + if let Some(origin) = &self.origin { + origin.backend.borrow_mut().active_requests += 1; } true } - // -> host, path, method - pub fn extract_route(&self) -> Result<(&str, &str, &Method), RetrieveClusterError> { - let given_method = self - .context - .method - .as_ref() - .ok_or(RetrieveClusterError::NoMethod)?; - let given_authority = self - .context - .authority - .as_deref() - .ok_or(RetrieveClusterError::NoHost)?; - let given_path = self - .context - .path - .as_deref() - .ok_or(RetrieveClusterError::NoPath)?; - - Ok((given_authority, given_path, given_method)) - } - pub fn get_route(&self) -> String { - if let Some(method) = &self.context.method { - if let Some(authority) = &self.context.authority { - if let Some(path) = &self.context.path { + if let Some(method) = &self.context.route.method { + if let Some(authority) = &self.context.route.authority { + if let Some(path) = &self.context.route.path { return format!("{method} {authority}{path}"); } return format!("{method} {authority}"); @@ -1252,12 +1208,12 @@ impl Http>, ) -> Result { - let (host, uri, method) = match self.extract_route() { + let (host, path, method) = match self.context.route.extract() { Ok(tuple) => tuple, Err(cluster_error) => { self.set_answer(DefaultAnswer::Answer400 { message: "Could not extract the route after connection started, this should not happen.".into(), - phase: self.request_stream.parsing_phase.marker(), + phase: kawa::ParsingPhaseMarker::StatusLine, successfully_parsed: "null".into(), partially_parsed: "null".into(), invalid: "null".into(), @@ -1266,43 +1222,184 @@ impl Http (unsafe { from_utf8_unchecked(hostname) }, port), + Ok(_) => { + let host = host.to_owned(); + self.set_answer(DefaultAnswer::Answer400 { + message: "Invalid characters after hostname, this should not happen.".into(), + phase: kawa::ParsingPhaseMarker::StatusLine, + successfully_parsed: "null".into(), + partially_parsed: "null".into(), + invalid: "null".into(), + }); + return Err(RetrieveClusterError::RetrieveFrontend( + FrontendFromRequestError::InvalidCharsAfterHost(host), + )); + } + Err(parse_error) => { + let host = host.to_owned(); + let error = parse_error.to_string(); + self.set_answer(DefaultAnswer::Answer400 { + message: "Could not parse port from hostname, this should not happen.".into(), + phase: kawa::ParsingPhaseMarker::StatusLine, + successfully_parsed: "null".into(), + partially_parsed: "null".into(), + invalid: "null".into(), + }); + return Err(RetrieveClusterError::RetrieveFrontend( + FrontendFromRequestError::HostParse { host, error }, + )); + } + }; + + let start = Instant::now(); let route_result = self .listener .borrow() - .frontend_from_request(host, uri, method); + .frontend_from_request(host, path, method); - let route = match route_result { + let RouteResult { + cluster_id, + required_auth, + redirect, + redirect_scheme, + redirect_template, + rewritten_host, + rewritten_path, + rewritten_port, + headers_request, + headers_response, + tags, + } = match route_result { Ok(route) => route, Err(frontend_error) => { self.set_answer(DefaultAnswer::Answer404 {}); return Err(RetrieveClusterError::RetrieveFrontend(frontend_error)); } }; + self.context.cluster_id = cluster_id; + self.context.headers_response = headers_response; + self.context.tags = tags; + let cluster_id = &self.context.cluster_id; + + if let Some(cluster_id) = cluster_id { + time!( + "frontend_matching_time", + cluster_id, + start.elapsed().as_millis() + ); + } - let cluster_id = match route { - Route::ClusterId(cluster_id) => cluster_id, - Route::Deny => { - self.set_answer(DefaultAnswer::Answer401 {}); - return Err(RetrieveClusterError::UnauthorizedRoute); - } + let body = if let Some(position) = self.context.last_header { + self.request_stream.blocks.split_off(position) + } else { + unreachable!(); }; - let frontend_should_redirect_https = matches!(proxy.borrow().kind(), ListenerType::Http) - && proxy - .borrow() - .clusters() - .get(&cluster_id) - .map(|cluster| cluster.https_redirect) - .unwrap_or(false); - - if frontend_should_redirect_https { - self.set_answer(DefaultAnswer::Answer301 { - location: format!("https://{host}{uri}"), - }); - return Err(RetrieveClusterError::UnauthorizedRoute); + match &mut self.request_stream.detached.status_line { + kawa::StatusLine::Request { authority, uri, .. } => { + let buf = self.request_stream.storage.mut_buffer(); + if let Some(new) = &rewritten_host { + authority.modify(buf, new.as_bytes()); + self.request_stream + .blocks + .push_back(kawa::Block::Header(kawa::Pair { + key: kawa::Store::Static(b"X-Forwarded-Host"), + val: kawa::Store::from_slice(host.as_bytes()), + })); + } + if let Some(new) = &rewritten_path { + uri.modify(buf, new.as_bytes()); + } + } + _ => unreachable!(), } - Ok(cluster_id) + let host = rewritten_host.as_deref().unwrap_or(host); + let path = rewritten_path.as_deref().unwrap_or(path); + let is_https = matches!(proxy.borrow().kind(), ListenerType::Https); + let proto = match (redirect_scheme, is_https) { + (RedirectScheme::UseHttp, _) | (RedirectScheme::UseSame, false) => "http", + (RedirectScheme::UseHttps, _) | (RedirectScheme::UseSame, true) => "https", + }; + + let (authorized, www_authenticate, https_redirect, https_redirect_port) = + match (&cluster_id, redirect, &redirect_template, required_auth) { + // unauthorized frontends + (_, RedirectPolicy::Unauthorized, _, _) => (false, None, false, None), + // forward frontends with no target (no cluster nor template) + (None, RedirectPolicy::Forward, None, _) => (false, None, false, None), + // clusterless frontend with auth (unsupported) + (None, _, _, true) => (false, None, false, None), + // clusterless frontends + (None, _, _, false) => (true, None, false, None), + // "attached" frontends + (Some(cluster_id), _, _, _) => { + proxy.borrow().clusters().get(cluster_id).map_or( + (true, None, false, None), // cluster not found, consider authorized? + |cluster| { + let authorized = + match (required_auth, &self.context.authorization_found) { + // auth not required + (false, _) => true, + // no auth found + (true, None) => false, + // validation + (true, Some(hash)) => { + println!("{hash:?}"); + cluster.authorized_hashes.contains(hash) + } + }; + ( + authorized, + cluster.www_authenticate.clone(), + cluster.https_redirect, + cluster.https_redirect_port, + ) + }, + ) + } + }; + + let port = match ( + port, + rewritten_port, + https_redirect_port, + !is_https && https_redirect, + ) { + (_, Some(port), _, _) => format!(":{port}"), + (_, _, Some(port), true) => format!(":{port}"), + (Some(port), _, _, _) => format!(":{}", unsafe { from_utf8_unchecked(port) }), + _ => String::new(), + }; + + match (cluster_id, redirect, redirect_template, authorized) { + (_, RedirectPolicy::Permanent, _, true) => { + let location = format!("{proto}://{host}{port}{path}"); + self.set_answer(DefaultAnswer::Answer301 { location }); + Err(RetrieveClusterError::Redirected) + } + (_, RedirectPolicy::Forward, Some(name), true) => { + let location = format!("{proto}://{host}{port}{path}"); + self.set_answer(DefaultAnswer::AnswerCustom { name, location }); + Err(RetrieveClusterError::Redirected) + } + (Some(cluster_id), RedirectPolicy::Forward, None, true) => { + if !is_https && https_redirect { + let location = format!("https://{host}{port}{path}"); + self.set_answer(DefaultAnswer::Answer301 { location }); + return Err(RetrieveClusterError::Redirected); + } + apply_header_edits(&mut self.request_stream, headers_request); + self.request_stream.blocks.extend(body); + Ok(cluster_id.clone()) + } + _ => { + self.set_answer(DefaultAnswer::Answer401 { www_authenticate }); + Err(RetrieveClusterError::UnauthorizedRoute) + } + } } pub fn backend_from_request( @@ -1310,8 +1407,7 @@ impl Http>, - metrics: &mut SessionMetrics, - ) -> Result { + ) -> Result<(Rc>, TcpStream), BackendConnectionError> { let (backend, conn) = self .get_backend_for_sticky_session( frontend_should_stick, @@ -1341,12 +1437,7 @@ impl Http Http>, metrics: &mut SessionMetrics, ) -> Result { - let old_cluster_id = self.context.cluster_id.clone(); - let old_backend_token = self.backend_token; - self.check_circuit_breaker()?; - let cluster_id = self - .cluster_id_from_request(proxy.clone()) - .map_err(BackendConnectionError::RetrieveClusterError)?; + let cluster_id = if self.connection_attempts == 0 { + // cluster_id is determined from the triplet (method, authority, path) + // which doesn't ever change for a request + // WARNING: cluster_id_from_request is NOT idempotent, but connect_to_backend SHOULD be + self.cluster_id_from_request(proxy.clone()) + .map_err(BackendConnectionError::RetrieveClusterError)? + } else if let Some(cluster_id) = self.context.cluster_id.take() { + cluster_id + } else { + unreachable!(); + }; + + // update response cluster producer + self.context.cluster_id = Some(cluster_id.clone()); trace!( - "{} Connect_to_backend: {:?} {:?} {:?}", + "{} Connect_to_backend: {:?}", log_context!(self), - self.context.cluster_id, cluster_id, - self.backend_connection_status ); + // check if we can reuse the backend connection - if (self.context.cluster_id.as_ref()) == Some(&cluster_id) - && self.backend_connection_status == BackendConnectionStatus::Connected - { - let has_backend = self - .backend - .as_ref() - .map(|backend| { - let backend = backend.borrow(); - proxy - .borrow() - .backends() - .borrow() - .has_backend(&cluster_id, &backend) - }) - .unwrap_or(false); - - if has_backend && self.check_backend_connection(metrics) { - return Ok(BackendConnectAction::Reuse); - } else if self.backend_token.take().is_some() { - self.close_backend(proxy.clone(), metrics); - } - } + if let Some(origin) = &self.origin { + if origin.is_connected_to(&cluster_id) { + let has_backend = proxy + .borrow() + .backends() + .borrow() + .has_backend(&cluster_id, &origin.backend.borrow()); - //replacing with a connection to another cluster - if old_cluster_id.is_some() - && old_cluster_id.as_ref() != Some(&cluster_id) - && self.backend_token.take().is_some() - { - self.close_backend(proxy.clone(), metrics); + if has_backend && self.check_backend_connection(metrics) { + // update response backend producer + self.context.backend_id = Some(origin.backend.borrow().backend_id.clone()); + return Ok(BackendConnectAction::Reuse); + } + } } - self.context.cluster_id = Some(cluster_id.clone()); + // close old backend (if it exists), replace with a connection to another cluster + let reusable_token = self.close_backend(proxy.clone(), true); let frontend_should_stick = proxy .borrow() @@ -1433,8 +1517,9 @@ impl Http Http { - self.set_backend_token(backend_token); - if let Err(e) = proxy.borrow().register_socket( - &mut socket, - backend_token, - Interest::READABLE | Interest::WRITABLE, - ) { - error!( - "{} Error registering back socket({:?}): {:?}", - log_context!(self), - socket, - e - ); - } + let backend_id = backend.borrow().backend_id.clone(); + // update response backend producer + self.context.backend_id = Some(backend_id.clone()); - self.set_backend_socket(socket, self.backend.clone()); - self.set_backend_timeout(self.configured_connect_timeout); + let (token, action) = if let Some(token) = reusable_token { + (token, BackendConnectAction::Replace) + } else { + ( + proxy.borrow().add_session(session_rc), + BackendConnectAction::New, + ) + }; - Ok(BackendConnectAction::Replace) - } - None => { - let backend_token = proxy.borrow().add_session(session_rc); - - if let Err(e) = proxy.borrow().register_socket( - &mut socket, - backend_token, - Interest::READABLE | Interest::WRITABLE, - ) { - error!( - "{} Error registering back socket({:?}): {:?}", - log_context!(self), - socket, - e - ); - } + if let Err(e) = proxy.borrow().register_socket( + &mut socket, + token, + Interest::READABLE | Interest::WRITABLE, + ) { + error!( + "{} Error registering backend socket({:?}): {:?}", + log_context!(self), + socket, + e + ); + } - self.set_backend_socket(socket, self.backend.clone()); - self.set_backend_token(backend_token); - self.set_backend_timeout(self.configured_connect_timeout); + self.origin = Some(Origin { + cluster_id, + backend_id, + backend, + token, + connected: false, + socket, + }); + self.set_backend_timeout(token, self.configured_connect_timeout); - Ok(BackendConnectAction::New) - } - } + Ok(action) } - fn set_backend_connected( - &mut self, - connected: BackendConnectionStatus, - metrics: &mut SessionMetrics, - ) { - let last = self.backend_connection_status; - self.backend_connection_status = connected; + fn set_backend_connected(&mut self, metrics: &mut SessionMetrics) { + if let Some(Origin { + connected: connected @ false, + token, + backend, + backend_id, + cluster_id, + .. + }) = &mut self.origin + { + *connected = true; - if connected == BackendConnectionStatus::Connected { gauge_add!("backend.connections", 1); gauge_add!( "connections_per_backend", 1, - self.context.cluster_id.as_deref(), - metrics.backend_id.as_deref() + Some(cluster_id), + Some(backend_id) ); - // the back timeout was of connect_timeout duration before, - // now that we're connected, move to backend_timeout duration - self.set_backend_timeout(self.configured_backend_timeout); - // if we are not waiting for the backend response, its timeout is concelled - // it should be set when the request has been entirely transmitted - if !self.backend_readiness.interest.is_readable() { - self.container_backend_timeout.cancel(); - } - - if let Some(backend) = &self.backend { + { let mut backend = backend.borrow_mut(); - if backend.retry_policy.is_down() { - incr!( - "backend.up", - self.context.cluster_id.as_deref(), - metrics.backend_id.as_deref() - ); - + incr!("backend.up", Some(cluster_id), Some(backend_id)); info!( "backend server {} at {} is up", backend.backend_id, backend.address @@ -1541,21 +1605,36 @@ impl Http Http Http Http StateResult { + pub fn backend_hup(&mut self, _metrics: &mut SessionMetrics) -> StateResult { let response_stream = match &mut self.response_stream { ResponseStream::BackendAnswer(response_stream) => response_stream, _ => return StateResult::CloseBackend, @@ -1673,7 +1748,11 @@ impl Http SessionResult { let mut counter = 0; - if self.backend_connection_status.is_connecting() + if self + .origin + .as_ref() + .map(|origin| !origin.connected) + .unwrap_or(false) && !self.backend_readiness.event.is_empty() { if self.backend_readiness.event.is_hup() && !self.test_backend_socket() { @@ -1685,13 +1764,10 @@ impl Http Http Http self.close_backend(proxy.clone(), metrics), + StateResult::CloseBackend => { + self.close_backend(proxy.clone(), false); + } StateResult::CloseSession => return SessionResult::Close, StateResult::Upgrade => return SessionResult::Upgrade, StateResult::Continue => {} @@ -1831,7 +1909,9 @@ impl Http {} - StateResult::CloseBackend => self.close_backend(proxy.clone(), metrics), + StateResult::CloseBackend => { + self.close_backend(proxy.clone(), false); + } StateResult::CloseSession => return SessionResult::Close, StateResult::ConnectBackend | StateResult::Upgrade => unreachable!(), } @@ -1871,7 +1951,7 @@ impl Http SessionState for Http { +impl SessionState for Http { fn ready( &mut self, session: Rc>, @@ -1901,25 +1981,32 @@ impl SessionState fn update_readiness(&mut self, token: Token, events: Ready) { if self.frontend_token == token { self.frontend_readiness.event |= events; - } else if self.backend_token == Some(token) { + } else if self + .origin + .as_ref() + .map(|origin| origin.token == token) + .unwrap_or(false) + { self.backend_readiness.event |= events; } } - fn close(&mut self, proxy: Rc>, metrics: &mut SessionMetrics) { - self.close_backend(proxy, metrics); - self.frontend_socket.socket_close(); - let _ = self.frontend_socket.socket_write_vectored(&[]); - + fn close(&mut self, proxy: Rc>, _metrics: &mut SessionMetrics) { //if the state was initial, the connection was already reset if !self.request_stream.is_initial() { gauge_add!("http.active_requests", -1); - if let Some(b) = self.backend.as_mut() { - let mut backend = b.borrow_mut(); - backend.active_requests = backend.active_requests.saturating_sub(1); + if let Some(origin) = &self.origin { + if origin.connected { + let mut backend = origin.backend.borrow_mut(); + backend.active_requests = backend.active_requests.saturating_sub(1); + } } } + + self.close_backend(proxy, false); + self.frontend_socket.socket_close(); + let _ = self.frontend_socket.socket_write_vectored(&[]); } fn timeout(&mut self, token: Token, metrics: &mut SessionMetrics) -> StateResult { @@ -1951,7 +2038,12 @@ impl SessionState }; } - if self.backend_token == Some(token) { + if self + .origin + .as_ref() + .map(|origin| origin.token == token) + .unwrap_or(false) + { //info!("backend timeout triggered for token {:?}", token); self.container_backend_timeout.triggered(); return match self.timeout_status() { @@ -1992,21 +2084,18 @@ impl SessionState } fn print_state(&self, context: &str) { + let kawa_back = match &self.response_stream { + ResponseStream::BackendAnswer(kawa) => format!("{:?}", kawa.parsing_phase), + ResponseStream::DefaultAnswer(status, ..) + | ResponseStream::DefaultAnswerKA(status, ..) => format!("DefaulAnswer({status})"), + ResponseStream::Swaping => unreachable!(), + }; error!( - "\ -{} {} Session(Kawa) -\tFrontend: -\t\ttoken: {:?}\treadiness: {:?}\tstate: {:?} -\tBackend: -\t\ttoken: {:?}\treadiness: {:?}", + "{} {} kawa_front={:?}, kawa_back={}", log_context!(self), context, - self.frontend_token, - self.frontend_readiness, self.request_stream.parsing_phase, - self.backend_token, - self.backend_readiness, - // self.response_stream.parsing_phase + kawa_back ); } diff --git a/lib/src/protocol/pipe.rs b/lib/src/protocol/pipe.rs index 63d9dffc0..cb1c7349c 100644 --- a/lib/src/protocol/pipe.rs +++ b/lib/src/protocol/pipe.rs @@ -4,7 +4,7 @@ use mio::{net::TcpStream, Token}; use rusty_ulid::Ulid; use sozu_command::{ config::MAX_LOOP_ITERATIONS, - logging::{EndpointRecord, LogContext}, + logging::{CachedTags, EndpointRecord, LogContext}, }; use crate::{ @@ -14,7 +14,7 @@ use crate::{ socket::{stats::socket_rtt, SocketHandler, SocketResult, TransportProtocol}, sozu_command::ready::Ready, timer::TimeoutContainer, - L7Proxy, ListenerHandler, Protocol, Readiness, SessionMetrics, SessionResult, StateResult, + L7Proxy, Protocol, Readiness, SessionMetrics, SessionResult, StateResult, }; /// This macro is defined uniquely in this module to help the tracking of pipelining @@ -59,7 +59,7 @@ pub enum WebSocketContext { Tcp, } -pub struct Pipe { +pub struct Pipe { backend_buffer: Checkout, backend_id: Option, pub backend_readiness: Readiness, @@ -75,14 +75,14 @@ pub struct Pipe { frontend_status: ConnectionStatus, frontend_token: Token, frontend: Front, - listener: Rc>, protocol: Protocol, request_id: Ulid, session_address: Option, + tags: Option>, websocket_context: WebSocketContext, } -impl Pipe { +impl Pipe { /// Instantiate a new Pipe SessionState with: /// /// - frontend_interest: READABLE | WRITABLE | HUP | ERROR @@ -103,12 +103,12 @@ impl Pipe { frontend_buffer: Checkout, frontend_token: Token, frontend: Front, - listener: Rc>, protocol: Protocol, request_id: Ulid, session_address: Option, websocket_context: WebSocketContext, - ) -> Pipe { + tags: Option>, + ) -> Pipe { let frontend_status = ConnectionStatus::Normal; let backend_status = if backend_socket.is_none() { ConnectionStatus::Closed @@ -138,10 +138,10 @@ impl Pipe { frontend_status, frontend_token, frontend, - listener, protocol, request_id, session_address, + tags, websocket_context, }; @@ -232,7 +232,6 @@ impl Pipe { } pub fn log_request(&self, metrics: &SessionMetrics, error: bool, message: Option<&str>) { - let listener = self.listener.borrow(); let context = self.log_context(); let endpoint = self.log_endpoint(); metrics.register_end_of_session(&context); @@ -245,7 +244,7 @@ impl Pipe { backend_address: self.get_backend_address(), protocol: self.protocol_string(), endpoint, - tags: listener.get_tags(&listener.get_addr().to_string()), + tags: self.tags.as_deref(), client_rtt: socket_rtt(self.front_socket()), server_rtt: self.backend_socket.as_ref().and_then(socket_rtt), service_time: metrics.service_time(), @@ -657,7 +656,7 @@ impl Pipe { } } -impl SessionState for Pipe { +impl SessionState for Pipe { fn ready( &mut self, _session: Rc>, diff --git a/lib/src/protocol/proxy_protocol/expect.rs b/lib/src/protocol/proxy_protocol/expect.rs index 031d1cdf9..fffab0fd3 100644 --- a/lib/src/protocol/proxy_protocol/expect.rs +++ b/lib/src/protocol/proxy_protocol/expect.rs @@ -3,7 +3,10 @@ use std::{cell::RefCell, rc::Rc}; use mio::{net::TcpStream, *}; use nom::{Err, HexDisplay}; use rusty_ulid::Ulid; -use sozu_command::{config::MAX_LOOP_ITERATIONS, logging::LogContext}; +use sozu_command::{ + config::MAX_LOOP_ITERATIONS, + logging::{CachedTags, LogContext}, +}; use crate::{ pool::Checkout, @@ -13,7 +16,6 @@ use crate::{ }, socket::{SocketHandler, SocketResult}, sozu_command::ready::Ready, - tcp::TcpListener, timer::TimeoutContainer, Protocol, Readiness, SessionMetrics, StateResult, }; @@ -166,8 +168,8 @@ impl ExpectProxyProtocol { back_buf: Checkout, backend_socket: Option, backend_token: Option, - listener: Rc>, - ) -> Pipe { + tags: Option>, + ) -> Pipe { let addr = self.front_socket().peer_addr().ok(); let mut pipe = Pipe::new( @@ -181,11 +183,11 @@ impl ExpectProxyProtocol { front_buf, self.frontend_token, self.frontend, - listener, Protocol::TCP, self.request_id, addr, WebSocketContext::Tcp, + tags, ); pipe.frontend_readiness.event = self.frontend_readiness.event; diff --git a/lib/src/protocol/proxy_protocol/relay.rs b/lib/src/protocol/proxy_protocol/relay.rs index 594ea3819..c2705c746 100644 --- a/lib/src/protocol/proxy_protocol/relay.rs +++ b/lib/src/protocol/proxy_protocol/relay.rs @@ -1,8 +1,9 @@ -use std::{cell::RefCell, io::Write, rc::Rc}; +use std::{io::Write, rc::Rc}; use mio::{net::TcpStream, Token}; use nom::{Err, Offset}; use rusty_ulid::Ulid; +use sozu_command::logging::CachedTags; use crate::{ pool::Checkout, @@ -12,7 +13,6 @@ use crate::{ }, socket::{SocketHandler, SocketResult}, sozu_command::ready::Ready, - tcp::TcpListener, Protocol, Readiness, SessionMetrics, SessionResult, }; @@ -172,11 +172,7 @@ impl RelayProxyProtocol { self.backend_token = Some(token); } - pub fn into_pipe( - mut self, - back_buf: Checkout, - listener: Rc>, - ) -> Pipe { + pub fn into_pipe(mut self, back_buf: Checkout, tags: Option>) -> Pipe { let backend_socket = self.backend.take().unwrap(); let addr = self.front_socket().peer_addr().ok(); @@ -191,11 +187,11 @@ impl RelayProxyProtocol { self.frontend_buffer, self.frontend_token, self.frontend, - listener, Protocol::TCP, self.request_id, addr, WebSocketContext::Tcp, + tags, ); pipe.frontend_readiness.event = self.frontend_readiness.event; diff --git a/lib/src/protocol/proxy_protocol/send.rs b/lib/src/protocol/proxy_protocol/send.rs index 061982509..7ed666f65 100644 --- a/lib/src/protocol/proxy_protocol/send.rs +++ b/lib/src/protocol/proxy_protocol/send.rs @@ -1,11 +1,11 @@ use std::{ - cell::RefCell, io::{ErrorKind, Write}, rc::Rc, }; use mio::{net::TcpStream, Token}; use rusty_ulid::Ulid; +use sozu_command::logging::CachedTags; use crate::{ pool::Checkout, @@ -15,7 +15,6 @@ use crate::{ }, socket::SocketHandler, sozu_command::ready::Ready, - tcp::TcpListener, BackendConnectionStatus, Protocol, Readiness, SessionMetrics, SessionResult, }; @@ -156,8 +155,8 @@ impl SendProxyProtocol { mut self, front_buf: Checkout, back_buf: Checkout, - listener: Rc>, - ) -> Pipe { + tags: Option>, + ) -> Pipe { let backend_socket = self.backend.take().unwrap(); let addr = self.front_socket().peer_addr().ok(); @@ -172,11 +171,11 @@ impl SendProxyProtocol { front_buf, self.frontend_token, self.frontend, - listener, Protocol::TCP, self.request_id, addr, WebSocketContext::Tcp, + tags, ); pipe.frontend_readiness = self.frontend_readiness; diff --git a/lib/src/router/mod.rs b/lib/src/router/mod.rs index 5f91baa7c..f6727c424 100644 --- a/lib/src/router/mod.rs +++ b/lib/src/router/mod.rs @@ -1,11 +1,22 @@ pub mod pattern_trie; -use std::{str::from_utf8, time::Instant}; +use std::{ + fmt::{Debug, Write}, + rc::Rc, + str::{from_utf8, from_utf8_unchecked}, + time::Instant, +}; +use nom::AsChar; +use pattern_trie::{TrieMatches, TrieSubMatch}; use regex::bytes::Regex; use sozu_command::{ - proto::command::{PathRule as CommandPathRule, PathRuleKind, RulePosition}, + logging::CachedTags, + proto::command::{ + HeaderPosition, PathRule as CommandPathRule, PathRuleKind, RedirectPolicy, RedirectScheme, + RulePosition, + }, response::HttpFrontend, state::ClusterId, }; @@ -16,8 +27,12 @@ use crate::{protocol::http::parser::Method, router::pattern_trie::TrieNode}; pub enum RouterError { #[error("Could not parse rule from frontend path {0:?}")] InvalidPathRule(String), - #[error("parsing hostname {hostname} failed")] - InvalidDomain { hostname: String }, + #[error("Could not parse hostname {0:?}")] + InvalidDomain(String), + #[error("Could not parse host rewrite {0:?}")] + InvalidHostRewrite(String), + #[error("Could not parse path rewrite {0:?}")] + InvalidPathRewrite(String), #[error("Could not add route {0}")] AddRoute(String), #[error("Could not remove route {0}")] @@ -31,9 +46,9 @@ pub enum RouterError { } pub struct Router { - pre: Vec<(DomainRule, PathRule, MethodRule, Route)>, - pub tree: TrieNode>, - post: Vec<(DomainRule, PathRule, MethodRule, Route)>, + pre: Vec<(DomainRule, PathRule, MethodRule, Frontend)>, + pub tree: TrieNode>, + post: Vec<(DomainRule, PathRule, MethodRule, Frontend)>, } impl Default for Router { @@ -51,35 +66,46 @@ impl Router { } } - pub fn lookup( + pub fn lookup<'a>( &self, - hostname: &str, - path: &str, - method: &Method, - ) -> Result { + hostname: &'a str, + path: &'a str, + method: &'a Method, + ) -> Result { let hostname_b = hostname.as_bytes(); let path_b = path.as_bytes(); - for (domain_rule, path_rule, method_rule, cluster_id) in &self.pre { + for (domain_rule, path_rule, method_rule, route) in &self.pre { if domain_rule.matches(hostname_b) && path_rule.matches(path_b) != PathRuleResult::None && method_rule.matches(method) != MethodRuleResult::None { - return Ok(cluster_id.clone()); + return Ok(RouteResult::new_no_trie( + hostname_b, + domain_rule, + path_b, + path_rule, + route, + )); } } - if let Some((_, path_rules)) = self.tree.lookup(hostname_b, true) { + let trie_path = Vec::with_capacity(16); + if let Some(((_, rules), trie_path)) = self.tree.lookup(hostname_b, true, trie_path) { let mut prefix_length = 0; - let mut route = None; + let mut frontend = None; - for (rule, method_rule, cluster_id) in path_rules { - match rule.matches(path_b) { + for (path_rule, method_rule, route) in rules { + match path_rule.matches(path_b) { PathRuleResult::Regex | PathRuleResult::Equals => { match method_rule.matches(method) { - MethodRuleResult::Equals => return Ok(cluster_id.clone()), + MethodRuleResult::Equals => { + return Ok(RouteResult::new_with_trie( + hostname_b, trie_path, path_b, path_rule, route, + )) + } MethodRuleResult::All => { prefix_length = path_b.len(); - route = Some(cluster_id); + frontend = Some((path_rule, route)); } MethodRuleResult::None => {} } @@ -90,11 +116,11 @@ impl Router { // FIXME: the rule order will be important here MethodRuleResult::Equals => { prefix_length = size; - route = Some(cluster_id); + frontend = Some((path_rule, route)); } MethodRuleResult::All => { prefix_length = size; - route = Some(cluster_id); + frontend = Some((path_rule, route)); } MethodRuleResult::None => {} } @@ -104,17 +130,25 @@ impl Router { } } - if let Some(cluster_id) = route { - return Ok(cluster_id.clone()); + if let Some((path_rule, route)) = frontend { + return Ok(RouteResult::new_with_trie( + hostname_b, trie_path, path_b, path_rule, route, + )); } } - for (domain_rule, path_rule, method_rule, cluster_id) in self.post.iter() { + for (domain_rule, path_rule, method_rule, route) in self.post.iter() { if domain_rule.matches(hostname_b) && path_rule.matches(path_b) != PathRuleResult::None && method_rule.matches(method) != MethodRuleResult::None { - return Ok(cluster_id.clone()); + return Ok(RouteResult::new_no_trie( + hostname_b, + domain_rule, + path_b, + path_rule, + route, + )); } } @@ -126,34 +160,22 @@ impl Router { } pub fn add_http_front(&mut self, front: &HttpFrontend) -> Result<(), RouterError> { + let domain_rule = front + .hostname + .parse::() + .map_err(|_| RouterError::InvalidDomain(front.hostname.clone()))?; + let path_rule = PathRule::from_config(front.path.clone()) .ok_or(RouterError::InvalidPathRule(front.path.to_string()))?; let method_rule = MethodRule::new(front.method.clone()); - let route = match &front.cluster_id { - Some(cluster_id) => Route::ClusterId(cluster_id.clone()), - None => Route::Deny, - }; + let route = Frontend::new(&domain_rule, &path_rule, front)?; let success = match front.position { - RulePosition::Pre => { - let domain = front.hostname.parse::().map_err(|_| { - RouterError::InvalidDomain { - hostname: front.hostname.clone(), - } - })?; - - self.add_pre_rule(&domain, &path_rule, &method_rule, &route) - } + RulePosition::Pre => self.add_pre_rule(&domain_rule, &path_rule, &method_rule, &route), RulePosition::Post => { - let domain = front.hostname.parse::().map_err(|_| { - RouterError::InvalidDomain { - hostname: front.hostname.clone(), - } - })?; - - self.add_post_rule(&domain, &path_rule, &method_rule, &route) + self.add_post_rule(&domain_rule, &path_rule, &method_rule, &route) } RulePosition::Tree => { self.add_tree_rule(front.hostname.as_bytes(), &path_rule, &method_rule, &route) @@ -173,22 +195,20 @@ impl Router { let remove_success = match front.position { RulePosition::Pre => { - let domain = front.hostname.parse::().map_err(|_| { - RouterError::InvalidDomain { - hostname: front.hostname.clone(), - } - })?; + let domain_rule = front + .hostname + .parse::() + .map_err(|_| RouterError::InvalidDomain(front.hostname.clone()))?; - self.remove_pre_rule(&domain, &path_rule, &method_rule) + self.remove_pre_rule(&domain_rule, &path_rule, &method_rule) } RulePosition::Post => { - let domain = front.hostname.parse::().map_err(|_| { - RouterError::InvalidDomain { - hostname: front.hostname.clone(), - } - })?; + let domain_rule = front + .hostname + .parse::() + .map_err(|_| RouterError::InvalidDomain(front.hostname.clone()))?; - self.remove_post_rule(&domain, &path_rule, &method_rule) + self.remove_post_rule(&domain_rule, &path_rule, &method_rule) } RulePosition::Tree => { self.remove_tree_rule(front.hostname.as_bytes(), &path_rule, &method_rule) @@ -205,7 +225,7 @@ impl Router { hostname: &[u8], path: &PathRule, method: &MethodRule, - cluster: &Route, + cluster: &Frontend, ) -> bool { let hostname = match from_utf8(hostname) { Err(_) => return false, @@ -282,7 +302,7 @@ impl Router { domain: &DomainRule, path: &PathRule, method: &MethodRule, - cluster_id: &Route, + cluster_id: &Frontend, ) -> bool { if !self .pre @@ -306,7 +326,7 @@ impl Router { domain: &DomainRule, path: &PathRule, method: &MethodRule, - cluster_id: &Route, + cluster_id: &Frontend, ) -> bool { if !self .post @@ -367,13 +387,13 @@ impl Router { #[derive(Clone, Debug)] pub enum DomainRule { Any, - Exact(String), + Equals(String), Wildcard(String), Regex(Regex), } fn convert_regex_domain_rule(hostname: &str) -> Option { - let mut result = String::new(); + let mut result = String::from("\\A"); let s = hostname.as_bytes(); let mut index = 0; @@ -388,6 +408,7 @@ fn convert_regex_domain_rule(hostname: &str) -> Option { } index = i + 1; found = true; + break; } } @@ -415,6 +436,7 @@ fn convert_regex_domain_rule(hostname: &str) -> Option { } if index == s.len() { + result.push_str("\\z"); return Some(result); } else if s[index] == b'.' { result.push_str("\\."); @@ -434,7 +456,7 @@ impl DomainRule { hostname.ends_with(s[1..].as_bytes()) && !&hostname[..len_without_suffix].contains(&b'.') } - DomainRule::Exact(s) => s.as_bytes() == hostname, + DomainRule::Equals(s) => s.as_bytes() == hostname, DomainRule::Regex(r) => { let start = Instant::now(); let is_a_match = r.is_match(hostname); @@ -451,7 +473,7 @@ impl std::cmp::PartialEq for DomainRule { match (self, other) { (DomainRule::Any, DomainRule::Any) => true, (DomainRule::Wildcard(s1), DomainRule::Wildcard(s2)) => s1 == s2, - (DomainRule::Exact(s1), DomainRule::Exact(s2)) => s1 == s2, + (DomainRule::Equals(s1), DomainRule::Equals(s2)) => s1 == s2, (DomainRule::Regex(r1), DomainRule::Regex(r2)) => r1.as_str() == r2.as_str(), _ => false, } @@ -483,7 +505,7 @@ impl std::str::FromStr for DomainRule { } } else { match ::idna::domain_to_ascii(s) { - Ok(r) => DomainRule::Exact(r), + Ok(r) => DomainRule::Equals(r), Err(_) => return Err(()), } }) @@ -497,7 +519,7 @@ pub enum PathRule { Equals(String), } -#[derive(PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum PathRuleResult { Regex, Prefix(usize), @@ -590,13 +612,458 @@ impl MethodRule { } } -/// The cluster to which the traffic will be redirected -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub enum Route { - /// send a 401 default answer - Deny, - /// the cluster to which the frontend belongs - ClusterId(ClusterId), +#[derive(Debug, Clone)] +enum RewritePart { + String(String), + Host(usize), + Path(usize), +} +impl RewritePart { + pub fn string(s: &str) -> Self { + Self::String(String::from(s)) + } + pub fn bytes(b: &[u8]) -> Self { + Self::String(unsafe { String::from_utf8_unchecked(b.to_vec()) }) + } +} + +#[derive(Debug, Clone)] +pub struct RewriteParts(Vec); +impl RewriteParts { + pub fn new( + pattern: &str, + index_max_host: usize, + index_max_path: usize, + used_index_host: &mut usize, + used_index_path: &mut usize, + ) -> Option { + let mut result = Vec::new(); + let mut i = 0; + let pattern = pattern.as_bytes(); + while i < pattern.len() { + if pattern[i] == b'$' { + let is_host = if pattern[i..].starts_with(b"$HOST[") { + i += 6; + true + } else if pattern[i..].starts_with(b"$PATH[") { + i += 6; + false + } else { + return None; + }; + let mut index = 0; + while i < pattern.len() && pattern[i].is_dec_digit() { + index = index * 10 + (pattern[i] - b'0') as usize; + i += 1; + } + if i >= pattern.len() || pattern[i] != b']' { + return None; + } + if is_host { + if index >= index_max_host { + return None; + } + if index >= *used_index_host { + *used_index_host = index + 1; + } + result.push(RewritePart::Host(index)); + } else { + if index >= index_max_path { + return None; + } + if index >= *used_index_path { + *used_index_path = index + 1; + } + result.push(RewritePart::Path(index)); + } + i += 1; + } else { + let start = i; + while i < pattern.len() && pattern[i] != b'$' { + i += 1; + } + result.push(RewritePart::bytes(&pattern[start..i])); + } + } + Some(Self(result)) + } + pub fn run(&self, host_captures: &[&str], path_captures: &[&str]) -> String { + let mut cap = 0; + for part in &self.0 { + cap += match part { + RewritePart::String(s) => s.len(), + RewritePart::Host(i) => unsafe { host_captures.get_unchecked(*i).len() }, + RewritePart::Path(i) => unsafe { path_captures.get_unchecked(*i).len() }, + }; + } + let mut result = String::with_capacity(cap); + for part in &self.0 { + let _ = match part { + RewritePart::String(s) => result.write_str(s), + RewritePart::Host(i) => { + result.write_str(unsafe { host_captures.get_unchecked(*i) }) + } + RewritePart::Path(i) => { + result.write_str(unsafe { path_captures.get_unchecked(*i) }) + } + }; + } + result + } +} + +#[derive(Clone, PartialEq, Eq)] +pub struct HeaderEdit { + pub key: Rc<[u8]>, + pub val: Rc<[u8]>, +} + +impl Debug for HeaderEdit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!( + "({:?}, {:?})", + String::from_utf8_lossy(&self.key), + String::from_utf8_lossy(&self.val) + )) + } +} + +/// What to do with the traffic +/// TODO: tags should be moved here +#[derive(Debug, Clone)] +pub struct Frontend { + cluster_id: Option, + required_auth: bool, + redirect: RedirectPolicy, + redirect_scheme: RedirectScheme, + redirect_template: Option, + capture_cap_host: usize, + capture_cap_path: usize, + rewrite_host: Option, + rewrite_path: Option, + rewrite_port: Option, + headers_request: Rc<[HeaderEdit]>, + headers_response: Rc<[HeaderEdit]>, + tags: Option>, +} + +impl Frontend { + pub fn new( + domain_rule: &DomainRule, + path_rule: &PathRule, + front: &HttpFrontend, + ) -> Result { + let cluster_id = front.cluster_id.clone(); + let required_auth = front.required_auth; + let rewrite_port = front.rewrite_port; + let rewrite_path = front.rewrite_path.clone(); + let rewrite_host = front.rewrite_host.clone(); + let redirect = front.redirect; + let redirect_scheme = front.redirect_scheme; + let redirect_template = front.redirect_template.clone(); + let headers = &front.headers; + let tags = front + .tags + .clone() + .map(|tags| Rc::new(CachedTags::new(tags))); + + let deny = match (&cluster_id, redirect, &redirect_template, required_auth) { + (_, RedirectPolicy::Unauthorized, _, false) => true, + (_, RedirectPolicy::Unauthorized, _, true) => { + warn!("Frontend[cluster: {:?}, domain: {:?}, path: {:?}, redirect: {:?}]: unauthorized frontends ignore auth", cluster_id, domain_rule, path_rule, redirect); + true + } + (None, RedirectPolicy::Forward, None, _) => { + warn!("Frontend[domain: {:?}, path: {:?}]: forward on clusterless frontends are unauthorized", domain_rule, path_rule); + true + } + (None, _, _, true) => { + warn!( + "Frontend[domain: {:?}, path: {:?}]: clusterless frontends ignore auth", + domain_rule, path_rule + ); + true + } + _ => false, + }; + if deny { + return Ok(Self { + cluster_id: cluster_id.clone(), + required_auth, + redirect: RedirectPolicy::Unauthorized, + redirect_scheme, + redirect_template: None, + capture_cap_host: 0, + capture_cap_path: 0, + rewrite_host: None, + rewrite_path: None, + rewrite_port: None, + headers_request: Rc::new([]), + headers_response: Rc::new([]), + tags: None, + }); + } + let mut capture_cap_host = match domain_rule { + DomainRule::Any => 1, + DomainRule::Equals(_) => 1, + DomainRule::Wildcard(_) => 2, + DomainRule::Regex(regex) => regex.captures_len(), + }; + let mut capture_cap_path = match path_rule { + PathRule::Equals(_) => 1, + PathRule::Prefix(_) => 2, + PathRule::Regex(regex) => regex.captures_len(), + }; + let mut used_capture_host = 0; + let mut used_capture_path = 0; + let rewrite_host = if let Some(p) = rewrite_host { + Some( + RewriteParts::new( + &p, + capture_cap_host, + capture_cap_path, + &mut used_capture_host, + &mut used_capture_path, + ) + .ok_or(RouterError::InvalidHostRewrite(p))?, + ) + } else { + None + }; + let rewrite_path = if let Some(p) = rewrite_path { + Some( + RewriteParts::new( + &p, + capture_cap_host, + capture_cap_path, + &mut used_capture_host, + &mut used_capture_path, + ) + .ok_or(RouterError::InvalidPathRewrite(p))?, + ) + } else { + None + }; + if used_capture_host == 0 { + capture_cap_host = 0; + } + if used_capture_path == 0 { + capture_cap_path = 0; + } + let mut headers_request = Vec::new(); + let mut headers_response = Vec::new(); + for header in headers { + let edit = HeaderEdit { + key: header.key.clone().into_bytes().into(), + val: header.val.clone().into_bytes().into(), + }; + match header.position() { + HeaderPosition::Request => headers_request.push(edit), + HeaderPosition::Response => headers_response.push(edit), + HeaderPosition::Both => { + headers_request.push(edit.clone()); + headers_response.push(edit); + } + } + } + Ok(Frontend { + cluster_id, + required_auth, + redirect, + redirect_scheme, + redirect_template, + capture_cap_host, + capture_cap_path, + rewrite_host, + rewrite_path, + rewrite_port, + headers_request: headers_request.into(), + headers_response: headers_response.into(), + tags, + }) + } + + #[cfg(test)] + pub fn forward(cluster_id: ClusterId) -> Self { + Self { + cluster_id: Some(cluster_id), + required_auth: false, + redirect: RedirectPolicy::Forward, + redirect_scheme: RedirectScheme::UseSame, + redirect_template: None, + capture_cap_host: 0, + capture_cap_path: 0, + rewrite_host: None, + rewrite_path: None, + rewrite_port: None, + headers_request: Rc::new([]), + headers_response: Rc::new([]), + tags: None, + } + } +} + +#[derive(Debug, Clone)] +pub struct RouteResult { + pub cluster_id: Option, + pub required_auth: bool, + pub redirect: RedirectPolicy, + pub redirect_scheme: RedirectScheme, + pub redirect_template: Option, + pub rewritten_host: Option, + pub rewritten_path: Option, + pub rewritten_port: Option, + pub headers_request: Rc<[HeaderEdit]>, + pub headers_response: Rc<[HeaderEdit]>, + pub tags: Option>, +} + +impl RouteResult { + fn deny(cluster_id: &Option) -> Self { + Self { + cluster_id: cluster_id.clone(), + required_auth: false, + redirect: RedirectPolicy::Unauthorized, + redirect_scheme: RedirectScheme::UseSame, + redirect_template: None, + rewritten_host: None, + rewritten_path: None, + rewritten_port: None, + headers_request: Rc::new([]), + headers_response: Rc::new([]), + tags: None, + } + } + + fn new<'a>( + captures_host: Vec<&'a str>, + path: &'a [u8], + path_rule: &PathRule, + route: &Frontend, + ) -> Self { + let Frontend { + cluster_id, + required_auth, + redirect, + redirect_scheme, + redirect_template, + capture_cap_path, + rewrite_host, + rewrite_path, + rewrite_port, + headers_request, + headers_response, + tags, + .. + } = route; + let mut captures_path = Vec::with_capacity(*capture_cap_path); + if *capture_cap_path > 0 { + captures_path.push(unsafe { from_utf8_unchecked(path) }); + match path_rule { + PathRule::Prefix(prefix) => { + captures_path.push(unsafe { from_utf8_unchecked(&path[prefix.len()..]) }) + } + PathRule::Regex(regex) => { + captures_path.extend(regex.captures(path).unwrap().iter().skip(1).map(|c| { + c.map(|c| unsafe { from_utf8_unchecked(c.as_bytes()) }) + .unwrap_or("") + })) + } + _ => {} + } + } + // println!("========HOST_CAPTURES: {captures_host:?}"); + // println!("========PATH_CAPTURES: {captures_path:?}"); + Self { + cluster_id: cluster_id.clone(), + required_auth: *required_auth, + redirect: *redirect, + redirect_scheme: *redirect_scheme, + redirect_template: redirect_template.clone(), + rewritten_host: rewrite_host + .as_ref() + .map(|rewrite| rewrite.run(&captures_host, &captures_path)), + rewritten_path: rewrite_path + .as_ref() + .map(|rewrite| rewrite.run(&captures_host, &captures_path)), + rewritten_port: *rewrite_port, + headers_request: headers_request.clone(), + headers_response: headers_response.clone(), + tags: tags.clone(), + } + } + fn new_no_trie<'a>( + domain: &'a [u8], + domain_rule: &DomainRule, + path: &'a [u8], + path_rule: &PathRule, + route: &Frontend, + ) -> Self { + let Frontend { + cluster_id, + redirect, + capture_cap_host, + .. + } = route; + if *redirect == RedirectPolicy::Unauthorized { + return Self::deny(cluster_id); + } + let mut captures_host = Vec::with_capacity(*capture_cap_host); + if *capture_cap_host > 0 { + captures_host.push(unsafe { from_utf8_unchecked(domain) }); + match domain_rule { + DomainRule::Wildcard(suffix) => captures_host + .push(unsafe { from_utf8_unchecked(&domain[..domain.len() - suffix.len()]) }), + DomainRule::Regex(regex) => captures_host.extend( + regex + .captures(domain) + .unwrap() + .iter() + .skip(1) + .map(|c| unsafe { from_utf8_unchecked(c.unwrap().as_bytes()) }), + ), + _ => {} + } + } + Self::new(captures_host, path, path_rule, route) + } + fn new_with_trie<'a>( + domain: &'a [u8], + domain_submatches: TrieMatches<'_, 'a>, + path: &'a [u8], + path_rule: &PathRule, + route: &Frontend, + ) -> Self { + let Frontend { + cluster_id, + redirect, + capture_cap_host, + .. + } = route; + if *redirect == RedirectPolicy::Unauthorized { + return Self::deny(cluster_id); + } + let mut captures_host = Vec::with_capacity(*capture_cap_host); + if *capture_cap_host > 0 { + captures_host.push(unsafe { from_utf8_unchecked(domain) }); + for submatch in domain_submatches { + match submatch { + TrieSubMatch::Wildcard(part) => { + captures_host.push(unsafe { from_utf8_unchecked(part) }) + } + TrieSubMatch::Regexp(part, regex) => captures_host.extend( + regex + .captures(part) + .unwrap() + .iter() + .skip(1) + .map(|c| unsafe { from_utf8_unchecked(c.unwrap().as_bytes()) }), + ), + } + } + } + Self::new(captures_host, path, path_rule, route) + } } #[cfg(test)] @@ -643,7 +1110,7 @@ mod tests { assert_eq!("*".parse::().unwrap(), DomainRule::Any); assert_eq!( "www.example.com".parse::().unwrap(), - DomainRule::Exact("www.example.com".to_string()) + DomainRule::Equals("www.example.com".to_string()) ); assert_eq!( "*.example.com".parse::().unwrap(), @@ -660,7 +1127,7 @@ mod tests { fn match_domain_rule() { assert!(DomainRule::Any.matches("www.example.com".as_bytes())); assert!( - DomainRule::Exact("www.example.com".to_string()).matches("www.example.com".as_bytes()) + DomainRule::Equals("www.example.com".to_string()).matches("www.example.com".as_bytes()) ); assert!( DomainRule::Wildcard("*.example.com".to_string()).matches("www.example.com".as_bytes()) @@ -715,27 +1182,33 @@ mod tests { b"*.sozu.io", &PathRule::Prefix("".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("base".to_string()) + &Frontend::forward("base".to_string()) )); println!("{:#?}", router.tree); assert_eq!( - router.lookup("www.sozu.io", "/api", &Method::Get), - Ok(Route::ClusterId("base".to_string())) + router + .lookup("www.sozu.io", "/api", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("base".to_string())) ); assert!(router.add_tree_rule( b"*.sozu.io", &PathRule::Prefix("/api".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("api".to_string()) + &Frontend::forward("api".to_string()) )); println!("{:#?}", router.tree); assert_eq!( - router.lookup("www.sozu.io", "/ap", &Method::Get), - Ok(Route::ClusterId("base".to_string())) + router + .lookup("www.sozu.io", "/ap", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("base".to_string())) ); assert_eq!( - router.lookup("www.sozu.io", "/api", &Method::Get), - Ok(Route::ClusterId("api".to_string())) + router + .lookup("www.sozu.io", "/api", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("api".to_string())) ); } @@ -754,27 +1227,33 @@ mod tests { b"*.sozu.io", &PathRule::Prefix("".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("base".to_string()) + &Frontend::forward("base".to_string()) )); println!("{:#?}", router.tree); assert_eq!( - router.lookup("www.sozu.io", "/api", &Method::Get), - Ok(Route::ClusterId("base".to_string())) + router + .lookup("www.sozu.io", "/api", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("base".to_string())) ); assert!(router.add_tree_rule( b"api.sozu.io", &PathRule::Prefix("".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("api".to_string()) + &Frontend::forward("api".to_string()) )); println!("{:#?}", router.tree); assert_eq!( - router.lookup("www.sozu.io", "/api", &Method::Get), - Ok(Route::ClusterId("base".to_string())) + router + .lookup("www.sozu.io", "/api", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("base".to_string())) ); assert_eq!( - router.lookup("api.sozu.io", "/api", &Method::Get), - Ok(Route::ClusterId("api".to_string())) + router + .lookup("api.sozu.io", "/api", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("api".to_string())) ); } @@ -786,23 +1265,27 @@ mod tests { b"www./.*/.io", &PathRule::Prefix("".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("base".to_string()) + &Frontend::forward("base".to_string()) )); println!("{:#?}", router.tree); assert!(router.add_tree_rule( b"www.doc./.*/.io", &PathRule::Prefix("".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("doc".to_string()) + &Frontend::forward("doc".to_string()) )); println!("{:#?}", router.tree); assert_eq!( - router.lookup("www.sozu.io", "/", &Method::Get), - Ok(Route::ClusterId("base".to_string())) + router + .lookup("www.sozu.io", "/", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("base".to_string())) ); assert_eq!( - router.lookup("www.doc.sozu.io", "/", &Method::Get), - Ok(Route::ClusterId("doc".to_string())) + router + .lookup("www.doc.sozu.io", "/", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("doc".to_string())) ); assert!(router.remove_tree_rule( b"www./.*/.io", @@ -812,8 +1295,10 @@ mod tests { println!("{:#?}", router.tree); assert!(router.lookup("www.sozu.io", "/", &Method::Get).is_err()); assert_eq!( - router.lookup("www.doc.sozu.io", "/", &Method::Get), - Ok(Route::ClusterId("doc".to_string())) + router + .lookup("www.doc.sozu.io", "/", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("doc".to_string())) ); } @@ -825,53 +1310,57 @@ mod tests { &"*".parse::().unwrap(), &PathRule::Prefix("/.well-known/acme-challenge".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("acme".to_string()) + &Frontend::forward("acme".to_string()) )); assert!(router.add_tree_rule( "www.example.com".as_bytes(), &PathRule::Prefix("/".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("example".to_string()) + &Frontend::forward("example".to_string()) )); assert!(router.add_tree_rule( "*.test.example.com".as_bytes(), &PathRule::Regex(Regex::new("/hello[A-Z]+/").unwrap()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("examplewildcard".to_string()) + &Frontend::forward("examplewildcard".to_string()) )); assert!(router.add_tree_rule( "/test[0-9]/.example.com".as_bytes(), &PathRule::Prefix("/".to_string()), &MethodRule::new(Some("GET".to_string())), - &Route::ClusterId("exampleregex".to_string()) + &Frontend::forward("exampleregex".to_string()) )); assert_eq!( - router.lookup("www.example.com", "/helloA", &Method::new(&b"GET"[..])), - Ok(Route::ClusterId("example".to_string())) + router + .lookup("www.example.com", "/helloA", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("example".to_string())) ); assert_eq!( - router.lookup( - "www.example.com", - "/.well-known/acme-challenge", - &Method::new(&b"GET"[..]) - ), - Ok(Route::ClusterId("acme".to_string())) + router + .lookup( + "www.example.com", + "/.well-known/acme-challenge", + &Method::Get + ) + .map(|r| r.cluster_id), + Ok(Some("acme".to_string())) ); assert!(router - .lookup("www.test.example.com", "/", &Method::new(&b"GET"[..])) + .lookup("www.test.example.com", "/", &Method::Get) .is_err()); assert_eq!( - router.lookup( - "www.test.example.com", - "/helloAB/", - &Method::new(&b"GET"[..]) - ), - Ok(Route::ClusterId("examplewildcard".to_string())) + router + .lookup("www.test.example.com", "/helloAB/", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("examplewildcard".to_string())) ); assert_eq!( - router.lookup("test1.example.com", "/helloAB/", &Method::new(&b"GET"[..])), - Ok(Route::ClusterId("exampleregex".to_string())) + router + .lookup("test1.example.com", "/helloAB/", &Method::Get) + .map(|r| r.cluster_id), + Ok(Some("exampleregex".to_string())) ); } } diff --git a/lib/src/router/pattern_trie.rs b/lib/src/router/pattern_trie.rs index f25a46962..90e05b79e 100644 --- a/lib/src/router/pattern_trie.rs +++ b/lib/src/router/pattern_trie.rs @@ -4,6 +4,12 @@ use regex::bytes::Regex; pub type Key = Vec; pub type KeyValue = (K, V); +pub type TrieMatches<'a, 'b> = Vec>; + +pub enum TrieSubMatch<'a, 'b> { + Wildcard(&'a [u8]), + Regexp(&'a [u8], &'b Regex), +} #[derive(Debug, PartialEq, Eq)] pub enum InsertResult { @@ -126,7 +132,8 @@ impl TrieNode { } } - if let Ok(r) = Regex::new(s) { + let s = format!("\\A{s}\\z"); + if let Ok(r) = Regex::new(&s) { if pos > 0 { let mut node = TrieNode::root(); let pos = pos - 1; @@ -263,11 +270,16 @@ impl TrieNode { } } - pub fn lookup(&self, partial_key: &[u8], accept_wildcard: bool) -> Option<&KeyValue> { + pub fn lookup<'a, 'b>( + &'a self, + partial_key: &'b [u8], + accept_wildcard: bool, + mut path: TrieMatches<'a, 'b>, + ) -> Option<(&'a KeyValue, TrieMatches<'a, 'b>)> { //println!("lookup: key == {}", std::str::from_utf8(partial_key).unwrap()); if partial_key.is_empty() { - return self.key_value.as_ref(); + return self.key_value.as_ref().map(|kv| (kv, path)); } let pos = find_last_dot(partial_key); @@ -278,27 +290,29 @@ impl TrieNode { //println!("lookup: prefix|suffix: {} | {}", std::str::from_utf8(prefix).unwrap(), std::str::from_utf8(suffix).unwrap()); match self.children.get(suffix) { - Some(child) => child.lookup(prefix, accept_wildcard), + Some(child) => child.lookup(prefix, accept_wildcard, path), None => { //println!("no child found, testing wildcard and regexps"); if prefix.is_empty() && self.wildcard.is_some() && accept_wildcard { //println!("no dot, wildcard applies"); - self.wildcard.as_ref() + path.insert(0, TrieSubMatch::Wildcard(suffix)); + self.wildcard.as_ref().map(|kv| (kv, path)) } else { //println!("there's still a subdomain, wildcard does not apply"); - for (ref regexp, ref child) in self.regexps.iter() { - let suffix = if suffix[0] == b'.' { - &suffix[1..] - } else { - suffix - }; + let suffix = if suffix[0] == b'.' { + &suffix[1..] + } else { + suffix + }; + for (regexp, child) in &self.regexps { //println!("testing regexp: {} on suffix {}", r.as_str(), str::from_utf8(s).unwrap()); if regexp.is_match(suffix) { //println!("matched"); - return child.lookup(prefix, accept_wildcard); + path.insert(0, TrieSubMatch::Regexp(suffix, regexp)); + return child.lookup(prefix, accept_wildcard, path); } } @@ -361,12 +375,12 @@ impl TrieNode { } else { //println!("there's still a subdomain, wildcard does not apply"); - for (ref regexp, ref mut child) in self.regexps.iter_mut() { - let suffix = if suffix[0] == b'.' { - &suffix[1..] - } else { - suffix - }; + let suffix = if suffix[0] == b'.' { + &suffix[1..] + } else { + suffix + }; + for (regexp, child) in &mut self.regexps { //println!("testing regexp: {} on suffix {}", r.as_str(), str::from_utf8(s).unwrap()); if regexp.is_match(suffix) { @@ -421,7 +435,8 @@ impl TrieNode { } pub fn domain_lookup(&self, key: &[u8], accept_wildcard: bool) -> Option<&KeyValue> { - self.lookup(key, accept_wildcard) + let path = Vec::new(); + self.lookup(key, accept_wildcard, path).map(|(kv, _)| kv) } pub fn domain_lookup_mut( @@ -809,7 +824,7 @@ mod tests { } //match root.domain_lookup(k.as_bytes()) { - match root.lookup(k.as_bytes(), false) { + match root.domain_lookup(k.as_bytes(), false) { None => { println!("did not find key '{k}'"); return false; diff --git a/lib/src/tcp.rs b/lib/src/tcp.rs index ea6955e86..90807f273 100644 --- a/lib/src/tcp.rs +++ b/lib/src/tcp.rs @@ -44,7 +44,7 @@ use crate::{ }, timer::TimeoutContainer, AcceptError, BackendConnectAction, BackendConnectionError, BackendConnectionStatus, CachedTags, - ListenerError, ListenerHandler, Protocol, ProxyConfiguration, ProxyError, ProxySession, + L4ListenerHandler, ListenerError, Protocol, ProxyConfiguration, ProxyError, ProxySession, Readiness, SessionIsToBeClosed, SessionMetrics, SessionResult, StateMachineBuilder, }; @@ -54,7 +54,7 @@ StateMachineBuilder! { /// 1. optional (ExpectProxyProtocol | SendProxyProtocol | RelayProxyProtocol) /// 2. Pipe enum TcpStateMachine { - Pipe(Pipe), + Pipe(Pipe), SendProxyProtocol(SendProxyProtocol), RelayProxyProtocol(RelayProxyProtocol), ExpectProxyProtocol(ExpectProxyProtocol), @@ -91,13 +91,14 @@ pub struct TcpSession { frontend_address: Option, frontend_buffer: Option, frontend_token: Token, - has_been_closed: SessionIsToBeClosed, + has_been_closed: bool, last_event: Instant, listener: Rc>, metrics: SessionMetrics, proxy: Rc>, request_id: Ulid, state: TcpStateMachine, + tags: Option>, } impl TcpSession { @@ -127,6 +128,8 @@ impl TcpSession { TimeoutContainer::new(configured_frontend_timeout, frontend_token); let container_backend_timeout = TimeoutContainer::new_empty(configured_connect_timeout); + let tags = listener.borrow().tags.clone(); + let state = match proxy_protocol { Some(ProxyProtocolConfig::RelayHeader) => { backend_buffer_session = Some(backend_buffer); @@ -174,11 +177,11 @@ impl TcpSession { frontend_buffer, frontend_token, socket, - listener.clone(), Protocol::TCP, request_id, frontend_address, WebSocketContext::Tcp, + tags.clone(), ); pipe.set_cluster_id(cluster_id.clone()); TcpStateMachine::Pipe(pipe) @@ -209,11 +212,11 @@ impl TcpSession { proxy, request_id, state, + tags, } } fn log_request(&self) { - let listener = self.listener.borrow(); let context = self.log_context(); self.metrics.register_end_of_session(&context); info_access!( @@ -224,7 +227,7 @@ impl TcpSession { backend_address: None, protocol: "TCP", endpoint: EndpointRecord::Tcp, - tags: listener.get_tags(&listener.get_addr().to_string()), + tags: self.tags.as_deref(), client_rtt: socket_rtt(self.state.front_socket()), server_rtt: None, user_agent: None, @@ -364,7 +367,7 @@ impl TcpSession { let mut pipe = send_proxy_protocol.into_pipe( self.frontend_buffer.take().unwrap(), self.backend_buffer.take().unwrap(), - self.listener.clone(), + self.tags.clone(), ); pipe.set_cluster_id(self.cluster_id.clone()); @@ -382,8 +385,7 @@ impl TcpSession { fn upgrade_relay(&mut self, rpp: RelayProxyProtocol) -> Option { if self.backend_buffer.is_some() { - let mut pipe = - rpp.into_pipe(self.backend_buffer.take().unwrap(), self.listener.clone()); + let mut pipe = rpp.into_pipe(self.backend_buffer.take().unwrap(), self.tags.clone()); pipe.set_cluster_id(self.cluster_id.clone()); gauge_add!("protocol.proxy.relay", -1); gauge_add!("protocol.tcp", 1); @@ -407,7 +409,7 @@ impl TcpSession { self.backend_buffer.take().unwrap(), None, None, - self.listener.clone(), + self.tags.clone(), ); pipe.set_cluster_id(self.cluster_id.clone()); @@ -495,7 +497,7 @@ impl TcpSession { "connections_per_backend", 1, self.cluster_id.as_deref(), - self.metrics.backend_id.as_deref() + self.backend_id.as_deref() ); // the back timeout was of connect_timeout duration before, @@ -515,7 +517,7 @@ impl TcpSession { incr!( "backend.up", self.cluster_id.as_deref(), - self.metrics.backend_id.as_deref() + self.backend_id.as_deref() ); info!( "backend server {} at {} is up", @@ -558,7 +560,7 @@ impl TcpSession { incr!( "backend.connections.error", self.cluster_id.as_deref(), - self.metrics.backend_id.as_deref() + self.backend_id.as_deref() ); if !already_unavailable && backend.retry_policy.is_down() { error!( @@ -568,7 +570,7 @@ impl TcpSession { incr!( "backend.down", self.cluster_id.as_deref(), - self.metrics.backend_id.as_deref() + self.backend_id.as_deref() ); push_event(Event { @@ -822,7 +824,7 @@ impl TcpSession { "connections_per_backend", -1, self.cluster_id.as_deref(), - self.metrics.backend_id.as_deref() + self.backend_id.as_deref() ); } @@ -902,7 +904,7 @@ impl TcpSession { self.set_back_token(back_token); self.set_back_socket(stream); - self.metrics.backend_id = Some(backend.borrow().backend_id.clone()); + // self.metrics.backend_id = Some(backend.borrow().backend_id.clone()); self.metrics.backend_start(); self.set_backend_id(backend.borrow().backend_id.clone()); @@ -1089,29 +1091,22 @@ impl ProxySession for TcpSession { } pub struct TcpListener { - active: SessionIsToBeClosed, + active: bool, address: SocketAddr, cluster_id: Option, config: TcpListenerConfig, listener: Option, - tags: BTreeMap, + pub tags: Option>, token: Token, } -impl ListenerHandler for TcpListener { - fn get_addr(&self) -> &SocketAddr { - &self.address - } - - fn get_tags(&self, key: &str) -> Option<&CachedTags> { - self.tags.get(key) +impl L4ListenerHandler for TcpListener { + fn get_tags(&self) -> Option<&CachedTags> { + self.tags.as_deref() } - fn set_tags(&mut self, key: String, tags: Option>) { - match tags { - Some(tags) => self.tags.insert(key, CachedTags::new(tags)), - None => self.tags.remove(&key), - }; + fn set_tags(&mut self, tags: Option>) { + self.tags = tags.map(|tags| Rc::new(CachedTags::new(tags))) } } @@ -1124,7 +1119,7 @@ impl TcpListener { address: config.address.into(), config, active: false, - tags: BTreeMap::new(), + tags: None, }) } @@ -1224,7 +1219,7 @@ impl TcpProxy { } } - pub fn remove_listener(&mut self, address: SocketAddr) -> SessionIsToBeClosed { + pub fn remove_listener(&mut self, address: SocketAddr) -> bool { let len = self.listeners.len(); self.listeners.retain(|_, l| l.borrow().address != address); @@ -1291,7 +1286,7 @@ impl TcpProxy { self.fronts .insert(front.cluster_id.to_string(), listener.token); - listener.set_tags(address.to_string(), Some(front.tags)); + listener.set_tags(Some(front.tags)); listener.cluster_id = Some(front.cluster_id); Ok(()) } @@ -1308,7 +1303,7 @@ impl TcpProxy { None => return Err(ProxyError::NoListenerFound(address)), }; - listener.set_tags(address.to_string(), None); + listener.set_tags(None); if let Some(cluster_id) = listener.cluster_id.take() { self.fronts.remove(&cluster_id); }