diff --git a/Cargo.lock b/Cargo.lock index 682c6890..b0ee0dda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1131,6 +1131,7 @@ dependencies = [ "anstyle", "bitflags", "clap_lex 0.4.1", + "once_cell", "strsim", ] @@ -3787,6 +3788,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-test", + "toml 0.7.4", "tonic", "tonic-build", "url", @@ -3874,7 +3876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e17d47ce914bf4de440332250b0edd23ce48c005f59fab39d3335866b114f11a" dependencies = [ "thiserror", - "toml", + "toml 0.5.11", ] [[package]] @@ -4667,6 +4669,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93107647184f6027e3b7dcb2e11034cf95ffa1e3a682c67951963ac69c1c007d" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -5324,6 +5335,40 @@ dependencies = [ "serde", ] +[[package]] +name = "toml" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6135d499e69981f9ff0ef2167955a5333c35e36f6937d382974566b3d5b94ec" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a76a9312f5ba4c2dec6b9161fdf25d87ad8a09256ccea5a556fef03c706a10f" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.19.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380d56e8670370eee6566b0bfd4265f65b3f432e8c6d85623f728d4fa31f739" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + [[package]] name = "tonic" version = "0.8.3" @@ -6287,6 +6332,15 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +[[package]] +name = "winnow" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca0ace3845f0d96209f0375e6d367e3eb87eb65d27d445bdc9f1843a26f39448" +dependencies = [ + "memchr", +] + [[package]] name = "winreg" version = "0.10.1" diff --git a/config/config.toml.example b/config/config.toml.example new file mode 100644 index 00000000..924d4a00 --- /dev/null +++ b/config/config.toml.example @@ -0,0 +1,59 @@ +## Copy this file, rename it to `config.toml` and place it in the `config` sub-directory of the `root_dir`. +## For instance, if the `root_dir` is `~/.polybase`, then the config file should be: `~/.polybase/config/config.toml`. +## +## Uncomment the options and provide values as required. Default values are supplied where sensible. + +# The core options are the options supported by the cli and env (via `clap`) as well as the TOML-based configurtion modes. +[core] + +# The command string. Possible values: "start", "generate-key" +#command = "start" + +# Node ID: an unsigned 64-bit value. For instance: 1 +#id = 1 + +# Log level. Possible values are DEBUG, INFO, and ERROR. Default: "INFO". +#log_level = "INFO" + +# Log format. Possible values: "PRETTY". Default: "PRETTY". +#log_format = "PRETTY" + +# RPC listen address. Default: "0.0.0.0:8080" +#rpc_laddr = "0.0.0.0:8080" + +# Secret key encoded as a hex string. +#secret_key = "" + +# Peer listen address as a string. List of strings. Default: "/ip4/0.0.0.0/tcp/0" +#network_laddr = ["/ip4/0.0.0.0/tcp/0"] + +# Peers to dial as a string. Comma separated strings. Default: [""] +#dial_addr = [""] + +# Validator peers - list of strings. Default: [""] +#peers = [""] + +# Maximum history of blocks to keep in memory. Default = 1024 +#block_cache_count = 1024 + +# Maximum number of txns to include in a block. Default = 1024 +#block_txns_count = 1024 + +# Size of the chunks of data sent during snapshot load. Default = 4194304 +#snapshot_chunk_size = 4194304 + +# Size of the chunks of data sent during snapshot load. Default = 500 +#min_block_duration = 500 + +# Sentry DSN. Default: "" +#sentry_dsn = "" + +# Public key whitelist: list of strings +#whitelist = [""] + +# Restrict namespaces to pk//. Default: false +#restrict_namespaces = false + +# Non-core configuration options + +[extra] diff --git a/polybase/Cargo.toml b/polybase/Cargo.toml index d100e0f8..e20291c7 100644 --- a/polybase/Cargo.toml +++ b/polybase/Cargo.toml @@ -24,7 +24,7 @@ secp256k1 = { version = "0.26", features = [ "global-context", "rand", ] } -clap = { version = "4.1.4", features = ["env", "derive"] } +clap = { version = "4.1.4", features = ["cargo", "env", "derive"] } bincode = "1.3.3" winter-crypto = "0.4.2" cid = "0.10" @@ -64,6 +64,7 @@ void = "1.0.2" either = "1.8.1" ed25519-dalek = "1.0.1" bs58 = "0.5.0" +toml = "0.7.4" [dev-dependencies] tokio-test = "0.4.2" diff --git a/polybase/src/config.rs b/polybase/src/config.rs deleted file mode 100644 index 5ca75087..00000000 --- a/polybase/src/config.rs +++ /dev/null @@ -1,118 +0,0 @@ -use clap::{Parser, Subcommand, ValueEnum}; - -/// Polybase is a p2p decentralized database -#[derive(Parser, Debug)] -#[command(name = "Polybase")] -#[command(author = "Polybase ")] -#[command(author, version, about = "The p2p decentralized database", long_about = None)] -#[command(propagate_version = true)] -pub struct Config { - #[command(subcommand)] - pub command: Option, - - /// ID of the node - #[arg(long, env = "ID")] - pub id: Option, - - /// Root directory where application data is stored - #[arg(short, long, env = "ROOT_DIR", default_value = "~/.polybase")] - pub root_dir: String, - - /// Log level - #[arg(value_enum, long, env = "LOG_LEVEL", default_value = "INFO")] - pub log_level: LogLevel, - - /// Log format - #[arg(value_enum, long, env = "LOG_FORMAT", default_value = "PRETTY")] - pub log_format: LogFormat, - - /// RPC listen address - #[arg(long, env = "RPC_LADDR", default_value = "0.0.0.0:8080")] - pub rpc_laddr: String, - - /// Secret key encoded as hex - #[arg(long, env = "SECRET_KEY")] - pub secret_key: Option, - - /// Peer listen address - #[arg( - long, - env = "NETWORK_LADDR", - value_parser, - value_delimiter = ',', - default_value = "/ip4/0.0.0.0/tcp/0" - )] - pub network_laddr: Vec, - - /// Peers to dial - #[arg( - long, - env = "DIAL_ADDR", - default_value = "", - value_parser, - value_delimiter = ',' - )] - pub dial_addr: Vec, - - /// Validator peers - #[arg( - long, - env = "PEERS", - default_value = "", - value_parser, - value_delimiter = ',' - )] - pub peers: Vec, - - // Maximum history of blocks to keep in memory - #[arg(long, env = "BLOCK_CACHE_SIZE", default_value = "1024")] - pub block_cache_count: usize, - - /// Maximum number of txns to include in a block - #[arg(long, env = "BLOCK_TXN_COUNT", default_value = "1024")] - pub block_txns_count: usize, - - /// Size of the chunks of data sent during snapshot load - #[arg(long, env = "SNAPSHOT_CHUNK_SIZE", default_value = "4194304")] - pub snapshot_chunk_size: usize, - - /// Size of the chunks of data sent during snapshot load - #[arg(long, env = "MIN_BLOCK_DURATION", default_value = "500")] - pub min_block_duration: u64, - - /// Sentry DSN - #[arg(long, env = "SENTRY_DSN", default_value = "")] - pub sentry_dsn: Option, - - /// Public key whitelist - #[arg(long, env = "WHITELIST", value_parser, value_delimiter = ',')] - pub whitelist: Option>, - - /// Restrict namespaces to pk// - #[arg(long, env = "RESTRICT_NAMESPACES", default_value = "false")] - pub restrict_namespaces: bool, -} - -#[derive(Subcommand, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)] -#[clap(rename_all = "SNAKE_CASE")] -pub enum Command { - /// Start the server - Start, - /// Generate a new secret key - GenerateKey, -} - -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)] -#[clap(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum LogLevel { - Debug, - Info, - Error, -} - -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)] -#[clap(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum LogFormat { - Pretty, - Json, -} diff --git a/polybase/src/config/clap_config.rs b/polybase/src/config/clap_config.rs new file mode 100644 index 00000000..7bec6027 --- /dev/null +++ b/polybase/src/config/clap_config.rs @@ -0,0 +1,164 @@ +use clap::{crate_version, Arg, ArgAction, ArgMatches, Command}; + +use super::{LogFormat, LogLevel}; + +/// Low-level `clap` object which provides with `value_source` which +/// indicates whether an option was set by the user (cli/env) or by the +/// default value. +/// +/// This also encapsulates the core configuation that is supported for the cli, env, +/// and TOML (file-based) configuration. +pub(super) fn get_matches() -> ArgMatches { + Command::new("Polybase") + .author("Polybase ") + .about("The p2p decentralized database") + .version(crate_version!()) // pick the version from `Cargo.toml` + .propagate_version(true) + .subcommand(Command::new("start").about("Start the server")) + .subcommand(Command::new("generate_key").about("Generate a new secret key")) + .arg( + Arg::new("id") + .help("ID of the node") + .long("id") + .value_name("ID") + .env("ID") + .value_parser(clap::value_parser!(u64)), + ) + .arg( + Arg::new("root-dir") + .help("Root directory where application data is stored") + .short('r') + .long("root-dir") + .value_name("ROOT_DIR") + .env("ROOT_DIR") + .value_parser(clap::value_parser!(String)) + .default_value("~/.polybase"), + ) + .arg( + Arg::new("log-level") + .help("Log level") + .long("log-level") + .value_name("LOG_LEVEL") + .env("LOG_LEVEL") + .value_parser(clap::builder::EnumValueParser::::new()) + .default_value("INFO"), + ) + .arg( + Arg::new("log-format") + .help("Log format") + .long("log-format") + .value_name("LOG_FORMAT") + .env("LOG_FORMAT") + .value_parser(clap::builder::EnumValueParser::::new()) + .default_value("PRETTY"), + ) + .arg( + Arg::new("rpc-laddr") + .help("RPC listen address") + .long("rpc-laddr") + .value_name("RPC_LADDR") + .env("RPC_LADDR") + .value_parser(clap::value_parser!(String)) + .default_value("0.0.0.0:8080"), + ) + .arg( + Arg::new("secret-key") + .help("Secret key encoded as hex") + .long("secret-key") + .value_name("SECRET_KEY") + .env("SECRET_KEY") + .value_parser(clap::value_parser!(String)), + ) + .arg( + Arg::new("network-laddr") + .help("Peer listen address") + .long("network-laddr") + .value_name("NETWORK_LADDR") + .env("NETWORK_LADDR") + .value_parser(clap::value_parser!(String)) + .value_delimiter(',') + .default_value("/ip4/0.0.0.0/tcp/0"), + ) + .arg( + Arg::new("dial-addr") + .help("Peers to dial") + .long("dial-addr") + .value_name("DIAL_ADDR") + .env("DIAL_ADDR") + .value_parser(clap::value_parser!(String)) + .value_delimiter(',') + .default_value(""), + ) + .arg( + Arg::new("peers") + .help("Validator peers") + .long("peers") + .value_name("PEERS") + .env("PEERS") + .value_parser(clap::value_parser!(String)) + .value_delimiter(',') + .default_value(""), + ) + .arg( + Arg::new("block-cache-count") + .help("Maximum history of blocks to keep in memory") + .long("block-cache-count") + .value_name("BLOCK_CACHE_COUNT") + .env("BLOCK_CACHE_COUNT") + .value_parser(clap::value_parser!(usize)) + .default_value("1024"), + ) + .arg( + Arg::new("block-txns-count") + .help("Maximum number of txns to include in a block") + .long("block-txns-count") + .value_name("BLOCK_TXNS_COUNT") + .env("BLOCK_TXNS_COUNT") + .value_parser(clap::value_parser!(usize)) + .default_value("1024"), + ) + .arg( + Arg::new("snapshot-chunk-size") + .help("Size of the chunks of data sent during snapshot load") + .long("snapshot-chunk-size") + .value_name("SNAPSHOT_CHUNK_SIZE") + .env("SNAPSHOT_CHUNK_SIZE") + .value_parser(clap::value_parser!(usize)) + .default_value("4194304"), + ) + .arg( + Arg::new("min-block-duration") + .help("Size of the chunks of data sent during snapshot load") + .long("min-block-duration") + .value_name("MIN_BLOCK_DURATION") + .env("MIN_BLOCK_DURATION") + .value_parser(clap::value_parser!(u64)) + .default_value("500"), + ) + .arg( + Arg::new("sentry-dsn") + .help("Sentry DSN") + .long("sentry-dsn") + .value_name("SENTRY_DSN") + .env("SENTRY_DSN") + .value_parser(clap::value_parser!(String)) + .default_value(""), + ) + .arg( + Arg::new("whitelist") + .help("Public key whitelist") + .long("whitelist") + .value_name("WHITELIST") + .env("WHITELIST") + .value_parser(clap::value_parser!(String)) + .value_delimiter(','), + ) + .arg( + Arg::new("restrict-namespaces") + .help("Restrict namespaces to pk//") + .long("restrict-namespaces") + .env("RESTRICT_NAMESPACES") + .action(ArgAction::SetTrue), + ) + .get_matches() +} diff --git a/polybase/src/config/mod.rs b/polybase/src/config/mod.rs new file mode 100644 index 00000000..f01be996 --- /dev/null +++ b/polybase/src/config/mod.rs @@ -0,0 +1,261 @@ +//! Configuration for Polybase - using the CLI (clap), env (clap), and configuration file (toml). + +mod clap_config; +mod toml_config; + +use clap::{parser::ValueSource, ArgMatches, ValueEnum}; +use serde::Deserialize; + +#[derive(Debug, thiserror::Error)] +pub enum ConfigError { + #[error("toml config error")] + TomlConfig(#[from] toml_config::TomlConfigError), +} + +pub type ConfigResult = std::result::Result; + +#[derive(Debug, Deserialize)] +pub struct ExtraConfig {} + +#[derive(Debug, Deserialize)] +pub struct Config { + pub command: Option, + + /// ID of the node + pub id: Option, + + /// Root directory where application data is stored + pub root_dir: String, + + /// Log level + pub log_level: LogLevel, + + /// Log format + pub log_format: LogFormat, + + /// RPC listen address + pub rpc_laddr: String, + + /// Secret key encoded as hex + pub secret_key: Option, + + /// Peer listen address + pub network_laddr: Vec, + + /// Peers to dial + pub dial_addr: Vec, + + /// Validator peers + pub peers: Vec, + + /// Maximum history of blocks to keep in memory + pub block_cache_count: usize, + + /// Maximum number of txns to include in a block + pub block_txns_count: usize, + + /// Size of the chunks of data sent during snapshot load + pub snapshot_chunk_size: usize, + + /// Size of the chunks of data sent during snapshot load + pub min_block_duration: u64, + + /// Sentry DSN + pub sentry_dsn: Option, + + /// Public key whitelist + pub whitelist: Option>, + + /// Restrict namespaces to pk// + pub restrict_namespaces: bool, + + // extra configurations + extra_config: Option, +} + +impl Config { + pub fn new() -> ConfigResult { + let clap_matches = clap_config::get_matches(); + + let mut config: Config = clap_matches.clone().into(); + Self::merge_toml_core_config(&mut config, clap_matches)?; + + Ok(config) + } + + #[allow(dead_code)] + pub fn extra_config(&mut self) -> Option { + self.extra_config.take() + } + + fn was_supplied_by_user(key: &str, matches: &ArgMatches) -> bool { + !matches!(matches.value_source(key), Some(ValueSource::DefaultValue)) + } + + /// The order of priority is (in decreasing order): + /// cli -> env -> toml -> default + /// + /// As such, here we will check if a field with a default value Was + /// supplied by the user. If so, do nothing. If not, if the TOML config + /// has a value for the same field, use that instead. + /// + /// Secondly, if a value for an optional type has not been set, and the TOML config again has a + /// value for it, then set it. + fn merge_toml_core_config(&mut self, matches: ArgMatches) -> ConfigResult<()> { + if let Some(mut toml_config) = toml_config::read_config(&self.root_dir)? { + if self.command.is_none() && toml_config.core.command.is_some() { + self.command = toml_config.core.command.take(); + } + + if self.id.is_none() && toml_config.core.id.is_some() { + self.id = toml_config.core.id.take(); + } + + if !Self::was_supplied_by_user("log-level", &matches) { + self.log_level = toml_config.core.log_level; + } + + if !Self::was_supplied_by_user("log-format", &matches) { + self.log_format = toml_config.core.log_format; + } + + if !Self::was_supplied_by_user("rpc-laddr", &matches) { + self.rpc_laddr = toml_config.core.rpc_laddr; + } + + if self.secret_key.is_none() && toml_config.core.secret_key.is_some() { + self.secret_key = toml_config.core.secret_key.take(); + } + + if !Self::was_supplied_by_user("network-laddr", &matches) { + self.network_laddr = toml_config.core.network_laddr; + } + + if !Self::was_supplied_by_user("dial-addr", &matches) { + self.dial_addr = toml_config.core.dial_addr; + } + + if !Self::was_supplied_by_user("peers", &matches) { + self.peers = toml_config.core.peers; + } + + if !Self::was_supplied_by_user("block-cache-count", &matches) { + self.block_cache_count = toml_config.core.block_cache_count; + } + + if !Self::was_supplied_by_user("block-txns-count", &matches) { + self.block_txns_count = toml_config.core.block_txns_count; + } + + if !Self::was_supplied_by_user("snapshot-chunk-size", &matches) { + self.snapshot_chunk_size = toml_config.core.snapshot_chunk_size; + } + + if !Self::was_supplied_by_user("min-block-duration", &matches) { + self.min_block_duration = toml_config.core.min_block_duration; + } + + if !Self::was_supplied_by_user("sentry-dsn", &matches) + && toml_config.core.sentry_dsn.is_some() + { + self.sentry_dsn = toml_config.core.sentry_dsn.take(); + } + + if self.whitelist.is_none() && toml_config.core.whitelist.is_some() { + self.whitelist = toml_config.core.whitelist.take(); + } + + if !Self::was_supplied_by_user("restrict-namespaces", &matches) { + self.restrict_namespaces = toml_config.core.restrict_namespaces; + } + + self.extra_config = toml_config.extra; + } + + Ok(()) + } +} + +// To convert from an ArgMatches into the main `Config` enity used by Polybase main. +// `clap` does not provide an automated way to do so in builder mode. +#[allow(clippy::unwrap_used)] +impl From for Config { + fn from(am: ArgMatches) -> Self { + Config { + command: { + match am.subcommand() { + Some(("start", _)) => Some(PolybaseCommand::Start), + Some(("generate_key", _)) => Some(PolybaseCommand::GenerateKey), + _ => None, + } + }, + + id: am.get_one::("id").copied(), + root_dir: am.get_one::("root-dir").unwrap().clone(), + log_level: *am.get_one::("log-level").unwrap(), + log_format: *am.get_one::("log-format").unwrap(), + rpc_laddr: am.get_one::("rpc-laddr").unwrap().clone(), + secret_key: am + .get_one::>("secret-key") + .unwrap_or(&None) + .clone(), + network_laddr: am + .get_many::("network-laddr") + .unwrap() + .map(|s| s.to_string()) + .collect::>(), + dial_addr: am + .get_many::("dial-addr") + .unwrap() + .map(|s| s.to_string()) + .collect::>(), + peers: am + .get_many::("peers") + .unwrap() + .map(|s| s.to_string()) + .collect::>(), + block_cache_count: *am.get_one::("block-cache-count").unwrap(), + block_txns_count: *am.get_one::("block-txns-count").unwrap(), + snapshot_chunk_size: *am.get_one::("snapshot-chunk-size").unwrap(), + min_block_duration: *am.get_one::("min-block-duration").unwrap(), + sentry_dsn: Some(am.get_one::("sentry-dsn").unwrap().clone()), + + whitelist: am + .get_many::("whitelist") + .map(|values| values.into_iter().map(String::from).collect::>()), + + restrict_namespaces: *am.get_one::("restrict-namespaces").unwrap(), + extra_config: None, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize)] +pub enum PolybaseCommand { + /// Start the server + #[serde(rename = "start")] + Start, + /// Generate a new secret key + #[serde(rename = "generate_key")] + GenerateKey, +} + +#[derive(Copy, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, ValueEnum)] +#[clap(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum LogLevel { + #[serde(rename = "DEBUG")] + Debug, + #[serde(rename = "INFO")] + Info, + #[serde(rename = "ERROR")] + Error, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug, Deserialize)] +#[clap(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum LogFormat { + #[serde(rename = "PRETTY")] + Pretty, + #[serde(rename = "JSON")] + Json, +} diff --git a/polybase/src/config/toml_config.rs b/polybase/src/config/toml_config.rs new file mode 100644 index 00000000..0513debe --- /dev/null +++ b/polybase/src/config/toml_config.rs @@ -0,0 +1,422 @@ +//! module for handling file-based (TOML) configuration for Polybase. + +use crate::util; +use std::{fmt, fs}; +use toml::{self, Value}; + +use super::{ + Config, ConfigError, ConfigResult, Deserialize, ExtraConfig, LogFormat, LogLevel, + PolybaseCommand, +}; + +#[derive(thiserror::Error)] +pub enum TomlConfigError { + #[error("toml file read error")] + Read(#[from] std::io::Error), + + #[error("toml deserialization error")] + Deserialization(#[from] toml::de::Error), + + #[error("invalid command '{0}': command must be one of 'start' or 'generate_key'")] + InvalidCommand(String), + + #[error( + "`command` must be a string with one of the following values: 'start', 'generate_key'" + )] + InvalidCommandType, + + #[error("`id' must be an unsigned 64-bit value")] + InvalidIdType, + + #[error("invalid log_level: '{0}': log_level must be one of 'DEBUG', 'INFO', or 'ERROR")] + InvalidLogLevel(String), + + #[error( + "`log_level` must be a string with one of the following values: 'DEBUG', 'INFO', 'ERROR'" + )] + InvalidLogLevelType, + + #[error("invalid log_format: '{0}': log_format must be one of 'PRETTY' or 'JSON'")] + InvalidLogFormat(String), + + #[error("`log_format` must be a string with one of the following values: 'PRETTY', 'JSON'")] + InvalidLogFormatType, + + #[error("`rpc_laddr` must be a string")] + InvalidRpcLaddrType, + + #[error("`secret_key` must be a hex string")] + InvalidSecretKeyType, + + #[error("`network_laddr` must be a list of strings delimited by commas")] + InvalidNetworkLaddrType, + + #[error("`dial_addr` must be a list of strings delimited by commas")] + InvalidDialAddrType, + + #[error("`peers` must be a list of strings delimited by commas")] + InvalidPeersType, + + #[error("`block_cache_count` must be an unsigned integer")] + InvalidBlockCacheCountType, + + #[error("`block_txns_count` must be an unsigned integer")] + InvalidBlockTxnsCountType, + + #[error("`snapshot_chunk_size` must be an unsigned integer")] + InvalidSnapshotChunkSizeType, + + #[error("`min_block_duration` must be an unsigned 64-bit value")] + InvalidMinBlockDurationType, + + #[error("`sentry_dsn` must be a string")] + InvalidSentryDsnType, + + #[error("`whitelist` must be a list of strings delimited by commas")] + InvalidWhiteListType, + + #[error("`restrict_namespaces` must be a boolean with value either 'false' or 'true")] + InvalidRestrictNamespacesType, +} + +impl fmt::Debug for TomlConfigError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use TomlConfigError::*; + + write!( + f, + "{}", + match self { + Read(ref e) => format!("TOML file read error. {e:?}"), + Deserialization(ref e) => format!("TOML deserialization error: {e:?}"), + InvalidCommand(ref cmd) => format!( + "invalid command '{cmd}': command must be one of 'start' or 'generate_key'" + ), + InvalidCommandType => "`command` must be a string with one of the following values: 'start', 'generate_key'".into(), + InvalidIdType => "`id` must be an unsigned 64-bit value".into(), + InvalidLogLevel(ref log_level) => format!("invalid log_level: '{log_level}': log_level must be one of 'DEBUG', 'INFO', or 'ERROR"), + InvalidLogLevelType => "`log_level` must be a string with one of the following values: 'DEBUG', 'INFO', 'ERROR'".into(), + InvalidLogFormat(ref log_format) => format!("invalid log_format: '{log_format}': log_format must be one of 'PRETTY' or 'JSON'"), + InvalidLogFormatType => "`log_format` must be a string with one of the following values: 'PRETTY', 'JSON'".into(), + InvalidRpcLaddrType => "`rpc_laddr` must be a string".into(), + InvalidSecretKeyType => "`secret_key` must be a hex string".into(), + InvalidNetworkLaddrType => "`network_laddr` must be a list of strings delimited by commas".into(), + InvalidDialAddrType => "`dial_addr` must be a list of strings delimited by commas".into(), + InvalidPeersType => "`peers` must be a list of strings delimited by commas".into(), + InvalidBlockCacheCountType => "`block_cache_count` must be an unsigned integer".into(), + InvalidBlockTxnsCountType => "`block_txns_count` must be an unsigned integer".into(), + InvalidSnapshotChunkSizeType => "`snapshot_chunk_size` must be an unsigned integer".into(), + InvalidMinBlockDurationType => "`min_block_duration` must be an unsigned 64-bit value".into(), + InvalidSentryDsnType => "`sentry_dsn` must be a string".into(), + InvalidWhiteListType => "`whitelist` must be a list of strings delimited by commas".into(), + InvalidRestrictNamespacesType => "`restrict_namespaces` must be a boolean with value either 'false' or 'true".into(), + } + ) + } +} + +#[derive(Debug, Deserialize)] +pub(crate) struct TomlConfig { + pub core: Config, + pub extra: Option, +} + +/// Read the TOML configuration file, if present in the `config` sub-directory under the +/// root Polybase directory. +pub(super) fn read_config(root_dir: &str) -> ConfigResult> { + util::get_toml_config_file(root_dir, "config").map_or(Ok(None), |config_file| { + if !config_file.exists() { + return Ok(None); + } + + // read the core configuration into the`Config` struct. + let toml_value = toml::from_str::( + fs::read_to_string(config_file) + .map_err(TomlConfigError::from)? + .as_str(), + ) + .map_err(TomlConfigError::from)?; + + let core = read_core_config(&toml_value)?; + let extra = read_extra_config(&toml_value)?; + + Ok(Some(TomlConfig { core, extra })) + }) +} + +/// Read the TOML configuration file and populate two fields: +/// - `core` for the core configurations common to the cli and env, and +/// - `extra` for extra configurations peculiar to the TOML config file. +fn read_core_config(toml_value: &Value) -> ConfigResult { + // default and optional values - separate from the default values + // read in by `clap` + let mut command = None; + let mut id = None; + let root_dir = "~/.polybase".into(); + let mut log_level = LogLevel::Info; + let mut log_format = LogFormat::Pretty; + let mut rpc_laddr = "0.0.0.0:8080".into(); + let mut secret_key = None; + let mut network_laddr = vec!["/ip4/0.0.0.0/tcp/0".into()]; + let mut dial_addr = vec!["".into()]; + let mut peers = vec!["".into()]; + let mut block_cache_count = 1024; + let mut block_txns_count = 1024; + let mut snapshot_chunk_size = 4194304; + let mut min_block_duration = 500; + let mut sentry_dsn = None; + let mut whitelist = None; + let mut restrict_namespaces = false; + + if let Some(core) = toml_value.get("core").and_then(|core| core.as_table()) { + if let Some(toml_cmd) = core.get("command") { + if let Value::String(toml_cmd) = toml_cmd { + command = match toml_cmd.as_str() { + "start" => Some(PolybaseCommand::Start), + "generate_key" => Some(PolybaseCommand::GenerateKey), + _ => { + return Err(ConfigError::from(TomlConfigError::InvalidCommand( + toml_cmd.clone(), + ))) + } + } + } else { + return Err(ConfigError::from(TomlConfigError::InvalidCommandType)); + } + } + + if let Some(toml_id) = core.get("id") { + if let Value::Integer(toml_id) = toml_id { + if *toml_id < 0 { + return Err(ConfigError::from(TomlConfigError::InvalidIdType)); + } + id = Some(*toml_id as u64); + } else { + return Err(ConfigError::from(TomlConfigError::InvalidIdType)); + } + } + + if let Some(toml_log_level) = core.get("log_level") { + if let Value::String(toml_log_level) = toml_log_level { + log_level = match toml_log_level.as_str() { + "DEBUG" => LogLevel::Debug, + "INFO" => LogLevel::Info, + "ERROR" => LogLevel::Error, + _ => { + return Err(ConfigError::from(TomlConfigError::InvalidLogLevel( + toml_log_level.clone(), + ))) + } + } + } else { + return Err(ConfigError::from(TomlConfigError::InvalidLogLevelType)); + } + } + + if let Some(toml_log_format) = core.get("log_format") { + if let Value::String(toml_log_format) = toml_log_format { + log_format = match toml_log_format.as_str() { + "PRETTY" => LogFormat::Pretty, + "JSON" => LogFormat::Json, + _ => { + return Err(ConfigError::from(TomlConfigError::InvalidLogFormat( + toml_log_format.clone(), + ))) + } + } + } else { + return Err(ConfigError::from(TomlConfigError::InvalidLogFormatType)); + } + } + + if let Some(toml_rpc_laddr) = core.get("rpc_laddr") { + if let Value::String(toml_rpc_laddr) = toml_rpc_laddr { + rpc_laddr = toml_rpc_laddr.clone(); + } else { + return Err(ConfigError::from(TomlConfigError::InvalidRpcLaddrType)); + } + } + + if let Some(toml_secret_key) = core.get("secret_key") { + if let Value::String(toml_secret_key) = toml_secret_key { + secret_key = Some(toml_secret_key.clone()); + } else { + return Err(ConfigError::from(TomlConfigError::InvalidSecretKeyType)); + } + } + + if let Some(toml_network_laddr) = core.get("network_laddr") { + if let Value::Array(toml_network_laddr) = toml_network_laddr { + if toml_network_laddr.is_empty() + || !toml_network_laddr.iter().all(|laddr| laddr.is_str()) + { + return Err(ConfigError::from(TomlConfigError::InvalidNetworkLaddrType)); + } + + network_laddr = toml_network_laddr + .iter() + .map(|laddr| laddr.to_string()) + .collect::>(); + } else { + return Err(ConfigError::from(TomlConfigError::InvalidNetworkLaddrType)); + } + } + + if let Some(toml_dial_addr) = core.get("dial_addr") { + if let Value::Array(toml_dial_addr) = toml_dial_addr { + if toml_dial_addr.is_empty() || !toml_dial_addr.iter().all(|addr| addr.is_str()) { + return Err(ConfigError::from(TomlConfigError::InvalidDialAddrType)); + } + + dial_addr = toml_dial_addr + .iter() + .map(|addr| addr.to_string()) + .collect::>(); + } else { + return Err(ConfigError::from(TomlConfigError::InvalidDialAddrType)); + } + } + + if let Some(toml_peers) = core.get("peers") { + if let Value::Array(toml_peers) = toml_peers { + if toml_peers.is_empty() || !toml_peers.iter().all(|peer| peer.is_str()) { + return Err(ConfigError::from(TomlConfigError::InvalidPeersType)); + } + + peers = toml_peers + .iter() + .map(|peer| peer.to_string()) + .collect::>(); + } else { + return Err(ConfigError::from(TomlConfigError::InvalidPeersType)); + } + } + + if let Some(toml_block_cache_count) = core.get("block_cache_count") { + if let Value::Integer(toml_block_cache_count) = toml_block_cache_count { + if *toml_block_cache_count < 0 { + return Err(ConfigError::from( + TomlConfigError::InvalidBlockCacheCountType, + )); + } + block_cache_count = *toml_block_cache_count as usize; + } else { + return Err(ConfigError::from( + TomlConfigError::InvalidBlockCacheCountType, + )); + } + } + + if let Some(toml_block_txns_count) = core.get("block_txns_count") { + if let Value::Integer(toml_block_txns_count) = toml_block_txns_count { + if *toml_block_txns_count < 0 { + return Err(ConfigError::from( + TomlConfigError::InvalidBlockTxnsCountType, + )); + } + block_txns_count = *toml_block_txns_count as usize; + } else { + return Err(ConfigError::from( + TomlConfigError::InvalidBlockTxnsCountType, + )); + } + } + + if let Some(toml_snapshot_chunk_size) = core.get("snapshot_chunk_size") { + if let Value::Integer(toml_snapshot_chunk_size) = toml_snapshot_chunk_size { + if *toml_snapshot_chunk_size < 0 { + return Err(ConfigError::from( + TomlConfigError::InvalidSnapshotChunkSizeType, + )); + } + snapshot_chunk_size = *toml_snapshot_chunk_size as usize; + } else { + return Err(ConfigError::from( + TomlConfigError::InvalidSnapshotChunkSizeType, + )); + } + } + + if let Some(toml_min_block_duration) = core.get("min_block_duration") { + if let Value::Integer(toml_min_block_duration) = toml_min_block_duration { + if *toml_min_block_duration < 0 { + return Err(ConfigError::from( + TomlConfigError::InvalidMinBlockDurationType, + )); + } + + min_block_duration = *toml_min_block_duration as u64; + } else { + return Err(ConfigError::from( + TomlConfigError::InvalidMinBlockDurationType, + )); + } + } + + if let Some(toml_sentry_dsn) = core.get("sentry_dsn") { + if let Value::String(toml_sentry_dsn) = toml_sentry_dsn { + sentry_dsn = Some(toml_sentry_dsn.to_string()); + } else { + return Err(ConfigError::from(TomlConfigError::InvalidSentryDsnType)); + } + } + + if let Some(toml_whitelist) = core.get("whitelist") { + if let Value::Array(toml_whitelist) = toml_whitelist { + if toml_whitelist.is_empty() || !toml_whitelist.iter().all(|wl| wl.is_str()) { + return Err(ConfigError::from(TomlConfigError::InvalidWhiteListType)); + } + + whitelist = Some( + toml_whitelist + .iter() + .map(|wl| wl.to_string()) + .collect::>(), + ); + } else { + return Err(ConfigError::from(TomlConfigError::InvalidWhiteListType)); + } + } + + if let Some(toml_restrict_namespaces) = core.get("restrict_namespaces") { + if let Value::Boolean(toml_restrict_namespaces) = toml_restrict_namespaces { + restrict_namespaces = *toml_restrict_namespaces; + } else { + return Err(ConfigError::from( + TomlConfigError::InvalidRestrictNamespacesType, + )); + } + } + } + + Ok(Config { + command, + id, + root_dir, + log_level, + log_format, + rpc_laddr, + secret_key, + network_laddr, + dial_addr, + peers, + block_cache_count, + block_txns_count, + snapshot_chunk_size, + min_block_duration, + sentry_dsn, + whitelist, + restrict_namespaces, + extra_config: None, + }) +} + +fn read_extra_config(toml_value: &Value) -> ConfigResult> { + let extra = toml_value.get("extra").and_then(|extra| extra.as_table()); + + if let Some(_extra) = extra { + Ok(Some(ExtraConfig {})) + } else { + Ok(None) + } +} diff --git a/polybase/src/errors.rs b/polybase/src/errors.rs index 45af7e7c..18726968 100644 --- a/polybase/src/errors.rs +++ b/polybase/src/errors.rs @@ -1,3 +1,4 @@ +use super::config; use super::db; use super::network; use indexer::IndexerError; @@ -53,4 +54,7 @@ pub enum AppError { "namespace is invalid, must be in format pk// got {0}" )] InvalidNamespace(String), + + #[error("configuration is invalid")] + InvalidConfiguration(#[from] config::ConfigError), } diff --git a/polybase/src/main.rs b/polybase/src/main.rs index 351cf2ea..baad6c56 100644 --- a/polybase/src/main.rs +++ b/polybase/src/main.rs @@ -17,12 +17,11 @@ mod rpc; mod txn; mod util; -use crate::config::{Command, Config, LogFormat}; +use crate::config::{Config, LogFormat, LogLevel, PolybaseCommand}; use crate::db::{Db, DbConfig}; use crate::errors::AppError; use crate::rpc::create_rpc_server; use chrono::Utc; -use clap::Parser; use ed25519_dalek::{self as ed25519}; use futures::StreamExt; use libp2p::PeerId; @@ -46,9 +45,9 @@ type Result = std::result::Result; #[tokio::main] async fn main() -> Result<()> { - let config = Config::parse(); + let config = Config::new()?; - if let Some(Command::GenerateKey) = config.command { + if let Some(PolybaseCommand::GenerateKey) = config.command { let (keypair, bytes) = util::generate_key(); #[allow(clippy::unwrap_used)] let key = ed25519::SecretKey::from_bytes(&bytes).unwrap(); @@ -82,9 +81,9 @@ async fn main() -> Result<()> { // Parse log level let log_level = match &config.log_level { - config::LogLevel::Debug => slog::Level::Debug, - config::LogLevel::Info => slog::Level::Info, - config::LogLevel::Error => slog::Level::Error, + LogLevel::Debug => slog::Level::Debug, + LogLevel::Info => slog::Level::Info, + LogLevel::Error => slog::Level::Error, }; // Create logger drain (json/pretty) diff --git a/polybase/src/util.rs b/polybase/src/util.rs index 6282eb18..a61eee7a 100644 --- a/polybase/src/util.rs +++ b/polybase/src/util.rs @@ -31,6 +31,13 @@ pub(crate) fn get_base_dir(dir: &str) -> Option { Some(path_buf) } +pub(crate) fn get_toml_config_file(dir: &str, config_dir: &str) -> Option { + let mut path_buf = get_base_dir(dir)?; + path_buf.push(config_dir); + path_buf.push("config.toml"); + Some(path_buf) +} + pub(crate) fn to_peer_id(base58_string: &String) -> Result { let decoded = bs58::decode(base58_string).into_vec()?; Ok(solid::peer::PeerId::new(decoded))