diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9ae72286c..9e42b837a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -105,15 +105,12 @@ jobs: matrix: include: - name: "easytier" - opts: "-E 'not test(connector::dns_connector::tests) and not test(tests::three_node)' --test-threads 1 --no-fail-fast" + opts: "-E 'not test(tests::three_node)' --test-threads 1 --no-fail-fast" - - name: "easytier::connector::dns_connector::tests" - opts: "-E 'test(connector::dns_connector::tests)' --test-threads 1 --no-fail-fast" - - - name: "easytier::tests::three_node" + - name: "three_node" opts: "-E 'test(tests::three_node) and not test(subnet_proxy_three_node_test)' --test-threads 1 --no-fail-fast" - - name: "easytier::tests::three_node::subnet_proxy_three_node_test" + - name: "three_node::subnet_proxy_three_node_test" opts: "-E 'test(subnet_proxy_three_node_test)' --test-threads 1 --no-fail-fast" steps: - uses: actions/checkout@v3 diff --git a/Cargo.lock b/Cargo.lock index b27a6825a..77e118d2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5238,11 +5238,10 @@ dependencies = [ [[package]] name = "num-bigint-dig" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +checksum = "e661dda6640fad38e827a6d4a310ff4763082116fe217f279885c97f511bb0b7" dependencies = [ - "byteorder", "lazy_static", "libm", "num-integer", diff --git a/README.md b/README.md index 8b6cd7ce9..78cda8621 100644 --- a/README.md +++ b/README.md @@ -48,31 +48,36 @@ Choose the installation method that best suits your needs: +Linux (Recommended): ```bash -# 1. Download pre-built binary (Recommended, All platforms supported) -# Visit https://github.com/EasyTier/EasyTier/releases +curl -fsSL "https://github.com/EasyTier/EasyTier/blob/main/script/install.sh?raw=true" | sudo bash -s install +``` -# 2. Install via cargo (Latest development version) -cargo install --git https://github.com/EasyTier/EasyTier.git easytier +Homebrew (MacOS/Linux): +```bash +brew tap brewforge/chinese +brew install --cask easytier-gui +``` + +Windows (Recommended, run with administrator privileges): +```powershell +irm "https://github.com/EasyTier/EasyTier/blob/main/script/install.ps1?raw=true" | iex +``` -# 3. Install via Docker -# See https://easytier.cn/en/guide/installation.html#installation-methods +Install via cargo (Latest development version): +```bash +cargo install --git https://github.com/EasyTier/EasyTier.git easytier +``` -# 4. Linux Quick Install -wget -O- https://raw.githubusercontent.com/EasyTier/EasyTier/main/script/install.sh | sudo bash -s install +[Install pre-built binary](https://github.com/EasyTier/EasyTier/releases) (Recommended, All platforms supported) -# 5. MacOS via Homebrew -brew tap brewforge/chinese -brew install --cask easytier-gui +[Install via Docker](https://easytier.cn/en/guide/installation.html#installation-methods) -# 6. OpenWrt Luci Web UI -# Visit https://github.com/EasyTier/luci-app-easytier +[Install OpenWrt ipk package](https://github.com/EasyTier/luci-app-easytier) -# 7. (Optional) Install shell completions: -easytier-core --gen-autocomplete fish > ~/.config/fish/completions/easytier-core.fish -easytier-cli gen-autocomplete fish > ~/.config/fish/completions/easytier-cli.fish +Additional steps: -``` +[One-Click Register Service](https://easytier.cn/en/guide/network/oneclick-install-as-service.html) (Automatically start when the system boots and run in the background) ### 🚀 Basic Usage diff --git a/README_CN.md b/README_CN.md index b90560a15..2e9d31fa9 100644 --- a/README_CN.md +++ b/README_CN.md @@ -48,32 +48,36 @@ 选择最适合您需求的安装方式: +Linux(推荐): ```bash -# 1. 下载预编译二进制文件(推荐,支持所有平台) -# 访问 https://github.com/EasyTier/EasyTier/releases +curl -fsSL "https://github.com/EasyTier/EasyTier/blob/main/script/install.sh?raw=true" | sudo bash -s install +``` -# 2. 通过 cargo 安装(最新开发版本) -cargo install --git https://github.com/EasyTier/EasyTier.git easytier +Homebrew(MacOS/Linux): +```bash +brew tap brewforge/chinese +brew install --cask easytier-gui +``` + +Windows(推荐,请以管理员权限运行): +```powershell +irm "https://github.com/EasyTier/EasyTier/blob/main/script/install.ps1?raw=true" | iex +``` -# 3. 通过 Docker 安装 -# 参见 https://easytier.cn/guide/installation.html#%E5%AE%89%E8%A3%85%E6%96%B9%E5%BC%8F +通过 cargo 安装(最新开发版本): +```bash +cargo install --git https://github.com/EasyTier/EasyTier.git easytier +``` -# 4. Linux 快速安装 -wget -O- https://raw.githubusercontent.com/EasyTier/EasyTier/main/script/install.sh | sudo bash -s install +[下载预编译文件](https://github.com/EasyTier/EasyTier/releases)(推荐,支持所有平台) -# 5. MacOS 通过 Homebrew 安装 -brew tap brewforge/chinese -brew install --cask easytier-gui +[通过 Docker 安装](https://easytier.cn/guide/installation.html#%E5%AE%89%E8%A3%85%E6%96%B9%E5%BC%8F) -# 6. OpenWrt Luci Web 界面 -# 访问 https://github.com/EasyTier/luci-app-easytier +[安装 OpenWrt ipk 软件包](https://github.com/EasyTier/luci-app-easytier) -# 7.(可选)安装 Shell 补全功能: -# Fish 补全 -easytier-core --gen-autocomplete fish > ~/.config/fish/completions/easytier-core.fish -easytier-cli gen-autocomplete fish > ~/.config/fish/completions/easytier-cli.fish +附加步骤: -``` +[一键注册系统服务](https://easytier.cn/guide/network/oneclick-install-as-service.html)(系统启动时自动后台运行) ### 🚀 基本用法 diff --git a/easytier-gui/src-tauri/src/lib.rs b/easytier-gui/src-tauri/src/lib.rs index 354bce90b..b72a50251 100644 --- a/easytier-gui/src-tauri/src/lib.rs +++ b/easytier-gui/src-tauri/src/lib.rs @@ -472,11 +472,17 @@ async fn init_web_client(app: AppHandle, url: Option) -> Result<(), Stri let hooks = Arc::new(manager::GuiHooks { app: app.clone() }); - let web_client = - web_client::run_web_client(url.as_str(), None, None, instance_manager, Some(hooks)) - .await - .with_context(|| "Failed to initialize web client") - .map_err(|e| format!("{:#}", e))?; + let web_client = web_client::run_web_client( + url.as_str(), + None, + None, + false, + instance_manager, + Some(hooks), + ) + .await + .with_context(|| "Failed to initialize web client") + .map_err(|e| format!("{:#}", e))?; *web_client_guard = Some(web_client); Ok(()) } diff --git a/easytier-web/src/client_manager/mod.rs b/easytier-web/src/client_manager/mod.rs index a3e510bc5..390daccb2 100644 --- a/easytier-web/src/client_manager/mod.rs +++ b/easytier-web/src/client_manager/mod.rs @@ -13,6 +13,7 @@ use easytier::{ }, rpc_service::remote_client::{self, RemoteClientManager}, tunnel::TunnelListener, + web_client::security, }; use maxminddb::geoip2; use session::{Location, Session}; @@ -99,12 +100,20 @@ impl ClientManager { let feature_flags = self.feature_flags.clone(); self.tasks.spawn(async move { while let Ok(tunnel) = listener.accept().await { + let (tunnel, secure) = match security::accept_or_upgrade_server_tunnel(tunnel).await { + Ok(v) => v, + Err(error) => { + tracing::warn!(%error, "failed to accept secure tunnel, dropping connection"); + continue; + } + }; let info = tunnel.info().unwrap(); let client_url: url::Url = info.remote_addr.unwrap().into(); let location = Self::lookup_location(&client_url, geoip_db.clone()); tracing::info!( - "New session from {:?}, location: {:?}", + "New session from {:?}, secure: {}, location: {:?}", client_url, + secure, location ); let mut session = Session::new( @@ -326,26 +335,36 @@ mod tests { connector, "test", "test", + false, Arc::new(NetworkInstanceManager::new()), None, ); wait_for_condition( - || async { mgr.client_sessions.len() == 1 }, - Duration::from_secs(6), + || async { !mgr.client_sessions.is_empty() }, + Duration::from_secs(12), ) .await; - let mut a = mgr - .client_sessions - .iter() - .next() - .unwrap() - .data() - .read() - .await - .heartbeat_waiter(); - let req = a.recv().await.unwrap(); + let req = tokio::time::timeout(Duration::from_secs(12), async { + loop { + let session = mgr + .client_sessions + .iter() + .next() + .map(|item| item.value().clone()); + let Some(session) = session else { + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + }; + let mut waiter = session.data().read().await.heartbeat_waiter(); + if let Ok(req) = waiter.recv().await { + break req; + } + } + }) + .await + .unwrap(); println!("{:?}", req); println!("{:?}", mgr); } diff --git a/easytier-web/src/client_manager/session.rs b/easytier-web/src/client_manager/session.rs index 4e7a3ecd9..40b3980e6 100644 --- a/easytier-web/src/client_manager/session.rs +++ b/easytier-web/src/client_manager/session.rs @@ -169,6 +169,16 @@ impl WebServerService for SessionRpcService { } ret } + + async fn get_feature( + &self, + _: BaseController, + _: easytier::proto::web::GetFeatureRequest, + ) -> rpc_types::error::Result { + Ok(easytier::proto::web::GetFeatureResponse { + support_encryption: true, + }) + } } pub struct Session { diff --git a/easytier/docs/credential_peer.md b/easytier/docs/credential_peer.md new file mode 100644 index 000000000..bf9837ad9 --- /dev/null +++ b/easytier/docs/credential_peer.md @@ -0,0 +1,724 @@ +# 临时凭据(Credential)系统实现计划 + +## Context + +EasyTier 的 secure mode 已实现 Noise XX 握手 + X25519 静态公钥认证。当前节点通过 `network_secret` 双向确认身份。用户需要一种"临时凭据"机制: + +- **管理节点**(任何持有 network_secret 的节点)可为当前网络生成凭据 +- **新节点**可使用凭据代替 `network_secret` 加入网络 +- **管理节点**可撤销凭据 +- **撤销后**,使用该凭据接入的节点被全网踢出 + +**核心设计**:凭据 = X25519 密钥对。完全复用现有 Noise `Noise_XX_25519_ChaChaPoly_SHA256` 握手流程,无需修改握手消息格式。通过 OSPF 路由同步传播可信公钥列表,撤销时全网自然断开。 + +## 整体架构 + +``` +凭据 = X25519 密钥对 + - 管理节点生成密钥对,将公钥加入可信列表 + - 临时节点持有私钥,用作 Noise static key + - 全网通过 OSPF 路由同步可信公钥列表 + +管理节点 (持有 network_secret): + 1. generate_credential() → 生成 X25519 密钥对 + 2. 公钥记入 trusted_credential_pubkeys → 随 RoutePeerInfo 通过 OSPF 传播 + 3. revoke → 从 trusted 列表移除 → OSPF 同步 → 全网感知 + +临时节点 (持有凭据私钥): + 1. 使用凭据私钥作为 SecureModeConfig.local_private_key + 2. Noise 握手完全走现有流程(XX 模式交换 static pubkey) + 3. 不持有 network_secret,secret_proof 验证会失败,但公钥在可信列表中即可 + 4. RoutePeerInfo.noise_static_pubkey 自然携带凭据公钥 + +校验逻辑(每个节点在路由同步时执行): + 1. 从全网 RoutePeerInfo 中收集管理节点的 trusted_credential_pubkeys(取并集) + **安全约束: 仅信任 secure_auth_level=NetworkSecretConfirmed 的节点发布的列表** + 临时节点(CredentialAuthenticated)发布的 trusted_credential_pubkeys 必须被忽略 + 2. 对每个 peer,如果其 secure_auth_level < NetworkSecretConfirmed: + - 检查其 noise_static_pubkey 是否在可信公钥集合中 + - 不在 → 从路由表移除 → 断开连接 +``` + +## 详细设计 + +### Step 1: Protobuf 定义 + +**文件: `easytier/src/proto/peer_rpc.proto`** + +在 `RoutePeerInfo` 新增字段(利用已有 `noise_static_pubkey` 字段 #18): +```protobuf +message TrustedCredentialPubkey { + bytes pubkey = 1; // X25519 公钥 (32 bytes) + repeated string groups = 2; // 该凭据所属的 ACL group(管理节点声明,无需 proof) + bool allow_relay = 3; // 是否允许该临时节点提供 peer relay 能力 + int64 expiry_unix = 4; // 必选:过期时间(Unix timestamp),过期后自动失效 + repeated string allowed_proxy_cidrs = 5; // 允许该临时节点声明的 proxy_cidrs 范围 +} + +message RoutePeerInfo { + // ... existing fields 1-18 ... + // 管理节点发布的可信凭据公钥列表(含 group 关联) + repeated TrustedCredentialPubkey trusted_credential_pubkeys = 19; +} +``` + +临时节点无需新字段——其 `noise_static_pubkey`(字段 18)已经在 OSPF 中传播,只需在校验端判断该公钥是否在可信列表中。 + +新增 `SecureAuthLevel` 枚举值: +```protobuf +enum SecureAuthLevel { + None = 0; + EncryptedUnauthenticated = 1; + SharedNodePubkeyVerified = 2; + NetworkSecretConfirmed = 3; + CredentialAuthenticated = 4; // 新增:凭据公钥已验证 +} +``` + +**文件: `easytier/src/proto/api_instance.proto`** + +新增凭据管理 RPC: +```protobuf +message GenerateCredentialRequest { + repeated string groups = 1; // 可选: 凭据关联的 ACL group + bool allow_relay = 2; // 可选: 是否允许该临时节点提供 peer relay + repeated string allowed_proxy_cidrs = 3; // 可选: 限制可声明的 proxy_cidrs + int64 ttl_seconds = 4; // 必选: 凭据有效期(秒) +} +message GenerateCredentialResponse { + string credential_id = 1; // 公钥的 base64 + string credential_secret = 2; // 私钥的 base64 +} +message RevokeCredentialRequest { string credential_id = 1; } +message RevokeCredentialResponse { bool success = 1; } +message ListCredentialsRequest {} +message CredentialInfo { + string credential_id = 1; // 公钥 base64 + google.protobuf.Timestamp created_at = 2; +} +message ListCredentialsResponse { repeated CredentialInfo credentials = 1; } + +service CredentialManageRpc { + rpc GenerateCredential(GenerateCredentialRequest) returns (GenerateCredentialResponse); + rpc RevokeCredential(RevokeCredentialRequest) returns (RevokeCredentialResponse); + rpc ListCredentials(ListCredentialsRequest) returns (ListCredentialsResponse); +} +``` + +### Step 2: 凭据管理模块 + +**新文件: `easytier/src/peers/credential_manager.rs`** + +```rust +use x25519_dalek::{StaticSecret, PublicKey}; + +pub struct CredentialManager { + // 本节点管理的可信凭据 + credentials: DashMap, // credential_id (pubkey base64) -> entry + storage_path: Option, // 可选: 凭据 JSON 文件路径 +} + +struct CredentialEntry { + pubkey_bytes: [u8; 32], + groups: Vec, // 关联的 ACL group(管理节点声明) + allow_relay: bool, // 是否允许 relay + allowed_proxy_cidrs: Vec, // 允许声明的 proxy_cidrs 范围 + expiry: SystemTime, // 过期时间(必选) + created_at: SystemTime, +} + +impl CredentialManager { + /// 生成新凭据(含 group 关联) + /// 返回 (credential_id=公钥base64, credential_secret=私钥base64) + pub fn generate_credential(&self, groups: Vec, allow_relay: bool, expiry: SystemTime) -> (String, String) { + let private = StaticSecret::random_from_rng(OsRng); + let public = PublicKey::from(&private); + let id = BASE64_STANDARD.encode(public.as_bytes()); + let secret = BASE64_STANDARD.encode(private.as_bytes()); + self.credentials.insert(id.clone(), CredentialEntry { + pubkey_bytes: *public.as_bytes(), + groups, + allow_relay, + expiry, // 由调用方传入 + created_at: SystemTime::now(), + }); + self.save_to_disk(); // 持久化 + (id, secret) + } + + /// 撤销凭据 + pub fn revoke_credential(&self, credential_id: &str) -> bool; + + /// 获取可信凭据列表(用于 RoutePeerInfo.trusted_credential_pubkeys) + pub fn get_trusted_pubkeys(&self) -> Vec; + + /// 列出所有凭据 + pub fn list_credentials(&self) -> Vec; +} +``` + +### Step 3: Noise 握手适配(最小改动) + +**文件: `easytier/src/peers/peer_conn.rs`** + +临时节点的握手流程**完全不需要修改**,因为: +- 临时节点配置 `SecureModeConfig { enabled: true, local_private_key: 凭据私钥, local_public_key: 凭据公钥 }` +- `get_keypair()` (line 434) 自然返回凭据密钥对 +- Noise XX 握手正常交换 static pubkey +- 唯一区别:`secret_proof_32` 验证会失败(临时节点没有 network_secret) + +需要修改 `do_noise_handshake_as_server()` (line 934): +- **当前行为**: `secret_proof` 验证失败 → 返回错误断开连接 (line 1059) +- **修改为**: `secret_proof` 验证失败时,不立即断开,而是将 `secure_auth_level` 保持为 `EncryptedUnauthenticated` +- 后续由 OSPF 路由同步阶段决定该 peer 是否可信(公钥是否在 trusted 列表中) + +同样修改 `do_noise_handshake_as_client()` (line 680): +- 当临时节点连接管理节点时,`secret_proof` 验证失败不应报错 +- 临时节点可以通过 `pinned_remote_pubkey` 或不验证来处理 + +**NoiseHandshakeResult** 新增: +```rust +// 标记此连接使用了凭据而非 network_secret +is_credential_conn: bool, +``` + +### Step 4: RoutePeerInfo 传播凭据信息 + +**文件: `easytier/src/peers/peer_ospf_route.rs`** + +修改 `RoutePeerInfo::new_updated_self()` (line 164): +- 管理节点(持有 network_secret): 从 `CredentialManager.get_trusted_pubkeys()` 获取列表,填入 `trusted_credential_pubkeys` +- 临时节点: **不填写 `trusted_credential_pubkeys`**(该字段留空),即使收到其他管理节点传播的列表也不转发 + - 实现方式: 在 `new_updated_self()` 中检查节点身份,临时节点跳过 trusted_credential_pubkeys 填充 +- 临时节点: 无需额外操作,`noise_static_pubkey` 已自然包含凭据公钥 + +### Step 5: 全网校验与自动踢出(核心逻辑) + +**文件: `easytier/src/peers/peer_ospf_route.rs`** + +在 `SyncedRouteInfo` 中新增: +```rust +// 从全网管理节点汇总的可信凭据公钥集合 +trusted_credential_pubkeys: DashSet>, // pubkey bytes +``` + +新增校验方法(类似 `verify_and_update_group_trusts` line 743): +```rust +fn verify_credential_peers(&self, peer_infos: &[RoutePeerInfo]) { + // 1. 收集管理节点的 trusted_credential_pubkeys(取并集) + // **安全约束: 仅信任 secret_digest 与本网络匹配的节点(即持有 network_secret 的管理节点)** + // 临时节点的 trusted_credential_pubkeys 直接忽略,防止恶意临时节点自我授权 + let mut all_trusted = HashSet::new(); + for info in peer_infos { + if self.is_peer_secret_verified(info.peer_id) { + // 该 peer 通过了 network_secret 双向确认,是合法管理节点 + for tc in &info.trusted_credential_pubkeys { + all_trusted.insert(tc.pubkey.clone()); + } + } + // else: 该 peer 未通过 network_secret 确认(含临时节点),忽略其 trusted 列表 + } + self.trusted_credential_pubkeys = all_trusted; + + // 2. 检查所有 peer 的凭据状态 + for info in peer_infos { + if !self.is_peer_secret_verified(info.peer_id) + && !info.noise_static_pubkey.is_empty() + { + if !self.trusted_credential_pubkeys.contains(&info.noise_static_pubkey) { + // 该 peer 既不持有 network_secret,其公钥也不在可信列表中 + // → 标记为不可信,后续从路由表移除 + self.mark_peer_untrusted(info.peer_id); + } + } + } +} +``` + +在 `do_sync_route_info()` (line 2614) 中调用此校验。 + +在路由表构建中(`update_route_table_and_cached_local_conn_bitmap()`): +- 不可信 peer 不加入路由图 +- 已连接的不可信 peer 调用 `PeerMap::close_peer()` 断开 + +**判断 peer 是否持有 network_secret**: 利用现有 `secret_digest` 字段。管理节点的 `RoutePeerInfo` 中 `secret_digest` 与本节点匹配,说明双方持有相同的 network_secret。 + +### Step 6: GlobalCtx / Config 集成 + +**文件: `easytier/src/common/global_ctx.rs`** + +在 `GlobalCtx` 新增: +```rust +credential_manager: Arc, // 所有节点都持有,管理节点用于生成/撤销 +``` + +**文件: `easytier/src/common/global_ctx.rs` - `GlobalCtxEvent`** + +新增: +```rust +CredentialChanged, // 触发 OSPF 立即同步 +``` + +**文件: `easytier/src/common/config.rs`** + +临时节点的配置方式: 直接使用凭据私钥作为 `SecureModeConfig.local_private_key`。 +可在 `TomlConfigLoader` 中新增便捷字段或 CLI 参数: +- `--credential <私钥base64>`: 临时节点使用凭据私钥加入网络 +- `--credential-file `: 管理节点指定凭据存储 JSON 文件路径 + +### Step 7: RPC 服务 + CLI + +**文件: `easytier/src/peers/rpc_service.rs`** + +实现 `CredentialManageRpc`,参考 `PeerManagerRpcService` 模式。 + +**CLI** (`easytier-cli`): +``` +easytier-cli credential generate + 输出: credential_id=<公钥base64> credential_secret=<私钥base64> + +easytier-cli credential revoke +easytier-cli credential list +``` + +**临时节点启动**: +```bash +# 方式1: 直接传入凭据私钥 +easytier-core --network-name test \ + --secure-mode \ + --credential <私钥base64> \ + --peers tcp://管理节点:11010 + +# 内部实现: 将凭据私钥设为 SecureModeConfig.local_private_key +``` + +### Step 8: 连接时验证(握手后快速拒绝,必选) + +在 `do_noise_handshake_as_server()` 完成后,**必须**进行快速检查: +- 如果对端 `secret_proof` 验证失败(非管理节点),且对端 `noise_static_pubkey` 不在本节点已知的 `trusted_credential_pubkeys` 中 +- 立即断开连接 + +这是**必选的安全措施**(非可选优化)。因为 Step 3 放宽了 secret_proof 失败的处理,如果不做快速拒绝,任何随机节点都能与管理节点建立加密连接并持有,浪费资源。 + +```rust +// 在 handshake 完成后 +if !secret_proof_verified { + let remote_pubkey = handshake_result.remote_static_pubkey; + if !self.global_ctx.credential_manager.is_pubkey_trusted(&remote_pubkey) { + return Err(Error::AuthError("unknown credential".to_string())); + } + // 公钥在 trusted 列表中 → 允许连接,标记为 CredentialAuthenticated + handshake_result.secure_auth_level = SecureAuthLevel::CredentialAuthenticated; +} +``` + +## 关键文件清单 + +| 文件 | 修改内容 | +|------|----------| +| `easytier/src/proto/peer_rpc.proto` | `RoutePeerInfo` 加 `trusted_credential_pubkeys`; `SecureAuthLevel` 加 `CredentialAuthenticated` | +| `easytier/src/proto/api_instance.proto` | 新增 `CredentialManageRpc` 服务及消息定义 | +| `easytier/src/peers/credential_manager.rs` | **新文件** — 凭据管理器(密钥对生成/撤销/列表) | +| `easytier/src/peers/mod.rs` | 导出 credential_manager | +| `easytier/src/peers/peer_ospf_route.rs` | `new_updated_self()` 填 trusted_pubkeys; 新增 `verify_credential_peers()`; 路由表过滤 | +| `easytier/src/peers/peer_conn.rs` | `do_noise_handshake_as_server()` 放宽 secret_proof 失败为非致命; 可选握手阶段快速拒绝 | +| `easytier/src/peers/peer_manager.rs` | 集成 CredentialManager; 不可信 peer 断连逻辑 | +| `easytier/src/common/global_ctx.rs` | 持有 CredentialManager; 新增 CredentialChanged 事件 | +| `easytier/src/common/config.rs` | 新增 `--credential` 参数处理 | +| `easytier/src/peers/rpc_service.rs` | 实现 CredentialManageRpc | +| `easytier/src/proto/common.rs` | SecureModeConfig 可选: credential 模式识别 | + +## 复用现有机制 + +| 现有机制 | 路径 | 复用方式 | +|----------|------|----------| +| Noise XX 握手 | `peer_conn.rs:680,934` | 临时节点直接使用凭据密钥对走完整 Noise 流程 | +| `SecureModeConfig` | `proto/common.rs:367` | 临时节点的凭据私钥直接设为 local_private_key | +| `noise_static_pubkey` | `RoutePeerInfo` 字段 18 | 临时节点的凭据公钥已在 OSPF 中传播 | +| `verify_and_update_group_trusts()` | `peer_ospf_route.rs:743` | 凭据校验逻辑参考此模式 | +| `PeerMap::close_peer()` | `peer_map.rs:317` | 断开不可信 peer | +| OSPF 路由同步 | `SyncRouteInfoRequest` | 可信公钥列表随 RoutePeerInfo 自然传播 | +| `PeerManagerRpcService` | `rpc_service.rs:24` | RPC 服务实现模式 | +| `GlobalCtxEvent` | `global_ctx.rs:32` | 新增事件触发同步 | + +## 验证方案 + +1. **单元测试**: + - `credential_manager.rs`: 密钥对生成、撤销、列表 + - `peer_conn.rs`: 凭据节点 Noise 握手成功(无 network_secret) + +2. **集成测试** (参考 `tests/three_node.rs`): + - 3 节点: A + B (管理节点, network_secret) + C (临时节点, credential) + - A 生成凭据(groups=["guest"])→ C 使用凭据连接 → 验证 C 加入路由表、可达 + - 验证 C 的 ACL group 为 "guest",配置 group ACL 规则后生效 + - A 撤销凭据 → 等待 OSPF 同步 (~1-3s) → 验证 C 被 A 和 B 断开 + - C 尝试重连 → 验证握手阶段被拒 + +3. **手动测试**: + ```bash + # A: 管理节点 + easytier-core -n test -s secret --secure-mode --listeners tcp://0.0.0.0:11010 + easytier-cli credential generate # → credential_id + credential_secret + + # C: 临时节点 + easytier-core -n test --secure-mode --credential <私钥base64> --peers tcp://A:11010 + + # 验证后撤销 + easytier-cli credential revoke + # C 数秒内被踢出 + ``` + +### Step 9: 临时节点 OSPF 路由限制 + +**约束**: 临时节点传播的路由信息不可信,需严格限制。 + +#### 9a. 管理节点不主动发起到临时节点的 OSPF session + +**核心原则**: OSPF `maintain_sessions()` 构建最小生成树时,只在管理节点之间选择 initiator,不将临时节点纳入 `dst_peer_id_to_initiate`。但管理节点**被动接受**临时节点发起的 session。 + +**文件: `easytier/src/peers/peer_ospf_route.rs`** + +修改 `maintain_sessions()` (line 2485): +- 在构建 `dst_peer_id_to_initiate` 候选列表时,过滤掉临时节点 +- 管理节点之间的 MST 不受影响 + +```rust +// 在 maintain_sessions() 中,构建 initiator 候选时过滤临时节点 +let peers: Vec = peers.into_iter().filter(|peer_id| { + // 只主动发起到管理节点的 session,不主动连临时节点 + !self.is_credential_peer(*peer_id) +}).collect(); +``` + +- **临时节点自身**: 在 `maintain_sessions()` 中只将管理节点作为 initiator 候选,跳过其他临时节点 + +```rust +// 临时节点侧: 只主动连管理节点 +if self.is_credential_node() { + let peers: Vec = peers.into_iter().filter(|peer_id| { + !self.is_credential_peer(*peer_id) // 只连管理节点 + }).collect(); +} +``` + +**session 建立方式**: +- **管理节点 → 管理节点**: 正常 MST initiator 选择(不变) +- **临时节点 → 管理节点**: 临时节点主动发起 session,管理节点被动接受 +- **临时节点 → 临时节点**: 不建立(双方都过滤掉对方) +- **管理节点 → 临时节点**: 不主动发起(不在 initiator 候选中) + +**路由信息传播**: 临时节点通过其主动发起的 session 调用 `sync_route_info` 推送自身 RoutePeerInfo。管理节点在正常 OSPF sync 中将其代理传播给其他管理节点。管理节点也通过该 session 向临时节点推送完整路由表。 + +#### 9b. 管理节点只选择性接收临时节点的路由信息 + +**文件: `easytier/src/peers/peer_ospf_route.rs`** + +临时节点通过其主动发起的 session 调用 `sync_route_info`,管理节点在处理时需做过滤: + +- 只接收该临时节点**自己的** `RoutePeerInfo`(`route_info.peer_id == dst_peer_id`),丢弃其声称的其他 peer 的路由信息 +- 对临时节点自身的 RoutePeerInfo,过滤其 `proxy_cidrs`:只保留在 `TrustedCredentialPubkey.allowed_proxy_cidrs` 范围内的网段,移除超出范围的声明 +- 临时节点的 `foreign_network_infos` 应忽略 +- 临时节点的 `conn_info`(连接拓扑)**根据 `allow_relay` 标志决定**(见下方) + +修改 `update_peer_infos()` (line 461): + +```rust +fn update_peer_infos( + &self, my_peer_id, my_peer_route_id, dst_peer_id, + peer_infos, raw_peer_infos, +) -> Result<(), Error> { + let dst_is_credential_peer = self.is_credential_peer(dst_peer_id); + + for (idx, route_info) in peer_infos.iter().enumerate() { + // 临时节点只允许传播自己的路由信息 + if dst_is_credential_peer && route_info.peer_id != dst_peer_id { + tracing::debug!( + ?dst_peer_id, peer_id=?route_info.peer_id, + "ignoring route info from credential peer for other peer" + ); + continue; + } + + // 过滤临时节点的 proxy_cidrs,只保留凭据允许的范围 + if dst_is_credential_peer { + let allowed = self.get_credential_allowed_proxy_cidrs(dst_peer_id); + if let Some(allowed_cidrs) = allowed { + route_info.proxy_cidrs.retain(|cidr| { + allowed_cidrs.iter().any(|a| cidr_is_subset(cidr, a)) + }); + } + } + // ... existing logic ... + } +} +``` + +修改 `do_sync_route_info()` (line 2614): + +```rust +// 在 do_sync_route_info 中 +let from_is_credential = self.is_credential_peer(from_peer_id); +let credential_allows_relay = from_is_credential + && self.is_credential_relay_allowed(from_peer_id); + +if let Some(peer_infos) = &peer_infos { + // update_peer_infos 内部会过滤临时节点的非自身信息 + service_impl.synced_route_info.update_peer_infos(...); +} + +// 临时节点的 conn_info: 仅当 allow_relay=true 时接收 +if let Some(conn_info) = &conn_info { + if !from_is_credential || credential_allows_relay { + service_impl.synced_route_info.update_conn_info(conn_info); + } +} + +// 临时节点的 foreign_network_infos 始终不接收 +if let Some(foreign_network) = &foreign_network { + if !from_is_credential { + service_impl.synced_route_info.update_foreign_network(foreign_network); + } +} +``` + +**conn_info 处理**: +- 临时节点的 `conn_info`: 根据凭据的 `allow_relay` 标志决定是否接收 + - `allow_relay = true`: 管理节点接收并传播该临时节点的 conn_info,使其参与路由图,可作为 relay 转发数据 + - `allow_relay = false`(默认): 忽略 conn_info,该临时节点不参与中继(仅作为叶子节点存在于路由图中) +- 临时节点的 `foreign_network_infos` 始终忽略 + +**`is_credential_relay_allowed()` 实现**: +```rust +fn is_credential_relay_allowed(&self, peer_id: PeerId) -> bool { + // 从全网汇总的 trusted_credential_pubkeys 中查找该 peer 的凭据 + // 检查对应 TrustedCredentialPubkey.allow_relay 标志 + let peer_info = self.peer_infos.read(); + if let Some(info) = peer_info.get(&peer_id) { + for tc in &self.all_trusted_credentials { + if tc.pubkey == info.noise_static_pubkey { + return tc.allow_relay; + } + } + } + false +} +``` + +**注意**: 即使 `allow_relay=true`,临时节点仍然不能转发握手包(Step 10b 限制不变),因此不会有新节点通过 relay 临时节点接入网络。relay 能力仅用于已建立连接的 peer 之间的数据转发。 + +#### 9c. 临时节点的 `RoutePeerInfo` 中的 `trusted_credential_pubkeys` 被忽略 + +已在 Step 5 中说明:只信任 `secret_digest` 匹配的管理节点发布的 trusted 列表。 + +#### 判断 peer 是否为临时节点的方法 + +在 `SyncedRouteInfo` / `PeerRouteServiceImpl` 中新增: +```rust +fn is_credential_peer(&self, peer_id: PeerId) -> bool { + // 方法: 检查该 peer 的 RoutePeerInfo + // 1. 如果 peer 的 noise_static_pubkey 在 trusted_credential_pubkeys 中 → 是临时节点 + // 2. 如果 peer 通过了 network_secret 确认 (secret_digest 匹配) → 是管理节点 + // 3. 在 peer_conn 握手后,可以记录 secure_auth_level 到连接信息中 + let peer_info = self.synced_route_info.peer_infos.read(); + if let Some(info) = peer_info.get(&peer_id) { + if !info.noise_static_pubkey.is_empty() + && self.trusted_credential_pubkeys.contains(&info.noise_static_pubkey) { + return true; + } + } + false +} +``` + +对于直连 peer,也可以在握手阶段直接记录 `secure_auth_level`,用于快速判断。 + +### Step 10: 禁止通过临时节点接入网络 + +**约束**: 不得有新节点(无论是否持有 network_secret)通过临时节点的 listener 接入网络。但允许通过管理节点中继后建立 P2P 连接。 + +#### 10a. 临时节点天然无法接受新节点接入(无需额外代码) + +临时节点作为 listener 时,新节点的连接会**自然失败**,因为: +1. 临时节点没有 `network_secret`,无法验证对端的 `secret_proof` → 无法确认对端是管理节点 +2. 临时节点不发布 `trusted_credential_pubkeys` → 对端公钥不在可信列表中 +3. 对端也无法验证临时节点的 `secret_proof`(临时节点没有 network_secret) + +因此 **不需要在 `add_tunnel_as_server()` 中添加显式拦截逻辑**。已有的 Noise 握手 + 凭据校验机制已足够阻止新节点通过临时节点接入。 + +**例外**: 已知的管理节点可以连接到临时节点(如 P2P hole punch 场景),因为管理节点的公钥已通过 OSPF 同步被临时节点知晓,握手可以成功。 + +#### 10b. 临时节点不转发来自未知 peer 的连接请求 + +**文件: `easytier/src/peers/peer_manager.rs`** + +在 packet forwarding 路径 (line 718-766) 中: +- 临时节点不应转发 `HandShake` / `NoiseHandshakeMsg*` 类型的包 +- 这防止新节点通过临时节点的中继接入网络 + +```rust +// 在 peer_recv 循环的 forward 分支中 +if to_peer_id != my_peer_id { + // 临时节点不转发握手包(阻止新节点通过临时节点接入) + if is_credential_node && ( + hdr.packet_type == PacketType::HandShake as u8 + || hdr.packet_type == PacketType::NoiseHandshakeMsg1 as u8 + || hdr.packet_type == PacketType::NoiseHandshakeMsg2 as u8 + || hdr.packet_type == PacketType::NoiseHandshakeMsg3 as u8 + ) { + tracing::debug!("credential node dropping forwarded handshake packet"); + continue; + } + // ... existing forward logic ... +} +``` + +#### 10c. P2P 连接通过管理节点中继仍然允许 + +P2P hole punch 的流程: +1. 两个节点通过管理节点交换打洞信息(RPC) +2. 建立直接 P2P tunnel +3. 在 P2P tunnel 上握手 + +这个流程不受影响,因为: +- 打洞信息交换通过管理节点中继(RPC),不经过临时节点 +- P2P tunnel 建立后的握手是直连,不通过临时节点的 listener +- `is_directly_connected=false` 的连接(hole punch 结果)可以被临时节点接受 + +**设计思路**: 将凭据映射为 ACL Group,复用现有的 group-based ACL 规则系统。 + +现有 ACL 系统已支持基于 group 的规则匹配: +- `Rule.source_groups` / `Rule.destination_groups` (acl.proto:72-73) +- `PeerGroupInfo` 通过 HMAC proof 验证 peer 所属 group (peer_rpc.rs:8-38) +- `verify_and_update_group_trusts()` 在 OSPF 同步时更新 group trust map (peer_ospf_route.rs:743) +- `get_peer_groups()` 返回 peer 所属的 group 列表,用于 ACL 匹配 (peer_ospf_route.rs:2287) + +**方案**: 生成凭据时,为每个凭据创建一个隐式 ACL Group。 + +1. **凭据生成时**: 管理节点为凭据创建一个关联的 group: + - group_name = `"credential:"` 或用户自定义名称 + - group_secret = 由 credential_secret 派生的密钥 + - 可选:指定凭据所属的 group_name(批量管理,如 `"guest"`, `"contractor"`) + +2. **临时节点加入时**: 临时节点使用凭据私钥连接。其 group 归属由管理节点在 `TrustedCredentialPubkey.groups` 中声明(无需临时节点自己提供 group proof)。验证节点在 `verify_credential_peers()` 中匹配公钥后,直接将声明的 groups 加入 `group_trust_map`。 + +3. **ACL 规则配置**: 管理员可配置基于 group 的 ACL 规则: + ```toml + # 示例配置: 限制 "guest" group 只能访问特定子网 + [[acl.acl_v1.chains]] + name = "inbound" + chain_type = "Inbound" + default_action = "Allow" + + [[acl.acl_v1.chains.rules]] + name = "restrict_guest" + source_groups = ["guest"] + destination_ips = ["10.0.0.0/24"] + action = "Drop" + ``` + +4. **管理节点发布 group 信息**: + - 在 `RoutePeerInfo.trusted_credential_pubkeys` 中传播可信公钥时,同时包含关联的 group 信息 + - 扩展 proto: + (使用 Step 1 中定义的 `TrustedCredentialPubkey`,group 归属由管理节点声明,无需 proof 验证) + - 替换 `repeated bytes trusted_credential_pubkeys` 为 `repeated TrustedCredentialPubkey trusted_credential_pubkeys` + +5. **校验节点处理**: 在 `verify_credential_peers()` 中: + - 验证凭据公钥在可信列表中后 + - 直接将 `TrustedCredentialPubkey.groups` 中声明的 group 加入 `group_trust_map` / `group_trust_map_cache`(无需验证 group proof,因为管理节点的声明已是可信的) + - ACL filter 在处理数据包时自动基于 group 匹配规则 + +**API 扩展**: + +生成凭据时可指定 group: +```protobuf +message GenerateCredentialRequest { + repeated string groups = 1; // 可选: 为该凭据关联的 group 名称 + bool allow_relay = 2; // 可选: 是否允许 relay + repeated string allowed_proxy_cidrs = 3; // 可选: 限制可声明的 proxy_cidrs + int64 ttl_seconds = 4; // 必选: 凭据有效期(秒) +} +``` + +CLI: +```bash +# 生成带 group 的凭据,有效期 24 小时 +easytier-cli credential generate --groups guest,restricted --ttl 86400 + +# 生成允许 relay 的凭据,有效期 7 天 +easytier-cli credential generate --groups relay-node --allow-relay --ttl 604800 + +# 最简用法(默认 group 名为 "credential") +easytier-cli credential generate --ttl 3600 +``` + +## 安全审查 + +### 已覆盖的安全性 + +- **端到端加密**: 数据包在源端加密、目的端解密,relay 节点(含 `allow_relay` 的临时节点)无法看到明文 +- **临时节点自我授权防护**: 只信任 `secret_digest` 匹配的管理节点发布的 `trusted_credential_pubkeys` +- **临时节点路由篡改防护**: 只接收临时节点自身的 RoutePeerInfo,忽略其转发的其他路由 +- **临时节点网络接入防护**: 临时节点天然无法接受新节点接入(无 network_secret、不发布 trusted 列表) + +### 需要关注的安全问题 + +**1. Step 8 握手后快速拒绝应为必选(非可选)** + +当前 Step 8 标记为"可选优化",但实际上是**必须的安全措施**。如果不做快速拒绝: +- 任何随机节点(无 credential、无 network_secret)都能完成 Noise 握手(因为 Step 3 放宽了 secret_proof 失败) +- 在等待 OSPF 同步验证期间,该节点持有一个有效的加密连接,浪费资源 +- **修改**: Step 8 改为必选。握手完成后立即检查:对端 secret_proof 失败 + 公钥不在本节点已知的 trusted 列表中 → 立即断开 + +**2. Group proof 验证机制需要明确** + +当前方案:临时节点在 `RoutePeerInfo.groups` 中携带 `PeerGroupInfo`(HMAC proof),管理节点在 `TrustedCredentialPubkey` 中传播 `group_secret_hash`。 + +问题:HMAC 验证需要**原始 secret**,不是 hash。验证节点如何知道 credential 的 group secret? + +**解决方案**: `TrustedCredentialPubkey.group_secret_hash` 改为 `group_secret_digest`,使用与现有 `NetworkIdentity.network_secret_digest` 相同的 digest 算法。验证时: +- 管理节点在 `TrustedCredentialPubkey` 中包含 `group_secret_digest` +- 临时节点发送的 `PeerGroupInfo` 中包含 `group_proof`(HMAC) +- 验证节点无法直接验证 HMAC(没有原始 secret),但可以信任管理节点的声明:如果管理节点在 `TrustedCredentialPubkey.groups` 中列出了某个 group,且临时节点的公钥匹配,就直接信任该 group 归属 +- 即:**group 归属由管理节点在 `TrustedCredentialPubkey` 中声明,无需临时节点提供 proof** +- 这简化了实现,且安全性不降低(管理节点已是可信源) + +**3. 凭据持久化** + +`CredentialManager` 当前设计为内存存储。管理节点重启后所有凭据丢失,导致使用这些凭据的临时节点被踢出。 + +**解决方案**: +- 管理节点可配置凭据存储的 JSON 文件路径(如 `--credential-file /path/to/credentials.json`) +- `CredentialManager` 启动时从该文件加载已有凭据 +- 生成/撤销凭据时自动写入该文件 +- 未配置文件路径时,凭据仅存内存(重启丢失) + +**4. 同一凭据多节点复用** + +同一个 credential 私钥可以被多个节点同时使用。它们有不同的 `peer_id` 但相同的 `noise_static_pubkey`。这会导致: +- 路由表中多个 RoutePeerInfo 有相同的 `noise_static_pubkey` +- 撤销时所有使用该凭据的节点同时被踢出(符合预期) +- **这是预期行为**,但应在文档中说明 + +**5. 临时节点 proxy_cidrs 限制** + +临时节点可能声明虚假的 `proxy_cidrs`(子网代理),导致流量黑洞。 + +**解决方案**(已纳入设计): +- 生成凭据时通过 `allowed_proxy_cidrs` 字段限制该凭据可声明的网段范围 +- 管理节点在 Step 9b 的 `update_peer_infos()` 中过滤:只保留临时节点声明的 proxy_cidrs 中属于 `allowed_proxy_cidrs` 子集的网段 +- 未配置 `allowed_proxy_cidrs` 时(空列表),临时节点不允许声明任何 proxy_cidrs + +**6. 凭据过期时间(TTL)** + +凭据必须设置过期时间。过期后自动失效,等同于被撤销。 +- 生成凭据时必须指定 `--ttl` 或 `--expiry` +- `verify_credential_peers()` 中检查 `expiry_unix`,过期的凭据从可信列表中移除 +- 过期检查在每次路由同步时执行,无需额外定时器 + +## 优势 + +- **最小改动**: Noise 握手消息格式不变,完全复用现有流程 +- **安全性**: X25519 密钥对提供强身份认证,不弱于 network_secret;端到端加密保护 relay 场景 +- **自然传播**: 利用 OSPF 已有基础设施,无需新 RPC +- **去中心化撤销**: 任何管理节点都可撤销,全网通过路由同步感知 +- **ACL 复用**: 凭据映射为 ACL Group,完全复用现有 group-based ACL 规则系统,无需新的 ACL 机制 diff --git a/easytier/docs/relay_peer_manager_design.md b/easytier/docs/relay_peer_manager_design.md new file mode 100644 index 000000000..3be4fb0d3 --- /dev/null +++ b/easytier/docs/relay_peer_manager_design.md @@ -0,0 +1,177 @@ +# Relay Peer 管理模块设计文档 + +## 背景与现状 + +当前出站转发路径中,PeerManager 根据路由直接选择下一跳并发送,转发路径以“取下一跳 → 发送”为核心流程: + +- 发送内部路径:[peer_manager.rs:L1053-L1082](file:///data/project/EasyTier/easytier/src/peers/peer_manager.rs#L1053-L1082) +- 数据面发送入口:[peer_manager.rs:L1187-L1238](file:///data/project/EasyTier/easytier/src/peers/peer_manager.rs#L1187-L1238) + +现状缺少面向“非直连目标”的统一管理模块,无法对 Relay Peer 进行会话、状态与策略层面的治理。 + +## 设计目标 + +- 对非直连 Relay Peer 做生命周期管理 +- 提供统一的会话(如 PeerSession)与路径选择入口 +- 与现有路由模块解耦,只消费下一跳候选与路由变更信息 +- 不改变现有数据面主路径流程 + +## 架构设计 + +### 模块命名 + +**RelayPeerMap** + +### 引用关系 + +- **PeerManager**: 作为顶层协调者,同时持有 `Arc` 和 `Arc`。 +- **RelayPeerMap**: 持有 `Arc`(或 `Weak`),用于在决策后调用底层发送能力。 +- **PeerMap**: 专注直连 Peer 管理与基础路由表维护,不直接持有 RelayPeerMap(避免循环依赖)。 + +### 职责划分 + +- **PeerManager**: + - 发送入口。 + - 判断目标是否直连: + - 若目标在 PeerMap:直接调用 `PeerMap` 发送。 + - 若目标不在 PeerMap:调用 `RelayPeerMap` 处理。 +- **RelayPeerMap**: + - 维护非直连 Peer 的状态(会话、健康度)。 + - 决策下一跳(Next Hop)。 + - 调用 `PeerMap` 将数据包发送给下一跳。 +- **ForeignNetworkManager**: + - 拥有独立的 RelayPeerMap 实例,用于 foreign network 的非直连转发。 +- **PeerMap**: + - 维护直连 Peer 连接。 + - 提供基础路由表查询。 + - 执行向直连邻居的物理发送。 + +## 数据模型 + +### RelayPeerKey + +- **dst_peer_id** (PeerId) +- 注:RelayPeerMap 实例隶属于特定网络上下文,因此 Key 仅需 PeerId。 + +### RelayPeerState + +- selected_next_hop: PeerId +- session: Option +- last_active_at: Instant +- path_metrics: latency, loss, hop_count (可选) + +### RelayPathCandidate + +- next_hop_peer_id +- cost / latency / availability + +## 简化状态管理 + +不再引入复杂状态机(如 Establishing/Suspect 等),仅依赖以下状态判断: + +- **会话是否存在**:`session.is_some()` +- **会话是否有效**:检查 session 过期时间或 generation +- **路由是否可达**:检查路由表中是否有 next hop + +## 关键流程 + +### 出站发送流程(非直连) + +1. **PeerManager** 接收发送请求(目标 `dst_peer_id`)。 +2. **PeerManager** 检查 `PeerMap` 是否直连 `dst_peer_id`。 +3. 若非直连,**PeerManager** 将请求转交给 **RelayPeerMap**。 +4. **RelayPeerMap** 处理: + - 查找 `RelayPeerState`。 + - 若首次与该 Relay Peer 通信,创建 RelayPeerState 并进入握手流程。 + - 确保会话存在(若无则触发握手与同步)。 + - 选择下一跳(由 RelayPeerMap 决策)。 + - 调用 **PeerMap** 的 `send_msg_directly(next_hop, packet)`。 + +### Relay 数据面握手出站流程(Relay Peer 特例) + +说明:Relay Peer 初次通信前必须先完成基于数据面消息的 Noise 握手,否则无法安全发送加密数据面包。握手消息通过普通数据面路径转发,但其目标是创建会话而非携带业务数据。 + +流程要点(发起方视角): + +1. 发送路径命中 `dst_peer_id` 为非直连目标后,进入 RelayPeerMap 流程。 +2. 若目标会话不存在或已失效,则发送 **RelayHandshake** 消息(携带 `m1`),通过 `send_msg_directly(next_hop, packet)` 转发给对端。 +3. 对端收到后返回 **RelayHandshakeAck**(携带 `m2`)沿原路径回传,双方派生会话并落库。 +4. 握手完成后,使用已建立会话的密钥对数据面包加密/鉴别,再走正常转发流程。 +5. 若握手失败或控制面公钥信息缺失,则不进入数据发送,返回可重试的错误(由上层决定重试节奏)。 + +### Relay 会话建立流程(数据面 + Noise 1-RTT) + +背景:直连 Peer 的 Noise 握手在 `PeerConn` 内完成;Relay Peer 没有 `PeerConn`,因此无法复用该握手逻辑。Relay 会话需要通过 **数据面握手消息** 完成握手与密钥派生,并把结果落到 `PeerSessionStore`(或等价的会话存储)中供数据面复用。 + +关键假设:Relay Peer 握手前即可拿到对端静态公钥(通过 OSPF 等控制面传播),因此可选用 **1-RTT 的 Noise 握手模式**(例如 IK/KK 一类的两报文握手),并将“两报文”映射为 **RelayHandshake / RelayHandshakeAck** 两种数据面消息。 + +建议流程(以本端作为 initiator 为例): + +1. `ensure_session(dst_peer_id)` 发现无可用会话,触发一次握手流程(可选:对并发请求做 in-flight 去重)。 +2. 从控制面缓存中读取 `dst_peer_id` 的静态公钥(若不存在则等待控制面收敛,或退化为非 1-RTT 的握手模式)。 +3. 生成 Noise 握手首报文 `m1`(包含必要的认证信息与抗重放字段,例如 session generation / nonce / 时间窗等)。 +4. 发送 `RelayHandshake(m1)`,对端返回 `RelayHandshakeAck(m2)`。 +5. initiator 处理 `m2`,双方派生出相同的会话密钥与会话标识,将会话写入 `PeerSessionStore`,供后续发送复用。 +6. 后续 Relay 数据面包使用该会话密钥进行加解密/鉴别(具体包格式不在本层定义,保持与直连会话的语义一致)。 + +实现要点: + +- **角色确定**:为避免并发双向握手导致的竞态,可使用确定性规则选择 initiator(如 `min(peer_id)` 发起),或由第一次发送方发起并在冲突时做幂等合并。 +- **幂等与重试**:数据面握手应支持重试(同一 generation/nonce 重放可安全拒绝或复用),并与路由收敛解耦。 +- **会话绑定**:握手需绑定 `dst_peer_id` 与其静态公钥指纹,避免控制面短暂不一致造成的密钥混用。 + +### 会话管理 + +- PeerSessionStore 仅用于 secure mode,会话创建与密钥派生在该模式下生效。 +- 在发送时若发现无会话,则触发 Create/Join/Sync 逻辑。 +- 对于 Relay Peer,会话创建阶段由 **数据面握手消息承载 Noise 握手**(见上节),以替代直连 `PeerConn` 内的握手流程。 + +### PacketType 规划(新增) + +- 新增 PacketType: + - `RelayHandshake`:承载 `m1`(initiator -> responder) + - `RelayHandshakeAck`:承载 `m2`(responder -> initiator) +- 载荷建议: + - `RelayHandshake`: `RelayNoiseMsg1Pb`(包含 a_session_generation/conn_id/算法等字段) + - `RelayHandshakeAck`: `RelayNoiseMsg2Pb`(包含 b_session_generation/root_key/initial_epoch/算法等字段) +- 约束: + - 两类包应与普通 Data 包一样可被转发,但不应被当作业务数据消费。 + - 需要在路由转发链路中识别为“握手控制类”消息。 + +## 策略设计 + +- 下一跳策略由 RelayPeerMap 决策,可结合 latency_first 选择 LeastHop 或 LatencyFirst。 +- 握手策略:优先采用“已知对端静态公钥”的 **1-RTT Noise 握手**,并通过 **RelayHandshake/RelayHandshakeAck** 消息触发会话建立。 +- 失败处理:依赖上层重试或底层路由收敛,暂不在此层做复杂的 Failover 状态流转。 +- 公钥来源:对端静态公钥以控制面传播为准;在控制面信息缺失或变更时,应阻止复用旧会话或触发重新握手。 + +## 接口草案 + +### RelayPeerMap 接口 + +- `send_msg(packet, dst_peer_id)`: 处理非直连发送逻辑。 +- `ensure_session(dst_peer_id)`: 确保会话可用。 +- `handshake_session(dst_peer_id)`: 通过握手消息完成 Relay 会话握手(对上层透明,可由 `ensure_session` 内部调用)。 +- `remove_peer(dst_peer_id)`: 删除已经失效的 Peer。 +## 监控与指标建议 + +- Relay 会话数 +- Relay 发送成功/失败计数 + +## 渐进式落地计划 + +### 阶段 1:基础能力 + +- 引入 RelayPeerMap 结构。 +- 在 PeerManager 中集成 RelayPeerMap。 +- 实现基础的“非直连转发”委托逻辑。 + +## 兼容性说明 + +- 需要新增 PacketType 用于 RelayHandshake/RelayHandshakeAck。 +- 在 secure mode 下,压缩由 PeerManager 完成;加密由 PeerConn(直连)或 RelayPeer(非直连)完成。 +- RelayPeer 在 secure mode 下需要提供会话级加密/解密入口: + - 发送:在 RelayPeerMap 决策完成后、调用 `send_msg_directly` 前,用 Relay 会话密钥加密。 + - 接收:在数据面包进入业务处理前,按 `from_peer_id/to_peer_id` 定位会话并解密。 +- PeerSessionStore 为 secure mode 的会话兼容性保留,非 secure mode 仅保持现有行为。 +- 不改变路由模块的计算结果。 diff --git a/easytier/locales/app.yml b/easytier/locales/app.yml index dd3397ef2..ff4484d20 100644 --- a/easytier/locales/app.yml +++ b/easytier/locales/app.yml @@ -244,6 +244,12 @@ core_clap: local_public_key: en: "local public key for secure mode. if not provided, a random key will be generated, or use local private key to derive public key" zh-CN: "安全模式下的本地公钥。如果未提供,则会随机生成一个密钥,或者使用本地私钥派生公钥" + credential: + en: "credential secret (base64-encoded private key) for joining network as a temporary node without network_secret" + zh-CN: "凭据密钥(base64编码的私钥),用于作为临时节点加入网络,无需 network_secret" + credential_file: + en: "path to credential storage file for persisting generated credentials across restarts (admin nodes)" + zh-CN: "凭据存储文件路径,用于在管理节点重启后保留已生成的凭据" check_config: en: Check config validity without starting the network zh-CN: 检查配置文件的有效性并退出 diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index 8689490a6..f2491fb46 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -216,6 +216,11 @@ pub trait ConfigLoader: Send + Sync { fn get_secure_mode(&self) -> Option; fn set_secure_mode(&self, secure_mode: Option); + fn get_credential_file(&self) -> Option { + None + } + fn set_credential_file(&self, _path: Option) {} + fn dump(&self) -> String; } @@ -296,6 +301,16 @@ impl NetworkIdentity { network_secret_digest: Some(network_secret_digest), } } + + /// Create a NetworkIdentity for a credential node (no network_secret). + /// The node identifies by network_name only and authenticates via credential keypair. + pub fn new_credential(network_name: String) -> Self { + NetworkIdentity { + network_name, + network_secret: None, + network_secret_digest: None, + } + } } impl Default for NetworkIdentity { @@ -428,6 +443,8 @@ struct Config { udp_whitelist: Option>, stun_servers: Option>, stun_servers_v6: Option>, + + credential_file: Option, } #[derive(Debug, Clone)] @@ -821,6 +838,14 @@ impl ConfigLoader for TomlConfigLoader { self.config.lock().unwrap().secure_mode = secure_mode; } + fn get_credential_file(&self) -> Option { + self.config.lock().unwrap().credential_file.clone() + } + + fn set_credential_file(&self, path: Option) { + self.config.lock().unwrap().credential_file = path; + } + fn dump(&self) -> String { let default_flags_json = serde_json::to_string(&gen_default_flags()).unwrap(); let default_flags_hashmap = diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index d8df3d83a..ef6eafeac 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -1,14 +1,21 @@ use std::collections::hash_map::DefaultHasher; +use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::{ hash::Hasher, sync::{Arc, Mutex}, + time::{SystemTime, UNIX_EPOCH}, }; +use arc_swap::ArcSwap; +use dashmap::DashMap; + use crate::common::config::ProxyNetworkConfig; +use crate::common::shrink_dashmap; use crate::common::stats_manager::StatsManager; use crate::common::token_bucket::TokenBucketManager; use crate::peers::acl_filter::AclFilter; +use crate::peers::credential_manager::CredentialManager; use crate::proto::acl::GroupIdentity; use crate::proto::api::config::InstanceConfigPatch; use crate::proto::api::instance::PeerConnInfo; @@ -59,11 +66,90 @@ pub enum GlobalCtxEvent { ConfigPatched(InstanceConfigPatch), ProxyCidrsUpdated(Vec, Vec), // (added, removed) + + CredentialChanged, } pub type EventBus = tokio::sync::broadcast::Sender; pub type EventBusSubscriber = tokio::sync::broadcast::Receiver; +/// Source of a trusted public key from OSPF route propagation +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TrustedKeySource { + /// Peer node's noise static pubkey + OspfNode, + /// Admin-declared trusted credential pubkey + OspfCredential, +} + +/// Metadata for a trusted public key +#[derive(Debug, Clone)] +pub struct TrustedKeyMetadata { + pub source: TrustedKeySource, + /// Expiry time in Unix seconds. None means never expires. + pub expiry_unix: Option, +} + +impl TrustedKeyMetadata { + pub fn is_expired(&self) -> bool { + if let Some(expiry) = self.expiry_unix { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + return now >= expiry; + } + false + } +} + +// key is (pubkey, network-name) +pub type TrustedKeyMap = HashMap, TrustedKeyMetadata>; + +struct TrustedKeyMapManager { + network_trusted_keys: DashMap>, +} + +impl TrustedKeyMapManager { + pub fn new() -> Self { + Self { + network_trusted_keys: DashMap::new(), + } + } + + pub fn update_trusted_keys(&self, network_name: &str, trusted_keys: TrustedKeyMap) { + match self.network_trusted_keys.entry(network_name.to_string()) { + dashmap::Entry::Vacant(entry) => { + entry.insert(ArcSwap::new(Arc::new(trusted_keys))); + } + dashmap::Entry::Occupied(entry) => { + entry.get().store(Arc::new(trusted_keys)); + } + } + } + + pub fn remove_trusted_keys(&self, network_name: &str) { + self.network_trusted_keys.remove(network_name); + shrink_dashmap(&self.network_trusted_keys, None); + } + + pub fn verify_trusted_key(&self, pubkey: &[u8], network_name: &str) -> bool { + let Some(trusted_keys) = self + .network_trusted_keys + .get(network_name) + .map(|v| v.load_full()) + else { + return false; + }; + + let Some(metadata) = trusted_keys.get(&pubkey.to_vec()) else { + return false; + }; + + !metadata.is_expired() + } +} + pub struct GlobalCtx { pub inst_name: String, pub id: uuid::Uuid, @@ -97,6 +183,12 @@ pub struct GlobalCtx { stats_manager: Arc, acl_filter: Arc, + + credential_manager: Arc, + + /// OSPF propagated trusted keys (peer pubkeys and admin credentials) + /// Stored in ArcSwap for lock-free reads and atomic batch updates + trusted_keys: Arc, } impl std::fmt::Debug for GlobalCtx { @@ -152,6 +244,9 @@ impl GlobalCtx { ..Default::default() }; + let credential_storage_path = config_fs.get_credential_file(); + let credential_manager = Arc::new(CredentialManager::new(credential_storage_path)); + GlobalCtx { inst_name: config_fs.get_inst_name(), id, @@ -187,6 +282,10 @@ impl GlobalCtx { stats_manager: Arc::new(StatsManager::new()), acl_filter: Arc::new(AclFilter::new()), + + credential_manager, + + trusted_keys: Arc::new(TrustedKeyMapManager::new()), } } @@ -404,6 +503,37 @@ impl GlobalCtx { &self.acl_filter } + pub fn get_credential_manager(&self) -> &Arc { + &self.credential_manager + } + + /// Check if a public key is trusted using two-level lookup: + /// 1. OSPF propagated trusted_keys (lock-free) + /// 2. Local credential_manager + pub fn is_pubkey_trusted(&self, pubkey: &[u8], network_name: &str) -> bool { + // First level: check OSPF propagated keys (lock-free) + if self.trusted_keys.verify_trusted_key(pubkey, network_name) { + return true; + } + + // Second level: check local credential_manager if in the same network + if network_name == self.get_network_name() { + return self.credential_manager.is_pubkey_trusted(pubkey); + } + + false + } + + /// Atomically replace all OSPF trusted keys with a new set + /// Called by OSPF route layer after each route update + pub fn update_trusted_keys(&self, keys: TrustedKeyMap, network_name: &str) { + self.trusted_keys.update_trusted_keys(network_name, keys); + } + + pub fn remove_trusted_keys(&self, network_name: &str) { + self.trusted_keys.remove_trusted_keys(network_name); + } + pub fn get_acl_groups(&self, peer_id: PeerId) -> Vec { use std::collections::HashSet; self.config diff --git a/easytier/src/common/log.rs b/easytier/src/common/log.rs index a9e4d205b..0c9b1f084 100644 --- a/easytier/src/common/log.rs +++ b/easytier/src/common/log.rs @@ -1,3 +1,5 @@ +use std::io::IsTerminal as _; + use crate::common::config::LoggingConfigLoader; use crate::common::get_logger_timer_rfc3339; use crate::common::tracing_rolling_appender::{FileAppenderWrapper, RollingFileAppenderBase}; @@ -175,7 +177,8 @@ pub fn init( let layer = || { layer() - .pretty() + .compact() + .with_ansi(std::io::stderr().is_terminal()) .with_timer(get_logger_timer_rfc3339()) .with_writer(std::io::stderr) }; diff --git a/easytier/src/common/network.rs b/easytier/src/common/network.rs index 78bd4a5a8..3343b1827 100644 --- a/easytier/src/common/network.rs +++ b/easytier/src/common/network.rs @@ -18,7 +18,8 @@ struct InterfaceFilter { #[cfg(any( target_os = "android", - any(target_os = "ios", feature = "macos-ne"), + target_os = "ios", + all(target_os = "macos", feature = "macos-ne"), target_env = "ohos" ))] impl InterfaceFilter { diff --git a/easytier/src/common/stun.rs b/easytier/src/common/stun.rs index cfe5ec0b6..61988ffa2 100644 --- a/easytier/src/common/stun.rs +++ b/easytier/src/common/stun.rs @@ -25,6 +25,25 @@ use crate::common::error::Error; use super::dns::resolve_txt_record; use super::stun_codec_ext::*; +const DEFAULT_UDP_STUN_SERVERS: &[&str] = &[ + "txt:stun.easytier.cn", + "stun.miwifi.com", + "stun.chat.bilibili.com", + "stun.hitv.com", +]; + +const DEFAULT_TCP_STUN_SERVERS: &[&str] = &[ + "stun.hot-chilli.net", + "stun.fitauto.ru", + "fwa.lifesizecloud.com", + "global.turn.twilio.com", + "turn.cloudflare.com", + "stun.voip.blackberry.com", + "stun.radiojar.com", +]; + +const DEFAULT_UDP_V6_STUN_SERVERS: &[&str] = &["txt:stun-v6.easytier.cn"]; + struct HostResolverIter { hostnames: Vec, ips: Vec, @@ -1100,39 +1119,39 @@ impl StunInfoCollector { } pub fn get_default_servers() -> Vec { - // NOTICE: we may need to choose stun server based on geolocation - // stun server cross nation may return an external ip address with high latency and loss rate - [ - "txt:stun.easytier.cn", - "stun.miwifi.com", - "stun.chat.bilibili.com", - "stun.hitv.com", - ] - .iter() - .map(|x| x.to_string()) - .collect() + if cfg!(test) { + Vec::new() + } else { + // NOTICE: we may need to choose stun server based on geolocation + // stun server cross nation may return an external ip address with high latency and loss rate + DEFAULT_UDP_STUN_SERVERS + .iter() + .map(ToString::to_string) + .collect() + } } pub fn get_default_tcp_servers() -> Vec { - [ - "stun.hot-chilli.net", - "stun.fitauto.ru", - "fwa.lifesizecloud.com", - "global.turn.twilio.com", - "turn.cloudflare.com", - "stun.voip.blackberry.com", - "stun.radiojar.com", - ] - .iter() - .map(|x| x.to_string()) - .collect() + // if test, return empty vector + if cfg!(test) { + Vec::new() + } else { + DEFAULT_TCP_STUN_SERVERS + .iter() + .map(ToString::to_string) + .collect() + } } pub fn get_default_servers_v6() -> Vec { - ["txt:stun-v6.easytier.cn"] - .iter() - .map(|x| x.to_string()) - .collect() + if cfg!(test) { + Vec::new() + } else { + DEFAULT_UDP_V6_STUN_SERVERS + .iter() + .map(ToString::to_string) + .collect() + } } async fn get_public_ipv6(servers: &[String]) -> Option { @@ -1328,7 +1347,14 @@ mod tests { #[tokio::test] async fn test_udp_nat_type_detector() { - let collector = StunInfoCollector::new_with_default_servers(); + let collector = StunInfoCollector::new( + DEFAULT_UDP_STUN_SERVERS + .iter() + .map(ToString::to_string) + .collect(), + vec![], + vec![], + ); collector.update_stun_info(); loop { let ret = collector.get_stun_info(); diff --git a/easytier/src/connector/udp_hole_punch/cone.rs b/easytier/src/connector/udp_hole_punch/cone.rs index d28948a39..69b2d7613 100644 --- a/easytier/src/connector/udp_hole_punch/cone.rs +++ b/easytier/src/connector/udp_hole_punch/cone.rs @@ -98,7 +98,6 @@ impl PunchConeHoleClient { } } - #[tracing::instrument(skip(self))] pub(crate) async fn do_hole_punching( &self, dst_peer_id: PeerId, @@ -241,7 +240,7 @@ impl PunchConeHoleClient { } } - return Ok(None); + Ok(None) } } diff --git a/easytier/src/connector/udp_hole_punch/mod.rs b/easytier/src/connector/udp_hole_punch/mod.rs index b8c6249aa..8c61a72d1 100644 --- a/easytier/src/connector/udp_hole_punch/mod.rs +++ b/easytier/src/connector/udp_hole_punch/mod.rs @@ -245,7 +245,7 @@ impl UdpHoePunchConnectorData { tracing::info!(?tunnel, "hole punching get tunnel success"); if let Err(e) = self.peer_mgr.add_client_tunnel(tunnel, false).await { - tracing::warn!(?e, "add client tunnel failed"); + tracing::warn!("add client tunnel failed, err: {}", e); op(true); false } else { @@ -258,7 +258,7 @@ impl UdpHoePunchConnectorData { false } Err(e) => { - tracing::info!(?e, "hole punching failed"); + tracing::info!("hole punching failed, err: {}", e); op(true); false } diff --git a/easytier/src/core.rs b/easytier/src/core.rs index 488348bc9..a34341388 100644 --- a/easytier/src/core.rs +++ b/easytier/src/core.rs @@ -636,6 +636,20 @@ struct NetworkOptions { help = t!("core_clap.local_public_key").to_string() )] local_public_key: Option, + + #[arg( + long, + env = "ET_CREDENTIAL", + help = t!("core_clap.credential").to_string() + )] + credential: Option, + + #[arg( + long, + env = "ET_CREDENTIAL_FILE", + help = t!("core_clap.credential_file").to_string() + )] + credential_file: Option, } #[derive(Parser, Debug)] @@ -802,11 +816,17 @@ impl NetworkOptions { let old_ns = cfg.get_network_identity(); let network_name = self.network_name.clone().unwrap_or(old_ns.network_name); - let network_secret = self - .network_secret - .clone() - .unwrap_or(old_ns.network_secret.unwrap_or_default()); - cfg.set_network_identity(NetworkIdentity::new(network_name, network_secret)); + + if self.credential.is_some() { + // Credential mode: no network_secret, authenticate via credential keypair + cfg.set_network_identity(NetworkIdentity::new_credential(network_name)); + } else { + let network_secret = self + .network_secret + .clone() + .unwrap_or(old_ns.network_secret.unwrap_or_default()); + cfg.set_network_identity(NetworkIdentity::new(network_name, network_secret)); + } if let Some(dhcp) = self.dhcp { cfg.set_dhcp(dhcp); @@ -975,7 +995,19 @@ impl NetworkOptions { cfg.set_port_forwards(old); } - if let Some(secure_mode) = self.secure_mode { + if let Some(ref credential_file) = self.credential_file { + cfg.set_credential_file(Some(credential_file.clone())); + } + + if let Some(ref credential_secret) = self.credential { + // --credential implies --secure-mode and sets the credential private key + let c = SecureModeConfig { + enabled: true, + local_private_key: Some(credential_secret.clone()), + local_public_key: None, + }; + cfg.set_secure_mode(Some(Self::process_secure_mode_cfg(c)?)); + } else if let Some(secure_mode) = self.secure_mode { if secure_mode { let c = SecureModeConfig { enabled: secure_mode, @@ -1249,6 +1281,7 @@ async fn run_main(cli: Cli) -> anyhow::Result<()> { config_server_url_s, cli.machine_id.clone(), cli.network_options.hostname.clone(), + cli.network_options.secure_mode.unwrap_or(false), manager.clone(), None, ) diff --git a/easytier/src/easytier-cli.rs b/easytier/src/easytier-cli.rs index 923b9a87d..52f5d9656 100644 --- a/easytier/src/easytier-cli.rs +++ b/easytier/src/easytier-cli.rs @@ -37,17 +37,18 @@ use easytier::{ instance::{ instance_identifier::{InstanceSelector, Selector}, list_peer_route_pair, AclManageRpc, AclManageRpcClientFactory, ConnectorManageRpc, - ConnectorManageRpcClientFactory, DumpRouteRequest, GetAclStatsRequest, - GetPrometheusStatsRequest, GetStatsRequest, GetVpnPortalInfoRequest, - GetWhitelistRequest, InstanceIdentifier, ListConnectorRequest, - ListForeignNetworkRequest, ListGlobalForeignNetworkRequest, - ListMappedListenerRequest, ListPeerRequest, ListPeerResponse, - ListPortForwardRequest, ListRouteRequest, ListRouteResponse, + ConnectorManageRpcClientFactory, CredentialManageRpc, + CredentialManageRpcClientFactory, DumpRouteRequest, GenerateCredentialRequest, + GetAclStatsRequest, GetPrometheusStatsRequest, GetStatsRequest, + GetVpnPortalInfoRequest, GetWhitelistRequest, InstanceIdentifier, + ListConnectorRequest, ListCredentialsRequest, ListForeignNetworkRequest, + ListGlobalForeignNetworkRequest, ListMappedListenerRequest, ListPeerRequest, + ListPeerResponse, ListPortForwardRequest, ListRouteRequest, ListRouteResponse, MappedListenerManageRpc, MappedListenerManageRpcClientFactory, NodeInfo, PeerManageRpc, PeerManageRpcClientFactory, PortForwardManageRpc, - PortForwardManageRpcClientFactory, ShowNodeInfoRequest, StatsRpc, - StatsRpcClientFactory, TcpProxyEntryState, TcpProxyEntryTransportType, TcpProxyRpc, - TcpProxyRpcClientFactory, VpnPortalRpc, VpnPortalRpcClientFactory, + PortForwardManageRpcClientFactory, RevokeCredentialRequest, ShowNodeInfoRequest, + StatsRpc, StatsRpcClientFactory, TcpProxyEntryState, TcpProxyEntryTransportType, + TcpProxyRpc, TcpProxyRpcClientFactory, VpnPortalRpc, VpnPortalRpcClientFactory, }, logger::{ GetLoggerConfigRequest, LogLevel, LoggerRpc, LoggerRpcClientFactory, @@ -134,6 +135,8 @@ enum SubCommand { Stats(StatsArgs), #[command(about = "manage logger configuration")] Logger(LoggerArgs), + #[command(about = "manage temporary credentials")] + Credential(CredentialArgs), #[command(about = t!("core_clap.generate_completions").to_string())] GenAutocomplete { shell: ShellType }, } @@ -340,6 +343,42 @@ enum LoggerSubCommand { }, } +#[derive(Args, Debug)] +struct CredentialArgs { + #[command(subcommand)] + sub_command: CredentialSubCommand, +} + +#[derive(Subcommand, Debug)] +enum CredentialSubCommand { + /// Generate a new temporary credential + Generate { + #[arg(long, help = "TTL in seconds (required)")] + ttl: i64, + #[arg(long, value_delimiter = ',', help = "ACL groups (comma-separated)")] + groups: Option>, + #[arg( + long, + default_value = "false", + help = "allow relay through this credential node" + )] + allow_relay: bool, + #[arg( + long, + value_delimiter = ',', + help = "allowed proxy CIDRs (comma-separated)" + )] + allowed_proxy_cidrs: Option>, + }, + /// Revoke a credential by its ID + Revoke { + #[arg(help = "credential ID (UUID)")] + credential_id: String, + }, + /// List all active credentials + List, +} + #[derive(Args, Debug)] struct ServiceArgs { #[arg(short, long, default_value = env!("CARGO_PKG_NAME"), help = "service name")] @@ -537,6 +576,18 @@ impl CommandHandler<'_> { .with_context(|| "failed to get config client")?) } + async fn get_credential_client( + &self, + ) -> Result>, Error> { + Ok(self + .client + .lock() + .await + .scoped_client::>("".to_string()) + .await + .with_context(|| "failed to get credential client")?) + } + async fn list_peers(&self) -> Result { let client = self.get_peer_manager_client().await?; let request = ListPeerRequest { @@ -1363,6 +1414,121 @@ impl CommandHandler<'_> { Ok(()) } + async fn handle_credential_generate( + &self, + ttl: i64, + groups: Vec, + allow_relay: bool, + allowed_proxy_cidrs: Vec, + ) -> Result<(), Error> { + let client = self.get_credential_client().await?; + let request = GenerateCredentialRequest { + groups, + allow_relay, + allowed_proxy_cidrs, + ttl_seconds: ttl, + }; + let response = client + .generate_credential(BaseController::default(), request) + .await?; + + match self.output_format { + OutputFormat::Table => { + println!("Credential generated successfully:"); + println!(" credential_id: {}", response.credential_id); + println!(" credential_secret: {}", response.credential_secret); + println!(); + println!("To use this credential on a new node:"); + println!( + " easytier-core --network-name --secure-mode --credential {} -p ", + response.credential_secret + ); + } + OutputFormat::Json => { + let json = serde_json::to_string_pretty(&response)?; + println!("{}", json); + } + } + + Ok(()) + } + + async fn handle_credential_revoke(&self, credential_id: &str) -> Result<(), Error> { + let client = self.get_credential_client().await?; + let request = RevokeCredentialRequest { + credential_id: credential_id.to_string(), + }; + let response = client + .revoke_credential(BaseController::default(), request) + .await?; + + match self.output_format { + OutputFormat::Table => { + if response.success { + println!("Credential revoked successfully"); + } else { + println!("Credential not found"); + } + } + OutputFormat::Json => { + let json = serde_json::to_string_pretty(&response)?; + println!("{}", json); + } + } + + Ok(()) + } + + async fn handle_credential_list(&self) -> Result<(), Error> { + let client = self.get_credential_client().await?; + let request = ListCredentialsRequest {}; + let response = client + .list_credentials(BaseController::default(), request) + .await?; + + match self.output_format { + OutputFormat::Table => { + if response.credentials.is_empty() { + println!("No active credentials"); + } else { + use tabled::{builder::Builder, settings::Style}; + let mut builder = Builder::default(); + builder.push_record(["ID", "Groups", "Relay", "Expiry", "Allowed CIDRs"]); + for cred in &response.credentials { + let expiry = { + let secs = cred.expiry_unix; + let remaining = secs + - std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + if remaining > 0 { + format!("{}s remaining", remaining) + } else { + "expired".to_string() + } + }; + builder.push_record([ + &cred.credential_id[..], + &cred.groups.join(","), + if cred.allow_relay { "yes" } else { "no" }, + &expiry, + &cred.allowed_proxy_cidrs.join(","), + ]); + } + let table = builder.build().with(Style::rounded()).to_string(); + println!("{}", table); + } + } + OutputFormat::Json => { + let json = serde_json::to_string_pretty(&response)?; + println!("{}", json); + } + } + + Ok(()) + } + fn parse_port_list(ports_str: &str) -> Result, Error> { let mut ports = Vec::new(); for port_spec in ports_str.split(',') { @@ -2193,6 +2359,29 @@ async fn main() -> Result<(), Error> { handler.handle_logger_set(level).await?; } }, + SubCommand::Credential(credential_args) => match &credential_args.sub_command { + CredentialSubCommand::Generate { + ttl, + groups, + allow_relay, + allowed_proxy_cidrs, + } => { + handler + .handle_credential_generate( + *ttl, + groups.clone().unwrap_or_default(), + *allow_relay, + allowed_proxy_cidrs.clone().unwrap_or_default(), + ) + .await?; + } + CredentialSubCommand::Revoke { credential_id } => { + handler.handle_credential_revoke(credential_id).await?; + } + CredentialSubCommand::List => { + handler.handle_credential_list().await?; + } + }, SubCommand::GenAutocomplete { shell } => { let mut cmd = Cli::command(); if let Some(shell) = shell.to_shell() { diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index b9ccb44f5..da8c09972 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -1316,6 +1316,7 @@ impl Instance { stats_rpc_service: G, config_rpc_service: H, peer_center_rpc_service: Arc, + credential_manage_rpc_service: PeerManagerRpcService, } #[async_trait::async_trait] @@ -1383,6 +1384,12 @@ impl Instance { ) -> Arc + Send + Sync> { self.peer_center_rpc_service.clone() } + + fn get_credential_manage_service( + &self, + ) -> &dyn CredentialManageRpc { + &self.credential_manage_rpc_service + } } ApiRpcServiceImpl { @@ -1444,6 +1451,7 @@ impl Instance { stats_rpc_service: self.get_stats_rpc_service(), config_rpc_service: self.get_config_service(), peer_center_rpc_service: Arc::new(self.peer_center.get_rpc_service()), + credential_manage_rpc_service: PeerManagerRpcService::new(self.peer_manager.clone()), } } diff --git a/easytier/src/instance_manager.rs b/easytier/src/instance_manager.rs index 4e6c8eea7..9e01ec2e5 100644 --- a/easytier/src/instance_manager.rs +++ b/easytier/src/instance_manager.rs @@ -423,6 +423,10 @@ fn handle_event( instance_id ); } + + GlobalCtxEvent::CredentialChanged => { + event!(info, "[{}] credential changed", instance_id); + } } } else { events = events.resubscribe(); diff --git a/easytier/src/launcher.rs b/easytier/src/launcher.rs index 2abfe5f8f..fa46d7bbe 100644 --- a/easytier/src/launcher.rs +++ b/easytier/src/launcher.rs @@ -241,6 +241,7 @@ impl EasyTierLauncher { } instance_alive.store(false, std::sync::atomic::Ordering::Relaxed); notifier.notify_one(); + rt.shutdown_background(); })); } diff --git a/easytier/src/peers/credential_manager.rs b/easytier/src/peers/credential_manager.rs new file mode 100644 index 000000000..b2a239542 --- /dev/null +++ b/easytier/src/peers/credential_manager.rs @@ -0,0 +1,407 @@ +use std::{ + collections::HashMap, + path::PathBuf, + sync::Mutex, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use base64::Engine; +use serde::{Deserialize, Serialize}; +use x25519_dalek::{PublicKey, StaticSecret}; + +use crate::proto::peer_rpc::{TrustedCredentialPubkey, TrustedCredentialPubkeyProof}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CredentialEntry { + pubkey: String, + groups: Vec, + allow_relay: bool, + allowed_proxy_cidrs: Vec, + expiry_unix: i64, + created_at_unix: i64, +} + +pub struct CredentialManager { + credentials: Mutex>, + storage_path: Option, +} + +impl CredentialManager { + pub fn new(storage_path: Option) -> Self { + let mgr = CredentialManager { + credentials: Mutex::new(HashMap::new()), + storage_path, + }; + mgr.load_from_disk(); + mgr + } + + pub fn generate_credential( + &self, + groups: Vec, + allow_relay: bool, + allowed_proxy_cidrs: Vec, + ttl: Duration, + ) -> (String, String) { + let private = StaticSecret::random_from_rng(rand::rngs::OsRng); + let public = PublicKey::from(&private); + let id = uuid::Uuid::new_v4().to_string(); + let pubkey = BASE64_STANDARD.encode(public.as_bytes()); + let secret = BASE64_STANDARD.encode(private.as_bytes()); + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + let expiry_unix = now + ttl.as_secs() as i64; + + let entry = CredentialEntry { + pubkey, + groups, + allow_relay, + allowed_proxy_cidrs, + expiry_unix, + created_at_unix: now, + }; + + self.credentials.lock().unwrap().insert(id.clone(), entry); + self.save_to_disk(); + (id, secret) + } + + pub fn revoke_credential(&self, credential_id: &str) -> bool { + let removed = self + .credentials + .lock() + .unwrap() + .remove(credential_id) + .is_some(); + if removed { + self.save_to_disk(); + } + removed + } + + pub fn get_trusted_pubkeys(&self, network_secret: &str) -> Vec { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + self.credentials + .lock() + .unwrap() + .values() + .filter(|e| e.expiry_unix > now) + .map(|e| { + let credential = TrustedCredentialPubkey { + pubkey: Self::decode_pubkey_b64(&e.pubkey).unwrap_or_default(), + groups: e.groups.clone(), + allow_relay: e.allow_relay, + expiry_unix: e.expiry_unix, + allowed_proxy_cidrs: e.allowed_proxy_cidrs.clone(), + }; + TrustedCredentialPubkeyProof::new_signed(credential, network_secret) + }) + .filter(|e| { + e.credential + .as_ref() + .map(|x| !x.pubkey.is_empty()) + .unwrap_or(false) + }) + .collect() + } + + pub fn is_pubkey_trusted(&self, pubkey: &[u8]) -> bool { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + let encoded = BASE64_STANDARD.encode(pubkey); + self.credentials + .lock() + .unwrap() + .values() + .any(|e| e.pubkey == encoded && e.expiry_unix > now) + } + + pub fn list_credentials(&self) -> Vec { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + self.credentials + .lock() + .unwrap() + .iter() + .filter(|(_, e)| e.expiry_unix > now) + .map(|(id, e)| crate::proto::api::instance::CredentialInfo { + credential_id: id.clone(), + groups: e.groups.clone(), + allow_relay: e.allow_relay, + expiry_unix: e.expiry_unix, + allowed_proxy_cidrs: e.allowed_proxy_cidrs.clone(), + }) + .collect() + } + + fn save_to_disk(&self) { + let Some(path) = &self.storage_path else { + return; + }; + let creds = self.credentials.lock().unwrap(); + if let Ok(json) = serde_json::to_string_pretty(&*creds) { + if let Err(e) = std::fs::write(path, json) { + tracing::warn!(?e, "failed to save credentials to disk"); + } + } + } + + fn load_from_disk(&self) { + let Some(path) = &self.storage_path else { + return; + }; + let Ok(data) = std::fs::read_to_string(path) else { + return; + }; + match serde_json::from_str::>(&data) { + Ok(loaded) => { + *self.credentials.lock().unwrap() = loaded; + tracing::info!("loaded credentials from {}", path.display()); + } + Err(e) => { + tracing::warn!(?e, "failed to parse credentials file"); + } + } + } + + fn decode_pubkey_b64(s: &str) -> Option> { + let decoded = BASE64_STANDARD.decode(s).ok()?; + if decoded.len() != 32 { + return None; + } + Some(decoded) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_and_revoke() { + let mgr = CredentialManager::new(None); + let (id, secret) = mgr.generate_credential( + vec!["guest".to_string()], + false, + vec![], + Duration::from_secs(3600), + ); + + assert!(!id.is_empty()); + assert!(!secret.is_empty()); + assert!(uuid::Uuid::parse_str(&id).is_ok()); + + let privkey_bytes: [u8; 32] = BASE64_STANDARD.decode(&secret).unwrap().try_into().unwrap(); + let private = StaticSecret::from(privkey_bytes); + let pubkey_bytes = PublicKey::from(&private).as_bytes().to_vec(); + assert!(mgr.is_pubkey_trusted(&pubkey_bytes)); + + let trusted = mgr.get_trusted_pubkeys("sec"); + assert_eq!(trusted.len(), 1); + assert_eq!( + trusted[0].credential.as_ref().unwrap().groups, + vec!["guest".to_string()] + ); + + assert!(mgr.revoke_credential(&id)); + assert!(!mgr.is_pubkey_trusted(&pubkey_bytes)); + assert!(mgr.get_trusted_pubkeys("sec").is_empty()); + } + + #[test] + fn test_expired_credential() { + let mgr = CredentialManager::new(None); + // TTL of 0 seconds - immediately expired + let (_, secret) = mgr.generate_credential(vec![], false, vec![], Duration::from_secs(0)); + + let privkey_bytes: [u8; 32] = BASE64_STANDARD.decode(&secret).unwrap().try_into().unwrap(); + let private = StaticSecret::from(privkey_bytes); + let pubkey_bytes = PublicKey::from(&private).as_bytes().to_vec(); + assert!(!mgr.is_pubkey_trusted(&pubkey_bytes)); + assert!(mgr.get_trusted_pubkeys("sec").is_empty()); + } + + #[test] + fn test_list_credentials() { + let mgr = CredentialManager::new(None); + mgr.generate_credential( + vec!["a".to_string()], + true, + vec!["10.0.0.0/24".to_string()], + Duration::from_secs(3600), + ); + mgr.generate_credential(vec![], false, vec![], Duration::from_secs(3600)); + + let list = mgr.list_credentials(); + assert_eq!(list.len(), 2); + } + + #[test] + fn test_keypair_validity() { + // Verify the generated private key can derive the same public key + let mgr = CredentialManager::new(None); + let (id, secret) = + mgr.generate_credential(vec![], false, vec![], Duration::from_secs(3600)); + + let privkey_bytes: [u8; 32] = BASE64_STANDARD.decode(&secret).unwrap().try_into().unwrap(); + let private = StaticSecret::from(privkey_bytes); + let derived_public = PublicKey::from(&private); + assert!(uuid::Uuid::parse_str(&id).is_ok()); + assert!(mgr.is_pubkey_trusted(derived_public.as_bytes())); + } + + #[test] + fn test_revoke_nonexistent() { + let mgr = CredentialManager::new(None); + assert!(!mgr.revoke_credential("nonexistent_id")); + } + + #[test] + fn test_multiple_credentials_independent() { + let mgr = CredentialManager::new(None); + let (id1, secret1) = mgr.generate_credential( + vec!["group1".to_string()], + false, + vec![], + Duration::from_secs(3600), + ); + let (_id2, secret2) = mgr.generate_credential( + vec!["group2".to_string()], + true, + vec!["10.0.0.0/8".to_string()], + Duration::from_secs(3600), + ); + + let sk1: [u8; 32] = BASE64_STANDARD + .decode(&secret1) + .unwrap() + .try_into() + .unwrap(); + let sk2: [u8; 32] = BASE64_STANDARD + .decode(&secret2) + .unwrap() + .try_into() + .unwrap(); + let pk1 = PublicKey::from(&StaticSecret::from(sk1)) + .as_bytes() + .to_vec(); + let pk2 = PublicKey::from(&StaticSecret::from(sk2)) + .as_bytes() + .to_vec(); + + assert!(mgr.is_pubkey_trusted(&pk1)); + assert!(mgr.is_pubkey_trusted(&pk2)); + + // Revoke first, second should still be trusted + mgr.revoke_credential(&id1); + assert!(!mgr.is_pubkey_trusted(&pk1)); + assert!(mgr.is_pubkey_trusted(&pk2)); + + let trusted = mgr.get_trusted_pubkeys("sec"); + assert_eq!(trusted.len(), 1); + assert_eq!( + trusted[0].credential.as_ref().unwrap().groups, + vec!["group2".to_string()] + ); + assert!(trusted[0].credential.as_ref().unwrap().allow_relay); + assert_eq!( + trusted[0].credential.as_ref().unwrap().allowed_proxy_cidrs, + vec!["10.0.0.0/8".to_string()] + ); + } + + #[test] + fn test_trusted_pubkeys_include_metadata() { + let mgr = CredentialManager::new(None); + let (_, secret) = mgr.generate_credential( + vec!["admin".to_string(), "ops".to_string()], + true, + vec!["192.168.0.0/16".to_string(), "10.0.0.0/8".to_string()], + Duration::from_secs(7200), + ); + + let trusted = mgr.get_trusted_pubkeys("sec"); + assert_eq!(trusted.len(), 1); + let tc = &trusted[0]; + assert_eq!( + tc.credential.as_ref().unwrap().groups, + vec!["admin".to_string(), "ops".to_string()] + ); + assert!(tc.credential.as_ref().unwrap().allow_relay); + assert_eq!( + tc.credential.as_ref().unwrap().allowed_proxy_cidrs, + vec!["192.168.0.0/16".to_string(), "10.0.0.0/8".to_string()] + ); + assert!(tc.credential.as_ref().unwrap().expiry_unix > 0); + assert!(tc.verify_credential_hmac("sec")); + assert!(tc + .credential + .as_ref() + .map(|x| !x.pubkey.is_empty()) + .unwrap_or(false)); + + let sk: [u8; 32] = BASE64_STANDARD.decode(&secret).unwrap().try_into().unwrap(); + let pk = PublicKey::from(&StaticSecret::from(sk)).as_bytes().to_vec(); + assert_eq!(tc.credential.as_ref().unwrap().pubkey, pk); + } + + #[test] + fn test_unknown_pubkey_not_trusted() { + let mgr = CredentialManager::new(None); + mgr.generate_credential(vec![], false, vec![], Duration::from_secs(3600)); + + let random_key = [42u8; 32]; + assert!(!mgr.is_pubkey_trusted(&random_key)); + } + + #[test] + fn test_persistence_roundtrip() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("creds.json"); + + // Create and save + { + let mgr = CredentialManager::new(Some(path.clone())); + mgr.generate_credential( + vec!["persist_group".to_string()], + true, + vec!["10.0.0.0/24".to_string()], + Duration::from_secs(3600), + ); + assert_eq!(mgr.list_credentials().len(), 1); + } + + // Load from disk + { + let mgr = CredentialManager::new(Some(path)); + let list = mgr.list_credentials(); + assert_eq!(list.len(), 1); + assert_eq!(list[0].groups, vec!["persist_group".to_string()]); + assert!(list[0].allow_relay); + } + } + + #[test] + fn test_list_credentials_filters_expired() { + let mgr = CredentialManager::new(None); + mgr.generate_credential(vec![], false, vec![], Duration::from_secs(3600)); + mgr.generate_credential(vec![], false, vec![], Duration::from_secs(0)); // expired + + let list = mgr.list_credentials(); + assert_eq!(list.len(), 1); + } +} diff --git a/easytier/src/peers/foreign_network_client.rs b/easytier/src/peers/foreign_network_client.rs index 13da76d77..99e2794f6 100644 --- a/easytier/src/peers/foreign_network_client.rs +++ b/easytier/src/peers/foreign_network_client.rs @@ -38,7 +38,7 @@ impl ForeignNetworkClient { } } - pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) { + pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> { tracing::warn!(peer_conn = ?peer_conn.get_conn_info(), network = ?peer_conn.get_network_identity(), "add new peer conn in foreign network client"); self.peer_map.add_new_peer_conn(peer_conn).await } diff --git a/easytier/src/peers/foreign_network_manager.rs b/easytier/src/peers/foreign_network_manager.rs index 3feb65a26..903b713d0 100644 --- a/easytier/src/peers/foreign_network_manager.rs +++ b/easytier/src/peers/foreign_network_manager.rs @@ -34,7 +34,7 @@ use crate::{ proto::{ api::instance::{ForeignNetworkEntryPb, ListForeignNetworkResponse, PeerInfo}, common::LimiterConfig, - peer_rpc::DirectConnectorRpcServer, + peer_rpc::{DirectConnectorRpcServer, PeerIdentityType}, }, tunnel::packet_def::{PacketType, ZCPacket}, use_global_var, @@ -47,7 +47,9 @@ use super::{ peer_ospf_route::PeerRoute, peer_rpc::{PeerRpcManager, PeerRpcManagerTransport}, peer_rpc_service::DirectConnectorManagerRpcServer, + peer_session::PeerSessionStore, recv_packet_from_chan, + relay_peer_map::RelayPeerMap, route_trait::NextHopPolicy, PacketRecvChan, PacketRecvChanReceiver, PUBLIC_SERVER_HOSTNAME_PREFIX, }; @@ -64,6 +66,8 @@ struct ForeignNetworkEntry { global_ctx: ArcGlobalCtx, network: NetworkIdentity, peer_map: Arc, + relay_peer_map: Arc, + peer_session_store: Arc, relay_data: bool, pm_packet_sender: Mutex>, @@ -90,6 +94,7 @@ impl ForeignNetworkEntry { my_peer_id: PeerId, global_ctx: ArcGlobalCtx, relay_data: bool, + peer_session_store: Arc, pm_packet_sender: PacketRecvChan, ) -> Self { let stats_mgr = global_ctx.stats_manager().clone(); @@ -103,6 +108,13 @@ impl ForeignNetworkEntry { foreign_global_ctx.clone(), my_peer_id, )); + let relay_peer_map = RelayPeerMap::new( + peer_map.clone(), + None, + foreign_global_ctx.clone(), + my_peer_id, + peer_session_store.clone(), + ); let (peer_rpc, rpc_transport_sender) = Self::build_rpc_tspt(my_peer_id, peer_map.clone()); @@ -136,6 +148,8 @@ impl ForeignNetworkEntry { global_ctx: foreign_global_ctx, network, peer_map, + relay_peer_map, + peer_session_store, relay_data, pm_packet_sender: Mutex::new(Some(pm_packet_sender)), @@ -168,6 +182,7 @@ impl ForeignNetworkEntry { PUBLIC_SERVER_HOSTNAME_PREFIX, global_ctx.get_hostname() ))); + config.set_secure_mode(global_ctx.config.get_secure_mode()); let mut flags = config.get_flags(); flags.disable_relay_kcp = !global_ctx.get_flags().enable_relay_foreign_network_kcp; @@ -314,6 +329,7 @@ impl ForeignNetworkEntry { let my_node_id = self.my_peer_id; let rpc_sender = self.rpc_sender.clone(); let peer_map = self.peer_map.clone(); + let relay_peer_map = self.relay_peer_map.clone(); let relay_data = self.relay_data; let pm_sender = self.pm_packet_sender.lock().await.take().unwrap(); let network_name = self.network.network_name.clone(); @@ -335,34 +351,64 @@ impl ForeignNetworkEntry { .get_counter(MetricName::TrafficPacketsRx, label_set.clone()); self.tasks.lock().await.spawn(async move { - while let Ok(zc_packet) = recv_packet_from_chan(&mut recv).await { + while let Ok(mut zc_packet) = recv_packet_from_chan(&mut recv).await { let buf_len = zc_packet.buf_len(); let Some(hdr) = zc_packet.peer_manager_header() else { tracing::warn!("invalid packet, skip"); continue; }; tracing::trace!(?hdr, "recv packet in foreign network manager"); + let from_peer_id = hdr.from_peer_id.get(); + let packet_type = hdr.packet_type; + let len = hdr.len.get(); let to_peer_id = hdr.to_peer_id.get(); if to_peer_id == my_node_id { - if hdr.packet_type == PacketType::TaRpc as u8 - || hdr.packet_type == PacketType::RpcReq as u8 - || hdr.packet_type == PacketType::RpcResp as u8 + if packet_type == PacketType::RelayHandshake as u8 + || packet_type == PacketType::RelayHandshakeAck as u8 + { + let _ = relay_peer_map.handle_handshake_packet(zc_packet).await; + continue; + } + + if !peer_map.has_peer(from_peer_id) && relay_peer_map.is_secure_mode_enabled() { + match relay_peer_map.decrypt_if_needed(&mut zc_packet).await { + Ok(true) => {} + Ok(false) => { + tracing::error!("relay session not found"); + continue; + } + Err(e) => { + tracing::error!(?e, "relay decrypt failed"); + continue; + } + } + } + + if packet_type == PacketType::TaRpc as u8 + || packet_type == PacketType::RpcReq as u8 + || packet_type == PacketType::RpcResp as u8 { rx_bytes.add(buf_len as u64); rx_packets.inc(); rpc_sender.send(zc_packet).unwrap(); continue; } - tracing::trace!(?hdr, "ignore packet in foreign network"); + tracing::trace!( + ?packet_type, + ?len, + ?from_peer_id, + ?to_peer_id, + "ignore packet in foreign network" + ); } else { - if hdr.packet_type == PacketType::Data as u8 - || hdr.packet_type == PacketType::KcpSrc as u8 - || hdr.packet_type == PacketType::KcpDst as u8 + if packet_type == PacketType::Data as u8 + || packet_type == PacketType::KcpSrc as u8 + || packet_type == PacketType::KcpDst as u8 { if !relay_data { continue; } - if !bps_limiter.try_consume(hdr.len.into()) { + if !bps_limiter.try_consume(len.into()) { continue; } } @@ -376,7 +422,19 @@ impl ForeignNetworkEntry { match gateway_peer_id { Some(peer_id) if peer_map.has_peer(peer_id) => { - if let Err(e) = peer_map.send_msg_directly(zc_packet, peer_id).await { + if peer_id != to_peer_id && hdr.from_peer_id.get() == my_node_id { + if let Err(e) = relay_peer_map + .send_msg(zc_packet, to_peer_id, NextHopPolicy::LeastHop) + .await + { + tracing::error!( + ?e, + "send packet to foreign peer inside relay peer map failed" + ); + } + } else if let Err(e) = + peer_map.send_msg_directly(zc_packet, peer_id).await + { tracing::error!( ?e, "send packet to foreign peer inside peer map failed" @@ -405,9 +463,20 @@ impl ForeignNetworkEntry { }); } + async fn run_relay_session_gc_routine(&self) { + let relay_peer_map = self.relay_peer_map.clone(); + self.tasks.lock().await.spawn(async move { + loop { + relay_peer_map.evict_idle_sessions(std::time::Duration::from_secs(60)); + tokio::time::sleep(std::time::Duration::from_secs(30)).await; + } + }); + } + async fn prepare(&self, accessor: Box) { self.prepare_route(accessor).await; self.start_packet_recv().await; + self.run_relay_session_gc_routine().await; self.peer_rpc.run(); self.peer_center.init().await; } @@ -419,6 +488,8 @@ impl Drop for ForeignNetworkEntry { .rpc_server() .registry() .unregister_by_domain(&self.network.network_name); + self.global_ctx + .remove_trusted_keys(&self.network.network_name); tracing::debug!(self.my_peer_id, ?self.network, "drop foreign network entry"); } @@ -484,6 +555,7 @@ impl ForeignNetworkManagerData { self.network_peer_last_update.remove(network_name); } + #[allow(clippy::too_many_arguments)] async fn get_or_insert_entry( &self, network_identity: &NetworkIdentity, @@ -491,6 +563,7 @@ impl ForeignNetworkManagerData { dst_peer_id: PeerId, relay_data: bool, global_ctx: &ArcGlobalCtx, + peer_session_store: Arc, pm_packet_sender: &PacketRecvChan, ) -> (Arc, bool) { let mut new_added = false; @@ -506,6 +579,7 @@ impl ForeignNetworkManagerData { my_peer_id, global_ctx.clone(), relay_data, + peer_session_store, pm_packet_sender.clone(), )) }) @@ -534,6 +608,7 @@ pub const FOREIGN_NETWORK_SERVICE_ID: u32 = 1; pub struct ForeignNetworkManager { my_peer_id: PeerId, global_ctx: ArcGlobalCtx, + peer_session_store: Arc, packet_sender_to_mgr: PacketRecvChan, data: Arc, @@ -545,6 +620,7 @@ impl ForeignNetworkManager { pub fn new( my_peer_id: PeerId, global_ctx: ArcGlobalCtx, + peer_session_store: Arc, packet_sender_to_mgr: PacketRecvChan, accessor: Box, ) -> Self { @@ -562,6 +638,7 @@ impl ForeignNetworkManager { Self { my_peer_id, global_ctx, + peer_session_store, packet_sender_to_mgr, data, @@ -597,13 +674,15 @@ impl ForeignNetworkManager { peer_conn.get_peer_id(), ret.is_ok(), &self.global_ctx, + self.peer_session_store.clone(), &self.packet_sender_to_mgr, ) .await; let _g = entry.lock.lock().await; - if entry.network != peer_conn.get_network_identity() + if (entry.network != peer_conn.get_network_identity() + && peer_conn.get_peer_identity_type() != PeerIdentityType::SharedNode) || entry.my_peer_id != peer_conn.get_my_peer_id() { if new_added { @@ -642,7 +721,7 @@ impl ForeignNetworkManager { } } - entry.peer_map.add_new_peer_conn(peer_conn).await; + entry.peer_map.add_new_peer_conn(peer_conn).await?; Ok(()) } @@ -726,7 +805,7 @@ impl ForeignNetworkManager { .map(|v| *v) } - pub async fn send_msg_to_peer( + pub async fn forward_foreign_network_packet( &self, network_name: &str, dst_peer_id: PeerId, diff --git a/easytier/src/peers/mod.rs b/easytier/src/peers/mod.rs index 9b3b0ee04..6c7a34732 100644 --- a/easytier/src/peers/mod.rs +++ b/easytier/src/peers/mod.rs @@ -1,6 +1,7 @@ mod graph_algo; pub mod acl_filter; +pub mod credential_manager; pub mod peer; pub mod peer_conn; pub mod peer_conn_ping; @@ -10,6 +11,7 @@ pub mod peer_ospf_route; pub mod peer_rpc; pub mod peer_rpc_service; pub mod peer_session; +pub mod relay_peer_map; pub mod route_trait; pub mod rpc_service; diff --git a/easytier/src/peers/peer.rs b/easytier/src/peers/peer.rs index 3e982d7eb..376f793b2 100644 --- a/easytier/src/peers/peer.rs +++ b/easytier/src/peers/peer.rs @@ -17,6 +17,7 @@ use crate::{ global_ctx::{ArcGlobalCtx, GlobalCtxEvent}, PeerId, }, + proto::peer_rpc::PeerIdentityType, tunnel::packet_def::ZCPacket, }; use crate::{ @@ -40,6 +41,7 @@ pub struct Peer { shutdown_notifier: Arc, default_conn_id: Arc>, + peer_identity_type: Arc>>, default_conn_id_clear_task: ScopedTask<()>, } @@ -52,6 +54,8 @@ impl Peer { let conns: ConnMap = Arc::new(DashMap::new()); let (close_event_sender, mut close_event_receiver) = mpsc::channel(10); let shutdown_notifier = Arc::new(tokio::sync::Notify::new()); + let peer_identity_type = Arc::new(AtomicCell::new(None)); + let peer_identity_type_copy = peer_identity_type.clone(); let conns_copy = conns.clone(); let shutdown_notifier_copy = shutdown_notifier.clone(); @@ -76,6 +80,9 @@ impl Peer { conn.get_conn_info(), )); shrink_dashmap(&conns_copy, Some(4)); + if conns_copy.is_empty() { + peer_identity_type_copy.store(None); + } } } @@ -118,11 +125,25 @@ impl Peer { shutdown_notifier, default_conn_id, + peer_identity_type, default_conn_id_clear_task, } } - pub async fn add_peer_conn(&self, mut conn: PeerConn) { + pub async fn add_peer_conn(&self, mut conn: PeerConn) -> Result<(), Error> { + let conn_identity_type = conn.get_peer_identity_type(); + let peer_identity_type = self.peer_identity_type.load(); + if let Some(peer_identity_type) = peer_identity_type { + if peer_identity_type != conn_identity_type { + return Err(Error::SecretKeyError(format!( + "peer identity type mismatch. peer: {:?}, conn: {:?}", + peer_identity_type, conn_identity_type + ))); + } + } else { + self.peer_identity_type.store(Some(conn_identity_type)); + } + let close_notifier = conn.get_close_notifier(); let conn_info = conn.get_conn_info(); @@ -143,6 +164,7 @@ impl Peer { self.global_ctx .issue_event(GlobalCtxEvent::PeerConnAdded(conn_info)); + Ok(()) } async fn select_conn(&self) -> Option { @@ -221,6 +243,10 @@ impl Peer { pub fn get_default_conn_id(&self) -> PeerConnId { self.default_conn_id.load() } + + pub fn get_peer_identity_type(&self) -> Option { + self.peer_identity_type.load() + } } // pritn on drop @@ -238,17 +264,38 @@ impl Drop for Peer { #[cfg(test)] mod tests { + use base64::prelude::{Engine as _, BASE64_STANDARD}; + use rand::rngs::OsRng; use std::sync::Arc; use tokio::time::timeout; use crate::{ - common::{global_ctx::tests::get_mock_global_ctx, new_peer_id}, + common::{ + config::{NetworkIdentity, PeerConfig}, + global_ctx::{tests::get_mock_global_ctx, GlobalCtx}, + new_peer_id, + }, peers::{create_packet_recv_chan, peer_conn::PeerConn, peer_session::PeerSessionStore}, + proto::common::SecureModeConfig, tunnel::ring::create_ring_tunnel_pair, }; use super::Peer; + fn set_secure_mode_cfg(global_ctx: &GlobalCtx, enabled: bool) { + if !enabled { + global_ctx.config.set_secure_mode(None); + } else { + let private = x25519_dalek::StaticSecret::random_from_rng(OsRng); + let public = x25519_dalek::PublicKey::from(&private); + global_ctx.config.set_secure_mode(Some(SecureModeConfig { + enabled: true, + local_private_key: Some(BASE64_STANDARD.encode(private.as_bytes())), + local_public_key: Some(BASE64_STANDARD.encode(public.as_bytes())), + })); + } + } + #[tokio::test] async fn close_peer() { let (local_packet_send, _local_packet_recv) = create_packet_recv_chan(); @@ -284,8 +331,8 @@ mod tests { let local_conn_id = local_peer_conn.get_conn_id(); - local_peer.add_peer_conn(local_peer_conn).await; - remote_peer.add_peer_conn(remote_peer_conn).await; + local_peer.add_peer_conn(local_peer_conn).await.unwrap(); + remote_peer.add_peer_conn(remote_peer_conn).await.unwrap(); assert_eq!(local_peer.list_peer_conns().await.len(), 1); assert_eq!(remote_peer.list_peer_conns().await.len(), 1); @@ -305,4 +352,110 @@ mod tests { println!("wait for close handler"); close_handler.await.unwrap().unwrap(); } + + #[tokio::test] + async fn reject_peer_conn_with_mismatched_identity_type() { + let (packet_send, _packet_recv) = create_packet_recv_chan(); + let global_ctx = get_mock_global_ctx(); + let local_peer_id = new_peer_id(); + let remote_peer_id = new_peer_id(); + let peer = Peer::new(remote_peer_id, packet_send, global_ctx); + + let ps = Arc::new(PeerSessionStore::new()); + + let (shared_client_tunnel, shared_server_tunnel) = create_ring_tunnel_pair(); + let shared_client_ctx = get_mock_global_ctx(); + let shared_server_ctx = get_mock_global_ctx(); + shared_client_ctx + .config + .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec2".to_string())); + shared_server_ctx + .config + .set_network_identity(NetworkIdentity { + network_name: "net2".to_string(), + network_secret: None, + network_secret_digest: None, + }); + set_secure_mode_cfg(&shared_client_ctx, true); + set_secure_mode_cfg(&shared_server_ctx, true); + let remote_url: url::Url = shared_client_tunnel + .info() + .unwrap() + .remote_addr + .unwrap() + .url + .parse() + .unwrap(); + shared_client_ctx.config.set_peers(vec![PeerConfig { + uri: remote_url, + peer_public_key: Some( + shared_server_ctx + .config + .get_secure_mode() + .unwrap() + .local_public_key + .unwrap(), + ), + }]); + let mut shared_client_conn = PeerConn::new( + local_peer_id, + shared_client_ctx, + Box::new(shared_client_tunnel), + ps.clone(), + ); + let mut shared_server_conn = PeerConn::new( + remote_peer_id, + shared_server_ctx, + Box::new(shared_server_tunnel), + ps.clone(), + ); + let (c1, s1) = tokio::join!( + shared_client_conn.do_handshake_as_client(), + shared_server_conn.do_handshake_as_server() + ); + c1.unwrap(); + s1.unwrap(); + assert_eq!( + shared_client_conn.get_peer_identity_type(), + crate::proto::peer_rpc::PeerIdentityType::SharedNode + ); + + let (admin_client_tunnel, admin_server_tunnel) = create_ring_tunnel_pair(); + let admin_client_ctx = get_mock_global_ctx(); + let admin_server_ctx = get_mock_global_ctx(); + admin_client_ctx + .config + .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec2".to_string())); + admin_server_ctx + .config + .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec2".to_string())); + set_secure_mode_cfg(&admin_client_ctx, true); + set_secure_mode_cfg(&admin_server_ctx, true); + let mut admin_client_conn = PeerConn::new( + local_peer_id, + admin_client_ctx, + Box::new(admin_client_tunnel), + Arc::new(PeerSessionStore::new()), + ); + let mut admin_server_conn = PeerConn::new( + remote_peer_id, + admin_server_ctx, + Box::new(admin_server_tunnel), + Arc::new(PeerSessionStore::new()), + ); + let (c2, s2) = tokio::join!( + admin_client_conn.do_handshake_as_client(), + admin_server_conn.do_handshake_as_server() + ); + c2.unwrap(); + s2.unwrap(); + assert_eq!( + admin_client_conn.get_peer_identity_type(), + crate::proto::peer_rpc::PeerIdentityType::Admin + ); + + peer.add_peer_conn(shared_client_conn).await.unwrap(); + let ret = peer.add_peer_conn(admin_client_conn).await; + assert!(ret.is_err()); + } } diff --git a/easytier/src/peers/peer_conn.rs b/easytier/src/peers/peer_conn.rs index 5213d036d..5c1491b15 100644 --- a/easytier/src/peers/peer_conn.rs +++ b/easytier/src/peers/peer_conn.rs @@ -43,7 +43,7 @@ use crate::{ common::{LimiterConfig, SecureModeConfig, TunnelInfo}, peer_rpc::{ HandshakeRequest, PeerConnNoiseMsg1Pb, PeerConnNoiseMsg2Pb, PeerConnNoiseMsg3Pb, - PeerConnSessionActionPb, SecureAuthLevel, + PeerConnSessionActionPb, PeerIdentityType, SecureAuthLevel, }, }, tunnel::{ @@ -83,6 +83,7 @@ struct NoiseHandshakeResult { remote_static_pubkey: Vec, handshake_hash: Vec, secure_auth_level: SecureAuthLevel, + peer_identity_type: PeerIdentityType, remote_network_name: String, secret_digest: Vec, @@ -138,6 +139,8 @@ impl PeerSessionTunnelFilter { hdr.packet_type == PacketType::NoiseHandshakeMsg1 as u8 || hdr.packet_type == PacketType::NoiseHandshakeMsg2 as u8 || hdr.packet_type == PacketType::NoiseHandshakeMsg3 as u8 + || hdr.packet_type == PacketType::RelayHandshake as u8 + || hdr.packet_type == PacketType::RelayHandshakeAck as u8 || hdr.packet_type == PacketType::Ping as u8 || hdr.packet_type == PacketType::Pong as u8 } @@ -169,9 +172,19 @@ impl TunnelFilter for PeerSessionTunnelFilter { }; let my_peer_id = self.my_peer_id.load(); - session - .encrypt_payload(my_peer_id, peer_id, &mut data) - .ok()?; + if my_peer_id != hdr.from_peer_id.get() { + return Some(data); + } + + if let Err(e) = session.encrypt_payload(my_peer_id, peer_id, &mut data) { + tracing::warn!( + ?my_peer_id, + ?peer_id, + ?e, + "PeerSessionTunnelFilter: encrypt failed, dropping packet" + ); + return None; + } Some(data) } @@ -198,7 +211,14 @@ impl TunnelFilter for PeerSessionTunnelFilter { if from_peer_id == 0 { return Some(Ok(data)); } - self.peer_id.store(Some(from_peer_id)); + + let Some(peer_id) = self.peer_id.load() else { + return Some(Ok(data)); + }; + + if from_peer_id != peer_id { + return Some(Ok(data)); + } let mut guard = self.session.lock().unwrap(); let Some(session) = guard.as_mut() else { @@ -206,7 +226,22 @@ impl TunnelFilter for PeerSessionTunnelFilter { }; let my_peer_id = self.my_peer_id.load(); - let _ = session.decrypt_payload(from_peer_id, my_peer_id, &mut data); + if hdr.to_peer_id.get() != my_peer_id { + return Some(Ok(data)); + } + + if let Err(e) = session.decrypt_payload(from_peer_id, my_peer_id, &mut data) { + if !session.is_valid() { + // Session auto-invalidated after too many consecutive failures. + // Close the connection to trigger reconnection with a fresh handshake. + tracing::error!(?e, "session invalidated, closing connection"); + return Some(Err(TunnelError::InternalError( + "session invalidated due to consecutive decrypt failures".to_string(), + ))); + } + // Transient failure, drop this packet but keep the connection alive. + return None; + } Some(Ok(data)) } @@ -643,6 +678,108 @@ impl PeerConn { Ok(self.sink.send(pkt).await?) } + /// Unified remote peer authentication verification. + /// + /// Auth outcome matrix (current behavior): + /// + /// | Client role | Server role | Typical credential condition | Client auth level | Server auth level | Client sees server type | Server sees client type | + /// | --- | --- | --- | --- | --- | --- | --- | + /// | Admin | Admin | same network_secret, proof verified | NetworkSecretConfirmed | NetworkSecretConfirmed | Admin | Admin | + /// | Credential | Admin | client pubkey is trusted by admin | EncryptedUnauthenticated | PeerVerified | Admin | Credential | + /// | Credential | Admin | client pubkey is unknown | handshake may fail | handshake reject | unknown | unknown | + /// | Admin | SharedNode | pinned key match | PeerVerified | EncryptedUnauthenticated | SharedNode | SharedNode | + /// | Admin | SharedNode | local has no pinned key requirement | EncryptedUnauthenticated | EncryptedUnauthenticated | SharedNode | SharedNode | + /// | Credential | SharedNode | no pin and not trusted | EncryptedUnauthenticated | EncryptedUnauthenticated | SharedNode | SharedNode | + /// | Credential | Credential | should reject | handshake reject | handshake reject | unknown | unknown | + /// + /// Logic (in priority order): + /// 1. **NetworkSecretConfirmed**: proof verification succeeds + /// 2. **PeerVerified**: pinned_pubkey matches and is in trusted list + /// (if no network_secret, pinned_pubkey must be in trusted list) + /// 3. **PeerVerified**: pubkey is in trusted list + /// 4. **EncryptedUnauthenticated**: initiator without network_secret + /// 5. **Reject**: none of the above + #[allow(clippy::too_many_arguments)] + fn verify_remote_auth( + &self, + proof: Option<&[u8]>, + handshake_hash: &[u8], + remote_pubkey: &[u8], + pinned_pubkey: Option<&[u8]>, + has_network_secret: bool, + is_initiator: bool, + remote_network_name: &str, + ) -> Result { + // 1. Verify proof + if let Some(proof) = proof { + if let Some(mac) = self.global_ctx.get_secret_proof(handshake_hash) { + if mac.verify_slice(proof).is_ok() { + return Ok(SecureAuthLevel::NetworkSecretConfirmed); + } + } + } + + // 2. Check pinned pubkey + if let Some(pinned) = pinned_pubkey { + if pinned != remote_pubkey { + return Err(Error::WaitRespError( + "pinned remote static pubkey mismatch".to_owned(), + )); + } + // If no network_secret, pinned key must be in trusted list + if !has_network_secret + && !self + .global_ctx + .is_pubkey_trusted(remote_pubkey, remote_network_name) + { + return Err(Error::WaitRespError( + "pinned pubkey not in trusted list".to_owned(), + )); + } + return Ok(SecureAuthLevel::PeerVerified); + } + + // 3. Check if pubkey is in trusted list + if self + .global_ctx + .is_pubkey_trusted(remote_pubkey, remote_network_name) + { + return Ok(SecureAuthLevel::PeerVerified); + } + + // 4. If we are the initiator without network_secret, keep encrypted channel only. + if is_initiator && !has_network_secret { + return Ok(SecureAuthLevel::EncryptedUnauthenticated); + } + + // 5. Reject + Err(Error::WaitRespError( + "authentication failed: invalid proof and unknown credential".to_owned(), + )) + } + + fn classify_remote_identity( + &self, + remote_network_name: &str, + secure_auth_level: SecureAuthLevel, + remote_role_hint_is_same_network: bool, + remote_sent_secret_proof: bool, + ) -> PeerIdentityType { + if !remote_role_hint_is_same_network + || remote_network_name != self.global_ctx.get_network_name() + { + return PeerIdentityType::SharedNode; + } + + if matches!(secure_auth_level, SecureAuthLevel::NetworkSecretConfirmed) + || remote_sent_secret_proof + { + return PeerIdentityType::Admin; + } + + PeerIdentityType::Credential + } + async fn do_noise_handshake_as_client(&self) -> Result { let prologue = b"easytier-peerconn-noise".to_vec(); @@ -681,8 +818,6 @@ impl PeerConn { .local_private_key(&local_private_key)? .build_initiator()?; - let mut secure_auth_level = SecureAuthLevel::EncryptedUnauthenticated; - self.send_noise_msg( msg1_pb, PacketType::NoiseHandshakeMsg1, @@ -717,29 +852,12 @@ impl PeerConn { let action = PeerConnSessionActionPb::try_from(msg2_pb.action) .map_err(|_| Error::WaitRespError("invalid session action".to_owned()))?; let remote_network_name = msg2_pb.b_network_name.clone(); + let remote_sent_secret_proof = msg2_pb.secret_proof_32.is_some(); - if remote_network_name == network.network_name { - if msg2_pb.role_hint != 1 { - return Err(Error::WaitRespError( - "role_hint must be 1 when network_name is same".to_owned(), - )); - } - let Some(secret_proof_32) = msg2_pb.secret_proof_32 else { - return Err(Error::WaitRespError( - "secret_proof_32 must be present when role_hint is 1".to_owned(), - )); - }; - let verify_result = self - .global_ctx - .get_secret_proof(&server_handshake_hash) - .map(|mac| mac.verify_slice(&secret_proof_32).is_ok()); - if verify_result != Some(true) { - return Err(Error::WaitRespError(format!( - "secret_proof_32 verify failed: {verify_result:?}" - ))); - } - - secure_auth_level = secure_auth_level.max(SecureAuthLevel::NetworkSecretConfirmed); + if remote_network_name == network.network_name && msg2_pb.role_hint != 1 { + return Err(Error::WaitRespError( + "role_hint must be 1 when network_name is same".to_owned(), + )); } let handshake_hash_for_proof = hs.get_handshake_hash().to_vec(); @@ -775,17 +893,34 @@ impl PeerConn { .get_remote_static() .map(|x: &[u8]| x.to_vec()) .unwrap_or_default(); + let remote_static_key = if remote_static.len() == 32 { + let mut key = [0u8; 32]; + key.copy_from_slice(&remote_static); + Some(key) + } else { + None + }; - if let Some(pinned) = pinned_remote_pubkey.as_ref() { - if pinned.as_slice() == remote_static.as_slice() { - secure_auth_level = - secure_auth_level.max(SecureAuthLevel::SharedNodePubkeyVerified); - } else { - return Err(Error::WaitRespError( - "pinned remote static pubkey mismatch".to_owned(), - )); - } - } + // Verify server authentication using unified logic + let secure_auth_level = if msg2_pb.role_hint != 1 && pinned_remote_pubkey.is_none() { + SecureAuthLevel::EncryptedUnauthenticated + } else { + self.verify_remote_auth( + msg2_pb.secret_proof_32.as_deref(), + &server_handshake_hash, + &remote_static, + pinned_remote_pubkey.as_deref(), + network.network_secret.is_some(), + true, // is_initiator + &remote_network_name, + )? + }; + let peer_identity_type = self.classify_remote_identity( + &remote_network_name, + secure_auth_level, + msg2_pb.role_hint == 1, + remote_sent_secret_proof, + ); let handshake_hash = hs.get_handshake_hash().to_vec(); @@ -812,6 +947,7 @@ impl PeerConn { msg2_pb.initial_epoch, algo, msg2_pb.server_encryption_algorithm.clone(), + remote_static_key, )?; Ok(NoiseHandshakeResult { @@ -821,6 +957,7 @@ impl PeerConn { remote_static_pubkey: remote_static, handshake_hash, secure_auth_level, + peer_identity_type, remote_network_name, // we have authorized the peer with noise handshake, so just set secret digest same as us even remote is a shared node. secret_digest, @@ -949,6 +1086,7 @@ impl PeerConn { msg1_pb.a_session_generation, algo.clone(), msg1_pb.client_encryption_algorithm.clone(), + None, )?; let b_conn_id = uuid::Uuid::new_v4(); @@ -1000,28 +1138,43 @@ impl PeerConn { )); } - let mut secure_auth_level = SecureAuthLevel::EncryptedUnauthenticated; - let Some(proof) = msg3_pb.secret_proof_32.as_ref() else { - return Err(Error::WaitRespError( - "noise msg3 secret_proof_32 is required".to_owned(), - )); - }; - - if role_hint == 1 { - if let Some(mac) = self.global_ctx.get_secret_proof(&handshake_hash_for_proof) { - if mac.verify_slice(proof).is_ok() { - secure_auth_level = - secure_auth_level.max(SecureAuthLevel::NetworkSecretConfirmed); - } else { - return Err(Error::WaitRespError("invalid secret_proof".to_owned())); - } - } - } - let remote_static = hs .get_remote_static() .map(|x: &[u8]| x.to_vec()) .unwrap_or_default(); + let remote_static_key = if remote_static.len() == 32 { + let mut key = [0u8; 32]; + key.copy_from_slice(&remote_static); + Some(key) + } else { + None + }; + session.check_or_set_peer_static_pubkey(remote_static_key)?; + + // Verify client authentication using unified logic + // Note: Server doesn't use pinned_pubkey since it's the responder + let secure_auth_level = if role_hint == 1 { + self.verify_remote_auth( + msg3_pb.secret_proof_32.as_deref(), + &handshake_hash_for_proof, + &remote_static, + None, // Server doesn't have pinned_remote_pubkey + self.global_ctx + .get_network_identity() + .network_secret + .is_some(), + false, // is_initiator + &remote_network_name, + )? + } else { + SecureAuthLevel::EncryptedUnauthenticated + }; + let peer_identity_type = self.classify_remote_identity( + &remote_network_name, + secure_auth_level, + role_hint == 1, + msg3_pb.secret_proof_32.is_some(), + ); let handshake_hash = hs.get_handshake_hash().to_vec(); @@ -1032,11 +1185,12 @@ impl PeerConn { remote_static_pubkey: remote_static, handshake_hash, secure_auth_level, + peer_identity_type, remote_network_name, secret_digest: msg3_pb.secret_digest, - client_secret_proof: Some(SecretProof { + client_secret_proof: msg3_pb.secret_proof_32.as_ref().map(|p| SecretProof { challenge: handshake_hash_for_proof, - proof: proof.clone(), + proof: p.clone(), }), my_encrypt_algo: self.my_encrypt_algo.clone(), @@ -1341,9 +1495,21 @@ impl PeerConn { .as_ref() .map(|x| x.secure_auth_level as i32) .unwrap_or_default(), + peer_identity_type: self + .noise_handshake_result + .as_ref() + .map(|x| x.peer_identity_type as i32) + .unwrap_or(PeerIdentityType::Admin as i32), } } + pub fn get_peer_identity_type(&self) -> PeerIdentityType { + self.noise_handshake_result + .as_ref() + .map(|x| x.peer_identity_type) + .unwrap_or(PeerIdentityType::Admin) + } + pub fn set_peer_id(&mut self, peer_id: PeerId) { if self.info.is_some() { panic!("set_peer_id should only be called before handshake"); @@ -1707,6 +1873,14 @@ pub mod tests { s_peer.get_conn_info().secure_auth_level, SecureAuthLevel::NetworkSecretConfirmed as i32, ); + assert_eq!( + c_peer.get_conn_info().peer_identity_type, + PeerIdentityType::Admin as i32, + ); + assert_eq!( + s_peer.get_conn_info().peer_identity_type, + PeerIdentityType::Admin as i32, + ); } #[tokio::test] @@ -1758,7 +1932,66 @@ pub mod tests { assert_eq!( c_peer.get_conn_info().secure_auth_level, - SecureAuthLevel::SharedNodePubkeyVerified as i32, + SecureAuthLevel::PeerVerified as i32, + ); + assert_eq!( + c_peer.get_conn_info().peer_identity_type, + PeerIdentityType::SharedNode as i32, + ); + assert_eq!( + s_peer.get_conn_info().peer_identity_type, + PeerIdentityType::SharedNode as i32, + ); + } + + #[tokio::test] + async fn peer_conn_secure_mode_shared_node_without_pin_is_unauthenticated() { + let (c, s) = create_ring_tunnel_pair(); + + let c_peer_id = new_peer_id(); + let s_peer_id = new_peer_id(); + + let c_ctx = get_mock_global_ctx(); + let s_ctx = get_mock_global_ctx(); + + c_ctx + .config + .set_network_identity(NetworkIdentity::new("net1".to_string(), "sec2".to_string())); + s_ctx.config.set_network_identity(NetworkIdentity { + network_name: "net2".to_string(), + network_secret: None, + network_secret_digest: None, + }); + + set_secure_mode_cfg(&c_ctx, true); + set_secure_mode_cfg(&s_ctx, true); + + let ps = Arc::new(PeerSessionStore::new()); + let mut c_peer = PeerConn::new(c_peer_id, c_ctx, Box::new(c), ps.clone()); + let mut s_peer = PeerConn::new(s_peer_id, s_ctx, Box::new(s), ps.clone()); + + let (c_ret, s_ret) = tokio::join!( + c_peer.do_handshake_as_client(), + s_peer.do_handshake_as_server() + ); + c_ret.unwrap(); + s_ret.unwrap(); + + assert_eq!( + c_peer.get_conn_info().secure_auth_level, + SecureAuthLevel::EncryptedUnauthenticated as i32, + ); + assert_eq!( + s_peer.get_conn_info().secure_auth_level, + SecureAuthLevel::EncryptedUnauthenticated as i32, + ); + assert_eq!( + c_peer.get_conn_info().peer_identity_type, + PeerIdentityType::SharedNode as i32, + ); + assert_eq!( + s_peer.get_conn_info().peer_identity_type, + PeerIdentityType::SharedNode as i32, ); } @@ -1852,4 +2085,227 @@ pub mod tests { .unwrap_err(); let _ = tokio::join!(j); } + + /// Helper: set up a credential node's GlobalCtx with a specific private key + /// (no network_secret, secure mode enabled with the given keypair) + fn set_credential_mode_cfg( + global_ctx: &GlobalCtx, + network_name: &str, + private_key: &x25519_dalek::StaticSecret, + ) { + use crate::common::config::NetworkIdentity; + let public = x25519_dalek::PublicKey::from(private_key); + global_ctx + .config + .set_network_identity(NetworkIdentity::new_credential(network_name.to_string())); + global_ctx.config.set_secure_mode(Some(SecureModeConfig { + enabled: true, + local_private_key: Some(BASE64_STANDARD.encode(private_key.as_bytes())), + local_public_key: Some(BASE64_STANDARD.encode(public.as_bytes())), + })); + } + + /// Test: credential node connects to admin node, admin has credential in trusted list. + /// Handshake should succeed with PeerVerified auth level on server side. + #[tokio::test] + async fn peer_conn_credential_node_connects_to_admin() { + let (c, s) = create_ring_tunnel_pair(); + + let c_peer_id = new_peer_id(); + let s_peer_id = new_peer_id(); + + // Admin node (server) has network_secret + let s_ctx = get_mock_global_ctx(); + s_ctx.config.set_network_identity(NetworkIdentity::new( + "net1".to_string(), + "secret".to_string(), + )); + set_secure_mode_cfg(&s_ctx, true); + + // Generate a credential on admin and get the private key for the client + let (cred_id, cred_secret) = s_ctx.get_credential_manager().generate_credential( + vec!["guest".to_string()], + false, + vec![], + std::time::Duration::from_secs(3600), + ); + + // Credential node (client) uses credential private key + let c_ctx = get_mock_global_ctx(); + let privkey_bytes: [u8; 32] = BASE64_STANDARD + .decode(&cred_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + set_credential_mode_cfg(&c_ctx, "net1", &private); + + let ps = Arc::new(PeerSessionStore::new()); + let mut c_peer = PeerConn::new(c_peer_id, c_ctx, Box::new(c), ps.clone()); + let mut s_peer = PeerConn::new(s_peer_id, s_ctx, Box::new(s), ps.clone()); + + let (c_ret, s_ret) = tokio::join!( + c_peer.do_handshake_as_client(), + s_peer.do_handshake_as_server() + ); + + c_ret.unwrap(); + s_ret.unwrap(); + + // Server should see credential node as PeerVerified + assert_eq!( + s_peer.get_conn_info().secure_auth_level, + SecureAuthLevel::PeerVerified as i32, + ); + assert_eq!( + s_peer.get_conn_info().peer_identity_type, + PeerIdentityType::Credential as i32, + ); + + // Client (credential node) keeps encrypted unauthenticated level + assert_eq!( + c_peer.get_conn_info().secure_auth_level, + SecureAuthLevel::EncryptedUnauthenticated as i32, + ); + assert_eq!( + c_peer.get_conn_info().peer_identity_type, + PeerIdentityType::Admin as i32, + ); + + // Verify credential ID matches + let _ = cred_id; // just to use it + } + + /// Test: unknown credential node (not in trusted list) is rejected by admin. + #[tokio::test] + async fn peer_conn_unknown_credential_rejected() { + let (c, s) = create_ring_tunnel_pair(); + + let c_peer_id = new_peer_id(); + let s_peer_id = new_peer_id(); + + // Admin node (server) with no credentials generated + let s_ctx = get_mock_global_ctx(); + s_ctx.config.set_network_identity(NetworkIdentity::new( + "net1".to_string(), + "secret".to_string(), + )); + set_secure_mode_cfg(&s_ctx, true); + + // Unknown credential node (client) with random key, not in admin's trusted list + let c_ctx = get_mock_global_ctx(); + let random_private = x25519_dalek::StaticSecret::random_from_rng(OsRng); + set_credential_mode_cfg(&c_ctx, "net1", &random_private); + + let ps = Arc::new(PeerSessionStore::new()); + let mut c_peer = PeerConn::new(c_peer_id, c_ctx, Box::new(c), ps.clone()); + let mut s_peer = PeerConn::new(s_peer_id, s_ctx, Box::new(s), ps.clone()); + + let (c_ret, s_ret) = tokio::join!( + c_peer.do_handshake_as_client(), + s_peer.do_handshake_as_server() + ); + + // Server should reject the unknown credential + assert!(s_ret.is_err(), "server should reject unknown credential"); + // Client may also fail due to connection being closed + let _ = c_ret; + } + + /// Test: two admin nodes with same network_secret still get NetworkSecretConfirmed. + /// (Regression test: credential system should not break normal admin-to-admin auth) + #[tokio::test] + async fn peer_conn_admin_to_admin_still_works() { + let (c, s) = create_ring_tunnel_pair(); + + let c_peer_id = new_peer_id(); + let s_peer_id = new_peer_id(); + + let c_ctx = get_mock_global_ctx(); + let s_ctx = get_mock_global_ctx(); + + c_ctx.config.set_network_identity(NetworkIdentity::new( + "net1".to_string(), + "secret".to_string(), + )); + s_ctx.config.set_network_identity(NetworkIdentity::new( + "net1".to_string(), + "secret".to_string(), + )); + + set_secure_mode_cfg(&c_ctx, true); + set_secure_mode_cfg(&s_ctx, true); + + let ps = Arc::new(PeerSessionStore::new()); + let mut c_peer = PeerConn::new(c_peer_id, c_ctx, Box::new(c), ps.clone()); + let mut s_peer = PeerConn::new(s_peer_id, s_ctx, Box::new(s), ps.clone()); + + let (c_ret, s_ret) = tokio::join!( + c_peer.do_handshake_as_client(), + s_peer.do_handshake_as_server() + ); + + c_ret.unwrap(); + s_ret.unwrap(); + + assert_eq!( + c_peer.get_conn_info().secure_auth_level, + SecureAuthLevel::NetworkSecretConfirmed as i32, + ); + assert_eq!( + s_peer.get_conn_info().secure_auth_level, + SecureAuthLevel::NetworkSecretConfirmed as i32, + ); + } + + /// Test: revoked credential is rejected on new connection attempt. + #[tokio::test] + async fn peer_conn_revoked_credential_rejected() { + // Admin generates credential, then revokes it + let admin_ctx = get_mock_global_ctx(); + admin_ctx.config.set_network_identity(NetworkIdentity::new( + "net1".to_string(), + "secret".to_string(), + )); + set_secure_mode_cfg(&admin_ctx, true); + + let (cred_id, cred_secret) = admin_ctx.get_credential_manager().generate_credential( + vec![], + false, + vec![], + std::time::Duration::from_secs(3600), + ); + + // Revoke the credential + assert!(admin_ctx + .get_credential_manager() + .revoke_credential(&cred_id)); + + // Now try to connect with the revoked credential + let (c, s) = create_ring_tunnel_pair(); + let c_peer_id = new_peer_id(); + let s_peer_id = new_peer_id(); + + let c_ctx = get_mock_global_ctx(); + let privkey_bytes: [u8; 32] = BASE64_STANDARD + .decode(&cred_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + set_credential_mode_cfg(&c_ctx, "net1", &private); + + let ps = Arc::new(PeerSessionStore::new()); + let mut c_peer = PeerConn::new(c_peer_id, c_ctx, Box::new(c), ps.clone()); + let mut s_peer = PeerConn::new(s_peer_id, admin_ctx, Box::new(s), ps.clone()); + + let (c_ret, s_ret) = tokio::join!( + c_peer.do_handshake_as_client(), + s_peer.do_handshake_as_server() + ); + + // Server should reject the revoked credential + assert!(s_ret.is_err(), "server should reject revoked credential"); + let _ = c_ret; + } } diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 59672db93..687f85d49 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -43,7 +43,8 @@ use crate::{ ListGlobalForeignNetworkResponse, }, peer_rpc::{ - ForeignNetworkRouteInfoEntry, ForeignNetworkRouteInfoKey, RouteForeignNetworkSummary, + ForeignNetworkRouteInfoEntry, ForeignNetworkRouteInfoKey, PeerIdentityType, + RouteForeignNetworkSummary, }, }, tunnel::{ @@ -62,6 +63,7 @@ use super::{ peer_map::PeerMap, peer_ospf_route::PeerRoute, peer_rpc::PeerRpcManager, + relay_peer_map::RelayPeerMap, route_trait::{ArcRoute, Route}, BoxNicPacketFilter, BoxPeerPacketFilter, PacketRecvChan, PacketRecvChanReceiver, }; @@ -76,6 +78,7 @@ struct RpcTransport { peer_rpc_tspt_sender: UnboundedSender, encryptor: Arc, + is_secure_mode_enabled: bool, } #[async_trait::async_trait] @@ -93,7 +96,7 @@ impl PeerRpcManagerTransport for RpcTransport { .and_then(|x| x.feature_flag.map(|x| x.is_public_server)) // if dst is directly connected, it's must not public server .unwrap_or(!peers.has_peer(dst_peer_id)); - if !is_dst_peer_public_server { + if !is_dst_peer_public_server && !self.is_secure_mode_enabled { self.encryptor .encrypt(&mut msg) .with_context(|| "encrypt failed")?; @@ -150,6 +153,7 @@ pub struct PeerManager { foreign_network_manager: Arc, foreign_network_client: Arc, + relay_peer_map: Arc, encryptor: Arc, data_compress_algo: CompressorAlgo, @@ -163,6 +167,7 @@ pub struct PeerManager { self_tx_counters: SelfTxCounters, peer_session_store: Arc, + is_secure_mode_enabled: bool, } impl Debug for PeerManager { @@ -189,6 +194,7 @@ impl PeerManager { global_ctx.clone(), my_peer_id, )); + let peer_session_store = Arc::new(PeerSessionStore::new()); let encryptor = if global_ctx.get_flags().enable_encryption { // 只有在启用加密时才使用工厂函数选择算法 @@ -213,6 +219,12 @@ impl PeerManager { global_ctx.set_feature_flags(f); } + let is_secure_mode_enabled = global_ctx + .config + .get_secure_mode() + .map(|cfg| cfg.enabled) + .unwrap_or(false); + // TODO: remove these because we have impl pipeline processor. let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel(); let rpc_tspt = Arc::new(RpcTransport { @@ -222,6 +234,7 @@ impl PeerManager { packet_recv: Mutex::new(peer_rpc_tspt_recv), peer_rpc_tspt_sender, encryptor: encryptor.clone(), + is_secure_mode_enabled, }); let peer_rpc_mgr = Arc::new(PeerRpcManager::new_with_stats_manager( rpc_tspt.clone(), @@ -240,6 +253,7 @@ impl PeerManager { let foreign_network_manager = Arc::new(ForeignNetworkManager::new( my_peer_id, global_ctx.clone(), + peer_session_store.clone(), packet_send.clone(), Self::build_foreign_network_manager_accessor(&peers), )); @@ -250,6 +264,14 @@ impl PeerManager { my_peer_id, )); + let relay_peer_map = RelayPeerMap::new( + peers.clone(), + Some(foreign_network_client.clone()), + global_ctx.clone(), + my_peer_id, + peer_session_store.clone(), + ); + let data_compress_algo = global_ctx .get_flags() .data_compress_algo() @@ -304,6 +326,7 @@ impl PeerManager { foreign_network_manager, foreign_network_client, + relay_peer_map, encryptor, data_compress_algo, @@ -316,7 +339,8 @@ impl PeerManager { self_tx_counters, - peer_session_store: Arc::new(PeerSessionStore::new()), + peer_session_store, + is_secure_mode_enabled, } } @@ -354,12 +378,34 @@ impl PeerManager { } async fn add_new_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> { - if self.global_ctx.get_network_identity() != peer_conn.get_network_identity() { + let my_identity = self.global_ctx.get_network_identity(); + let peer_identity = peer_conn.get_network_identity(); + + // For credential nodes, network_secret_digest is either None or all-zeros + // (all-zeros when received over the wire via handshake). + // In this case, only compare network_name. + let my_digest_empty = my_identity + .network_secret_digest + .as_ref() + .is_none_or(|d| d.iter().all(|b| *b == 0)); + let peer_digest_empty = peer_identity + .network_secret_digest + .as_ref() + .is_none_or(|d| d.iter().all(|b| *b == 0)); + + let identity_ok = if my_digest_empty || peer_digest_empty { + // Credential node: only check network_name + my_identity.network_name == peer_identity.network_name + } else { + my_identity == peer_identity + }; + + if !identity_ok { return Err(Error::SecretKeyError( "network identity not match".to_string(), )); } - self.peers.add_new_peer_conn(peer_conn).await; + self.peers.add_new_peer_conn(peer_conn).await?; Ok(()) } @@ -394,7 +440,7 @@ impl PeerManager { { self.add_new_peer_conn(peer).await?; } else { - self.foreign_network_client.add_new_peer_conn(peer).await; + self.foreign_network_client.add_new_peer_conn(peer).await?; } Ok((peer_id, conn_id)) } @@ -576,7 +622,7 @@ impl PeerManager { MetricName::TrafficPacketsForeignForwardRx, ); if let Err(e) = foreign_network_mgr - .send_msg_to_peer( + .forward_foreign_network_packet( &foreign_network_name, foreign_peer_id, packet.foreign_network_packet(), @@ -645,13 +691,21 @@ impl PeerManager { let peers = self.peers.clone(); let pipe_line = self.peer_packet_process_pipeline.clone(); let foreign_client = self.foreign_network_client.clone(); + let relay_peer_map = self.relay_peer_map.clone(); let foreign_mgr = self.foreign_network_manager.clone(); let encryptor = self.encryptor.clone(); let compress_algo = self.data_compress_algo; let acl_filter = self.global_ctx.get_acl_filter().clone(); let global_ctx = self.global_ctx.clone(); + let secure_mode_enabled = self.is_secure_mode_enabled; let stats_mgr = self.global_ctx.stats_manager().clone(); let route = self.get_route(); + let is_credential_node = self + .global_ctx + .get_network_identity() + .network_secret + .is_none() + && secure_mode_enabled; let label_set = LabelSet::new().with_label_type(LabelType::NetworkName(global_ctx.get_network_name())); @@ -699,6 +753,17 @@ impl PeerManager { continue; } + // Step 10b: credential nodes don't forward handshake packets + if is_credential_node + && (hdr.packet_type == PacketType::HandShake as u8 + || hdr.packet_type == PacketType::NoiseHandshakeMsg1 as u8 + || hdr.packet_type == PacketType::NoiseHandshakeMsg2 as u8 + || hdr.packet_type == PacketType::NoiseHandshakeMsg3 as u8) + { + tracing::debug!("credential node dropping forwarded handshake packet"); + continue; + } + if hdr.forward_counter > 2 && hdr.is_latency_first() { tracing::trace!(?hdr, "set_latency_first false because too many hop"); hdr.set_latency_first(false); @@ -713,9 +778,13 @@ impl PeerManager { || hdr.packet_type == PacketType::KcpSrc as u8 || hdr.packet_type == PacketType::KcpDst as u8 { - let _ = - Self::try_compress_and_encrypt(compress_algo, &encryptor, &mut ret) - .await; + let _ = Self::try_compress_and_encrypt( + compress_algo, + &encryptor, + &mut ret, + secure_mode_enabled, + ) + .await; } compress_tx_bytes_after.add(ret.buf_len() as u64); @@ -727,16 +796,44 @@ impl PeerManager { } tracing::trace!(?to_peer_id, ?my_peer_id, "need forward"); - let ret = - Self::send_msg_internal(&peers, &foreign_client, ret, to_peer_id).await; + let ret = Self::send_msg_internal( + &peers, + &foreign_client, + &relay_peer_map, + ret, + to_peer_id, + ) + .await; if ret.is_err() { tracing::error!(?ret, ?to_peer_id, ?from_peer_id, "forward packet error"); } } else { - if let Err(e) = encryptor.decrypt(&mut ret) { - tracing::error!(?e, "decrypt failed"); + if hdr.packet_type == PacketType::RelayHandshake as u8 + || hdr.packet_type == PacketType::RelayHandshakeAck as u8 + { + let _ = relay_peer_map.handle_handshake_packet(ret).await; continue; } + if !secure_mode_enabled { + if let Err(e) = encryptor.decrypt(&mut ret) { + tracing::error!(?e, "decrypt failed"); + continue; + } + } else if !peers.has_peer(from_peer_id) + && !foreign_client.has_next_hop(from_peer_id) + { + match relay_peer_map.decrypt_if_needed(&mut ret).await { + Ok(true) => {} + Ok(false) => { + tracing::error!("relay session not found"); + continue; + } + Err(e) => { + tracing::error!(?e, "relay decrypt failed"); + continue; + } + } + } self_rx_bytes.add(buf_len as u64); self_rx_packets.inc(); @@ -882,6 +979,21 @@ impl PeerManager { self.my_peer_id } + async fn close_peer(&self, peer_id: PeerId) { + if let Some(peer_map) = self.peers.upgrade() { + let _ = peer_map.close_peer(peer_id).await; + } + + if let Some(foreign_client) = self.foreign_network_client.upgrade() { + let _ = foreign_client.get_peer_map().close_peer(peer_id).await; + } + } + + async fn get_peer_identity_type(&self, peer_id: PeerId) -> Option { + let peer_map = self.peers.upgrade()?; + peer_map.get_peer_identity_type(peer_id) + } + async fn list_foreign_networks(&self) -> ForeignNetworkRouteInfoMap { let ret = DashMap::new(); let Some(foreign_mgr) = self.foreign_network_manager.upgrade() else { @@ -1033,16 +1145,27 @@ impl PeerManager { .compress_tx_bytes_before .add(msg.buf_len() as u64); - Self::try_compress_and_encrypt(self.data_compress_algo, &self.encryptor, &mut msg).await?; + Self::try_compress_and_encrypt( + self.data_compress_algo, + &self.encryptor, + &mut msg, + self.is_secure_mode_enabled, + ) + .await?; self.self_tx_counters .compress_tx_bytes_after .add(msg.buf_len() as u64); let msg_len = msg.buf_len() as u64; - let result = - Self::send_msg_internal(&self.peers, &self.foreign_network_client, msg, dst_peer_id) - .await; + let result = Self::send_msg_internal( + &self.peers, + &self.foreign_network_client, + &self.relay_peer_map, + msg, + dst_peer_id, + ) + .await; if result.is_ok() { self.self_tx_counters.self_tx_bytes.add(msg_len); self.self_tx_counters.self_tx_packets.inc(); @@ -1053,17 +1176,22 @@ impl PeerManager { async fn send_msg_internal( peers: &Arc, foreign_network_client: &Arc, + relay_peer_map: &Arc, msg: ZCPacket, dst_peer_id: PeerId, ) -> Result<(), Error> { let policy = Self::get_next_hop_policy(msg.peer_manager_header().unwrap().is_latency_first()); + if peers.has_peer(dst_peer_id) { + return peers.send_msg_directly(msg, dst_peer_id).await; + } else if foreign_network_client.has_next_hop(dst_peer_id) { + return foreign_network_client.send_msg(msg, dst_peer_id).await; + } + if let Some(gateway) = peers.get_gateway_peer_id(dst_peer_id, policy.clone()).await { - if peers.has_peer(gateway) { - peers.send_msg_directly(msg, gateway).await - } else if foreign_network_client.has_next_hop(gateway) { - foreign_network_client.send_msg(msg, gateway).await + if peers.has_peer(gateway) || foreign_network_client.has_next_hop(gateway) { + relay_peer_map.send_msg(msg, dst_peer_id, policy).await } else { tracing::warn!( ?gateway, @@ -1174,13 +1302,16 @@ impl PeerManager { compress_algo: CompressorAlgo, encryptor: &Arc, msg: &mut ZCPacket, + secure_mode_enabled: bool, ) -> Result<(), Error> { let compressor = DefaultCompressor {}; compressor .compress(msg, compress_algo) .await .with_context(|| "compress failed")?; - encryptor.encrypt(msg).with_context(|| "encrypt failed")?; + if !secure_mode_enabled { + encryptor.encrypt(msg).with_context(|| "encrypt failed")?; + } Ok(()) } @@ -1209,6 +1340,7 @@ impl PeerManager { return Self::send_msg_internal( &self.peers, &self.foreign_network_client, + &self.relay_peer_map, msg, cur_to_peer_id, ) @@ -1229,7 +1361,13 @@ impl PeerManager { .compress_tx_bytes_before .add(msg.buf_len() as u64); - Self::try_compress_and_encrypt(self.data_compress_algo, &self.encryptor, &mut msg).await?; + Self::try_compress_and_encrypt( + self.data_compress_algo, + &self.encryptor, + &mut msg, + self.is_secure_mode_enabled, + ) + .await?; self.self_tx_counters .compress_tx_bytes_after @@ -1273,9 +1411,14 @@ impl PeerManager { .add(msg.buf_len() as u64); self.self_tx_counters.self_tx_packets.inc(); - if let Err(e) = - Self::send_msg_internal(&self.peers, &self.foreign_network_client, msg, *peer_id) - .await + if let Err(e) = Self::send_msg_internal( + &self.peers, + &self.foreign_network_client, + &self.relay_peer_map, + msg, + *peer_id, + ) + .await { errs.push(e); } @@ -1301,6 +1444,26 @@ impl PeerManager { }); } + async fn run_relay_session_gc_routine(&self) { + let relay_peer_map = self.relay_peer_map.clone(); + self.tasks.lock().await.spawn(async move { + loop { + relay_peer_map.evict_idle_sessions(std::time::Duration::from_secs(60)); + tokio::time::sleep(std::time::Duration::from_secs(30)).await; + } + }); + } + + async fn run_peer_session_gc_routine(&self) { + let peer_session_store = self.peer_session_store.clone(); + self.tasks.lock().await.spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + peer_session_store.evict_unused_sessions(); + } + }); + } + async fn run_foriegn_network(&self) { self.peer_rpc_tspt .foreign_peers @@ -1322,6 +1485,8 @@ impl PeerManager { self.start_peer_recv().await; self.run_clean_peer_without_conn_routine().await; + self.run_relay_session_gc_routine().await; + self.run_peer_session_gc_routine().await; self.run_foriegn_network().await; @@ -1332,10 +1497,18 @@ impl PeerManager { self.peers.clone() } + pub fn get_relay_peer_map(&self) -> Arc { + self.relay_peer_map.clone() + } + pub fn get_peer_rpc_mgr(&self) -> Arc { self.peer_rpc_mgr.clone() } + pub fn get_peer_session_store(&self) -> Arc { + self.peer_session_store.clone() + } + pub fn my_node_id(&self) -> uuid::Uuid { self.global_ctx.get_id() } @@ -1852,7 +2025,7 @@ mod tests { return false; }; conns.iter().any(|c| { - c.secure_auth_level == SecureAuthLevel::SharedNodePubkeyVerified as i32 + c.secure_auth_level == SecureAuthLevel::PeerVerified as i32 && c.noise_local_static_pubkey.len() == 32 && c.noise_remote_static_pubkey.len() == 32 }) diff --git a/easytier/src/peers/peer_map.rs b/easytier/src/peers/peer_map.rs index ffe0bce60..5b055a813 100644 --- a/easytier/src/peers/peer_map.rs +++ b/easytier/src/peers/peer_map.rs @@ -16,7 +16,7 @@ use crate::{ }, proto::{ api::instance::{self, PeerConnInfo}, - peer_rpc::RoutePeerInfo, + peer_rpc::{PeerIdentityType, RoutePeerInfo}, }, tunnel::{packet_def::ZCPacket, TunnelError}, }; @@ -56,18 +56,19 @@ impl PeerMap { .issue_event(GlobalCtxEvent::PeerAdded(peer_id)); } - pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) { + pub async fn add_new_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> { let _ = self.maintain_alive_client_urls(&peer_conn); let peer_id = peer_conn.get_peer_id(); let no_entry = self.peer_map.get(&peer_id).is_none(); if no_entry { let new_peer = Peer::new(peer_id, self.packet_send.clone(), self.global_ctx.clone()); - new_peer.add_peer_conn(peer_conn).await; + new_peer.add_peer_conn(peer_conn).await?; self.add_new_peer(new_peer).await; } else { let peer = self.peer_map.get(&peer_id).unwrap().clone(); - peer.add_peer_conn(peer_conn).await; + peer.add_peer_conn(peer_conn).await?; } + Ok(()) } fn maintain_alive_client_urls(&self, peer_conn: &PeerConn) -> Option<()> { @@ -302,6 +303,11 @@ impl PeerMap { .map(|p| p.get_default_conn_id()) } + pub fn get_peer_identity_type(&self, peer_id: PeerId) -> Option { + self.get_peer_by_id(peer_id) + .and_then(|p| p.get_peer_identity_type()) + } + pub async fn close_peer_conn( &self, peer_id: PeerId, diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index bdefcbf33..386c00abf 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -1,5 +1,5 @@ use std::{ - collections::{BTreeMap, BTreeSet, HashMap}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, fmt::Debug, net::{IpAddr, Ipv4Addr, Ipv6Addr}, sync::{ @@ -43,9 +43,10 @@ use crate::{ route_foreign_network_infos, route_foreign_network_summary, sync_route_info_request::ConnInfo, ForeignNetworkRouteInfoEntry, ForeignNetworkRouteInfoKey, OspfRouteRpc, OspfRouteRpcClientFactory, - OspfRouteRpcServer, PeerGroupInfo, PeerIdVersion, RouteForeignNetworkInfos, - RouteForeignNetworkSummary, RoutePeerInfo, RoutePeerInfos, SyncRouteInfoError, - SyncRouteInfoRequest, SyncRouteInfoResponse, + OspfRouteRpcServer, PeerGroupInfo, PeerIdVersion, PeerIdentityType, + RouteForeignNetworkInfos, RouteForeignNetworkSummary, RoutePeerInfo, RoutePeerInfos, + SyncRouteInfoError, SyncRouteInfoRequest, SyncRouteInfoResponse, + TrustedCredentialPubkey, }, rpc_types::{ self, @@ -80,6 +81,38 @@ static REMOVE_UNREACHABLE_PEER_INFO_AFTER: Duration = Duration::from_secs(90); type Version = u32; +/// Check if `child` CIDR is a subset of `parent` CIDR (both as string representations). +/// Returns true if child is contained within parent, or if they are equal. +fn cidr_is_subset_str(child: &str, parent: &str) -> bool { + let Ok(child_cidr) = child.parse::() else { + return false; + }; + let Ok(parent_cidr) = parent.parse::() else { + return false; + }; + match (child_cidr, parent_cidr) { + (IpCidr::V4(c), IpCidr::V4(p)) => { + p.first_address() <= c.first_address() && c.last_address() <= p.last_address() + } + (IpCidr::V6(c), IpCidr::V6(p)) => { + p.first_address() <= c.first_address() && c.last_address() <= p.last_address() + } + _ => false, // mixed v4/v6 + } +} + +/// Patch specific fields in a raw DynamicMessage from a decoded RoutePeerInfo, +/// preserving all other fields (including unknown ones). +fn patch_raw_from_info(raw: &mut DynamicMessage, info: &RoutePeerInfo, fields: &[&str]) { + let mut decoded_raw = DynamicMessage::new(RoutePeerInfo::default().descriptor()); + decoded_raw.transcode_from(info).unwrap(); + for field_name in fields { + if let Some(value) = decoded_raw.get_field_by_name(field_name) { + raw.set_field_by_name(field_name, value.into_owned()); + } + } +} + #[derive(Debug, Clone)] struct AtomicVersion(Arc); @@ -146,6 +179,8 @@ impl RoutePeerInfo { groups: Vec::new(), quic_port: None, + noise_static_pubkey: Vec::new(), + trusted_credential_pubkeys: Vec::new(), } } @@ -164,6 +199,12 @@ impl RoutePeerInfo { global_ctx: &ArcGlobalCtx, ) -> Self { let stun_info = global_ctx.get_stun_info_collector().get_stun_info(); + let noise_static_pubkey = global_ctx + .config + .get_secure_mode() + .and_then(|cfg| cfg.public_key().ok()) + .map(|pk| pk.as_bytes().to_vec()) + .unwrap_or_default(); Self { peer_id: my_peer_id, inst_id: Some(global_ctx.get_id().into()), @@ -197,6 +238,19 @@ impl RoutePeerInfo { groups: global_ctx.get_acl_groups(my_peer_id), + noise_static_pubkey, + + // Only admin nodes (holding network_secret) publish trusted credential pubkeys + trusted_credential_pubkeys: if let Some(network_secret) = + global_ctx.get_network_identity().network_secret + { + global_ctx + .get_credential_manager() + .get_trusted_pubkeys(&network_secret) + } else { + Vec::new() + }, + ..Default::default() } } @@ -327,6 +381,10 @@ struct SyncedRouteInfo { group_trust_map: DashMap>>, group_trust_map_cache: DashMap>>, // cache for group trust map, should sync with group_trust_map + // Aggregated trusted credential pubkeys from all admin nodes + // Maps pubkey bytes -> TrustedCredentialPubkey + trusted_credential_pubkeys: DashMap, TrustedCredentialPubkey>, + version: AtomicVersion, } @@ -343,6 +401,19 @@ impl Debug for SyncedRouteInfo { } impl SyncedRouteInfo { + fn mark_credential_peer(info: &mut RoutePeerInfo, is_credential_peer: bool) { + let mut feature_flag = info.feature_flag.unwrap_or_default(); + feature_flag.is_credential_peer = is_credential_peer; + info.feature_flag = Some(feature_flag); + } + + fn is_credential_peer_info(info: &RoutePeerInfo) -> bool { + info.feature_flag + .as_ref() + .map(|x| x.is_credential_peer) + .unwrap_or(false) + } + fn get_connected_peers>(&self, peer_id: PeerId) -> Option { self.conn_map .read() @@ -821,6 +892,173 @@ impl SyncedRouteInfo { self.group_trust_map_cache .insert(my_peer_id, Arc::new(my_group_names)); } + + /// Collect trusted credential pubkeys from admin nodes (network_secret holders) + /// and verify credential peers. Returns set of peer_ids that should be removed. + /// Also returns a HashMap of trusted keys for synchronization to GlobalCtx. + fn verify_and_update_credential_trusts( + &self, + network_secret: Option<&str>, + ) -> ( + Vec, + HashMap, crate::common::global_ctx::TrustedKeyMetadata>, + ) { + use crate::common::global_ctx::{TrustedKeyMetadata, TrustedKeySource}; + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + // Step 1: Collect trusted credential pubkeys from admin nodes (take union) + // Only trust nodes whose secret_digest matches ours (i.e. they hold network_secret) + let mut all_trusted: HashMap, TrustedCredentialPubkey> = HashMap::new(); + // Also collect all peer pubkeys for GlobalCtx synchronization + let mut global_trusted_keys: HashMap, TrustedKeyMetadata> = HashMap::new(); + + let peer_infos = self.peer_infos.read(); + + for (_, info) in peer_infos.iter() { + if !self.is_admin_peer(info) { + continue; + } + // Collect all peer noise_static_pubkeys as trusted keys + if !info.noise_static_pubkey.is_empty() { + global_trusted_keys.insert( + info.noise_static_pubkey.clone(), + TrustedKeyMetadata { + source: TrustedKeySource::OspfNode, + expiry_unix: None, // Peer pubkeys never expire + }, + ); + } + for proof in &info.trusted_credential_pubkeys { + // If we have a network_secret, verify the HMAC as before. + // If we don't (e.g. credential nodes), accept proofs from admin peers + // based on the authenticated channel instead of local HMAC verification. + let hmac_valid = network_secret + .map(|secret| proof.verify_credential_hmac(secret)) + .unwrap_or(true); + if !hmac_valid { + continue; + } + let Some(tc) = proof.credential.as_ref() else { + continue; + }; + if tc.expiry_unix > now { + all_trusted + .entry(tc.pubkey.clone()) + .or_insert_with(|| tc.clone()); + // Also add to global trusted keys + global_trusted_keys.insert( + tc.pubkey.clone(), + TrustedKeyMetadata { + source: TrustedKeySource::OspfCredential, + expiry_unix: Some(tc.expiry_unix), + }, + ); + } + } + } + + // Save the previous trusted set to detect revoked credentials + let prev_trusted: HashSet> = self + .trusted_credential_pubkeys + .iter() + .map(|r| r.key().clone()) + .collect(); + + // Update the trusted_credential_pubkeys map + self.trusted_credential_pubkeys.clear(); + for (k, v) in &all_trusted { + self.trusted_credential_pubkeys.insert(k.clone(), v.clone()); + } + + // Step 2: Update group trust map for credential peers + // Credential peers get their groups from the TrustedCredentialPubkey declaration + for (_, info) in peer_infos.iter() { + if info.noise_static_pubkey.is_empty() { + continue; + } + if let Some(tc) = all_trusted.get(&info.noise_static_pubkey) { + // This peer is a credential peer, assign groups from credential declaration + if !tc.groups.is_empty() { + let mut group_map = HashMap::new(); + let mut group_names = Vec::new(); + for g in &tc.groups { + group_map.insert(g.clone(), Vec::new()); // no proof needed, admin-declared + group_names.push(g.clone()); + } + self.group_trust_map.insert(info.peer_id, group_map); + self.group_trust_map_cache + .insert(info.peer_id, Arc::new(group_names)); + } + } + } + + // Step 3: Find and remove peers with revoked/expired credentials. + // A peer is untrusted if: + // - Its noise_static_pubkey was in the PREVIOUS trusted set (it was a credential peer) + // - Its noise_static_pubkey is NOT in the CURRENT trusted set (credential revoked/expired) + let mut untrusted_peers = Vec::new(); + for (peer_id, info) in peer_infos.iter() { + if info.noise_static_pubkey.is_empty() || info.version == 0 { + continue; + } + // Only remove peers whose pubkey was previously trusted but no longer is + if prev_trusted.contains(&info.noise_static_pubkey) + && !all_trusted.contains_key(&info.noise_static_pubkey) + { + untrusted_peers.push(*peer_id); + } + } + + // Remove untrusted peers from peer_infos so they won't appear in route graph + if !untrusted_peers.is_empty() { + drop(peer_infos); // release read lock before writing + let mut peer_infos_write = self.peer_infos.write(); + for peer_id in &untrusted_peers { + tracing::warn!(?peer_id, "removing untrusted peer from route info"); + peer_infos_write.remove(peer_id); + self.raw_peer_infos.remove(peer_id); + } + drop(peer_infos_write); + // Also remove from conn_map + let mut conn_map = self.conn_map.write(); + for peer_id in &untrusted_peers { + conn_map.remove(peer_id); + } + self.version.inc(); + } + + (untrusted_peers, global_trusted_keys) + } + + fn is_admin_peer(&self, info: &RoutePeerInfo) -> bool { + if info.version == 0 { + return false; + } + !Self::is_credential_peer_info(info) + } + + fn is_credential_peer(&self, peer_id: PeerId) -> bool { + let peer_infos = self.peer_infos.read(); + peer_infos + .get(&peer_id) + .map(Self::is_credential_peer_info) + .unwrap_or(false) + } + + fn get_credential_info(&self, peer_id: PeerId) -> Option { + let peer_infos = self.peer_infos.read(); + let info = peer_infos.get(&peer_id)?; + if info.noise_static_pubkey.is_empty() { + return None; + } + self.trusted_credential_pubkeys + .get(&info.noise_static_pubkey) + .map(|r| r.value().clone()) + } } type PeerGraph = Graph; @@ -968,6 +1206,14 @@ impl RouteTable { start_node: &NodeIndex, version: Version, ) { + if graph.node_weight(*start_node).is_none() { + tracing::warn!( + ?start_node, + version, + "invalid start node for least-hop route rebuild" + ); + return; + } let normalize_edge_cost = |e: petgraph::graph::EdgeReference| { if *e.weight() >= AVOID_RELAY_COST { AVOID_RELAY_COST + 1 @@ -1011,6 +1257,14 @@ impl RouteTable { start_node: &NodeIndex, version: Version, ) { + if graph.node_weight(*start_node).is_none() { + tracing::warn!( + ?start_node, + version, + "invalid start node for least-cost route rebuild" + ); + return; + } let (costs, next_hops) = dijkstra_with_first_hop(&graph, *start_node, |e| *e.weight()); for (dst, (next_hop, path_len)) in next_hops.iter() { @@ -1049,6 +1303,18 @@ impl RouteTable { if graph.node_count() == 0 { tracing::warn!("no peer in graph, cannot build next hop map"); + self.next_hop_map_version.set_if_larger(version); + self.clean_expired_route_info(); + return; + } + if start_node == NodeIndex::end() { + tracing::warn!( + ?my_peer_id, + version, + "my peer id is missing in graph, skip next-hop rebuild this round" + ); + self.next_hop_map_version.set_if_larger(version); + self.clean_expired_route_info(); return; } @@ -1587,6 +1853,7 @@ impl PeerRouteServiceImpl { foreign_network: DashMap::new(), group_trust_map: DashMap::new(), group_trust_map_cache: DashMap::new(), + trusted_credential_pubkeys: DashMap::new(), version: AtomicVersion::new(), }, cached_local_conn_map: std::sync::Mutex::new(RouteConnBitmap::default()), @@ -1598,6 +1865,24 @@ impl PeerRouteServiceImpl { } } + fn get_my_secret_digest(&self) -> Option> { + let ni = self.global_ctx.get_network_identity(); + ni.network_secret_digest.map(|d| d.to_vec()) + } + + fn is_credential_node(&self) -> bool { + self.global_ctx + .get_network_identity() + .network_secret + .is_none() + && self + .global_ctx + .config + .get_secure_mode() + .map(|c| c.enabled) + .unwrap_or(false) + } + fn get_or_create_session(&self, dst_peer_id: PeerId) -> Arc { self.sessions .entry(dst_peer_id) @@ -1631,29 +1916,31 @@ impl PeerRouteServiceImpl { .collect() } + async fn get_peer_identity_type_from_interface( + &self, + peer_id: PeerId, + ) -> Option { + self.interface + .lock() + .await + .as_ref() + .unwrap() + .get_peer_identity_type(peer_id) + .await + } + fn update_my_peer_info(&self) -> bool { - if self.synced_route_info.update_my_peer_info( + self.synced_route_info.update_my_peer_info( self.my_peer_id, self.my_peer_route_id, &self.global_ctx, - ) { - self.update_route_table_and_cached_local_conn_bitmap(); - return true; - } - false + ) } async fn update_my_conn_info(&self) -> bool { let connected_peers: BTreeSet = self.list_peers_from_interface().await; - let updated = self - .synced_route_info - .update_my_conn_info(self.my_peer_id, connected_peers); - - if updated { - self.update_route_table_and_cached_local_conn_bitmap(); - } - - updated + self.synced_route_info + .update_my_conn_info(self.my_peer_id, connected_peers) } async fn update_my_foreign_network(&self) -> bool { @@ -1842,15 +2129,6 @@ impl PeerRouteServiceImpl { if let Some(last_update) = peer_info.last_update { let last_update = TryInto::::try_into(last_update).unwrap(); if last_sync_succ_timestamp.is_some_and(|t| last_update < t) { - tracing::debug!( - "ignore peer_info {:?} because last_update: {:?} is older than last_sync_succ_timestamp: {:?}, peer_infos_count: {}, my_peer_id: {:?}, session: {:?}", - peer_info, - last_update, - last_sync_succ_timestamp, - peer_infos.len(), - self.my_peer_id, - session - ); break; } } @@ -1921,15 +2199,6 @@ impl PeerRouteServiceImpl { // stop iter if last_update of conn info is older than session.last_sync_succ_timestamp let last_update = TryInto::::try_into(conn_info.last_update).unwrap(); if last_sync_succ_timestamp.is_some_and(|t| last_update < t) { - tracing::debug!( - "ignore conn info {:?} because last_update: {:?} is older than last_sync_succ_timestamp: {:?}, conn_map count: {}, my_peer_id: {:?}, session: {:?}", - conn_info, - last_update, - last_sync_succ_timestamp, - conn_map.len(), - self.my_peer_id, - session - ); break; } @@ -2012,7 +2281,21 @@ impl PeerRouteServiceImpl { let my_peer_info_updated = self.update_my_peer_info(); let my_conn_info_updated = self.update_my_conn_info().await; let my_foreign_network_updated = self.update_my_foreign_network().await; - if my_conn_info_updated || my_peer_info_updated { + let mut untrusted_changed = false; + if my_peer_info_updated { + let network_identity = self.global_ctx.get_network_identity(); + let network_secret = network_identity.network_secret.as_deref(); + let (untrusted, global_trusted_keys) = self + .synced_route_info + .verify_and_update_credential_trusts(network_secret); + self.global_ctx + .update_trusted_keys(global_trusted_keys, &network_identity.network_name); + self.disconnect_untrusted_peers(&untrusted).await; + untrusted_changed = !untrusted.is_empty(); + } + + if my_peer_info_updated || my_conn_info_updated || untrusted_changed { + self.update_route_table_and_cached_local_conn_bitmap(); self.update_foreign_network_owner_map(); } if my_peer_info_updated { @@ -2021,6 +2304,22 @@ impl PeerRouteServiceImpl { my_peer_info_updated || my_conn_info_updated || my_foreign_network_updated } + async fn disconnect_untrusted_peers(&self, untrusted_peers: &[PeerId]) { + if untrusted_peers.is_empty() { + return; + } + + let interface = self.interface.lock().await; + let Some(interface) = interface.as_ref() else { + return; + }; + + for peer_id in untrusted_peers { + tracing::warn!(?peer_id, "disconnecting untrusted peer"); + interface.close_peer(*peer_id).await; + } + } + fn build_sync_request( &self, session: &SyncRouteSession, @@ -2168,7 +2467,7 @@ impl PeerRouteServiceImpl { return true; } - tracing::debug!(?foreign_network, "sync_route request need send to peer. my_id {:?}, pper_id: {:?}, peer_infos: {:?}, conn_info: {:?}, synced_route_info: {:?} session: {:?}", + tracing::debug!(?foreign_network, "sync_route request need send to peer. my_id {:?}, dst_peer_id: {:?}, peer_infos: {:?}, conn_info: {:?}, synced_route_info: {:?} session: {:?}", my_peer_id, dst_peer_id, peer_infos, conn_info, self.synced_route_info, session); session @@ -2504,16 +2803,28 @@ impl RouteSessionManager { } // find peer_ids that are not initiators. - let initiator_candidates = peers - .iter() - .filter(|x| { - let Some(session) = service_impl.get_session(**x) else { - return true; - }; - !session.dst_is_initiator.load(Ordering::Relaxed) - }) - .copied() - .collect::>(); + let mut initiator_candidates = Vec::new(); + for peer_id in peers.iter().copied() { + // Step 9a: Filter OSPF session candidates based on direct auth level. + // - Credential nodes only initiate sessions to admin nodes (not other credential nodes) + // - Admin nodes don't initiate sessions to credential nodes + let identity_type = service_impl + .get_peer_identity_type_from_interface(peer_id) + .await + .unwrap_or(PeerIdentityType::Admin); + if matches!(identity_type, PeerIdentityType::Credential) { + continue; + } + + let Some(session) = service_impl.get_session(peer_id) else { + initiator_candidates.push(peer_id); + continue; + }; + + if !session.dst_is_initiator.load(Ordering::Relaxed) { + initiator_candidates.push(peer_id); + } + } if initiator_candidates.is_empty() { next_sleep_ms = 1000; @@ -2556,6 +2867,7 @@ impl RouteSessionManager { continue; }; session.update_initiator_flag(true); + self.sync_now("update_initiator_flag"); } // clear sessions that are neither dst_initiator or we_are_initiator. @@ -2625,6 +2937,13 @@ impl RouteSessionManager { let my_peer_id = service_impl.my_peer_id; let session = self.get_or_start_session(from_peer_id)?; + let from_identity_type = service_impl + .get_peer_identity_type_from_interface(from_peer_id) + .await + .unwrap_or(PeerIdentityType::Admin); + let from_is_credential = matches!(from_identity_type, PeerIdentityType::Credential); + let from_is_shared = matches!(from_identity_type, PeerIdentityType::SharedNode); + let _session_lock = session.lock.lock(); session.rpc_rx_count.fetch_add(1, Ordering::Relaxed); @@ -2632,40 +2951,141 @@ impl RouteSessionManager { session.update_dst_session_id(from_session_id); let mut need_update_route_table = false; + let mut untrusted_peers = Vec::new(); if let Some(peer_infos) = &peer_infos { + // Step 9b: credential peers can only propagate their own route info + let normalize_raw = |info: &RoutePeerInfo| { + let mut raw = DynamicMessage::new(RoutePeerInfo::default().descriptor()); + raw.transcode_from(info).unwrap(); + raw + }; + let normalized_peer_infos: Vec; + let normalized_raw_peer_infos: Vec; + let (pi, rpi) = if from_is_credential { + let allowed_cidrs = service_impl + .synced_route_info + .get_credential_info(from_peer_id) + .map(|tc| tc.allowed_proxy_cidrs.clone()) + .unwrap_or_default(); + normalized_peer_infos = peer_infos + .iter() + .filter(|info| info.peer_id == from_peer_id) + .cloned() + .map(|mut info| { + // Filter proxy_cidrs to only those allowed by credential + if !allowed_cidrs.is_empty() { + info.proxy_cidrs.retain(|cidr| { + allowed_cidrs + .iter() + .any(|allowed| cidr_is_subset_str(cidr, allowed)) + }); + } else { + // No allowed_proxy_cidrs → no proxy_cidrs allowed + info.proxy_cidrs.clear(); + } + SyncedRouteInfo::mark_credential_peer(&mut info, true); + info + }) + .collect(); + normalized_raw_peer_infos = normalized_peer_infos + .iter() + .map(|info| { + // Find original raw for this peer to preserve unknown fields + let orig_idx = peer_infos.iter().position(|p| p.peer_id == info.peer_id); + let mut raw = orig_idx + .and_then(|idx| raw_peer_infos.as_ref().map(|rpi| rpi[idx].clone())) + .unwrap_or_else(|| normalize_raw(info)); + patch_raw_from_info(&mut raw, info, &["proxy_cidrs", "feature_flag"]); + raw + }) + .collect(); + (&normalized_peer_infos, &normalized_raw_peer_infos) + } else { + let mut peer_infos_mut = peer_infos.clone(); + let mut raw_peer_infos_mut = raw_peer_infos + .as_ref() + .cloned() + .unwrap_or_else(|| peer_infos_mut.iter().map(normalize_raw).collect()); + if from_is_shared { + for (info, raw) in peer_infos_mut.iter_mut().zip(raw_peer_infos_mut.iter_mut()) + { + info.trusted_credential_pubkeys.clear(); + patch_raw_from_info(raw, info, &["trusted_credential_pubkeys"]); + } + } + if let Some((idx, info)) = peer_infos_mut + .iter() + .enumerate() + .find(|(_, info)| info.peer_id == from_peer_id) + { + let mut info = info.clone(); + SyncedRouteInfo::mark_credential_peer(&mut info, false); + peer_infos_mut[idx] = info.clone(); + patch_raw_from_info(&mut raw_peer_infos_mut[idx], &info, &["feature_flag"]); + } + normalized_peer_infos = peer_infos_mut; + normalized_raw_peer_infos = raw_peer_infos_mut; + (&normalized_peer_infos, &normalized_raw_peer_infos) + }; + service_impl.synced_route_info.update_peer_infos( my_peer_id, service_impl.my_peer_route_id, from_peer_id, - peer_infos, - raw_peer_infos.as_ref().unwrap(), + pi, + rpi, )?; service_impl .synced_route_info .verify_and_update_group_trusts( - peer_infos, + pi, &service_impl.global_ctx.get_acl_group_declarations(), ); - session.update_dst_saved_peer_info_version(peer_infos, from_peer_id); + session.update_dst_saved_peer_info_version(pi, from_peer_id); need_update_route_table = true; } + // Step 9b: credential peers' conn_info depends on allow_relay flag if let Some(conn_info) = &conn_info { - service_impl.synced_route_info.update_conn_info(conn_info); - session.update_dst_saved_conn_info_version(conn_info, from_peer_id); - need_update_route_table = true; + let accept_conn_info = if from_is_credential { + service_impl + .synced_route_info + .get_credential_info(from_peer_id) + .map(|tc| tc.allow_relay) + .unwrap_or(false) + } else { + true + }; + if accept_conn_info { + service_impl.synced_route_info.update_conn_info(conn_info); + session.update_dst_saved_conn_info_version(conn_info, from_peer_id); + need_update_route_table = true; + } } if need_update_route_table { + // Run credential verification and update route table + let network_identity = service_impl.global_ctx.get_network_identity(); + let (untrusted, global_trusted_keys) = service_impl + .synced_route_info + .verify_and_update_credential_trusts(network_identity.network_secret.as_deref()); + untrusted_peers = untrusted; + // Sync trusted keys to GlobalCtx for handshake verification + service_impl + .global_ctx + .update_trusted_keys(global_trusted_keys, &network_identity.network_name); service_impl.update_route_table_and_cached_local_conn_bitmap(); } if let Some(foreign_network) = &foreign_network { - service_impl - .synced_route_info - .update_foreign_network(foreign_network); - session.update_dst_saved_foreign_network_version(foreign_network, from_peer_id); + // Step 9b: credential peers' foreign_network_infos are always ignored + if !from_is_credential { + service_impl + .synced_route_info + .update_foreign_network(foreign_network); + session.update_dst_saved_foreign_network_version(foreign_network, from_peer_id); + } } if need_update_route_table || foreign_network.is_some() { @@ -2682,6 +3102,11 @@ impl RouteSessionManager { let is_initiator = session.we_are_initiator.load(Ordering::Relaxed); let session_id = session.my_session_id.load(Ordering::Relaxed); + drop(_session_lock); + service_impl + .disconnect_untrusted_peers(&untrusted_peers) + .await; + self.sync_now("sync_route_info"); Ok(SyncRouteInfoResponse { @@ -3040,12 +3465,15 @@ mod tests { create_packet_recv_chan, peer_manager::{PeerManager, RouteAlgoType}, peer_ospf_route::{PeerIdVersion, PeerRouteServiceImpl, FORCE_USE_CONN_LIST}, - route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface}, + route_trait::{NextHopPolicy, Route, RouteCostCalculatorInterface, RouteInterface}, tests::{connect_peer_manager, create_mock_peer_manager, wait_route_appear}, }, proto::{ - common::NatType, - peer_rpc::{RoutePeerInfo, RoutePeerInfos, SyncRouteInfoRequest}, + common::{NatType, PeerFeatureFlag}, + peer_rpc::{ + PeerIdentityType, RoutePeerInfo, RoutePeerInfos, SyncRouteInfoRequest, + TrustedCredentialPubkey, TrustedCredentialPubkeyProof, + }, }, tunnel::common::tests::wait_for_condition, }; @@ -3053,6 +3481,26 @@ mod tests { use super::PeerRoute; + struct AuthOnlyInterface { + my_peer_id: PeerId, + identity_type: DashMap, + } + + #[async_trait::async_trait] + impl RouteInterface for AuthOnlyInterface { + async fn list_peers(&self) -> Vec { + Vec::new() + } + + fn my_peer_id(&self) -> PeerId { + self.my_peer_id + } + + async fn get_peer_identity_type(&self, peer_id: PeerId) -> Option { + self.identity_type.get(&peer_id).map(|x| *x.value()) + } + } + async fn create_mock_route(peer_mgr: Arc) -> Arc { let peer_route = PeerRoute::new( peer_mgr.my_peer_id(), @@ -3097,6 +3545,287 @@ mod tests { assert!(rx1 <= max_rx); } + #[tokio::test] + async fn credential_flag_controls_role_classification() { + let service_impl = PeerRouteServiceImpl::new(1, get_mock_global_ctx()); + + let mut admin_info = RoutePeerInfo::new(); + admin_info.peer_id = 10; + admin_info.version = 1; + admin_info.feature_flag = Some(PeerFeatureFlag { + is_credential_peer: false, + ..Default::default() + }); + + let mut credential_info = RoutePeerInfo::new(); + credential_info.peer_id = 11; + credential_info.version = 1; + credential_info.feature_flag = Some(PeerFeatureFlag { + is_credential_peer: true, + ..Default::default() + }); + + { + let mut guard = service_impl.synced_route_info.peer_infos.write(); + guard.insert(admin_info.peer_id, admin_info.clone()); + guard.insert(credential_info.peer_id, credential_info.clone()); + } + + assert!(service_impl.synced_route_info.is_admin_peer(&admin_info)); + assert!(!service_impl + .synced_route_info + .is_admin_peer(&credential_info)); + assert!(service_impl + .synced_route_info + .is_credential_peer(credential_info.peer_id)); + assert!(!service_impl + .synced_route_info + .is_credential_peer(admin_info.peer_id)); + } + + #[tokio::test] + async fn trusted_credentials_only_from_admin_publishers() { + let service_impl = PeerRouteServiceImpl::new(1, get_mock_global_ctx()); + let network_secret = "sec1"; + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + let admin_key = vec![1; 32]; + let credential_key = vec![2; 32]; + + let mut admin_info = RoutePeerInfo::new(); + admin_info.peer_id = 20; + admin_info.version = 1; + admin_info.feature_flag = Some(PeerFeatureFlag { + is_credential_peer: false, + ..Default::default() + }); + admin_info.trusted_credential_pubkeys = vec![TrustedCredentialPubkeyProof::new_signed( + TrustedCredentialPubkey { + pubkey: admin_key.clone(), + expiry_unix: now + 600, + ..Default::default() + }, + network_secret, + )]; + + let mut credential_info = RoutePeerInfo::new(); + credential_info.peer_id = 21; + credential_info.version = 1; + credential_info.feature_flag = Some(PeerFeatureFlag { + is_credential_peer: true, + ..Default::default() + }); + credential_info.trusted_credential_pubkeys = + vec![TrustedCredentialPubkeyProof::new_signed( + TrustedCredentialPubkey { + pubkey: credential_key.clone(), + expiry_unix: now + 600, + ..Default::default() + }, + network_secret, + )]; + + { + let mut guard = service_impl.synced_route_info.peer_infos.write(); + guard.insert(admin_info.peer_id, admin_info); + guard.insert(credential_info.peer_id, credential_info); + } + + service_impl + .synced_route_info + .verify_and_update_credential_trusts(Some(network_secret)); + + assert!(service_impl + .synced_route_info + .trusted_credential_pubkeys + .contains_key(&admin_key)); + assert!(!service_impl + .synced_route_info + .trusted_credential_pubkeys + .contains_key(&credential_key)); + } + + #[tokio::test] + async fn sync_route_info_marks_credential_sender_and_filters_entries() { + let peer_mgr = create_mock_pmgr().await; + let route = create_mock_route(peer_mgr.clone()).await; + let from_peer_id: PeerId = 10001; + let forwarded_peer_id: PeerId = 10002; + + let identity_type = DashMap::new(); + identity_type.insert(from_peer_id, PeerIdentityType::Credential); + *route.service_impl.interface.lock().await = Some(Box::new(AuthOnlyInterface { + my_peer_id: peer_mgr.my_peer_id(), + identity_type, + })); + + let mut sender_info = RoutePeerInfo::new(); + sender_info.peer_id = from_peer_id; + sender_info.version = 1; + sender_info.proxy_cidrs = vec!["10.10.0.0/24".to_string()]; + + let mut forwarded_info = RoutePeerInfo::new(); + forwarded_info.peer_id = forwarded_peer_id; + forwarded_info.version = 1; + + let make_raw = |info: &RoutePeerInfo| { + let mut raw = DynamicMessage::new(RoutePeerInfo::default().descriptor()); + raw.transcode_from(info).unwrap(); + raw + }; + let raw_infos = vec![make_raw(&sender_info), make_raw(&forwarded_info)]; + + route + .session_mgr + .do_sync_route_info( + from_peer_id, + 1, + true, + Some(vec![sender_info, forwarded_info]), + Some(raw_infos), + None, + None, + ) + .await + .unwrap(); + + let guard = route.service_impl.synced_route_info.peer_infos.read(); + let stored = guard.get(&from_peer_id).unwrap(); + assert!(stored + .feature_flag + .as_ref() + .map(|x| x.is_credential_peer) + .unwrap_or(false)); + assert!(stored.proxy_cidrs.is_empty()); + assert!(guard.get(&forwarded_peer_id).is_none()); + } + + #[tokio::test] + async fn sync_route_info_shared_sender_cannot_publish_trusted_credentials() { + let peer_mgr = create_mock_pmgr().await; + let route = create_mock_route(peer_mgr.clone()).await; + let from_peer_id: PeerId = 10021; + let forwarded_peer_id: PeerId = 10022; + let credential_key = vec![9u8; 32]; + + let identity_type = DashMap::new(); + identity_type.insert(from_peer_id, PeerIdentityType::SharedNode); + *route.service_impl.interface.lock().await = Some(Box::new(AuthOnlyInterface { + my_peer_id: peer_mgr.my_peer_id(), + identity_type, + })); + + let mut sender_info = RoutePeerInfo::new(); + sender_info.peer_id = from_peer_id; + sender_info.version = 1; + + let mut forwarded_info = RoutePeerInfo::new(); + forwarded_info.peer_id = forwarded_peer_id; + forwarded_info.version = 1; + forwarded_info.trusted_credential_pubkeys = vec![TrustedCredentialPubkeyProof { + credential: Some(TrustedCredentialPubkey { + pubkey: credential_key.clone(), + expiry_unix: i64::MAX, + ..Default::default() + }), + credential_hmac: vec![1; 32], + }]; + + let make_raw = |info: &RoutePeerInfo| { + let mut raw = DynamicMessage::new(RoutePeerInfo::default().descriptor()); + raw.transcode_from(info).unwrap(); + raw + }; + let raw_infos = vec![make_raw(&sender_info), make_raw(&forwarded_info)]; + + route + .session_mgr + .do_sync_route_info( + from_peer_id, + 1, + true, + Some(vec![sender_info, forwarded_info]), + Some(raw_infos), + None, + None, + ) + .await + .unwrap(); + + let guard = route.service_impl.synced_route_info.peer_infos.read(); + assert!(guard + .get(&forwarded_peer_id) + .map(|x| x.trusted_credential_pubkeys.is_empty()) + .unwrap_or(false)); + drop(guard); + + assert!(!route + .service_impl + .synced_route_info + .trusted_credential_pubkeys + .contains_key(&credential_key)); + } + + #[tokio::test] + async fn sync_route_info_forces_non_credential_for_legacy_admin_sender() { + let peer_mgr = create_mock_pmgr().await; + let route = create_mock_route(peer_mgr.clone()).await; + let from_peer_id: PeerId = 10011; + let other_peer_id: PeerId = 10012; + + let identity_type = DashMap::new(); + identity_type.insert(from_peer_id, PeerIdentityType::Admin); + *route.service_impl.interface.lock().await = Some(Box::new(AuthOnlyInterface { + my_peer_id: peer_mgr.my_peer_id(), + identity_type, + })); + + let mut sender_info = RoutePeerInfo::new(); + sender_info.peer_id = from_peer_id; + sender_info.version = 1; + sender_info.feature_flag = Some(PeerFeatureFlag { + is_credential_peer: true, + ..Default::default() + }); + + let mut other_info = RoutePeerInfo::new(); + other_info.peer_id = other_peer_id; + other_info.version = 1; + + let make_raw = |info: &RoutePeerInfo| { + let mut raw = DynamicMessage::new(RoutePeerInfo::default().descriptor()); + raw.transcode_from(info).unwrap(); + raw + }; + let raw_infos = vec![make_raw(&sender_info), make_raw(&other_info)]; + + route + .session_mgr + .do_sync_route_info( + from_peer_id, + 1, + true, + Some(vec![sender_info, other_info]), + Some(raw_infos), + None, + None, + ) + .await + .unwrap(); + + let guard = route.service_impl.synced_route_info.peer_infos.read(); + let sender = guard.get(&from_peer_id).unwrap(); + assert!(!sender + .feature_flag + .as_ref() + .map(|x| x.is_credential_peer) + .unwrap_or(false)); + assert!(guard.get(&other_peer_id).is_some()); + } + #[rstest::rstest] #[tokio::test] async fn ospf_route_2node(#[values(true, false)] enable_conn_list_sync: bool) { @@ -3691,4 +4420,197 @@ mod tests { connect_peer_manager(p_b.clone(), p_c.clone()).await; wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap(); } + + /// Helper: create a raw DynamicMessage from a RoutePeerInfo with an extra + /// unknown field appended (field number 9999, varint value 42). + /// Returns the raw DynamicMessage and the encoded unknown field bytes. + fn make_raw_with_unknown_field(info: &RoutePeerInfo) -> (DynamicMessage, Vec) { + // Encode the info to bytes + let mut bytes = info.encode_to_vec(); + // Append an unknown field: field 9999, wire type 0 (varint), value 42 + // Tag = (9999 << 3) | 0 = 79992, encoded as varint + prost::encoding::encode_key(9999, prost::encoding::WireType::Varint, &mut bytes); + prost::encoding::encode_varint(42, &mut bytes); + let unknown_field_bytes = bytes[info.encoded_len()..].to_vec(); + // Decode as DynamicMessage — unknown fields are preserved + let raw = DynamicMessage::decode(RoutePeerInfo::default().descriptor(), bytes.as_slice()) + .unwrap(); + (raw, unknown_field_bytes) + } + + /// Check that a raw DynamicMessage still contains the unknown field bytes + /// by re-encoding and checking the suffix. + fn raw_has_unknown_bytes(raw: &DynamicMessage, unknown_bytes: &[u8]) -> bool { + let encoded = raw.encode_to_vec(); + // The unknown field bytes should appear somewhere in the encoded output + encoded + .windows(unknown_bytes.len()) + .any(|w| w == unknown_bytes) + } + + #[tokio::test] + async fn sync_route_preserves_unknown_fields_for_credential_sender() { + let peer_mgr = create_mock_pmgr().await; + let route = create_mock_route(peer_mgr.clone()).await; + let from_peer_id: PeerId = 20001; + + let identity_type = DashMap::new(); + identity_type.insert(from_peer_id, PeerIdentityType::Credential); + *route.service_impl.interface.lock().await = Some(Box::new(AuthOnlyInterface { + my_peer_id: peer_mgr.my_peer_id(), + identity_type, + })); + + let mut sender_info = RoutePeerInfo::new(); + sender_info.peer_id = from_peer_id; + sender_info.version = 1; + + let (raw, unknown_bytes) = make_raw_with_unknown_field(&sender_info); + + route + .session_mgr + .do_sync_route_info( + from_peer_id, + 1, + true, + Some(vec![sender_info]), + Some(vec![raw]), + None, + None, + ) + .await + .unwrap(); + + let stored_raw = route + .service_impl + .synced_route_info + .raw_peer_infos + .get(&from_peer_id) + .expect("raw peer info should be stored"); + assert!( + raw_has_unknown_bytes(stored_raw.value(), &unknown_bytes), + "unknown fields should be preserved for credential sender" + ); + } + + #[tokio::test] + async fn sync_route_preserves_unknown_fields_for_shared_sender() { + let peer_mgr = create_mock_pmgr().await; + let route = create_mock_route(peer_mgr.clone()).await; + let from_peer_id: PeerId = 20011; + let forwarded_peer_id: PeerId = 20012; + + let identity_type = DashMap::new(); + identity_type.insert(from_peer_id, PeerIdentityType::SharedNode); + *route.service_impl.interface.lock().await = Some(Box::new(AuthOnlyInterface { + my_peer_id: peer_mgr.my_peer_id(), + identity_type, + })); + + let mut sender_info = RoutePeerInfo::new(); + sender_info.peer_id = from_peer_id; + sender_info.version = 1; + + let mut forwarded_info = RoutePeerInfo::new(); + forwarded_info.peer_id = forwarded_peer_id; + forwarded_info.version = 1; + forwarded_info.trusted_credential_pubkeys = vec![TrustedCredentialPubkeyProof { + credential: Some(TrustedCredentialPubkey { + pubkey: vec![9u8; 32], + expiry_unix: i64::MAX, + ..Default::default() + }), + credential_hmac: vec![1; 32], + }]; + + let (raw_sender, unknown_sender) = make_raw_with_unknown_field(&sender_info); + let (raw_forwarded, unknown_forwarded) = make_raw_with_unknown_field(&forwarded_info); + + route + .session_mgr + .do_sync_route_info( + from_peer_id, + 1, + true, + Some(vec![sender_info, forwarded_info]), + Some(vec![raw_sender, raw_forwarded]), + None, + None, + ) + .await + .unwrap(); + + // Shared node: trusted_credential_pubkeys cleared but unknown fields preserved + let stored_sender = route + .service_impl + .synced_route_info + .raw_peer_infos + .get(&from_peer_id) + .expect("sender raw should be stored"); + assert!( + raw_has_unknown_bytes(stored_sender.value(), &unknown_sender), + "unknown fields should be preserved for shared sender's own info" + ); + + let stored_forwarded = route + .service_impl + .synced_route_info + .raw_peer_infos + .get(&forwarded_peer_id) + .expect("forwarded raw should be stored"); + assert!( + raw_has_unknown_bytes(stored_forwarded.value(), &unknown_forwarded), + "unknown fields should be preserved for shared sender's forwarded info" + ); + } + + #[tokio::test] + async fn sync_route_preserves_unknown_fields_for_admin_sender() { + let peer_mgr = create_mock_pmgr().await; + let route = create_mock_route(peer_mgr.clone()).await; + let from_peer_id: PeerId = 20021; + + let identity_type = DashMap::new(); + identity_type.insert(from_peer_id, PeerIdentityType::Admin); + *route.service_impl.interface.lock().await = Some(Box::new(AuthOnlyInterface { + my_peer_id: peer_mgr.my_peer_id(), + identity_type, + })); + + let mut sender_info = RoutePeerInfo::new(); + sender_info.peer_id = from_peer_id; + sender_info.version = 1; + // Set is_credential_peer=true so the mark_credential_peer(false) path triggers + sender_info.feature_flag = Some(PeerFeatureFlag { + is_credential_peer: true, + ..Default::default() + }); + + let (raw, unknown_bytes) = make_raw_with_unknown_field(&sender_info); + + route + .session_mgr + .do_sync_route_info( + from_peer_id, + 1, + true, + Some(vec![sender_info]), + Some(vec![raw]), + None, + None, + ) + .await + .unwrap(); + + let stored_raw = route + .service_impl + .synced_route_info + .raw_peer_infos + .get(&from_peer_id) + .expect("raw peer info should be stored"); + assert!( + raw_has_unknown_bytes(stored_raw.value(), &unknown_bytes), + "unknown fields should be preserved for admin sender (mark non-credential path)" + ); + } } diff --git a/easytier/src/peers/peer_session.rs b/easytier/src/peers/peer_session.rs index fc766cad8..decd6517c 100644 --- a/easytier/src/peers/peer_session.rs +++ b/easytier/src/peers/peer_session.rs @@ -1,6 +1,6 @@ use std::{ sync::{ - atomic::{AtomicU32, Ordering}, + atomic::{AtomicBool, AtomicU32, Ordering}, Arc, Mutex, RwLock, }, time::{SystemTime, UNIX_EPOCH}, @@ -36,7 +36,7 @@ pub enum PeerSessionAction { Create, } -#[derive(PartialEq, Clone, Eq, Hash)] +#[derive(PartialEq, Clone, Eq, Hash, Debug)] pub struct SessionKey { network_name: String, peer_id: PeerId, @@ -70,17 +70,46 @@ impl PeerSessionStore { } pub fn get(&self, key: &SessionKey) -> Option> { - self.sessions.get(key).map(|v| v.clone()) + let session = self.sessions.get(key)?.clone(); + if session.is_valid() { + Some(session) + } else { + self.sessions.remove(key); + None + } + } + + pub fn remove(&self, key: &SessionKey) { + self.sessions.remove(key); + } + + pub fn insert_session(&self, key: SessionKey, session: Arc) { + self.sessions.insert(key, session); + } + + /// Remove sessions that are no longer referenced by any PeerConn or RelayPeerMap. + /// A session with strong_count == 1 means only the store holds it — no active + /// connection is using it, so it can be safely cleaned up. + pub fn evict_unused_sessions(&self) { + self.sessions + .retain(|_key, session| Arc::strong_count(session) > 1); } + #[tracing::instrument(skip(self))] pub fn upsert_responder_session( &self, key: &SessionKey, a_session_generation: Option, send_algorithm: String, recv_algorithm: String, + peer_static_pubkey: Option<[u8; 32]>, ) -> Result { - let existing = self.sessions.get(key).map(|v| v.clone()); + tracing::event!(tracing::Level::INFO, "upsert_responder_session {:?}", key); + let existing = self + .sessions + .get(key) + .map(|v| v.clone()) + .filter(|s| s.is_valid()); match existing { None => { let root_key = PeerSession::new_root_key(); @@ -93,6 +122,7 @@ impl PeerSessionStore { initial_epoch, send_algorithm, recv_algorithm, + peer_static_pubkey, )); self.sessions.insert(key.clone(), session.clone()); Ok(UpsertResponderSessionReturn { @@ -105,6 +135,7 @@ impl PeerSessionStore { } Some(session) => { session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?; + session.check_or_set_peer_static_pubkey(peer_static_pubkey)?; let local_gen = session.session_generation(); if a_session_generation.is_some_and(|g| g == local_gen) { Ok(UpsertResponderSessionReturn { @@ -130,6 +161,7 @@ impl PeerSessionStore { } #[allow(clippy::too_many_arguments)] + #[tracing::instrument(skip(self))] pub fn apply_initiator_action( &self, key: &SessionKey, @@ -139,19 +171,16 @@ impl PeerSessionStore { initial_epoch: u32, send_algorithm: String, recv_algorithm: String, + peer_static_pubkey: Option<[u8; 32]>, ) -> Result, anyhow::Error> { - tracing::info!( - "apply_initiator_action {:?}, send_algorithm: {}, recv_algorithm: {}", - action, - send_algorithm, - recv_algorithm - ); + tracing::event!(tracing::Level::INFO, "apply_initiator_action {:?}", key); match action { PeerSessionAction::Join => { let Some(session) = self.get(key) else { return Err(anyhow!("no local session for JOIN")); }; session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?; + session.check_or_set_peer_static_pubkey(peer_static_pubkey)?; if session.session_generation() != b_session_generation { return Err(anyhow!("JOIN generation mismatch")); } @@ -159,6 +188,13 @@ impl PeerSessionStore { } PeerSessionAction::Sync | PeerSessionAction::Create => { let root_key = root_key_32.ok_or_else(|| anyhow!("missing root_key"))?; + // If the existing session is invalidated, remove it so we create a fresh one + if let Some(existing) = self.sessions.get(key) { + if !existing.is_valid() { + drop(existing); + self.sessions.remove(key); + } + } let session = self .sessions .entry(key.clone()) @@ -170,10 +206,12 @@ impl PeerSessionStore { initial_epoch, send_algorithm.clone(), recv_algorithm.clone(), + peer_static_pubkey, )) }) .clone(); session.check_encrypt_algo_same(&send_algorithm, &recv_algorithm)?; + session.check_or_set_peer_static_pubkey(peer_static_pubkey)?; session.sync_root_key(root_key, b_session_generation, initial_epoch); Ok(session) } @@ -261,9 +299,9 @@ impl ReplayWindow256 { if bit_shift > 0 { let mut carry = 0u8; - for b in self.bitmap.iter_mut().rev() { - let new_carry = *b << (8 - bit_shift); - *b = (*b >> bit_shift) | carry; + for b in self.bitmap.iter_mut() { + let new_carry = *b >> (8 - bit_shift); + *b = (*b << bit_shift) | carry; carry = new_carry; } } @@ -318,6 +356,7 @@ pub struct PeerSession { peer_id: PeerId, root_key: RwLock<[u8; 32]>, session_generation: AtomicU32, + peer_static_pubkey: RwLock>, send_epoch: AtomicU32, send_seq: [AtomicU64; 2], @@ -329,6 +368,12 @@ pub struct PeerSession { send_cipher_algorithm: String, recv_cipher_algorithm: String, + + /// Set to true when the session is detected as corrupted (persistent decrypt failures). + /// Holders of Arc can check this to know the session should be discarded. + invalidated: AtomicBool, + /// Consecutive decrypt failure counter. Auto-invalidates when threshold is reached. + decrypt_fail_count: AtomicU32, } impl std::fmt::Debug for PeerSession { @@ -337,6 +382,7 @@ impl std::fmt::Debug for PeerSession { .field("peer_id", &self.peer_id) .field("root_key", &self.root_key) .field("session_generation", &self.session_generation) + .field("peer_static_pubkey", &self.peer_static_pubkey) .field("send_epoch", &self.send_epoch) .field("send_seq", &self.send_seq) .field("send_epoch_started_ms", &self.send_epoch_started_ms) @@ -381,6 +427,7 @@ impl PeerSession { /// stricter security requirements may decrease it. const ROTATE_AFTER_MS: u64 = 10 * 60 * 1000; const MAX_ACCEPTED_RX_EPOCH_AHEAD: u32 = 3; + const DECRYPT_FAIL_THRESHOLD: u32 = 10; pub fn new( peer_id: PeerId, @@ -389,11 +436,8 @@ impl PeerSession { initial_epoch: u32, send_cipher_algorithm: String, recv_cipher_algorithm: String, + peer_static_pubkey: Option<[u8; 32]>, ) -> Self { - // let mut root_key_128 = [0u8; 16]; - // root_key_128.copy_from_slice(&root_key[..16]); - // let send_cipher = create_encryptor(&send_algorithm, root_key_128, root_key); - // let recv_cipher = create_encryptor(&recv_algorithm, root_key_128, root_key); let rx_slots = [ [EpochRxSlot::default(), EpochRxSlot::default()], [EpochRxSlot::default(), EpochRxSlot::default()], @@ -407,6 +451,7 @@ impl PeerSession { peer_id, root_key: RwLock::new(root_key), session_generation: AtomicU32::new(session_generation), + peer_static_pubkey: RwLock::new(peer_static_pubkey), send_epoch: AtomicU32::new(initial_epoch), send_seq: [AtomicU64::new(0), AtomicU64::new(0)], send_epoch_started_ms: AtomicU64::new(now_ms), @@ -415,6 +460,8 @@ impl PeerSession { key_cache: Mutex::new(key_cache), send_cipher_algorithm, recv_cipher_algorithm, + invalidated: AtomicBool::new(false), + decrypt_fail_count: AtomicU32::new(0), } } @@ -422,6 +469,15 @@ impl PeerSession { self.peer_id } + /// Mark this session as invalid. All holders of Arc will see this. + pub fn invalidate(&self) { + self.invalidated.store(true, Ordering::Relaxed); + } + + pub fn is_valid(&self) -> bool { + !self.invalidated.load(Ordering::Relaxed) + } + pub fn session_generation(&self) -> u32 { self.session_generation.load(Ordering::Relaxed) } @@ -466,6 +522,24 @@ impl PeerSession { Ok(()) } + pub fn check_or_set_peer_static_pubkey( + &self, + peer_static_pubkey: Option<[u8; 32]>, + ) -> Result<(), anyhow::Error> { + let Some(peer_static_pubkey) = peer_static_pubkey else { + return Ok(()); + }; + let mut guard = self.peer_static_pubkey.write().unwrap(); + if let Some(existing) = *guard { + if existing != peer_static_pubkey { + return Err(anyhow!("peer static pubkey mismatch")); + } + return Ok(()); + } + *guard = Some(peer_static_pubkey); + Ok(()) + } + pub fn sync_root_key(&self, root_key: [u8; 32], session_generation: u32, initial_epoch: u32) { { let mut g = self.root_key.write().unwrap(); @@ -484,12 +558,7 @@ impl PeerSession { { let mut rx = self.rx_slots.lock().unwrap(); for dir in 0..2 { - rx[dir][0] = EpochRxSlot { - epoch: initial_epoch, - window: ReplayWindow256::default(), - last_rx_ms: 0, - valid: true, - }; + rx[dir][0].clear(); rx[dir][1].clear(); } } @@ -703,12 +772,23 @@ impl PeerSession { receiver_peer_id: PeerId, pkt: &mut ZCPacket, ) -> Result<(), anyhow::Error> { + if !self.is_valid() { + return Err(anyhow!("session invalidated")); + } let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id); let (epoch, _seq, nonce_bytes) = self.next_nonce(dir); let encryptor = self .get_encryptor(epoch, dir, true) .ok_or_else(|| anyhow!("no key for epoch"))?; - let _ = encryptor.encrypt_with_nonce(pkt, Some(nonce_bytes.as_slice())); + if let Err(e) = encryptor.encrypt_with_nonce(pkt, Some(nonce_bytes.as_slice())) { + tracing::warn!( + peer_id = ?self.peer_id, + ?e, + "session encrypt failed, invalidating" + ); + self.invalidate(); + return Err(e.into()); + } Ok(()) } @@ -718,6 +798,9 @@ impl PeerSession { receiver_peer_id: PeerId, ciphertext_with_tail: &mut ZCPacket, ) -> Result<(), anyhow::Error> { + if !self.is_valid() { + return Err(anyhow!("session invalidated")); + } let dir = Self::dir_for_sender(sender_peer_id, receiver_peer_id); let nonce_bytes = Self::parse_tail(ciphertext_with_tail.payload()).ok_or_else(|| anyhow!("no tail"))?; @@ -726,13 +809,29 @@ impl PeerSession { let now_ms = now_ms(); if !self.check_replay(epoch, seq, dir, now_ms) { - return Err(anyhow!("replay rejected")); + return Err(anyhow!( + "replay rejected, sender_peer_id: {:?}, receiver_peer_id: {:?}", + sender_peer_id, + receiver_peer_id + )); } let encryptor = self .get_encryptor(epoch, dir, false) .ok_or_else(|| anyhow!("no key for epoch"))?; - encryptor.decrypt(ciphertext_with_tail)?; + if let Err(e) = encryptor.decrypt(ciphertext_with_tail) { + let count = self.decrypt_fail_count.fetch_add(1, Ordering::Relaxed) + 1; + if count >= Self::DECRYPT_FAIL_THRESHOLD { + self.invalidate(); + tracing::warn!( + peer_id = ?self.peer_id, + count, + "session auto-invalidated after consecutive decrypt failures" + ); + } + return Err(e.into()); + } + self.decrypt_fail_count.store(0, Ordering::Relaxed); Ok(()) } @@ -764,6 +863,7 @@ mod tests { initial_epoch, "aes-256-gcm".to_string(), "chacha20-poly1305".to_string(), + None, ); let sb = PeerSession::new( a, @@ -772,6 +872,7 @@ mod tests { initial_epoch, "chacha20-poly1305".to_string(), "aes-256-gcm".to_string(), + None, ); let plaintext1 = b"hello from a"; @@ -802,6 +903,7 @@ mod tests { initial_epoch, "aes-256-gcm".to_string(), "aes-256-gcm".to_string(), + None, ); let now = now_ms(); @@ -814,4 +916,71 @@ mod tests { assert!(s.check_replay(1, 1, 0, now + 1)); assert!(s.check_replay(1, 2, 0, now + 2)); } + + #[test] + fn replay_window_shift_preserves_bits() { + let mut w = ReplayWindow256::default(); + // Accept seqs 0..10 + for i in 0..10u64 { + assert!(w.accept(i), "seq {i} should be accepted"); + } + assert_eq!(w.max_seq, 9); + + // All seqs 0..10 should be marked as seen (replay) + for i in 0..10u64 { + assert!(!w.accept(i), "seq {i} should be rejected as replay"); + } + + // Seq 10 should still be accepted + assert!(w.accept(10)); + } + + #[test] + fn replay_window_out_of_order_within_window() { + let mut w = ReplayWindow256::default(); + // Accept even seqs 0,2,4,...,20 + for i in (0..=20u64).step_by(2) { + assert!(w.accept(i), "seq {i} should be accepted"); + } + // Now accept odd seqs 1,3,5,...,19 (out of order, within window) + for i in (1..=19u64).step_by(2) { + assert!(w.accept(i), "seq {i} should be accepted (out of order)"); + } + // All seqs 0..=20 should now be marked as seen + for i in 0..=20u64 { + assert!(!w.accept(i), "seq {i} should be rejected as replay"); + } + } + + #[test] + fn sync_root_key_allows_any_epoch_from_remote() { + // After sync_root_key, the remote peer may still be sending at an + // old epoch. The receiver should accept those packets. + let peer_id: PeerId = 10; + let root_key = PeerSession::new_root_key(); + let s = PeerSession::new( + peer_id, + root_key, + 1, + 0, + "aes-256-gcm".to_string(), + "aes-256-gcm".to_string(), + None, + ); + + // Simulate receiving some packets at epoch 0 + let now = now_ms(); + assert!(s.check_replay(0, 0, 0, now)); + assert!(s.check_replay(0, 1, 0, now)); + + // Sync with initial_epoch=2 (simulating a Sync action) + s.sync_root_key(root_key, 2, 2); + + // Remote peer is still sending at epoch 0 — should be accepted + // (rx_slots were cleared, so the first packet establishes the epoch) + assert!( + s.check_replay(0, 10, 0, now + 1), + "packets at old epoch should be accepted after sync" + ); + } } diff --git a/easytier/src/peers/relay_peer_map.rs b/easytier/src/peers/relay_peer_map.rs new file mode 100644 index 000000000..bc8e377ae --- /dev/null +++ b/easytier/src/peers/relay_peer_map.rs @@ -0,0 +1,675 @@ +use std::{sync::Arc, time::Instant}; + +use dashmap::DashMap; +use prost::Message; +use snow::params::NoiseParams; +use tokio::sync::{oneshot, Mutex, OwnedMutexGuard}; +use tokio::time::{timeout, Duration}; + +use crate::peers::foreign_network_client::ForeignNetworkClient; +use crate::{ + common::error::Error, + common::{global_ctx::ArcGlobalCtx, PeerId}, + peers::peer_map::PeerMap, + peers::peer_session::{PeerSession, PeerSessionAction, PeerSessionStore, SessionKey}, + peers::route_trait::NextHopPolicy, + proto::peer_rpc::{PeerConnSessionActionPb, RelayNoiseMsg1Pb, RelayNoiseMsg2Pb}, + tunnel::packet_def::{PacketType, ZCPacket}, +}; + +const RELAY_NOISE_VERSION: u32 = 1; +const RELAY_NOISE_PROLOGUE: &[u8] = b"easytier-relay-noise"; +const HANDSHAKE_TIMEOUT_SECS: u64 = 5; +const HANDSHAKE_RETRY_BASE_MS: u64 = 200; +const HANDSHAKE_MAX_ATTEMPTS: u32 = 3; +const MAX_PENDING_PACKETS_PER_PEER: usize = 32; + +#[derive(Clone)] +pub struct RelayPeerState { + pub last_active_at: Instant, + pub failure_count: u32, + pub next_retry_at: Option, +} + +impl Default for RelayPeerState { + fn default() -> Self { + Self { + last_active_at: Instant::now(), + failure_count: 0, + next_retry_at: None, + } + } +} + +pub struct RelayPeerMap { + peer_map: Arc, + foreign_network_client: Option>, + global_ctx: ArcGlobalCtx, + my_peer_id: PeerId, + peer_session_store: Arc, + states: DashMap, + pending_handshakes: DashMap>, + handshake_locks: DashMap>>, + pub(crate) pending_packets: DashMap>, + + is_secure_mode_enabled: bool, +} + +impl RelayPeerMap { + pub fn new( + peer_map: Arc, + foreign_network_client: Option>, + global_ctx: ArcGlobalCtx, + my_peer_id: PeerId, + peer_session_store: Arc, + ) -> Arc { + let is_secure_mode_enabled = global_ctx + .config + .get_secure_mode() + .map(|cfg| cfg.enabled) + .unwrap_or(false); + Arc::new(Self { + peer_map, + foreign_network_client, + global_ctx, + my_peer_id, + peer_session_store, + states: DashMap::new(), + pending_handshakes: DashMap::new(), + handshake_locks: DashMap::new(), + pending_packets: DashMap::new(), + is_secure_mode_enabled, + }) + } + + pub fn is_secure_mode_enabled(&self) -> bool { + self.is_secure_mode_enabled + } + + fn get_local_keypair(&self) -> Result<(Vec, Vec), Error> { + let cfg = self + .global_ctx + .config + .get_secure_mode() + .ok_or_else(|| Error::RouteError(Some("secure mode config not set".to_string())))?; + let private = cfg + .private_key() + .map_err(|e| Error::RouteError(Some(format!("invalid private key: {e:?}"))))?; + let public = cfg + .public_key() + .map_err(|e| Error::RouteError(Some(format!("invalid public key: {e:?}"))))?; + Ok((private.as_bytes().to_vec(), public.as_bytes().to_vec())) + } + + async fn get_remote_static_pubkey(&self, peer_id: PeerId) -> Result, Error> { + let info = self + .peer_map + .get_route_peer_info(peer_id) + .await + .ok_or_else(|| Error::RouteError(Some("route peer info not found".to_string())))?; + if info.noise_static_pubkey.is_empty() { + return Err(Error::RouteError(Some( + "remote static pubkey not found".to_string(), + ))); + } + Ok(info.noise_static_pubkey) + } + + fn get_handshake_lock(&self, peer_id: PeerId) -> Arc> { + self.handshake_locks + .entry(peer_id) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() + } + + async fn send_handshake_packet( + &self, + payload: Vec, + packet_type: PacketType, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result<(), Error> { + let mut pkt = ZCPacket::new_with_payload(&payload); + pkt.fill_peer_manager_hdr(self.my_peer_id, dst_peer_id, packet_type as u8); + self.send_via_next_hop(pkt, dst_peer_id, policy).await + } + + async fn send_via_next_hop( + &self, + msg: ZCPacket, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result<(), Error> { + let Some(next_hop) = self.peer_map.get_gateway_peer_id(dst_peer_id, policy).await else { + return Err(Error::RouteError(Some(format!( + "next hop not found in route for peer {dst_peer_id:?}" + )))); + }; + if self.peer_map.has_peer(next_hop) { + self.peer_map.send_msg_directly(msg, next_hop).await + } else if let Some(foreign_network_client) = &self.foreign_network_client { + foreign_network_client.send_msg(msg, next_hop).await + } else { + Err(Error::RouteError(Some(format!( + "next hop not found in direct peer map: {next_hop:?}" + )))) + } + } + + pub async fn send_msg( + self: &Arc, + mut msg: ZCPacket, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result<(), Error> { + let now = Instant::now(); + + self.states.entry(dst_peer_id).or_default().last_active_at = now; + + if self.is_secure_mode_enabled() { + match self.ensure_session(dst_peer_id, policy.clone()).await { + Ok(session) => { + let my_peer_id = self.my_peer_id; + session + .encrypt_payload(my_peer_id, dst_peer_id, &mut msg) + .map_err(|e| Error::RouteError(Some(format!("{e:?}"))))?; + } + Err(_) => { + // Handshake in progress, buffer the packet instead of dropping it + self.buffer_pending_packet(dst_peer_id, msg, policy); + return Ok(()); + } + } + } + + self.send_via_next_hop(msg, dst_peer_id, policy).await + } + + fn buffer_pending_packet(&self, dst_peer_id: PeerId, pkt: ZCPacket, policy: NextHopPolicy) { + let mut entry = self.pending_packets.entry(dst_peer_id).or_default(); + if entry.len() < MAX_PENDING_PACKETS_PER_PEER { + entry.push((pkt, policy)); + } + // silently drop when buffer is full + } + + async fn flush_pending_packets(&self, dst_peer_id: PeerId, session: Arc) { + let packets = self.pending_packets.remove(&dst_peer_id).map(|(_, v)| v); + let Some(packets) = packets else { return }; + if packets.is_empty() { + return; + } + + tracing::debug!( + ?dst_peer_id, + count = packets.len(), + "flushing pending packets after relay handshake" + ); + + for (mut pkt, policy) in packets { + if session + .encrypt_payload(self.my_peer_id, dst_peer_id, &mut pkt) + .is_err() + { + continue; + } + let _ = self.send_via_next_hop(pkt, dst_peer_id, policy).await; + } + } + + pub fn has_session(&self, dst_peer_id: PeerId) -> bool { + self.peer_session_store + .get(&SessionKey::new( + self.global_ctx.get_network_identity().network_name.clone(), + dst_peer_id, + )) + .is_some() + } + + pub async fn ensure_session( + self: &Arc, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result, Error> { + let network = self.global_ctx.get_network_identity(); + let key = SessionKey::new(network.network_name.clone(), dst_peer_id); + if let Some(session) = self.peer_session_store.get(&key) { + return Ok(session); + } + + let lock = self.get_handshake_lock(dst_peer_id); + if let Ok(guard) = lock.try_lock_owned() { + let self_clone = self.clone(); + tokio::spawn(async move { + self_clone + .handshake_session(dst_peer_id, policy, Some(guard)) + .await + }); + }; + Err(Error::RouteError(Some( + "relay handshake in progress".to_string(), + ))) + } + + #[tracing::instrument(skip(self, _lock_guard), level = "debug", ret)] + pub async fn handshake_session( + &self, + dst_peer_id: PeerId, + policy: NextHopPolicy, + _lock_guard: Option>, + ) -> Result<(), Error> { + let network = self.global_ctx.get_network_identity(); + let key = SessionKey::new(network.network_name.clone(), dst_peer_id); + if let Some(session) = self.peer_session_store.get(&key) { + self.flush_pending_packets(dst_peer_id, session).await; + return Ok(()); + } + + if let Some(next_retry_at) = self.states.get(&dst_peer_id).and_then(|v| v.next_retry_at) { + if Instant::now() < next_retry_at { + self.pending_packets.remove(&dst_peer_id); + return Err(Error::RouteError(Some( + "relay handshake backoff".to_string(), + ))); + } + } + + let mut last_err = None; + for attempt in 0..HANDSHAKE_MAX_ATTEMPTS { + let ret = self + .handshake_session_once(dst_peer_id, policy.clone()) + .await; + match ret { + Ok(session) => { + self.register_handshake_success(dst_peer_id); + self.flush_pending_packets(dst_peer_id, session).await; + return Ok(()); + } + Err(e) => { + last_err = Some(e); + self.register_handshake_failure(dst_peer_id, attempt); + if attempt + 1 < HANDSHAKE_MAX_ATTEMPTS { + let backoff = HANDSHAKE_RETRY_BASE_MS.saturating_mul(1 << attempt); + tokio::time::sleep(Duration::from_millis(backoff)).await; + } + } + } + } + + // All attempts failed, drop buffered packets + self.pending_packets.remove(&dst_peer_id); + + Err(last_err + .unwrap_or_else(|| Error::RouteError(Some("relay handshake failed".to_string())))) + } + + #[tracing::instrument(skip(self), level = "debug", ret)] + async fn handshake_session_once( + &self, + dst_peer_id: PeerId, + policy: NextHopPolicy, + ) -> Result, Error> { + let network = self.global_ctx.get_network_identity(); + let session_key = SessionKey::new(network.network_name.clone(), dst_peer_id); + let (local_private_key, _local_public_key) = self.get_local_keypair()?; + let remote_static = self.get_remote_static_pubkey(dst_peer_id).await?; + let params: NoiseParams = "Noise_IK_25519_ChaChaPoly_SHA256" + .parse() + .map_err(|e| Error::RouteError(Some(format!("parse noise params failed: {e:?}"))))?; + + let builder = snow::Builder::new(params); + let mut hs = builder + .prologue(RELAY_NOISE_PROLOGUE) + .map_err(|e| Error::RouteError(Some(format!("set prologue failed: {e:?}"))))? + .local_private_key(&local_private_key) + .map_err(|e| Error::RouteError(Some(format!("set local key failed: {e:?}"))))? + .remote_public_key(&remote_static) + .map_err(|e| Error::RouteError(Some(format!("set remote key failed: {e:?}"))))? + .build_initiator() + .map_err(|e| Error::RouteError(Some(format!("build initiator failed: {e:?}"))))?; + + let a_session_generation = self + .peer_session_store + .get(&session_key) + .map(|s| s.session_generation()); + let a_conn_id = uuid::Uuid::new_v4(); + let msg1_pb = RelayNoiseMsg1Pb { + version: RELAY_NOISE_VERSION, + a_session_generation, + a_conn_id: Some(a_conn_id.into()), + client_encryption_algorithm: self.global_ctx.get_flags().encryption_algorithm.clone(), + }; + let payload = msg1_pb.encode_to_vec(); + let mut out = vec![0u8; 4096]; + let out_len = hs + .write_message(&payload, &mut out) + .map_err(|e| Error::RouteError(Some(format!("noise write msg1 failed: {e:?}"))))?; + let (tx, rx) = oneshot::channel(); + self.pending_handshakes.insert(dst_peer_id, tx); + + let send_res = self + .send_handshake_packet( + out[..out_len].to_vec(), + PacketType::RelayHandshake, + dst_peer_id, + policy, + ) + .await; + + if send_res.is_err() { + self.pending_handshakes.remove(&dst_peer_id); + } + send_res?; + let msg2_pkt = match timeout(Duration::from_secs(HANDSHAKE_TIMEOUT_SECS), rx).await { + Ok(Ok(pkt)) => pkt, + Ok(Err(_)) => { + self.pending_handshakes.remove(&dst_peer_id); + return Err(Error::RouteError(Some( + "relay handshake canceled".to_string(), + ))); + } + Err(_) => { + self.pending_handshakes.remove(&dst_peer_id); + return Err(Error::RouteError(Some( + "relay handshake timeout".to_string(), + ))); + } + }; + + let msg2_pb = self.decode_handshake_message::( + PacketType::RelayHandshakeAck, + &mut hs, + msg2_pkt, + )?; + if msg2_pb.a_conn_id_echo != Some(a_conn_id.into()) { + return Err(Error::RouteError(Some( + "relay msg2 conn_id_echo mismatch".to_string(), + ))); + } + + let action = PeerConnSessionActionPb::try_from(msg2_pb.action) + .map_err(|_| Error::RouteError(Some("invalid session action".to_string())))?; + let session_action = match action { + PeerConnSessionActionPb::Join => PeerSessionAction::Join, + PeerConnSessionActionPb::Sync => PeerSessionAction::Sync, + PeerConnSessionActionPb::Create => PeerSessionAction::Create, + }; + let remote_static_key = if remote_static.len() == 32 { + let mut key = [0u8; 32]; + key.copy_from_slice(&remote_static); + Some(key) + } else { + None + }; + let root_key_bytes = msg2_pb + .root_key_32 + .as_deref() + .filter(|v| v.len() == 32) + .map(|v| { + let mut key_bytes = [0u8; 32]; + key_bytes.copy_from_slice(v); + key_bytes + }); + let algo = self.global_ctx.get_flags().encryption_algorithm.clone(); + let session = self + .peer_session_store + .apply_initiator_action( + &session_key, + session_action, + msg2_pb.b_session_generation, + root_key_bytes, + msg2_pb.initial_epoch, + algo, + msg2_pb.server_encryption_algorithm.clone(), + remote_static_key, + ) + .map_err(|e| Error::RouteError(Some(format!("{e:?}"))))?; + + Ok(session) + } + + fn register_handshake_success(&self, dst_peer_id: PeerId) { + let mut entry = self.states.entry(dst_peer_id).or_default(); + entry.failure_count = 0; + entry.next_retry_at = None; + } + + fn register_handshake_failure(&self, dst_peer_id: PeerId, attempt: u32) { + let mut entry = self.states.entry(dst_peer_id).or_default(); + entry.failure_count = entry.failure_count.saturating_add(1); + let backoff = HANDSHAKE_RETRY_BASE_MS.saturating_mul(1 << attempt); + entry.next_retry_at = Some(Instant::now() + Duration::from_millis(backoff)); + } + + fn decode_handshake_message( + &self, + expected_type: PacketType, + hs: &mut snow::HandshakeState, + pkt: ZCPacket, + ) -> Result { + let hdr = pkt.peer_manager_header().ok_or_else(|| { + Error::RouteError(Some("packet without peer manager header".to_string())) + })?; + if hdr.packet_type != expected_type as u8 { + return Err(Error::RouteError(Some("packet type mismatch".to_string()))); + } + let mut out = vec![0u8; 4096]; + let out_len = hs + .read_message(pkt.payload(), &mut out) + .map_err(|e| Error::RouteError(Some(format!("noise read msg failed: {e:?}"))))?; + let msg = MsgT::decode(&out[..out_len]) + .map_err(|e| Error::RouteError(Some(format!("decode message failed: {e:?}"))))?; + Ok(msg) + } + + pub async fn handle_handshake_packet(&self, packet: ZCPacket) -> Result<(), Error> { + let hdr = packet + .peer_manager_header() + .ok_or_else(|| Error::RouteError(Some("packet without header".to_string())))?; + let src_peer_id = hdr.from_peer_id.get(); + match hdr.packet_type { + x if x == PacketType::RelayHandshake as u8 => { + tracing::debug!("handle_relay_msg1 from {:?}", src_peer_id); + self.handle_relay_msg1(packet, src_peer_id).await + } + x if x == PacketType::RelayHandshakeAck as u8 => { + if let Some((_, sender)) = self.pending_handshakes.remove(&src_peer_id) { + let _ = sender.send(packet); + } + Ok(()) + } + _ => Ok(()), + } + } + + async fn handle_relay_msg1(&self, msg1: ZCPacket, remote_peer_id: PeerId) -> Result<(), Error> { + // Check for bidirectional handshake race condition. + // If we are also waiting for a RelayHandshakeAck from this peer, + // use deterministic rule: the peer with smaller peer_id becomes initiator. + if self.pending_handshakes.contains_key(&remote_peer_id) { + // We have a pending handshake as initiator. + // If remote_peer_id < my_peer_id, remote should be initiator, we should be responder. + // Cancel our pending handshake and proceed as responder. + if remote_peer_id < self.my_peer_id { + tracing::debug!( + ?remote_peer_id, + my_peer_id = ?self.my_peer_id, + "bidirectional handshake race: yielding initiator role to smaller peer_id" + ); + // Remove our pending handshake + self.pending_handshakes.remove(&remote_peer_id); + } else { + // We have smaller peer_id, we should remain initiator. + // Ignore this RelayHandshake and let our initiator flow complete. + tracing::debug!( + ?remote_peer_id, + my_peer_id = ?self.my_peer_id, + "bidirectional handshake race: keeping initiator role due to smaller peer_id" + ); + return Err(Error::RouteError(Some( + "bidirectional handshake race: we are initiator".to_string(), + ))); + } + } + + let (local_private_key, _local_public_key) = self.get_local_keypair()?; + let params: NoiseParams = "Noise_IK_25519_ChaChaPoly_SHA256" + .parse() + .map_err(|e| Error::RouteError(Some(format!("parse noise params failed: {e:?}"))))?; + let builder = snow::Builder::new(params); + let mut hs = builder + .prologue(RELAY_NOISE_PROLOGUE) + .map_err(|e| Error::RouteError(Some(format!("set prologue failed: {e:?}"))))? + .local_private_key(&local_private_key) + .map_err(|e| Error::RouteError(Some(format!("set local key failed: {e:?}"))))? + .build_responder() + .map_err(|e| Error::RouteError(Some(format!("build responder failed: {e:?}"))))?; + + let msg1_pb = self.decode_handshake_message::( + PacketType::RelayHandshake, + &mut hs, + msg1, + )?; + let remote_static = hs + .get_remote_static() + .map(|x: &[u8]| x.to_vec()) + .unwrap_or_default(); + let remote_static_key = if remote_static.len() == 32 { + let mut key = [0u8; 32]; + key.copy_from_slice(&remote_static); + Some(key) + } else { + None + }; + + // Verify initiator's static public key matches the expected key from route info + let expected_pubkey = self.get_remote_static_pubkey(remote_peer_id).await?; + if remote_static != expected_pubkey { + return Err(Error::RouteError(Some(format!( + "responder: initiator static pubkey mismatch for peer {}, expected {} bytes, got {} bytes", + remote_peer_id, + expected_pubkey.len(), + remote_static.len() + )))); + } + + let server_network_name = self.global_ctx.get_network_name(); + let algo = self.global_ctx.get_flags().encryption_algorithm.clone(); + let key = SessionKey::new(server_network_name.clone(), remote_peer_id); + let upsert = self + .peer_session_store + .upsert_responder_session( + &key, + msg1_pb.a_session_generation, + algo.clone(), + msg1_pb.client_encryption_algorithm.clone(), + remote_static_key, + ) + .map_err(|e| Error::RouteError(Some(format!("{e:?}"))))?; + let msg2_pb = RelayNoiseMsg2Pb { + action: match upsert.action { + PeerSessionAction::Join => PeerConnSessionActionPb::Join as i32, + PeerSessionAction::Sync => PeerConnSessionActionPb::Sync as i32, + PeerSessionAction::Create => PeerConnSessionActionPb::Create as i32, + }, + b_session_generation: upsert.session_generation, + root_key_32: upsert.root_key.map(|k| k.to_vec()), + initial_epoch: upsert.initial_epoch, + b_conn_id: Some(uuid::Uuid::new_v4().into()), + a_conn_id_echo: msg1_pb.a_conn_id, + server_encryption_algorithm: algo, + }; + let payload = msg2_pb.encode_to_vec(); + let mut out = vec![0u8; 4096]; + let out_len = hs + .write_message(&payload, &mut out) + .map_err(|e| Error::RouteError(Some(format!("noise write msg2 failed: {e:?}"))))?; + + self.register_handshake_success(remote_peer_id); + + self.send_handshake_packet( + out[..out_len].to_vec(), + PacketType::RelayHandshakeAck, + remote_peer_id, + NextHopPolicy::LeastHop, + ) + .await?; + + // Flush any packets buffered while waiting for the handshake to complete + self.flush_pending_packets(remote_peer_id, upsert.session) + .await; + + Ok(()) + } + + pub async fn decrypt_if_needed(self: &Arc, packet: &mut ZCPacket) -> Result { + if !self.is_secure_mode_enabled() { + return Ok(false); + } + let hdr = packet + .peer_manager_header() + .ok_or_else(|| Error::RouteError(Some("packet without header".to_string())))?; + let from_peer_id = hdr.from_peer_id.get(); + let network = self.global_ctx.get_network_identity(); + let key = SessionKey::new(network.network_name.clone(), from_peer_id); + let Some(session) = self.peer_session_store.get(&key) else { + tracing::debug!( + "relay session not found for peer {}, try handshake", + from_peer_id + ); + self.ensure_session(from_peer_id, NextHopPolicy::LeastHop) + .await?; + return Ok(false); + }; + let now = Instant::now(); + let mut entry = self.states.entry(from_peer_id).or_default(); + entry.last_active_at = now; + session.decrypt_payload(from_peer_id, self.my_peer_id, packet)?; + Ok(true) + } + + pub fn evict_idle_sessions(&self, idle: Duration) { + let now = Instant::now(); + let mut to_remove = Vec::new(); + for entry in self.states.iter() { + if now.duration_since(entry.last_active_at) > idle { + to_remove.push(*entry.key()); + } + } + for peer_id in to_remove { + self.states.remove(&peer_id); + self.pending_handshakes.remove(&peer_id); + self.handshake_locks.remove(&peer_id); + self.pending_packets.remove(&peer_id); + } + } + + pub fn has_state(&self, peer_id: PeerId) -> bool { + self.states.contains_key(&peer_id) + } + + pub fn failure_count(&self, peer_id: PeerId) -> Option { + self.states.get(&peer_id).map(|v| v.failure_count) + } + + pub fn is_backoff_active(&self, peer_id: PeerId) -> bool { + self.states + .get(&peer_id) + .and_then(|v| v.next_retry_at) + .is_some_and(|ts| Instant::now() < ts) + } + + /// Remove relay-specific state for a specific peer. + /// This does NOT remove the session from PeerSessionStore, because the + /// session lifecycle is independent of any particular connection type + /// (relay or direct). The session may still be used by direct connections + /// or for fast reconnection (Join instead of Create). + pub fn remove_peer(&self, peer_id: PeerId) { + self.states.remove(&peer_id); + self.pending_handshakes.remove(&peer_id); + self.handshake_locks.remove(&peer_id); + self.pending_packets.remove(&peer_id); + + tracing::debug!(?peer_id, "RelayPeerMap removed peer relay state"); + } +} diff --git a/easytier/src/peers/route_trait.rs b/easytier/src/peers/route_trait.rs index ef9b99c4d..7dc0319e2 100644 --- a/easytier/src/peers/route_trait.rs +++ b/easytier/src/peers/route_trait.rs @@ -8,8 +8,8 @@ use dashmap::DashMap; use crate::{ common::{global_ctx::NetworkIdentity, PeerId}, proto::peer_rpc::{ - ForeignNetworkRouteInfoEntry, ForeignNetworkRouteInfoKey, RouteForeignNetworkInfos, - RouteForeignNetworkSummary, RoutePeerInfo, + ForeignNetworkRouteInfoEntry, ForeignNetworkRouteInfoKey, PeerIdentityType, + RouteForeignNetworkInfos, RouteForeignNetworkSummary, RoutePeerInfo, }, }; @@ -27,6 +27,10 @@ pub type ForeignNetworkRouteInfoMap = pub trait RouteInterface { async fn list_peers(&self) -> Vec; fn my_peer_id(&self) -> PeerId; + async fn close_peer(&self, _peer_id: PeerId) {} + async fn get_peer_identity_type(&self, _peer_id: PeerId) -> Option { + None + } async fn list_foreign_networks(&self) -> ForeignNetworkRouteInfoMap { DashMap::new() } diff --git a/easytier/src/peers/rpc_service.rs b/easytier/src/peers/rpc_service.rs index f5ffef670..61e1cda48 100644 --- a/easytier/src/peers/rpc_service.rs +++ b/easytier/src/peers/rpc_service.rs @@ -1,17 +1,21 @@ use std::{ ops::Deref, sync::{Arc, Weak}, + time::Duration, }; use crate::{ proto::{ api::instance::{ - AclManageRpc, DumpRouteRequest, DumpRouteResponse, GetAclStatsRequest, + AclManageRpc, CredentialManageRpc, DumpRouteRequest, DumpRouteResponse, + GenerateCredentialRequest, GenerateCredentialResponse, GetAclStatsRequest, GetAclStatsResponse, GetForeignNetworkSummaryRequest, GetForeignNetworkSummaryResponse, - GetWhitelistRequest, GetWhitelistResponse, ListForeignNetworkRequest, - ListForeignNetworkResponse, ListGlobalForeignNetworkRequest, - ListGlobalForeignNetworkResponse, ListPeerRequest, ListPeerResponse, ListRouteRequest, - ListRouteResponse, PeerInfo, PeerManageRpc, ShowNodeInfoRequest, ShowNodeInfoResponse, + GetWhitelistRequest, GetWhitelistResponse, ListCredentialsRequest, + ListCredentialsResponse, ListForeignNetworkRequest, ListForeignNetworkResponse, + ListGlobalForeignNetworkRequest, ListGlobalForeignNetworkResponse, ListPeerRequest, + ListPeerResponse, ListRouteRequest, ListRouteResponse, PeerInfo, PeerManageRpc, + RevokeCredentialRequest, RevokeCredentialResponse, ShowNodeInfoRequest, + ShowNodeInfoResponse, }, rpc_types::{self, controller::BaseController}, }, @@ -201,3 +205,82 @@ impl AclManageRpc for PeerManagerRpcService { }) } } + +#[async_trait::async_trait] +impl CredentialManageRpc for PeerManagerRpcService { + type Controller = BaseController; + + async fn generate_credential( + &self, + _: BaseController, + request: GenerateCredentialRequest, + ) -> Result { + let pm = weak_upgrade(&self.peer_manager)?; + let global_ctx = pm.get_global_ctx(); + + if global_ctx.get_network_identity().network_secret.is_none() { + return Err(rpc_types::error::Error::ExecutionError(anyhow::anyhow!( + "only admin nodes (with network_secret) can generate credentials" + ))); + } + + let ttl = if request.ttl_seconds > 0 { + Duration::from_secs(request.ttl_seconds as u64) + } else { + return Err(rpc_types::error::Error::ExecutionError(anyhow::anyhow!( + "ttl_seconds must be positive" + ))); + }; + + let (id, secret) = global_ctx.get_credential_manager().generate_credential( + request.groups, + request.allow_relay, + request.allowed_proxy_cidrs, + ttl, + ); + + global_ctx.issue_event(crate::common::global_ctx::GlobalCtxEvent::CredentialChanged); + + Ok(GenerateCredentialResponse { + credential_id: id, + credential_secret: secret, + }) + } + + async fn revoke_credential( + &self, + _: BaseController, + request: RevokeCredentialRequest, + ) -> Result { + let pm = weak_upgrade(&self.peer_manager)?; + let global_ctx = pm.get_global_ctx(); + if global_ctx.get_network_identity().network_secret.is_none() { + return Err(rpc_types::error::Error::ExecutionError(anyhow::anyhow!( + "only admin nodes (with network_secret) can revoke credentials" + ))); + } + + let success = global_ctx + .get_credential_manager() + .revoke_credential(&request.credential_id); + + if success { + global_ctx.issue_event(crate::common::global_ctx::GlobalCtxEvent::CredentialChanged); + } + + Ok(RevokeCredentialResponse { success }) + } + + async fn list_credentials( + &self, + _: BaseController, + _request: ListCredentialsRequest, + ) -> Result { + let pm = weak_upgrade(&self.peer_manager)?; + let global_ctx = pm.get_global_ctx(); + + Ok(ListCredentialsResponse { + credentials: global_ctx.get_credential_manager().list_credentials(), + }) + } +} diff --git a/easytier/src/peers/tests.rs b/easytier/src/peers/tests.rs index b09b400f5..81d303aad 100644 --- a/easytier/src/peers/tests.rs +++ b/easytier/src/peers/tests.rs @@ -1,4 +1,7 @@ use std::sync::Arc; +use std::time::Duration; + +use base64::Engine as _; use crate::{ common::{ @@ -9,12 +12,21 @@ use crate::{ }, PeerId, }, - tunnel::ring::create_ring_tunnel_pair, + tunnel::{ + common::tests::wait_for_condition, + packet_def::{PacketType, ZCPacket}, + ring::create_ring_tunnel_pair, + }, }; use super::{ create_packet_recv_chan, + peer_conn::tests::set_secure_mode_cfg, peer_manager::{PeerManager, RouteAlgoType}, + peer_map::PeerMap, + peer_session::{PeerSession, PeerSessionStore, SessionKey}, + relay_peer_map::RelayPeerMap, + route_trait::NextHopPolicy, }; pub async fn create_mock_peer_manager() -> Arc { @@ -37,6 +49,19 @@ pub async fn create_mock_peer_manager_with_name(network_name: String) -> Arc Arc { + let (s, _r) = create_packet_recv_chan(); + let g = + get_mock_global_ctx_with_network(Some(NetworkIdentity::new(network_name, network_secret))); + set_secure_mode_cfg(&g, true); + let peer_mgr = Arc::new(PeerManager::new(RouteAlgoType::Ospf, g, s)); + peer_mgr.run().await.unwrap(); + peer_mgr +} + pub async fn connect_peer_manager(client: Arc, server: Arc) { let (a_ring, b_ring) = create_ring_tunnel_pair(); let a_mgr_copy = client; @@ -127,3 +152,1124 @@ async fn foreign_mgr_stress_test() { } } } + +#[tokio::test] +async fn relay_peer_map_secure_session_decrypt() { + let (s, _r) = create_packet_recv_chan(); + let ctx = get_mock_global_ctx_with_network(Some(NetworkIdentity::new( + "net1".to_string(), + "sec1".to_string(), + ))); + set_secure_mode_cfg(&ctx, true); + let peer_map = Arc::new(PeerMap::new(s, ctx.clone(), 10)); + let store = Arc::new(PeerSessionStore::new()); + let relay_map = RelayPeerMap::new(peer_map, None, ctx.clone(), 10, store.clone()); + + let algo = ctx.get_flags().encryption_algorithm.clone(); + let root_key = [7u8; 32]; + let session = Arc::new(PeerSession::new( + 20, + root_key, + 1, + 1, + algo.clone(), + algo.clone(), + None, + )); + let key = SessionKey::new(ctx.get_network_identity().network_name, 20); + store.insert_session(key.clone(), session.clone()); + + relay_map + .ensure_session(20, NextHopPolicy::LeastHop) + .await + .unwrap(); + assert!(relay_map.has_session(20)); + + let mut packet = ZCPacket::new_with_payload(b"relay-hello"); + packet.fill_peer_manager_hdr(20, 10, PacketType::Data as u8); + session.encrypt_payload(20, 10, &mut packet).unwrap(); + assert!(relay_map.decrypt_if_needed(&mut packet).await.unwrap()); + assert_eq!(packet.payload(), b"relay-hello"); +} + +#[tokio::test] +async fn relay_peer_map_retry_backoff_and_evict() { + let (s, _r) = create_packet_recv_chan(); + let ctx_secure = get_mock_global_ctx(); + set_secure_mode_cfg(&ctx_secure, true); + let peer_map = Arc::new(PeerMap::new(s, ctx_secure.clone(), 10)); + let relay_map = RelayPeerMap::new( + peer_map, + None, + ctx_secure.clone(), + 10, + Arc::new(PeerSessionStore::new()), + ); + + let ret = relay_map + .handshake_session(20, NextHopPolicy::LeastHop, None) + .await; + assert!(ret.is_err()); + assert!(relay_map.failure_count(20).unwrap_or(0) >= 1); + assert!(relay_map.is_backoff_active(20)); + + let (s2, _r2) = create_packet_recv_chan(); + let ctx_plain = get_mock_global_ctx(); + let peer_map_plain = Arc::new(PeerMap::new(s2, ctx_plain.clone(), 30)); + let relay_map_plain = RelayPeerMap::new( + peer_map_plain, + None, + ctx_plain.clone(), + 30, + Arc::new(PeerSessionStore::new()), + ); + + let mut pkt = ZCPacket::new_with_payload(b"evict"); + pkt.fill_peer_manager_hdr(30, 40, PacketType::Data as u8); + let _ = relay_map_plain + .send_msg(pkt, 40, NextHopPolicy::LeastHop) + .await; + assert!(relay_map_plain.has_state(40)); + relay_map_plain.evict_idle_sessions(Duration::from_millis(0)); + assert!(!relay_map_plain.has_state(40)); +} + +#[tokio::test] +async fn relay_peer_map_pending_packet_buffer() { + // Verify that packets sent during handshake are buffered (not dropped), + // and flushed after handshake completes. + let (s, _r) = create_packet_recv_chan(); + let ctx = get_mock_global_ctx_with_network(Some(NetworkIdentity::new( + "net1".to_string(), + "sec1".to_string(), + ))); + set_secure_mode_cfg(&ctx, true); + let peer_map = Arc::new(PeerMap::new(s, ctx.clone(), 10)); + let store = Arc::new(PeerSessionStore::new()); + let relay_map = RelayPeerMap::new(peer_map, None, ctx.clone(), 10, store.clone()); + + // Send multiple packets while no session exists (handshake will fail, but packets should be buffered) + for i in 0..5u8 { + let mut pkt = ZCPacket::new_with_payload(&[i]); + pkt.fill_peer_manager_hdr(10, 20, PacketType::Data as u8); + let _ = relay_map.send_msg(pkt, 20, NextHopPolicy::LeastHop).await; + } + + // Verify packets were buffered + assert_eq!( + relay_map + .pending_packets + .get(&20) + .map(|v| v.len()) + .unwrap_or(0), + 5, + "5 packets should be buffered during handshake" + ); + + // Verify buffer respects capacity limit + for i in 0..50u8 { + let mut pkt = ZCPacket::new_with_payload(&[i]); + pkt.fill_peer_manager_hdr(10, 20, PacketType::Data as u8); + let _ = relay_map.send_msg(pkt, 20, NextHopPolicy::LeastHop).await; + } + + let buffered = relay_map + .pending_packets + .get(&20) + .map(|v| v.len()) + .unwrap_or(0); + assert!( + buffered <= 32, + "buffer should not exceed MAX_PENDING_PACKETS_PER_PEER, got {buffered}" + ); + + // Verify remove_peer clears pending packets + relay_map.remove_peer(20); + assert_eq!( + relay_map + .pending_packets + .get(&20) + .map(|v| v.len()) + .unwrap_or(0), + 0, + "pending packets should be cleared on peer removal" + ); +} + +#[tokio::test] +async fn relay_peer_map_pending_packets_flushed_on_handshake_success() { + // Test that pending packets are flushed after handshake succeeds. + // We pre-populate the buffer, then run handshake, and verify it's cleared. + let peer_a = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_b = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_c = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + + connect_peer_manager(peer_a.clone(), peer_b.clone()).await; + connect_peer_manager(peer_b.clone(), peer_c.clone()).await; + + let peer_a_id = peer_a.my_peer_id(); + let peer_c_id = peer_c.my_peer_id(); + + // Wait for routes to propagate + wait_for_condition( + || { + let peer_a = peer_a.clone(); + let peer_c = peer_c.clone(); + async move { wait_route_appear(peer_a.clone(), peer_c).await.is_ok() } + }, + Duration::from_secs(10), + ) + .await; + + // Wait for noise_static_pubkey to be available on both sides + wait_for_condition( + || { + let peer_a = peer_a.clone(); + async move { + peer_a + .get_peer_map() + .get_route_peer_info(peer_c_id) + .await + .map(|info| !info.noise_static_pubkey.is_empty()) + .unwrap_or(false) + } + }, + Duration::from_secs(10), + ) + .await; + + let relay_a = peer_a.get_relay_peer_map(); + + // Pre-populate pending packets buffer (simulating what send_msg does during handshake) + for i in 0..3u8 { + let mut pkt = ZCPacket::new_with_payload(&[i]); + pkt.fill_peer_manager_hdr(peer_a_id, peer_c_id, PacketType::Data as u8); + relay_a + .pending_packets + .entry(peer_c_id) + .or_default() + .push((pkt, NextHopPolicy::LeastHop)); + } + + assert_eq!( + relay_a + .pending_packets + .get(&peer_c_id) + .map(|v| v.len()) + .unwrap_or(0), + 3, + "3 packets should be in the buffer" + ); + + // Run handshake — on success it should flush the buffer + relay_a + .handshake_session(peer_c_id, NextHopPolicy::LeastHop, None) + .await + .unwrap(); + + // Verify session established and buffer cleared + assert!(relay_a.has_session(peer_c_id)); + assert_eq!( + relay_a + .pending_packets + .get(&peer_c_id) + .map(|v| v.len()) + .unwrap_or(0), + 0, + "pending packets should be flushed after successful handshake" + ); +} + +#[tokio::test] +async fn relay_peer_map_real_link_handshake_success() { + let peer_a = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_b = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_c = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + + connect_peer_manager(peer_a.clone(), peer_b.clone()).await; + connect_peer_manager(peer_b.clone(), peer_c.clone()).await; + + let peer_a_id = peer_a.my_peer_id(); + let peer_b_id = peer_b.my_peer_id(); + let peer_c_id = peer_c.my_peer_id(); + + wait_for_condition( + || { + let peer_a = peer_a.clone(); + let peer_c = peer_c.clone(); + async move { wait_route_appear(peer_a.clone(), peer_c).await.is_ok() } + }, + Duration::from_secs(10), + ) + .await; + + wait_for_condition( + || { + let peer_a = peer_a.clone(); + async move { + peer_a + .get_peer_map() + .get_gateway_peer_id(peer_c_id, NextHopPolicy::LeastHop) + .await + == Some(peer_b_id) + } + }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || { + let peer_a = peer_a.clone(); + async move { + peer_a + .get_peer_map() + .get_route_peer_info(peer_c_id) + .await + .map(|info| !info.noise_static_pubkey.is_empty()) + .unwrap_or(false) + } + }, + Duration::from_secs(10), + ) + .await; + + let relay_a = peer_a.get_relay_peer_map(); + let relay_c = peer_c.get_relay_peer_map(); + + relay_a + .handshake_session(peer_c_id, NextHopPolicy::LeastHop, None) + .await + .unwrap(); + + wait_for_condition( + || { + let relay_a = relay_a.clone(); + async move { relay_a.has_session(peer_c_id) } + }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || { + let relay_c = relay_c.clone(); + async move { relay_c.has_session(peer_a_id) } + }, + Duration::from_secs(5), + ) + .await; +} + +#[tokio::test] +async fn relay_peer_map_responder_rejects_mismatched_pubkey() { + // Create three peers: A -> B -> C + let peer_a = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_b = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_c = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + + connect_peer_manager(peer_a.clone(), peer_b.clone()).await; + connect_peer_manager(peer_b.clone(), peer_c.clone()).await; + + let peer_a_id = peer_a.my_peer_id(); + let peer_c_id = peer_c.my_peer_id(); + + // Wait for routes to propagate + wait_for_condition( + || { + let peer_a = peer_a.clone(); + let peer_c = peer_c.clone(); + async move { wait_route_appear(peer_a.clone(), peer_c).await.is_ok() } + }, + Duration::from_secs(10), + ) + .await; + + // Wait for noise_static_pubkey to be available + wait_for_condition( + || { + let peer_a = peer_a.clone(); + async move { + peer_a + .get_peer_map() + .get_route_peer_info(peer_c_id) + .await + .map(|info| !info.noise_static_pubkey.is_empty()) + .unwrap_or(false) + } + }, + Duration::from_secs(10), + ) + .await; + + // Get the original correct pubkey to verify it exists + let original_info = peer_a + .get_peer_map() + .get_route_peer_info(peer_c_id) + .await + .expect("should have route info for peer_c"); + assert!( + !original_info.noise_static_pubkey.is_empty(), + "noise_static_pubkey should be present" + ); + + // Attempt handshake - this should succeed because pubkeys match + let relay_a = peer_a.get_relay_peer_map(); + let result = relay_a + .handshake_session(peer_c_id, NextHopPolicy::LeastHop, None) + .await; + + // The handshake should succeed because the pubkeys match + assert!( + result.is_ok(), + "handshake should succeed with matching pubkeys" + ); + + // Verify session was established on both sides + wait_for_condition( + || { + let relay_a = relay_a.clone(); + async move { relay_a.has_session(peer_c_id) } + }, + Duration::from_secs(5), + ) + .await; + + let relay_c = peer_c.get_relay_peer_map(); + wait_for_condition( + || { + let relay_c = relay_c.clone(); + async move { relay_c.has_session(peer_a_id) } + }, + Duration::from_secs(5), + ) + .await; +} + +#[tokio::test] +async fn relay_peer_map_remove_peer() { + let (s, _r) = create_packet_recv_chan(); + let ctx = get_mock_global_ctx_with_network(Some(NetworkIdentity::new( + "net1".to_string(), + "sec1".to_string(), + ))); + set_secure_mode_cfg(&ctx, true); + let peer_map = Arc::new(PeerMap::new(s, ctx.clone(), 10)); + let store = Arc::new(PeerSessionStore::new()); + let relay_map = RelayPeerMap::new(peer_map, None, ctx.clone(), 10, store.clone()); + + let peer_1: PeerId = 100; + + // Add session for peer_1 + let root_key = [1u8; 32]; + let session = Arc::new(PeerSession::new( + peer_1, + root_key, + 1, + 0, + "aes-256-gcm".to_string(), + "aes-256-gcm".to_string(), + None, + )); + let key = SessionKey::new(ctx.get_network_name(), peer_1); + store.insert_session(key.clone(), session); + + assert!(store.get(&key).is_some()); + + // Remove the peer relay state + relay_map.remove_peer(peer_1); + + // Session should still be in the store (lifecycle is independent of relay state) + assert!( + store.get(&key).is_some(), + "session should persist after relay peer removal" + ); +} + +/// Test bidirectional handshake race resolution. +/// When both peers simultaneously initiate handshake, the one with smaller peer_id +/// should become initiator, and the other should yield and become responder. +#[tokio::test] +async fn relay_peer_map_bidirectional_handshake_race() { + // Create three peers: A -> B -> C + let peer_a = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_b = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_c = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + + connect_peer_manager(peer_a.clone(), peer_b.clone()).await; + connect_peer_manager(peer_b.clone(), peer_c.clone()).await; + + let peer_a_id = peer_a.my_peer_id(); + let peer_c_id = peer_c.my_peer_id(); + + // Wait for routes to propagate + wait_for_condition( + || { + let peer_a = peer_a.clone(); + let peer_c = peer_c.clone(); + async move { wait_route_appear(peer_a.clone(), peer_c).await.is_ok() } + }, + Duration::from_secs(10), + ) + .await; + + // Wait for noise_static_pubkey to be available + wait_for_condition( + || { + let peer_a = peer_a.clone(); + async move { + peer_a + .get_peer_map() + .get_route_peer_info(peer_c_id) + .await + .map(|info| !info.noise_static_pubkey.is_empty()) + .unwrap_or(false) + } + }, + Duration::from_secs(10), + ) + .await; + + wait_for_condition( + || { + let peer_c = peer_c.clone(); + async move { + peer_c + .get_peer_map() + .get_route_peer_info(peer_a_id) + .await + .map(|info| !info.noise_static_pubkey.is_empty()) + .unwrap_or(false) + } + }, + Duration::from_secs(10), + ) + .await; + + // Simulate bidirectional handshake race by having both sides initiate simultaneously + let relay_a = peer_a.get_relay_peer_map(); + let relay_c = peer_c.get_relay_peer_map(); + + // Both sides initiate handshake at the same time + let handle_a = tokio::spawn({ + let relay_a = relay_a.clone(); + async move { + relay_a + .handshake_session(peer_c_id, NextHopPolicy::LeastHop, None) + .await + } + }); + + let handle_c = tokio::spawn({ + let relay_c = relay_c.clone(); + async move { + relay_c + .handshake_session(peer_a_id, NextHopPolicy::LeastHop, None) + .await + } + }); + + // Wait for both handshakes to complete + let (result_a, result_c) = tokio::join!(handle_a, handle_c); + + // At least one should succeed (the initiator with smaller peer_id) + // Both could succeed if race resolution worked correctly + tracing::info!( + ?peer_a_id, + ?peer_c_id, + ?result_a, + ?result_c, + "bidirectional handshake results" + ); + + // Wait for sessions to be established + wait_for_condition( + || { + let relay_a = relay_a.clone(); + async move { relay_a.has_session(peer_c_id) } + }, + Duration::from_secs(5), + ) + .await; + + wait_for_condition( + || { + let relay_c = relay_c.clone(); + async move { relay_c.has_session(peer_a_id) } + }, + Duration::from_secs(5), + ) + .await; + + // Both sides should have sessions after race resolution + assert!( + relay_a.has_session(peer_c_id), + "peer_a should have session with peer_c" + ); + assert!( + relay_c.has_session(peer_a_id), + "peer_c should have session with peer_a" + ); +} + +/// Helper: create a secure peer manager for a credential node. +/// Uses the given X25519 private key as the Noise static key, with no network_secret. +pub async fn create_mock_peer_manager_credential( + network_name: String, + private_key: &x25519_dalek::StaticSecret, +) -> Arc { + use crate::common::config::NetworkIdentity; + use crate::proto::common::SecureModeConfig; + use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; + use base64::Engine; + + let (s, _r) = create_packet_recv_chan(); + let g = get_mock_global_ctx_with_network(Some(NetworkIdentity::new_credential(network_name))); + + let public = x25519_dalek::PublicKey::from(private_key); + g.config.set_secure_mode(Some(SecureModeConfig { + enabled: true, + local_private_key: Some(BASE64_STANDARD.encode(private_key.as_bytes())), + local_public_key: Some(BASE64_STANDARD.encode(public.as_bytes())), + })); + + let peer_mgr = Arc::new(PeerManager::new(RouteAlgoType::Ospf, g, s)); + peer_mgr.run().await.unwrap(); + peer_mgr +} + +/// Test: credential node joins a 2-admin network and routes appear. +/// Topology: Admin_A -- Credential_C, Admin_A -- Admin_B +/// Credential node connects to the admin that generated the credential. +#[tokio::test] +async fn credential_node_joins_network() { + let admin_a = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + let admin_b = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + + // Generate credential on admin_a + let (_cred_id, cred_secret) = admin_a + .get_global_ctx() + .get_credential_manager() + .generate_credential( + vec!["guest".to_string()], + false, + vec![], + std::time::Duration::from_secs(3600), + ); + + // Create credential node using the generated key + let privkey_bytes: [u8; 32] = base64::engine::general_purpose::STANDARD + .decode(&cred_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + let cred_c = create_mock_peer_manager_credential("net1".to_string(), &private).await; + + // Connect admins first + connect_peer_manager(admin_a.clone(), admin_b.clone()).await; + + // Admin A and B should discover each other + wait_route_appear(admin_a.clone(), admin_b.clone()) + .await + .unwrap(); + + // Now connect credential node to admin A (credential as client) + connect_peer_manager(cred_c.clone(), admin_a.clone()).await; + + // Credential node C should be reachable from admin B (via A) + let cred_c_id = cred_c.my_peer_id(); + wait_for_condition( + || { + let admin_b = admin_b.clone(); + async move { + admin_b + .list_routes() + .await + .iter() + .any(|r| r.peer_id == cred_c_id) + } + }, + Duration::from_secs(10), + ) + .await; + + // Credential node C should see admin B + wait_for_condition( + || { + let cred_c = cred_c.clone(); + let admin_b_id = admin_b.my_peer_id(); + async move { + cred_c + .list_routes() + .await + .iter() + .any(|r| r.peer_id == admin_b_id) + } + }, + Duration::from_secs(10), + ) + .await; +} + +/// Test: credential node is rejected when its pubkey is not in any admin's trusted list. +/// Topology: Admin_A -- Unknown_B (random key, not in trusted list) +#[tokio::test] +async fn unknown_credential_node_rejected() { + let admin_a = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + + // Create a credential node with a random key (NOT generated by admin) + let random_private = x25519_dalek::StaticSecret::random_from_rng(rand::rngs::OsRng); + let unknown_c = create_mock_peer_manager_credential("net1".to_string(), &random_private).await; + + // Try to connect: C -> A (unknown credential as client, admin as server) + connect_peer_manager(unknown_c.clone(), admin_a.clone()).await; + + // The handshake should fail so the connection won't establish. + // Wait a bit and verify no route appears. + tokio::time::sleep(Duration::from_secs(3)).await; + + let routes = admin_a.list_routes().await; + assert!( + !routes.iter().any(|r| r.peer_id == unknown_c.my_peer_id()), + "unknown credential node should NOT appear in admin's routes" + ); +} + +/// Test: after revocation, the credential node disappears from routes. +/// Topology: Admin_A -- Credential_C, Admin_A -- Admin_B +/// After revocation on A, C should be removed from B's route table. +#[tokio::test] +async fn credential_revocation_removes_from_routes() { + let admin_a = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + let admin_b = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + + let (cred_id, cred_secret) = admin_a + .get_global_ctx() + .get_credential_manager() + .generate_credential(vec![], false, vec![], std::time::Duration::from_secs(3600)); + + let privkey_bytes: [u8; 32] = base64::engine::general_purpose::STANDARD + .decode(&cred_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + let cred_c = create_mock_peer_manager_credential("net1".to_string(), &private).await; + + // Connect: A -- B, C -> A (credential node as client, admin as server) + connect_peer_manager(admin_a.clone(), admin_b.clone()).await; + connect_peer_manager(cred_c.clone(), admin_a.clone()).await; + + // Wait for credential node to appear in admin_b's routes + let cred_c_id = cred_c.my_peer_id(); + wait_for_condition( + || { + let admin_b = admin_b.clone(); + async move { + admin_b + .list_routes() + .await + .iter() + .any(|r| r.peer_id == cred_c_id) + } + }, + Duration::from_secs(10), + ) + .await; + + // Now revoke the credential + assert!(admin_a + .get_global_ctx() + .get_credential_manager() + .revoke_credential(&cred_id)); + // Issue event to trigger OSPF sync + admin_a + .get_global_ctx() + .issue_event(crate::common::global_ctx::GlobalCtxEvent::CredentialChanged); + + // Wait for credential node to disappear from admin_b's routes + wait_for_condition( + || { + let admin_b = admin_b.clone(); + async move { + !admin_b + .list_routes() + .await + .iter() + .any(|r| r.peer_id == cred_c_id) + } + }, + Duration::from_secs(15), + ) + .await; +} + +#[tokio::test] +async fn credential_expiry_disconnects_from_all_admins() { + let admin_a = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + let admin_b = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + + connect_peer_manager(admin_a.clone(), admin_b.clone()).await; + wait_route_appear(admin_a.clone(), admin_b.clone()) + .await + .unwrap(); + + let (_cred_id, cred_secret) = admin_a + .get_global_ctx() + .get_credential_manager() + .generate_credential(vec![], false, vec![], std::time::Duration::from_secs(2)); + + admin_a + .get_global_ctx() + .issue_event(crate::common::global_ctx::GlobalCtxEvent::CredentialChanged); + + let privkey_bytes: [u8; 32] = base64::engine::general_purpose::STANDARD + .decode(&cred_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + let cred_c = create_mock_peer_manager_credential("net1".to_string(), &private).await; + let cred_c_id = cred_c.my_peer_id(); + + connect_peer_manager(cred_c.clone(), admin_a.clone()).await; + + wait_for_condition( + || { + let admin_b = admin_b.clone(); + async move { + admin_b + .list_routes() + .await + .iter() + .any(|r| r.peer_id == cred_c_id) + } + }, + Duration::from_secs(10), + ) + .await; + + connect_peer_manager(cred_c.clone(), admin_b.clone()).await; + + wait_for_condition( + || { + let admin_b = admin_b.clone(); + async move { + admin_b + .get_peer_map() + .list_peer_conns(cred_c_id) + .await + .is_some_and(|conns| !conns.is_empty()) + } + }, + Duration::from_secs(10), + ) + .await; + + tokio::time::sleep(Duration::from_secs(3)).await; + admin_a + .get_global_ctx() + .issue_event(crate::common::global_ctx::GlobalCtxEvent::CredentialChanged); + + wait_for_condition( + || { + let admin_b = admin_b.clone(); + async move { + !admin_b + .list_routes() + .await + .iter() + .any(|r| r.peer_id == cred_c_id) + } + }, + Duration::from_secs(20), + ) + .await; + + wait_for_condition( + || { + let admin_b = admin_b.clone(); + async move { + admin_b + .get_peer_map() + .list_peer_conns(cred_c_id) + .await + .is_none_or(|conns| conns.is_empty()) + } + }, + Duration::from_secs(20), + ) + .await; +} + +/// Test: admin node with credential — credential node gets group assignment. +/// Verify that the credential node's groups appear in the OSPF sync data. +#[tokio::test] +async fn credential_node_group_assignment() { + let admin_a = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + let admin_b = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + + let (_cred_id, cred_secret) = admin_a + .get_global_ctx() + .get_credential_manager() + .generate_credential( + vec!["guest".to_string(), "limited".to_string()], + false, + vec![], + std::time::Duration::from_secs(3600), + ); + + let privkey_bytes: [u8; 32] = base64::engine::general_purpose::STANDARD + .decode(&cred_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + let cred_c = create_mock_peer_manager_credential("net1".to_string(), &private).await; + + connect_peer_manager(admin_a.clone(), admin_b.clone()).await; + connect_peer_manager(cred_c.clone(), admin_a.clone()).await; + + // Wait for credential node route to appear on admin_b (via OSPF through admin_a) + let cred_c_id = cred_c.my_peer_id(); + wait_for_condition( + || { + let admin_b = admin_b.clone(); + async move { + admin_b + .list_routes() + .await + .iter() + .any(|r| r.peer_id == cred_c_id) + } + }, + Duration::from_secs(10), + ) + .await; + + // Verify the credential node's groups are assigned via OSPF on admin_b + // (admin_b gets the groups from admin_a's TrustedCredentialPubkey via OSPF sync) + wait_for_condition( + || { + let admin_b = admin_b.clone(); + async move { + let g = admin_b.get_route().get_peer_groups(cred_c_id); + g.contains(&"guest".to_string()) && g.contains(&"limited".to_string()) + } + }, + Duration::from_secs(10), + ) + .await; +} + +/// Minimal test: two secure peers connect and discover each other's route. +#[tokio::test] +async fn two_secure_peers_route_appear() { + let peer_a = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + let peer_b = create_mock_peer_manager_secure("net1".to_string(), "sec1".to_string()).await; + + connect_peer_manager(peer_a.clone(), peer_b.clone()).await; + + wait_route_appear(peer_a.clone(), peer_b.clone()) + .await + .unwrap(); +} + +#[tokio::test] +async fn multi_admin_multi_credential_route_and_revocation_isolation() { + let admin_a = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + let admin_b = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + let admin_d = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + + connect_peer_manager(admin_a.clone(), admin_b.clone()).await; + connect_peer_manager(admin_b.clone(), admin_d.clone()).await; + connect_peer_manager(admin_a.clone(), admin_d.clone()).await; + + wait_route_appear(admin_a.clone(), admin_b.clone()) + .await + .unwrap(); + wait_route_appear(admin_b.clone(), admin_d.clone()) + .await + .unwrap(); + wait_route_appear(admin_a.clone(), admin_d.clone()) + .await + .unwrap(); + + let (cred1_id, cred1_secret) = admin_a + .get_global_ctx() + .get_credential_manager() + .generate_credential( + vec!["guest-a".to_string()], + false, + vec![], + std::time::Duration::from_secs(3600), + ); + let (_cred2_id, cred2_secret) = admin_b + .get_global_ctx() + .get_credential_manager() + .generate_credential( + vec!["guest-b".to_string()], + false, + vec![], + std::time::Duration::from_secs(3600), + ); + + let cred1_private: [u8; 32] = base64::engine::general_purpose::STANDARD + .decode(&cred1_secret) + .unwrap() + .try_into() + .unwrap(); + let cred2_private: [u8; 32] = base64::engine::general_purpose::STANDARD + .decode(&cred2_secret) + .unwrap() + .try_into() + .unwrap(); + let cred_1 = create_mock_peer_manager_credential( + "net1".to_string(), + &x25519_dalek::StaticSecret::from(cred1_private), + ) + .await; + let cred_2 = create_mock_peer_manager_credential( + "net1".to_string(), + &x25519_dalek::StaticSecret::from(cred2_private), + ) + .await; + + connect_peer_manager(cred_1.clone(), admin_a.clone()).await; + connect_peer_manager(cred_2.clone(), admin_b.clone()).await; + + let cred_1_id = cred_1.my_peer_id(); + let cred_2_id = cred_2.my_peer_id(); + + wait_for_condition( + || { + let admin_d = admin_d.clone(); + async move { + let routes = admin_d.list_routes().await; + routes.iter().any(|r| r.peer_id == cred_1_id) + && routes.iter().any(|r| r.peer_id == cred_2_id) + } + }, + Duration::from_secs(15), + ) + .await; + + wait_for_condition( + || { + let admin_d = admin_d.clone(); + async move { + let g1 = admin_d.get_route().get_peer_groups(cred_1_id); + let g2 = admin_d.get_route().get_peer_groups(cred_2_id); + g1.contains(&"guest-a".to_string()) && g2.contains(&"guest-b".to_string()) + } + }, + Duration::from_secs(15), + ) + .await; + + assert!(admin_a + .get_global_ctx() + .get_credential_manager() + .revoke_credential(&cred1_id)); + admin_a + .get_global_ctx() + .issue_event(crate::common::global_ctx::GlobalCtxEvent::CredentialChanged); + + wait_for_condition( + || { + let admin_d = admin_d.clone(); + async move { + let routes = admin_d.list_routes().await; + !routes.iter().any(|r| r.peer_id == cred_1_id) + && routes.iter().any(|r| r.peer_id == cred_2_id) + } + }, + Duration::from_secs(20), + ) + .await; +} + +#[tokio::test] +async fn unknown_credential_rejected_while_valid_credential_survives() { + let admin_a = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + let admin_b = create_mock_peer_manager_secure("net1".to_string(), "secret".to_string()).await; + + connect_peer_manager(admin_a.clone(), admin_b.clone()).await; + wait_route_appear(admin_a.clone(), admin_b.clone()) + .await + .unwrap(); + + let (_cred_id, cred_secret) = admin_a + .get_global_ctx() + .get_credential_manager() + .generate_credential( + vec!["stable".to_string()], + false, + vec![], + std::time::Duration::from_secs(3600), + ); + + let valid_private: [u8; 32] = base64::engine::general_purpose::STANDARD + .decode(&cred_secret) + .unwrap() + .try_into() + .unwrap(); + let valid_cred = create_mock_peer_manager_credential( + "net1".to_string(), + &x25519_dalek::StaticSecret::from(valid_private), + ) + .await; + let unknown_private = x25519_dalek::StaticSecret::random_from_rng(rand::rngs::OsRng); + let unknown_cred = + create_mock_peer_manager_credential("net1".to_string(), &unknown_private).await; + + connect_peer_manager(valid_cred.clone(), admin_a.clone()).await; + let (unknown_ring_client, unknown_ring_server) = create_ring_tunnel_pair(); + let unknown_connect_client = tokio::spawn({ + let unknown_cred = unknown_cred.clone(); + async move { + unknown_cred + .add_client_tunnel(unknown_ring_client, false) + .await + } + }); + let unknown_connect_server = tokio::spawn({ + let admin_a = admin_a.clone(); + async move { + admin_a + .add_tunnel_as_server(unknown_ring_server, true) + .await + } + }); + let (unknown_client_ret, unknown_server_ret) = + tokio::join!(unknown_connect_client, unknown_connect_server); + assert!( + unknown_client_ret.unwrap().is_err() || unknown_server_ret.unwrap().is_err(), + "unknown credential connection should fail on at least one side" + ); + + let valid_id = valid_cred.my_peer_id(); + let unknown_id = unknown_cred.my_peer_id(); + + wait_for_condition( + || { + let admin_b = admin_b.clone(); + async move { + admin_b + .list_routes() + .await + .iter() + .any(|r| r.peer_id == valid_id) + } + }, + Duration::from_secs(15), + ) + .await; + + tokio::time::sleep(Duration::from_secs(5)).await; + + let routes = admin_b.list_routes().await; + assert!(routes.iter().any(|r| r.peer_id == valid_id)); + assert!(!routes.iter().any(|r| r.peer_id == unknown_id)); +} diff --git a/easytier/src/proto/api_instance.proto b/easytier/src/proto/api_instance.proto index e6947cc82..1230adb95 100644 --- a/easytier/src/proto/api_instance.proto +++ b/easytier/src/proto/api_instance.proto @@ -44,6 +44,7 @@ message PeerConnInfo { bytes noise_local_static_pubkey = 11; bytes noise_remote_static_pubkey = 12; peer_rpc.SecureAuthLevel secure_auth_level = 13; + peer_rpc.PeerIdentityType peer_identity_type = 14; } message PeerInfo { @@ -291,3 +292,45 @@ service StatsRpc { rpc GetPrometheusStats(GetPrometheusStatsRequest) returns (GetPrometheusStatsResponse); } + +// Credential management messages + +message GenerateCredentialRequest { + repeated string groups = 1; // optional: ACL groups for this credential + bool allow_relay = 2; // optional: allow relay through credential node + repeated string allowed_proxy_cidrs = 3; // optional: restrict proxy_cidrs + int64 ttl_seconds = 4; // must be > 0: credential TTL in seconds (0 / omitted is invalid) +} + +message GenerateCredentialResponse { + string credential_id = 1; // UUID + string credential_secret = 2; // private key base64 +} + +message RevokeCredentialRequest { + string credential_id = 1; +} + +message RevokeCredentialResponse { + bool success = 1; +} + +message ListCredentialsRequest {} + +message CredentialInfo { + string credential_id = 1; // UUID + repeated string groups = 2; + bool allow_relay = 3; + int64 expiry_unix = 4; + repeated string allowed_proxy_cidrs = 5; +} + +message ListCredentialsResponse { + repeated CredentialInfo credentials = 1; +} + +service CredentialManageRpc { + rpc GenerateCredential(GenerateCredentialRequest) returns (GenerateCredentialResponse); + rpc RevokeCredential(RevokeCredentialRequest) returns (RevokeCredentialResponse); + rpc ListCredentials(ListCredentialsRequest) returns (ListCredentialsResponse); +} diff --git a/easytier/src/proto/common.proto b/easytier/src/proto/common.proto index bb9f82b93..1ef107b21 100644 --- a/easytier/src/proto/common.proto +++ b/easytier/src/proto/common.proto @@ -216,6 +216,7 @@ message PeerFeatureFlag { bool support_conn_list_sync = 5; bool quic_input = 6; bool no_relay_quic = 7; + bool is_credential_peer = 8; } enum SocketType { diff --git a/easytier/src/proto/peer_rpc.proto b/easytier/src/proto/peer_rpc.proto index 77861f741..5ededd122 100644 --- a/easytier/src/proto/peer_rpc.proto +++ b/easytier/src/proto/peer_rpc.proto @@ -5,6 +5,19 @@ import "common.proto"; package peer_rpc; +message TrustedCredentialPubkey { + bytes pubkey = 1; // X25519 public key (32 bytes) + repeated string groups = 2; // ACL groups this credential belongs to + bool allow_relay = 3; // whether this credential node can relay data + int64 expiry_unix = 4; // expiry time (Unix timestamp) + repeated string allowed_proxy_cidrs = 5; // allowed proxy_cidrs ranges +} + +message TrustedCredentialPubkeyProof { + TrustedCredentialPubkey credential = 1; + bytes credential_hmac = 2; +} + message RoutePeerInfo { // means next hop in route table. uint32 peer_id = 1; @@ -29,6 +42,10 @@ message RoutePeerInfo { repeated PeerGroupInfo groups = 16; common.NatType tcp_nat_type = 17; + bytes noise_static_pubkey = 18; + + // Trusted credential public keys published by admin nodes (holding network_secret) + repeated TrustedCredentialPubkeyProof trusted_credential_pubkeys = 19; } message PeerIdVersion { @@ -262,10 +279,16 @@ message KcpConnData { enum SecureAuthLevel { None = 0; EncryptedUnauthenticated = 1; - SharedNodePubkeyVerified = 2; + PeerVerified = 2; NetworkSecretConfirmed = 3; } +enum PeerIdentityType { + Admin = 0; + Credential = 1; + SharedNode = 2; +} + enum PeerConnSessionActionPb { Join = 0; Sync = 1; @@ -293,6 +316,23 @@ message PeerConnNoiseMsg2Pb { string server_encryption_algorithm = 10; } +message RelayNoiseMsg1Pb { + uint32 version = 1; + optional uint32 a_session_generation = 3; + common.UUID a_conn_id = 4; + string client_encryption_algorithm = 5; +} + +message RelayNoiseMsg2Pb { + PeerConnSessionActionPb action = 3; + uint32 b_session_generation = 4; + optional bytes root_key_32 = 5; + uint32 initial_epoch = 6; + common.UUID b_conn_id = 7; + common.UUID a_conn_id_echo = 8; + string server_encryption_algorithm = 10; +} + message PeerConnNoiseMsg3Pb { common.UUID a_conn_id_echo = 1; common.UUID b_conn_id_echo = 2; diff --git a/easytier/src/proto/peer_rpc.rs b/easytier/src/proto/peer_rpc.rs index f6a5fded2..f60a72a44 100644 --- a/easytier/src/proto/peer_rpc.rs +++ b/easytier/src/proto/peer_rpc.rs @@ -1,4 +1,5 @@ use hmac::{Hmac, Mac}; +use prost::Message; use sha2::Sha256; use crate::common::PeerId; @@ -38,6 +39,42 @@ impl PeerGroupInfo { } } +impl TrustedCredentialPubkeyProof { + pub fn generate_credential_hmac( + credential: &TrustedCredentialPubkey, + network_secret: &str, + ) -> Vec { + let mut mac = Hmac::::new_from_slice(network_secret.as_bytes()) + .expect("HMAC can take key of any size"); + mac.update(b"easytier credential proof"); + mac.update(&credential.encode_to_vec()); + mac.finalize().into_bytes().to_vec() + } + + pub fn new_signed(credential: TrustedCredentialPubkey, network_secret: &str) -> Self { + let credential_hmac = Self::generate_credential_hmac(&credential, network_secret); + Self { + credential: Some(credential), + credential_hmac, + } + } + + pub fn verify_credential_hmac(&self, network_secret: &str) -> bool { + let Some(credential) = self.credential.as_ref() else { + return false; + }; + if self.credential_hmac.is_empty() { + return false; + } + + let mut mac = Hmac::::new_from_slice(network_secret.as_bytes()) + .expect("HMAC can take key of any size"); + mac.update(b"easytier credential proof"); + mac.update(&credential.encode_to_vec()); + mac.verify_slice(&self.credential_hmac).is_ok() + } +} + impl From for sync_route_info_request::ConnInfo { fn from(val: RouteConnBitmap) -> Self { Self::ConnBitmap(val) @@ -254,4 +291,35 @@ mod tests { println!("verify took {:?} for {} iterations", duration, iterations); println!("Avg time per iteration: {:?}", duration / iterations as u32); } + + #[test] + fn test_trusted_credential_pubkey_hmac_valid() { + let credential = TrustedCredentialPubkey { + pubkey: vec![7u8; 32], + groups: vec!["ops".to_string(), "guest".to_string()], + allow_relay: true, + expiry_unix: 123456, + allowed_proxy_cidrs: vec!["10.0.0.0/24".to_string()], + }; + let tc = TrustedCredentialPubkeyProof::new_signed(credential, "sec-1"); + + assert!(tc.verify_credential_hmac("sec-1")); + assert!(!tc.verify_credential_hmac("sec-2")); + } + + #[test] + fn test_trusted_credential_pubkey_hmac_tampered() { + let credential = TrustedCredentialPubkey { + pubkey: vec![8u8; 32], + groups: vec!["g1".to_string()], + allow_relay: false, + expiry_unix: 1, + allowed_proxy_cidrs: vec![], + }; + let tc = TrustedCredentialPubkeyProof::new_signed(credential, "sec-1"); + + let mut tampered = tc.clone(); + tampered.credential.as_mut().unwrap().allow_relay = true; + assert!(!tampered.verify_credential_hmac("sec-1")); + } } diff --git a/easytier/src/proto/web.proto b/easytier/src/proto/web.proto index 68429bc9a..0b283254e 100644 --- a/easytier/src/proto/web.proto +++ b/easytier/src/proto/web.proto @@ -18,6 +18,13 @@ message HeartbeatRequest { message HeartbeatResponse {} +message GetFeatureRequest {} + +message GetFeatureResponse { + bool support_encryption = 1; +} + service WebServerService { rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse); -} \ No newline at end of file + rpc GetFeature(GetFeatureRequest) returns (GetFeatureResponse); +} diff --git a/easytier/src/rpc_service/api.rs b/easytier/src/rpc_service/api.rs index 37051e4df..23d2149ec 100644 --- a/easytier/src/rpc_service/api.rs +++ b/easytier/src/rpc_service/api.rs @@ -10,9 +10,9 @@ use crate::{ api::{ config::ConfigRpcServer, instance::{ - AclManageRpcServer, ConnectorManageRpcServer, MappedListenerManageRpcServer, - PeerManageRpcServer, PortForwardManageRpcServer, StatsRpcServer, TcpProxyRpcServer, - VpnPortalRpcServer, + AclManageRpcServer, ConnectorManageRpcServer, CredentialManageRpcServer, + MappedListenerManageRpcServer, PeerManageRpcServer, PortForwardManageRpcServer, + StatsRpcServer, TcpProxyRpcServer, VpnPortalRpcServer, }, logger::LoggerRpcServer, manage::WebClientServiceServer, @@ -23,8 +23,9 @@ use crate::{ }, rpc_service::{ acl_manage::AclManageRpcService, config::ConfigRpcService, - connector_manage::ConnectorManageRpcService, instance_manage::InstanceManageRpcService, - logger::LoggerRpcService, mapped_listener_manage::MappedListenerManageRpcService, + connector_manage::ConnectorManageRpcService, credential_manage::CredentialManageRpcService, + instance_manage::InstanceManageRpcService, logger::LoggerRpcService, + mapped_listener_manage::MappedListenerManageRpcService, peer_center::PeerCenterManageRpcService, peer_manage::PeerManageRpcService, port_forward_manage::PortForwardManageRpcService, proxy::TcpProxyRpcService, stats::StatsRpcService, vpn_portal::VpnPortalRpcService, @@ -156,6 +157,11 @@ fn register_api_rpc_service( PeerCenterRpcServer::new(PeerCenterManageRpcService::new(instance_manager.clone())), "", ); + + registry.register( + CredentialManageRpcServer::new(CredentialManageRpcService::new(instance_manager.clone())), + "", + ); } fn parse_rpc_portal(rpc_portal: Option) -> anyhow::Result { diff --git a/easytier/src/rpc_service/credential_manage.rs b/easytier/src/rpc_service/credential_manage.rs new file mode 100644 index 000000000..5b13d0bc3 --- /dev/null +++ b/easytier/src/rpc_service/credential_manage.rs @@ -0,0 +1,62 @@ +use std::sync::Arc; + +use crate::{ + instance_manager::NetworkInstanceManager, + proto::{ + api::instance::{ + CredentialManageRpc, GenerateCredentialRequest, GenerateCredentialResponse, + ListCredentialsRequest, ListCredentialsResponse, RevokeCredentialRequest, + RevokeCredentialResponse, + }, + rpc_types::controller::BaseController, + }, +}; + +#[derive(Clone)] +pub struct CredentialManageRpcService { + instance_manager: Arc, +} + +impl CredentialManageRpcService { + pub fn new(instance_manager: Arc) -> Self { + Self { instance_manager } + } +} + +#[async_trait::async_trait] +impl CredentialManageRpc for CredentialManageRpcService { + type Controller = BaseController; + + async fn generate_credential( + &self, + ctrl: Self::Controller, + req: GenerateCredentialRequest, + ) -> crate::proto::rpc_types::error::Result { + super::get_instance_service(&self.instance_manager, &None)? + .get_credential_manage_service() + .generate_credential(ctrl, req) + .await + } + + async fn revoke_credential( + &self, + ctrl: Self::Controller, + req: RevokeCredentialRequest, + ) -> crate::proto::rpc_types::error::Result { + super::get_instance_service(&self.instance_manager, &None)? + .get_credential_manage_service() + .revoke_credential(ctrl, req) + .await + } + + async fn list_credentials( + &self, + ctrl: Self::Controller, + req: ListCredentialsRequest, + ) -> crate::proto::rpc_types::error::Result { + super::get_instance_service(&self.instance_manager, &None)? + .get_credential_manage_service() + .list_credentials(ctrl, req) + .await + } +} diff --git a/easytier/src/rpc_service/mod.rs b/easytier/src/rpc_service/mod.rs index cddec6afd..e06d05abd 100644 --- a/easytier/src/rpc_service/mod.rs +++ b/easytier/src/rpc_service/mod.rs @@ -2,6 +2,7 @@ mod acl_manage; mod api; mod config; mod connector_manage; +mod credential_manage; mod mapped_listener_manage; mod peer_center; mod peer_manage; @@ -76,6 +77,11 @@ pub trait InstanceRpcService: Sync + Send { > + Send + Sync, >; + fn get_credential_manage_service( + &self, + ) -> &dyn crate::proto::api::instance::CredentialManageRpc< + Controller = crate::proto::rpc_types::controller::BaseController, + >; } fn get_instance_service( diff --git a/easytier/src/tests/credential_tests.rs b/easytier/src/tests/credential_tests.rs new file mode 100644 index 000000000..650daedc7 --- /dev/null +++ b/easytier/src/tests/credential_tests.rs @@ -0,0 +1,931 @@ +//! Credential system integration tests +//! +//! These tests verify the credential-based authentication system where: +//! - Admin nodes hold network_secret and can generate credentials +//! - Credential nodes use X25519 keypairs to authenticate without network_secret +//! - Credentials can be revoked and propagate across the network + +use std::time::Duration; + +use crate::{ + common::{ + config::{ConfigLoader, NetworkIdentity, TomlConfigLoader}, + global_ctx::GlobalCtxEvent, + }, + instance::instance::Instance, + tests::three_node::{generate_secure_mode_config, generate_secure_mode_config_with_key}, + tunnel::{common::tests::wait_for_condition, tcp::TcpTunnelConnector}, +}; + +use super::{add_ns_to_bridge, create_netns, del_netns, drop_insts, ping_test}; + +use rstest::rstest; + +/// Prepare network namespaces for credential tests +/// Topology: +/// br_a (10.1.1.0/24): ns_adm (10.1.1.1), ns_c1 (10.1.1.2), ns_c2 (10.1.1.3), ns_c3 (10.1.1.4) +/// br_b (10.1.2.0/24): ns_adm2 (10.1.2.1) - for multi-admin tests +/// Note: Using short names (max 15 chars for veth interfaces) +pub fn prepare_credential_network() { + // Clean up any existing namespaces + for ns in ["ns_adm", "ns_c1", "ns_c2", "ns_c3", "ns_adm2"] { + del_netns(ns); + } + + // Create bridge br_a for admin and credentials + let _ = std::process::Command::new("ip") + .args(["link", "del", "br_a"]) + .output(); + let _ = std::process::Command::new("brctl") + .args(["delbr", "br_a"]) + .output(); + let _ = std::process::Command::new("brctl") + .args(["addbr", "br_a"]) + .output() + .expect("Failed to create br_a"); + let _ = std::process::Command::new("ip") + .args(["link", "set", "br_a", "up"]) + .output(); + + // Create namespaces and add to bridge + create_netns("ns_adm", "10.1.1.1/24", "fd11::1/64"); + add_ns_to_bridge("br_a", "ns_adm"); + + create_netns("ns_c1", "10.1.1.2/24", "fd11::2/64"); + add_ns_to_bridge("br_a", "ns_c1"); + + create_netns("ns_c2", "10.1.1.3/24", "fd11::3/64"); + add_ns_to_bridge("br_a", "ns_c2"); + + // Create ns_c3 for relay tests (needs 4 nodes) + create_netns("ns_c3", "10.1.1.4/24", "fd11::4/64"); + add_ns_to_bridge("br_a", "ns_c3"); + + // Create bridge br_b for second admin (multi-admin tests) + let _ = std::process::Command::new("ip") + .args(["link", "del", "br_b"]) + .output(); + let _ = std::process::Command::new("brctl") + .args(["delbr", "br_b"]) + .output(); + let _ = std::process::Command::new("brctl") + .args(["addbr", "br_b"]) + .output() + .expect("Failed to create br_b"); + let _ = std::process::Command::new("ip") + .args(["link", "set", "br_b", "up"]) + .output(); + + create_netns("ns_adm2", "10.1.2.1/24", "fd12::1/64"); + add_ns_to_bridge("br_b", "ns_adm2"); +} + +/// Helper: Create credential node config with generated credential +async fn create_credential_config( + admin_inst: &Instance, + inst_name: &str, + ns: Option<&str>, + ipv4: &str, + ipv6: &str, +) -> TomlConfigLoader { + use base64::Engine as _; + + // Generate credential on admin + let (_cred_id, cred_secret) = admin_inst + .get_global_ctx() + .get_credential_manager() + .generate_credential(vec![], false, vec![], Duration::from_secs(3600)); + + // Decode private key + let privkey_bytes: [u8; 32] = base64::prelude::BASE64_STANDARD + .decode(&cred_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + + // Create config + let config = TomlConfigLoader::default(); + config.set_inst_name(inst_name.to_owned()); + config.set_netns(ns.map(|s| s.to_owned())); + config.set_ipv4(Some(ipv4.parse().unwrap())); + config.set_ipv6(Some(ipv6.parse().unwrap())); + config.set_listeners(vec![]); + config.set_network_identity(NetworkIdentity::new_credential( + admin_inst + .get_global_ctx() + .get_network_identity() + .network_name + .clone(), + )); + config.set_secure_mode(Some(generate_secure_mode_config_with_key(&private))); + + config +} + +/// Helper: Create admin node config +fn create_admin_config( + inst_name: &str, + ns: Option<&str>, + ipv4: &str, + ipv6: &str, +) -> TomlConfigLoader { + let config = TomlConfigLoader::default(); + config.set_inst_name(inst_name.to_owned()); + config.set_netns(ns.map(|s| s.to_owned())); + config.set_ipv4(Some(ipv4.parse().unwrap())); + config.set_ipv6(Some(ipv6.parse().unwrap())); + config.set_listeners(vec![ + "tcp://0.0.0.0:11010".parse().unwrap(), + "udp://0.0.0.0:11010".parse().unwrap(), + ]); + config.set_network_identity(NetworkIdentity::new( + "test_network".to_string(), + "test_secret".to_string(), + )); + config.set_secure_mode(Some(generate_secure_mode_config())); + + config +} + +fn create_shared_config( + inst_name: &str, + ns: Option<&str>, + ipv4: &str, + ipv6: &str, +) -> TomlConfigLoader { + let config = TomlConfigLoader::default(); + config.set_inst_name(inst_name.to_owned()); + config.set_netns(ns.map(|s| s.to_owned())); + config.set_ipv4(Some(ipv4.parse().unwrap())); + config.set_ipv6(Some(ipv6.parse().unwrap())); + config.set_listeners(vec![ + "tcp://0.0.0.0:11010".parse().unwrap(), + "udp://0.0.0.0:11010".parse().unwrap(), + ]); + config.set_network_identity(NetworkIdentity::new( + "shared_network".to_string(), + "".to_string(), + )); + config.set_secure_mode(Some(generate_secure_mode_config())); + config +} + +/// Test 1: Basic credential node connectivity +/// Topology: Admin ← Credential +/// Verifies that a credential node can connect to an admin node and appears in routes +#[tokio::test] +#[serial_test::serial] +async fn credential_basic_connectivity() { + prepare_credential_network(); + + // Create admin node + let admin_config = create_admin_config("admin", Some("ns_adm"), "10.144.144.1", "fd00::1/64"); + let mut admin_inst = Instance::new(admin_config); + admin_inst.run().await.unwrap(); + + // Create credential node + let cred_config = create_credential_config( + &admin_inst, + "cred", + Some("ns_c1"), + "10.144.144.2", + "fd00::2/64", + ) + .await; + let mut cred_inst = Instance::new(cred_config); + cred_inst.run().await.unwrap(); + + // Credential connects to admin + cred_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.1:11010".parse().unwrap(), + )); + + let cred_peer_id = cred_inst.peer_id(); + let admin_peer_id = admin_inst.peer_id(); + println!( + "Admin peer_id: {}, Credential peer_id: {}", + admin_peer_id, cred_peer_id + ); + + // Wait a bit for connection attempt + tokio::time::sleep(Duration::from_secs(2)).await; + + // Check peers and connections + let admin_peers = admin_inst.get_peer_manager().get_peer_map().list_peers(); + let cred_peers = cred_inst.get_peer_manager().get_peer_map().list_peers(); + println!("Admin peers: {:?}", admin_peers); + println!("Credential peers: {:?}", cred_peers); + + // Wait for credential to appear in admin's route table + wait_for_condition( + || async { + let routes = admin_inst.get_peer_manager().list_routes().await; + let cred_routes = cred_inst.get_peer_manager().list_routes().await; + let admin_peers = admin_inst.get_peer_manager().get_peer_map().list_peers(); + let cred_peers = cred_inst.get_peer_manager().get_peer_map().list_peers(); + println!( + "Admin peers: {:?}, routes: {:?}", + admin_peers, + routes + .iter() + .map(|r| (r.peer_id, r.ipv4_addr)) + .collect::>() + ); + println!( + "Credential peers: {:?}, routes: {:?}", + cred_peers, + cred_routes + .iter() + .map(|r| (r.peer_id, r.ipv4_addr)) + .collect::>() + ); + routes.iter().any(|r| r.peer_id == cred_peer_id) + }, + Duration::from_secs(10), + ) + .await; + + // Verify connectivity + wait_for_condition( + || async { ping_test("ns_adm", "10.144.144.2", None).await }, + Duration::from_secs(10), + ) + .await; + + wait_for_condition( + || async { ping_test("ns_c1", "10.144.144.1", None).await }, + Duration::from_secs(10), + ) + .await; + + drop_insts(vec![admin_inst, cred_inst]).await; +} + +/// Test 5-6: Credential relay capability with allow_relay parameter +/// Topology: Admin ← Credential_A, Admin ← Credential_B, Admin ← Credential_C(listener, allow_relay) +/// Verifies routing behavior based on allow_relay flag: +/// - allow_relay=true: A→B route goes through C (cost 2 via C) +/// - allow_relay=false: A→B route goes through Admin (cost 2 via Admin) +#[rstest] +#[case(true)] +#[case(false)] +#[tokio::test] +#[serial_test::serial] +async fn credential_relay_capability(#[case] allow_relay: bool) { + use crate::peers::route_trait::NextHopPolicy; + + prepare_credential_network(); + + // Create admin node + let admin_config = create_admin_config("admin", Some("ns_adm"), "10.144.144.1", "fd00::1/64"); + let mut admin_inst = Instance::new(admin_config); + let mut ff = admin_inst.get_global_ctx().get_feature_flags(); + // if cred c allow relay, we set admin inst avoid relay (if other same-cost path available, admin will not relay data) + ff.avoid_relay_data = allow_relay; + admin_inst.get_global_ctx().set_feature_flags(ff); + admin_inst.run().await.unwrap(); + + let admin_peer_id = admin_inst.peer_id(); + + // Generate credentials for A, B, C + // C has configurable allow_relay + let (_cred_a_id, cred_a_secret) = admin_inst + .get_global_ctx() + .get_credential_manager() + .generate_credential(vec![], false, vec![], Duration::from_secs(3600)); + + let (_cred_b_id, cred_b_secret) = admin_inst + .get_global_ctx() + .get_credential_manager() + .generate_credential(vec![], false, vec![], Duration::from_secs(3600)); + + let (_cred_c_id, cred_c_secret) = admin_inst + .get_global_ctx() + .get_credential_manager() + .generate_credential(vec![], allow_relay, vec![], Duration::from_secs(3600)); + + // Create credential A on ns_c1 + let cred_a_config = { + use base64::Engine as _; + let privkey_bytes: [u8; 32] = base64::prelude::BASE64_STANDARD + .decode(&cred_a_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + let config = TomlConfigLoader::default(); + config.set_inst_name("cred_a".to_string()); + config.set_netns(Some("ns_c1".to_string())); + config.set_ipv4(Some("10.144.144.2".parse().unwrap())); + config.set_ipv6(Some("fd00::2/64".parse().unwrap())); + config.set_network_identity(NetworkIdentity::new_credential( + admin_inst + .get_global_ctx() + .get_network_identity() + .network_name + .clone(), + )); + config.set_secure_mode(Some(generate_secure_mode_config_with_key(&private))); + config + }; + let mut cred_a_inst = Instance::new(cred_a_config); + cred_a_inst.run().await.unwrap(); + + // Create credential B on ns_c2 + let cred_b_config = { + use base64::Engine as _; + let privkey_bytes: [u8; 32] = base64::prelude::BASE64_STANDARD + .decode(&cred_b_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + let config = TomlConfigLoader::default(); + config.set_inst_name("cred_b".to_string()); + config.set_netns(Some("ns_c2".to_string())); + config.set_ipv4(Some("10.144.144.3".parse().unwrap())); + config.set_ipv6(Some("fd00::3/64".parse().unwrap())); + config.set_network_identity(NetworkIdentity::new_credential( + admin_inst + .get_global_ctx() + .get_network_identity() + .network_name + .clone(), + )); + config.set_secure_mode(Some(generate_secure_mode_config_with_key(&private))); + config + }; + let mut cred_b_inst = Instance::new(cred_b_config); + cred_b_inst.run().await.unwrap(); + + // Create credential C on ns_c3 WITH listener (so A and B can connect to it) + let cred_c_config = { + use base64::Engine as _; + let privkey_bytes: [u8; 32] = base64::prelude::BASE64_STANDARD + .decode(&cred_c_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + let config = TomlConfigLoader::default(); + config.set_inst_name("cred_c".to_string()); + config.set_netns(Some("ns_c3".to_string())); + config.set_ipv4(Some("10.144.144.4".parse().unwrap())); + config.set_ipv6(Some("fd00::4/64".parse().unwrap())); + // C has listener so A and B can connect to it + config.set_listeners(vec!["tcp://0.0.0.0:11020".parse().unwrap()]); + config.set_network_identity(NetworkIdentity::new_credential( + admin_inst + .get_global_ctx() + .get_network_identity() + .network_name + .clone(), + )); + config.set_secure_mode(Some(generate_secure_mode_config_with_key(&private))); + config + }; + let mut cred_c_inst = Instance::new(cred_c_config); + cred_c_inst.run().await.unwrap(); + + let cred_a_peer_id = cred_a_inst.peer_id(); + let cred_b_peer_id = cred_b_inst.peer_id(); + let cred_c_peer_id = cred_c_inst.peer_id(); + + // All credentials connect to admin + cred_a_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.1:11010".parse().unwrap(), + )); + cred_b_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.1:11010".parse().unwrap(), + )); + cred_c_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.1:11010".parse().unwrap(), + )); + + // A and B also connect to C (simulating P2P discovery and connection) + // C is on ns_c3 with IP 10.1.1.4, listener on port 11020 + cred_a_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.4:11020".parse().unwrap(), + )); + cred_b_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.4:11020".parse().unwrap(), + )); + // print all peer ids + println!("Admin peer id: {:?}", admin_peer_id); + println!("Cred A peer id: {:?}", cred_a_peer_id); + println!("Cred B peer id: {:?}", cred_b_peer_id); + println!("Cred C peer id: {:?}", cred_c_peer_id); + + // Wait for all nodes to appear in admin's route table + wait_for_condition( + || async { + let routes = admin_inst.get_peer_manager().list_routes().await; + let has_a = routes.iter().any(|r| r.peer_id == cred_a_peer_id); + let has_b = routes.iter().any(|r| r.peer_id == cred_b_peer_id); + let has_c = routes.iter().any(|r| r.peer_id == cred_c_peer_id); + println!("Admin routes: a={}, b={}, c={}", has_a, has_b, has_c); + has_a && has_b && has_c + }, + Duration::from_secs(30), + ) + .await; + + // Wait for P2P connections to establish + wait_for_condition( + || async { + let peers_a = cred_a_inst.get_peer_manager().get_peer_map().list_peers(); + let peers_b = cred_b_inst.get_peer_manager().get_peer_map().list_peers(); + let peers_c = cred_c_inst.get_peer_manager().get_peer_map().list_peers(); + + let a_connected_c = peers_a.contains(&cred_c_peer_id); + let b_connected_c = peers_b.contains(&cred_c_peer_id); + let c_connected_a = peers_c.contains(&cred_a_peer_id); + let c_connected_b = peers_c.contains(&cred_b_peer_id); + + println!( + "P2P: A->C={}, B->C={}, C->A={}, C->B={}, allow_relay={}", + a_connected_c, b_connected_c, c_connected_a, c_connected_b, allow_relay + ); + + if allow_relay { + a_connected_c && b_connected_c && c_connected_a && c_connected_b + } else { + a_connected_c && b_connected_c + } + }, + Duration::from_secs(30), + ) + .await; + + // Wait for routes to propagate + wait_for_condition( + || async { + let routes_a = cred_a_inst.get_peer_manager().list_routes().await; + let a_sees_b = routes_a.iter().any(|r| r.peer_id == cred_b_peer_id); + let cost_a_to_b = routes_a + .iter() + .find(|r| r.peer_id == cred_b_peer_id) + .map(|r| r.cost); + + println!("Routes: a_sees_b={} (cost={:?})", a_sees_b, cost_a_to_b); + a_sees_b + }, + Duration::from_secs(15), + ) + .await; + + wait_for_condition( + || async { + let next_hop_a_to_b = cred_a_inst + .get_peer_manager() + .get_route() + .get_next_hop_with_policy(cred_b_peer_id, NextHopPolicy::LeastCost) + .await; + println!( + "Next hop convergence A->B={:?} (admin={}, c={}), allow_relay={}", + next_hop_a_to_b, admin_peer_id, cred_c_peer_id, allow_relay + ); + if allow_relay { + next_hop_a_to_b == Some(cred_c_peer_id) + } else { + next_hop_a_to_b == Some(admin_peer_id) + } + }, + Duration::from_secs(20), + ) + .await; + + // wait 5s, make sure the routes are stable + tokio::time::sleep(Duration::from_secs(5)).await; + + // Verify next hop from A to B based on allow_relay flag + let next_hop_a_to_b = cred_a_inst + .get_peer_manager() + .get_route() + .get_next_hop_with_policy(cred_b_peer_id, NextHopPolicy::LeastCost) + .await; + + println!( + "Next hop A->B={:?} (admin={}, c={}), allow_relay={}", + next_hop_a_to_b, admin_peer_id, cred_c_peer_id, allow_relay + ); + + // When C has allow_relay=false, route should go through Admin + // When C has allow_relay=true, route may go through C or Admin depending on routing algorithm + if !allow_relay { + assert_eq!( + next_hop_a_to_b, + Some(admin_peer_id), + "Route from A to B should go through admin when allow_relay=false" + ); + } else { + assert_eq!( + next_hop_a_to_b, + Some(cred_c_peer_id), + "Route from A to B should go through C when allow_relay=true" + ); + } + + // Cleanup + drop_insts(vec![admin_inst, cred_a_inst, cred_b_inst, cred_c_inst]).await; +} + +/// Test 2: Two credential nodes connect to same admin +/// Topology: Admin ← Credential_A, Admin ← Credential_B +/// Verifies that multiple credential nodes can connect to the same admin +#[tokio::test] +#[serial_test::serial] +async fn credential_two_credentials_communicate_tcp() { + prepare_credential_network(); + + // Create admin node + let admin_config = create_admin_config("admin", Some("ns_adm"), "10.144.144.1", "fd00::1/64"); + let mut admin_inst = Instance::new(admin_config); + admin_inst.run().await.unwrap(); + + // Create credential1 on ns_c1 + let cred1_config = create_credential_config( + &admin_inst, + "cred1", + Some("ns_c1"), + "10.144.144.2", + "fd00::2/64", + ) + .await; + let mut cred1_inst = Instance::new(cred1_config); + cred1_inst.run().await.unwrap(); + + // Create credential2 on ns_c2 + let cred2_config = create_credential_config( + &admin_inst, + "cred2", + Some("ns_c2"), + "10.144.144.3", + "fd00::3/64", + ) + .await; + let mut cred2_inst = Instance::new(cred2_config); + cred2_inst.run().await.unwrap(); + + // Both credentials connect to admin + cred1_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.1:11010".parse().unwrap(), + )); + cred2_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.1:11010".parse().unwrap(), + )); + + let cred1_peer_id = cred1_inst.peer_id(); + let cred2_peer_id = cred2_inst.peer_id(); + + // Wait for both credentials to appear in admin's route table + wait_for_condition( + || async { + let routes = admin_inst.get_peer_manager().list_routes().await; + routes.iter().any(|r| r.peer_id == cred1_peer_id) + && routes.iter().any(|r| r.peer_id == cred2_peer_id) + }, + Duration::from_secs(10), + ) + .await; + + // Verify admin can ping both credentials + wait_for_condition( + || async { ping_test("ns_adm", "10.144.144.2", None).await }, + Duration::from_secs(10), + ) + .await; + + wait_for_condition( + || async { ping_test("ns_adm", "10.144.144.3", None).await }, + Duration::from_secs(10), + ) + .await; + + drop_insts(vec![admin_inst, cred1_inst, cred2_inst]).await; +} + +/// Test 3: Credential revocation removes credential from route table +/// Topology: Admin ← Credential +/// Verifies that when credential is revoked, it's removed from admin's route table +#[tokio::test] +#[serial_test::serial] +async fn credential_revocation_propagates() { + prepare_credential_network(); + + // Create admin on ns_adm (10.1.1.1) + let admin_config = create_admin_config("admin", Some("ns_adm"), "10.144.144.1", "fd00::1/64"); + let mut admin_inst = Instance::new(admin_config); + admin_inst.run().await.unwrap(); + + // Generate credential on admin + let (cred_id, cred_secret) = admin_inst + .get_global_ctx() + .get_credential_manager() + .generate_credential(vec![], false, vec![], Duration::from_secs(3600)); + + // Create credential node + let cred_config = { + use base64::Engine as _; + let privkey_bytes: [u8; 32] = base64::prelude::BASE64_STANDARD + .decode(&cred_secret) + .unwrap() + .try_into() + .unwrap(); + let private = x25519_dalek::StaticSecret::from(privkey_bytes); + + let config = TomlConfigLoader::default(); + config.set_inst_name("cred".to_string()); + config.set_netns(Some("ns_c1".to_string())); + config.set_ipv4(Some("10.144.144.2".parse().unwrap())); + config.set_ipv6(Some("fd00::2/64".parse().unwrap())); + config.set_listeners(vec![]); + config.set_network_identity(NetworkIdentity::new_credential( + admin_inst + .get_global_ctx() + .get_network_identity() + .network_name + .clone(), + )); + config.set_secure_mode(Some(generate_secure_mode_config_with_key(&private))); + config + }; + + let mut cred_inst = Instance::new(cred_config); + cred_inst.run().await.unwrap(); + + // Credential connects to admin + cred_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.1:11010".parse().unwrap(), + )); + + let cred_peer_id = cred_inst.peer_id(); + + // Wait for credential to appear in admin's route table + wait_for_condition( + || async { + admin_inst + .get_peer_manager() + .list_routes() + .await + .iter() + .any(|r| r.peer_id == cred_peer_id) + }, + Duration::from_secs(10), + ) + .await; + + // Verify connectivity before revocation + wait_for_condition( + || async { ping_test("ns_adm", "10.144.144.2", None).await }, + Duration::from_secs(10), + ) + .await; + + // Revoke the credential + assert!( + admin_inst + .get_global_ctx() + .get_credential_manager() + .revoke_credential(&cred_id), + "Credential should be revoked successfully" + ); + + // Trigger OSPF sync + admin_inst + .get_global_ctx() + .issue_event(GlobalCtxEvent::CredentialChanged); + + // Wait for credential to disappear from admin's route table + wait_for_condition( + || async { + !admin_inst + .get_peer_manager() + .list_routes() + .await + .iter() + .any(|r| r.peer_id == cred_peer_id) + }, + Duration::from_secs(15), + ) + .await; + + wait_for_condition( + || async { !ping_test("ns_adm", "10.144.144.2", None).await }, + Duration::from_secs(10), + ) + .await; + + wait_for_condition( + || async { !ping_test("ns_c1", "10.144.144.1", None).await }, + Duration::from_secs(10), + ) + .await; + + drop_insts(vec![admin_inst, cred_inst]).await; +} + +/// Test 4: Unknown credential (not in trusted list) is rejected +/// Topology: Admin +/// Verifies that credential nodes with unknown/random keys cannot connect +#[tokio::test] +#[serial_test::serial] +async fn credential_unknown_rejected() { + prepare_credential_network(); + + // Create admin node + let admin_config = create_admin_config("admin", Some("ns_adm"), "10.144.144.1", "fd00::1/64"); + let mut admin_inst = Instance::new(admin_config); + admin_inst.run().await.unwrap(); + + // Create credential node with random key (not generated by admin) + let random_private = x25519_dalek::StaticSecret::random_from_rng(rand::rngs::OsRng); + let cred_config = { + let config = TomlConfigLoader::default(); + config.set_inst_name("cred".to_string()); + config.set_netns(Some("ns_c1".to_string())); + config.set_ipv4(Some("10.144.144.2".parse().unwrap())); + config.set_ipv6(Some("fd00::2/64".parse().unwrap())); + config.set_listeners(vec![]); + config.set_network_identity(NetworkIdentity::new_credential( + admin_inst + .get_global_ctx() + .get_network_identity() + .network_name + .clone(), + )); + config.set_secure_mode(Some(generate_secure_mode_config_with_key(&random_private))); + config + }; + + let mut cred_inst = Instance::new(cred_config); + cred_inst.run().await.unwrap(); + + // Attempt to connect to admin + cred_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.1:11010".parse().unwrap(), + )); + + let cred_peer_id = cred_inst.peer_id(); + + // Wait a bit for connection attempt + tokio::time::sleep(Duration::from_secs(5)).await; + + // Verify credential does NOT appear in admin's route table + let routes = admin_inst.get_peer_manager().list_routes().await; + assert!( + !routes.iter().any(|r| r.peer_id == cred_peer_id), + "Unknown credential node should NOT appear in admin's route table" + ); + + // Verify no connectivity + let ping_result = ping_test("ns_adm", "10.144.144.2", None).await; + assert!( + !ping_result, + "Should NOT be able to ping unknown credential node" + ); + + drop_insts(vec![admin_inst, cred_inst]).await; +} + +#[rstest::rstest] +#[tokio::test] +#[serial_test::serial] +async fn credential_admin_shared_admin_credential_connectivity( + #[values(true, false)] connect_to_admin: bool, +) { + prepare_credential_network(); + + // 10.1.1.1 + let admin_a_config = + create_admin_config("admin_a", Some("ns_adm"), "10.144.144.1", "fd00::1/64"); + let mut admin_a_inst = Instance::new(admin_a_config); + admin_a_inst.run().await.unwrap(); + + // 10.1.1.2 + let shared_b_config = + create_shared_config("shared_b", Some("ns_c1"), "10.144.144.2", "fd00::2/64"); + let mut shared_b_inst = Instance::new(shared_b_config); + shared_b_inst.run().await.unwrap(); + + // 10.1.1.4 + let admin_c_config = + create_admin_config("admin_c", Some("ns_c3"), "10.144.144.4", "fd00::4/64"); + let mut admin_c_inst = Instance::new(admin_c_config); + admin_c_inst.run().await.unwrap(); + + admin_a_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.2:11010".parse().unwrap(), + )); + admin_c_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new( + "tcp://10.1.1.2:11010".parse().unwrap(), + )); + + // print all peer ids + println!("admin_a_peer_id: {:?}", admin_a_inst.peer_id()); + println!("shared_b_peer_id: {:?}", shared_b_inst.peer_id()); + println!("admin_c_peer_id: {:?}", admin_c_inst.peer_id()); + + let admin_c_peer_id = admin_c_inst.peer_id(); + wait_for_condition( + || async { + let a_routes = admin_a_inst.get_peer_manager().list_routes().await; + let c_routes = admin_c_inst.get_peer_manager().list_routes().await; + println!( + "bootstrap routes: a={:?} c={:?}", + a_routes.iter().map(|r| r.peer_id).collect::>(), + c_routes.iter().map(|r| r.peer_id).collect::>() + ); + a_routes.iter().any(|r| r.peer_id == admin_c_peer_id) + || c_routes.iter().any(|r| r.peer_id == admin_a_inst.peer_id()) + }, + Duration::from_secs(3), + ) + .await; + + let cred_d_config = create_credential_config( + &admin_a_inst, + "cred_d", + Some("ns_c2"), + "10.144.144.5", + "fd00::5/64", + ) + .await; + admin_a_inst + .get_global_ctx() + .issue_event(GlobalCtxEvent::CredentialChanged); + + let mut cred_d_inst = Instance::new(cred_d_config); + cred_d_inst.run().await.unwrap(); + let cred_d_peer_id = cred_d_inst.peer_id(); + + cred_d_inst + .get_conn_manager() + .add_connector(TcpTunnelConnector::new(if !connect_to_admin { + // connect to shared node + "tcp://10.1.1.2:11010".parse().unwrap() + } else { + // connect to admin node + "tcp://10.1.1.4:11010".parse().unwrap() + })); + // print all peer ids + println!("cred_d_peer_id: {:?}", cred_d_peer_id); + + wait_for_condition( + || async { + admin_c_inst + .get_peer_manager() + .list_routes() + .await + .iter() + .any(|r| r.peer_id == cred_d_peer_id) + }, + Duration::from_secs(60), + ) + .await; + + wait_for_condition( + || async { ping_test("ns_c3", "10.144.144.5", None).await }, + Duration::from_secs(15), + ) + .await; + + wait_for_condition( + || async { ping_test("ns_adm", "10.144.144.5", None).await }, + Duration::from_secs(15), + ) + .await; + + wait_for_condition( + || async { ping_test("ns_c2", "10.144.144.4", None).await }, + Duration::from_secs(15), + ) + .await; + + drop_insts(vec![admin_a_inst, shared_b_inst, admin_c_inst, cred_d_inst]).await; +} diff --git a/easytier/src/tests/mod.rs b/easytier/src/tests/mod.rs index 7996fe32d..ec55d6298 100644 --- a/easytier/src/tests/mod.rs +++ b/easytier/src/tests/mod.rs @@ -3,6 +3,11 @@ mod three_node; mod ipv6_test; +#[cfg(target_os = "linux")] +mod credential_tests; + +use std::io::IsTerminal as _; + use crate::common::PeerId; use crate::peers::peer_manager::PeerManager; @@ -126,9 +131,12 @@ pub fn enable_log() { .from_env() .unwrap() .add_directive("tarpc=error".parse().unwrap()); + let use_ansi = std::io::stderr().is_terminal(); tracing_subscriber::fmt::fmt() .pretty() + .with_ansi(use_ansi) .with_env_filter(filter) + .with_writer(std::io::stderr) .init(); } @@ -200,3 +208,45 @@ fn set_link_status(net_ns: &str, up: bool) { .unwrap(); tracing::info!("set link status: {:?}, net_ns: {}, up: {}", ret, net_ns, up); } + +pub async fn drop_insts(insts: Vec) { + let mut set = tokio::task::JoinSet::new(); + for mut inst in insts { + set.spawn(async move { + inst.clear_resources().await; + let pm = std::sync::Arc::downgrade(&inst.get_peer_manager()); + drop(inst); + let now = std::time::Instant::now(); + while now.elapsed().as_secs() < 5 && pm.strong_count() > 0 { + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + assert_eq!(pm.strong_count(), 0, "PeerManager should be dropped"); + }); + } + while set.join_next().await.is_some() {} +} + +pub async fn ping_test(from_netns: &str, target_ip: &str, payload_size: Option) -> bool { + use crate::common::netns::{NetNS, ROOT_NETNS_NAME}; + let _g = NetNS::new(Some(ROOT_NETNS_NAME.to_owned())).guard(); + let code = tokio::process::Command::new("ip") + .args([ + "netns", + "exec", + from_netns, + "ping", + "-c", + "1", + "-s", + payload_size.unwrap_or(56).to_string().as_str(), + "-W", + "1", + target_ip.to_string().as_str(), + ]) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .status() + .await + .unwrap(); + code.code().unwrap() == 0 +} diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index 1b014a003..e1e0c897f 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -7,8 +7,9 @@ use std::{ time::Duration, }; -use rand::Rng; +use rand::{rngs::OsRng, Rng}; use tokio::{net::UdpSocket, task::JoinSet}; +use x25519_dalek::StaticSecret; use super::*; @@ -21,9 +22,14 @@ use crate::{ stats_manager::{LabelType, MetricName}, }, instance::instance::Instance, - proto::{api::instance::TcpProxyEntryTransportType, common::CompressionAlgoPb}, + proto::{ + api::instance::TcpProxyEntryTransportType, + common::{CompressionAlgoPb, SecureModeConfig}, + }, tunnel::{ - common::tests::{_tunnel_bench_netns, wait_for_condition}, + common::tests::{ + _tunnel_bench_netns, _tunnel_pingpong_netns_with_timeout, wait_for_condition, + }, ring::RingTunnelConnector, tcp::{TcpTunnelConnector, TcpTunnelListener}, udp::UdpTunnelConnector, @@ -415,8 +421,10 @@ pub async fn subnet_proxy_loop_prevention_test() { drop_insts(insts).await; } -async fn subnet_proxy_test_udp(listen_ip: &str, target_ip: &str) { - use crate::tunnel::{common::tests::_tunnel_pingpong_netns, udp::UdpTunnelListener}; +async fn subnet_proxy_test_udp(listen_ip: &str, target_ip: &str, timeout: Duration) { + use crate::tunnel::{ + common::tests::_tunnel_pingpong_netns_with_timeout, udp::UdpTunnelListener, + }; use rand::Rng; let udp_listener = @@ -434,14 +442,16 @@ async fn subnet_proxy_test_udp(listen_ip: &str, target_ip: &str) { "net_d" }; - _tunnel_pingpong_netns( + let result = _tunnel_pingpong_netns_with_timeout( udp_listener, udp_connector, NetNS::new(Some(ns_name.into())), NetNS::new(Some("net_a".into())), buf, + timeout, ) .await; + assert!(result.is_ok(), "{}", result.unwrap_err()); // no fragment let udp_listener = @@ -452,18 +462,22 @@ async fn subnet_proxy_test_udp(listen_ip: &str, target_ip: &str) { let mut buf = vec![0; 1024]; rand::thread_rng().fill(&mut buf[..]); - _tunnel_pingpong_netns( + let result = _tunnel_pingpong_netns_with_timeout( udp_listener, udp_connector, NetNS::new(Some(ns_name.into())), NetNS::new(Some("net_a".into())), buf, + timeout, ) .await; + assert!(result.is_ok(), "{}", result.unwrap_err()); } -async fn subnet_proxy_test_tcp(listen_ip: &str, connect_ip: &str) { - use crate::tunnel::{common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener}; +async fn subnet_proxy_test_tcp(listen_ip: &str, connect_ip: &str, timeout: Duration) { + use crate::tunnel::{ + common::tests::_tunnel_pingpong_netns_with_timeout, tcp::TcpTunnelListener, + }; use rand::Rng; let tcp_listener = TcpTunnelListener::new(format!("tcp://{listen_ip}:22223").parse().unwrap()); @@ -479,26 +493,28 @@ async fn subnet_proxy_test_tcp(listen_ip: &str, connect_ip: &str) { "net_d" }; - _tunnel_pingpong_netns( + let result = _tunnel_pingpong_netns_with_timeout( tcp_listener, tcp_connector, NetNS::new(Some(ns_name.into())), NetNS::new(Some("net_a".into())), buf, + timeout, ) .await; + assert!(result.is_ok(), "{}", result.unwrap_err()); } -async fn subnet_proxy_test_icmp(target_ip: &str) { +async fn subnet_proxy_test_icmp(target_ip: &str, timeout: Duration) { wait_for_condition( || async { ping_test("net_a", target_ip, None).await }, - Duration::from_secs(5), + timeout, ) .await; wait_for_condition( || async { ping_test("net_a", target_ip, Some(5 * 1024)).await }, - Duration::from_secs(5), + timeout, ) .await; } @@ -534,10 +550,10 @@ pub async fn quic_proxy() { let target_ip = "10.1.2.4"; - subnet_proxy_test_icmp(target_ip).await; - subnet_proxy_test_icmp("10.144.144.3").await; - subnet_proxy_test_tcp(target_ip, target_ip).await; - subnet_proxy_test_tcp("0.0.0.0", "10.144.144.3").await; + subnet_proxy_test_icmp(target_ip, Duration::from_secs(5)).await; + subnet_proxy_test_icmp("10.144.144.3", Duration::from_secs(5)).await; + subnet_proxy_test_tcp(target_ip, target_ip, Duration::from_secs(5)).await; + subnet_proxy_test_tcp("0.0.0.0", "10.144.144.3", Duration::from_secs(5)).await; let metrics = insts[0] .get_global_ctx() @@ -625,14 +641,14 @@ pub async fn subnet_proxy_three_node_test( .await; for target_ip in ["10.1.3.4", "10.1.2.4", "10.144.144.3"] { - subnet_proxy_test_icmp(target_ip).await; + subnet_proxy_test_icmp(target_ip, Duration::from_secs(5)).await; let listen_ip = if target_ip == "10.144.144.3" { "0.0.0.0" } else { "10.1.2.4" }; - subnet_proxy_test_tcp(listen_ip, target_ip).await; - subnet_proxy_test_udp(listen_ip, target_ip).await; + subnet_proxy_test_tcp(listen_ip, target_ip, Duration::from_secs(5)).await; + subnet_proxy_test_udp(listen_ip, target_ip, Duration::from_secs(5)).await; } if enable_quic_proxy && !disable_quic_input { let metrics = insts[0] @@ -1369,10 +1385,7 @@ pub async fn port_forward_test( ) .await; - use crate::tunnel::{ - common::tests::_tunnel_pingpong_netns, tcp::TcpTunnelListener, udp::UdpTunnelConnector, - udp::UdpTunnelListener, - }; + use crate::tunnel::{tcp::TcpTunnelListener, udp::UdpTunnelConnector, udp::UdpTunnelListener}; let tcp_listener = TcpTunnelListener::new("tcp://0.0.0.0:23456".parse().unwrap()); let tcp_connector = TcpTunnelConnector::new("tcp://127.0.0.1:23456".parse().unwrap()); @@ -1380,14 +1393,16 @@ pub async fn port_forward_test( let mut buf = vec![0; buf_size as usize]; rand::thread_rng().fill(&mut buf[..]); - _tunnel_pingpong_netns( + _tunnel_pingpong_netns_with_timeout( tcp_listener, tcp_connector, NetNS::new(Some("net_c".into())), NetNS::new(Some("net_a".into())), buf, + Duration::from_secs(1), ) - .await; + .await + .unwrap(); let tcp_listener = TcpTunnelListener::new("tcp://0.0.0.0:23457".parse().unwrap()); let tcp_connector = TcpTunnelConnector::new("tcp://127.0.0.1:23457".parse().unwrap()); @@ -1395,14 +1410,16 @@ pub async fn port_forward_test( let mut buf = vec![0; buf_size as usize]; rand::thread_rng().fill(&mut buf[..]); - _tunnel_pingpong_netns( + _tunnel_pingpong_netns_with_timeout( tcp_listener, tcp_connector, NetNS::new(Some("net_d".into())), NetNS::new(Some("net_a".into())), buf, + Duration::from_secs(1), ) - .await; + .await + .unwrap(); let udp_listener = UdpTunnelListener::new("udp://0.0.0.0:23458".parse().unwrap()); let udp_connector = UdpTunnelConnector::new("udp://127.0.0.1:23458".parse().unwrap()); @@ -1410,14 +1427,16 @@ pub async fn port_forward_test( let mut buf = vec![0; buf_size as usize]; rand::thread_rng().fill(&mut buf[..]); - _tunnel_pingpong_netns( + _tunnel_pingpong_netns_with_timeout( udp_listener, udp_connector, NetNS::new(Some("net_c".into())), NetNS::new(Some("net_a".into())), buf, + Duration::from_secs(1), ) - .await; + .await + .unwrap(); let udp_listener = UdpTunnelListener::new("udp://0.0.0.0:23459".parse().unwrap()); let udp_connector = UdpTunnelConnector::new("udp://127.0.0.1:23459".parse().unwrap()); @@ -1425,14 +1444,16 @@ pub async fn port_forward_test( let mut buf = vec![0; buf_size as usize]; rand::thread_rng().fill(&mut buf[..]); - _tunnel_pingpong_netns( + _tunnel_pingpong_netns_with_timeout( udp_listener, udp_connector, NetNS::new(Some("net_d".into())), NetNS::new(Some("net_a".into())), buf, + Duration::from_secs(1), ) - .await; + .await + .unwrap(); drop_insts(_insts).await; } @@ -1603,7 +1624,7 @@ pub async fn acl_rule_test_inbound( #[values(true, false)] enable_quic_proxy: bool, ) { use crate::tunnel::{ - common::tests::_tunnel_pingpong_netns, + common::tests::_tunnel_pingpong_netns_with_timeout, tcp::{TcpTunnelConnector, TcpTunnelListener}, udp::{UdpTunnelConnector, UdpTunnelListener}, }; @@ -1703,43 +1724,38 @@ pub async fn acl_rule_test_inbound( rand::thread_rng().fill(&mut buf[..]); // 5. 8081 应该可以 pingpong 成功 - _tunnel_pingpong_netns( + let result = _tunnel_pingpong_netns_with_timeout( listener_8081, connector_8081, NetNS::new(Some("net_c".into())), NetNS::new(Some("net_a".into())), buf.clone(), + Duration::from_secs(5), ) .await; + assert!(result.is_ok(), "{}", result.unwrap_err()); // 6. 8080 应该连接失败(被 ACL 拦截) - let result = tokio::spawn(tokio::time::timeout( - std::time::Duration::from_millis(200), - _tunnel_pingpong_netns( - listener_8080, - connector_8080, - NetNS::new(Some("net_c".into())), - NetNS::new(Some("net_a".into())), - buf.clone(), - ), - )) + let result = _tunnel_pingpong_netns_with_timeout( + listener_8080, + connector_8080, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf.clone(), + Duration::from_millis(500), + ) .await; - assert!( - result.is_err() || result.unwrap().is_err(), - "TCP 连接 8080 应被 ACL 拦截,不能成功" - ); + assert!(result.is_err(), "TCP 连接 8080 应被 ACL 拦截,不能成功"); // 7. 从 10.144.144.2 连接 8082 应该连接失败(被 ACL 拦截) - let result = tokio::time::timeout( - std::time::Duration::from_millis(200), - _tunnel_pingpong_netns( - listener_8082, - connector_8082, - NetNS::new(Some("net_c".into())), - NetNS::new(Some("net_b".into())), - buf.clone(), - ), + let result = _tunnel_pingpong_netns_with_timeout( + listener_8082, + connector_8082, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_b".into())), + buf.clone(), + Duration::from_millis(500), ) .await; @@ -1766,25 +1782,25 @@ pub async fn acl_rule_test_inbound( rand::thread_rng().fill(&mut buf[..]); // 4. 8081 应该可以 pingpong 成功 - _tunnel_pingpong_netns( + let result = _tunnel_pingpong_netns_with_timeout( listener_8081, connector_8081, NetNS::new(Some("net_c".into())), NetNS::new(Some("net_a".into())), buf.clone(), + Duration::from_secs(5), ) .await; + assert!(result.is_ok(), "{}", result.unwrap_err()); // 5. 8080 应该连接失败(被 ACL 拦截) - let result = tokio::time::timeout( - std::time::Duration::from_millis(200), - _tunnel_pingpong_netns( - listener_8080, - connector_8080, - NetNS::new(Some("net_c".into())), - NetNS::new(Some("net_a".into())), - buf.clone(), - ), + let result = _tunnel_pingpong_netns_with_timeout( + listener_8080, + connector_8080, + NetNS::new(Some("net_c".into())), + NetNS::new(Some("net_a".into())), + buf.clone(), + Duration::from_millis(500), ) .await; @@ -1811,7 +1827,7 @@ pub async fn acl_rule_test_subnet_proxy( #[values(true, false)] enable_quic_proxy: bool, ) { use crate::tunnel::{ - common::tests::_tunnel_pingpong_netns, + common::tests::_tunnel_pingpong_netns_with_timeout, tcp::{TcpTunnelConnector, TcpTunnelListener}, udp::{UdpTunnelConnector, UdpTunnelListener}, }; @@ -1928,48 +1944,46 @@ pub async fn acl_rule_test_subnet_proxy( rand::thread_rng().fill(&mut buf[..]); // 8082 应该可以连接成功(不被 ACL 拦截) - _tunnel_pingpong_netns( + let result = _tunnel_pingpong_netns_with_timeout( listener_8082, connector_8082, NetNS::new(Some("net_d".into())), NetNS::new(Some("net_a".into())), buf.clone(), + Duration::from_secs(5), ) .await; + assert!(result.is_ok(), "{}", result.unwrap_err()); // 8080 应该连接失败(被 ACL 拦截 - 禁止访问子网代理的 8080) - let result = tokio::spawn(tokio::time::timeout( - std::time::Duration::from_millis(200), - _tunnel_pingpong_netns( - listener_8080, - connector_8080, - NetNS::new(Some("net_d".into())), - NetNS::new(Some("net_a".into())), - buf.clone(), - ), - )) + let result = _tunnel_pingpong_netns_with_timeout( + listener_8080, + connector_8080, + NetNS::new(Some("net_d".into())), + NetNS::new(Some("net_a".into())), + buf.clone(), + Duration::from_millis(500), + ) .await; assert!( - result.is_err() || result.unwrap().is_err(), + result.is_err(), "TCP 连接子网代理 8080 应被 ACL 拦截,不能成功" ); // 8081 应该连接失败(被 ACL 拦截 - 禁止 inst1 访问子网代理的 8081) - let result = tokio::spawn(tokio::time::timeout( - std::time::Duration::from_millis(200), - _tunnel_pingpong_netns( - listener_8081, - connector_8081, - NetNS::new(Some("net_d".into())), - NetNS::new(Some("net_a".into())), - buf.clone(), - ), - )) + let result = _tunnel_pingpong_netns_with_timeout( + listener_8081, + connector_8081, + NetNS::new(Some("net_d".into())), + NetNS::new(Some("net_a".into())), + buf.clone(), + Duration::from_millis(500), + ) .await; assert!( - result.is_err() || result.unwrap().is_err(), + result.is_err(), "TCP 连接子网代理 8081 应被 ACL 拦截,不能成功" ); @@ -1989,25 +2003,25 @@ pub async fn acl_rule_test_subnet_proxy( rand::thread_rng().fill(&mut buf[..]); // 8082 应该可以连接成功 - _tunnel_pingpong_netns( + let result = _tunnel_pingpong_netns_with_timeout( listener_8082, connector_8082, NetNS::new(Some("net_d".into())), NetNS::new(Some("net_a".into())), buf.clone(), + Duration::from_secs(5), ) .await; + assert!(result.is_ok(), "{}", result.unwrap_err()); // 8080 应该连接失败(被 ACL 拦截) - let result = tokio::time::timeout( - std::time::Duration::from_millis(200), - _tunnel_pingpong_netns( - listener_8080, - connector_8080, - NetNS::new(Some("net_d".into())), - NetNS::new(Some("net_a".into())), - buf.clone(), - ), + let result = _tunnel_pingpong_netns_with_timeout( + listener_8080, + connector_8080, + NetNS::new(Some("net_d".into())), + NetNS::new(Some("net_a".into())), + buf.clone(), + Duration::from_millis(500), ) .await; @@ -2116,7 +2130,7 @@ pub async fn p2p_only_test( for target_ip in ["10.144.144.3", target_ip] { assert_panics_ext( || async { - subnet_proxy_test_icmp(target_ip).await; + subnet_proxy_test_icmp(target_ip, Duration::from_millis(100)).await; }, !has_p2p_conn, ) @@ -2129,7 +2143,7 @@ pub async fn p2p_only_test( }; assert_panics_ext( || async { - subnet_proxy_test_tcp(listen_ip, target_ip).await; + subnet_proxy_test_tcp(listen_ip, target_ip, Duration::from_millis(100)).await; }, !has_p2p_conn, ) @@ -2137,7 +2151,7 @@ pub async fn p2p_only_test( assert_panics_ext( || async { - subnet_proxy_test_udp(listen_ip, target_ip).await; + subnet_proxy_test_udp(listen_ip, target_ip, Duration::from_millis(100)).await; }, !has_p2p_conn, ) @@ -2759,3 +2773,208 @@ pub async fn config_patch_test() { drop_insts(insts).await; } + +/// Generate SecureModeConfig with specified x25519 private key +pub fn generate_secure_mode_config_with_key( + private_key: &x25519_dalek::StaticSecret, +) -> SecureModeConfig { + use base64::{prelude::BASE64_STANDARD, Engine}; + use x25519_dalek::PublicKey; + + let public = PublicKey::from(private_key); + + SecureModeConfig { + enabled: true, + local_private_key: Some(BASE64_STANDARD.encode(private_key.as_bytes())), + local_public_key: Some(BASE64_STANDARD.encode(public.as_bytes())), + } +} + +/// Generate SecureModeConfig with random x25519 keypair +pub fn generate_secure_mode_config() -> SecureModeConfig { + let private = StaticSecret::random_from_rng(OsRng); + generate_secure_mode_config_with_key(&private) +} + +/// Test relay peer end-to-end encryption with TCP +#[rstest::rstest] +#[tokio::test] +#[serial_test::serial] +pub async fn relay_peer_e2e_encryption(#[values("tcp", "udp")] proto: &str) { + use crate::peers::route_trait::NextHopPolicy; + + let insts = init_three_node_ex( + proto, + |cfg| { + cfg.set_secure_mode(Some(generate_secure_mode_config())); + cfg + }, + false, + ) + .await; + + let inst1_peer_id = insts[0].peer_id(); + let inst2_peer_id = insts[1].peer_id(); + let inst3_peer_id = insts[2].peer_id(); + + println!( + "Test topology: inst1({}) <-> inst2({}) <-> inst3({})", + inst1_peer_id, inst2_peer_id, inst3_peer_id + ); + + // Check secure mode is enabled + let secure_mode_1 = insts[0].get_global_ctx().config.get_secure_mode(); + let secure_mode_2 = insts[1].get_global_ctx().config.get_secure_mode(); + let secure_mode_3 = insts[2].get_global_ctx().config.get_secure_mode(); + println!( + "Secure mode enabled: inst1={}, inst2={}, inst3={}", + secure_mode_1.is_some(), + secure_mode_2.is_some(), + secure_mode_3.is_some() + ); + + // Wait for routes to be established + wait_for_condition( + || async { + let routes = insts[0].get_peer_manager().list_routes().await; + routes.len() == 2 + }, + Duration::from_secs(10), + ) + .await; + + // Verify inst1 sees inst3 via inst2 (non-direct path) + let next_hop_to_inst3 = insts[0] + .get_peer_manager() + .get_peer_map() + .get_gateway_peer_id(inst3_peer_id, NextHopPolicy::LeastHop) + .await; + println!("Next hop from inst1 to inst3: {:?}", next_hop_to_inst3); + assert_eq!( + next_hop_to_inst3, + Some(inst2_peer_id), + "inst1 should reach inst3 via inst2 (relay)" + ); + + // Verify inst1 has no direct connection to inst3 + assert!( + !insts[0] + .get_peer_manager() + .get_peer_map() + .has_peer(inst3_peer_id), + "inst1 should NOT have direct connection to inst3" + ); + + // Check if noise_static_pubkey is available for relay handshake + let route_info_inst3 = insts[0] + .get_peer_manager() + .get_peer_map() + .get_route_peer_info(inst3_peer_id) + .await; + println!( + "Route info for inst3 on inst1: noise_static_pubkey len = {:?}", + route_info_inst3 + .as_ref() + .map(|i| i.noise_static_pubkey.len()) + ); + + // Test basic connectivity through relay + println!("Starting ping test from net_a to 10.144.144.3..."); + + assert!( + ping_test("net_a", "10.144.144.3", None).await, + "Ping from net_a to inst3 should succeed" + ); + + // Verify relay sessions are established + let relay_map_1 = insts[0].get_peer_manager().get_relay_peer_map(); + let relay_map_3 = insts[2].get_peer_manager().get_relay_peer_map(); + + println!( + "Relay states after ping: inst1->inst3: {}, inst3->inst1: {}", + relay_map_1.has_state(inst3_peer_id), + relay_map_3.has_state(inst1_peer_id) + ); + + // Test bidirectional connectivity + assert!( + ping_test("net_a", "10.144.144.3", None).await, + "Ping from net_a to inst3 should work" + ); + assert!( + ping_test("net_c", "10.144.144.1", None).await, + "Ping from net_c to inst1 should work" + ); + + println!("Test completed successfully!"); + drop_insts(insts).await; +} + +/// Test Relay Peer session cleanup on relay failure - TCP +#[tokio::test] +#[serial_test::serial] +pub async fn relay_peer_session_cleanup() { + use crate::peers::route_trait::NextHopPolicy; + + let mut insts = init_three_node_ex( + "tcp", + |cfg| { + cfg.set_secure_mode(Some(generate_secure_mode_config())); + cfg + }, + false, + ) + .await; + + let inst2_peer_id = insts[1].peer_id(); + let inst3_peer_id = insts[2].peer_id(); + let relay_map_1 = insts[0].get_peer_manager().get_relay_peer_map(); + + wait_for_condition( + || async { ping_test("net_a", "10.144.144.3", None).await }, + Duration::from_secs(6), + ) + .await; + + wait_for_condition( + || async { relay_map_1.has_state(inst3_peer_id) && relay_map_1.has_session(inst3_peer_id) }, + Duration::from_secs(3), + ) + .await; + + let next_hop = insts[0] + .get_peer_manager() + .get_peer_map() + .get_gateway_peer_id(inst3_peer_id, NextHopPolicy::LeastHop) + .await; + assert_eq!(next_hop, Some(inst2_peer_id)); + + let mut inst2 = insts.remove(1); + inst2.clear_resources().await; + drop(inst2); + + wait_for_condition( + || async { + let routes = insts[0].get_peer_manager().list_routes().await; + !routes.iter().any(|r| r.peer_id == inst3_peer_id) + }, + Duration::from_secs(6), + ) + .await; + + relay_map_1.evict_idle_sessions(Duration::from_millis(0)); + assert!(!relay_map_1.has_state(inst3_peer_id)); + + insts[0] + .get_peer_manager() + .get_peer_session_store() + .evict_unused_sessions(); + + wait_for_condition( + || async { !relay_map_1.has_session(inst3_peer_id) }, + Duration::from_secs(1), + ) + .await; + + drop_insts(insts).await; +} diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index 50b81e6b7..b4c05ff56 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -495,17 +495,20 @@ pub mod tests { L: TunnelListener + Send + Sync + 'static, C: TunnelConnector + Send + Sync + 'static, { - _tunnel_pingpong_netns( + _tunnel_pingpong_netns_with_timeout( listener, connector, NetNS::new(None), NetNS::new(None), "12345678abcdefg".as_bytes().to_vec(), + // only used by tunnel test, so set a long timeout + tokio::time::Duration::from_secs(5), ) - .await; + .await + .unwrap(); } - pub(crate) async fn _tunnel_pingpong_netns( + async fn _tunnel_pingpong_netns( mut listener: L, mut connector: C, l_netns: NetNS, diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs index d5a464f15..c79b65433 100644 --- a/easytier/src/tunnel/packet_def.rs +++ b/easytier/src/tunnel/packet_def.rs @@ -77,6 +77,8 @@ pub enum PacketType { NoiseHandshakeMsg1 = 13, NoiseHandshakeMsg2 = 14, NoiseHandshakeMsg3 = 15, + RelayHandshake = 20, + RelayHandshakeAck = 21, // used internally, DataWithKcpSrcModified = 18, diff --git a/easytier/src/web_client/mod.rs b/easytier/src/web_client/mod.rs index d3dfb28d0..c2953079b 100644 --- a/easytier/src/web_client/mod.rs +++ b/easytier/src/web_client/mod.rs @@ -36,6 +36,7 @@ pub struct DefaultHooks; impl WebClientHooks for DefaultHooks {} pub mod controller; +pub mod security; pub mod session; use std::sync::atomic::{AtomicBool, Ordering}; @@ -52,6 +53,7 @@ impl WebClient { connector: T, token: S, hostname: H, + secure_mode: bool, manager: Arc, hooks: Option>, ) -> Self { @@ -68,7 +70,13 @@ impl WebClient { let controller_clone = controller.clone(); let connected_clone = connected.clone(); let tasks = ScopedTask::from(tokio::spawn(async move { - Self::routine(controller_clone, connected_clone, Box::new(connector)).await; + Self::routine( + controller_clone, + connected_clone, + secure_mode, + Box::new(connector), + ) + .await; })); WebClient { @@ -82,6 +90,7 @@ impl WebClient { async fn routine( controller: Arc, connected: Arc, + secure_mode: bool, mut connector: Box, ) { loop { @@ -99,6 +108,65 @@ impl WebClient { log::info!("Successfully connected to {:?}", conn.info()); let mut session = session::Session::new(conn, controller.clone()); + let support_encryption = match tokio::time::timeout( + std::time::Duration::from_secs(3), + session.get_feature(), + ) + .await + { + Ok(Ok(feature)) => feature.support_encryption, + Ok(Err(error)) => { + log::warn!(%error, "GetFeature rpc failed, fallback to legacy tunnel"); + false + } + Err(_) => { + log::warn!("GetFeature rpc timeout, fallback to legacy tunnel"); + false + } + }; + + if support_encryption { + log::info!("Server supports encryption, reconnecting with secure tunnel"); + drop(session); + + let conn = match connector.connect().await { + Ok(conn) => conn, + Err(error) => { + connected.store(false, Ordering::Release); + let wait = 1; + log::warn!(%error, "Failed to reconnect secure tunnel, retrying in {} seconds...", wait); + tokio::time::sleep(std::time::Duration::from_secs(wait)).await; + continue; + } + }; + + let conn = match security::upgrade_client_tunnel(conn).await { + Ok(conn) => conn, + Err(error) => { + connected.store(false, Ordering::Release); + let wait = 1; + log::warn!(%error, "Noise handshake failed, retrying in {} seconds...", wait); + tokio::time::sleep(std::time::Duration::from_secs(wait)).await; + continue; + } + }; + + let mut session = session::Session::new(conn, controller.clone()); + session.start_heartbeat().await; + session.wait().await; + connected.store(false, Ordering::Release); + continue; + } + + if secure_mode { + connected.store(false, Ordering::Release); + let wait = 1; + log::warn!("secure-mode enabled but server does not support encryption, retrying in {} seconds...", wait); + tokio::time::sleep(std::time::Duration::from_secs(wait)).await; + continue; + } + + session.start_heartbeat().await; session.wait().await; connected.store(false, Ordering::Release); } @@ -113,6 +181,7 @@ pub async fn run_web_client( config_server_url_s: &str, machine_id: Option, hostname: Option, + secure_mode: bool, manager: Arc, hooks: Option>, ) -> Result { @@ -160,6 +229,7 @@ pub async fn run_web_client( create_connector_by_url(c_url.as_str(), &global_ctx, IpVersion::Both).await?, token.to_string(), hostname, + secure_mode, manager.clone(), hooks, )) @@ -178,6 +248,7 @@ mod tests { format!("ring://{}/test", uuid::Uuid::new_v4()).as_str(), None, None, + false, manager.clone(), None, ) diff --git a/easytier/src/web_client/security.rs b/easytier/src/web_client/security.rs new file mode 100644 index 000000000..27ba38695 --- /dev/null +++ b/easytier/src/web_client/security.rs @@ -0,0 +1,229 @@ +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use bytes::BytesMut; +use futures::{SinkExt, StreamExt}; +use snow::{params::NoiseParams, Builder, TransportState}; + +use crate::{ + proto::common::TunnelInfo, + tunnel::{ + filter::{TunnelFilter, TunnelWithFilter}, + packet_def::{PacketType, ZCPacket, ZCPacketType}, + SplitTunnel, StreamItem, Tunnel, TunnelError, ZCPacketSink, ZCPacketStream, + }, +}; + +const NOISE_MAGIC: &[u8] = b"ET_WEB_NOISE_V1:"; +const NOISE_PROLOGUE: &[u8] = b"easytier-webclient-noise-v1"; +const NOISE_PATTERN: &str = "Noise_NN_25519_ChaChaPoly_SHA256"; + +struct RawSplitTunnel { + info: Option, + split: Mutex>, +} + +impl RawSplitTunnel { + fn new( + info: Option, + stream: std::pin::Pin>, + sink: std::pin::Pin>, + ) -> Self { + Self { + info, + split: Mutex::new(Some((stream, sink))), + } + } +} + +impl Tunnel for RawSplitTunnel { + fn split(&self) -> SplitTunnel { + self.split + .lock() + .unwrap() + .take() + .expect("split can only be called once") + } + + fn info(&self) -> Option { + self.info.clone() + } +} + +struct NoiseTunnelFilter { + transport: Arc>, +} + +impl TunnelFilter for NoiseTunnelFilter { + type FilterOutput = (); + + fn before_send(&self, data: ZCPacket) -> Option { + let plain = data.tunnel_payload(); + let mut encrypted = vec![0u8; plain.len() + 64]; + let len = self + .transport + .lock() + .unwrap() + .write_message(plain, &mut encrypted) + .ok()?; + let mut packet = ZCPacket::new_with_payload(&encrypted[..len]); + packet.fill_peer_manager_hdr(0, 0, PacketType::Data as u8); + Some(packet) + } + + fn after_received(&self, data: StreamItem) -> Option { + let packet = match data { + Ok(v) => v, + Err(e) => return Some(Err(e)), + }; + let cipher = packet.payload(); + let mut plain = vec![0u8; cipher.len() + 64]; + let len = match self + .transport + .lock() + .unwrap() + .read_message(cipher, &mut plain) + { + Ok(v) => v, + Err(e) => { + return Some(Err(TunnelError::InvalidPacket(format!( + "noise decrypt failed: {e}" + )))); + } + }; + Some(Ok(ZCPacket::new_from_buf( + BytesMut::from(&plain[..len]), + ZCPacketType::DummyTunnel, + ))) + } + + fn filter_output(&self) {} +} + +fn pack_control_packet(payload: &[u8]) -> ZCPacket { + let mut packet = ZCPacket::new_with_payload(payload); + packet.fill_peer_manager_hdr(0, 0, PacketType::Data as u8); + packet +} + +fn encode_noise_payload(buf: &[u8]) -> Vec { + let mut payload = Vec::with_capacity(NOISE_MAGIC.len() + buf.len()); + payload.extend_from_slice(NOISE_MAGIC); + payload.extend_from_slice(buf); + payload +} + +fn decode_noise_payload(payload: &[u8]) -> Option<&[u8]> { + payload.strip_prefix(NOISE_MAGIC) +} + +fn wrap_secure_tunnel( + info: Option, + stream: std::pin::Pin>, + sink: std::pin::Pin>, + transport: TransportState, +) -> Box { + let raw = RawSplitTunnel::new(info, stream, sink); + Box::new(TunnelWithFilter::new( + raw, + NoiseTunnelFilter { + transport: Arc::new(Mutex::new(transport)), + }, + )) +} + +pub async fn upgrade_client_tunnel( + tunnel: Box, +) -> Result, TunnelError> { + let info = tunnel.info(); + let (mut stream, mut sink) = tunnel.split(); + + let params: NoiseParams = NOISE_PATTERN + .parse() + .map_err(|e| TunnelError::InternalError(format!("parse noise params failed: {e}")))?; + let mut state = Builder::new(params) + .prologue(NOISE_PROLOGUE) + .map_err(|e| TunnelError::InternalError(format!("set prologue failed: {e}")))? + .build_initiator() + .map_err(|e| TunnelError::InternalError(format!("build initiator failed: {e}")))?; + + let mut msg1 = vec![0u8; 1024]; + let msg1_len = state + .write_message(&[], &mut msg1) + .map_err(|e| TunnelError::InternalError(format!("write noise msg1 failed: {e}")))?; + sink.send(pack_control_packet(&encode_noise_payload( + &msg1[..msg1_len], + ))) + .await?; + + let msg2_packet = stream.next().await.ok_or(TunnelError::Shutdown)??; + let msg2_cipher = decode_noise_payload(msg2_packet.payload()) + .ok_or_else(|| TunnelError::InvalidPacket("invalid noise msg2 magic".to_string()))?; + let mut msg2 = vec![0u8; 1024]; + state + .read_message(msg2_cipher, &mut msg2) + .map_err(|e| TunnelError::InvalidPacket(format!("read noise msg2 failed: {e}")))?; + + let transport = state + .into_transport_mode() + .map_err(|e| TunnelError::InternalError(format!("switch transport mode failed: {e}")))?; + + Ok(wrap_secure_tunnel(info, stream, sink, transport)) +} + +pub async fn accept_or_upgrade_server_tunnel( + tunnel: Box, +) -> Result<(Box, bool), TunnelError> { + let info = tunnel.info(); + let (stream, sink) = tunnel.split(); + let mut stream = stream; + let mut sink = sink; + + let first_packet = match tokio::time::timeout(Duration::from_secs(1), stream.next()).await { + Ok(Some(Ok(packet))) => packet, + Ok(Some(Err(error))) => return Err(error), + Ok(None) => return Err(TunnelError::Shutdown), + Err(_) => { + return Ok(( + Box::new(RawSplitTunnel::new(info, stream, sink)) as Box, + false, + )); + } + }; + let Some(msg1_cipher) = decode_noise_payload(first_packet.payload()) else { + let stream = Box::pin(futures::stream::once(async move { Ok(first_packet) }).chain(stream)); + return Ok(( + Box::new(RawSplitTunnel::new(info, stream, sink)) as Box, + false, + )); + }; + + let params: NoiseParams = NOISE_PATTERN + .parse() + .map_err(|e| TunnelError::InternalError(format!("parse noise params failed: {e}")))?; + let mut state = Builder::new(params) + .prologue(NOISE_PROLOGUE) + .map_err(|e| TunnelError::InternalError(format!("set prologue failed: {e}")))? + .build_responder() + .map_err(|e| TunnelError::InternalError(format!("build responder failed: {e}")))?; + + let mut msg1 = vec![0u8; 1024]; + state + .read_message(msg1_cipher, &mut msg1) + .map_err(|e| TunnelError::InvalidPacket(format!("read noise msg1 failed: {e}")))?; + + let mut msg2 = vec![0u8; 1024]; + let msg2_len = state + .write_message(&[], &mut msg2) + .map_err(|e| TunnelError::InternalError(format!("write noise msg2 failed: {e}")))?; + sink.send(pack_control_packet(&encode_noise_payload( + &msg2[..msg2_len], + ))) + .await?; + + let transport = state + .into_transport_mode() + .map_err(|e| TunnelError::InternalError(format!("switch transport mode failed: {e}")))?; + + Ok((wrap_secure_tunnel(info, stream, sink, transport), true)) +} diff --git a/easytier/src/web_client/session.rs b/easytier/src/web_client/session.rs index fdec67e8f..c4e0d2519 100644 --- a/easytier/src/web_client/session.rs +++ b/easytier/src/web_client/session.rs @@ -12,7 +12,10 @@ use crate::{ api::manage::WebClientServiceServer, rpc_impl::bidirect::BidirectRpcManager, rpc_types::controller::BaseController, - web::{HeartbeatRequest, HeartbeatResponse, WebServerServiceClientFactory}, + web::{ + GetFeatureRequest, GetFeatureResponse, HeartbeatRequest, HeartbeatResponse, + WebServerServiceClientFactory, + }, }, tunnel::Tunnel, }; @@ -30,6 +33,7 @@ pub struct Session { controller: Arc, heartbeat_ctx: HeartbeatCtx, + heartbeat_started: std::sync::atomic::AtomicBool, tasks: Mutex>, } @@ -44,15 +48,18 @@ impl Session { "", ); - let mut tasks: JoinSet<()> = JoinSet::new(); - let heartbeat_ctx = - Self::heartbeat_routine(&rpc_mgr, Arc::downgrade(&controller), &mut tasks); + let (tx, _rx1) = broadcast::channel(2); + let heartbeat_ctx = HeartbeatCtx { + notifier: Arc::new(tx), + resp: Arc::new(Mutex::new(None)), + }; Session { rpc_mgr, controller, heartbeat_ctx, - tasks: Mutex::new(tasks), + heartbeat_started: std::sync::atomic::AtomicBool::new(false), + tasks: Mutex::new(JoinSet::new()), } } @@ -60,14 +67,8 @@ impl Session { rpc_mgr: &BidirectRpcManager, controller: Weak, tasks: &mut JoinSet<()>, - ) -> HeartbeatCtx { - let (tx, _rx1) = broadcast::channel(2); - - let ctx = HeartbeatCtx { - notifier: Arc::new(tx), - resp: Arc::new(Mutex::new(None)), - }; - + ctx: HeartbeatCtx, + ) { let mid = get_machine_id(); let inst_id = uuid::Uuid::new_v4(); let token = controller.upgrade().unwrap().token(); @@ -118,8 +119,22 @@ impl Session { } } }); + } - ctx + pub async fn start_heartbeat(&self) { + if self + .heartbeat_started + .swap(true, std::sync::atomic::Ordering::AcqRel) + { + return; + } + let mut tasks = self.tasks.lock().await; + Self::heartbeat_routine( + &self.rpc_mgr, + Arc::downgrade(&self.controller), + &mut tasks, + self.heartbeat_ctx.clone(), + ); } async fn wait_routines(&self) { @@ -135,6 +150,18 @@ impl Session { } } + pub async fn get_feature( + &self, + ) -> Result { + let client = self + .rpc_mgr + .rpc_client() + .scoped_client::>(1, 1, "".to_string()); + client + .get_feature(BaseController::default(), GetFeatureRequest {}) + .await + } + pub async fn wait_next_heartbeat(&self) -> Option { let mut rx = self.heartbeat_ctx.notifier.subscribe(); rx.recv().await.ok() diff --git a/script/install.ps1 b/script/install.ps1 new file mode 100644 index 000000000..9a5891fc5 --- /dev/null +++ b/script/install.ps1 @@ -0,0 +1,234 @@ +<# +.SYNOPSIS + EasyTier Windows Installer + +.DESCRIPTION + Download EasyTier from GitHub Release and install it. + Copies binaries to the install directory and updates the system PATH. + +.PARAMETER Version + Target version: "latest", "stable", or a specific tag like "v2.5.0". + Default: "latest" + +.PARAMETER InstallDir + Directory to install EasyTier binaries. + Default: "$env:ProgramFiles\EasyTier" + +.EXAMPLE + .\install.ps1 + .\install.ps1 -Version v2.5.0 + .\install.ps1 -InstallDir "C:\EasyTier" + +.NOTES + Administrator privileges are required. + After installation, run: easytier-cli service install + to register EasyTier as a system service. +#> +param( + [Parameter(Position = 0)] + [ValidatePattern('^(stable|latest|v?\d+\.\d+\.\d+(-[^\s]+)?)$')] + [string]$Version = 'latest', + + [Parameter(Position = 1)] + [string]$InstallDir = "$env:ProgramFiles\EasyTier" +) + +Set-StrictMode -Version Latest +$ErrorActionPreference = 'Stop' +$ProgressPreference = 'SilentlyContinue' + +# Force TLS 1.2+ for GitHub API and download requests +[Net.ServicePointManager]::SecurityProtocol = [Net.ServicePointManager]::SecurityProtocol -bor [Net.SecurityProtocolType]::Tls12 + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +$GITHUB_REPO = 'EasyTier/EasyTier' +$GITHUB_API = "https://api.github.com/repos/$GITHUB_REPO" +$GITHUB_RELEASE_URL = "https://github.com/$GITHUB_REPO/releases" + +# --------------------------------------------------------------------------- +# Administrator check +# --------------------------------------------------------------------------- +$currentPrincipal = [Security.Principal.WindowsPrincipal][Security.Principal.WindowsIdentity]::GetCurrent() +if (-not $currentPrincipal.IsInRole([Security.Principal.WindowsBuiltInRole]::Administrator)) { + Write-Error 'Please run this script as Administrator.' + exit 1 +} + +# --------------------------------------------------------------------------- +# Architecture detection +# --------------------------------------------------------------------------- +# Check PROCESSOR_ARCHITEW6432 first to correctly identify 64-bit OS when +# running under 32-bit PowerShell (WoW64), where PROCESSOR_ARCHITECTURE +# reports 'x86' even on a 64-bit machine. +$cpuArch = if ($env:PROCESSOR_ARCHITEW6432) { $env:PROCESSOR_ARCHITEW6432 } else { $env:PROCESSOR_ARCHITECTURE } +switch ($cpuArch) { + 'AMD64' { $arch = 'x86_64' } + 'ARM64' { $arch = 'arm64' } + 'x86' { $arch = 'i686' } + default { + Write-Error "Unsupported processor architecture: $cpuArch" + exit 1 + } +} +$assetBaseName = "easytier-windows-$arch" + +Write-Host '' +Write-Host ' ===============================================' -ForegroundColor Cyan +Write-Host ' EasyTier Windows Installer ' -ForegroundColor Cyan +Write-Host ' ===============================================' -ForegroundColor Cyan +Write-Host '' +Write-Host " Architecture : $arch" -ForegroundColor White +Write-Host '' + +# --------------------------------------------------------------------------- +# Step 1 - Resolve release version +# --------------------------------------------------------------------------- +Write-Host '[1/5] Querying GitHub Release info...' -ForegroundColor Yellow + +try { + if ($Version -eq 'latest' -or $Version -eq 'stable') { + $releaseInfo = Invoke-RestMethod ` + -Uri "$GITHUB_API/releases/latest" ` + -Headers @{ 'User-Agent' = 'EasyTier-Installer/1.0' } ` + -ErrorAction Stop + } + else { + $tag = if ($Version -notmatch '^v') { "v$Version" } else { $Version } + $releaseInfo = Invoke-RestMethod ` + -Uri "$GITHUB_API/releases/tags/$tag" ` + -Headers @{ 'User-Agent' = 'EasyTier-Installer/1.0' } ` + -ErrorAction Stop + } +} +catch { + Write-Error "Failed to fetch release info from GitHub: $_`nPlease check your network or visit $GITHUB_RELEASE_URL to download manually." + exit 1 +} + +$releaseVersion = $releaseInfo.tag_name +$assetZipName = "$assetBaseName-$releaseVersion.zip" + +Write-Host " Version : $releaseVersion" -ForegroundColor Green + +# --------------------------------------------------------------------------- +# Step 2 - Find download URL +# --------------------------------------------------------------------------- +Write-Host '' +Write-Host '[2/5] Resolving download URL...' -ForegroundColor Yellow +$asset = $releaseInfo.assets | + Where-Object { $_.name -eq $assetZipName } | + Select-Object -First 1 + +if (-not $asset) { + $availableAssets = ($releaseInfo.assets | Select-Object -ExpandProperty name) -join ', ' + Write-Error "Asset '$assetZipName' not found in release $releaseVersion.`nAvailable: $availableAssets`nVisit $GITHUB_RELEASE_URL to download manually." + exit 1 +} + +$downloadUrl = $asset.browser_download_url +Write-Host " URL : $downloadUrl" -ForegroundColor DarkGray + +# --------------------------------------------------------------------------- +# Step 3 - Download ZIP +# --------------------------------------------------------------------------- +Write-Host '' +Write-Host "[3/5] Downloading $assetZipName ..." -ForegroundColor Yellow + +$tempDir = Join-Path $env:TEMP "easytier-install-$(Get-Random)" +$zipPath = Join-Path $tempDir $assetZipName + +New-Item -ItemType Directory -Force -Path $tempDir | Out-Null + +try { + Invoke-WebRequest -Uri $downloadUrl -OutFile $zipPath -ErrorAction Stop + $sizeMB = [math]::Round((Get-Item $zipPath).Length / 1MB, 2) + Write-Host " Download complete ($sizeMB MB)" -ForegroundColor Green +} +catch { + Write-Error "Download failed: $_" + Remove-Item -Recurse -Force $tempDir -ErrorAction SilentlyContinue + exit 1 +} + +# --------------------------------------------------------------------------- +# Step 4 - Extract & copy to install directory +# --------------------------------------------------------------------------- +Write-Host '' +Write-Host '[4/5] Extracting and copying files...' -ForegroundColor Yellow + +$extractDir = Join-Path $tempDir 'extracted' +New-Item -ItemType Directory -Force -Path $extractDir | Out-Null + +try { + Expand-Archive -Path $zipPath -DestinationPath $extractDir -Force -ErrorAction Stop +} +catch { + Write-Error "Extraction failed: $_" + Remove-Item -Recurse -Force $tempDir -ErrorAction SilentlyContinue + exit 1 +} + +# ZIP may contain a sub-directory; find exe files recursively and flatten +$exeFiles = Get-ChildItem -Path $extractDir -Filter '*.exe' -Recurse +if (-not $exeFiles) { + Remove-Item -Recurse -Force $tempDir -ErrorAction SilentlyContinue + Write-Error 'No .exe files found after extraction. The ZIP may be malformed.' + exit 1 +} + +$binSourceDir = $exeFiles[0].DirectoryName + +try { + New-Item -ItemType Directory -Force -Path $InstallDir | Out-Null + Get-ChildItem -Path $binSourceDir | Copy-Item -Destination $InstallDir -Force +} +catch { + Remove-Item -Recurse -Force $tempDir -ErrorAction SilentlyContinue + Write-Error "Failed to copy files to install directory: $_" + exit 1 +} + +Write-Host " Installed to: $InstallDir" -ForegroundColor Green + +# --------------------------------------------------------------------------- +# Step 5 - Update system PATH +# --------------------------------------------------------------------------- +Write-Host '' +Write-Host '[5/5] Updating system PATH...' -ForegroundColor Yellow + +$systemPath = [Environment]::GetEnvironmentVariable('PATH', 'Machine') +# Split on ';' and normalize (trim trailing backslash, case-insensitive) for an exact match +$pathEntries = $systemPath -split ';' | ForEach-Object { $_.TrimEnd('\') } +$normalizedInstallDir = $InstallDir.TrimEnd('\') +if ($pathEntries -inotcontains $normalizedInstallDir) { + [Environment]::SetEnvironmentVariable('PATH', "$systemPath;$InstallDir", 'Machine') + $env:PATH = "$env:PATH;$InstallDir" + Write-Host " Added $InstallDir to system PATH" -ForegroundColor Green +} +else { + Write-Host " $InstallDir is already in PATH, skipping" -ForegroundColor DarkGray +} + +# --------------------------------------------------------------------------- +# Cleanup +# --------------------------------------------------------------------------- +try { + Remove-Item -Recurse -Force $tempDir -ErrorAction SilentlyContinue +} +catch { + Write-Warning "Could not remove temp dir $tempDir : $_" +} + +# --------------------------------------------------------------------------- +# Done +# --------------------------------------------------------------------------- +Write-Host '' +Write-Host " [OK] EasyTier $releaseVersion installation complete!" -ForegroundColor Green +Write-Host '' +Write-Host " Install dir : $InstallDir" -ForegroundColor White +Write-Host ' User guide : https://easytier.cn/en/guide/network/decentralized-networking.html' -ForegroundColor DarkGray +Write-Host '' +Write-Host ' NOTE: If PATH was just updated, please restart your terminal.' -ForegroundColor DarkYellow +Write-Host ''