Config deserialization field type protection

Signed-off-by: 21pages <pages21@163.com>
This commit is contained in:
21pages 2023-05-16 17:57:40 +08:00
parent 5bd8befb0f
commit 33fb415b9d

View File

@ -105,7 +105,7 @@ macro_rules! serde_field_string {
where where
D: de::Deserializer<'de>, D: de::Deserializer<'de>,
{ {
let s: &str = de::Deserialize::deserialize(deserializer)?; let s: &str = de::Deserialize::deserialize(deserializer).unwrap_or_default();
Ok(if s.is_empty() { Ok(if s.is_empty() {
Self::$default_func() Self::$default_func()
} else { } else {
@ -119,7 +119,7 @@ macro_rules! serde_field_bool {
($struct_name: ident, $field_name: literal, $func: ident, $default: literal) => { ($struct_name: ident, $field_name: literal, $func: ident, $default: literal) => {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct $struct_name { pub struct $struct_name {
#[serde(default = $default, rename = $field_name)] #[serde(default = $default, rename = $field_name, deserialize_with = "deserialize_bool")]
pub v: bool, pub v: bool,
} }
impl Default for $struct_name { impl Default for $struct_name {
@ -143,59 +143,63 @@ pub enum NetworkType {
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct Config { pub struct Config {
#[serde(default, skip_serializing_if = "String::is_empty")] #[serde(
default,
skip_serializing_if = "String::is_empty",
deserialize_with = "deserialize_string"
)]
pub id: String, // use pub id: String, // use
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
enc_id: String, // store enc_id: String, // store
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
password: String, password: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
salt: String, salt: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_keypair")]
key_pair: KeyPair, // sk, pk key_pair: KeyPair, // sk, pk
#[serde(default)] #[serde(default, deserialize_with = "deserialize_bool")]
key_confirmed: bool, key_confirmed: bool,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_hashmap_string_bool")]
keys_confirmed: HashMap<String, bool>, keys_confirmed: HashMap<String, bool>,
} }
#[derive(Debug, Default, PartialEq, Serialize, Deserialize, Clone)] #[derive(Debug, Default, PartialEq, Serialize, Deserialize, Clone)]
pub struct Socks5Server { pub struct Socks5Server {
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub proxy: String, pub proxy: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub username: String, pub username: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub password: String, pub password: String,
} }
// more variable configs // more variable configs
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct Config2 { pub struct Config2 {
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
rendezvous_server: String, rendezvous_server: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_i32")]
nat_type: i32, nat_type: i32,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_i32")]
serial: i32, serial: i32,
#[serde(default)] #[serde(default)]
socks: Option<Socks5Server>, socks: Option<Socks5Server>,
// the other scalar value must before this // the other scalar value must before this
#[serde(default)] #[serde(default, deserialize_with = "deserialize_hashmap_string_string")]
pub options: HashMap<String, String>, pub options: HashMap<String, String>,
} }
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct PeerConfig { pub struct PeerConfig {
#[serde(default)] #[serde(default, deserialize_with = "deserialize_vec_u8")]
pub password: Vec<u8>, pub password: Vec<u8>,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_size")]
pub size: Size, pub size: Size,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_size")]
pub size_ft: Size, pub size_ft: Size,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_size")]
pub size_pf: Size, pub size_pf: Size,
#[serde( #[serde(
default = "PeerConfig::default_view_style", default = "PeerConfig::default_view_style",
@ -225,9 +229,9 @@ pub struct PeerConfig {
pub privacy_mode: PrivacyMode, pub privacy_mode: PrivacyMode,
#[serde(flatten)] #[serde(flatten)]
pub allow_swap_key: AllowSwapKey, pub allow_swap_key: AllowSwapKey,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_vec_i32_string_i32")]
pub port_forwards: Vec<(i32, String, i32)>, pub port_forwards: Vec<(i32, String, i32)>,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_i32")]
pub direct_failures: i32, pub direct_failures: i32,
#[serde(flatten)] #[serde(flatten)]
pub disable_audio: DisableAudio, pub disable_audio: DisableAudio,
@ -237,7 +241,7 @@ pub struct PeerConfig {
pub enable_file_transfer: EnableFileTransfer, pub enable_file_transfer: EnableFileTransfer,
#[serde(flatten)] #[serde(flatten)]
pub show_quality_monitor: ShowQualityMonitor, pub show_quality_monitor: ShowQualityMonitor,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub keyboard_mode: String, pub keyboard_mode: String,
#[serde(flatten)] #[serde(flatten)]
pub view_only: ViewOnly, pub view_only: ViewOnly,
@ -246,7 +250,7 @@ pub struct PeerConfig {
#[serde(default, deserialize_with = "PeerConfig::deserialize_options")] #[serde(default, deserialize_with = "PeerConfig::deserialize_options")]
pub options: HashMap<String, String>, // not use delete to represent default values pub options: HashMap<String, String>, // not use delete to represent default values
// Various data for flutter ui // Various data for flutter ui
#[serde(default)] #[serde(default, deserialize_with = "deserialize_hashmap_string_string")]
pub ui_flutter: HashMap<String, String>, pub ui_flutter: HashMap<String, String>,
#[serde(default)] #[serde(default)]
pub info: PeerInfoSerde, pub info: PeerInfoSerde,
@ -256,48 +260,51 @@ pub struct PeerConfig {
#[derive(Debug, PartialEq, Default, Serialize, Deserialize, Clone)] #[derive(Debug, PartialEq, Default, Serialize, Deserialize, Clone)]
pub struct PeerInfoSerde { pub struct PeerInfoSerde {
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub username: String, pub username: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub hostname: String, pub hostname: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub platform: String, pub platform: String,
} }
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct ConfigOidc { pub struct ConfigOidc {
#[serde(default)] #[serde(default, deserialize_with = "deserialize_usize")]
pub max_auth_count: usize, pub max_auth_count: usize,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub callback_url: String, pub callback_url: String,
#[serde(default)] #[serde(
default,
deserialize_with = "deserialize_hashmap_string_configoidcprovider"
)]
pub providers: HashMap<String, ConfigOidcProvider>, pub providers: HashMap<String, ConfigOidcProvider>,
} }
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct ConfigOidcProvider { pub struct ConfigOidcProvider {
// seconds. 0 means never expires // seconds. 0 means never expires
#[serde(default)] #[serde(default, deserialize_with = "deserialize_u32")]
pub refresh_token_expires_in: u32, pub refresh_token_expires_in: u32,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub client_id: String, pub client_id: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub client_secret: String, pub client_secret: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_option_string")]
pub issuer: Option<String>, pub issuer: Option<String>,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_option_string")]
pub authorization_endpoint: Option<String>, pub authorization_endpoint: Option<String>,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_option_string")]
pub token_endpoint: Option<String>, pub token_endpoint: Option<String>,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_option_string")]
pub userinfo_endpoint: Option<String>, pub userinfo_endpoint: Option<String>,
} }
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
pub struct TransferSerde { pub struct TransferSerde {
#[serde(default)] #[serde(default, deserialize_with = "deserialize_vec_string")]
pub write_jobs: Vec<String>, pub write_jobs: Vec<String>,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_vec_string")]
pub read_jobs: Vec<String>, pub read_jobs: Vec<String>,
} }
@ -1148,18 +1155,18 @@ serde_field_bool!(
#[derive(Debug, Default, Serialize, Deserialize, Clone)] #[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct LocalConfig { pub struct LocalConfig {
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
remote_id: String, // latest used one remote_id: String, // latest used one
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
kb_layout_type: String, kb_layout_type: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_size")]
size: Size, size: Size,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_vec_string")]
pub fav: Vec<String>, pub fav: Vec<String>,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_hashmap_string_string")]
options: HashMap<String, String>, options: HashMap<String, String>,
// Various data for flutter ui // Various data for flutter ui
#[serde(default)] #[serde(default, deserialize_with = "deserialize_hashmap_string_string")]
ui_flutter: HashMap<String, String>, ui_flutter: HashMap<String, String>,
} }
@ -1267,17 +1274,17 @@ impl LocalConfig {
#[derive(Debug, Default, Serialize, Deserialize, Clone)] #[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct DiscoveryPeer { pub struct DiscoveryPeer {
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub id: String, pub id: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub username: String, pub username: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub hostname: String, pub hostname: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_string")]
pub platform: String, pub platform: String,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_bool")]
pub online: bool, pub online: bool,
#[serde(default)] #[serde(default, deserialize_with = "deserialize_hashmap_string_string")]
pub ip_mac: HashMap<String, String>, pub ip_mac: HashMap<String, String>,
} }
@ -1289,6 +1296,7 @@ impl DiscoveryPeer {
#[derive(Debug, Default, Serialize, Deserialize, Clone)] #[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct LanPeers { pub struct LanPeers {
#[serde(default, deserialize_with = "deserialize_vec_discoverypeer")]
pub peers: Vec<DiscoveryPeer>, pub peers: Vec<DiscoveryPeer>,
} }
@ -1324,7 +1332,7 @@ impl LanPeers {
#[derive(Debug, Default, Serialize, Deserialize, Clone)] #[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct HwCodecConfig { pub struct HwCodecConfig {
#[serde(default)] #[serde(default, deserialize_with = "deserialize_hashmap_string_string")]
pub options: HashMap<String, String>, pub options: HashMap<String, String>,
} }
@ -1354,7 +1362,7 @@ impl HwCodecConfig {
#[derive(Debug, Default, Serialize, Deserialize, Clone)] #[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct UserDefaultConfig { pub struct UserDefaultConfig {
#[serde(default)] #[serde(default, deserialize_with = "deserialize_hashmap_string_string")]
options: HashMap<String, String>, options: HashMap<String, String>,
} }
@ -1451,6 +1459,34 @@ impl ConfigOidc {
} }
} }
// use default value when field type is wrong
macro_rules! deserialize_default {
($func_name:ident, $return_type:ty) => {
fn $func_name<'de, D>(deserializer: D) -> Result<$return_type, D::Error>
where
D: de::Deserializer<'de>,
{
Ok(de::Deserialize::deserialize(deserializer).unwrap_or_default())
}
};
}
deserialize_default!(deserialize_string, String);
deserialize_default!(deserialize_bool, bool);
deserialize_default!(deserialize_i32, i32);
deserialize_default!(deserialize_u32, u32);
deserialize_default!(deserialize_usize, usize);
deserialize_default!(deserialize_vec_u8, Vec<u8>);
deserialize_default!(deserialize_vec_string, Vec<String>);
deserialize_default!(deserialize_vec_i32_string_i32, Vec<(i32, String, i32)>);
deserialize_default!(deserialize_vec_discoverypeer, Vec<DiscoveryPeer>);
deserialize_default!(deserialize_keypair, KeyPair);
deserialize_default!(deserialize_size, Size);
deserialize_default!(deserialize_option_string, Option<String>);
deserialize_default!(deserialize_hashmap_string_string, HashMap<String, String>);
deserialize_default!(deserialize_hashmap_string_bool, HashMap<String, bool>);
deserialize_default!(deserialize_hashmap_string_configoidcprovider, HashMap<String, ConfigOidcProvider>);
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -1464,4 +1500,38 @@ mod tests {
let res = toml::to_string_pretty(&cfg); let res = toml::to_string_pretty(&cfg);
assert!(res.is_ok()); assert!(res.is_ok());
} }
#[test]
fn test_config_deserialize() {
let wrong_type_str = r#"
id = true
enc_id = []
password = 1
salt = "123456"
key_pair = {}
key_confirmed = "1"
keys_confirmed = 1
"#;
let cfg = toml::from_str::<Config>(wrong_type_str);
assert_eq!(
cfg,
Ok(Config {
salt: "123456".to_string(),
..Default::default()
})
);
let wrong_field_str = r#"
hello = "world"
key_confirmed = true
"#;
let cfg = toml::from_str::<Config>(wrong_field_str);
assert_eq!(
cfg,
Ok(Config {
key_confirmed: true,
..Default::default()
})
);
}
} }