diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f03397..9b46239 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ The format is based on Keep a Changelog and this project uses Semantic Versionin ## [Unreleased] +### Fixed + +- Prevent removed accounts from reappearing after managed-home discovery on macOS and Windows +- Recover stale managed account credentials from newer matching auth homes before surfacing refresh-token errors +- Decode Plus account credit balances when the Codex usage API returns `credits.balance` as a string +- Keep distinct provider account IDs separate even when accounts share the same email or auth subject + ## [1.1.3] - 2026-04-23 ### Added diff --git a/Package.swift b/Package.swift index 79dfa02..3f92f9e 100644 --- a/Package.swift +++ b/Package.swift @@ -20,4 +20,8 @@ let package = Package( .linkedFramework("AppKit"), .linkedFramework("SwiftUI"), ]), + .testTarget( + name: "CodexControlTests", + dependencies: ["CodexControl"], + path: "Tests/CodexControlTests"), ]) diff --git a/Sources/CodexControl/App/AppModel.swift b/Sources/CodexControl/App/AppModel.swift index aed6cfa..da78d85 100644 --- a/Sources/CodexControl/App/AppModel.swift +++ b/Sources/CodexControl/App/AppModel.swift @@ -25,6 +25,7 @@ final class AppModel: ObservableObject { private let snapshotStore = SnapshotStore() private let accountManager = CodexAccountManager() private let desktopController = CodexDesktopControl() + private var removedAccounts: [RemovedAccountIdentity] = [] private let autoRefreshInterval: TimeInterval = 5 * 60 private var autoRefreshTask: Task? private var addAccountTask: Task? @@ -185,8 +186,12 @@ final class AppModel: ObservableObject { return } + let snapshot = self.accounts.filter { !self.requiresReauthentication(accountID: $0.id) } + guard !snapshot.isEmpty else { + return + } + self.isRefreshingAll = true - let snapshot = self.accounts for account in snapshot { var state = self.runtimeStates[account.id] ?? AccountRuntimeState() state.isLoading = true @@ -260,8 +265,9 @@ final class AppModel: ObservableObject { do { let account = try await self.accountManager.addManagedAccount() + self.restoreRemovedAccount(account) self.accounts = self.accountStore.merge(existing: self.accounts, incoming: [account]) - try self.accountStore.saveAccounts(self.accounts) + try self.accountStore.saveAccounts(self.accounts, removedAccounts: self.removedAccounts) self.selectedAccountID = self.accounts.first(where: { $0.matches(account) })?.id ?? account.id self.statusMessage = "\(account.displayName) added." if let selectedAccount = self.selectedAccount { @@ -287,8 +293,9 @@ final class AppModel: ObservableObject { do { let updated = try await self.accountManager.reauthenticate(account) + self.restoreRemovedAccount(updated) self.mergeAccount(updated) - try self.accountStore.saveAccounts(self.accounts) + try self.accountStore.saveAccounts(self.accounts, removedAccounts: self.removedAccounts) self.statusMessage = "\(updated.displayName) reauthenticated." if let refreshed = self.accounts.first(where: { $0.id == account.id }) { await self.refresh(account: refreshed) @@ -308,7 +315,7 @@ final class AppModel: ObservableObject { let result = try self.accountManager.switchActiveAccount(account, existing: self.accounts) if let materializedAccount = result.materializedAccount { self.mergeAccount(materializedAccount) - try self.accountStore.saveAccounts(self.accounts) + try self.accountStore.saveAccounts(self.accounts, removedAccounts: self.removedAccounts) } self.loadInitialAccounts() @@ -395,17 +402,20 @@ final class AppModel: ObservableObject { private func loadInitialAccounts() { do { - let loadedAccounts = try self.accountStore.loadAccounts() + let stored = try self.accountStore.loadAccountList() + self.removedAccounts = stored.removedAccounts + let loadedAccounts = stored.accounts.filter { !self.isRemoved($0) } let storedAccounts = loadedAccounts.filter { $0.source != .ambient } let discoveredManagedAccounts = try self.accountManager.discoverManagedAccounts(existing: loadedAccounts) var incomingAccounts = discoveredManagedAccounts if let ambientAccount = try self.accountManager.discoverAmbientAccount(existing: loadedAccounts) { incomingAccounts.insert(ambientAccount, at: 0) } + incomingAccounts.removeAll { self.isRemoved($0) } self.accounts = self.accountStore.merge(existing: storedAccounts, incoming: incomingAccounts) - if self.accounts != loadedAccounts { - try self.accountStore.saveAccounts(self.accounts) + if self.accounts != stored.accounts { + try self.accountStore.saveAccounts(self.accounts, removedAccounts: self.removedAccounts) } self.ensureSelection() self.refreshActiveIdentity() @@ -420,12 +430,16 @@ final class AppModel: ObservableObject { } private func remove(_ account: StoredAccount) { - self.accounts.removeAll { $0.id == account.id } + let removedIdentity = RemovedAccountIdentity(account: account) + self.removedAccounts.removeAll { $0.matches(account) } + self.removedAccounts.append(removedIdentity) + + self.accounts.removeAll { $0.id == account.id || removedIdentity.matches($0) } self.runtimeStates.removeValue(forKey: account.id) do { try self.accountManager.removeManagedFilesIfOwned(account) - try self.accountStore.saveAccounts(self.accounts) + try self.accountStore.saveAccounts(self.accounts, removedAccounts: self.removedAccounts) self.ensureSelection() self.statusMessage = "\(account.displayName) removed." } catch { @@ -478,12 +492,29 @@ final class AppModel: ObservableObject { private func persistAccountsSilently() { do { - try self.accountStore.saveAccounts(self.accounts) + try self.accountStore.saveAccounts(self.accounts, removedAccounts: self.removedAccounts) } catch { self.statusMessage = error.localizedDescription } } + private func isRemoved(_ account: StoredAccount) -> Bool { + self.removedAccounts.contains { $0.matches(account) } + } + + private func restoreRemovedAccount(_ account: StoredAccount) { + self.removedAccounts.removeAll { $0.matches(account) } + } + + private func requiresReauthentication(accountID: UUID) -> Bool { + guard let message = self.runtimeStates[accountID]?.errorMessage?.lowercased() else { + return false + } + + return message.contains("refresh token") + && message.contains("sign in again") + } + private func persistSnapshotsSilently() { let snapshots = self.runtimeStates.compactMapValues(\.snapshot) do { diff --git a/Sources/CodexControl/Models/AccountModels.swift b/Sources/CodexControl/Models/AccountModels.swift index f787af7..dbebf63 100644 --- a/Sources/CodexControl/Models/AccountModels.swift +++ b/Sources/CodexControl/Models/AccountModels.swift @@ -71,17 +71,28 @@ struct StoredAccount: Codable, Identifiable, Hashable, Sendable { Self.normalizeIdentifier(self.authSubject) } + var normalizedProviderAccountID: String? { + Self.normalizeIdentifier(self.providerAccountID) + } + var standardizedHomePath: String { URL(fileURLWithPath: self.codexHomePath, isDirectory: true).standardizedFileURL.path } func matches(_ other: StoredAccount) -> Bool { - if let normalizedAuthSubject, normalizedAuthSubject == other.normalizedAuthSubject - { + if self.standardizedHomePath == other.standardizedHomePath { return true } - if self.standardizedHomePath == other.standardizedHomePath { + if let normalizedProviderAccountID, let otherProviderAccountID = other.normalizedProviderAccountID { + return normalizedProviderAccountID == otherProviderAccountID + } + + if self.normalizedProviderAccountID != nil || other.normalizedProviderAccountID != nil { + return false + } + + if let normalizedAuthSubject, normalizedAuthSubject == other.normalizedAuthSubject { return true } @@ -163,6 +174,86 @@ struct StoredAccount: Codable, Identifiable, Hashable, Sendable { struct StoredAccountList: Codable, Sendable { let version: Int let accounts: [StoredAccount] + let removedAccounts: [RemovedAccountIdentity] + + init(version: Int, accounts: [StoredAccount], removedAccounts: [RemovedAccountIdentity] = []) { + self.version = version + self.accounts = accounts + self.removedAccounts = removedAccounts + } + + private enum CodingKeys: String, CodingKey { + case version + case accounts + case removedAccounts + } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.version = try container.decode(Int.self, forKey: .version) + self.accounts = try container.decode([StoredAccount].self, forKey: .accounts) + self.removedAccounts = try container.decodeIfPresent([RemovedAccountIdentity].self, forKey: .removedAccounts) ?? [] + } +} + +struct RemovedAccountIdentity: Codable, Hashable, Sendable { + let id: UUID + let emailHint: String? + let authSubject: String? + let providerAccountID: String? + let codexHomePath: String + let source: StoredAccountSource + let removedAt: Date + + init(account: StoredAccount, removedAt: Date = Date()) { + self.id = UUID() + self.emailHint = account.emailHint + self.authSubject = account.authSubject + self.providerAccountID = account.providerAccountID + self.codexHomePath = account.codexHomePath + self.source = account.source + self.removedAt = removedAt + } + + var normalizedProviderAccountID: String? { + StoredAccount.normalizeIdentifier(self.providerAccountID) + } + + var normalizedAuthSubject: String? { + StoredAccount.normalizeIdentifier(self.authSubject) + } + + var normalizedEmailHint: String? { + StoredAccount.normalizeEmail(self.emailHint) + } + + var standardizedHomePath: String { + URL(fileURLWithPath: self.codexHomePath, isDirectory: true).standardizedFileURL.path + } + + func matches(_ account: StoredAccount) -> Bool { + if self.standardizedHomePath == account.standardizedHomePath { + return true + } + + if let normalizedProviderAccountID, let accountProviderAccountID = account.normalizedProviderAccountID { + return normalizedProviderAccountID == accountProviderAccountID + } + + if self.normalizedProviderAccountID != nil || account.normalizedProviderAccountID != nil { + return false + } + + if let normalizedAuthSubject, normalizedAuthSubject == account.normalizedAuthSubject { + return true + } + + if let normalizedEmailHint, normalizedEmailHint == account.normalizedEmailHint { + return true + } + + return false + } } struct AccountRuntimeState: Sendable { diff --git a/Sources/CodexControl/Services/AccountStore.swift b/Sources/CodexControl/Services/AccountStore.swift index 637458c..591cbad 100644 --- a/Sources/CodexControl/Services/AccountStore.swift +++ b/Sources/CodexControl/Services/AccountStore.swift @@ -1,26 +1,41 @@ import Foundation struct AccountStore { - private static let currentVersion = 1 + private static let currentVersion = 2 func loadAccounts() throws -> [StoredAccount] { + try self.loadAccountList().accounts + } + + func loadRemovedAccounts() throws -> [RemovedAccountIdentity] { + try self.loadAccountList().removedAccounts + } + + func loadAccountList() throws -> StoredAccountList { guard FileManager.default.fileExists(atPath: FileLocations.accountsFile.path) else { - return [] + return StoredAccountList(version: Self.currentVersion, accounts: []) } let data = try Data(contentsOf: FileLocations.accountsFile) let decoder = JSONDecoder() decoder.dateDecodingStrategy = .iso8601 let stored = try decoder.decode(StoredAccountList.self, from: data) - return self.sorted(stored.accounts) + return StoredAccountList( + version: stored.version, + accounts: self.sorted(stored.accounts), + removedAccounts: stored.removedAccounts) } - func saveAccounts(_ accounts: [StoredAccount]) throws { + func saveAccounts(_ accounts: [StoredAccount], removedAccounts: [RemovedAccountIdentity]? = nil) throws { try FileLocations.ensureDirectories() let encoder = JSONEncoder() encoder.outputFormatting = [.prettyPrinted, .sortedKeys] encoder.dateEncodingStrategy = .iso8601 - let data = try encoder.encode(StoredAccountList(version: Self.currentVersion, accounts: self.sorted(accounts))) + let preservedRemovedAccounts = try removedAccounts ?? self.loadRemovedAccountsIfPresent() + let data = try encoder.encode(StoredAccountList( + version: Self.currentVersion, + accounts: self.sorted(accounts), + removedAccounts: preservedRemovedAccounts)) try data.write(to: FileLocations.accountsFile, options: .atomic) } @@ -47,4 +62,11 @@ struct AccountStore { return left < right } } + + private func loadRemovedAccountsIfPresent() throws -> [RemovedAccountIdentity] { + guard FileManager.default.fileExists(atPath: FileLocations.accountsFile.path) else { + return [] + } + return try self.loadRemovedAccounts() + } } diff --git a/Sources/CodexControl/Services/CodexAPI.swift b/Sources/CodexControl/Services/CodexAPI.swift index 4620c65..cdc3be2 100644 --- a/Sources/CodexControl/Services/CodexAPI.swift +++ b/Sources/CodexControl/Services/CodexAPI.swift @@ -113,6 +113,22 @@ private struct CreditDetails: Decodable { case unlimited case balance } + + init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.hasCredits = try container.decode(Bool.self, forKey: .hasCredits) + self.unlimited = try container.decode(Bool.self, forKey: .unlimited) + + if let value = try? container.decodeIfPresent(Double.self, forKey: .balance) { + self.balance = value + } else if let value = try? container.decodeIfPresent(String.self, forKey: .balance), + let parsed = Double(value) + { + self.balance = parsed + } else { + self.balance = nil + } + } } enum CodexAPI { @@ -129,6 +145,10 @@ enum CodexAPI { static func loadIdentity(codexHomePath: String) throws -> AuthBackedIdentity { let credentials = try self.loadCredentials(codexHomePath: codexHomePath) + return self.identity(from: credentials) + } + + private static func identity(from credentials: AuthCredentials) -> AuthBackedIdentity { let payload = credentials.idToken.flatMap(self.parseJWT) let auth = payload?["https://api.openai.com/auth"] as? [String: Any] let profile = payload?["https://api.openai.com/profile"] as? [String: Any] @@ -149,8 +169,15 @@ enum CodexAPI { var credentials = try self.loadCredentials(codexHomePath: account.codexHomePath) if credentials.needsRefresh, !credentials.refreshToken.isEmpty { - credentials = try await self.refresh(credentials) - try self.saveCredentials(credentials, codexHomePath: account.codexHomePath) + do { + credentials = try await self.refresh(credentials) + try self.saveCredentials(credentials, codexHomePath: account.codexHomePath) + } catch { + if let recovered = await self.recoverSnapshot(for: account, excluding: [credentials.refreshToken]) { + return recovered + } + throw error + } } do { @@ -162,8 +189,15 @@ enum CodexAPI { guard !credentials.refreshToken.isEmpty else { throw CodexAPIError.unauthorized } - credentials = try await self.refresh(credentials) - try self.saveCredentials(credentials, codexHomePath: account.codexHomePath) + do { + credentials = try await self.refresh(credentials) + try self.saveCredentials(credentials, codexHomePath: account.codexHomePath) + } catch { + if let recovered = await self.recoverSnapshot(for: account, excluding: [credentials.refreshToken]) { + return recovered + } + throw error + } return try await self.fetchVerifiedSnapshot( codexHomePath: account.codexHomePath, credentials: credentials, @@ -206,7 +240,7 @@ enum CodexAPI { credentials: AuthCredentials, fallbackEmail: String?) async throws -> AccountUsageSnapshot { - let identity = try? self.loadIdentity(codexHomePath: codexHomePath) + let identity = self.identity(from: credentials) let response = try await self.fetchUsage( accessToken: credentials.accessToken, accountId: credentials.accountId, @@ -214,9 +248,9 @@ enum CodexAPI { let windows = self.makeNormalizedWindows(response.rateLimit) return AccountUsageSnapshot( - email: identity?.email ?? fallbackEmail, - providerAccountID: identity?.providerAccountID ?? credentials.accountId, - plan: self.normalizeString(response.planType) ?? identity?.plan, + email: identity.email ?? fallbackEmail, + providerAccountID: identity.providerAccountID ?? credentials.accountId, + plan: self.normalizeString(response.planType) ?? identity.plan, allowed: response.rateLimit?.allowed, limitReached: response.rateLimit?.limitReached, primaryWindow: windows.primary, @@ -260,6 +294,7 @@ enum CodexAPI { let refreshToken = self.stringValue(in: tokens, key: "refresh_token") ?? "" let idToken = self.stringValue(in: tokens, key: "id_token") let accountId = self.stringValue(in: tokens, key: "account_id") + ?? self.accountID(fromIDToken: idToken) return AuthCredentials( accessToken: accessToken, @@ -343,7 +378,7 @@ enum CodexAPI { accessToken: (json["access_token"] as? String) ?? credentials.accessToken, refreshToken: (json["refresh_token"] as? String) ?? credentials.refreshToken, idToken: (json["id_token"] as? String) ?? credentials.idToken, - accountId: credentials.accountId, + accountId: credentials.accountId ?? self.accountID(fromIDToken: json["id_token"] as? String), lastRefresh: Date()) } catch let error as CodexAPIError { throw error @@ -602,6 +637,141 @@ enum CodexAPI { return formatter.date(from: value) } + private struct CredentialRecoveryCandidate { + let homePath: String + let credentials: AuthCredentials + let identity: AuthBackedIdentity + let freshness: Date + } + + private static func recoverSnapshot(for account: StoredAccount, excluding excludedRefreshTokens: Set) async -> AccountUsageSnapshot? { + for candidate in self.recoveryCandidates(for: account, excluding: excludedRefreshTokens) { + var credentials = candidate.credentials + + do { + if credentials.needsRefresh, !credentials.refreshToken.isEmpty { + credentials = try await self.refresh(credentials) + try self.saveCredentials(credentials, codexHomePath: candidate.homePath) + } + + let snapshot = try await self.fetchVerifiedSnapshot( + codexHomePath: candidate.homePath, + credentials: credentials, + fallbackEmail: account.emailHint) + try self.saveCredentials(credentials, codexHomePath: account.codexHomePath) + return snapshot + } catch CodexAPIError.unauthorized where !credentials.refreshToken.isEmpty { + do { + credentials = try await self.refresh(credentials) + try self.saveCredentials(credentials, codexHomePath: candidate.homePath) + let snapshot = try await self.fetchVerifiedSnapshot( + codexHomePath: candidate.homePath, + credentials: credentials, + fallbackEmail: account.emailHint) + try self.saveCredentials(credentials, codexHomePath: account.codexHomePath) + return snapshot + } catch { + continue + } + } catch { + continue + } + } + + return nil + } + + private static func recoveryCandidates(for account: StoredAccount, excluding excludedRefreshTokens: Set) -> [CredentialRecoveryCandidate] { + let candidateHomes = self.recoveryHomeURLs(for: account) + let accountHomePath = URL(fileURLWithPath: account.codexHomePath, isDirectory: true).standardizedFileURL.path + + return candidateHomes.compactMap { homeURL in + let homePath = homeURL.standardizedFileURL.path + guard homePath != accountHomePath, + let credentials = try? self.loadCredentials(codexHomePath: homePath), + !excludedRefreshTokens.contains(credentials.refreshToken) + else { + return nil + } + + let identity = self.identity(from: credentials) + guard self.identity(identity, at: homePath, matches: account) else { + return nil + } + + return CredentialRecoveryCandidate( + homePath: homePath, + credentials: credentials, + identity: identity, + freshness: self.credentialsFreshness(credentials, homePath: homePath)) + } + .sorted { $0.freshness > $1.freshness } + } + + private static func recoveryHomeURLs(for account: StoredAccount) -> [URL] { + var homes: [URL] = [] + if let managedHomes = try? FileManager.default.contentsOfDirectory( + at: FileLocations.managedHomesDirectory, + includingPropertiesForKeys: [.contentModificationDateKey], + options: [.skipsHiddenFiles]) + { + homes.append(contentsOf: managedHomes) + } + + homes.append(FileLocations.ambientCodexHome) + return homes + } + + private static func identity(_ identity: AuthBackedIdentity, at homePath: String, matches account: StoredAccount) -> Bool { + let accountHomePath = URL(fileURLWithPath: account.codexHomePath, isDirectory: true).standardizedFileURL.path + if homePath == accountHomePath { + return true + } + + let accountProviderAccountID = StoredAccount.normalizeIdentifier(account.providerAccountID) + let identityProviderAccountID = StoredAccount.normalizeIdentifier(identity.providerAccountID) + if let accountProviderAccountID, let identityProviderAccountID { + return accountProviderAccountID == identityProviderAccountID + } + + if accountProviderAccountID != nil || identityProviderAccountID != nil { + return false + } + + let accountSubject = StoredAccount.normalizeIdentifier(account.authSubject) + let identitySubject = StoredAccount.normalizeIdentifier(identity.authSubject) + if let accountSubject, let identitySubject, accountSubject == identitySubject { + return true + } + + let accountEmail = StoredAccount.normalizeEmail(account.emailHint) + let identityEmail = StoredAccount.normalizeEmail(identity.email) + if let accountEmail, let identityEmail, accountEmail == identityEmail { + return true + } + + return false + } + + private static func credentialsFreshness(_ credentials: AuthCredentials, homePath: String) -> Date { + if let lastRefresh = credentials.lastRefresh { + return lastRefresh + } + + let authURL = URL(fileURLWithPath: homePath, isDirectory: true).appendingPathComponent("auth.json", isDirectory: false) + let values = try? authURL.resourceValues(forKeys: [.contentModificationDateKey]) + return values?.contentModificationDate ?? .distantPast + } + + private static func accountID(fromIDToken idToken: String?) -> String? { + guard let payload = idToken.flatMap(self.parseJWT) else { + return nil + } + + let auth = payload["https://api.openai.com/auth"] as? [String: Any] + return self.normalizeString((auth?["chatgpt_account_id"] as? String) ?? (payload["chatgpt_account_id"] as? String)) + } + private static func stringValue(in dictionary: [String: Any], key: String) -> String? { if let value = dictionary[key] as? String, !value.isEmpty { return value diff --git a/Sources/CodexControl/Services/CodexAccountManager.swift b/Sources/CodexControl/Services/CodexAccountManager.swift index 156c90d..802a33c 100644 --- a/Sources/CodexControl/Services/CodexAccountManager.swift +++ b/Sources/CodexControl/Services/CodexAccountManager.swift @@ -222,14 +222,17 @@ struct CodexAccountManager { } let root = FileLocations.managedHomesDirectory.standardizedFileURL.path - let target = URL(fileURLWithPath: account.codexHomePath, isDirectory: true).standardizedFileURL.path + let targets = try self.managedHomePathsMatching(account) let prefix = root.hasSuffix("/") ? root : root + "/" - guard target.hasPrefix(prefix) else { - throw CodexAccountManagerError.unsafeDeletePath - } - if FileManager.default.fileExists(atPath: target) { - try FileManager.default.removeItem(atPath: target) + for target in targets { + guard target.hasPrefix(prefix) else { + throw CodexAccountManagerError.unsafeDeletePath + } + + if FileManager.default.fileExists(atPath: target) { + try FileManager.default.removeItem(atPath: target) + } } } @@ -461,12 +464,43 @@ struct CodexAccountManager { } private func directoryTimestamp(for homeURL: URL) -> Date { + let authURL = homeURL.appendingPathComponent("auth.json", isDirectory: false) + if let authValues = try? authURL.resourceValues(forKeys: [.contentModificationDateKey]), + let contentModificationDate = authValues.contentModificationDate + { + return contentModificationDate + } + let values = try? homeURL.resourceValues(forKeys: [.creationDateKey, .contentModificationDateKey]) return values?.contentModificationDate ?? values?.creationDate ?? Date() } + private func managedHomePathsMatching(_ account: StoredAccount) throws -> Set { + try FileLocations.ensureDirectories() + + var targets: Set = [ + URL(fileURLWithPath: account.codexHomePath, isDirectory: true).standardizedFileURL.path, + ] + + let homeURLs = try FileManager.default.contentsOfDirectory( + at: FileLocations.managedHomesDirectory, + includingPropertiesForKeys: nil, + options: [.skipsHiddenFiles]) + + for homeURL in homeURLs { + guard let discovered = self.discoveredManagedAccount(at: homeURL, existing: [account]), + discovered.matches(account) + else { + continue + } + targets.insert(homeURL.standardizedFileURL.path) + } + + return targets + } + private func timestampSlug() -> String { let formatter = DateFormatter() formatter.locale = Locale(identifier: "en_US_POSIX") diff --git a/Tests/CodexControlTests/AccountIdentityTests.swift b/Tests/CodexControlTests/AccountIdentityTests.swift new file mode 100644 index 0000000..71180ac --- /dev/null +++ b/Tests/CodexControlTests/AccountIdentityTests.swift @@ -0,0 +1,59 @@ +import XCTest +@testable import CodexControl + +final class AccountIdentityTests: XCTestCase { + func testMatchesKeepsDifferentProviderAccountsSeparate() { + let createdAt = Date(timeIntervalSince1970: 1_766_534_400) + let first = StoredAccount( + id: UUID(), + nickname: nil, + emailHint: "user@example.com", + authSubject: "auth0|same-user", + providerAccountID: "account-1", + codexHomePath: "/tmp/a", + source: .managedByApp, + createdAt: createdAt, + updatedAt: createdAt) + let second = StoredAccount( + id: UUID(), + nickname: nil, + emailHint: "user@example.com", + authSubject: "auth0|same-user", + providerAccountID: "account-2", + codexHomePath: "/tmp/b", + source: .managedByApp, + createdAt: createdAt, + updatedAt: createdAt) + + XCTAssertFalse(first.matches(second)) + } + + func testRemovedIdentityMatchesProviderNotSharedEmail() { + let createdAt = Date(timeIntervalSince1970: 1_766_534_400) + let removedAccount = StoredAccount( + id: UUID(), + nickname: nil, + emailHint: "user@example.com", + authSubject: "auth0|same-user", + providerAccountID: "account-1", + codexHomePath: "/tmp/a", + source: .managedByApp, + createdAt: createdAt, + updatedAt: createdAt) + let otherAccount = StoredAccount( + id: UUID(), + nickname: nil, + emailHint: "user@example.com", + authSubject: "auth0|same-user", + providerAccountID: "account-2", + codexHomePath: "/tmp/b", + source: .managedByApp, + createdAt: createdAt, + updatedAt: createdAt) + + let removed = RemovedAccountIdentity(account: removedAccount, removedAt: createdAt) + + XCTAssertTrue(removed.matches(removedAccount)) + XCTAssertFalse(removed.matches(otherAccount)) + } +} diff --git a/windows/codexcontrol_windows/account_manager.py b/windows/codexcontrol_windows/account_manager.py index 44937c4..007b051 100644 --- a/windows/codexcontrol_windows/account_manager.py +++ b/windows/codexcontrol_windows/account_manager.py @@ -162,15 +162,16 @@ def remove_managed_files_if_owned(self, account: StoredAccount) -> None: return root = MANAGED_HOMES_DIRECTORY.resolve(strict=False) - target = Path(account.codex_home_path).resolve(strict=False) + targets = self._managed_home_paths_matching(account) - try: - target.relative_to(root) - except ValueError as error: - raise CodexAccountManagerError("This path is not an app-managed home directory.") from error + for target in targets: + try: + target.relative_to(root) + except ValueError as error: + raise CodexAccountManagerError("This path is not an app-managed home directory.") from error - if target.exists(): - shutil.rmtree(target, ignore_errors=False) + if target.exists(): + shutil.rmtree(target, ignore_errors=False) def discover_managed_accounts(self, existing: list[StoredAccount]) -> list[StoredAccount]: ensure_directories() @@ -372,6 +373,25 @@ def _rewrite_creator_id( environment["creator_id"] = updated_creator_id path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8") + def _managed_home_paths_matching(self, account: StoredAccount) -> set[Path]: + ensure_directories() + targets = {Path(account.codex_home_path).resolve(strict=False)} + seen_keys = {_managed_home_key(next(iter(targets)))} + + for home_path in MANAGED_HOMES_DIRECTORY.iterdir(): + candidate = self._discovered_managed_account(home_path, [account]) + if candidate is None or not candidate.matches(account): + continue + + resolved = home_path.resolve(strict=False) + key = _managed_home_key(resolved) + if key in seen_keys: + continue + targets.add(resolved) + seen_keys.add(key) + + return targets + def _authenticate_account( self, home_path: Path, @@ -494,10 +514,18 @@ def _read_remaining_output(process: subprocess.Popen[str]) -> tuple[str, str]: def _directory_timestamp(path: Path): + auth_path = path / "auth.json" + if auth_path.exists(): + return datetime.fromtimestamp(auth_path.stat().st_mtime, tz=timezone.utc) + stat = path.stat() return datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc) +def _managed_home_key(path: Path) -> str: + return os.path.normcase(os.path.abspath(os.path.normpath(path))) + + def _updated_creator_id( creator_id: object, previous_account_id: str | None, diff --git a/windows/codexcontrol_windows/app.py b/windows/codexcontrol_windows/app.py index 28353c4..ff1dfd9 100644 --- a/windows/codexcontrol_windows/app.py +++ b/windows/codexcontrol_windows/app.py @@ -21,7 +21,7 @@ from .codex_api import AuthBackedIdentity from .codex_api import fetch_snapshot from .codex_desktop import CodexDesktopControlError, restart_codex_desktop -from .models import AccountRuntimeState, AccountUsageSnapshot, StoredAccount, StoredAccountSource, utc_now +from .models import AccountRuntimeState, AccountUsageSnapshot, RemovedAccountIdentity, StoredAccount, StoredAccountSource, utc_now from .presentation_logic import account_sort_key, is_active_account from .stores import AccountStore, SnapshotStore @@ -328,6 +328,7 @@ def __init__(self, start_hidden: bool = False) -> None: self.events: queue.Queue[tuple[Any, ...]] = queue.Queue() self.accounts: list[StoredAccount] = [] + self.removed_accounts: list[RemovedAccountIdentity] = [] self.runtime_states: dict[UUID, AccountRuntimeState] = {} self.nickname_drafts: dict[UUID, str] = {} self.selected_account_id: UUID | None = None @@ -463,9 +464,16 @@ def refresh_all(self) -> None: if not self.accounts or self.is_refreshing_all: return + refreshable_accounts = [ + account for account in self.accounts + if not self._requires_reauthentication(account.id) + ] + if not refreshable_accounts: + return + self.is_refreshing_all = True - self._group_refresh_pending = len(self.accounts) - for account in self.accounts: + self._group_refresh_pending = len(refreshable_accounts) + for account in refreshable_accounts: state = self.runtime_states.setdefault(account.id, AccountRuntimeState()) state.is_loading = True self.runtime_states[account.id] = state @@ -547,11 +555,6 @@ def update_nickname(self, account_id: UUID) -> None: self._render() def remove_account(self, account: StoredAccount) -> None: - if not account.source.owns_files: - self.status_message = "System accounts are auto-discovered and cannot be removed here." - self._render() - return - confirmed = messagebox.askyesno( "Remove Account", f"{account.display_name} will be removed from CodexControl.", @@ -560,14 +563,22 @@ def remove_account(self, account: StoredAccount) -> None: if not confirmed: return - self.accounts = [candidate for candidate in self.accounts if candidate.id != account.id] + removed_identity = RemovedAccountIdentity.from_account(account) + self.removed_accounts = [candidate for candidate in self.removed_accounts if not candidate.matches(account)] + self.removed_accounts.append(removed_identity) + + self.accounts = [ + candidate + for candidate in self.accounts + if candidate.id != account.id and not removed_identity.matches(candidate) + ] self.runtime_states.pop(account.id, None) self.nickname_drafts.pop(account.id, None) self._mark_accounts_dirty() self._mark_runtime_dirty() try: self.account_manager.remove_managed_files_if_owned(account) - self.account_store.save_accounts(self.accounts) + self.account_store.save_accounts(self.accounts, self.removed_accounts) self._ensure_selection() self.status_message = f"{account.display_name} removed." except CodexAccountManagerError as error: @@ -880,17 +891,19 @@ def _toggle_window(self) -> None: def _load_initial_state(self) -> None: try: - loaded_accounts = self.account_store.load_accounts() + loaded_accounts, self.removed_accounts = self.account_store.load_account_list() + loaded_accounts = [account for account in loaded_accounts if not self._is_removed(account)] stored_accounts = [account for account in loaded_accounts if account.source is not StoredAccountSource.AMBIENT] discovered_accounts = self.account_manager.discover_managed_accounts(loaded_accounts) ambient_account = self.account_manager.discover_ambient_account(loaded_accounts) incoming_accounts = list(discovered_accounts) if ambient_account is not None: incoming_accounts.insert(0, ambient_account) + incoming_accounts = [account for account in incoming_accounts if not self._is_removed(account)] self.accounts = self.account_store.merge(stored_accounts, incoming_accounts) if self.accounts != loaded_accounts: - self.account_store.save_accounts(self.accounts) + self.account_store.save_accounts(self.accounts, self.removed_accounts) except Exception as error: self.status_message = str(error) self.accounts = [] @@ -988,8 +1001,9 @@ def _apply_add_account_result(self, account: StoredAccount | None, error: Except self._add_handle = None if account is not None: + self._restore_removed_account(account) self.accounts = self.account_store.merge(self.accounts, [account]) - self.account_store.save_accounts(self.accounts) + self.account_store.save_accounts(self.accounts, self.removed_accounts) self._mark_accounts_dirty() matched = next((candidate for candidate in self.accounts if candidate.matches(account)), account) self.selected_account_id = matched.id @@ -1010,8 +1024,9 @@ def _apply_reauth_result( self._reauth_handle = None if account is not None: + self._restore_removed_account(account) self.accounts = self.account_store.merge(self.accounts, [account]) - self.account_store.save_accounts(self.accounts) + self.account_store.save_accounts(self.accounts, self.removed_accounts) self._mark_accounts_dirty() self.status_message = f"{account.display_name} reauthenticated." refreshed = next((candidate for candidate in self.accounts if candidate.id == original_account_id), None) @@ -1045,7 +1060,7 @@ def _update_account_metadata(self, account_id: UUID, snapshot: AccountUsageSnaps def _persist_accounts_silently(self) -> None: try: - self.account_store.save_accounts(self.accounts) + self.account_store.save_accounts(self.accounts, self.removed_accounts) except Exception as error: self.status_message = str(error) @@ -1097,6 +1112,20 @@ def _ensure_selection(self) -> None: def _refresh_active_identity(self) -> None: self.active_identity = self.account_manager.load_active_identity() + def _is_removed(self, account: StoredAccount) -> bool: + return any(removed.matches(account) for removed in self.removed_accounts) + + def _restore_removed_account(self, account: StoredAccount) -> None: + self.removed_accounts = [removed for removed in self.removed_accounts if not removed.matches(account)] + + def _requires_reauthentication(self, account_id: UUID) -> bool: + state = self.runtime_states.get(account_id) + if state is None or not state.error_message: + return False + + message = state.error_message.lower() + return "refresh token" in message and "sign in again" in message + def _replace_or_append_account(self, account: StoredAccount) -> None: replaced = False for index, existing in enumerate(self.accounts): diff --git a/windows/codexcontrol_windows/codex_api.py b/windows/codexcontrol_windows/codex_api.py index ecb36d0..0e1db66 100644 --- a/windows/codexcontrol_windows/codex_api.py +++ b/windows/codexcontrol_windows/codex_api.py @@ -10,6 +10,7 @@ import requests +from .file_locations import AMBIENT_CODEX_HOME, MANAGED_HOMES_DIRECTORY from .models import AccountUsageSnapshot, CreditsBalanceSnapshot, StoredAccount, UsageWindowSnapshot, parse_datetime @@ -19,6 +20,7 @@ REQUEST_TIMEOUT_SECONDS = 30 _SESSION_STATE = threading.local() _USAGE_URL_CACHE: dict[str, tuple[float | None, str]] = {} +UNAUTHORIZED_MESSAGE = "The Codex usage API request returned unauthorized." class CodexApiError(RuntimeError): @@ -48,6 +50,14 @@ def needs_refresh(self) -> bool: return datetime.now(timezone.utc) - self.last_refresh > timedelta(days=8) +@dataclass(slots=True) +class CredentialRecoveryCandidate: + home_path: str + credentials: AuthCredentials + identity: AuthBackedIdentity + freshness: datetime + + def load_identity(codex_home_path: str) -> AuthBackedIdentity: credentials = _load_credentials(codex_home_path) return _identity_from_credentials(credentials) @@ -85,8 +95,14 @@ def fetch_snapshot( credentials = _load_credentials(account.codex_home_path) if credentials.needs_refresh and credentials.refresh_token: - credentials = _refresh(credentials) - _save_credentials(credentials, account.codex_home_path) + try: + credentials = _refresh(credentials) + _save_credentials(credentials, account.codex_home_path) + except CodexApiError: + recovered = _recover_snapshot(account, {credentials.refresh_token}, verify_live_data) + if recovered is not None: + return recovered + raise try: if verify_live_data: @@ -101,11 +117,17 @@ def fetch_snapshot( fallback_email=account.email_hint, ) except CodexApiError as error: - if str(error) != "The Codex usage API request returned unauthorized." or not credentials.refresh_token: + if str(error) != UNAUTHORIZED_MESSAGE or not credentials.refresh_token: raise - credentials = _refresh(credentials) - _save_credentials(credentials, account.codex_home_path) + try: + credentials = _refresh(credentials) + _save_credentials(credentials, account.codex_home_path) + except CodexApiError: + recovered = _recover_snapshot(account, {credentials.refresh_token}, verify_live_data) + if recovered is not None: + return recovered + raise if verify_live_data: return _fetch_verified_snapshot( codex_home_path=account.codex_home_path, @@ -193,11 +215,13 @@ def _load_credentials(codex_home_path: str) -> AuthCredentials: if not access_token: raise CodexApiError("The required token fields are missing from `auth.json`.") + id_token = _string_value(tokens, "id_token") + return AuthCredentials( access_token=access_token, refresh_token=_string_value(tokens, "refresh_token") or "", - id_token=_string_value(tokens, "id_token"), - account_id=_string_value(tokens, "account_id"), + id_token=id_token, + account_id=_string_value(tokens, "account_id") or _account_id_from_id_token(id_token), last_refresh=parse_datetime(payload.get("last_refresh")), ) @@ -271,7 +295,7 @@ def _refresh(credentials: AuthCredentials) -> AuthCredentials: access_token=str(payload.get("access_token") or credentials.access_token), refresh_token=str(payload.get("refresh_token") or credentials.refresh_token), id_token=payload.get("id_token") or credentials.id_token, - account_id=credentials.account_id, + account_id=credentials.account_id or _account_id_from_id_token(payload.get("id_token")), last_refresh=datetime.now(timezone.utc), ) @@ -303,7 +327,7 @@ def _fetch_usage(access_token: str, account_id: str | None, codex_home_path: str raise CodexApiError("The Codex API response was not in the expected format.") if response.status_code in (401, 403): - raise CodexApiError("The Codex usage API request returned unauthorized.") + raise CodexApiError(UNAUTHORIZED_MESSAGE) message = response.text.strip() if message: @@ -525,6 +549,154 @@ def _role_for_window(window: UsageWindowSnapshot) -> str: return "unknown" +def _recover_snapshot( + account: StoredAccount, + excluded_refresh_tokens: set[str], + verify_live_data: bool, +) -> AccountUsageSnapshot | None: + for candidate in _recovery_candidates(account, excluded_refresh_tokens): + credentials = candidate.credentials + + try: + if credentials.needs_refresh and credentials.refresh_token: + credentials = _refresh(credentials) + _save_credentials(credentials, candidate.home_path) + + snapshot = _fetch_candidate_snapshot( + codex_home_path=candidate.home_path, + credentials=credentials, + fallback_email=account.email_hint, + verify_live_data=verify_live_data, + ) + _save_credentials(credentials, account.codex_home_path) + return snapshot + except CodexApiError as error: + if str(error) != UNAUTHORIZED_MESSAGE or not credentials.refresh_token: + continue + + try: + credentials = _refresh(credentials) + _save_credentials(credentials, candidate.home_path) + snapshot = _fetch_candidate_snapshot( + codex_home_path=candidate.home_path, + credentials=credentials, + fallback_email=account.email_hint, + verify_live_data=verify_live_data, + ) + _save_credentials(credentials, account.codex_home_path) + return snapshot + except CodexApiError: + continue + + return None + + +def _fetch_candidate_snapshot( + codex_home_path: str, + credentials: AuthCredentials, + fallback_email: str | None, + verify_live_data: bool, +) -> AccountUsageSnapshot: + if verify_live_data: + return _fetch_verified_snapshot(codex_home_path, credentials, fallback_email) + return _fetch_snapshot(codex_home_path, credentials, fallback_email) + + +def _recovery_candidates( + account: StoredAccount, + excluded_refresh_tokens: set[str], +) -> list[CredentialRecoveryCandidate]: + account_home_path = _standardized_path(Path(account.codex_home_path)) + candidates: list[CredentialRecoveryCandidate] = [] + + for home_path in _recovery_home_paths(): + standardized_home_path = _standardized_path(home_path) + if standardized_home_path == account_home_path: + continue + + try: + credentials = _load_credentials(str(home_path)) + except CodexApiError: + continue + + if credentials.refresh_token in excluded_refresh_tokens: + continue + + identity = _identity_from_credentials(credentials) + if not _identity_matches_account(identity, standardized_home_path, account): + continue + + candidates.append( + CredentialRecoveryCandidate( + home_path=str(home_path), + credentials=credentials, + identity=identity, + freshness=_credentials_freshness(credentials, home_path), + ) + ) + + return sorted(candidates, key=lambda candidate: candidate.freshness, reverse=True) + + +def _recovery_home_paths() -> list[Path]: + homes: list[Path] = [] + if MANAGED_HOMES_DIRECTORY.exists(): + homes.extend(path for path in MANAGED_HOMES_DIRECTORY.iterdir() if path.is_dir()) + homes.append(AMBIENT_CODEX_HOME) + return homes + + +def _identity_matches_account(identity: AuthBackedIdentity, home_path: str, account: StoredAccount) -> bool: + if home_path == _standardized_path(Path(account.codex_home_path)): + return True + + account_provider_id = _normalize_identifier(account.provider_account_id) + identity_provider_id = _normalize_identifier(identity.provider_account_id) + if account_provider_id and identity_provider_id: + return account_provider_id == identity_provider_id + if account_provider_id or identity_provider_id: + return False + + account_subject = _normalize_identifier(account.auth_subject) + identity_subject = _normalize_identifier(identity.auth_subject) + if account_subject and identity_subject and account_subject == identity_subject: + return True + + account_email = _normalize_identifier(account.email_hint) + identity_email = _normalize_identifier(identity.email) + return bool(account_email and identity_email and account_email == identity_email) + + +def _credentials_freshness(credentials: AuthCredentials, home_path: Path) -> datetime: + if credentials.last_refresh is not None: + return credentials.last_refresh + + auth_path = home_path / "auth.json" + try: + return datetime.fromtimestamp(auth_path.stat().st_mtime, tz=timezone.utc) + except OSError: + return datetime.min.replace(tzinfo=timezone.utc) + + +def _account_id_from_id_token(id_token: Any) -> str | None: + payload = _parse_jwt(id_token) if isinstance(id_token, str) else None + if not isinstance(payload, dict): + return None + + auth = payload.get("https://api.openai.com/auth") + auth = auth if isinstance(auth, dict) else {} + return _normalize_string(auth.get("chatgpt_account_id")) or _normalize_string(payload.get("chatgpt_account_id")) + + +def _standardized_path(path: Path) -> str: + return str(path.expanduser().resolve(strict=False)).casefold() + + +def _normalize_identifier(value: str | None) -> str | None: + normalized = _normalize_string(value) + return normalized.lower() if normalized else None + + def _string_value(dictionary: dict[str, Any], key: str) -> str | None: value = dictionary.get(key) if isinstance(value, str) and value: diff --git a/windows/codexcontrol_windows/models.py b/windows/codexcontrol_windows/models.py index 590d37a..1640aa4 100644 --- a/windows/codexcontrol_windows/models.py +++ b/windows/codexcontrol_windows/models.py @@ -5,7 +5,7 @@ from datetime import datetime, timezone from enum import Enum from typing import Any -from uuid import UUID +from uuid import UUID, uuid4 LEGACY_IMPORTED_VALUE = "".join(["imported", "Codex", "Bar"]) @@ -100,6 +100,10 @@ def normalized_email_hint(self) -> str | None: def normalized_auth_subject(self) -> str | None: return normalize_identifier(self.auth_subject) + @property + def normalized_provider_account_id(self) -> str | None: + return normalize_identifier(self.provider_account_id) + @property def standardized_home_path(self) -> str: return os.path.normcase(os.path.abspath(os.path.normpath(self.codex_home_path))) @@ -115,10 +119,14 @@ def recency_date(self) -> datetime: return self.last_authenticated_at or self.updated_at def matches(self, other: "StoredAccount") -> bool: - if self.normalized_auth_subject and self.normalized_auth_subject == other.normalized_auth_subject: - return True if self.standardized_home_path == other.standardized_home_path: return True + if self.normalized_provider_account_id and self.normalized_provider_account_id == other.normalized_provider_account_id: + return True + if self.normalized_provider_account_id or other.normalized_provider_account_id: + return False + if self.normalized_auth_subject and self.normalized_auth_subject == other.normalized_auth_subject: + return True if self.normalized_email_hint and self.normalized_email_hint == other.normalized_email_hint: return True return False @@ -183,6 +191,81 @@ def from_dict(cls, payload: dict[str, Any]) -> "StoredAccount": ) +@dataclass(slots=True) +class RemovedAccountIdentity: + id: UUID + email_hint: str | None + auth_subject: str | None + provider_account_id: str | None + codex_home_path: str + source: StoredAccountSource + removed_at: datetime + + @classmethod + def from_account(cls, account: StoredAccount) -> "RemovedAccountIdentity": + return cls( + id=uuid4(), + email_hint=account.email_hint, + auth_subject=account.auth_subject, + provider_account_id=account.provider_account_id, + codex_home_path=account.codex_home_path, + source=account.source, + removed_at=utc_now(), + ) + + @property + def normalized_email_hint(self) -> str | None: + return normalize_identifier(self.email_hint) + + @property + def normalized_auth_subject(self) -> str | None: + return normalize_identifier(self.auth_subject) + + @property + def normalized_provider_account_id(self) -> str | None: + return normalize_identifier(self.provider_account_id) + + @property + def standardized_home_path(self) -> str: + return os.path.normcase(os.path.abspath(os.path.normpath(self.codex_home_path))) + + def matches(self, account: StoredAccount) -> bool: + if self.standardized_home_path == account.standardized_home_path: + return True + if self.normalized_provider_account_id and self.normalized_provider_account_id == account.normalized_provider_account_id: + return True + if self.normalized_provider_account_id or account.normalized_provider_account_id: + return False + if self.normalized_auth_subject and self.normalized_auth_subject == account.normalized_auth_subject: + return True + if self.normalized_email_hint and self.normalized_email_hint == account.normalized_email_hint: + return True + return False + + def to_dict(self) -> dict[str, Any]: + return { + "id": str(self.id), + "emailHint": self.email_hint, + "authSubject": self.auth_subject, + "providerAccountID": self.provider_account_id, + "codexHomePath": self.codex_home_path, + "source": self.source.value, + "removedAt": format_datetime(self.removed_at), + } + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "RemovedAccountIdentity": + return cls( + id=UUID(str(payload.get("id") or uuid4())), + email_hint=payload.get("emailHint"), + auth_subject=payload.get("authSubject"), + provider_account_id=payload.get("providerAccountID"), + codex_home_path=str(payload.get("codexHomePath") or ""), + source=StoredAccountSource.from_raw(str(payload.get("source") or StoredAccountSource.MANAGED_BY_APP.value)), + removed_at=parse_datetime(payload.get("removedAt")) or utc_now(), + ) + + @dataclass(slots=True) class AccountRuntimeState: snapshot: "AccountUsageSnapshot | None" = None diff --git a/windows/codexcontrol_windows/stores.py b/windows/codexcontrol_windows/stores.py index f208592..e2d0708 100644 --- a/windows/codexcontrol_windows/stores.py +++ b/windows/codexcontrol_windows/stores.py @@ -7,7 +7,7 @@ from uuid import UUID from .file_locations import ACCOUNTS_FILE, SNAPSHOTS_FILE, ensure_directories -from .models import AccountUsageSnapshot, StoredAccount +from .models import AccountUsageSnapshot, RemovedAccountIdentity, StoredAccount def _fold_text(value: str) -> str: @@ -16,21 +16,42 @@ def _fold_text(value: str) -> str: class AccountStore: - current_version = 1 + current_version = 2 def load_accounts(self) -> list[StoredAccount]: + accounts, _ = self.load_account_list() + return accounts + + def load_removed_accounts(self) -> list[RemovedAccountIdentity]: + _, removed_accounts = self.load_account_list() + return removed_accounts + + def load_account_list(self) -> tuple[list[StoredAccount], list[RemovedAccountIdentity]]: if not ACCOUNTS_FILE.exists(): - return [] + return [], [] payload = json.loads(ACCOUNTS_FILE.read_text(encoding="utf-8")) accounts = [StoredAccount.from_dict(item) for item in payload.get("accounts", [])] - return self._sorted(accounts) - - def save_accounts(self, accounts: Iterable[StoredAccount]) -> None: + removed_accounts = [ + RemovedAccountIdentity.from_dict(item) + for item in payload.get("removedAccounts", []) + if isinstance(item, dict) + ] + return self._sorted(accounts), removed_accounts + + def save_accounts( + self, + accounts: Iterable[StoredAccount], + removed_accounts: Iterable[RemovedAccountIdentity] | None = None, + ) -> None: ensure_directories() + if removed_accounts is None: + removed_accounts = self.load_removed_accounts() if ACCOUNTS_FILE.exists() else [] + payload = { "version": self.current_version, "accounts": [account.to_dict() for account in self._sorted(list(accounts))], + "removedAccounts": [removed.to_dict() for removed in removed_accounts], } ACCOUNTS_FILE.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") diff --git a/windows/tests/test_account_manager.py b/windows/tests/test_account_manager.py index 6ad7724..187229f 100644 --- a/windows/tests/test_account_manager.py +++ b/windows/tests/test_account_manager.py @@ -36,6 +36,53 @@ def _write_auth(home_path: Path, email: str, account_id: str) -> None: class CodexAccountManagerTests(unittest.TestCase): + def test_remove_managed_account_removes_duplicate_homes_for_same_provider(self) -> None: + account_id = "83c5ae92-f5ee-41f8-9528-199110d1d0f9" + now = datetime(2026, 4, 23, tzinfo=timezone.utc) + + with TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + managed_homes_dir = root / "managed-homes" + backups_dir = root / "auth-backups" + first_home = managed_homes_dir / "first" + duplicate_home = managed_homes_dir / "duplicate" + other_home = managed_homes_dir / "other" + for home in (first_home, duplicate_home, other_home): + home.mkdir(parents=True) + + _write_auth(first_home, "user@example.com", account_id) + _write_auth(duplicate_home, "user@example.com", account_id) + _write_auth(other_home, "user@example.com", "different-provider") + + manager = CodexAccountManager() + account = StoredAccount( + id=uuid4(), + nickname=None, + email_hint="user@example.com", + auth_subject=f"auth0|{account_id}", + provider_account_id=account_id, + codex_home_path=str(first_home), + source=StoredAccountSource.MANAGED_BY_APP, + created_at=now, + updated_at=now, + last_authenticated_at=now, + ) + + def ensure_dirs() -> None: + managed_homes_dir.mkdir(parents=True, exist_ok=True) + backups_dir.mkdir(parents=True, exist_ok=True) + + with ( + patch("codexcontrol_windows.account_manager.MANAGED_HOMES_DIRECTORY", managed_homes_dir), + patch("codexcontrol_windows.account_manager.AUTH_BACKUPS_DIRECTORY", backups_dir), + patch("codexcontrol_windows.account_manager.ensure_directories", side_effect=ensure_dirs), + ): + manager.remove_managed_files_if_owned(account) + + self.assertFalse(first_home.exists()) + self.assertFalse(duplicate_home.exists()) + self.assertTrue(other_home.exists()) + def test_switch_active_account_updates_global_state_creator_id(self) -> None: old_account_id = "1ea93d04-5c50-42e3-857b-3db850785967" new_account_id = "83c5ae92-f5ee-41f8-9528-199110d1d0f9" diff --git a/windows/tests/test_codex_api.py b/windows/tests/test_codex_api.py index 3e987c8..5fb050e 100644 --- a/windows/tests/test_codex_api.py +++ b/windows/tests/test_codex_api.py @@ -3,12 +3,16 @@ import base64 import json import unittest -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone +from pathlib import Path +from tempfile import TemporaryDirectory from unittest.mock import patch from uuid import uuid4 from codexcontrol_windows.codex_api import ( AuthCredentials, + CodexApiError, + _account_id_from_id_token, _fetch_snapshot, _identity_from_credentials, _normalize_window_roles, @@ -18,6 +22,32 @@ from codexcontrol_windows.models import AccountUsageSnapshot, StoredAccount, StoredAccountSource, UsageWindowSnapshot +def _id_token(email: str, subject: str, account_id: str) -> str: + payload = { + "email": email, + "sub": subject, + "https://api.openai.com/auth": { + "chatgpt_plan_type": "team", + "chatgpt_account_id": account_id, + }, + } + encoded = base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8")).decode("utf-8").rstrip("=") + return f"header.{encoded}.signature" + + +def _write_auth(home_path: Path, access_token: str, refresh_token: str, id_token: str, last_refresh: datetime) -> None: + home_path.mkdir(parents=True, exist_ok=True) + payload = { + "tokens": { + "access_token": access_token, + "refresh_token": refresh_token, + "id_token": id_token, + }, + "last_refresh": last_refresh.isoformat().replace("+00:00", "Z"), + } + (home_path / "auth.json").write_text(json.dumps(payload, indent=2), encoding="utf-8") + + class CodexApiTests(unittest.TestCase): def test_identity_from_credentials_uses_id_token_payload(self) -> None: payload = { @@ -44,6 +74,11 @@ def test_identity_from_credentials_uses_id_token_payload(self) -> None: self.assertEqual(identity.plan, "team") self.assertEqual(identity.provider_account_id, "provider-1") + def test_account_id_from_id_token_is_used_when_token_field_is_missing(self) -> None: + id_token = _id_token("user@example.com", "auth0|user", "provider-1") + + self.assertEqual(_account_id_from_id_token(id_token), "provider-1") + def test_parse_chatgpt_base_url(self) -> None: contents = """ # comment @@ -147,6 +182,68 @@ def test_raw_fetch_snapshot_uses_credentials_identity_without_reloading_auth(sel self.assertEqual(snapshot.provider_account_id, "provider-1") self.assertEqual(snapshot.plan, "team") + def test_fetch_snapshot_recovers_from_duplicate_home_when_refresh_token_is_stale(self) -> None: + now = datetime(2026, 4, 18, tzinfo=timezone.utc) + stale = now - timedelta(days=30) + provider_id = "provider-1" + token = _id_token("user@example.com", "auth0|user", provider_id) + snapshot = AccountUsageSnapshot( + email="user@example.com", + provider_account_id=provider_id, + plan="team", + allowed=True, + limit_reached=False, + primary_window=UsageWindowSnapshot(used_percent=30.0, reset_at=now, limit_window_seconds=18_000), + secondary_window=None, + credits=None, + updated_at=now, + ) + + with TemporaryDirectory() as temp_dir: + root = Path(temp_dir) + selected_home = root / "managed-homes" / "selected" + recovery_home = root / "managed-homes" / "recovery" + ambient_home = root / ".codex" + _write_auth(selected_home, "stale-access", "stale-refresh", token, stale) + _write_auth(recovery_home, "recovery-access", "recovery-refresh", token, now) + + account = StoredAccount( + id=uuid4(), + nickname=None, + email_hint="user@example.com", + auth_subject="auth0|user", + provider_account_id=provider_id, + codex_home_path=str(selected_home), + source=StoredAccountSource.MANAGED_BY_APP, + created_at=now, + updated_at=now, + last_authenticated_at=now, + ) + + def refresh(credentials: AuthCredentials) -> AuthCredentials: + if credentials.refresh_token == "stale-refresh": + raise CodexApiError("The refresh token can no longer be reused. Sign in again for this account.") + return AuthCredentials( + access_token="fresh-access", + refresh_token="fresh-refresh", + id_token=credentials.id_token, + account_id=credentials.account_id, + last_refresh=now, + ) + + with ( + patch("codexcontrol_windows.codex_api.MANAGED_HOMES_DIRECTORY", root / "managed-homes"), + patch("codexcontrol_windows.codex_api.AMBIENT_CODEX_HOME", ambient_home), + patch("codexcontrol_windows.codex_api._refresh", side_effect=refresh), + patch("codexcontrol_windows.codex_api._fetch_snapshot", return_value=snapshot), + ): + result = fetch_snapshot(account, verify_live_data=False) + + self.assertIs(result, snapshot) + selected_payload = json.loads((selected_home / "auth.json").read_text(encoding="utf-8")) + self.assertEqual(selected_payload["tokens"]["refresh_token"], "fresh-refresh") + self.assertEqual(selected_payload["tokens"]["account_id"], provider_id) + if __name__ == "__main__": unittest.main() diff --git a/windows/tests/test_models.py b/windows/tests/test_models.py index a8e1204..5097a08 100644 --- a/windows/tests/test_models.py +++ b/windows/tests/test_models.py @@ -6,6 +6,7 @@ from codexcontrol_windows.models import ( AccountUsageSnapshot, + RemovedAccountIdentity, StoredAccount, StoredAccountSource, UsageWindowSnapshot, @@ -46,6 +47,63 @@ def test_merge_prefers_managed_account_identity(self) -> None: self.assertEqual(original.provider_account_id, "account-1") self.assertEqual(original.codex_home_path, "C:/temp/b") + def test_matches_keeps_different_provider_accounts_separate(self) -> None: + created_at = datetime(2026, 4, 18, tzinfo=timezone.utc) + first = StoredAccount( + id=uuid4(), + nickname=None, + email_hint="user@example.com", + auth_subject="auth0|same-user", + provider_account_id="account-1", + codex_home_path="C:/temp/a", + source=StoredAccountSource.MANAGED_BY_APP, + created_at=created_at, + updated_at=created_at, + ) + second = StoredAccount( + id=uuid4(), + nickname=None, + email_hint="user@example.com", + auth_subject="auth0|same-user", + provider_account_id="account-2", + codex_home_path="C:/temp/b", + source=StoredAccountSource.MANAGED_BY_APP, + created_at=created_at, + updated_at=created_at, + ) + + self.assertFalse(first.matches(second)) + + def test_removed_identity_matches_provider_not_shared_email(self) -> None: + created_at = datetime(2026, 4, 18, tzinfo=timezone.utc) + removed_account = StoredAccount( + id=uuid4(), + nickname=None, + email_hint="user@example.com", + auth_subject="auth0|same-user", + provider_account_id="account-1", + codex_home_path="C:/temp/a", + source=StoredAccountSource.MANAGED_BY_APP, + created_at=created_at, + updated_at=created_at, + ) + other_account = StoredAccount( + id=uuid4(), + nickname=None, + email_hint="user@example.com", + auth_subject="auth0|same-user", + provider_account_id="account-2", + codex_home_path="C:/temp/b", + source=StoredAccountSource.MANAGED_BY_APP, + created_at=created_at, + updated_at=created_at, + ) + + removed = RemovedAccountIdentity.from_account(removed_account) + + self.assertTrue(removed.matches(removed_account)) + self.assertFalse(removed.matches(other_account)) + class SnapshotTests(unittest.TestCase): def test_snapshot_prefers_lowest_remaining_window(self) -> None: