diff --git a/nat/src/lib.rs b/nat/src/lib.rs index 412c9161d..3450d49c7 100644 --- a/nat/src/lib.rs +++ b/nat/src/lib.rs @@ -43,8 +43,18 @@ struct NatTranslationData { } impl NatTranslationData { #[must_use] - pub(crate) fn new() -> Self { - Self::default() + pub(crate) fn new( + src_addr: Option, + dst_addr: Option, + src_port: Option, + dst_port: Option, + ) -> Self { + Self { + src_addr, + dst_addr, + src_port, + dst_port, + } } #[must_use] pub(crate) fn src_addr(mut self, address: IpAddr) -> Self { diff --git a/nat/src/portfw/icmp_handling.rs b/nat/src/portfw/icmp_handling.rs index b02c35fd3..ae66c770a 100644 --- a/nat/src/portfw/icmp_handling.rs +++ b/nat/src/portfw/icmp_handling.rs @@ -51,10 +51,10 @@ use tracing::debug; // Build `NatTranslationData` from `PortFwState` to translate the packet embedded in the ICMP error fn as_nat_translation(pfw_state: &PortFwState) -> NatTranslationData { match pfw_state.action { - PortFwAction::SrcNat => NatTranslationData::new() + PortFwAction::SrcNat => NatTranslationData::default() .dst_addr(pfw_state.use_ip().inner()) .dst_port(NatPort::Port(pfw_state.use_port())), - PortFwAction::DstNat => NatTranslationData::new() + PortFwAction::DstNat => NatTranslationData::default() .src_addr(pfw_state.use_ip().inner()) .src_port(NatPort::Port(pfw_state.use_port())), } diff --git a/nat/src/stateful/allocator.rs b/nat/src/stateful/allocator.rs index f7edc4b8c..ff9d98d7c 100644 --- a/nat/src/stateful/allocator.rs +++ b/nat/src/stateful/allocator.rs @@ -3,10 +3,12 @@ //! NAT allocator trait: a trait to build allocators to manage IP addresses and ports for stateful NAT. +use crate::NatPort; use crate::port::NatPortError; use net::ExtendedFlowKey; use net::ip::NextHeader; use std::fmt::{Debug, Display}; +use std::net::IpAddr; use std::time::Duration; #[derive(Debug, Clone, PartialEq, Eq, Hash, thiserror::Error)] @@ -37,42 +39,23 @@ pub enum AllocatorError { /// `AllocationResult` is a struct to represent the result of an allocation. /// -/// It contains the allocated IP addresses and ports for both source and destination NAT for the -/// packet forwarded. In addition, it "reserves" IP addresses and ports for packets on the return -/// path for this flow, and returns them so that the stateful NAT pipeline stage can update the flow -/// table to prepare for the reply. It is necessary to "reserve" the IP and ports at this stage, to -/// limit the risk of another flow accidentally getting the same resources assigned. +/// It contains the allocated IP addresses and ports for source NAT for the packet forwarded. In +/// addition, it "reserves" IP addresses and ports for packets on the return path for this flow, and +/// returns them so that the stateful NAT pipeline stage can update the flow table to prepare for +/// the reply. It is necessary to "reserve" the IP and ports at this stage, to limit the risk of +/// another flow accidentally getting the same resources assigned. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct AllocationResult { pub src: Option, - pub dst: Option, - pub return_src: Option, - pub return_dst: Option, - pub src_flow_idle_timeout: Option, - pub dst_flow_idle_timeout: Option, + pub return_dst: Option<(IpAddr, NatPort)>, + pub idle_timeout: Option, } impl AllocationResult { /// Returns the idle timeout for the flow. - /// - /// # Returns - /// - /// * `Some(Duration)` if at least one of `src_flow_idle_timeout` or `dst_flow_idle_timeout` is set. - /// * `None` if both `src_flow_idle_timeout` and `dst_flow_idle_timeout` are `None`. #[must_use] pub fn idle_timeout(&self) -> Option { - // Use the minimum of the two timeouts (source/destination). - // - // FIXME: We shouldn't use just one of the two timeouts, but doing otherwise will require - // uncoupling entry creation for source and destination NAT. - match (self.src_flow_idle_timeout, self.dst_flow_idle_timeout) { - (Some(src), Some(dst)) => Some(src.min(dst)), - (Some(src), None) => Some(src), - (None, Some(dst)) => Some(dst), - // Given that at least one of alloc.src or alloc.dst is set, we should always have at - // least one timeout set. - (None, None) => None, - } + self.idle_timeout } } @@ -80,17 +63,16 @@ impl Display for AllocationResult { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "src: {}, dst: {}, return_src: {}, return_dst: {}, src_flow_idle_timeout: {:?}, dst_flow_idle_timeout: {:?}", + "src: {}, return_dst: {}, idle_timeout: {:?}", self.src.as_ref().map_or("None".to_string(), T::to_string), - self.dst.as_ref().map_or("None".to_string(), T::to_string), - self.return_src - .as_ref() - .map_or("None".to_string(), T::to_string), self.return_dst .as_ref() - .map_or("None".to_string(), T::to_string), - self.src_flow_idle_timeout, - self.dst_flow_idle_timeout + .map_or("None".to_string(), |(ip, port)| format!( + "{}:{}", + ip, + port.as_u16() + )), + self.idle_timeout, ) } } diff --git a/nat/src/stateful/apalloc/display.rs b/nat/src/stateful/apalloc/display.rs index d458a8342..9adc8a940 100644 --- a/nat/src/stateful/apalloc/display.rs +++ b/nat/src/stateful/apalloc/display.rs @@ -37,10 +37,6 @@ impl Display for NatDefaultAllocator { writeln!(with_indent!(f), "{}", self.pools_src44)?; writeln!(f, "source pools (IPv6):")?; writeln!(with_indent!(f), "{}", self.pools_src66)?; - writeln!(f, "destination pools (IPv4):")?; - writeln!(with_indent!(f), "{}", self.pools_dst44)?; - writeln!(f, "destination pools (IPv6):")?; - writeln!(with_indent!(f), "{}", self.pools_dst66)?; Ok(()) } } diff --git a/nat/src/stateful/apalloc/mod.rs b/nat/src/stateful/apalloc/mod.rs index 7546f5b5e..69f3cbd69 100644 --- a/nat/src/stateful/apalloc/mod.rs +++ b/nat/src/stateful/apalloc/mod.rs @@ -76,7 +76,7 @@ use net::packet::VpcDiscriminant; use net::{ExtendedFlowKey, FlowKey}; use std::collections::BTreeMap; use std::fmt::Display; -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use tracing::error; mod alloc; @@ -160,7 +160,6 @@ impl PoolTable { /// [`AllocatedIpPort`] is the public type for the object returned by our allocator. pub type AllocatedIpPort = port_alloc::AllocatedPort; -type AllocationMapping = (Option>, Option>); impl Display for AllocatedIpPort { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -180,9 +179,7 @@ impl Display for AllocatedIpPort { #[derive(Debug)] pub struct NatDefaultAllocator { pools_src44: PoolTable, - pools_dst44: PoolTable, pools_src66: PoolTable, - pools_dst66: PoolTable, #[cfg(test)] disable_randomness: bool, } @@ -191,9 +188,7 @@ impl NatAllocator, AllocatedIpPort> for NatD fn new() -> Self { Self { pools_src44: PoolTable::new(), - pools_dst44: PoolTable::new(), pools_src66: PoolTable::new(), - pools_dst66: PoolTable::new(), #[cfg(test)] disable_randomness: false, } @@ -203,24 +198,14 @@ impl NatAllocator, AllocatedIpPort> for NatD &self, eflow_key: &ExtendedFlowKey, ) -> Result>, AllocatorError> { - Self::allocate_from_tables( - eflow_key, - &self.pools_src44, - &self.pools_dst44, - self.must_disable_randomness(), - ) + Self::allocate_from_tables(eflow_key, &self.pools_src44, self.must_disable_randomness()) } fn allocate_v6( &self, eflow_key: &ExtendedFlowKey, ) -> Result>, AllocatorError> { - Self::allocate_from_tables( - eflow_key, - &self.pools_src66, - &self.pools_dst66, - self.must_disable_randomness(), - ) + Self::allocate_from_tables(eflow_key, &self.pools_src66, self.must_disable_randomness()) } } @@ -228,14 +213,15 @@ impl NatDefaultAllocator { fn allocate_from_tables( eflow_key: &ExtendedFlowKey, pools_src: &PoolTable, - pools_dst: &PoolTable, disable_randomness: bool, ) -> Result>, AllocatorError> { // get flow key from extended flow key let flow_key = eflow_key.flow_key(); let next_header = Self::get_next_header(flow_key); Self::check_proto(next_header)?; - let (src_vpc_id, dst_vpc_id) = Self::get_vpc_discriminants(eflow_key)?; + let dst_vpc_id = eflow_key + .dst_vpcd() + .ok_or(AllocatorError::MissingDiscriminant)?; // Get address pools for source let pool_src_opt = pools_src.get_entry( @@ -261,52 +247,15 @@ impl NatDefaultAllocator { return Err(AllocatorError::Denied); } - // Get address pools for destination - let pool_dst_opt = pools_dst.get_entry( - next_header, - dst_vpc_id, - NatIp::try_from_addr(*flow_key.data().dst_ip()).map_err(|()| { - AllocatorError::InternalIssue( - "Failed to convert IP address to Ipv4Addr".to_string(), - ) - })?, - ); - // Allocate IP and ports from pools, for source and destination NAT let allow_null = matches!(flow_key.data().proto_key_info(), IpProtoKey::Icmp(_)); - let (src_mapping, dst_mapping) = - Self::get_mapping(pool_src_opt, pool_dst_opt, allow_null, disable_randomness)?; - - // Now based on the previous allocation, we need to "reserve" IP and ports for the reverse - // path for the flow. First retrieve the relevant address pools. - - let reverse_pool_src_opt = if let Some(mapping) = &dst_mapping { - pools_src.get_entry(next_header, src_vpc_id, mapping.ip()) - } else { - None - }; - - let reverse_pool_dst_opt = if let Some(mapping) = &src_mapping { - pools_dst.get_entry(next_header, src_vpc_id, mapping.ip()) - } else { - None - }; - - // Reserve IP and ports for the reverse path for the flow. - let (reverse_src_mapping, reverse_dst_mapping) = Self::get_reverse_mapping( - flow_key, - reverse_pool_src_opt, - reverse_pool_dst_opt, - disable_randomness, - )?; + let src_mapping = Self::get_mapping(pool_src_opt, allow_null, disable_randomness)?; + let reverse_dst_mapping = Self::get_reverse_mapping(flow_key)?; Ok(AllocationResult { src: src_mapping, - dst: dst_mapping, - return_src: reverse_src_mapping, return_dst: reverse_dst_mapping, - src_flow_idle_timeout: pool_src_opt.and_then(IpAllocator::idle_timeout), - dst_flow_idle_timeout: pool_dst_opt.and_then(IpAllocator::idle_timeout), + idle_timeout: pool_src_opt.and_then(IpAllocator::idle_timeout), }) } @@ -335,27 +284,11 @@ impl NatDefaultAllocator { } } - fn get_vpc_discriminants( - eflow_key: &ExtendedFlowKey, - ) -> Result<(VpcDiscriminant, VpcDiscriminant), AllocatorError> { - let src_vpc_id = eflow_key - .flow_key() - .data() - .src_vpcd() - .ok_or(AllocatorError::MissingDiscriminant)?; - - let dst_vpc_id = eflow_key - .dst_vpcd() - .ok_or(AllocatorError::MissingDiscriminant)?; - Ok((src_vpc_id, dst_vpc_id)) - } - fn get_mapping( pool_src_opt: Option<&alloc::IpAllocator>, - pool_dst_opt: Option<&alloc::IpAllocator>, allow_null: bool, disable_randomness: bool, - ) -> Result, AllocatorError> { + ) -> Result>, AllocatorError> { // Allocate IP and ports for source and destination NAT. // // In the case of ICMP Query messages, use dst_mapping to hold an allocated identifier @@ -372,69 +305,19 @@ impl NatDefaultAllocator { None => None, }; - let dst_mapping = match pool_dst_opt { - Some(pool_dst) => Some(pool_dst.allocate(allow_null, disable_randomness)?), - None => None, - }; - - Ok((src_mapping, dst_mapping)) + Ok(src_mapping) } - fn get_reverse_mapping( + fn get_reverse_mapping( flow_key: &FlowKey, - reverse_pool_src_opt: Option<&alloc::IpAllocator>, - reverse_pool_dst_opt: Option<&alloc::IpAllocator>, - disable_randomness: bool, - ) -> Result, AllocatorError> { - let reverse_src_mapping = match reverse_pool_src_opt { - Some(pool_src) => { - let reservation_src_port_number = match flow_key.data().proto_key_info() { - IpProtoKey::Tcp(tcp) => tcp.dst_port.into(), - IpProtoKey::Udp(udp) => udp.dst_port.into(), - // FIXME: We're doing a useless port reservation here, but without reserving a - // "port" (or an ID for ICMP) we can't reserve an IP, given the current - // architecture of the allocator. The ID will be overwritten by the ID for the - // destination mapping. Note: this does not mean we're exhausting allocatable - // identifiers sooner, because we allocate from a ports/identifier pool we don't - // need. - IpProtoKey::Icmp(icmp) => NatPort::Identifier(Self::get_icmp_query_id(icmp)?), - }; - - Some(pool_src.reserve( - NatIp::try_from_addr(*flow_key.data().dst_ip()).map_err(|()| { - AllocatorError::InternalIssue( - "Failed to convert IP address to Ipv4Addr".to_string(), - ) - })?, - reservation_src_port_number, - disable_randomness, - )?) - } - None => None, + ) -> Result, AllocatorError> { + let reverse_target_ip = *flow_key.data().src_ip(); + let reverse_target_port = match flow_key.data().proto_key_info() { + IpProtoKey::Tcp(tcp) => tcp.src_port.into(), + IpProtoKey::Udp(udp) => udp.src_port.into(), + IpProtoKey::Icmp(icmp) => NatPort::Identifier(Self::get_icmp_query_id(icmp)?), }; - - let reverse_dst_mapping = match reverse_pool_dst_opt { - Some(pool_dst) => { - let reservation_dst_port_number = match flow_key.data().proto_key_info() { - IpProtoKey::Tcp(tcp) => tcp.src_port.into(), - IpProtoKey::Udp(udp) => udp.src_port.into(), - IpProtoKey::Icmp(icmp) => NatPort::Identifier(Self::get_icmp_query_id(icmp)?), - }; - - Some(pool_dst.reserve( - NatIp::try_from_addr(*flow_key.data().src_ip()).map_err(|()| { - AllocatorError::InternalIssue( - "Failed to convert IP address to Ipv4Addr".to_string(), - ) - })?, - reservation_dst_port_number, - disable_randomness, - )?) - } - None => None, - }; - - Ok((reverse_src_mapping, reverse_dst_mapping)) + Ok(Some((reverse_target_ip, reverse_target_port))) } fn get_icmp_query_id(key: &IcmpProtoKey) -> Result { diff --git a/nat/src/stateful/apalloc/setup.rs b/nat/src/stateful/apalloc/setup.rs index 0cf430805..af5430dd8 100644 --- a/nat/src/stateful/apalloc/setup.rs +++ b/nat/src/stateful/apalloc/setup.rs @@ -57,84 +57,34 @@ impl NatDefaultAllocator { ) -> Result<(), AllocatorError> { let new_peering = collapse_prefixes_peering(peering); - // Update tables for source NAT - self.build_src_nat_pool_for_expose(&new_peering, dst_vpc_id)?; - - // Update table for destination NAT - self.build_dst_nat_pool_for_expose(&new_peering, dst_vpc_id)?; - - Ok(()) - } - - fn build_src_nat_pool_for_expose( - &mut self, - peering: &Peering, - dst_vpc_id: VpcDiscriminant, - ) -> Result<(), AllocatorError> { build_nat_pool_generic( - &peering.local, + &new_peering.local, dst_vpc_id, VpcManifest::stateful_nat_exposes_44, VpcManifest::port_forwarding_exposes_44, - VpcExpose::as_range_or_empty, - |expose| &expose.ips, &mut self.pools_src44, NextHeader::ICMP, )?; build_nat_pool_generic( - &peering.local, + &new_peering.local, dst_vpc_id, VpcManifest::stateful_nat_exposes_66, VpcManifest::port_forwarding_exposes_66, - VpcExpose::as_range_or_empty, - |expose| &expose.ips, &mut self.pools_src66, NextHeader::ICMP6, ) } - - fn build_dst_nat_pool_for_expose( - &mut self, - peering: &Peering, - dst_vpc_id: VpcDiscriminant, - ) -> Result<(), AllocatorError> { - build_nat_pool_generic( - &peering.remote, - dst_vpc_id, - VpcManifest::stateful_nat_exposes_44, - VpcManifest::port_forwarding_exposes_44, - |expose| &expose.ips, - VpcExpose::as_range_or_empty, - &mut self.pools_dst44, - NextHeader::ICMP, - )?; - - build_nat_pool_generic( - &peering.remote, - dst_vpc_id, - VpcManifest::stateful_nat_exposes_66, - VpcManifest::port_forwarding_exposes_66, - |expose| &expose.ips, - VpcExpose::as_range_or_empty, - &mut self.pools_dst66, - NextHeader::ICMP6, - ) - } } #[allow(clippy::too_many_arguments)] -fn build_nat_pool_generic<'a, I: NatIpWithBitmap, J: NatIpWithBitmap, F, FIter, G, H, P, PIter>( +fn build_nat_pool_generic<'a, I: NatIpWithBitmap, J: NatIpWithBitmap, F, FIter, P, PIter>( manifest: &'a VpcManifest, dst_vpc_id: VpcDiscriminant, // A filter to select relevant exposes: those with stateful NAT, for the relevant IP version exposes_filter: F, // A filter to select other exposes with port forwarding, for the relevant IP version port_forwarding_exposes_filter: P, - // A function to get the list of prefixes to translate into - original_prefixes_from_expose: G, - // A function to get the list of prefixes to translate from - target_prefixes_from_expose: H, table: &mut PoolTable, icmp_proto: NextHeader, ) -> Result<(), AllocatorError> @@ -143,8 +93,6 @@ where FIter: Iterator, P: FnOnce(&'a VpcManifest) -> PIter, PIter: Iterator, - G: Fn(&'a VpcExpose) -> &'a PrefixPortsSet, - H: Fn(&'a VpcExpose) -> &'a PrefixPortsSet, { let port_forwarding_exposes: Vec<&'a VpcExpose> = port_forwarding_exposes_filter(manifest).collect(); @@ -157,24 +105,24 @@ where .unwrap_or(DEFAULT_MASQUERADE_IDLE_TIMEOUT); let tcp_ip_allocator = ip_allocator_for_prefixes( - original_prefixes_from_expose(expose), + expose.as_range_or_empty(), idle_timeout, &prefixes_and_ports_to_exclude_from_pools.tcp, )?; let udp_ip_allocator = ip_allocator_for_prefixes( - original_prefixes_from_expose(expose), + expose.as_range_or_empty(), idle_timeout, &prefixes_and_ports_to_exclude_from_pools.udp, )?; let icmp_ip_allocator = ip_allocator_for_prefixes( - original_prefixes_from_expose(expose), + expose.as_range_or_empty(), idle_timeout, &PrefixPortsSet::default(), )?; add_pool_entries( table, - target_prefixes_from_expose(expose), + &expose.ips, dst_vpc_id, &tcp_ip_allocator, &udp_ip_allocator, diff --git a/nat/src/stateful/apalloc/test_alloc.rs b/nat/src/stateful/apalloc/test_alloc.rs index 1de0208c7..99eba2da4 100644 --- a/nat/src/stateful/apalloc/test_alloc.rs +++ b/nat/src/stateful/apalloc/test_alloc.rs @@ -74,10 +74,7 @@ mod context { "".to_string() } }; - println!("src: {}", format_ip_port(&allocation.src)); - println!("dst: {}", format_ip_port(&allocation.dst)); - println!("return_src: {}", format_ip_port(&allocation.return_src)); - println!("return_dst: {}", format_ip_port(&allocation.return_dst)); + println!("allocation: {allocation}"); } pub fn get_ip_allocator_v4( @@ -107,12 +104,7 @@ mod context { .unwrap() .not_as("10.1.0.3/32".into()) .unwrap(); - let expose2 = VpcExpose::empty() - .make_stateful_nat(None) - .unwrap() - .ip("2.0.0.0/16".into()) - .as_range("10.2.0.0/29".into()) - .unwrap(); + let expose2 = VpcExpose::empty().ip("2.0.0.0/16".into()); let manifest1 = VpcManifest { name: "VPC-1".into(), @@ -126,14 +118,7 @@ mod context { .ip("3.0.1.0/24".into()) .as_range("10.3.0.0/30".into()) .unwrap(); - let expose4 = VpcExpose::empty() - .make_stateful_nat(None) - .unwrap() - .ip("4.0.0.0/16".into()) - .as_range("10.4.0.0/31".into()) - .unwrap() - .as_range("10.4.1.0/30".into()) - .unwrap(); + let expose4 = VpcExpose::empty().ip("4.0.0.0/16".into()); let manifest2 = VpcManifest { name: "VPC-2".into(), @@ -219,7 +204,7 @@ mod std_tests { .keys() .filter(|k| k.protocol == NextHeader::TCP) .count(), - 7 + 5 ); assert_eq!( allocator @@ -228,39 +213,10 @@ mod std_tests { .keys() .filter(|k| k.protocol == NextHeader::UDP) .count(), - 7 - ); - - assert!( - allocator - .pools_dst44 - .0 - .keys() - .all(|k| k.dst_id == vpcd2() || k.dst_id == vpcd1()) - ); - // One entry for each ".as_range()" from the VPCExpose objects, - // after exclusion ranges have been applied - assert_eq!( - allocator - .pools_dst44 - .0 - .keys() - .filter(|k| k.protocol == NextHeader::TCP) - .count(), - 6 - ); - assert_eq!( - allocator - .pools_dst44 - .0 - .keys() - .filter(|k| k.protocol == NextHeader::UDP) - .count(), - 6 + 5 ); assert_eq!(allocator.pools_src66.0.len(), 0); - assert_eq!(allocator.pools_dst66.0.len(), 0); let ip_allocator = allocator .pools_src44 @@ -276,21 +232,6 @@ mod std_tests { assert!(bitmap.contains_range(addr_v4_bits("10.1.0.0")..=addr_v4_bits("10.1.0.2"))); assert_eq!(bitmap.len(), 3); assert_eq!(in_use.len(), 0); - - let ip_allocator = allocator - .pools_dst44 - .get(&PoolTableKey::new( - NextHeader::TCP, - vpcd2(), - addr_v4("10.3.0.0"), - addr_v4("255.255.255.255"), - )) - .unwrap(); - let (bitmap, in_use) = ip_allocator.get_pool_clone_for_tests(); - - assert!(bitmap.contains_range(addr_v4_bits("3.0.0.0")..=addr_v4_bits("3.0.1.255"))); - assert_eq!(bitmap.len(), 512); - assert_eq!(in_use.len(), 0); } // Allocate IP addresses and ports for running NAT on a tuple from a simple packet. Ensure that @@ -321,27 +262,19 @@ mod std_tests { print_allocation(&allocation); assert!(allocation.src.is_some()); - assert!(allocation.dst.is_some()); - assert!(allocation.return_src.is_some()); assert!(allocation.return_dst.is_some()); assert_eq!(allocation.src.as_ref().unwrap().ip(), addr_v4("10.1.0.0")); - assert_eq!(allocation.dst.as_ref().unwrap().ip(), addr_v4("3.0.0.0")); - assert_eq!( - allocation.return_src.as_ref().unwrap().ip(), - addr_v4("10.3.0.2") - ); - assert_eq!( - allocation.return_src.as_ref().unwrap().port().as_u16(), - 5678 - ); assert_eq!( - allocation.return_dst.as_ref().unwrap().ip(), - addr_v4("1.1.0.0") + allocation.return_dst.as_ref().map(|(ip, _)| ip), + Some(&ipaddr("1.1.0.0")) ); assert_eq!( - allocation.return_dst.as_ref().unwrap().port().as_u16(), - 1234 + allocation + .return_dst + .as_ref() + .map(|(_, port)| port.as_u16()), + Some(1234) ); let (bitmap, in_use) = get_ip_allocator_v4( diff --git a/nat/src/stateful/icmp_handling.rs b/nat/src/stateful/icmp_handling.rs index f9e0f60b0..91b0114de 100644 --- a/nat/src/stateful/icmp_handling.rs +++ b/nat/src/stateful/icmp_handling.rs @@ -3,18 +3,14 @@ //! Handling of ICMP errors in stateful NAT (masquerading) -use std::net::{Ipv4Addr, Ipv6Addr}; - use crate::icmp_handler::icmp_error_msg::nat_translate_icmp_inner; +use crate::stateful::NatFlowState; use net::buffer::PacketBufferMut; use net::flows::ExtractRef; use net::flows::FlowInfo; use net::headers::{TryInnerIp, TryIp, TryIpMut}; use net::packet::{DoneReason, Packet}; - -use crate::StatefulNat; -use crate::stateful::NatFlowState; - +use std::net::{Ipv4Addr, Ipv6Addr}; use tracing::debug; pub(crate) fn handle_icmp_error_masquerading( @@ -31,19 +27,17 @@ pub(crate) fn handle_icmp_error_masquerading( .src_addr() .is_ipv4() { - let nat_state = flow_info_locked + flow_info_locked .nat_state .extract_ref::>() - .unwrap_or_else(|| unreachable!()); - - StatefulNat::get_translation_data(&nat_state.dst_alloc, &nat_state.src_alloc) + .unwrap_or_else(|| unreachable!()) + .reverse_translation_data() } else { - let nat_state = flow_info_locked + flow_info_locked .nat_state .extract_ref::>() - .unwrap_or_else(|| unreachable!()); - - StatefulNat::get_translation_data(&nat_state.dst_alloc, &nat_state.src_alloc) + .unwrap_or_else(|| unreachable!()) + .reverse_translation_data() }; // translate inner packet fragment diff --git a/nat/src/stateful/mod.rs b/nat/src/stateful/mod.rs index 0309ec464..6799175bc 100644 --- a/nat/src/stateful/mod.rs +++ b/nat/src/stateful/mod.rs @@ -1,19 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright Open Network Fabric Authors -mod allocator; +pub(crate) mod allocator; mod allocator_writer; pub mod apalloc; pub(crate) mod icmp_handling; mod natip; +mod state; mod test; use super::NatTranslationData; use crate::stateful::allocator::{AllocationResult, AllocatorError, NatAllocator}; use crate::stateful::allocator_writer::NatAllocatorReader; -use crate::stateful::apalloc::AllocatedIpPort; -use crate::stateful::apalloc::{NatDefaultAllocator, NatIpWithBitmap}; +use crate::stateful::apalloc::{AllocatedIpPort, NatDefaultAllocator, NatIpWithBitmap}; use crate::stateful::natip::NatIp; +use crate::stateful::state::NatFlowState; pub use allocator_writer::NatAllocatorWriter; use concurrency::sync::Arc; use flow_entry::flow_table::FlowTable; @@ -24,7 +25,7 @@ use net::headers::{Net, Transport, TryIp, TryIpMut, TryTransportMut}; use net::packet::{DoneReason, Packet, VpcDiscriminant}; use net::{FlowKey, IpProtoKey}; use pipeline::{NetworkFunction, PipelineData}; -use std::fmt::{Debug, Display}; +use std::fmt::Debug; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::time::{Duration, Instant}; @@ -56,27 +57,6 @@ enum StatefulNatError { UnexpectedKeyVariant, } -#[derive(Debug, Clone)] -pub(crate) struct NatFlowState { - src_alloc: Option>, - dst_alloc: Option>, - idle_timeout: Duration, -} - -impl Display for NatFlowState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.src_alloc.as_ref() { - Some(a) => write!(f, "({}:{}, ", a.ip(), a.port().as_u16()), - None => write!(f, "(unchanged, "), - }?; - match self.dst_alloc.as_ref() { - Some(a) => write!(f, "{}:{})", a.ip(), a.port().as_u16()), - None => write!(f, "unchanged)"), - }?; - write!(f, "[{}s]", self.idle_timeout.as_secs()) - } -} - /// A stateful NAT processor, implementing the [`NetworkFunction`] trait. [`StatefulNat`] processes /// packets to run source or destination Network Address Translation (NAT) on their IP addresses. #[derive(Debug)] @@ -145,10 +125,9 @@ impl StatefulNat { let flow_info = packet.meta_mut().flow_info.as_mut()?; let value = flow_info.locked.read().unwrap(); let state = value.nat_state.as_ref()?.extract_ref::>()?; - flow_info.reset_expiry(state.idle_timeout).ok()?; + flow_info.reset_expiry(state.idle_timeout()).ok()?; flow_info.set_genid_pair(genid); // unconditionally (FIXME) - let translation_data = Self::get_translation_data(&state.src_alloc, &state.dst_alloc); - Some(translation_data) + Some(state.translation_data()) } // Look up for a session by passing the parameters that make up a flow key. @@ -167,8 +146,7 @@ impl StatefulNat { let flow_info = self.sessions.lookup(&flow_key)?; let value = flow_info.locked.read().unwrap(); let state = value.nat_state.as_ref()?.extract_ref::>()?; - let translation_data = Self::get_translation_data(&state.src_alloc, &state.dst_alloc); - Some((translation_data, state.idle_timeout)) + Some((state.translation_data(), state.idle_timeout())) } fn session_timeout_time(timeout: Duration) -> Instant { @@ -208,7 +186,7 @@ impl StatefulNat { let reverse_key = Self::new_reverse_session(flow_key, &alloc, dst_vpc_id)?; // build NAT state for both flows - let (forward_state, reverse_state) = Self::new_states_from_alloc(alloc, idle_timeout); + let (forward_state, reverse_state) = NatFlowState::new_pair_from_alloc(alloc, idle_timeout); // build a flow pair from the keys (without NAT state) let expires_at = Self::session_timeout_time(idle_timeout); @@ -327,23 +305,6 @@ impl StatefulNat { } } - fn new_states_from_alloc( - alloc: AllocationResult>, - idle_timeout: Duration, - ) -> (NatFlowState, NatFlowState) { - let forward_state = NatFlowState { - src_alloc: alloc.src, - dst_alloc: alloc.dst, - idle_timeout, - }; - let reverse_state = NatFlowState { - src_alloc: alloc.return_src, - dst_alloc: alloc.return_dst, - idle_timeout, - }; - (forward_state, reverse_state) - } - fn new_reverse_session( flow_key: &FlowKey, alloc: &AllocationResult>, @@ -359,18 +320,14 @@ impl StatefulNat { // - tuple r.init = (src: f.nated.dst, dst: f.nated.src) // - mapping r.nated = (src: f.init.dst, dst: f.init.src) - let (reverse_src_addr, allocated_src_port_to_use) = - match alloc.dst.as_ref().map(|a| (a.ip(), a.port())) { - Some((ip, port)) => (ip.to_ip_addr(), Some(port)), - // No destination NAT for forward session: - // f.init:(src: a, dst: b) -> f.nated:(src: A, dst: b) - // - // Reverse session will be: - // r.init:(src: b, dst: A) -> r.nated:(src: b, dst: a) - // - // Use destination IP and port from forward tuple. - None => (*flow_key.data().dst_ip(), None), - }; + // No destination NAT for forward session: + // f.init:(src: a, dst: b) -> f.nated:(src: A, dst: b) + // + // Reverse session will be: + // r.init:(src: b, dst: A) -> r.nated:(src: b, dst: a) + // + // Use destination IP and port from forward tuple. + let reverse_src_addr = *flow_key.data().dst_ip(); let (reverse_dst_addr, allocated_dst_port_to_use) = match alloc.src.as_ref().map(|a| (a.ip(), a.port())) { Some((ip, port)) => (ip.to_ip_addr(), Some(port)), @@ -380,31 +337,6 @@ impl StatefulNat { // Reverse the forward protocol key... let mut reverse_proto_key = flow_key.data().proto_key_info().reverse(); // ... but adjust ports as necessary (use allocated ports for the reverse session) - if let Some(src_port) = allocated_src_port_to_use { - match reverse_proto_key { - IpProtoKey::Tcp(_) | IpProtoKey::Udp(_) => { - reverse_proto_key - .try_set_src_port( - src_port - .try_into() - .map_err(|_| StatefulNatError::InvalidPort(src_port.as_u16()))?, - ) - .map_err(|_| StatefulNatError::BadTransportHeader)?; - } - IpProtoKey::Icmp(IcmpProtoKey::QueryMsgData(_)) => { - // For ICMP, we only need to set the identifier once. Use the "dst_port" below if - // available, otherwise, use the "src_port" here. - if allocated_dst_port_to_use.is_none() { - reverse_proto_key - .try_set_identifier(src_port.as_u16()) - .map_err(|_| StatefulNatError::BadTransportHeader)?; - } - } - IpProtoKey::Icmp(_) => { - return Err(StatefulNatError::UnexpectedKeyVariant); - } - } - } if let Some(dst_port) = allocated_dst_port_to_use { match reverse_proto_key { IpProtoKey::Tcp(_) | IpProtoKey::Udp(_) => { @@ -472,7 +404,7 @@ impl StatefulNat { debug!("{}: Allocated translation data: {alloc}", self.name()); - let translation_data = Self::get_translation_data(&alloc.src, &alloc.dst); + let translation_data = Self::get_translation_data(&alloc.src, &None); self.create_flow_pair(packet, &flow_key, alloc)?; diff --git a/nat/src/stateful/state.rs b/nat/src/stateful/state.rs new file mode 100644 index 000000000..e694c0815 --- /dev/null +++ b/nat/src/stateful/state.rs @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use super::NatIpWithBitmap; +use super::allocator::AllocationResult; +use super::apalloc::AllocatedIpPort; +use crate::{NatPort, NatTranslationData}; +use std::fmt::Display; +use std::time::Duration; + +#[derive(Debug, Clone)] +pub(crate) enum NatFlowState { + Allocated(AllocatedFlowState), + Computed(ComputedFlowState), +} + +impl NatFlowState { + pub(crate) fn new_pair_from_alloc( + alloc: AllocationResult>, + idle_timeout: Duration, + ) -> (Self, Self) { + ( + Self::Allocated(AllocatedFlowState { + src_alloc: alloc.src, + dst_alloc: None, + idle_timeout, + }), + Self::Computed(ComputedFlowState { + src: None, + dst: alloc.return_dst.map(|(addr, port)| { + ( + I::try_from_addr(addr).unwrap_or_else(|()| unreachable!()), + port, + ) + }), + idle_timeout, + }), + ) + } + + pub(crate) fn idle_timeout(&self) -> Duration { + match self { + NatFlowState::Allocated(allocated) => allocated.idle_timeout, + NatFlowState::Computed(computed) => computed.idle_timeout, + } + } + + pub(crate) fn translation_data(&self) -> NatTranslationData { + match self { + NatFlowState::Allocated(allocated) => allocated.translation_data(), + NatFlowState::Computed(computed) => computed.translation_data(), + } + } + + pub(crate) fn reverse_translation_data(&self) -> NatTranslationData { + match self { + NatFlowState::Allocated(allocated) => allocated.reverse_translation_data(), + NatFlowState::Computed(computed) => computed.reverse_translation_data(), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct AllocatedFlowState { + src_alloc: Option>, + dst_alloc: Option>, + idle_timeout: Duration, +} + +impl AllocatedFlowState { + fn build_translation_data( + src: Option<&AllocatedIpPort>, + dst: Option<&AllocatedIpPort>, + ) -> NatTranslationData { + let (src_addr, src_port) = src + .as_ref() + .map(|a| (a.ip().to_ip_addr(), a.port())) + .unzip(); + let (dst_addr, dst_port) = dst + .as_ref() + .map(|a| (a.ip().to_ip_addr(), a.port())) + .unzip(); + NatTranslationData::new(src_addr, dst_addr, src_port, dst_port) + } + + fn translation_data(&self) -> NatTranslationData { + Self::build_translation_data(self.src_alloc.as_ref(), self.dst_alloc.as_ref()) + } + + fn reverse_translation_data(&self) -> NatTranslationData { + Self::build_translation_data(self.dst_alloc.as_ref(), self.src_alloc.as_ref()) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ComputedFlowState { + src: Option<(I, NatPort)>, + dst: Option<(I, NatPort)>, + idle_timeout: Duration, +} + +impl ComputedFlowState { + fn build_translation_data( + src: Option<(I, NatPort)>, + dst: Option<(I, NatPort)>, + ) -> NatTranslationData { + let (src_addr, src_port) = src.map(|(ip, port)| (ip.to_ip_addr(), port)).unzip(); + let (dst_addr, dst_port) = dst.map(|(ip, port)| (ip.to_ip_addr(), port)).unzip(); + NatTranslationData::new(src_addr, dst_addr, src_port, dst_port) + } + + fn translation_data(&self) -> NatTranslationData { + Self::build_translation_data(self.src, self.dst) + } + + fn reverse_translation_data(&self) -> NatTranslationData { + Self::build_translation_data(self.dst, self.src) + } +} + +// From / TryFrom + +impl From> for NatFlowState { + fn from(value: AllocatedFlowState) -> Self { + NatFlowState::Allocated(value) + } +} + +impl From> for NatFlowState { + fn from(value: ComputedFlowState) -> Self { + NatFlowState::Computed(value) + } +} + +impl TryFrom> for AllocatedFlowState { + type Error = (); + + fn try_from(value: NatFlowState) -> Result { + match value { + NatFlowState::Allocated(allocated) => Ok(allocated), + NatFlowState::Computed(_) => Err(()), + } + } +} + +impl From> for ComputedFlowState { + fn from(value: AllocatedFlowState) -> Self { + ComputedFlowState { + src: value.src_alloc.map(|a| (a.ip(), a.port())), + dst: value.dst_alloc.map(|a| (a.ip(), a.port())), + idle_timeout: value.idle_timeout, + } + } +} + +impl From> for ComputedFlowState { + fn from(value: NatFlowState) -> Self { + match value { + NatFlowState::Allocated(allocated) => allocated.into(), + NatFlowState::Computed(computed) => computed, + } + } +} + +// Display + +impl Display for NatFlowState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + NatFlowState::Allocated(allocated) => allocated.fmt(f), + NatFlowState::Computed(computed) => computed.fmt(f), + } + } +} + +impl Display for AllocatedFlowState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.src_alloc.as_ref() { + Some(a) => write!(f, "({}:{}, ", a.ip(), a.port().as_u16()), + None => write!(f, "(unchanged, "), + }?; + match self.dst_alloc.as_ref() { + Some(a) => write!(f, "{}:{})", a.ip(), a.port().as_u16()), + None => write!(f, "unchanged)"), + }?; + write!(f, "[{}s]", self.idle_timeout.as_secs()) + } +} + +impl Display for ComputedFlowState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.src.as_ref() { + Some((ip, port)) => write!(f, "({ip}:{}, ", port.as_u16()), + None => write!(f, "(unchanged, "), + }?; + match self.dst.as_ref() { + Some((ip, port)) => write!(f, "{ip}:{})", port.as_u16()), + None => write!(f, "unchanged)"), + }?; + write!(f, "[{}s]", self.idle_timeout.as_secs()) + } +} diff --git a/nat/src/stateful/test.rs b/nat/src/stateful/test.rs index 6629b4436..d84742fe9 100644 --- a/nat/src/stateful/test.rs +++ b/nat/src/stateful/test.rs @@ -849,9 +849,9 @@ mod tests { orig_dst_port, ); assert_eq!(done_reason, None); - assert_eq!(output_src, addr_v4(target_src)); assert_eq!(output_dst, addr_v4(orig_dst)); + // Reverse path let ( return_output_src, @@ -868,11 +868,11 @@ mod tests { output_dst_port, output_src_port, ); + assert_eq!(done_reason, None); assert_eq!(return_output_src, addr_v4(orig_dst)); assert_eq!(return_output_dst, addr_v4(orig_src)); assert_eq!(return_output_src_port, orig_dst_port); assert_eq!(return_output_dst_port, orig_src_port); - assert_eq!(done_reason, None); // Using the default expose let (orig_src, orig_dst, orig_src_port, orig_dst_port) = @@ -888,9 +888,9 @@ mod tests { orig_dst_port, ); assert_eq!(done_reason, None); - assert_eq!(output_src, addr_v4(target_src)); assert_eq!(output_dst, addr_v4(orig_dst)); + // Reverse path let ( return_output_src, @@ -907,11 +907,11 @@ mod tests { output_dst_port, output_src_port, ); + assert_eq!(done_reason, None); assert_eq!(return_output_src, addr_v4(orig_dst)); assert_eq!(return_output_dst, addr_v4(orig_src)); assert_eq!(return_output_src_port, orig_dst_port); assert_eq!(return_output_dst_port, orig_src_port); - assert_eq!(done_reason, None); } fn build_overlay_3vpcs_unidirectional_nat_overlapping_addr() -> Overlay {