diff --git a/Cargo.lock b/Cargo.lock index b884591..f48e6e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,7 +39,9 @@ dependencies = [ "anyhow", "clap", "env_logger", + "ipnet", "log", + "netaddr2", "serde", "serde_json", "ureq", @@ -187,6 +189,12 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "ipnet" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f2d64f2edebec4ce84ad108148e67e1064789bee435edc5b60ad398714a3a9" + [[package]] name = "itoa" version = "0.4.7" @@ -235,6 +243,12 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" +[[package]] +name = "netaddr2" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6680d6da37e96edea724d1b065d82b516040191a38453a6a440d8000eb3b479" + [[package]] name = "once_cell" version = "1.8.0" diff --git a/Cargo.toml b/Cargo.toml index 3ffe0b9..9da2edf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,9 @@ license = "MIT" anyhow = "1.0.41" clap = "3.0.0-beta.2" env_logger = "0.8.4" +ipnet = "2.3.1" log = "0.4.14" +netaddr2 = "0.10.0" serde = { version = "1.0.126", features = ["derive"] } serde_json = "1.0.64" ureq = { version = "2.1.1", features = ["json"] } diff --git a/src/main.rs b/src/main.rs index b4cc9f1..13b19a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -30,7 +30,7 @@ struct ConfigOpts { #[clap(short, long)] no_dns: bool, - #[clap(short='4', long)] + #[clap(short = '4', long)] no_ipv6: bool, } @@ -96,6 +96,13 @@ struct WireguardKeyPair { } impl WireguardConfigData { + fn addresses(&self) -> Result, ipnet::AddrParseError> { + self.address + .split(',') + .map(|s: &str| -> Result { s.trim().parse() }) + .collect() + } + fn dns(&self) -> Result, AddrParseError> { self.dns .split(',') @@ -169,12 +176,17 @@ fn write_config( debug!("endpoint_addr = {:?}", &endpoint_addr); writeln!(output, "[Interface]")?; writeln!(output, "PrivateKey = {}", &keys.private_key)?; - writeln!(output, "Address = {}", &config.data.address)?; + let addresses = config.data.addresses()?; + let allowed_addresses = addresses + .iter() + .filter(|addr| addr.addr().is_ipv4() || !config_opts.no_ipv6); + write_list(output, "Address = ", allowed_addresses)?; if !config_opts.no_dns { let dns_addrs = config.data.dns()?; - let allowed_dns_addrs = dns_addrs.iter().filter(|addr| addr.is_ipv4() || !config_opts.no_ipv6); - write!(output, "DNS = ")?; - write_list(output, allowed_dns_addrs)?; + let allowed_dns_addrs = dns_addrs + .iter() + .filter(|addr| addr.is_ipv4() || !config_opts.no_ipv6); + write_list(output, "DNS = ", allowed_dns_addrs)?; } writeln!(output)?; @@ -186,11 +198,12 @@ fn write_config( Ok(()) } -fn write_list(output: &mut dyn Write, values: I) -> Result<(), std::io::Error> +fn write_list(output: &mut dyn Write, prefix: &str, values: I) -> Result<(), std::io::Error> where I: IntoIterator, T: std::fmt::Display, { + write!(output, "{}", prefix)?; for (i, value) in values.into_iter().enumerate() { if i != 0 { write!(output, ", ")?;