From 4de2a1584359df264425684334e139b71b6f4f12 Mon Sep 17 00:00:00 2001 From: John Logan Date: Thu, 12 Mar 2026 14:16:16 -0700 Subject: [PATCH 1/6] Use Sendable DNS types. - Closes #1268. - The types we were using weren't very usable with Swift 6 structured concurrency. - Use notImplemented instead of formatError for unknown record types. - Use pure actor for LocalhostDNSHandler now that we have sendable types. --- Package.resolved | 20 +- Package.swift | 8 +- .../Handlers/HostTableResolver.swift | 27 +- .../DNSServer/Handlers/NxDomainResolver.swift | 22 +- Sources/DNSServer/Records/Bindable.swift | 100 +++++ Sources/DNSServer/Records/DNSEnums.swift | 167 ++++++++ Sources/DNSServer/Records/DNSName.swift | 142 +++++++ .../DNSServer/Records/IPAddressProtocol.swift | 34 ++ Sources/DNSServer/Records/Message.swift | 257 ++++++++++++ Sources/DNSServer/Records/Question.swift | 96 +++++ .../DNSServer/Records/ResourceRecord.swift | 97 +++++ Sources/DNSServer/Types.swift | 8 - .../APIServer/ContainerDNSHandler.swift | 29 +- .../APIServer/LocalhostDNSHandler.swift | 43 +- .../CompositeResolverTest.swift | 10 +- .../HostTableResolverTest.swift | 26 +- Tests/DNSServerTests/MockHandlers.swift | 14 +- .../DNSServerTests/NxDomainResolverTest.swift | 1 - Tests/DNSServerTests/RecordsTests.swift | 389 ++++++++++++++++++ .../StandardQueryValidatorTest.swift | 6 +- 20 files changed, 1333 insertions(+), 163 deletions(-) create mode 100644 Sources/DNSServer/Records/Bindable.swift create mode 100644 Sources/DNSServer/Records/DNSEnums.swift create mode 100644 Sources/DNSServer/Records/DNSName.swift create mode 100644 Sources/DNSServer/Records/IPAddressProtocol.swift create mode 100644 Sources/DNSServer/Records/Message.swift create mode 100644 Sources/DNSServer/Records/Question.swift create mode 100644 Sources/DNSServer/Records/ResourceRecord.swift create mode 100644 Tests/DNSServerTests/RecordsTests.swift diff --git a/Package.resolved b/Package.resolved index d1aabdfd4..e3cf72916 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "040dd2e2c8649defb737f900e4032d270935f0af91ebf4a87d64391ecd4ea40b", + "originHash" : "4ec05f4e83999a89d3397d0657536924d4a425d7f0e3f0fd6a3578e34c924502", "pins" : [ { "identity" : "async-http-client", @@ -19,24 +19,6 @@ "version" : "0.27.0" } }, - { - "identity" : "dns", - "kind" : "remoteSourceControl", - "location" : "https://github.com/Bouke/DNS.git", - "state" : { - "revision" : "78bbd1589890a90b202d11d5f9e1297050cf0eb2", - "version" : "1.2.0" - } - }, - { - "identity" : "dnsclient", - "kind" : "remoteSourceControl", - "location" : "https://github.com/orlandos-nl/DNSClient.git", - "state" : { - "revision" : "551fbddbf4fa728d4cd86f6a5208fe4f925f0549", - "version" : "2.4.4" - } - }, { "identity" : "grpc-swift", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 3a24b7971..8a8ee49bb 100644 --- a/Package.swift +++ b/Package.swift @@ -47,7 +47,6 @@ let package = Package( .library(name: "TerminalProgress", targets: ["TerminalProgress"]), ], dependencies: [ - .package(url: "https://github.com/Bouke/DNS.git", from: "1.2.0"), .package(url: "https://github.com/apple/containerization.git", exact: Version(stringLiteral: scVersion)), .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.3.0"), .package(url: "https://github.com/apple/swift-collections.git", from: "1.2.0"), @@ -56,7 +55,6 @@ let package = Package( .package(url: "https://github.com/apple/swift-protobuf.git", from: "1.29.0"), .package(url: "https://github.com/apple/swift-system.git", from: "1.4.0"), .package(url: "https://github.com/grpc/grpc-swift.git", from: "1.26.0"), - .package(url: "https://github.com/orlandos-nl/DNSClient.git", from: "2.4.1"), .package(url: "https://github.com/swift-server/async-http-client.git", from: "1.20.1"), .package(url: "https://github.com/swiftlang/swift-docc-plugin.git", from: "1.1.0"), ], @@ -427,17 +425,15 @@ let package = Package( dependencies: [ .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), - .product(name: "DNSClient", package: "DNSClient"), - .product(name: "DNS", package: "DNS"), .product(name: "Logging", package: "swift-log"), + .product(name: "ContainerizationExtras", package: "containerization"), .product(name: "ContainerizationOS", package: "containerization"), ] ), .testTarget( name: "DNSServerTests", dependencies: [ - .product(name: "DNS", package: "DNS"), - "DNSServer", + "DNSServer" ] ), .testTarget( diff --git a/Sources/DNSServer/Handlers/HostTableResolver.swift b/Sources/DNSServer/Handlers/HostTableResolver.swift index 0bc247609..ea8d2dce4 100644 --- a/Sources/DNSServer/Handlers/HostTableResolver.swift +++ b/Sources/DNSServer/Handlers/HostTableResolver.swift @@ -14,14 +14,14 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS +import ContainerizationExtras /// Handler that uses table lookup to resolve hostnames. public struct HostTableResolver: DNSHandler { - public let hosts4: [String: IPv4] + public let hosts4: [String: IPv4Address] private let ttl: UInt32 - public init(hosts4: [String: IPv4], ttl: UInt32 = 300) { + public init(hosts4: [String: IPv4Address], ttl: UInt32 = 300) { self.hosts4 = hosts4 self.ttl = ttl } @@ -48,28 +48,11 @@ public struct HostTableResolver: DNSHandler { } // If hostname doesn't exist, return nil which will become NXDOMAIN return nil - case ResourceRecordType.nameServer, - ResourceRecordType.alias, - ResourceRecordType.startOfAuthority, - ResourceRecordType.pointer, - ResourceRecordType.mailExchange, - ResourceRecordType.text, - ResourceRecordType.service, - ResourceRecordType.incrementalZoneTransfer, - ResourceRecordType.standardZoneTransfer, - ResourceRecordType.all: - return Message( - id: query.id, - type: .response, - returnCode: .notImplemented, - questions: query.questions, - answers: [] - ) default: return Message( id: query.id, type: .response, - returnCode: .formatError, + returnCode: .notImplemented, questions: query.questions, answers: [] ) @@ -93,6 +76,6 @@ public struct HostTableResolver: DNSHandler { return nil } - return HostRecord(name: question.name, ttl: ttl, ip: ip) + return HostRecord(name: question.name, ttl: ttl, ip: ip) } } diff --git a/Sources/DNSServer/Handlers/NxDomainResolver.swift b/Sources/DNSServer/Handlers/NxDomainResolver.swift index 68b36bf1d..8fa9c05b4 100644 --- a/Sources/DNSServer/Handlers/NxDomainResolver.swift +++ b/Sources/DNSServer/Handlers/NxDomainResolver.swift @@ -14,8 +14,6 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS - /// Handler that returns NXDOMAIN for all hostnames. public struct NxDomainResolver: DNSHandler { private let ttl: UInt32 @@ -35,29 +33,11 @@ public struct NxDomainResolver: DNSHandler { questions: query.questions, answers: [] ) - case ResourceRecordType.nameServer, - ResourceRecordType.alias, - ResourceRecordType.startOfAuthority, - ResourceRecordType.pointer, - ResourceRecordType.mailExchange, - ResourceRecordType.text, - ResourceRecordType.host6, - ResourceRecordType.service, - ResourceRecordType.incrementalZoneTransfer, - ResourceRecordType.standardZoneTransfer, - ResourceRecordType.all: - return Message( - id: query.id, - type: .response, - returnCode: .notImplemented, - questions: query.questions, - answers: [] - ) default: return Message( id: query.id, type: .response, - returnCode: .formatError, + returnCode: .notImplemented, questions: query.questions, answers: [] ) diff --git a/Sources/DNSServer/Records/Bindable.swift b/Sources/DNSServer/Records/Bindable.swift new file mode 100644 index 000000000..775d38f31 --- /dev/null +++ b/Sources/DNSServer/Records/Bindable.swift @@ -0,0 +1,100 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import Foundation + +// TODO: Look for a way that we can make use of the +// bit-fiddling types from ContainerizationExtras, instead +// of copying them here. + +/// Errors that can occur during DNS message serialization/deserialization. +public enum DNSBindError: Error, CustomStringConvertible { + case marshalFailure(type: String, field: String) + case unmarshalFailure(type: String, field: String) + + public var description: String { + switch self { + case .marshalFailure(let type, let field): + return "failed to marshal \(type).\(field)" + case .unmarshalFailure(let type, let field): + return "failed to unmarshal \(type).\(field)" + } + } +} + +/// Protocol for types that can be serialized to/from a byte buffer. +protocol Bindable: Sendable { + /// The fixed size of this type in bytes, if applicable. + static var size: Int { get } + + /// Serialize this value into the buffer at the given offset. + /// - Returns: The new offset after writing. + func appendBuffer(_ buffer: inout [UInt8], offset: Int) throws -> Int + + /// Deserialize this value from the buffer at the given offset. + /// - Returns: The new offset after reading. + mutating func bindBuffer(_ buffer: inout [UInt8], offset: Int) throws -> Int +} + +extension [UInt8] { + /// Copy a value into the buffer at the given offset. + /// - Returns: The new offset after writing, or nil if the buffer is too small. + package mutating func copyIn(as type: T.Type, value: T, offset: Int = 0) -> Int? { + let size = MemoryLayout.size + guard self.count >= size + offset else { + return nil + } + return self.withUnsafeMutableBytes { + $0.baseAddress?.advanced(by: offset).assumingMemoryBound(to: T.self).pointee = value + return offset + size + } + } + + /// Copy a value out of the buffer at the given offset. + /// - Returns: A tuple of (new offset, value), or nil if the buffer is too small. + package func copyOut(as type: T.Type, offset: Int = 0) -> (Int, T)? { + let size = MemoryLayout.size + guard self.count >= size + offset else { + return nil + } + return self.withUnsafeBytes { + guard let value = $0.baseAddress?.advanced(by: offset).assumingMemoryBound(to: T.self).pointee else { + return nil + } + return (offset + size, value) + } + } + + /// Copy a byte array into the buffer at the given offset. + /// - Returns: The new offset after writing, or nil if the buffer is too small. + package mutating func copyIn(buffer: [UInt8], offset: Int = 0) -> Int? { + guard offset + buffer.count <= self.count else { + return nil + } + self[offset.. Int? { + guard offset + buffer.count <= self.count else { + return nil + } + buffer[0.. Int { + var offset = offset + + for label in labels { + let bytes = Array(label.utf8) + guard bytes.count <= 63 else { + throw DNSBindError.marshalFailure(type: "DNSName", field: "label too long") + } + + guard let newOffset = buffer.copyIn(as: UInt8.self, value: UInt8(bytes.count), offset: offset) else { + throw DNSBindError.marshalFailure(type: "DNSName", field: "label length") + } + offset = newOffset + + guard let newOffset = buffer.copyIn(buffer: bytes, offset: offset) else { + throw DNSBindError.marshalFailure(type: "DNSName", field: "label") + } + offset = newOffset + } + + // Null terminator + guard let newOffset = buffer.copyIn(as: UInt8.self, value: 0, offset: offset) else { + throw DNSBindError.marshalFailure(type: "DNSName", field: "terminator") + } + + return newOffset + } + + /// Deserialize a name from the buffer at the given offset. + /// + /// - Parameters: + /// - buffer: The buffer to read from. + /// - offset: The offset to start reading. + /// - messageStart: The start of the DNS message (for compression pointer resolution). + /// - Returns: The new offset after reading. + public mutating func bindBuffer( + _ buffer: inout [UInt8], + offset: Int, + messageStart: Int = 0 + ) throws -> Int { + var offset = offset + labels = [] + var jumped = false + var returnOffset = offset + + while true { + guard offset < buffer.count else { + throw DNSBindError.unmarshalFailure(type: "DNSName", field: "unexpected end") + } + + let length = buffer[offset] + + // Check for compression pointer (top 2 bits set) + if (length & 0xC0) == 0xC0 { + guard offset + 1 < buffer.count else { + throw DNSBindError.unmarshalFailure(type: "DNSName", field: "compression pointer") + } + + if !jumped { + returnOffset = offset + 2 + } + + // Calculate pointer offset from message start + let pointer = Int(length & 0x3F) << 8 | Int(buffer[offset + 1]) + offset = messageStart + pointer + jumped = true + continue + } + + offset += 1 + + // Null terminator - end of name + if length == 0 { + break + } + + guard offset + Int(length) <= buffer.count else { + throw DNSBindError.unmarshalFailure(type: "DNSName", field: "label") + } + + let labelBytes = Array(buffer[offset..> 11) & 0x0F)) ?? .query + self.authoritativeAnswer = (flags & 0x0400) != 0 + self.truncation = (flags & 0x0200) != 0 + self.recursionDesired = (flags & 0x0100) != 0 + self.recursionAvailable = (flags & 0x0080) != 0 + self.returnCode = ReturnCode(rawValue: UInt8(flags & 0x000F)) ?? .noError + + // Read counts + guard let (newOffset, rawQdCount) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Message", field: "qdcount") + } + let qdCount = UInt16(bigEndian: rawQdCount) + offset = newOffset + + guard let (newOffset, rawAnCount) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Message", field: "ancount") + } + let anCount = UInt16(bigEndian: rawAnCount) + offset = newOffset + + guard let (newOffset, rawNsCount) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Message", field: "nscount") + } + // nsCount not used for now, but we need to read past it + _ = UInt16(bigEndian: rawNsCount) + offset = newOffset + + guard let (newOffset, rawArCount) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Message", field: "arcount") + } + // arCount not used for now, but we need to read past it + _ = UInt16(bigEndian: rawArCount) + offset = newOffset + + // Read questions + self.questions = [] + for _ in 0.. Data { + // Calculate buffer size (estimate) + var bufferSize = Self.headerSize + for question in questions { + bufferSize += DNSName(question.name).size + 4 // name + type + class + } + for answer in answers { + bufferSize += DNSName(answer.name).size + 10 + 16 // name + type + class + ttl + rdlen + rdata (max) + } + bufferSize += 64 // padding for safety + + var buffer = [UInt8](repeating: 0, count: bufferSize) + var offset = 0 + + // Write ID + guard let newOffset = buffer.copyIn(as: UInt16.self, value: id.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "id") + } + offset = newOffset + + // Build and write flags + var flags: UInt16 = 0 + flags |= type == .response ? 0x8000 : 0 + flags |= UInt16(operationCode.rawValue) << 11 + flags |= authoritativeAnswer ? 0x0400 : 0 + flags |= truncation ? 0x0200 : 0 + flags |= recursionDesired ? 0x0100 : 0 + flags |= recursionAvailable ? 0x0080 : 0 + flags |= UInt16(returnCode.rawValue) + + guard let newOffset = buffer.copyIn(as: UInt16.self, value: flags.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "flags") + } + offset = newOffset + + // Write counts + guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(questions.count).bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "qdcount") + } + offset = newOffset + + guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(answers.count).bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "ancount") + } + offset = newOffset + + guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(authorities.count).bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "nscount") + } + offset = newOffset + + guard let newOffset = buffer.copyIn(as: UInt16.self, value: UInt16(additional.count).bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Message", field: "arcount") + } + offset = newOffset + + // Write questions + for question in questions { + offset = try question.appendBuffer(&buffer, offset: offset) + } + + // Write answers + for answer in answers { + offset = try answer.appendBuffer(&buffer, offset: offset) + } + + // Write authorities + for authority in authorities { + offset = try authority.appendBuffer(&buffer, offset: offset) + } + + // Write additional + for record in additional { + offset = try record.appendBuffer(&buffer, offset: offset) + } + + return Data(buffer[0.. Int { + var offset = offset + + // Write name + let dnsName = DNSName(name) + offset = try dnsName.appendBuffer(&buffer, offset: offset) + + // Write type (big-endian) + guard let newOffset = buffer.copyIn(as: UInt16.self, value: type.rawValue.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Question", field: "type") + } + offset = newOffset + + // Write class (big-endian) + guard let newOffset = buffer.copyIn(as: UInt16.self, value: recordClass.rawValue.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "Question", field: "class") + } + + return newOffset + } + + /// Deserialize a question from the buffer. + public mutating func bindBuffer(_ buffer: inout [UInt8], offset: Int, messageStart: Int = 0) throws -> Int { + var offset = offset + + // Read name + var dnsName = DNSName() + offset = try dnsName.bindBuffer(&buffer, offset: offset, messageStart: messageStart) + self.name = dnsName.description + + // Read type (big-endian) + guard let (newOffset, rawType) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Question", field: "type") + } + guard let qtype = ResourceRecordType(rawValue: UInt16(bigEndian: rawType)) else { + throw DNSBindError.unmarshalFailure(type: "Question", field: "type value") + } + self.type = qtype + offset = newOffset + + // Read class (big-endian) + guard let (newOffset, rawClass) = buffer.copyOut(as: UInt16.self, offset: offset) else { + throw DNSBindError.unmarshalFailure(type: "Question", field: "class") + } + guard let qclass = ResourceRecordClass(rawValue: UInt16(bigEndian: rawClass)) else { + throw DNSBindError.unmarshalFailure(type: "Question", field: "class value") + } + self.recordClass = qclass + + return newOffset + } +} diff --git a/Sources/DNSServer/Records/ResourceRecord.swift b/Sources/DNSServer/Records/ResourceRecord.swift new file mode 100644 index 000000000..a193dbd24 --- /dev/null +++ b/Sources/DNSServer/Records/ResourceRecord.swift @@ -0,0 +1,97 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import Foundation + +/// Protocol for DNS resource records. +public protocol ResourceRecord: Sendable { + /// The domain name this record applies to. + var name: String { get } + + /// The record type. + var type: ResourceRecordType { get } + + /// The record class. + var recordClass: ResourceRecordClass { get } + + /// Time to live in seconds. + var ttl: UInt32 { get } + + /// Serialize this record into the buffer. + func appendBuffer(_ buffer: inout [UInt8], offset: Int) throws -> Int +} + +/// A host record (A or AAAA) containing an IP address. +public struct HostRecord: ResourceRecord { + public let name: String + public let type: ResourceRecordType + public let recordClass: ResourceRecordClass + public let ttl: UInt32 + public let ip: T + + public init( + name: String, + ttl: UInt32 = 300, + ip: T, + recordClass: ResourceRecordClass = .internet + ) { + self.name = name + self.type = T.recordType + self.recordClass = recordClass + self.ttl = ttl + self.ip = ip + } + + public func appendBuffer(_ buffer: inout [UInt8], offset: Int) throws -> Int { + var offset = offset + + // Write name + let dnsName = DNSName(name) + offset = try dnsName.appendBuffer(&buffer, offset: offset) + + // Write type (big-endian) + guard let newOffset = buffer.copyIn(as: UInt16.self, value: type.rawValue.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "HostRecord", field: "type") + } + offset = newOffset + + // Write class (big-endian) + guard let newOffset = buffer.copyIn(as: UInt16.self, value: recordClass.rawValue.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "HostRecord", field: "class") + } + offset = newOffset + + // Write TTL (big-endian) + guard let newOffset = buffer.copyIn(as: UInt32.self, value: ttl.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "HostRecord", field: "ttl") + } + offset = newOffset + + // Write rdlength (big-endian) + let rdlength = UInt16(T.size) + guard let newOffset = buffer.copyIn(as: UInt16.self, value: rdlength.bigEndian, offset: offset) else { + throw DNSBindError.marshalFailure(type: "HostRecord", field: "rdlength") + } + offset = newOffset + + // Write IP address bytes + guard let newOffset = buffer.copyIn(buffer: ip.bytes, offset: offset) else { + throw DNSBindError.marshalFailure(type: "HostRecord", field: "rdata") + } + + return newOffset + } +} diff --git a/Sources/DNSServer/Types.swift b/Sources/DNSServer/Types.swift index 60e5cd6e0..796d61911 100644 --- a/Sources/DNSServer/Types.swift +++ b/Sources/DNSServer/Types.swift @@ -14,16 +14,8 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS import Foundation -public typealias Message = DNS.Message -public typealias ResourceRecord = DNS.ResourceRecord -public typealias HostRecord = DNS.HostRecord -public typealias IPv4 = DNS.IPv4 -public typealias IPv6 = DNS.IPv6 -public typealias ReturnCode = DNS.ReturnCode - public enum DNSResolverError: Swift.Error, CustomStringConvertible { case serverError(_ msg: String) case invalidHandlerSpec(_ spec: String) diff --git a/Sources/Helpers/APIServer/ContainerDNSHandler.swift b/Sources/Helpers/APIServer/ContainerDNSHandler.swift index 4fcf8a2c4..161572e5b 100644 --- a/Sources/Helpers/APIServer/ContainerDNSHandler.swift +++ b/Sources/Helpers/APIServer/ContainerDNSHandler.swift @@ -15,7 +15,7 @@ //===----------------------------------------------------------------------===// import ContainerAPIService -import DNS +import ContainerizationExtras import DNSServer /// Handler that uses table lookup to resolve hostnames. @@ -50,28 +50,11 @@ struct ContainerDNSHandler: DNSHandler { ) } record = result.record - case ResourceRecordType.nameServer, - ResourceRecordType.alias, - ResourceRecordType.startOfAuthority, - ResourceRecordType.pointer, - ResourceRecordType.mailExchange, - ResourceRecordType.text, - ResourceRecordType.service, - ResourceRecordType.incrementalZoneTransfer, - ResourceRecordType.standardZoneTransfer, - ResourceRecordType.all: - return Message( - id: query.id, - type: .response, - returnCode: .notImplemented, - questions: query.questions, - answers: [] - ) default: return Message( id: query.id, type: .response, - returnCode: .formatError, + returnCode: .notImplemented, questions: query.questions, answers: [] ) @@ -95,11 +78,11 @@ struct ContainerDNSHandler: DNSHandler { return nil } let ipv4 = ipAllocation.ipv4Address.address.description - guard let ip = IPv4(ipv4) else { + guard let ip = try? IPv4Address(ipv4) else { throw DNSResolverError.serverError("failed to parse IP address: \(ipv4)") } - return HostRecord(name: question.name, ttl: ttl, ip: ip) + return HostRecord(name: question.name, ttl: ttl, ip: ip) } private func answerHost6(question: Question) async throws -> (record: ResourceRecord?, hostnameExists: Bool) { @@ -110,10 +93,10 @@ struct ContainerDNSHandler: DNSHandler { return (nil, true) } let ipv6 = ipv6Address.address.description - guard let ip = IPv6(ipv6) else { + guard let ip = try? IPv6Address(ipv6) else { throw DNSResolverError.serverError("failed to parse IPv6 address: \(ipv6)") } - return (HostRecord(name: question.name, ttl: ttl, ip: ip), true) + return (HostRecord(name: question.name, ttl: ttl, ip: ip), true) } } diff --git a/Sources/Helpers/APIServer/LocalhostDNSHandler.swift b/Sources/Helpers/APIServer/LocalhostDNSHandler.swift index cf87badb4..292eea230 100644 --- a/Sources/Helpers/APIServer/LocalhostDNSHandler.swift +++ b/Sources/Helpers/APIServer/LocalhostDNSHandler.swift @@ -18,7 +18,7 @@ import ContainerAPIClient import ContainerOS import ContainerPersistence import ContainerizationError -import DNS +import ContainerizationExtras import DNSServer import Foundation import Logging @@ -28,42 +28,41 @@ actor LocalhostDNSHandler: DNSHandler { private let ttl: UInt32 private let watcher: DirectoryWatcher - private let dns: Mutex<[String: IPv4]> + private var dns: [String: IPv4Address] public init(resolversURL: URL = HostDNSResolver.defaultConfigPath, ttl: UInt32 = 5, log: Logger) { self.ttl = ttl self.watcher = DirectoryWatcher(directoryURL: resolversURL, log: log) - self.dns = Mutex([:]) + self.dns = [:] } public func monitorResolvers() async { - await self.watcher.startWatching { fileURLs in - var dns: [String: String] = [:] + await self.watcher.startWatching { [weak self] fileURLs in + var dns: [String: IPv4Address] = [:] let regex = try Regex(HostDNSResolver.localhostOptionsRegex) for file in fileURLs.filter({ $0.lastPathComponent.starts(with: HostDNSResolver.containerizationPrefix) }) { let content = try String(contentsOf: file, encoding: .utf8) if let match = content.firstMatch(of: regex), - let ipv4 = (match[1].substring.map { String($0) }) + let ipv4 = (match[1].substring.map { try? IPv4Address(String($0)) }) { let name = String(file.lastPathComponent.dropFirst(HostDNSResolver.containerizationPrefix.count)) dns[name + "."] = ipv4 } } - self.dns.withLock { $0 = dns.compactMapValues { IPv4($0) } } + Task { await self?.updateDNS(dns) } } } - nonisolated public func answer(query: Message) async throws -> Message? { + public func answer(query: Message) async throws -> Message? { let question = query.questions[0] var record: ResourceRecord? switch question.type { case ResourceRecordType.host: - let dns = dns.withLock { $0 } if let ip = dns[question.name] { - record = HostRecord(name: question.name, ttl: ttl, ip: ip) + record = HostRecord(name: question.name, ttl: ttl, ip: ip) } case ResourceRecordType.host6: return Message( @@ -73,28 +72,11 @@ actor LocalhostDNSHandler: DNSHandler { questions: query.questions, answers: [] ) - case ResourceRecordType.nameServer, - ResourceRecordType.alias, - ResourceRecordType.startOfAuthority, - ResourceRecordType.pointer, - ResourceRecordType.mailExchange, - ResourceRecordType.text, - ResourceRecordType.service, - ResourceRecordType.incrementalZoneTransfer, - ResourceRecordType.standardZoneTransfer, - ResourceRecordType.all: - return Message( - id: query.id, - type: .response, - returnCode: .notImplemented, - questions: query.questions, - answers: [] - ) default: return Message( id: query.id, type: .response, - returnCode: .formatError, + returnCode: .notImplemented, questions: query.questions, answers: [] ) @@ -112,4 +94,9 @@ actor LocalhostDNSHandler: DNSHandler { answers: [record] ) } + + private func updateDNS(_ dns: [String: IPv4Address]) { + self.dns = dns + } + } diff --git a/Tests/DNSServerTests/CompositeResolverTest.swift b/Tests/DNSServerTests/CompositeResolverTest.swift index 227c6b997..6698e8296 100644 --- a/Tests/DNSServerTests/CompositeResolverTest.swift +++ b/Tests/DNSServerTests/CompositeResolverTest.swift @@ -14,7 +14,7 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS +import ContainerizationExtras import Testing @testable import DNSServer @@ -36,8 +36,8 @@ struct CompositeResolverTest { #expect(.noError == fooResponse?.returnCode) #expect(1 == fooResponse?.id) #expect(1 == fooResponse?.answers.count) - let fooAnswer = fooResponse?.answers[0] as? HostRecord - #expect(IPv4("1.2.3.4") == fooAnswer?.ip) + let fooAnswer = fooResponse?.answers[0] as? HostRecord + #expect(try IPv4Address("1.2.3.4") == fooAnswer?.ip) let barQuery = Message( id: UInt16(1), @@ -50,8 +50,8 @@ struct CompositeResolverTest { #expect(.noError == barResponse?.returnCode) #expect(1 == barResponse?.id) #expect(1 == barResponse?.answers.count) - let barAnswer = barResponse?.answers[0] as? HostRecord - #expect(IPv4("5.6.7.8") == barAnswer?.ip) + let barAnswer = barResponse?.answers[0] as? HostRecord + #expect(try IPv4Address("5.6.7.8") == barAnswer?.ip) let otherQuery = Message( id: UInt16(1), diff --git a/Tests/DNSServerTests/HostTableResolverTest.swift b/Tests/DNSServerTests/HostTableResolverTest.swift index 1a7aff2cc..141fc73db 100644 --- a/Tests/DNSServerTests/HostTableResolverTest.swift +++ b/Tests/DNSServerTests/HostTableResolverTest.swift @@ -14,16 +14,14 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS +import ContainerizationExtras import Testing @testable import DNSServer struct HostTableResolverTest { @Test func testUnsupportedQuestionType() async throws { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } + let ip = try IPv4Address("1.2.3.4") let handler = HostTableResolver(hosts4: ["foo": ip]) let query = Message( @@ -43,9 +41,7 @@ struct HostTableResolverTest { } @Test func testAAAAQueryReturnsNoDataWhenARecordExists() async throws { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } + let ip = try IPv4Address("1.2.3.4") let handler = HostTableResolver(hosts4: ["foo": ip]) let query = Message( @@ -67,9 +63,7 @@ struct HostTableResolverTest { } @Test func testAAAAQueryReturnsNilWhenHostDoesNotExist() async throws { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } + let ip = try IPv4Address("1.2.3.4") let handler = HostTableResolver(hosts4: ["foo": ip]) let query = Message( @@ -86,9 +80,7 @@ struct HostTableResolverTest { } @Test func testHostNotPresent() async throws { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } + let ip = try IPv4Address("1.2.3.4") let handler = HostTableResolver(hosts4: ["foo": ip]) let query = Message( @@ -104,9 +96,7 @@ struct HostTableResolverTest { } @Test func testHostPresent() async throws { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } + let ip = try IPv4Address("1.2.3.4") let handler = HostTableResolver(hosts4: ["foo": ip]) let query = Message( @@ -125,7 +115,7 @@ struct HostTableResolverTest { #expect("foo" == response?.questions[0].name) #expect(.host == response?.questions[0].type) #expect(1 == response?.answers.count) - let answer = response?.answers[0] as? HostRecord - #expect(IPv4("1.2.3.4") == answer?.ip) + let answer = response?.answers[0] as? HostRecord + #expect(try IPv4Address("1.2.3.4") == answer?.ip) } } diff --git a/Tests/DNSServerTests/MockHandlers.swift b/Tests/DNSServerTests/MockHandlers.swift index 6496c7e04..306540436 100644 --- a/Tests/DNSServerTests/MockHandlers.swift +++ b/Tests/DNSServerTests/MockHandlers.swift @@ -14,7 +14,7 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS +import ContainerizationExtras import Testing @testable import DNSServer @@ -22,15 +22,13 @@ import Testing struct FooHandler: DNSHandler { public func answer(query: Message) async throws -> Message? { if query.questions[0].name == "foo" { - guard let ip = IPv4("1.2.3.4") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } + let ip = try IPv4Address("1.2.3.4") return Message( id: query.id, type: .response, returnCode: .noError, questions: query.questions, - answers: [HostRecord(name: query.questions[0].name, ttl: 0, ip: ip)] + answers: [HostRecord(name: query.questions[0].name, ttl: 0, ip: ip)] ) } return nil @@ -41,15 +39,13 @@ struct BarHandler: DNSHandler { public func answer(query: Message) async throws -> Message? { let question = query.questions[0] if question.name == "foo" || question.name == "bar" { - guard let ip = IPv4("5.6.7.8") else { - throw DNSResolverError.serverError("cannot create IP address in test") - } + let ip = try IPv4Address("5.6.7.8") return Message( id: query.id, type: .response, returnCode: .noError, questions: query.questions, - answers: [HostRecord(name: query.questions[0].name, ttl: 0, ip: ip)] + answers: [HostRecord(name: query.questions[0].name, ttl: 0, ip: ip)] ) } return nil diff --git a/Tests/DNSServerTests/NxDomainResolverTest.swift b/Tests/DNSServerTests/NxDomainResolverTest.swift index db592e56d..8eead6906 100644 --- a/Tests/DNSServerTests/NxDomainResolverTest.swift +++ b/Tests/DNSServerTests/NxDomainResolverTest.swift @@ -14,7 +14,6 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS import Testing @testable import DNSServer diff --git a/Tests/DNSServerTests/RecordsTests.swift b/Tests/DNSServerTests/RecordsTests.swift new file mode 100644 index 000000000..ae2a7712a --- /dev/null +++ b/Tests/DNSServerTests/RecordsTests.swift @@ -0,0 +1,389 @@ +//===----------------------------------------------------------------------===// +// Copyright © 2026 Apple Inc. and the container project authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//===----------------------------------------------------------------------===// + +import ContainerizationExtras +import Foundation +import Testing + +@testable import DNSServer + +@Suite("DNS Records Tests") +struct RecordsTests { + + // MARK: - DNSName Tests + + @Suite("DNSName") + struct DNSNameTests { + @Test("Create from string") + func createFromString() { + let name = DNSName("example.com") + #expect(name.labels == ["example", "com"]) + } + + @Test("Create from string with trailing dot") + func createFromStringTrailingDot() { + let name = DNSName("example.com.") + #expect(name.labels == ["example", "com"]) + } + + @Test("Description includes trailing dot") + func descriptionTrailingDot() { + let name = DNSName("example.com") + #expect(name.description == "example.com.") + } + + @Test("Root domain") + func rootDomain() { + let name = DNSName("") + #expect(name.labels == []) + #expect(name.description == ".") + } + + @Test("Size calculation") + func sizeCalculation() { + let name = DNSName("example.com") + // [7]example[3]com[0] = 1 + 7 + 1 + 3 + 1 = 13 + #expect(name.size == 13) + } + + @Test("Serialize and deserialize") + func serializeDeserialize() throws { + let original = DNSName("test.example.com") + var buffer = [UInt8](repeating: 0, count: 64) + + let endOffset = try original.appendBuffer(&buffer, offset: 0) + + var parsed = DNSName() + let readOffset = try parsed.bindBuffer(&buffer, offset: 0) + + #expect(readOffset == endOffset) + #expect(parsed.labels == original.labels) + } + + @Test("Serialize subdomain") + func serializeSubdomain() throws { + let name = DNSName("a.b.c.d.example.com") + var buffer = [UInt8](repeating: 0, count: 64) + + _ = try name.appendBuffer(&buffer, offset: 0) + + var parsed = DNSName() + _ = try parsed.bindBuffer(&buffer, offset: 0) + + #expect(parsed.labels == ["a", "b", "c", "d", "example", "com"]) + } + + @Test("Reject label too long") + func rejectLabelTooLong() { + let longLabel = String(repeating: "a", count: 64) + let name = DNSName(longLabel + ".com") + var buffer = [UInt8](repeating: 0, count: 128) + + #expect(throws: DNSBindError.self) { + _ = try name.appendBuffer(&buffer, offset: 0) + } + } + } + + // MARK: - Question Tests + + @Suite("Question") + struct QuestionTests { + @Test("Create question") + func create() { + let q = Question(name: "example.com.", type: .host, recordClass: .internet) + #expect(q.name == "example.com.") + #expect(q.type == .host) + #expect(q.recordClass == .internet) + } + + @Test("Serialize and deserialize A record question") + func serializeDeserializeA() throws { + let original = Question(name: "example.com.", type: .host, recordClass: .internet) + var buffer = [UInt8](repeating: 0, count: 64) + + let endOffset = try original.appendBuffer(&buffer, offset: 0) + + var parsed = Question(name: "") + let readOffset = try parsed.bindBuffer(&buffer, offset: 0) + + #expect(readOffset == endOffset) + #expect(parsed.type == .host) + #expect(parsed.recordClass == .internet) + } + + @Test("Serialize and deserialize AAAA record question") + func serializeDeserializeAAAA() throws { + let original = Question(name: "example.com.", type: .host6, recordClass: .internet) + var buffer = [UInt8](repeating: 0, count: 64) + + _ = try original.appendBuffer(&buffer, offset: 0) + + var parsed = Question(name: "") + _ = try parsed.bindBuffer(&buffer, offset: 0) + + #expect(parsed.type == .host6) + } + } + + // MARK: - HostRecord Tests + + @Suite("HostRecord") + struct HostRecordTests { + @Test("Create A record") + func createARecord() throws { + let ip = try IPv4Address("192.168.1.1") + let record = HostRecord(name: "example.com.", ttl: 300, ip: ip) + + #expect(record.name == "example.com.") + #expect(record.type == .host) + #expect(record.ttl == 300) + #expect(record.ip == ip) + } + + @Test("Create AAAA record") + func createAAAARecord() throws { + let ip = try IPv6Address("::1") + let record = HostRecord(name: "example.com.", ttl: 600, ip: ip) + + #expect(record.name == "example.com.") + #expect(record.type == .host6) + #expect(record.ttl == 600) + } + + @Test("Serialize A record") + func serializeARecord() throws { + let ip = try IPv4Address("10.0.0.1") + let record = HostRecord(name: "test.com.", ttl: 300, ip: ip) + var buffer = [UInt8](repeating: 0, count: 64) + + let endOffset = try record.appendBuffer(&buffer, offset: 0) + + // Should have written: name + type(2) + class(2) + ttl(4) + rdlen(2) + rdata(4) + #expect(endOffset > 0) + + // Verify IP bytes are at the end + let ipStart = endOffset - 4 + #expect(buffer[ipStart] == 10) + #expect(buffer[ipStart + 1] == 0) + #expect(buffer[ipStart + 2] == 0) + #expect(buffer[ipStart + 3] == 1) + } + + @Test("Serialize AAAA record") + func serializeAAAARecord() throws { + let ip = try IPv6Address("::1") + let record = HostRecord(name: "test.com.", ttl: 300, ip: ip) + var buffer = [UInt8](repeating: 0, count: 64) + + let endOffset = try record.appendBuffer(&buffer, offset: 0) + + // Verify last byte is 1 (::1) + #expect(buffer[endOffset - 1] == 1) + } + } + + // MARK: - Message Tests + + @Suite("Message") + struct MessageTests { + @Test("Create query message") + func createQuery() { + let msg = Message( + id: 0x1234, + type: .query, + questions: [Question(name: "example.com.", type: .host)] + ) + + #expect(msg.id == 0x1234) + #expect(msg.type == .query) + #expect(msg.questions.count == 1) + } + + @Test("Create response message") + func createResponse() throws { + let ip = try IPv4Address("192.168.1.1") + let msg = Message( + id: 0x1234, + type: .response, + returnCode: .noError, + questions: [Question(name: "example.com.", type: .host)], + answers: [HostRecord(name: "example.com.", ttl: 300, ip: ip)] + ) + + #expect(msg.type == .response) + #expect(msg.returnCode == .noError) + #expect(msg.answers.count == 1) + } + + @Test("Serialize and deserialize query") + func serializeDeserializeQuery() throws { + let original = Message( + id: 0xABCD, + type: .query, + recursionDesired: true, + questions: [Question(name: "example.com.", type: .host)] + ) + + let data = try original.serialize() + let parsed = try Message(deserialize: data) + + #expect(parsed.id == 0xABCD) + #expect(parsed.type == .query) + #expect(parsed.recursionDesired == true) + #expect(parsed.questions.count == 1) + #expect(parsed.questions[0].type == .host) + } + + @Test("Serialize response with answer") + func serializeResponse() throws { + let ip = try IPv4Address("10.0.0.1") + let msg = Message( + id: 0x1234, + type: .response, + authoritativeAnswer: true, + returnCode: .noError, + questions: [Question(name: "test.com.", type: .host)], + answers: [HostRecord(name: "test.com.", ttl: 300, ip: ip)] + ) + + let data = try msg.serialize() + + // Verify we can at least parse the header back + let parsed = try Message(deserialize: data) + #expect(parsed.id == 0x1234) + #expect(parsed.type == .response) + #expect(parsed.authoritativeAnswer == true) + #expect(parsed.returnCode == .noError) + } + + @Test("Serialize NXDOMAIN response") + func serializeNxdomain() throws { + let msg = Message( + id: 0x1234, + type: .response, + returnCode: .nonExistentDomain, + questions: [Question(name: "unknown.com.", type: .host)], + answers: [] + ) + + let data = try msg.serialize() + let parsed = try Message(deserialize: data) + + #expect(parsed.returnCode == .nonExistentDomain) + #expect(parsed.answers.count == 0) + } + + @Test("Serialize NODATA response (empty answers with noError)") + func serializeNodata() throws { + let msg = Message( + id: 0x1234, + type: .response, + returnCode: .noError, + questions: [Question(name: "example.com.", type: .host6)], + answers: [] + ) + + let data = try msg.serialize() + let parsed = try Message(deserialize: data) + + #expect(parsed.returnCode == .noError) + #expect(parsed.answers.count == 0) + } + + @Test("Multiple questions") + func multipleQuestions() throws { + let msg = Message( + id: 0x1234, + type: .query, + questions: [ + Question(name: "a.com.", type: .host), + Question(name: "b.com.", type: .host6), + ] + ) + + let data = try msg.serialize() + let parsed = try Message(deserialize: data) + + #expect(parsed.questions.count == 2) + #expect(parsed.questions[0].type == .host) + #expect(parsed.questions[1].type == .host6) + } + } + + // MARK: - Wire Format Tests + + @Suite("Wire Format") + struct WireFormatTests { + @Test("Parse real DNS query bytes") + func parseRealQuery() throws { + // A minimal DNS query for "example.com" A record + // Header: ID=0x1234, QR=0, OPCODE=0, RD=1, QDCOUNT=1 + let queryBytes: [UInt8] = [ + 0x12, 0x34, // ID + 0x01, 0x00, // Flags: RD=1 + 0x00, 0x01, // QDCOUNT=1 + 0x00, 0x00, // ANCOUNT=0 + 0x00, 0x00, // NSCOUNT=0 + 0x00, 0x00, // ARCOUNT=0 + // Question: example.com A IN + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // "example" + 0x03, 0x63, 0x6f, 0x6d, // "com" + 0x00, // null terminator + 0x00, 0x01, // QTYPE=A + 0x00, 0x01, // QCLASS=IN + ] + + let msg = try Message(deserialize: Data(queryBytes)) + + #expect(msg.id == 0x1234) + #expect(msg.type == .query) + #expect(msg.recursionDesired == true) + #expect(msg.questions.count == 1) + #expect(msg.questions[0].type == .host) + #expect(msg.questions[0].recordClass == .internet) + } + + @Test("Roundtrip preserves data") + func roundtrip() throws { + let ip = try IPv4Address("1.2.3.4") + let original = Message( + id: 0xBEEF, + type: .response, + operationCode: .query, + authoritativeAnswer: true, + truncation: false, + recursionDesired: true, + recursionAvailable: true, + returnCode: .noError, + questions: [Question(name: "test.example.com.", type: .host)], + answers: [HostRecord(name: "test.example.com.", ttl: 3600, ip: ip)] + ) + + let data = try original.serialize() + let parsed = try Message(deserialize: data) + + #expect(parsed.id == original.id) + #expect(parsed.type == original.type) + #expect(parsed.authoritativeAnswer == original.authoritativeAnswer) + #expect(parsed.truncation == original.truncation) + #expect(parsed.recursionDesired == original.recursionDesired) + #expect(parsed.recursionAvailable == original.recursionAvailable) + #expect(parsed.returnCode == original.returnCode) + #expect(parsed.questions.count == original.questions.count) + } + } +} diff --git a/Tests/DNSServerTests/StandardQueryValidatorTest.swift b/Tests/DNSServerTests/StandardQueryValidatorTest.swift index 185b00b44..0c8947a25 100644 --- a/Tests/DNSServerTests/StandardQueryValidatorTest.swift +++ b/Tests/DNSServerTests/StandardQueryValidatorTest.swift @@ -14,7 +14,7 @@ // limitations under the License. //===----------------------------------------------------------------------===// -import DNS +import ContainerizationExtras import Testing @testable import DNSServer @@ -110,7 +110,7 @@ struct StandardQueryValidatorTest { #expect("foo" == response?.questions[0].name) #expect(.host == response?.questions[0].type) #expect(1 == response?.answers.count) - let answer = response?.answers[0] as? HostRecord - #expect(IPv4("1.2.3.4") == answer?.ip) + let answer = response?.answers[0] as? HostRecord + #expect(try IPv4Address("1.2.3.4") == answer?.ip) } } From f7168007c5883e5d44113ca2666639dfdf5b1211 Mon Sep 17 00:00:00 2001 From: John Logan Date: Thu, 12 Mar 2026 17:34:26 -0700 Subject: [PATCH 2/6] Additional PR feedback - domain name handling refinement. - Require all input domain names to carry trailing dot. --- .../Handlers/HostTableResolver.swift | 9 ++++++++ Sources/DNSServer/Records/Question.swift | 13 ++++++++++- .../APIServer/ContainerDNSHandler.swift | 4 ++-- .../Server/Networks/NetworksService.swift | 8 +++++-- .../CompositeResolverTest.swift | 6 ++--- .../HostTableResolverTest.swift | 22 +++++++++---------- Tests/DNSServerTests/MockHandlers.swift | 4 ++-- .../DNSServerTests/NxDomainResolverTest.swift | 4 ++-- .../StandardQueryValidatorTest.swift | 20 ++++++++--------- 9 files changed, 57 insertions(+), 33 deletions(-) diff --git a/Sources/DNSServer/Handlers/HostTableResolver.swift b/Sources/DNSServer/Handlers/HostTableResolver.swift index ea8d2dce4..f51dd8b1c 100644 --- a/Sources/DNSServer/Handlers/HostTableResolver.swift +++ b/Sources/DNSServer/Handlers/HostTableResolver.swift @@ -17,10 +17,19 @@ import ContainerizationExtras /// Handler that uses table lookup to resolve hostnames. +/// +/// All keys in `hosts4` must be canonical DNS names — fully-qualified with a +/// trailing dot (e.g. `"example.com."`). This matches the canonical form used +/// by `Question.name` when decoded from the wire. public struct HostTableResolver: DNSHandler { public let hosts4: [String: IPv4Address] private let ttl: UInt32 + /// Creates a resolver backed by a static IPv4 host table. + /// + /// - Parameter hosts4: A dictionary mapping fully-qualified domain names (with trailing dot) + /// to IPv4 addresses. Keys without a trailing dot will not match wire-decoded queries. + /// - Parameter ttl: The TTL in seconds to set on answer records. public init(hosts4: [String: IPv4Address], ttl: UInt32 = 300) { self.hosts4 = hosts4 self.ttl = ttl diff --git a/Sources/DNSServer/Records/Question.swift b/Sources/DNSServer/Records/Question.swift index e2cf0fce6..6afe1ecd5 100644 --- a/Sources/DNSServer/Records/Question.swift +++ b/Sources/DNSServer/Records/Question.swift @@ -17,8 +17,12 @@ import Foundation /// A DNS question (query). +/// +/// All domain names in a `Question` are canonical DNS names — fully-qualified +/// with a trailing dot (e.g. `"example.com."`). This invariant is not enforced +/// automatically; callers are responsible for supplying canonical names. public struct Question: Sendable, CustomStringConvertible { - /// The domain name being queried. + /// The fully-qualified domain name being queried, with a trailing dot (e.g. `"example.com."`). public var name: String /// The record type being requested. @@ -27,6 +31,13 @@ public struct Question: Sendable, CustomStringConvertible { /// The record class (usually .internet). public var recordClass: ResourceRecordClass + /// Creates a DNS question. + /// + /// - Parameter name: The fully-qualified domain name to query, with a trailing dot + /// (e.g. `"example.com."`). Supplying a name without a trailing dot will produce + /// lookup mismatches against canonically-stored records. + /// - Parameter type: The record type being requested. + /// - Parameter recordClass: The record class (usually `.internet`). public init( name: String, type: ResourceRecordType = .host, diff --git a/Sources/Helpers/APIServer/ContainerDNSHandler.swift b/Sources/Helpers/APIServer/ContainerDNSHandler.swift index 161572e5b..70fbe0452 100644 --- a/Sources/Helpers/APIServer/ContainerDNSHandler.swift +++ b/Sources/Helpers/APIServer/ContainerDNSHandler.swift @@ -74,7 +74,7 @@ struct ContainerDNSHandler: DNSHandler { } private func answerHost(question: Question) async throws -> ResourceRecord? { - guard let ipAllocation = try await networkService.lookup(hostname: question.name) else { + guard let ipAllocation = try await networkService.lookup(dnsHostname: question.name) else { return nil } let ipv4 = ipAllocation.ipv4Address.address.description @@ -86,7 +86,7 @@ struct ContainerDNSHandler: DNSHandler { } private func answerHost6(question: Question) async throws -> (record: ResourceRecord?, hostnameExists: Bool) { - guard let ipAllocation = try await networkService.lookup(hostname: question.name) else { + guard let ipAllocation = try await networkService.lookup(dnsHostname: question.name) else { return (nil, false) } guard let ipv6Address = ipAllocation.ipv6Address else { diff --git a/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift b/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift index 7966f77d2..a5d8577e4 100644 --- a/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift +++ b/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift @@ -339,8 +339,12 @@ public actor NetworksService { } /// Perform a hostname lookup on all networks. - public func lookup(hostname: String) async throws -> Attachment? { - try await self.stateLock.withLock { _ in + /// + /// - Parameter dnsHostname: A DNS-format hostname, optionally with a trailing dot + /// (e.g. `"example.com."` or `"example.com"`). The trailing dot is stripped before lookup. + public func lookup(dnsHostname: String) async throws -> Attachment? { + let hostname = dnsHostname.hasSuffix(".") ? String(dnsHostname.dropLast()) : dnsHostname + return try await self.stateLock.withLock { _ in for state in await self.serviceStates.values { guard let allocation = try await state.client.lookup(hostname: hostname) else { continue diff --git a/Tests/DNSServerTests/CompositeResolverTest.swift b/Tests/DNSServerTests/CompositeResolverTest.swift index 6698e8296..df9a4fe1b 100644 --- a/Tests/DNSServerTests/CompositeResolverTest.swift +++ b/Tests/DNSServerTests/CompositeResolverTest.swift @@ -29,7 +29,7 @@ struct CompositeResolverTest { id: UInt16(1), type: .query, questions: [ - Question(name: "foo", type: .host) + Question(name: "foo.", type: .host) ]) let fooResponse = try await resolver.answer(query: fooQuery) @@ -43,7 +43,7 @@ struct CompositeResolverTest { id: UInt16(1), type: .query, questions: [ - Question(name: "bar", type: .host) + Question(name: "bar.", type: .host) ]) let barResponse = try await resolver.answer(query: barQuery) @@ -57,7 +57,7 @@ struct CompositeResolverTest { id: UInt16(1), type: .query, questions: [ - Question(name: "other", type: .host) + Question(name: "other.", type: .host) ]) let otherResponse = try await resolver.answer(query: otherQuery) diff --git a/Tests/DNSServerTests/HostTableResolverTest.swift b/Tests/DNSServerTests/HostTableResolverTest.swift index 141fc73db..a38d328de 100644 --- a/Tests/DNSServerTests/HostTableResolverTest.swift +++ b/Tests/DNSServerTests/HostTableResolverTest.swift @@ -22,13 +22,13 @@ import Testing struct HostTableResolverTest { @Test func testUnsupportedQuestionType() async throws { let ip = try IPv4Address("1.2.3.4") - let handler = HostTableResolver(hosts4: ["foo": ip]) + let handler = HostTableResolver(hosts4: ["foo.": ip]) let query = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "foo", type: .mailExchange) + Question(name: "foo.", type: .mailExchange) ]) let response = try await handler.answer(query: query) @@ -42,13 +42,13 @@ struct HostTableResolverTest { @Test func testAAAAQueryReturnsNoDataWhenARecordExists() async throws { let ip = try IPv4Address("1.2.3.4") - let handler = HostTableResolver(hosts4: ["foo": ip]) + let handler = HostTableResolver(hosts4: ["foo.": ip]) let query = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "foo", type: .host6) + Question(name: "foo.", type: .host6) ]) let response = try await handler.answer(query: query) @@ -64,13 +64,13 @@ struct HostTableResolverTest { @Test func testAAAAQueryReturnsNilWhenHostDoesNotExist() async throws { let ip = try IPv4Address("1.2.3.4") - let handler = HostTableResolver(hosts4: ["foo": ip]) + let handler = HostTableResolver(hosts4: ["foo.": ip]) let query = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "bar", type: .host6) + Question(name: "bar.", type: .host6) ]) let response = try await handler.answer(query: query) @@ -81,13 +81,13 @@ struct HostTableResolverTest { @Test func testHostNotPresent() async throws { let ip = try IPv4Address("1.2.3.4") - let handler = HostTableResolver(hosts4: ["foo": ip]) + let handler = HostTableResolver(hosts4: ["foo.": ip]) let query = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "bar", type: .host) + Question(name: "bar.", type: .host) ]) let response = try await handler.answer(query: query) @@ -97,13 +97,13 @@ struct HostTableResolverTest { @Test func testHostPresent() async throws { let ip = try IPv4Address("1.2.3.4") - let handler = HostTableResolver(hosts4: ["foo": ip]) + let handler = HostTableResolver(hosts4: ["foo.": ip]) let query = Message( id: UInt16(1), type: .query, questions: [ - Question(name: "foo", type: .host) + Question(name: "foo.", type: .host) ]) let response = try await handler.answer(query: query) @@ -112,7 +112,7 @@ struct HostTableResolverTest { #expect(1 == response?.id) #expect(.response == response?.type) #expect(1 == response?.questions.count) - #expect("foo" == response?.questions[0].name) + #expect("foo." == response?.questions[0].name) #expect(.host == response?.questions[0].type) #expect(1 == response?.answers.count) let answer = response?.answers[0] as? HostRecord diff --git a/Tests/DNSServerTests/MockHandlers.swift b/Tests/DNSServerTests/MockHandlers.swift index 306540436..b0a4740cb 100644 --- a/Tests/DNSServerTests/MockHandlers.swift +++ b/Tests/DNSServerTests/MockHandlers.swift @@ -21,7 +21,7 @@ import Testing struct FooHandler: DNSHandler { public func answer(query: Message) async throws -> Message? { - if query.questions[0].name == "foo" { + if query.questions[0].name == "foo." { let ip = try IPv4Address("1.2.3.4") return Message( id: query.id, @@ -38,7 +38,7 @@ struct FooHandler: DNSHandler { struct BarHandler: DNSHandler { public func answer(query: Message) async throws -> Message? { let question = query.questions[0] - if question.name == "foo" || question.name == "bar" { + if question.name == "foo." || question.name == "bar." { let ip = try IPv4Address("5.6.7.8") return Message( id: query.id, diff --git a/Tests/DNSServerTests/NxDomainResolverTest.swift b/Tests/DNSServerTests/NxDomainResolverTest.swift index 8eead6906..27264c158 100644 --- a/Tests/DNSServerTests/NxDomainResolverTest.swift +++ b/Tests/DNSServerTests/NxDomainResolverTest.swift @@ -26,7 +26,7 @@ struct NxDomainResolverTest { id: UInt16(1), type: .query, questions: [ - Question(name: "foo", type: .host6) + Question(name: "foo.", type: .host6) ]) let response = try await handler.answer(query: query) @@ -45,7 +45,7 @@ struct NxDomainResolverTest { id: UInt16(1), type: .query, questions: [ - Question(name: "bar", type: .host) + Question(name: "bar.", type: .host) ]) let response = try await handler.answer(query: query) diff --git a/Tests/DNSServerTests/StandardQueryValidatorTest.swift b/Tests/DNSServerTests/StandardQueryValidatorTest.swift index 0c8947a25..8a1caeb44 100644 --- a/Tests/DNSServerTests/StandardQueryValidatorTest.swift +++ b/Tests/DNSServerTests/StandardQueryValidatorTest.swift @@ -28,7 +28,7 @@ struct StandardQueryValidatorTest { id: UInt16(1), type: .response, questions: [ - Question(name: "foo", type: .host) + Question(name: "foo.", type: .host) ]) let response = try await handler.answer(query: query) @@ -37,7 +37,7 @@ struct StandardQueryValidatorTest { #expect(1 == response?.id) #expect(.response == response?.type) #expect(1 == response?.questions.count) - #expect("foo" == response?.questions[0].name) + #expect("foo." == response?.questions[0].name) #expect(.host == response?.questions[0].type) #expect(0 == response?.answers.count) } @@ -51,7 +51,7 @@ struct StandardQueryValidatorTest { type: .query, operationCode: .notify, questions: [ - Question(name: "foo", type: .host) + Question(name: "foo.", type: .host) ]) let response = try await handler.answer(query: query) @@ -60,7 +60,7 @@ struct StandardQueryValidatorTest { #expect(2 == response?.id) #expect(.response == response?.type) #expect(1 == response?.questions.count) - #expect("foo" == response?.questions[0].name) + #expect("foo." == response?.questions[0].name) #expect(.host == response?.questions[0].type) #expect(0 == response?.answers.count) } @@ -73,8 +73,8 @@ struct StandardQueryValidatorTest { id: UInt16(2), type: .query, questions: [ - Question(name: "foo", type: .host), - Question(name: "bar", type: .host), + Question(name: "foo.", type: .host), + Question(name: "bar.", type: .host), ]) let response = try await handler.answer(query: query) @@ -83,9 +83,9 @@ struct StandardQueryValidatorTest { #expect(2 == response?.id) #expect(.response == response?.type) #expect(2 == response?.questions.count) - #expect("foo" == response?.questions[0].name) + #expect("foo." == response?.questions[0].name) #expect(.host == response?.questions[0].type) - #expect("bar" == response?.questions[1].name) + #expect("bar." == response?.questions[1].name) #expect(.host == response?.questions[1].type) #expect(0 == response?.answers.count) } @@ -98,7 +98,7 @@ struct StandardQueryValidatorTest { id: UInt16(2), type: .query, questions: [ - Question(name: "foo", type: .host) + Question(name: "foo.", type: .host) ]) let response = try await handler.answer(query: query) @@ -107,7 +107,7 @@ struct StandardQueryValidatorTest { #expect(2 == response?.id) #expect(.response == response?.type) #expect(1 == response?.questions.count) - #expect("foo" == response?.questions[0].name) + #expect("foo." == response?.questions[0].name) #expect(.host == response?.questions[0].type) #expect(1 == response?.answers.count) let answer = response?.answers[0] as? HostRecord From c778d70e14d68d5078e5be509830e1fbd7cedfb2 Mon Sep 17 00:00:00 2001 From: John Logan Date: Thu, 12 Mar 2026 17:37:03 -0700 Subject: [PATCH 3/6] Additional PR feedback. - Fold labels to lowercase. - Reject compression pointer recursion. --- Sources/DNSServer/Records/DNSName.swift | 10 +++-- Tests/DNSServerTests/RecordsTests.swift | 54 +++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/Sources/DNSServer/Records/DNSName.swift b/Sources/DNSServer/Records/DNSName.swift index df7b5aa6a..faae72e50 100644 --- a/Sources/DNSServer/Records/DNSName.swift +++ b/Sources/DNSServer/Records/DNSName.swift @@ -33,7 +33,7 @@ public struct DNSName: Sendable, Hashable, CustomStringConvertible { public init(_ string: String) { // Remove trailing dot if present, then split let normalized = string.hasSuffix(".") ? String(string.dropLast()) : string - self.labels = normalized.isEmpty ? [] : normalized.split(separator: ".").map(String.init) + self.labels = normalized.isEmpty ? [] : normalized.split(separator: ".").map { String($0).lowercased() } } /// The wire format size of this name in bytes. @@ -112,7 +112,11 @@ public struct DNSName: Sendable, Hashable, CustomStringConvertible { // Calculate pointer offset from message start let pointer = Int(length & 0x3F) << 8 | Int(buffer[offset + 1]) - offset = messageStart + pointer + let pointerTarget = messageStart + pointer + guard pointerTarget < offset else { + throw DNSBindError.unmarshalFailure(type: "DNSName", field: "compression pointer not prior") + } + offset = pointerTarget jumped = true continue } @@ -133,7 +137,7 @@ public struct DNSName: Sendable, Hashable, CustomStringConvertible { throw DNSBindError.unmarshalFailure(type: "DNSName", field: "label encoding") } - labels.append(label) + labels.append(label.lowercased()) offset += Int(length) } diff --git a/Tests/DNSServerTests/RecordsTests.swift b/Tests/DNSServerTests/RecordsTests.swift index ae2a7712a..74fd5077f 100644 --- a/Tests/DNSServerTests/RecordsTests.swift +++ b/Tests/DNSServerTests/RecordsTests.swift @@ -96,6 +96,60 @@ struct RecordsTests { _ = try name.appendBuffer(&buffer, offset: 0) } } + + @Test("Lowercase labels on init") + func lowercaseLabelsOnInit() { + let name = DNSName("EXAMPLE.COM") + #expect(name.labels == ["example", "com"]) + } + + @Test("Lowercase labels on init with trailing dot") + func lowercaseLabelsOnInitTrailingDot() { + let name = DNSName("Example.Com.") + #expect(name.labels == ["example", "com"]) + } + + @Test("Lowercase labels from wire format") + func lowercaseLabelsFromWire() throws { + // Wire-encode "EXAMPLE.COM" with uppercase bytes, then decode + let upper = DNSName(labels: ["EXAMPLE", "COM"]) + var buffer = [UInt8](repeating: 0, count: 64) + _ = try upper.appendBuffer(&buffer, offset: 0) + + var parsed = DNSName() + _ = try parsed.bindBuffer(&buffer, offset: 0) + #expect(parsed.labels == ["example", "com"]) + } + + @Test("Reject forward compression pointer") + func rejectForwardCompressionPointer() throws { + // Craft a packet with a forward compression pointer at offset 12 pointing to offset 20 + // Header (12 bytes) + pointer bytes + var buffer = [UInt8](repeating: 0, count: 32) + // At offset 0: compression pointer to offset 20 (forward) + buffer[0] = 0xC0 + buffer[1] = 0x14 // points to offset 20, which is > 0 + + #expect(throws: DNSBindError.self) { + var b = buffer + var name = DNSName() + _ = try name.bindBuffer(&b, offset: 0) + } + } + + @Test("Reject self-referential compression pointer") + func rejectSelfReferentialCompressionPointer() throws { + var buffer = [UInt8](repeating: 0, count: 16) + // At offset 0: compression pointer pointing back to offset 0 (same location) + buffer[0] = 0xC0 + buffer[1] = 0x00 // points to offset 0 == current offset, not prior + + #expect(throws: DNSBindError.self) { + var b = buffer + var name = DNSName() + _ = try name.bindBuffer(&b, offset: 0) + } + } } // MARK: - Question Tests From a1b63250c07334a00e222b762a96109cb35276b1 Mon Sep 17 00:00:00 2001 From: John Logan Date: Mon, 16 Mar 2026 08:52:59 -0700 Subject: [PATCH 4/6] Rename API parameter to avoid breaking change. --- Sources/Helpers/APIServer/ContainerDNSHandler.swift | 4 ++-- .../Server/Networks/NetworksService.swift | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Sources/Helpers/APIServer/ContainerDNSHandler.swift b/Sources/Helpers/APIServer/ContainerDNSHandler.swift index 70fbe0452..161572e5b 100644 --- a/Sources/Helpers/APIServer/ContainerDNSHandler.swift +++ b/Sources/Helpers/APIServer/ContainerDNSHandler.swift @@ -74,7 +74,7 @@ struct ContainerDNSHandler: DNSHandler { } private func answerHost(question: Question) async throws -> ResourceRecord? { - guard let ipAllocation = try await networkService.lookup(dnsHostname: question.name) else { + guard let ipAllocation = try await networkService.lookup(hostname: question.name) else { return nil } let ipv4 = ipAllocation.ipv4Address.address.description @@ -86,7 +86,7 @@ struct ContainerDNSHandler: DNSHandler { } private func answerHost6(question: Question) async throws -> (record: ResourceRecord?, hostnameExists: Bool) { - guard let ipAllocation = try await networkService.lookup(dnsHostname: question.name) else { + guard let ipAllocation = try await networkService.lookup(hostname: question.name) else { return (nil, false) } guard let ipv6Address = ipAllocation.ipv6Address else { diff --git a/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift b/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift index a5d8577e4..2f9148aed 100644 --- a/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift +++ b/Sources/Services/ContainerAPIService/Server/Networks/NetworksService.swift @@ -340,10 +340,10 @@ public actor NetworksService { /// Perform a hostname lookup on all networks. /// - /// - Parameter dnsHostname: A DNS-format hostname, optionally with a trailing dot + /// - Parameter hostname: A DNS-format hostname, optionally with a trailing dot /// (e.g. `"example.com."` or `"example.com"`). The trailing dot is stripped before lookup. - public func lookup(dnsHostname: String) async throws -> Attachment? { - let hostname = dnsHostname.hasSuffix(".") ? String(dnsHostname.dropLast()) : dnsHostname + public func lookup(hostname: String) async throws -> Attachment? { + let hostname = hostname.hasSuffix(".") ? String(hostname.dropLast()) : hostname return try await self.stateLock.withLock { _ in for state in await self.serviceStates.values { guard let allocation = try await state.client.lookup(hostname: hostname) else { From ea1e492c621444569d3c97445043f0c7af606a07 Mon Sep 17 00:00:00 2001 From: John Logan Date: Mon, 16 Mar 2026 09:10:19 -0700 Subject: [PATCH 5/6] Add default ttl to docc comment. --- Sources/DNSServer/Handlers/HostTableResolver.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/DNSServer/Handlers/HostTableResolver.swift b/Sources/DNSServer/Handlers/HostTableResolver.swift index f51dd8b1c..708ec6488 100644 --- a/Sources/DNSServer/Handlers/HostTableResolver.swift +++ b/Sources/DNSServer/Handlers/HostTableResolver.swift @@ -29,7 +29,7 @@ public struct HostTableResolver: DNSHandler { /// /// - Parameter hosts4: A dictionary mapping fully-qualified domain names (with trailing dot) /// to IPv4 addresses. Keys without a trailing dot will not match wire-decoded queries. - /// - Parameter ttl: The TTL in seconds to set on answer records. + /// - Parameter ttl: The TTL in seconds to set on answer records (default is 300). public init(hosts4: [String: IPv4Address], ttl: UInt32 = 300) { self.hosts4 = hosts4 self.ttl = ttl From a108c643780f217179b4a4de53109887d008c5ce Mon Sep 17 00:00:00 2001 From: John Logan Date: Mon, 16 Mar 2026 11:37:58 -0700 Subject: [PATCH 6/6] Fix case folding in HostTableResolver. --- .../Handlers/HostTableResolver.swift | 18 +++---- Sources/DNSServer/Records/DNSName.swift | 12 +++-- .../HostTableResolverTest.swift | 48 +++++++++++++++++++ 3 files changed, 65 insertions(+), 13 deletions(-) diff --git a/Sources/DNSServer/Handlers/HostTableResolver.swift b/Sources/DNSServer/Handlers/HostTableResolver.swift index 708ec6488..74d719063 100644 --- a/Sources/DNSServer/Handlers/HostTableResolver.swift +++ b/Sources/DNSServer/Handlers/HostTableResolver.swift @@ -18,20 +18,20 @@ import ContainerizationExtras /// Handler that uses table lookup to resolve hostnames. /// -/// All keys in `hosts4` must be canonical DNS names — fully-qualified with a -/// trailing dot (e.g. `"example.com."`). This matches the canonical form used -/// by `Question.name` when decoded from the wire. +/// Keys in `hosts4` are normalized to `DNSName` on construction, so lookups +/// are case-insensitive and trailing dots are optional. public struct HostTableResolver: DNSHandler { - public let hosts4: [String: IPv4Address] + public let hosts4: [DNSName: IPv4Address] private let ttl: UInt32 /// Creates a resolver backed by a static IPv4 host table. /// - /// - Parameter hosts4: A dictionary mapping fully-qualified domain names (with trailing dot) - /// to IPv4 addresses. Keys without a trailing dot will not match wire-decoded queries. + /// - Parameter hosts4: A dictionary mapping domain names to IPv4 addresses. + /// Keys are normalized to `DNSName` (lowercased, trailing dot stripped), so + /// `"FOO."`, `"foo."`, and `"foo"` all refer to the same entry. /// - Parameter ttl: The TTL in seconds to set on answer records (default is 300). public init(hosts4: [String: IPv4Address], ttl: UInt32 = 300) { - self.hosts4 = hosts4 + self.hosts4 = Dictionary(uniqueKeysWithValues: hosts4.map { (DNSName($0.key), $0.value) }) self.ttl = ttl } @@ -46,7 +46,7 @@ public struct HostTableResolver: DNSHandler { // This is required because musl libc has issues when A record exists but AAAA returns NXDOMAIN. // musl treats NXDOMAIN on AAAA as "domain doesn't exist" and fails DNS resolution entirely. // NODATA correctly indicates "no IPv6 address available, but domain exists". - if hosts4[question.name] != nil { + if hosts4[DNSName(question.name)] != nil { return Message( id: query.id, type: .response, @@ -81,7 +81,7 @@ public struct HostTableResolver: DNSHandler { } private func answerHost(question: Question) -> ResourceRecord? { - guard let ip = hosts4[question.name] else { + guard let ip = hosts4[DNSName(question.name)] else { return nil } diff --git a/Sources/DNSServer/Records/DNSName.swift b/Sources/DNSServer/Records/DNSName.swift index faae72e50..63fe24066 100644 --- a/Sources/DNSServer/Records/DNSName.swift +++ b/Sources/DNSServer/Records/DNSName.swift @@ -25,15 +25,19 @@ public struct DNSName: Sendable, Hashable, CustomStringConvertible { public var labels: [String] /// Creates a DNS name from an array of labels. + /// + /// Labels are lowercased to normalize for case-insensitive DNS comparison. public init(labels: [String] = []) { - self.labels = labels + self.labels = labels.map { $0.lowercased() } } - /// Creates a DNS name from a dot-separated string (e.g., "example.com."). + /// Creates a DNS name from a dot-separated string (e.g., "example.com." or "example.com"). + /// + /// A trailing dot is accepted but not required. Labels are lowercased to normalize + /// for case-insensitive DNS comparison. public init(_ string: String) { - // Remove trailing dot if present, then split let normalized = string.hasSuffix(".") ? String(string.dropLast()) : string - self.labels = normalized.isEmpty ? [] : normalized.split(separator: ".").map { String($0).lowercased() } + self.init(labels: normalized.isEmpty ? [] : normalized.split(separator: ".").map { String($0) }) } /// The wire format size of this name in bytes. diff --git a/Tests/DNSServerTests/HostTableResolverTest.swift b/Tests/DNSServerTests/HostTableResolverTest.swift index a38d328de..ceeba01d8 100644 --- a/Tests/DNSServerTests/HostTableResolverTest.swift +++ b/Tests/DNSServerTests/HostTableResolverTest.swift @@ -118,4 +118,52 @@ struct HostTableResolverTest { let answer = response?.answers[0] as? HostRecord #expect(try IPv4Address("1.2.3.4") == answer?.ip) } + + @Test func testHostPresentUppercaseTable() async throws { + let ip = try IPv4Address("1.2.3.4") + let handler = HostTableResolver(hosts4: ["FOO.": ip]) + + let query = Message( + id: UInt16(1), + type: .query, + questions: [ + Question(name: "foo.", type: .host) + ]) + + let response = try await handler.answer(query: query) + + #expect(.noError == response?.returnCode) + #expect(1 == response?.id) + #expect(.response == response?.type) + #expect(1 == response?.questions.count) + #expect("foo." == response?.questions[0].name) + #expect(.host == response?.questions[0].type) + #expect(1 == response?.answers.count) + let answer = response?.answers[0] as? HostRecord + #expect(try IPv4Address("1.2.3.4") == answer?.ip) + } + + @Test func testHostPresentUppercaseQuestion() async throws { + let ip = try IPv4Address("1.2.3.4") + let handler = HostTableResolver(hosts4: ["foo.": ip]) + + let query = Message( + id: UInt16(1), + type: .query, + questions: [ + Question(name: "FOO.", type: .host) + ]) + + let response = try await handler.answer(query: query) + + #expect(.noError == response?.returnCode) + #expect(1 == response?.id) + #expect(.response == response?.type) + #expect(1 == response?.questions.count) + #expect("FOO." == response?.questions[0].name) + #expect(.host == response?.questions[0].type) + #expect(1 == response?.answers.count) + let answer = response?.answers[0] as? HostRecord + #expect(try IPv4Address("1.2.3.4") == answer?.ip) + } }