From 798cd4d2fd759a22329f0473f06ea21269692625 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Wed, 17 Jun 2026 11:50:34 +0000 Subject: [PATCH 1/4] Pull request 2676: AGDNS-3863-gopacket-dhcp-vol.26 Updates #4923. Squashed commit of the following: commit 072006739bf2efaef83689f1a2b2c706a31d29e1 Merge: 763b23a37 54e6e3002 Author: Eugene Burkov Date: Wed Jun 17 14:40:42 2026 +0300 Merge branch 'master' into AGDNS-3863-gopacket-dhcp-vol.26 commit 763b23a37bc3c5c379832df55021f5703b9a6b0b Author: Eugene Burkov Date: Tue Jun 16 17:28:05 2026 +0300 dhcpsvc: fix lease expiry logic commit 84cb6fde32b721b86e4ffdff8aebb30fe3be662a Author: Eugene Burkov Date: Mon Jun 15 19:13:42 2026 +0300 dhcpsvc: imp tests commit cbe23b49f21218d02c56cb7f8b965663834e9457 Author: Eugene Burkov Date: Tue Jun 9 20:45:59 2026 +0300 dhcpsvc: add v6 tests --- internal/dhcpsvc/dhcpsvc_test.go | 97 ++++- internal/dhcpsvc/handler4_test.go | 272 ++++++-------- internal/dhcpsvc/handler6.go | 7 +- internal/dhcpsvc/handler6_test.go | 337 ++++++++++++++++++ internal/dhcpsvc/interface.go | 2 + internal/dhcpsvc/options6.go | 94 ++--- internal/dhcpsvc/options6_test.go | 88 +++++ internal/dhcpsvc/server.go | 26 +- .../leases.json | 26 ++ internal/dhcpsvc/v6.go | 55 ++- 10 files changed, 755 insertions(+), 249 deletions(-) create mode 100644 internal/dhcpsvc/handler6_test.go create mode 100644 internal/dhcpsvc/options6_test.go create mode 100644 internal/dhcpsvc/testdata/TestDHCPServer_ServeEther6_solicit/leases.json diff --git a/internal/dhcpsvc/dhcpsvc_test.go b/internal/dhcpsvc/dhcpsvc_test.go index f2ec21fc7c8..7d0cf912bd3 100644 --- a/internal/dhcpsvc/dhcpsvc_test.go +++ b/internal/dhcpsvc/dhcpsvc_test.go @@ -41,9 +41,6 @@ const testTimeout = 10 * time.Second // testLeaseTTL is the lease duration used in tests. const testLeaseTTL = 24 * time.Hour -// testXid is a common transaction ID for DHCPv4 tests. -const testXid = 1 - // testLogger is a common logger for tests. var testLogger = slogutil.NewDiscardLogger() @@ -102,11 +99,15 @@ const ( const ( // testRangeStartV6Str is the string representation of the range start of // the IPv6 interface used in tests. - testRangeStartV6Str = "2001:db8::1" + testRangeStartV6Str = "2001:db8::2" // testAnotherRangeStartV6Str is the string representation of the range // start of the second IPv6 interface used in tests. testAnotherRangeStartV6Str = "2001:db9::1" + + // testIfaceAddrV6Str is the string representation of the interface's IPv6 + // address used in tests. + testIfaceAddrV6Str = "2001:db8::1" ) var ( @@ -133,10 +134,23 @@ var ( RASLAACOnly: true, } - // testIfaceAddr is a common valid IPv4 address of the test network + // disabledIPv4Conf is a configuration of IPv4 part of the interfaces + // configuration that is disabled. + disabledIPv4Conf = &dhcpsvc.IPv4Config{Enabled: false} + + // disabledIPv6Conf is a configuration of IPv6 part of the interfaces + // configuration that is disabled. + disabledIPv6Conf = &dhcpsvc.IPv6Config{Enabled: false} + + // testIfaceAddrV4 is a common valid IPv4 address of the test network // interface, compliant with [testIPv4Conf], i.e. outside of the range, // within the subnet, not equal to the gateway. - testIfaceAddr = netip.MustParseAddr(testIfaceAddrV4Str) + testIfaceAddrV4 = netip.MustParseAddr(testIfaceAddrV4Str) + + // testIfaceAddrV6 is a common valid IPv6 address of the test network + // interface, compliant with [testIPv6Conf], i.e. outside of the range, + // within the subnet, not equal to the gateway. + testIfaceAddrV6 = netip.MustParseAddr(testIfaceAddrV6Str) // testIfaceHWAddr is a common valid hardware address of the test network // interface. @@ -171,19 +185,44 @@ var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{ }, } -// disabledIPv6Config is a configuration of IPv6 part of the interfaces -// configuration that is disabled. -var disabledIPv6Config = &dhcpsvc.IPv6Config{Enabled: false} +// Hardware addresses for test cases. +// +// NOTE: Keep in sync with testdata. +var ( + // testHWUnknown is the test MAC address for an unknown client. + testHWUnknown = net.HardwareAddr{0x0, 0x1, 0x2, 0x3, 0x4, 0x5} -// fullLayersStack is the complete stack of layers expected to appear in the + // testHWStatic is the test MAC address for a known static lease. + testHWStatic = net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6} + + // testHWDynamic is the test MAC address for a known dynamic lease. + testHWDynamic = net.HardwareAddr{0x2, 0x3, 0x4, 0x5, 0x6, 0x7} + + // testHWExpired is the test MAC address for a known expired lease. + testHWExpired = net.HardwareAddr{0x3, 0x4, 0x5, 0x6, 0x7, 0x8} + + // testHWAnother is the test MAC address for a lease with another IP. + testHWAnother = net.HardwareAddr{0x4, 0x5, 0x6, 0x7, 0x8, 0x9} +) + +// fullLayersStack4 is the complete stack of layers expected to appear in the // DHCP response packets. -var fullLayersStack = []gopacket.LayerType{ +var fullLayersStack4 = []gopacket.LayerType{ layers.LayerTypeEthernet, layers.LayerTypeIPv4, layers.LayerTypeUDP, layers.LayerTypeDHCPv4, } +// fullLayersStack6 is the complete stack of layers expected to appear in the +// DHCPv6 response packets. +var fullLayersStack6 = []gopacket.LayerType{ + layers.LayerTypeEthernet, + layers.LayerTypeIPv6, + layers.LayerTypeUDP, + layers.LayerTypeDHCPv6, +} + // newTempDB copies the leases database file located in the testdata FS, under // tb.Name()/leases.json, to a temporary directory and returns the path to the // copied file. @@ -235,3 +274,39 @@ func newTestDHCPServer(tb testing.TB, conf *dhcpsvc.Config) (srv *dhcpsvc.DHCPSe func startTestDHCPServer(tb testing.TB, conf *dhcpsvc.Config) { servicetest.RequireRun(tb, newTestDHCPServer(tb, conf), testTimeout) } + +// newTestPacket creates a valid packet from ls using first as first layer +// decoder. +func newTestPacket( + tb testing.TB, + first gopacket.Decoder, + ls ...gopacket.SerializableLayer, +) (pkg gopacket.Packet) { + tb.Helper() + + buf := gopacket.NewSerializeBuffer() + + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + err := gopacket.SerializeLayers(buf, opts, ls...) + require.NoError(tb, err) + + return gopacket.NewPacket(buf.Bytes(), first, gopacket.Default) +} + +// assertNoResponse asserts that no response is received on the channel within +// the timeout. +// +// TODO(e.burkov): Improve the helper to not rely on timeout. +func assertNoResponse(tb testing.TB, outCh <-chan []byte, timeout time.Duration) { + tb.Helper() + + var resp []byte + require.Panics(tb, func() { + resp, _ = testutil.RequireReceive(testutil.NewPanicT(tb), outCh, timeout) + }) + + require.Nil(tb, resp) +} diff --git a/internal/dhcpsvc/handler4_test.go b/internal/dhcpsvc/handler4_test.go index 8e06c31f44d..f4b40154737 100644 --- a/internal/dhcpsvc/handler4_test.go +++ b/internal/dhcpsvc/handler4_test.go @@ -25,7 +25,7 @@ import ( var testIPv4InterfacesConf = map[string]*dhcpsvc.InterfaceConfig{ testIfaceName: { IPv4: testIPv4Conf, - IPv6: disabledIPv6Config, + IPv6: disabledIPv6Conf, }, } @@ -33,54 +33,42 @@ var testIPv4InterfacesConf = map[string]*dhcpsvc.InterfaceConfig{ // // NOTE: Keep in sync with testdata. const ( - // testLeaseHostnameStatic is the test hostname for the static lease. - testLeaseHostnameStatic = "static4" + // testLease4HostnameStatic is the test hostname for a static DHCPv4 lease. + testLease4HostnameStatic = "static4" - // testLeaseHostnameDynamic is the test hostname for the dynamic lease. - testLeaseHostnameDynamic = "dynamic4" + // testLease4HostnameDynamic is the test hostname for a dynamic DHCPv4 + // lease. + testLease4HostnameDynamic = "dynamic4" - // testLeaseHostnameExpired is the test hostname for the expired lease. - testLeaseHostnameExpired = "expired4" + // testLease4HostnameExpired is the test hostname for an expired DHCPv4 + // lease. + testLease4HostnameExpired = "expired4" ) -// Hardware addresses for test cases. +// testXid is a common transaction ID for DHCPv4 tests. // -// NOTE: Keep in sync with testdata. -var ( - // testHWUnknown is the test MAC address for an unknown client. - testHWUnknown = net.HardwareAddr{0x0, 0x1, 0x2, 0x3, 0x4, 0x5} - - // testHWStatic is the test MAC address for a known static lease. - testHWStatic = net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6} - - // testHWDynamic is the test MAC address for a known dynamic lease. - testHWDynamic = net.HardwareAddr{0x2, 0x3, 0x4, 0x5, 0x6, 0x7} - - // testHWExpired is the test MAC address for a known expired lease. - testHWExpired = net.HardwareAddr{0x3, 0x4, 0x5, 0x6, 0x7, 0x8} - - // testHWAnother is the test MAC address for a lease with another IP. - testHWAnother = net.HardwareAddr{0x4, 0x5, 0x6, 0x7, 0x8, 0x9} -) +// TODO(e.burkov): Generate unique IDs when they will be actually used. +const testXid = 1 // IP addresses for test cases. // // NOTE: Keep in sync with testdata. var ( - // testIPUnknown is the test IP address for an unknown client. - testIPUnknown = netip.MustParseAddr("192.0.2.142") + // testIPv4Unknown is the test IP address for an unknown client. + testIPv4Unknown = netip.MustParseAddr("192.0.2.142") - // testIPStatic is the test IP address for a known static lease. - testIPStatic = netip.MustParseAddr("192.0.2.101") + // testIPv4Static is the test IP address for a known static lease. + testIPv4Static = netip.MustParseAddr("192.0.2.101") - // testIPDynamic is the test IP address for a known dynamic lease. - testIPDynamic = netip.MustParseAddr("192.0.2.102") + // testIPv4Dynamic is the test IP address for a known dynamic lease. + testIPv4Dynamic = netip.MustParseAddr("192.0.2.102") - // testIPOtherSubnet is the test IP address for a client on another subnet. - testIPOtherSubnet = netip.MustParseAddr(testAnotherGatewayIPv4Str) + // testIPv4OtherSubnet is the test IP address for a client on another + // subnet. + testIPv4OtherSubnet = netip.MustParseAddr(testAnotherGatewayIPv4Str) - // testIPRelayAgent is the test IP address of the relay agent. - testIPRelayAgent = netip.MustParseAddr("10.0.0.1") + // testIPv4RelayAgent is the test IP address of the relay agent. + testIPv4RelayAgent = netip.MustParseAddr("10.0.0.1") ) // Time-related variables for test cases. @@ -106,7 +94,7 @@ func TestDHCPServer_ServeEther4_discover(t *testing.T) { in: newDHCPDISCOVER(t, testHWUnknown), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeOffer), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testLeaseTTL), }, }, { @@ -114,27 +102,27 @@ func TestDHCPServer_ServeEther4_discover(t *testing.T) { in: newDHCPDISCOVER(t, testHWStatic), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeOffer), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testLeaseTTL), - newOptHostname(t, testLeaseHostnameStatic), + newOptHostname(t, testLease4HostnameStatic), }, }, { name: "existing_dynamic", in: newDHCPDISCOVER(t, testHWDynamic), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeOffer), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testTTLDynamicLease), - newOptHostname(t, testLeaseHostnameDynamic), + newOptHostname(t, testLease4HostnameDynamic), }, }, { name: "existing_dynamic_expired", in: newDHCPDISCOVER(t, testHWExpired), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeOffer), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testLeaseTTL), - newOptHostname(t, testLeaseHostnameExpired), + newOptHostname(t, testLease4HostnameExpired), }, }} @@ -145,7 +133,7 @@ func TestDHCPServer_ServeEther4_discover(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddr) + ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddrV4) startTestDHCPServer(t, &dhcpsvc.Config{ Interfaces: testIPv4InterfacesConf, NetworkDeviceManager: ndMgr, @@ -155,7 +143,7 @@ func TestDHCPServer_ServeEther4_discover(t *testing.T) { testutil.RequireSend(t, inCh, tc.in, testTimeout) - assertValidResponse(t, req, outCh, tc.wantOpts) + assertValidResponse4(t, req, outCh, tc.wantOpts) }) } } @@ -166,7 +154,7 @@ func TestDHCPServer_ServeEther4_discoverExpired(t *testing.T) { pkt := newDHCPDISCOVER(t, testHWUnknown) req := testutil.RequireTypeAssert[*layers.DHCPv4](t, pkt.Layer(layers.LayerTypeDHCPv4)) - ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddr) + ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddrV4) startTestDHCPServer(t, &dhcpsvc.Config{ Interfaces: testIPv4InterfacesConf, @@ -177,9 +165,9 @@ func TestDHCPServer_ServeEther4_discoverExpired(t *testing.T) { testutil.RequireSend(t, inCh, pkt, testTimeout) - assertValidResponse(t, req, outCh, layers.DHCPOptions{ + assertValidResponse4(t, req, outCh, layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeOffer), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testLeaseTTL), }) } @@ -187,18 +175,18 @@ func TestDHCPServer_ServeEther4_discoverExpired(t *testing.T) { func TestDHCPServer_ServeEther4_release(t *testing.T) { t.Parallel() - ipMismatch := testIPDynamic.Next().Next() + ipMismatch := testIPv4Dynamic.Next().Next() testCases := []struct { req gopacket.Packet name string wantChange bool }{{ - req: newDHCPRELEASE(t, testHWDynamic, testIPDynamic), + req: newDHCPRELEASE(t, testHWDynamic, testIPv4Dynamic), name: "success", wantChange: true, }, { - req: newDHCPRELEASE(t, testHWUnknown, testIPDynamic), + req: newDHCPRELEASE(t, testHWUnknown, testIPv4Dynamic), name: "not_found", wantChange: false, }, { @@ -206,7 +194,7 @@ func TestDHCPServer_ServeEther4_release(t *testing.T) { name: "mismatch_ip", wantChange: false, }, { - req: newDHCPRELEASE(t, testHWDynamic, testIPOtherSubnet), + req: newDHCPRELEASE(t, testHWDynamic, testIPv4OtherSubnet), name: "bad_subnet", wantChange: false, }} @@ -217,7 +205,7 @@ func TestDHCPServer_ServeEther4_release(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ndMgr, inCh, _ := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddr) + ndMgr, inCh, _ := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddrV4) srv := newTestDHCPServer(t, &dhcpsvc.Config{ Interfaces: testIPv4InterfacesConf, NetworkDeviceManager: ndMgr, @@ -259,7 +247,7 @@ func TestDHCPServer_ServeEther4_requestSelecting(t *testing.T) { request: newDHCPREQUEST(t, &dhcpRequestConfig{ options: layers.DHCPOptions{ newOptRequestIP(t, testIPv4Conf.RangeStart), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), }, clientHWAddr: testHWUnknown, flags: dhcpsvc.FlagsBroadcast, @@ -267,15 +255,15 @@ func TestDHCPServer_ServeEther4_requestSelecting(t *testing.T) { name: "success", wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeAck), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testLeaseTTL), }, }, { discover: newDHCPDISCOVER(t, testHWStatic), request: newDHCPREQUEST(t, &dhcpRequestConfig{ options: layers.DHCPOptions{ - newOptRequestIP(t, testIPStatic), - newOptServerID(t, testIPOtherSubnet), + newOptRequestIP(t, testIPv4Static), + newOptServerID(t, testIPv4OtherSubnet), }, clientHWAddr: testHWStatic, flags: dhcpsvc.FlagsBroadcast, @@ -287,7 +275,7 @@ func TestDHCPServer_ServeEther4_requestSelecting(t *testing.T) { request: newDHCPREQUEST(t, &dhcpRequestConfig{ options: layers.DHCPOptions{ newOptRequestIP(t, testIPv4Conf.RangeEnd.Next()), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), }, clientHWAddr: testHWUnknown, flags: dhcpsvc.FlagsBroadcast, @@ -295,14 +283,14 @@ func TestDHCPServer_ServeEther4_requestSelecting(t *testing.T) { name: "no_lease", wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeNak), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), }, }, { discover: newDHCPDISCOVER(t, testHWStatic), request: newDHCPREQUEST(t, &dhcpRequestConfig{ options: layers.DHCPOptions{ newOptRequestIP(t, testIPv4Conf.RangeEnd.Next()), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), }, clientHWAddr: testHWStatic, flags: dhcpsvc.FlagsBroadcast, @@ -310,17 +298,17 @@ func TestDHCPServer_ServeEther4_requestSelecting(t *testing.T) { name: "wrong_ip", wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeNak), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), }, }, { discover: newDHCPDISCOVER(t, testHWStatic), request: newDHCPREQUEST(t, &dhcpRequestConfig{ options: layers.DHCPOptions{ - newOptRequestIP(t, testIPStatic), - newOptServerID(t, testIfaceAddr), + newOptRequestIP(t, testIPv4Static), + newOptServerID(t, testIfaceAddrV4), }, clientHWAddr: testHWStatic, - clientIP: testIPStatic, + clientIP: testIPv4Static, flags: dhcpsvc.FlagsBroadcast, }), name: "nonzero_ciaddr", @@ -333,7 +321,7 @@ func TestDHCPServer_ServeEther4_requestSelecting(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddr) + ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddrV4) startTestDHCPServer(t, &dhcpsvc.Config{ Logger: slogutil.NewDiscardLogger(), Interfaces: testIPv4InterfacesConf, @@ -351,7 +339,7 @@ func TestDHCPServer_ServeEther4_requestSelecting(t *testing.T) { testutil.RequireSend(t, inCh, tc.request, testTimeout) - assertValidResponse(t, dhcpv4FromPacket(t, tc.request), outCh, tc.wantOpts) + assertValidResponse4(t, dhcpv4FromPacket(t, tc.request), outCh, tc.wantOpts) }) } } @@ -366,31 +354,31 @@ func TestDHCPServer_ServeEther4_requestInitReboot(t *testing.T) { }{{ name: "success", req: newDHCPREQUEST(t, &dhcpRequestConfig{ - options: layers.DHCPOptions{newOptRequestIP(t, testIPStatic)}, + options: layers.DHCPOptions{newOptRequestIP(t, testIPv4Static)}, clientHWAddr: testHWStatic, flags: dhcpsvc.FlagsBroadcast, }), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeAck), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testLeaseTTL), - newOptHostname(t, testLeaseHostnameStatic), + newOptHostname(t, testLease4HostnameStatic), }, }, { name: "wrong_subnet", req: newDHCPREQUEST(t, &dhcpRequestConfig{ - options: layers.DHCPOptions{newOptRequestIP(t, testIPOtherSubnet)}, + options: layers.DHCPOptions{newOptRequestIP(t, testIPv4OtherSubnet)}, clientHWAddr: testHWStatic, flags: dhcpsvc.FlagsBroadcast, }), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeNak), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), }, }, { name: "no_lease", req: newDHCPREQUEST(t, &dhcpRequestConfig{ - options: layers.DHCPOptions{newOptRequestIP(t, testIPStatic)}, + options: layers.DHCPOptions{newOptRequestIP(t, testIPv4Static)}, clientHWAddr: testHWUnknown, flags: dhcpsvc.FlagsBroadcast, }), @@ -398,30 +386,30 @@ func TestDHCPServer_ServeEther4_requestInitReboot(t *testing.T) { }, { name: "wrong_ip", req: newDHCPREQUEST(t, &dhcpRequestConfig{ - options: layers.DHCPOptions{newOptRequestIP(t, testIPDynamic)}, + options: layers.DHCPOptions{newOptRequestIP(t, testIPv4Dynamic)}, clientHWAddr: testHWStatic, flags: dhcpsvc.FlagsBroadcast, }), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeNak), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), }, }, { name: "wrong_ip_no_broadcast", req: newDHCPREQUEST(t, &dhcpRequestConfig{ - options: layers.DHCPOptions{newOptRequestIP(t, testIPDynamic)}, + options: layers.DHCPOptions{newOptRequestIP(t, testIPv4Dynamic)}, clientHWAddr: testHWStatic, }), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeNak), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), }, }, { name: "nonzero_ciaddr", req: newDHCPREQUEST(t, &dhcpRequestConfig{ - options: layers.DHCPOptions{newOptRequestIP(t, testIPStatic)}, + options: layers.DHCPOptions{newOptRequestIP(t, testIPv4Static)}, clientHWAddr: testHWStatic, - clientIP: testIPStatic, + clientIP: testIPv4Static, flags: dhcpsvc.FlagsBroadcast, }), wantOpts: nil, @@ -433,7 +421,7 @@ func TestDHCPServer_ServeEther4_requestInitReboot(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddr) + ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddrV4) startTestDHCPServer(t, &dhcpsvc.Config{ Interfaces: testIPv4InterfacesConf, NetworkDeviceManager: ndMgr, @@ -443,7 +431,7 @@ func TestDHCPServer_ServeEther4_requestInitReboot(t *testing.T) { testutil.RequireSend(t, inCh, tc.req, testTimeout) - assertValidResponse(t, dhcpv4FromPacket(t, tc.req), outCh, tc.wantOpts) + assertValidResponse4(t, dhcpv4FromPacket(t, tc.req), outCh, tc.wantOpts) }) } } @@ -459,52 +447,52 @@ func TestDHCPServer_ServeEther4_requestRenewSuccess(t *testing.T) { name: "success", req: newDHCPREQUEST(t, &dhcpRequestConfig{ clientHWAddr: testHWDynamic, - clientIP: testIPDynamic, + clientIP: testIPv4Dynamic, flags: dhcpsvc.FlagsBroadcast, }), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeAck), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testTTLDynamicLease), - newOptHostname(t, testLeaseHostnameDynamic), + newOptHostname(t, testLease4HostnameDynamic), }, }, { name: "static", req: newDHCPREQUEST(t, &dhcpRequestConfig{ clientHWAddr: testHWStatic, - clientIP: testIPStatic, + clientIP: testIPv4Static, flags: dhcpsvc.FlagsBroadcast, }), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeAck), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testLeaseTTL), - newOptHostname(t, testLeaseHostnameStatic), + newOptHostname(t, testLease4HostnameStatic), }, }, { name: "relay_agent", req: newDHCPREQUEST(t, &dhcpRequestConfig{ clientHWAddr: testHWDynamic, - clientIP: testIPDynamic, - relayAgentIP: testIPRelayAgent, + clientIP: testIPv4Dynamic, + relayAgentIP: testIPv4RelayAgent, }), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeAck), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testTTLDynamicLease), - newOptHostname(t, testLeaseHostnameDynamic), + newOptHostname(t, testLease4HostnameDynamic), }, }, { name: "ciaddr_unicast", req: newDHCPREQUEST(t, &dhcpRequestConfig{ clientHWAddr: testHWDynamic, - clientIP: testIPDynamic, + clientIP: testIPv4Dynamic, }), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeAck), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), newOptLeaseTime(t, testTTLDynamicLease), - newOptHostname(t, testLeaseHostnameDynamic), + newOptHostname(t, testLease4HostnameDynamic), }, }} @@ -514,7 +502,7 @@ func TestDHCPServer_ServeEther4_requestRenewSuccess(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddr) + ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddrV4) startTestDHCPServer(t, &dhcpsvc.Config{ Interfaces: testIPv4InterfacesConf, NetworkDeviceManager: ndMgr, @@ -524,7 +512,7 @@ func TestDHCPServer_ServeEther4_requestRenewSuccess(t *testing.T) { testutil.RequireSend(t, inCh, tc.req, testTimeout) - assertValidResponse(t, dhcpv4FromPacket(t, tc.req), outCh, tc.wantOpts) + assertValidResponse4(t, dhcpv4FromPacket(t, tc.req), outCh, tc.wantOpts) }) } } @@ -540,7 +528,7 @@ func TestDHCPServer_ServeEther4_requestRenewFail(t *testing.T) { name: "wrong_subnet", req: newDHCPREQUEST(t, &dhcpRequestConfig{ clientHWAddr: testHWStatic, - clientIP: testIPOtherSubnet, + clientIP: testIPv4OtherSubnet, flags: dhcpsvc.FlagsBroadcast, }), wantOpts: nil, @@ -548,7 +536,7 @@ func TestDHCPServer_ServeEther4_requestRenewFail(t *testing.T) { name: "no_lease", req: newDHCPREQUEST(t, &dhcpRequestConfig{ clientHWAddr: testHWUnknown, - clientIP: testIPStatic, + clientIP: testIPv4Static, flags: dhcpsvc.FlagsBroadcast, }), wantOpts: nil, @@ -556,12 +544,12 @@ func TestDHCPServer_ServeEther4_requestRenewFail(t *testing.T) { name: "wrong_ip", req: newDHCPREQUEST(t, &dhcpRequestConfig{ clientHWAddr: testHWStatic, - clientIP: testIPDynamic, + clientIP: testIPv4Dynamic, flags: dhcpsvc.FlagsBroadcast, }), wantOpts: layers.DHCPOptions{ newOptMessageType(t, layers.DHCPMsgTypeNak), - newOptServerID(t, testIfaceAddr), + newOptServerID(t, testIfaceAddrV4), }, }} @@ -571,7 +559,7 @@ func TestDHCPServer_ServeEther4_requestRenewFail(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddr) + ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddrV4) startTestDHCPServer(t, &dhcpsvc.Config{ Interfaces: testIPv4InterfacesConf, NetworkDeviceManager: ndMgr, @@ -581,7 +569,7 @@ func TestDHCPServer_ServeEther4_requestRenewFail(t *testing.T) { testutil.RequireSend(t, inCh, tc.req, testTimeout) - assertValidResponse(t, dhcpv4FromPacket(t, tc.req), outCh, tc.wantOpts) + assertValidResponse4(t, dhcpv4FromPacket(t, tc.req), outCh, tc.wantOpts) }) } } @@ -594,19 +582,19 @@ func TestDHCPServer_ServeEther4_decline(t *testing.T) { name string wantChange bool }{{ - req: newDHCPDECLINE(t, testHWDynamic, testIPDynamic), + req: newDHCPDECLINE(t, testHWDynamic, testIPv4Dynamic), name: "success", wantChange: true, }, { - req: newDHCPDECLINE(t, testHWUnknown, testIPDynamic), + req: newDHCPDECLINE(t, testHWUnknown, testIPv4Dynamic), name: "not_found", wantChange: false, }, { - req: newDHCPDECLINE(t, testHWAnother, testIPUnknown), + req: newDHCPDECLINE(t, testHWAnother, testIPv4Unknown), name: "mismatch_ip", wantChange: false, }, { - req: newDHCPDECLINE(t, testHWDynamic, testIPOtherSubnet), + req: newDHCPDECLINE(t, testHWDynamic, testIPv4OtherSubnet), name: "bad_subnet", wantChange: false, }, { @@ -621,7 +609,7 @@ func TestDHCPServer_ServeEther4_decline(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ndMgr, inCh, _ := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddr) + ndMgr, inCh, _ := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddrV4) srv := newTestDHCPServer(t, &dhcpsvc.Config{ Interfaces: testIPv4InterfacesConf, NetworkDeviceManager: ndMgr, @@ -675,7 +663,7 @@ type dhcpRequestConfig struct { func newDHCPREQUEST(tb testing.TB, conf *dhcpRequestConfig) (pkt gopacket.Packet) { tb.Helper() - eth := newEthernet4Layer(tb, conf.clientHWAddr, nil) + eth := newEthernetLayer(tb, conf.clientHWAddr, nil, layers.EthernetTypeIPv4) ip, udp := newIPv4UDPLayer( tb, @@ -711,7 +699,7 @@ func newDHCPREQUEST(tb testing.TB, conf *dhcpRequestConfig) (pkt gopacket.Packet func newDHCPDISCOVER(tb testing.TB, clientHWAddr net.HardwareAddr) (pkt gopacket.Packet) { tb.Helper() - eth := newEthernet4Layer(tb, clientHWAddr, nil) + eth := newEthernetLayer(tb, clientHWAddr, nil, layers.EthernetTypeIPv4) ip, udp := newIPv4UDPLayer(tb, netip.AddrPort{}, netip.AddrPort{}) @@ -737,12 +725,12 @@ func newDHCPRELEASE( ) (pkt gopacket.Packet) { tb.Helper() - eth := newEthernet4Layer(tb, clientHWAddr, testIfaceHWAddr) + eth := newEthernetLayer(tb, clientHWAddr, testIfaceHWAddr, layers.EthernetTypeIPv4) ip, udp := newIPv4UDPLayer( tb, netip.AddrPortFrom(clientIP, uint16(dhcpsvc.ClientPortV4)), - netip.AddrPortFrom(testIfaceAddr, uint16(dhcpsvc.ServerPortV4)), + netip.AddrPortFrom(testIfaceAddrV4, uint16(dhcpsvc.ServerPortV4)), ) dhcp := &layers.DHCPv4{ @@ -768,7 +756,7 @@ func newDHCPDECLINE( ) (pkt gopacket.Packet) { tb.Helper() - eth := newEthernet4Layer(tb, clientHWAddr, nil) + eth := newEthernetLayer(tb, clientHWAddr, nil, layers.EthernetTypeIPv4) ip, udp := newIPv4UDPLayer(tb, netip.AddrPort{}, netip.AddrPort{}) @@ -828,47 +816,6 @@ func newIPv4UDPLayer(tb testing.TB, src, dst netip.AddrPort) (ip *layers.IPv4, u return ip, udp } -// newEthernet4Layer creates a new Ethernet layer for IPv4 packets. Nil src is -// replaced with an unspecified MAC address, nil dst is replaced with a -// broadcast MAC address. -func newEthernet4Layer(tb testing.TB, src, dst net.HardwareAddr) (eth *layers.Ethernet) { - tb.Helper() - - if src == nil { - src = net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - } - if dst == nil { - dst = net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff} - } - - return &layers.Ethernet{ - SrcMAC: src, - DstMAC: dst, - EthernetType: layers.EthernetTypeIPv4, - } -} - -// newTestPacket creates a valid packet from ls using first as first layer -// decoder. -func newTestPacket( - tb testing.TB, - first gopacket.Decoder, - ls ...gopacket.SerializableLayer, -) (pkg gopacket.Packet) { - tb.Helper() - - buf := gopacket.NewSerializeBuffer() - - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - err := gopacket.SerializeLayers(buf, opts, ls...) - require.NoError(tb, err) - - return gopacket.NewPacket(buf.Bytes(), first, gopacket.Default) -} - // requireEthernet requires data to contain an Ethernet layer and all layers // from ls. First of ls must be of type [layers.LayerTypeEthernet]. func requireEthernet( @@ -886,10 +833,10 @@ func requireEthernet( return types } -// assertValidResponse asserts that recvCh eventually gets the response with +// assertValidResponse4 asserts that recvCh eventually gets the response with // wantOpts for request. If wantOpts is nil, asserts that no response is sent. // request and recvCh must not be nil. -func assertValidResponse( +func assertValidResponse4( tb testing.TB, request *layers.DHCPv4, recvCh <-chan []byte, @@ -910,7 +857,7 @@ func assertValidResponse( udp := &layers.UDP{} resp := &layers.DHCPv4{} types := requireEthernet(tb, respData, &layers.Ethernet{}, ip, udp, resp) - require.Equal(tb, fullLayersStack, types) + require.Equal(tb, fullLayersStack4, types) assertValidDHCPv4(tb, request, resp, ip, udp) @@ -927,6 +874,8 @@ func assertValidResponse( // assertValidDHCPv4 asserts that the response is valid for the given request // according to RFC 2131. func assertValidDHCPv4(tb testing.TB, req, resp *layers.DHCPv4, ip *layers.IPv4, udp *layers.UDP) { + tb.Helper() + switch { case !req.RelayAgentIP.IsUnspecified(): assert.Equal(tb, req.RelayAgentIP.To4(), ip.DstIP) @@ -945,21 +894,6 @@ func assertValidDHCPv4(tb testing.TB, req, resp *layers.DHCPv4, ip *layers.IPv4, } } -// assertNoResponse asserts that no response is received on the channel within -// the timeout. -// -// TODO(e.burkov): Improve the helper to not rely on timeout. -func assertNoResponse(tb testing.TB, outCh <-chan []byte, timeout time.Duration) { - tb.Helper() - - var resp []byte - require.Panics(tb, func() { - resp, _ = testutil.RequireReceive(testutil.NewPanicT(tb), outCh, timeout) - }) - - require.Nil(tb, resp) -} - // dhcpv4FromPacket extracts the DHCPv4 layer from pkt, which is required to // contain one. func dhcpv4FromPacket(tb testing.TB, pkt gopacket.Packet) (msg *layers.DHCPv4) { diff --git a/internal/dhcpsvc/handler6.go b/internal/dhcpsvc/handler6.go index 12b376b3ed2..4c12b2f08f4 100644 --- a/internal/dhcpsvc/handler6.go +++ b/internal/dhcpsvc/handler6.go @@ -92,7 +92,6 @@ func (iface *dhcpInterfaceV6) handleSolicit( } if lease == nil { - l.DebugContext(ctx, "no ia_na in solicit or no addresses available") resp.Options = iface.newSolicitRespOpts(fd, req, cliID, iaid, nil, false) return respond6(fd, resp) @@ -110,6 +109,12 @@ func (iface *dhcpInterfaceV6) handleSolicit( if err != nil { l.WarnContext(ctx, "committing rapid leases", slogutil.KeyError, err) isRapidCommit = false + } else { + // The server will also send a Reply in response to a Solicit with a + // Rapid Commit option. + // + // See RFC 9915 Section 18.3. + resp.MsgType = layers.DHCPv6MsgTypeReply } resp.Options = iface.newSolicitRespOpts(fd, req, cliID, iaid, lease, isRapidCommit) diff --git a/internal/dhcpsvc/handler6_test.go b/internal/dhcpsvc/handler6_test.go new file mode 100644 index 00000000000..a491f9c9b9a --- /dev/null +++ b/internal/dhcpsvc/handler6_test.go @@ -0,0 +1,337 @@ +package dhcpsvc_test + +import ( + "net" + "net/netip" + "slices" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/testutil" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TODO(e.burkov): Add tests for wrong packets. + +// testIPv6InterfacesConf is the test interfaces configuration for the DHCPv6 +// part of the [DHCPServer]. +var testIPv6InterfacesConf = map[string]*dhcpsvc.InterfaceConfig{ + testIfaceName: { + IPv4: disabledIPv4Conf, + IPv6: testIPv6Conf, + }, +} + +// testIAID is a common IAID for IANA options in tests. +const testIAID = 1 + +// testTransactionID is a sample transaction ID for testing. +// +// TODO(e.burkov): Generate unique IDs when they will be actually used. +var testTransactionID = []byte{0x01, 0x02, 0x03} + +// IP addresses for test cases. +// +// NOTE: Keep in sync with testdata. +var ( + // testIPv6Unknown is the test IP address for an unknown client. + testIPv6Unknown = netip.MustParseAddr("2001:db8::64") + + // testIPv6Dynamic is the test IP address for a known dynamic lease. + testIPv6Dynamic = netip.MustParseAddr("2001:db8::66") + + // testIPv6Expired is the test IP address for a known expired lease. + testIPv6Expired = netip.MustParseAddr("2001:db8::67") + + // testIPv6Static is the test IP address for a known static lease. + testIPv6Static = netip.MustParseAddr("2001:db8::65") +) + +func TestDHCPServer_ServeEther6_solicit(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + in gopacket.Packet + wantOpts layers.DHCPv6Options + }{{ + name: "new", + in: newDHCPv6SOLICIT(t, testHWUnknown, testIPv6Unknown, false), + wantOpts: layers.DHCPv6Options{ + newOptServerDUID(t, testIfaceHWAddr), + newOptClientDUID(t, testHWUnknown), + newOptIANA(t, testIAID, testIPv6Conf.RangeStart), + newOptPreference(t, 0), + newOptSolMaxRT(t, dhcpsvc.DefaultSolMaxRT), + }, + }, { + name: "existing_static", + in: newDHCPv6SOLICIT(t, testHWStatic, testIPv6Static, false), + wantOpts: layers.DHCPv6Options{ + newOptServerDUID(t, testIfaceHWAddr), + newOptClientDUID(t, testHWStatic), + newOptIANA(t, testIAID, testIPv6Static), + newOptPreference(t, 0), + newOptSolMaxRT(t, dhcpsvc.DefaultSolMaxRT), + }, + }, { + name: "existing_dynamic", + in: newDHCPv6SOLICIT(t, testHWDynamic, testIPv6Dynamic, false), + wantOpts: layers.DHCPv6Options{ + newOptServerDUID(t, testIfaceHWAddr), + newOptClientDUID(t, testHWDynamic), + newOptIANA(t, testIAID, testIPv6Dynamic), + newOptPreference(t, 0), + newOptSolMaxRT(t, dhcpsvc.DefaultSolMaxRT), + }, + }, { + name: "existing_expired", + in: newDHCPv6SOLICIT(t, testHWExpired, testIPv6Expired, false), + wantOpts: layers.DHCPv6Options{ + newOptServerDUID(t, testIfaceHWAddr), + newOptClientDUID(t, testHWExpired), + newOptIANA(t, testIAID, testIPv6Expired), + newOptPreference(t, 0), + newOptSolMaxRT(t, dhcpsvc.DefaultSolMaxRT), + }, + }, { + name: "new_rapid_commit", + in: newDHCPv6SOLICIT(t, testHWUnknown, testIPv6Unknown, true), + wantOpts: layers.DHCPv6Options{ + newOptServerDUID(t, testIfaceHWAddr), + newOptClientDUID(t, testHWUnknown), + newOptIANA(t, testIAID, testIPv6Conf.RangeStart), + newOptPreference(t, 0), + newOptSolMaxRT(t, dhcpsvc.DefaultSolMaxRT), + layers.NewDHCPv6Option(layers.DHCPv6OptRapidCommit, []byte{}), + }, + }, { + name: "existing_rapid_commit", + in: newDHCPv6SOLICIT(t, testHWStatic, testIPv6Static, true), + wantOpts: layers.DHCPv6Options{ + newOptServerDUID(t, testIfaceHWAddr), + newOptClientDUID(t, testHWStatic), + newOptIANA(t, testIAID, testIPv6Static), + newOptPreference(t, 0), + newOptSolMaxRT(t, dhcpsvc.DefaultSolMaxRT), + layers.NewDHCPv6Option(layers.DHCPv6OptRapidCommit, []byte{}), + }, + }, { + name: "existing_dynamic_rapid_commit", + in: newDHCPv6SOLICIT(t, testHWDynamic, testIPv6Dynamic, true), + wantOpts: layers.DHCPv6Options{ + newOptServerDUID(t, testIfaceHWAddr), + newOptClientDUID(t, testHWDynamic), + newOptIANA(t, testIAID, testIPv6Dynamic), + newOptPreference(t, 0), + newOptSolMaxRT(t, dhcpsvc.DefaultSolMaxRT), + layers.NewDHCPv6Option(layers.DHCPv6OptRapidCommit, []byte{}), + }, + }, { + name: "existing_expired_rapid_commit", + in: newDHCPv6SOLICIT(t, testHWExpired, testIPv6Expired, true), + wantOpts: layers.DHCPv6Options{ + newOptServerDUID(t, testIfaceHWAddr), + newOptClientDUID(t, testHWExpired), + newOptIANA(t, testIAID, testIPv6Expired), + newOptPreference(t, 0), + newOptSolMaxRT(t, dhcpsvc.DefaultSolMaxRT), + layers.NewDHCPv6Option(layers.DHCPv6OptRapidCommit, []byte{}), + }, + }} + + for _, tc := range testCases { + req := testutil.RequireTypeAssert[*layers.DHCPv6](t, tc.in.Layer(layers.LayerTypeDHCPv6)) + dbFilePath := newTempDB(t) + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ndMgr, inCh, outCh := newTestNetworkDeviceManager(t, testIfaceName, testIfaceAddrV6) + startTestDHCPServer(t, &dhcpsvc.Config{ + Interfaces: testIPv6InterfacesConf, + NetworkDeviceManager: ndMgr, + Logger: testLogger, + DBFilePath: dbFilePath, + Enabled: true, + }) + + testutil.RequireSend(t, inCh, tc.in, testTimeout) + + assertValidResponse6(t, req, outCh, tc.wantOpts) + }) + } +} + +// newDHCPv6SOLICIT creates a new DHCPv6 SOLICIT packet for testing. +func newDHCPv6SOLICIT( + tb testing.TB, + hwAddr net.HardwareAddr, + reqIP netip.Addr, + rapidCommit bool, +) (pkt gopacket.Packet) { + tb.Helper() + + eth := newEthernetLayer(tb, hwAddr, nil, layers.EthernetTypeIPv6) + + ip, udp := newIPv6UDPLayer(tb, netip.AddrPort{}, netip.AddrPort{}) + + dhcp := &layers.DHCPv6{ + MsgType: layers.DHCPv6MsgTypeSolicit, + HopCount: 0, + // Don't specify link and peer addresses, as they are intended for relay + // messages. + LinkAddr: nil, + PeerAddr: nil, + TransactionID: testTransactionID, + Options: layers.DHCPv6Options{ + newOptClientDUID(tb, hwAddr), + }, + } + + if reqIP.IsValid() && reqIP.Is6() { + dhcp.Options = append(dhcp.Options, newOptIANA(tb, testIAID, reqIP)) + } + + if rapidCommit { + o := layers.NewDHCPv6Option(layers.DHCPv6OptRapidCommit, nil) + dhcp.Options = append(dhcp.Options, o) + } + + return newTestPacket(tb, layers.LinkTypeEthernet, eth, ip, udp, dhcp) +} + +// newIPv6UDPLayer creates IPv6 and UDP layers for testing. Invalid src is +// replaced with an unspecified address and client DHCPv6 port, invalid dst is +// replaced with the broadcast address and server DHCPv6 port. +func newIPv6UDPLayer(tb testing.TB, src, dst netip.AddrPort) (ip *layers.IPv6, udp *layers.UDP) { + tb.Helper() + + if !src.IsValid() { + src = netip.AddrPortFrom(netip.IPv6Unspecified(), uint16(dhcpsvc.ClientPortV6)) + } + + if !dst.IsValid() { + bcastAddr, ok := netip.AddrFromSlice(net.IPv6linklocalallnodes) + require.True(tb, ok) + + dst = netip.AddrPortFrom(bcastAddr, uint16(dhcpsvc.ServerPortV6)) + } + + ip = &layers.IPv6{ + Version: 6, + HopLimit: dhcpsvc.IPv6DefaultHopLimit, + SrcIP: src.Addr().AsSlice(), + DstIP: dst.Addr().AsSlice(), + NextHeader: layers.IPProtocolUDP, + } + udp = &layers.UDP{ + SrcPort: layers.UDPPort(src.Port()), + DstPort: layers.UDPPort(dst.Port()), + } + require.NoError(tb, udp.SetNetworkLayerForChecksum(ip)) + + return ip, udp +} + +// newEthernetLayer creates a new Ethernet layer for IP packets of the specified +// type. Nil src is replaced with an unspecified MAC address, nil dst is +// replaced with a broadcast MAC address, typ must be [layers.EthernetTypeIPv4] +// or [layers.EthernetTypeIPv6]. +func newEthernetLayer( + tb testing.TB, + src net.HardwareAddr, + dst net.HardwareAddr, + typ layers.EthernetType, +) (eth *layers.Ethernet) { + tb.Helper() + + if src == nil { + src = net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + } + if dst == nil { + dst = net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + } + + return &layers.Ethernet{ + SrcMAC: src, + DstMAC: dst, + EthernetType: typ, + } +} + +// assertValidResponse6 asserts that the response received on recvCh is a valid +// DHCPv6 response for the given request and contains the expected options. If +// wantOpts is nil, it asserts that no response is received. +func assertValidResponse6( + tb testing.TB, + req *layers.DHCPv6, + recvCh <-chan []byte, + wantOpts layers.DHCPv6Options, +) { + tb.Helper() + + if wantOpts == nil { + assertNoResponse(tb, recvCh, testTimeout/10) + + return + } + + respData, ok := testutil.RequireReceive(tb, recvCh, testTimeout) + require.True(tb, ok) + + ip := &layers.IPv6{} + udp := &layers.UDP{} + resp := &layers.DHCPv6{} + types := requireEthernet(tb, respData, &layers.Ethernet{}, ip, udp, resp) + require.Equal(tb, fullLayersStack6, types) + + assertValidDHCPv6(tb, req, resp) + + // TODO(e.burkov): Consider comparing the whole message instead of separate + // fields. + assert.Equal(tb, req.LinkAddr, resp.LinkAddr, "link address") + assert.Equal(tb, req.PeerAddr, resp.PeerAddr, "peer address") + assert.Equal(tb, req.TransactionID, resp.TransactionID, "transaction id") + assert.Equal(tb, wantOpts, resp.Options, "options") +} + +// assertValidDHCPv6 asserts that the response is valid for the given request +// according to RFC 9915. +// +// TODO(e.burkov): Add more checks involving other network layers. +func assertValidDHCPv6( + tb testing.TB, + req *layers.DHCPv6, + resp *layers.DHCPv6, +) { + tb.Helper() + + switch req.MsgType { + case + layers.DHCPv6MsgTypeRequest, + layers.DHCPv6MsgTypeConfirm, + layers.DHCPv6MsgTypeRenew, + layers.DHCPv6MsgTypeRebind, + layers.DHCPv6MsgTypeRelease, + layers.DHCPv6MsgTypeDecline, + layers.DHCPv6MsgTypeInformationRequest: + assert.Equal(tb, layers.DHCPv6MsgTypeReply, resp.MsgType) + case layers.DHCPv6MsgTypeSolicit: + isRapidCommit := slices.ContainsFunc(resp.Options, func(o layers.DHCPv6Option) (ok bool) { + return o.Code == layers.DHCPv6OptRapidCommit + }) + + if isRapidCommit { + assert.Equal(tb, layers.DHCPv6MsgTypeReply, resp.MsgType) + } else { + assert.Equal(tb, layers.DHCPv6MsgTypeAdverstise, resp.MsgType) + } + default: + tb.Errorf("request message type: %v: %s", errors.ErrUnexpectedValue, req.MsgType) + } +} diff --git a/internal/dhcpsvc/interface.go b/internal/dhcpsvc/interface.go index ad594900410..25b08684fc0 100644 --- a/internal/dhcpsvc/interface.go +++ b/internal/dhcpsvc/interface.go @@ -226,6 +226,8 @@ func (iface *netInterface) allocateLease( // reserveLease reserves a lease for a client by its MAC-address. lease is nil // if a new lease can't be allocated. mac must be a valid according to // [netutil.ValidateMAC]. iface.indexMu mutex must be locked. +// +// TODO(e.burkov): Pass the time moment instead of clock. func (iface *netInterface) reserveLease( ctx context.Context, mac net.HardwareAddr, diff --git a/internal/dhcpsvc/options6.go b/internal/dhcpsvc/options6.go index 1031292a9f9..c7cafdf83a5 100644 --- a/internal/dhcpsvc/options6.go +++ b/internal/dhcpsvc/options6.go @@ -18,34 +18,34 @@ import ( // See RFC 9915 Section 21.4. const iaNAMinLen = 12 -// iaNAOption represents a parsed IA_NA (Identity Association for Non-temporary +// IANAOption represents a parsed IA_NA (Identity Association for Non-temporary // Addresses) option. // // See RFC 9915 Section 21.4. -type iaNAOption struct { - // nested are the IA Address options nested within this IA_NA. - nested []iaAddrOption +type IANAOption struct { + // Nested are the IA Address options Nested within this IA_NA. + Nested []IAAddrOption - // iaid is the Identity Association IDentifier, a 4-octet value uniquely + // ID is the Identity Association Identifier, a 4-octet value uniquely // identifying this IA within the client. // // TODO(e.burkov): Add new type. - iaid uint32 + ID uint32 - // t1 is the time after which the client must contact the same server to + // T1 is the time after which the client must contact the same server to // extend the lifetimes of the addresses in this IA. - t1 time.Duration + T1 time.Duration - // t2 is the time after which the client may contact any available server to + // T2 is the time after which the client may contact any available server to // extend the lifetimes. - t2 time.Duration + T2 time.Duration } // type check -var _ encoding.BinaryUnmarshaler = (*iaNAOption)(nil) +var _ encoding.BinaryUnmarshaler = (*IANAOption)(nil) // UnmarshalBinary implements the [encoding.BinaryUnmarshaler] interface for -// *iaNAOption. data should have the following format: +// *IANAOption. data should have the following format: // // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 @@ -60,16 +60,16 @@ var _ encoding.BinaryUnmarshaler = (*iaNAOption)(nil) // . IA_NA-options . // . . // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -func (opt *iaNAOption) UnmarshalBinary(data []byte) (err error) { +func (opt *IANAOption) UnmarshalBinary(data []byte) (err error) { err = validate.NoLessThan("data length", len(data), iaNAMinLen) if err != nil { // Don't wrap the error, since it's informative enough as is. return err } - opt.iaid = binary.BigEndian.Uint32(data[0:4]) - opt.t1 = time.Duration(binary.BigEndian.Uint32(data[4:8])) * time.Second - opt.t2 = time.Duration(binary.BigEndian.Uint32(data[8:12])) * time.Second + opt.ID = binary.BigEndian.Uint32(data[0:4]) + opt.T1 = time.Duration(binary.BigEndian.Uint32(data[4:8])) * time.Second + opt.T2 = time.Duration(binary.BigEndian.Uint32(data[8:12])) * time.Second // Parse the nested options that follow the fixed fields. nested := data[iaNAMinLen:] @@ -83,13 +83,13 @@ func (opt *iaNAOption) UnmarshalBinary(data []byte) (err error) { } if code == layers.DHCPv6OptIAAddr { - addr := iaAddrOption{} + addr := IAAddrOption{} err = addr.UnmarshalBinary(nested[4 : 4+l]) if err != nil { return fmt.Errorf("nested ia_addr at index %d: %w", i, err) } - opt.nested = append(opt.nested, addr) + opt.Nested = append(opt.Nested, addr) } nested = nested[4+l:] @@ -98,21 +98,21 @@ func (opt *iaNAOption) UnmarshalBinary(data []byte) (err error) { return nil } -// Encode serializes ia into a DHCPv6 IA_NA option. Each contained -// [iaAddrOption] is encoded as a nested IA Address option. +// Encode serializes opt into a DHCPv6 IA_NA option. Each contained +// [IAAddrOption] is encoded as a nested IA Address option. // // TODO(e.burkov): Use. -func (opt iaNAOption) Encode() (iaOpt layers.DHCPv6Option) { +func (opt IANAOption) Encode() (iaOpt layers.DHCPv6Option) { // Each nested IA Address option: code (2) + length (2) + data (24). const nestedAddrSize = 2 + 2 + iaAddrDataLen - data := make([]byte, 0, iaNAMinLen+len(opt.nested)*nestedAddrSize) + data := make([]byte, 0, iaNAMinLen+len(opt.Nested)*nestedAddrSize) - data = binary.BigEndian.AppendUint32(data, opt.iaid) - data = binary.BigEndian.AppendUint32(data, uint32(opt.t1.Seconds())) - data = binary.BigEndian.AppendUint32(data, uint32(opt.t2.Seconds())) + data = binary.BigEndian.AppendUint32(data, opt.ID) + data = binary.BigEndian.AppendUint32(data, uint32(opt.T1.Seconds())) + data = binary.BigEndian.AppendUint32(data, uint32(opt.T2.Seconds())) - for _, addr := range opt.nested { + for _, addr := range opt.Nested { data = addr.appendTo(data) } @@ -125,27 +125,27 @@ func (opt iaNAOption) Encode() (iaOpt layers.DHCPv6Option) { // (4 bytes each). const iaAddrDataLen = 24 -// iaAddrOption represents a parsed IA Address option. +// IAAddrOption represents a parsed IA Address option. // // See RFC 9915 Section 21.6. -type iaAddrOption struct { - // addr is the IPv6 address. - addr netip.Addr +type IAAddrOption struct { + // Addr is the IPv6 address. + Addr netip.Addr - // preferredLifetime is the preferred lifetime of the address. When it is + // PreferredLifetime is the preferred lifetime of the address. When it is // zero, the address is deprecated. - preferredLifetime time.Duration + PreferredLifetime time.Duration - // validLifetime is the valid lifetime of the address. When it is zero, the + // ValidLifetime is the valid lifetime of the address. When it is zero, the // address is no longer valid. - validLifetime time.Duration + ValidLifetime time.Duration } // type check -var _ encoding.BinaryUnmarshaler = (*iaAddrOption)(nil) +var _ encoding.BinaryUnmarshaler = (*IAAddrOption)(nil) // UnmarshalBinary implements the [encoding.BinaryUnmarshaler] interface for -// *iaAddrOption. Nested options within IA Address, if any, are +// *IAAddrOption. Nested options within IA Address, if any, are // ignored. data should have the following format: // // 0 1 2 3 @@ -164,7 +164,7 @@ var _ encoding.BinaryUnmarshaler = (*iaAddrOption)(nil) // . IAaddr-options . // . . // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -func (ia *iaAddrOption) UnmarshalBinary(data []byte) (err error) { +func (ia *IAAddrOption) UnmarshalBinary(data []byte) (err error) { err = validate.NoLessThan("data length", len(data), iaAddrDataLen) if err != nil { // Don't wrap the error, since it's informative enough as is. @@ -172,30 +172,30 @@ func (ia *iaAddrOption) UnmarshalBinary(data []byte) (err error) { } var ok bool - ia.addr, ok = netip.AddrFromSlice(data[0:16]) + ia.Addr, ok = netip.AddrFromSlice(data[0:16]) if !ok { return fmt.Errorf("ia_addr: invalid ipv6 address bytes") } - ia.preferredLifetime = time.Duration(binary.BigEndian.Uint32(data[16:20])) * time.Second - ia.validLifetime = time.Duration(binary.BigEndian.Uint32(data[20:24])) * time.Second + ia.PreferredLifetime = time.Duration(binary.BigEndian.Uint32(data[16:20])) * time.Second + ia.ValidLifetime = time.Duration(binary.BigEndian.Uint32(data[20:24])) * time.Second return nil } // appendTo returns the data portion of the IA Address option encoding, suitable // for use as a nested option inside an IA_NA. -func (ia iaAddrOption) appendTo(orig []byte) (data []byte) { +func (ia IAAddrOption) appendTo(orig []byte) (data []byte) { data = orig data = binary.BigEndian.AppendUint16(data, uint16(layers.DHCPv6OptIAAddr)) data = binary.BigEndian.AppendUint16(data, uint16(iaAddrDataLen)) // [netip.Addr.AppendBinary] never returns errors. - data, _ = ia.addr.AppendBinary(data) + data, _ = ia.Addr.AppendBinary(data) - data = binary.BigEndian.AppendUint32(data, uint32(ia.preferredLifetime.Seconds())) - data = binary.BigEndian.AppendUint32(data, uint32(ia.validLifetime.Seconds())) + data = binary.BigEndian.AppendUint32(data, uint32(ia.PreferredLifetime.Seconds())) + data = binary.BigEndian.AppendUint32(data, uint32(ia.ValidLifetime.Seconds())) return data } @@ -236,11 +236,11 @@ func serverDUID6(opts layers.DHCPv6Options) (duid []byte, ok bool) { return findOption6(opts, layers.DHCPv6OptServerID) } -// solMaxRT is the recommended SOL_MAX_RT value sent to clients. It caps the -// client's solicit retransmission interval. +// DefaultSolMaxRT is the recommended SOL_MAX_RT value sent to clients. It caps +// the client's solicit retransmission interval. // // See RFC 9915 Section 21.24. -const solMaxRT = 1 * time.Hour +const DefaultSolMaxRT = 1 * time.Hour // newPreferenceOption returns a DHCPv6 Preference option with the given value. // diff --git a/internal/dhcpsvc/options6_test.go b/internal/dhcpsvc/options6_test.go new file mode 100644 index 00000000000..011102df6d3 --- /dev/null +++ b/internal/dhcpsvc/options6_test.go @@ -0,0 +1,88 @@ +package dhcpsvc_test + +import ( + "encoding/binary" + "net" + "net/netip" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/google/gopacket/layers" +) + +// newOptIANA creates a DHCPv6 Identity Association for Non-temporary Address +// (3) option containing an IA Address with the specified IAID and requested IP +// address. reqIP must be a valid IPv6 address. The option will have the T1 +// and T2 values set to the recommended values based on the [testLeaseTTL] +// constant, see the RFC reference in the +// [dhcpsvc.DHCPServer.newDHCPInterfaceV6]. +func newOptIANA(tb testing.TB, iaid uint32, reqIP netip.Addr) (opt layers.DHCPv6Option) { + tb.Helper() + + iana := &dhcpsvc.IANAOption{ + ID: iaid, + Nested: []dhcpsvc.IAAddrOption{{ + PreferredLifetime: testLeaseTTL, + ValidLifetime: testLeaseTTL, + Addr: reqIP, + }}, + T1: testLeaseTTL / 2, + T2: testLeaseTTL * 4 / 5, + } + + return iana.Encode() +} + +// newOptPreference creates a DHCPv6 Preference (7) option with the specified +// preference value. +func newOptPreference(tb testing.TB, pref uint8) (opt layers.DHCPv6Option) { + tb.Helper() + + return layers.NewDHCPv6Option(layers.DHCPv6OptPreference, []byte{pref}) +} + +// newOptSolMaxRT creates a DHCPv6 Solicit Message Maximum Retransmission Time +// (80) option with the specified maxRT value. +func newOptSolMaxRT(tb testing.TB, maxRT time.Duration) (opt layers.DHCPv6Option) { + tb.Helper() + + return layers.NewDHCPv6Option( + layers.DHCPv6OptSolMaxRt, + binary.BigEndian.AppendUint32(nil, uint32(maxRT.Seconds())), + ) +} + +// newOptClientDUID creates a DHCPv6 Client Identifier (1) option containing a +// DUID-LL made of cliHWAddr. +func newOptClientDUID(tb testing.TB, cliHWAddr net.HardwareAddr) (opt layers.DHCPv6Option) { + tb.Helper() + + return newOptDUIDLL(tb, layers.DHCPv6OptClientID, cliHWAddr) +} + +// newOptServerID creates a DHCPv6 Server Identifier (2) option containing a +// DUID-LL made of srvHWAddr. +func newOptServerDUID(tb testing.TB, srvHWAddr net.HardwareAddr) (opt layers.DHCPv6Option) { + tb.Helper() + + return newOptDUIDLL(tb, layers.DHCPv6OptServerID, srvHWAddr) +} + +// newOptDUIDLL creates a DHCPv6 option with the specified code containing a +// DUID-LL made of hwAddr and Ethernet hardware type. +func newOptDUIDLL( + tb testing.TB, + code layers.DHCPv6Opt, + hwAddr net.HardwareAddr, +) (opt layers.DHCPv6Option) { + tb.Helper() + + duid := &layers.DHCPv6DUID{ + Type: layers.DHCPv6DUIDTypeLL, + HardwareType: binary.BigEndian.AppendUint16(nil, uint16(layers.LinkTypeEthernet)), + LinkLayerAddress: hwAddr, + } + + return layers.NewDHCPv6Option(code, duid.Encode()) +} diff --git a/internal/dhcpsvc/server.go b/internal/dhcpsvc/server.go index d4709d3bac7..00dfb36462a 100644 --- a/internal/dhcpsvc/server.go +++ b/internal/dhcpsvc/server.go @@ -144,6 +144,9 @@ func (srv *DHCPServer) newInterfaces( func (srv *DHCPServer) Start(ctx context.Context) (err error) { srv.logger.DebugContext(ctx, "starting dhcp server") + // TODO(e.burkov): Create a single device for each network interface with + // dual-stack support when possible. + var errs []error for _, iface := range srv.interfaces4 { netDevName := iface.common.name @@ -153,7 +156,7 @@ func (srv *DHCPServer) Start(ctx context.Context) (err error) { Name: netDevName, }) if err != nil { - errs = append(errs, err) + errs = append(errs, fmt.Errorf("opening ipv4 device %q: %w", netDevName, err)) continue } @@ -166,7 +169,26 @@ func (srv *DHCPServer) Start(ctx context.Context) (err error) { go srv.serveEther4(context.WithoutCancel(ctx), iface, netDev) } - // TODO(e.burkov): Serve EthernetTypeIPv6. + for _, iface := range srv.interfaces6 { + netDevName := iface.common.name + + var netDev NetworkDevice + netDev, err = srv.deviceManager.Open(ctx, &NetworkDeviceConfig{ + Name: netDevName, + }) + if err != nil { + errs = append(errs, fmt.Errorf("opening ipv6 device %q: %w", netDevName, err)) + + continue + } + + srv.devices = append(srv.devices, container.KeyValue[string, NetworkDevice]{ + Key: netDevName, + Value: netDev, + }) + + go srv.serveEther6(context.WithoutCancel(ctx), iface, netDev) + } return errors.Join(errs...) } diff --git a/internal/dhcpsvc/testdata/TestDHCPServer_ServeEther6_solicit/leases.json b/internal/dhcpsvc/testdata/TestDHCPServer_ServeEther6_solicit/leases.json new file mode 100644 index 00000000000..0c15365924e --- /dev/null +++ b/internal/dhcpsvc/testdata/TestDHCPServer_ServeEther6_solicit/leases.json @@ -0,0 +1,26 @@ +{ + "leases": [ + { + "expires": "2025-01-01T10:01:01Z", + "ip": "2001:db8::66", + "hostname": "dynamic6", + "mac": "02:03:04:05:06:07", + "static": false + }, + { + "expires": "2025-01-01T01:01:00Z", + "ip": "2001:db8::67", + "hostname": "expired6", + "mac": "03:04:05:06:07:08", + "static": false + }, + { + "expires": "", + "ip": "2001:db8::65", + "hostname": "static6", + "mac": "01:02:03:04:05:06", + "static": true + } + ], + "version": 1 +} diff --git a/internal/dhcpsvc/v6.go b/internal/dhcpsvc/v6.go index f7d077e5a6e..2113cf22ce4 100644 --- a/internal/dhcpsvc/v6.go +++ b/internal/dhcpsvc/v6.go @@ -210,6 +210,13 @@ func (srv *DHCPServer) newDHCPInterfaceV6( // TODO(e.burkov): Use an ICMP implementation. addrChecker: noopAddressChecker{}, subnetPrefix: netip.PrefixFrom(conf.RangeStart, v6PrefLen), + // Recommended values for T1 and T2 are 0.5 and 0.8 times the shortest + // preferred lifetime of the addresses in the IA that the server is + // willing to extend, respectively. + // + // See RFC 9915 Section 21.4. + // + // TODO(e.burkov): Consider making configurable. t1: conf.LeaseDuration / 2, t2: conf.LeaseDuration * 4 / 5, raSLAACOnly: conf.RASLAACOnly, @@ -360,11 +367,11 @@ func clientIDMatchingServer( return cliID, nil } -// defaultHopLimit is the default hop limit for relaying DHCPv6 response +// IPv6DefaultHopLimit is the default hop limit for relaying DHCPv6 response // packets. // // See RFC 9915 Section 7.6. -const defaultHopLimit = 8 +const IPv6DefaultHopLimit = 8 // respond6 constructs and sends a DHCPv6 response to the client. func respond6(fd *frameData6, resp *layers.DHCPv6) (err error) { @@ -377,7 +384,7 @@ func respond6(fd *frameData6, resp *layers.DHCPv6) (err error) { ip := &layers.IPv6{ Version: 6, NextHeader: layers.IPProtocolUDP, - HopLimit: defaultHopLimit, + HopLimit: IPv6DefaultHopLimit, SrcIP: fd.localAddr.AsSlice(), // If the original message was received directly by the server, the // server unicasts the Advertise or Reply message directly to the client @@ -412,7 +419,7 @@ func respond6(fd *frameData6, resp *layers.DHCPv6) (err error) { // option is malformed. lease is nil if there is no address available for // leasing. mac must be a valid MAC address according to [netutil.ValidateMAC], // req must be a valid DHCPv6 message of SOLICIT type, iface.common.indexMu -// mutex must be locked. +// must be locked. // // TODO(e.burkov): Support allocating several leases at a time when the // database will migrate, see the BUG at [Lease]'s documentation. @@ -422,13 +429,14 @@ func (iface *dhcpInterfaceV6) allocateForSolicit( req *layers.DHCPv6, ) (lease *Lease, iaid uint32) { l := iface.common.logger + key := macToKey(mac) for _, reqOpt := range req.Options { if reqOpt.Code != layers.DHCPv6OptIANA { continue } - var iana iaNAOption + var iana IANAOption err := iana.UnmarshalBinary(reqOpt.Data) if err != nil { // TODO(e.burkov): Recheck the logic on malformed IA_NA options. @@ -437,21 +445,25 @@ func (iface *dhcpInterfaceV6) allocateForSolicit( continue } - // TODO(e.burkov): Test the case, where the lease exists and is - // expired. - // + var ok bool + if lease, ok = iface.common.leases[key]; ok { + return lease, iana.ID + } + // TODO(e.burkov): Support allocating the exact requested address if it // is available. lease, err = iface.common.allocateLease(ctx, mac, iface.addrChecker, iface.clock) if err != nil { - l.DebugContext(ctx, "no address available", "iaid", iana.iaid, slogutil.KeyError, err) + l.DebugContext(ctx, "no address available", "iaid", iana.ID, slogutil.KeyError, err) continue } - return lease, iana.iaid + return lease, iana.ID } + l.DebugContext(ctx, "no valid ia_na in solicit") + return nil, 0 } @@ -483,7 +495,7 @@ func (iface *dhcpInterfaceV6) newSolicitRespOpts( // // See RFC 9915 Section 18.3.9. opts = append(opts, newPreferenceOption(0)) - opts = append(opts, newSOLMaxRTOption(solMaxRT)) + opts = append(opts, newSOLMaxRTOption(DefaultSolMaxRT)) if rapidCommit { opts = append(opts, layers.NewDHCPv6Option(layers.DHCPv6OptRapidCommit, nil)) @@ -501,15 +513,15 @@ func (iface *dhcpInterfaceV6) iaNAFromLease(lease *Lease, iaid uint32) (iana lay return newIANAWithStatus(iaid, layers.DHCPv6StatusCodeNoAddrsAvail) } - return iaNAOption{ - nested: []iaAddrOption{{ - addr: lease.IP, - preferredLifetime: iface.common.leaseTTL, - validLifetime: iface.common.leaseTTL, + return IANAOption{ + Nested: []IAAddrOption{{ + Addr: lease.IP, + PreferredLifetime: iface.common.leaseTTL, + ValidLifetime: iface.common.leaseTTL, }}, - iaid: iaid, - t1: iface.t1, - t2: iface.t2, + ID: iaid, + T1: iface.t1, + T2: iface.t2, }.Encode() } @@ -532,6 +544,11 @@ func (iface *dhcpInterfaceV6) commit( lease.Hostname = aghnet.GenerateHostname(lease.IP) } + // TODO(e.burkov): Add the Lease.isExpired. method. + if lease.Expiry.Before(iface.clock.Now()) { + lease.updateExpiry(iface.clock, iface.common.leaseTTL) + } + err = iface.common.index.update(ctx, iface.common.logger, lease, iface.common) if err != nil { rmErr := iface.common.removeLease(lease) From af9142e98e1c49783c610f3c5383843cc5b58e53 Mon Sep 17 00:00:00 2001 From: Maksim Kazantsev Date: Wed, 17 Jun 2026 12:08:53 +0000 Subject: [PATCH 2/4] Pull request 2675: AGDNS-4104-move-tls-http-api-from-tls-manager Squashed commit of the following: commit 926a3fd13bdbb74ef9a253e61f403a9a63ebd9b2 Merge: 13e3b38d8 798cd4d2f Author: Maksim Kazantsev Date: Wed Jun 17 14:59:10 2026 +0300 Merge branch 'master' into AGDNS-4104-move-tls-http-api-from-tls-manager commit 13e3b38d823277690a9780d0d9c058620bb1b5c1 Author: Maksim Kazantsev Date: Tue Jun 16 19:17:34 2026 +0300 home: fix docs; commit c73ea0e062802f4fa8b192315d266b3f9ee8df12 Author: Maksim Kazantsev Date: Tue Jun 16 17:39:25 2026 +0300 home: imp code; add todo; commit 754e7d2afd41bf13179b729f98374bf871867c81 Merge: ef139f560 54e6e3002 Author: Maksim Kazantsev Date: Mon Jun 15 15:11:06 2026 +0300 Merge branch 'master' into AGDNS-4104-move-tls-http-api-from-tls-manager commit ef139f5601b8052172dd437f5eda0e37782e7270 Author: Maksim Kazantsev Date: Mon Jun 15 15:04:44 2026 +0300 home: imp code; imp docs; fix bugs; commit 5e7bd02bff7191582a79ae5606c6fed6e0681c01 Author: Maksim Kazantsev Date: Wed Jun 10 12:58:50 2026 +0300 home: fix private fields are not set in tls settings; commit 4109d46b789c45bb99fb415636af8690d31e5a36 Author: Maksim Kazantsev Date: Tue Jun 9 19:19:09 2026 +0300 home: mv tls endpoints to web api; --- internal/home/config.go | 11 +- internal/home/tls.go | 402 +++-------------------------- internal/home/tls_internal_test.go | 334 +----------------------- internal/home/web.go | 323 +++++++++++++++++++++++ internal/home/web_internal_test.go | 354 +++++++++++++++++++++++++ 5 files changed, 728 insertions(+), 696 deletions(-) create mode 100644 internal/home/web_internal_test.go diff --git a/internal/home/config.go b/internal/home/config.go index 31bfaee8a1f..7570e12b2cd 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -301,6 +301,9 @@ type pendingRequests struct { // and HTTPS. When adding new properties, update the [tlsConfigSettings.clone] // and [tlsConfigSettings.setPrivateFieldsAndCompare] methods as necessary. type tlsConfigSettings struct { + // Status is the current status of the configuration. + Status tlsConfigStatus `yaml:"-" json:"-"` + // Enabled indicates whether encryption (DoT/DoH/HTTPS) is enabled. Enabled bool `yaml:"enabled" json:"enabled"` @@ -360,6 +363,9 @@ type tlsConfigSettings struct { // StrictSNICheck controls if the connections with SNI mismatching the // certificate's ones should be rejected. StrictSNICheck bool `yaml:"strict_sni_check" json:"-"` + + // ServePlainDNS defines whether to serve a plain DNS. + ServePlainDNS bool `yaml:"-" json:"-"` } // clone returns a deep copy of c. @@ -371,12 +377,13 @@ func (c *tlsConfigSettings) clone() (clone *tlsConfigSettings) { clone.CertificateChainData = slices.Clone(c.CertificateChainData) clone.PrivateKeyData = slices.Clone(c.PrivateKeyData) + clone.Status.DNSNames = slices.Clone(c.Status.DNSNames) + return clone } // setPrivateFieldsAndCompare sets any missing properties in conf to match those -// in c and returns true if TLS configurations are equal. conf must not be be -// nil. +// in c and returns true if TLS configurations are equal. conf must not be nil. // It sets the following properties because these are not accepted from the // frontend: // diff --git a/internal/home/tls.go b/internal/home/tls.go index ac5071fbd89..cdd136e04b3 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -14,7 +14,6 @@ import ( "fmt" "log/slog" "net/http" - "net/netip" "os" "strings" "sync" @@ -23,7 +22,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/logutil/slogutil" @@ -36,12 +34,9 @@ type tlsManager struct { // logger is used for logging the operation of the TLS Manager. logger *slog.Logger - // mu protects status, certLastMod, extTLSConf, and servePlainDNS. + // mu protects certLastMod, extTLSConf. mu *sync.Mutex - // status is the current status of the configuration. It is never nil. - status *tlsConfigStatus - // certLastMod is the last modification time of the certificate file. certLastMod time.Time @@ -74,9 +69,6 @@ type tlsManager struct { // customCipherIDs are the IDs of the cipher suites that AdGuard Home must // use. customCipherIDs []uint16 - - // servePlainDNS defines if plain DNS is allowed for incoming requests. - servePlainDNS bool } // tlsManagerConfig contains the settings for initializing the TLS manager. @@ -109,18 +101,19 @@ type tlsManagerConfig struct { // [tlsManager.setWebAPI]. func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, err error) { m = &tlsManager{ - logger: conf.logger, - mu: &sync.Mutex{}, - confModifier: conf.confModifier, - httpReg: conf.httpReg, - manager: conf.manager, - status: &tlsConfigStatus{}, - extTLSConf: &conf.tlsSettings, - servePlainDNS: conf.servePlainDNS, + logger: conf.logger, + mu: &sync.Mutex{}, + confModifier: conf.confModifier, + httpReg: conf.httpReg, + manager: conf.manager, + extTLSConf: &conf.tlsSettings, } m.rootCerts = aghtls.SystemRootCAs(ctx, conf.logger) + m.extTLSConf.ServePlainDNS = conf.servePlainDNS + m.extTLSConf.Status = tlsConfigStatus{} + if len(conf.tlsSettings.OverrideTLSCiphers) > 0 { m.customCipherIDs, err = aghtls.ParseCiphers(config.TLS.OverrideTLSCiphers) if err != nil { @@ -149,7 +142,7 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, m.logger.ErrorContext(ctx, "setting tls files", slogutil.KeyError, err) } - err = m.loadTLSConfig(ctx, m.extTLSConf, m.status) + err = m.loadTLSConfig(ctx, m.extTLSConf, &m.extTLSConf.Status) if err != nil { m.extTLSConf.Enabled = false @@ -162,8 +155,8 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, } // setWebAPI stores the provided web API. It must be called before -// [tlsManager.start], [tlsManager.reload], [tlsManager.handleTLSConfigure], or -// [tlsManager.validateTLSSettings]. +// [tlsManager.start], [tlsManager.reload], [webAPI.handleTLSConfigure], or +// [webAPI.validateTLSSettings]. // // TODO(s.chzhen): Remove it once cyclic dependency is resolved. func (m *tlsManager) setWebAPI(webAPI *webAPI) { @@ -199,8 +192,6 @@ func (m *tlsManager) setCertFileTime(ctx context.Context) { // // TODO(s.chzhen): Use context. func (m *tlsManager) start(ctx context.Context) { - m.registerWebHandlers() - m.mu.Lock() defer m.mu.Unlock() @@ -271,12 +262,12 @@ func (m *tlsManager) reload(ctx context.Context) { return } - m.extTLSConf = &tlsConf - m.status = status + tlsConf.Status = *status + m.extTLSConf = &tlsConf m.certLastMod = fi.ModTime().UTC() - err = m.reconfigureDNSServer(ctx) + err = m.web.reconfigureDNSServer(ctx, m.extTLSConf) if err != nil { m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err) } @@ -287,31 +278,6 @@ func (m *tlsManager) reload(ctx context.Context) { m.web.tlsConfigChanged(context.Background(), m.extTLSConf) } -// reconfigureDNSServer updates the DNS server configuration using the stored -// TLS settings. m.mu is expected to be locked. -func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) { - newConf, err := newServerConfig( - &config.DNS, - config.Clients.Sources, - m.extTLSConf, - config.HTTPConfig.DoH, - m, - m.httpReg, - globalContext.clients.storage, - m.confModifier, - ) - if err != nil { - return fmt.Errorf("generating forwarding dns server config: %w", err) - } - - err = globalContext.dnsServer.Reconfigure(ctx, newConf) - if err != nil { - return fmt.Errorf("starting forwarding dns server: %w", err) - } - - return nil -} - // loadTLSConfig loads and validates the TLS configuration. It also sets // [tlsConfigSettings.CertificateChainData] and // [tlsConfigSettings.PrivateKeyData] properties. The returned error is also @@ -457,93 +423,26 @@ type tlsConfigSettingsExt struct { ServePlainDNS aghalg.NullBool `yaml:"-" json:"serve_plain_dns"` } -// handleTLSStatus is the handler for the GET /control/tls/status HTTP API. -func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) { - var tlsConf *tlsConfigSettings - var servePlainDNS bool - func() { - m.mu.Lock() - defer m.mu.Unlock() - - tlsConf = m.extTLSConf.clone() - servePlainDNS = m.servePlainDNS - }() - - data := &tlsConfig{ - tlsConfigSettingsExt: tlsConfigSettingsExt{ - tlsConfigSettings: *tlsConf, - ServePlainDNS: aghalg.BoolToNullBool(servePlainDNS), - }, - tlsConfigStatus: m.status, - } - - m.marshalTLS(r.Context(), w, r, data) -} - -// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API. -func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - setts, err := unmarshalTLS(r) - if err != nil { - // errFmt does not follow error message guidelines because it is sent - // directly to the frontend. - const errFmt = "Failed to unmarshal TLS config: %s" - - aghhttp.ErrorAndLog(ctx, m.logger, r, w, http.StatusBadRequest, errFmt, err) - - return - } - - m.mu.Lock() - defer m.mu.Unlock() - - if setts.PrivateKeySaved { - setts.PrivateKey = m.extTLSConf.PrivateKey - } - - if err = m.validateTLSSettings(setts); err != nil { - m.logger.InfoContext(ctx, "validating tls settings", slogutil.KeyError, err) - - aghhttp.ErrorAndLog(ctx, m.logger, r, w, http.StatusBadRequest, "%s", err) - - return - } - - // Skip the error check, since we are only interested in the value of - // status.WarningValidation. - status := &tlsConfigStatus{} - _ = m.loadTLSConfig(ctx, &setts.tlsConfigSettings, status) - resp := &tlsConfig{ - tlsConfigSettingsExt: setts, - tlsConfigStatus: status, - } - - m.marshalTLS(ctx, w, r, resp) -} - -// setConfig updates manager TLS configuration with the given one. m.mu is -// expected to be locked. +// setConfig updates manager TLS configuration with the given one. newConf must +// not be nil. func (m *tlsManager) setConfig( ctx context.Context, - newConf tlsConfigSettings, - status *tlsConfigStatus, + newConf *tlsConfigSettings, servePlain aghalg.NullBool, ) (restartHTTPS bool) { - if !m.extTLSConf.setPrivateFieldsAndCompare(&newConf) { + m.mu.Lock() + defer m.mu.Unlock() + + m.extTLSConf.updatePlainDNS(newConf, servePlain) + + if !m.extTLSConf.setPrivateFieldsAndCompare(newConf) { m.logger.InfoContext(ctx, "config has changed, restarting https server") restartHTTPS = true } else { m.logger.InfoContext(ctx, "config has not changed") } - m.extTLSConf = &newConf - - m.status = status - - if servePlain != aghalg.NBNull { - m.servePlainDNS = servePlain == aghalg.NBTrue - } + m.extTLSConf = newConf certPath, keyPath := "", "" if newConf.Enabled { @@ -558,181 +457,30 @@ func (m *tlsManager) setConfig( m.logger.ErrorContext(ctx, "setting tls files", slogutil.KeyError, err) } + m.setCertFileTime(ctx) + return restartHTTPS } -// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP -// API. -func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - req, err := unmarshalTLS(r) - if err != nil { - aghhttp.ErrorAndLog( - ctx, - m.logger, - r, - w, - http.StatusBadRequest, - "Failed to unmarshal TLS config: %s", - err, - ) - - return - } - - var restartHTTPS bool - defer func() { - if restartHTTPS { - m.confModifier.Apply(ctx) - } - }() - - m.mu.Lock() - defer m.mu.Unlock() - - if req.PrivateKeySaved { - req.PrivateKey = m.extTLSConf.PrivateKey - } - - req.StrictSNICheck = m.extTLSConf.StrictSNICheck - - if err = m.validateTLSSettings(req); err != nil { - aghhttp.ErrorAndLog(ctx, m.logger, r, w, http.StatusBadRequest, "%s", err) - - return - } - - status := &tlsConfigStatus{} - err = m.loadTLSConfig(ctx, &req.tlsConfigSettings, status) - if err != nil { - resp := &tlsConfig{ - tlsConfigSettingsExt: req, - tlsConfigStatus: status, - } - - m.marshalTLS(ctx, w, r, resp) - - return - } - - restartHTTPS = m.setConfig(ctx, req.tlsConfigSettings, status, req.ServePlainDNS) - m.setCertFileTime(ctx) - - if req.ServePlainDNS != aghalg.NBNull { +// updatePlainDNS checks the old value of [tlsConfigSettings.ServePlainDNS] in +// c and if it differs from servePlain, sets the value of servePlain in +// newTLSConf.ServePlainDNS. newTLSConf must not be nil. +func (c *tlsConfigSettings) updatePlainDNS( + newTLSConf *tlsConfigSettings, + servePlain aghalg.NullBool, +) { + if servePlain != aghalg.NBNull { func() { config.Lock() defer config.Unlock() - config.DNS.ServePlainDNS = req.ServePlainDNS == aghalg.NBTrue + config.DNS.ServePlainDNS = servePlain == aghalg.NBTrue }() - } - - err = m.reconfigureDNSServer(ctx) - if err != nil { - m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err) - - aghhttp.ErrorAndLog(ctx, m.logger, r, w, http.StatusInternalServerError, "%s", err) - - return - } - - resp := &tlsConfig{ - tlsConfigSettingsExt: req, - tlsConfigStatus: m.status, - } - - m.marshalTLS(ctx, w, r, resp) - rc := http.NewResponseController(w) - err = rc.Flush() - if err != nil { - m.logger.ErrorContext(ctx, "flushing response", slogutil.KeyError, err) - } - - // The background context is used because the TLSConfigChanged wraps context - // with timeout on its own and shuts down the server, which handles current - // request. It is also should be done in a separate goroutine due to the - // same reason. - if restartHTTPS { - go m.web.tlsConfigChanged(context.Background(), &req.tlsConfigSettings) - } -} - -// validateTLSSettings returns error if the setts are not valid. -func (m *tlsManager) validateTLSSettings(setts tlsConfigSettingsExt) (err error) { - if !setts.Enabled { - if setts.ServePlainDNS == aghalg.NBFalse { - // TODO(a.garipov): Support full disabling of all DNS. - return errors.Error("plain DNS is required in case encryption protocols are disabled") - } - - return nil - } - var ( - tlsConf tlsConfigSettings - webAPIAddr netip.Addr - webAPIPort uint16 - plainDNSPort uint16 - ) - - func() { - config.Lock() - defer config.Unlock() - - tlsConf = config.TLS - webAPIAddr = config.HTTPConfig.Address.Addr() - webAPIPort = config.HTTPConfig.Address.Port() - plainDNSPort = config.DNS.Port - }() - - err = validatePorts( - tcpPort(webAPIPort), - tcpPort(setts.PortHTTPS), - tcpPort(setts.PortDNSOverTLS), - tcpPort(setts.PortDNSCrypt), - udpPort(plainDNSPort), - udpPort(setts.PortDNSOverQUIC), - ) - if err != nil { - // Don't wrap the error because it's informative enough as is. - return err - } - - // Don't wrap the error because it's informative enough as is. - return m.checkPortAvailability(tlsConf, setts.tlsConfigSettings, webAPIAddr) -} - -// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home -// DNS protocols. -func validatePorts( - bindPort, dohPort, dotPort, dnscryptTCPPort tcpPort, - dnsPort, doqPort udpPort, -) (err error) { - tcpPorts := aghalg.UniqChecker[tcpPort]{} - addPorts( - tcpPorts, - bindPort, - dohPort, - dotPort, - dnscryptTCPPort, - tcpPort(dnsPort), - ) - - err = tcpPorts.Validate() - if err != nil { - return fmt.Errorf("validating tcp ports: %w", err) - } - - udpPorts := aghalg.UniqChecker[udpPort]{} - addPorts(udpPorts, dnsPort, doqPort) - - err = udpPorts.Validate() - if err != nil { - return fmt.Errorf("validating udp ports: %w", err) + newTLSConf.ServePlainDNS = servePlain == aghalg.NBTrue + } else { + newTLSConf.ServePlainDNS = c.ServePlainDNS } - - return nil } // validateCertChain verifies certs using the first as the main one and others @@ -775,67 +523,6 @@ func (m *tlsManager) validateCertChain( return nil } -// checkPortAvailability checks [tlsConfigSettings.PortHTTPS], -// [tlsConfigSettings.PortDNSOverTLS], and [tlsConfigSettings.PortDNSOverQUIC] -// are available for use. It checks the current configuration and, if needed, -// attempts to bind to the port. The function returns human-readable error -// messages for the frontend. This is best-effort check to prevent an "address -// already in use" error. -// -// TODO(a.garipov): Adapt for HTTP/3. -func (m *tlsManager) checkPortAvailability( - currConf tlsConfigSettings, - newConf tlsConfigSettings, - addr netip.Addr, -) (err error) { - const ( - networkTCP = "tcp" - networkUDP = "udp" - - protoHTTPS = "HTTPS" - protoDoT = "DNS-over-TLS" - protoDoQ = "DNS-over-QUIC" - ) - - needBindingCheck := []struct { - network string - proto string - currPort uint16 - newPort uint16 - }{{ - network: networkTCP, - proto: protoHTTPS, - currPort: currConf.PortHTTPS, - newPort: newConf.PortHTTPS, - }, { - network: networkTCP, - proto: protoDoT, - currPort: currConf.PortDNSOverTLS, - newPort: newConf.PortDNSOverTLS, - }, { - network: networkUDP, - proto: protoDoQ, - currPort: currConf.PortDNSOverQUIC, - newPort: newConf.PortDNSOverQUIC, - }} - - var errs []error - for _, v := range needBindingCheck { - port := v.newPort - if v.currPort == port { - continue - } - - addrPort := netip.AddrPortFrom(addr, port) - err = aghnet.CheckPort(v.network, addrPort) - if err != nil { - errs = append(errs, fmt.Errorf("port %d for %s is not available", port, v.proto)) - } - } - - return errors.Join(errs...) -} - // errNoIPInCert is the error that is returned from [tlsManager.parseCertChain] // if the leaf certificate doesn't contain IPs. const errNoIPInCert errors.Error = `certificates has no IP addresses; ` + @@ -1059,7 +746,7 @@ func parsePrivateKey(der []byte) (key crypto.PrivateKey, typ string, err error) return nil, "", errors.Error("tls: failed to parse private key") } -// unmarshalTLS handles base64-encoded certificates transparently +// unmarshalTLS handles base64-encoded certificates transparently. func unmarshalTLS(r *http.Request) (data tlsConfigSettingsExt, err error) { data = tlsConfigSettingsExt{} err = json.NewDecoder(r.Body).Decode(&data) @@ -1117,10 +804,3 @@ func (m *tlsManager) marshalTLS( aghhttp.WriteJSONResponseOK(ctx, m.logger, w, r, *data) } - -// registerWebHandlers registers HTTP handlers for TLS configuration. -func (m *tlsManager) registerWebHandlers() { - m.httpReg.Register(http.MethodGet, "/control/tls/status", m.handleTLSStatus) - m.httpReg.Register(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure) - m.httpReg.Register(http.MethodPost, "/control/tls/validate", m.handleTLSValidate) -} diff --git a/internal/home/tls_internal_test.go b/internal/home/tls_internal_test.go index 0cb219fe4e7..e3b10d2dd34 100644 --- a/internal/home/tls_internal_test.go +++ b/internal/home/tls_internal_test.go @@ -6,26 +6,17 @@ import ( "crypto/rsa" "crypto/tls" "crypto/x509" - "encoding/base64" - "encoding/json" "encoding/pem" - "fmt" "math/big" - "net" - "net/http" - "net/http/httptest" - "net/netip" "os" "path/filepath" "testing" "time" "github.com/AdguardTeam/AdGuardHome/internal/agh" - "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" - "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/stretchr/testify/assert" @@ -285,7 +276,7 @@ func TestTLSManager_Reload(t *testing.T) { }) require.NoError(t, err) - web := newTestWeb(t, &webConfig{}) + web := newTestWeb(t, &webConfig{tlsManager: m}) m.setWebAPI(web) extTLSConf := m.extendedTLSConfig() @@ -305,326 +296,3 @@ func TestTLSManager_Reload(t *testing.T) { extTLSConf = m.extendedTLSConfig() assertCertSerialNumber(t, extTLSConf, snAfter) } - -func TestTLSManager_HandleTLSStatus(t *testing.T) { - var ( - ctx = testutil.ContextWithTimeout(t, testTimeout) - err error - ) - - testCertChain := requireReadFile(t, testCertificatePath) - testPrivateKeyData := requireReadFile(t, testPrivateKeyPath) - - m, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: testLogger, - confModifier: agh.EmptyConfigModifier{}, - manager: aghtls.EmptyManager{}, - tlsSettings: tlsConfigSettings{ - Enabled: true, - CertificateChain: string(testCertChain), - PrivateKey: string(testPrivateKeyData), - }, - servePlainDNS: false, - }) - require.NoError(t, err) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, "/control/tls/status", nil) - m.handleTLSStatus(w, r) - - res := &tlsConfigSettingsExt{} - err = json.NewDecoder(w.Body).Decode(res) - require.NoError(t, err) - - wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChain) - assert.True(t, res.Enabled) - assert.Equal(t, wantCertificateChain, res.CertificateChain) - assert.True(t, res.PrivateKeySaved) -} - -func TestValidateTLSSettings(t *testing.T) { - storeGlobals(t) - - var ( - ctx = testutil.ContextWithTimeout(t, testTimeout) - err error - ) - - m, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: testLogger, - confModifier: agh.EmptyConfigModifier{}, - manager: aghtls.EmptyManager{}, - servePlainDNS: false, - }) - require.NoError(t, err) - - web := newTestWeb(t, &webConfig{}) - m.setWebAPI(web) - - tcpLn, err := net.Listen("tcp", ":0") - require.NoError(t, err) - - testutil.CleanupAndRequireSuccess(t, tcpLn.Close) - - tcpAddr := testutil.RequireTypeAssert[*net.TCPAddr](t, tcpLn.Addr()) - busyTCPPort := tcpAddr.Port - - udpLn, err := net.ListenPacket("udp", ":0") - require.NoError(t, err) - - testutil.CleanupAndRequireSuccess(t, udpLn.Close) - - udpAddr := testutil.RequireTypeAssert[*net.UDPAddr](t, udpLn.LocalAddr()) - busyUDPPort := udpAddr.Port - - testCases := []struct { - name string - wantErr string - setts tlsConfigSettingsExt - }{{ - name: "basic", - wantErr: "", - setts: tlsConfigSettingsExt{}, - }, { - name: "disabled_all", - wantErr: "plain DNS is required in case encryption protocols are disabled", - setts: tlsConfigSettingsExt{ - ServePlainDNS: aghalg.NBFalse, - }, - }, { - name: "busy_https_port", - wantErr: fmt.Sprintf("port %d for HTTPS is not available", busyTCPPort), - setts: tlsConfigSettingsExt{ - tlsConfigSettings: tlsConfigSettings{ - Enabled: true, - PortHTTPS: uint16(busyTCPPort), - }, - }, - }, { - name: "busy_dot_port", - wantErr: fmt.Sprintf("port %d for DNS-over-TLS is not available", busyTCPPort), - setts: tlsConfigSettingsExt{ - tlsConfigSettings: tlsConfigSettings{ - Enabled: true, - PortDNSOverTLS: uint16(busyTCPPort), - }, - }, - }, { - name: "busy_doq_port", - wantErr: fmt.Sprintf("port %d for DNS-over-QUIC is not available", busyUDPPort), - setts: tlsConfigSettingsExt{ - tlsConfigSettings: tlsConfigSettings{ - Enabled: true, - PortDNSOverQUIC: uint16(busyUDPPort), - }, - }, - }, { - name: "duplicate_port", - wantErr: "validating tcp ports: duplicated values: [4433]", - setts: tlsConfigSettingsExt{ - tlsConfigSettings: tlsConfigSettings{ - Enabled: true, - PortHTTPS: 4433, - PortDNSOverTLS: 4433, - }, - }, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err = m.validateTLSSettings(tc.setts) - testutil.AssertErrorMsg(t, tc.wantErr, err) - }) - } -} - -func TestTLSManager_HandleTLSValidate(t *testing.T) { - storeGlobals(t) - - var ( - ctx = testutil.ContextWithTimeout(t, testTimeout) - err error - ) - - m, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: testLogger, - confModifier: agh.EmptyConfigModifier{}, - manager: aghtls.EmptyManager{}, - tlsSettings: tlsConfigSettings{ - Enabled: true, - CertificatePath: testCertificatePath, - PrivateKeyPath: testPrivateKeyPath, - }, - servePlainDNS: false, - }) - require.NoError(t, err) - - web := newTestWeb(t, &webConfig{}) - m.setWebAPI(web) - - setts := &tlsConfigSettingsExt{ - tlsConfigSettings: tlsConfigSettings{ - Enabled: true, - CertificatePath: testCertificatePath, - PrivateKeyPath: testPrivateKeyPath, - }, - } - - req, err := json.Marshal(setts) - require.NoError(t, err) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodPost, "/control/tls/validate", bytes.NewReader(req)) - m.handleTLSValidate(w, r) - - res := &tlsConfigStatus{} - err = json.NewDecoder(w.Body).Decode(res) - require.NoError(t, err) - - testCertChainData := requireReadFile(t, testCertificatePath) - testPrivateKeyData := requireReadFile(t, testPrivateKeyPath) - - cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) - require.NoError(t, err) - - wantIssuer := cert.Leaf.Issuer.String() - assert.Equal(t, wantIssuer, res.Issuer) -} - -func TestTLSManager_HandleTLSConfigure(t *testing.T) { - // Store the global state before making any changes. - storeGlobals(t) - - var ( - ctx = testutil.ContextWithTimeout(t, testTimeout) - err error - ) - - globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ - Logger: testLogger, - }) - require.NoError(t, err) - - err = globalContext.dnsServer.Prepare( - testutil.ContextWithTimeout(t, testTimeout), - &dnsforward.ServerConfig{ - TLSConf: &dnsforward.TLSConfig{}, - Config: dnsforward.Config{ - UpstreamMode: dnsforward.UpstreamModeLoadBalance, - EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false}, - ClientsContainer: dnsforward.EmptyClientsContainer{}, - }, - ServePlainDNS: true, - }) - require.NoError(t, err) - - globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{ - BaseLogger: testLogger, - Logger: testLogger, - Clock: timeutil.SystemClock{}, - }) - require.NoError(t, err) - - config.DNS.BindHosts = []netip.Addr{netutil.IPv4Localhost()} - config.DNS.Port = 0 - - const wantSerialNumber int64 = 1 - - // Prepare the TLS manager configuration. - tmpDir := t.TempDir() - certPath := filepath.Join(tmpDir, "cert.pem") - keyPath := filepath.Join(tmpDir, "key.pem") - - certDER, key := newCertAndKey(t, wantSerialNumber) - writeCertAndKey(t, certDER, certPath, key, keyPath) - - // Initialize the TLS manager and assert its configuration. - m, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: testLogger, - confModifier: agh.EmptyConfigModifier{}, - manager: aghtls.EmptyManager{}, - tlsSettings: tlsConfigSettings{ - Enabled: true, - CertificatePath: certPath, - PrivateKeyPath: keyPath, - }, - servePlainDNS: true, - }) - require.NoError(t, err) - - web := newTestWeb(t, &webConfig{}) - m.setWebAPI(web) - - extTLSConf := m.extendedTLSConfig() - assertCertSerialNumber(t, extTLSConf, wantSerialNumber) - - // Prepare a request with the new TLS configuration. - setts := &tlsConfigSettingsExt{ - tlsConfigSettings: tlsConfigSettings{ - Enabled: true, - PortHTTPS: 4433, - CertificatePath: testCertificatePath, - PrivateKeyPath: testPrivateKeyPath, - }, - } - - req, err := json.Marshal(setts) - require.NoError(t, err) - - r := httptest.NewRequest(http.MethodPost, "/control/tls/configure", bytes.NewReader(req)) - w := httptest.NewRecorder() - - // Reconfigure the TLS manager. - m.handleTLSConfigure(w, r) - - // The [tlsManager.handleTLSConfigure] method will start the DNS server and - // it should be stopped after the test ends. - testutil.CleanupAndRequireSuccess(t, func() (err error) { - return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout)) - }) - - res := &tlsConfig{ - tlsConfigStatus: &tlsConfigStatus{}, - } - - err = json.NewDecoder(w.Body).Decode(res) - require.NoError(t, err) - - testCertChainData := requireReadFile(t, testCertificatePath) - testPrivateKeyData := requireReadFile(t, testPrivateKeyPath) - - cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) - require.NoError(t, err) - - wantIssuer := cert.Leaf.Issuer.String() - assert.Equal(t, wantIssuer, res.tlsConfigStatus.Issuer) - - // Assert that the Web API's TLS configuration has been updated. - // - // TODO(s.chzhen): Remove when [httpsServer.cond] is removed. - assert.Eventually(t, func() bool { - web.httpsServer.condLock.Lock() - defer web.httpsServer.condLock.Unlock() - - cert = web.httpsServer.cert - if cert.Leaf == nil { - return false - } - - assert.Equal(t, wantIssuer, cert.Leaf.Issuer.String()) - - return true - }, testTimeout, testTimeout/10) -} - -// requireReadFile reads the file at the specified path and returns its content. -// -// TODO(m.kazantsev): Move to golibs/testutil. -func requireReadFile(tb testing.TB, path string) (data []byte) { - tb.Helper() - - data, err := os.ReadFile(path) - require.NoError(tb, err) - - return data -} diff --git a/internal/home/web.go b/internal/home/web.go index 29ab155e19b..83f4b4ee4d3 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -13,7 +13,9 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/agh" + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/logutil/slogutil" @@ -202,6 +204,7 @@ func newWebAPI(ctx context.Context, conf *webAPIConfig) (w *webAPI) { mux.Handle("/install.html", w.preInstallHandler(clientFS)) w.registerInstallHandlers() } else { + w.registerTLSHandlers() w.registerControlHandlers() } @@ -464,3 +467,323 @@ func startPprof(baseLogger *slog.Logger, port uint16) { } }() } + +// registerTLSHandlers registers HTTP handlers for TLS configuration. +func (web *webAPI) registerTLSHandlers() { + web.httpReg.Register(http.MethodGet, "/control/tls/status", web.handleTLSStatus) + web.httpReg.Register(http.MethodPost, "/control/tls/configure", web.handleTLSConfigure) + web.httpReg.Register(http.MethodPost, "/control/tls/validate", web.handleTLSValidate) +} + +// handleTLSStatus is the handler for the GET /control/tls/status HTTP API. +func (web *webAPI) handleTLSStatus(w http.ResponseWriter, r *http.Request) { + tlsConf := web.tlsManager.extendedTLSConfig() + + data := &tlsConfig{ + tlsConfigSettingsExt: tlsConfigSettingsExt{ + tlsConfigSettings: *tlsConf, + ServePlainDNS: aghalg.BoolToNullBool(tlsConf.ServePlainDNS), + }, + tlsConfigStatus: &tlsConf.Status, + } + + web.tlsManager.marshalTLS(r.Context(), w, r, data) +} + +// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API. +func (web *webAPI) handleTLSValidate(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + setts, err := unmarshalTLS(r) + if err != nil { + // errFmt does not follow error message guidelines because it is sent + // directly to the frontend. + const errFmt = "Failed to unmarshal TLS config: %s" + + aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, errFmt, err) + + return + } + + extTLSConf := web.tlsManager.extendedTLSConfig() + + if setts.PrivateKeySaved { + setts.PrivateKey = extTLSConf.PrivateKey + } + + if err = web.validateTLSSettings(setts); err != nil { + web.logger.InfoContext(ctx, "validating tls settings", slogutil.KeyError, err) + + aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "%s", err) + + return + } + + // Skip the error check, since we are only interested in the value of + // status.WarningValidation. + status := &tlsConfigStatus{} + _ = web.tlsManager.loadTLSConfig(ctx, &setts.tlsConfigSettings, status) + resp := &tlsConfig{ + tlsConfigSettingsExt: setts, + tlsConfigStatus: status, + } + + web.tlsManager.marshalTLS(ctx, w, r, resp) +} + +// validateTLSSettings returns error if the setts are not valid. +func (web *webAPI) validateTLSSettings(setts tlsConfigSettingsExt) (err error) { + if !setts.Enabled { + if setts.ServePlainDNS == aghalg.NBFalse { + // TODO(a.garipov): Support full disabling of all DNS. + return errors.Error("plain DNS is required in case encryption protocols are disabled") + } + + return nil + } + + var ( + tlsConf tlsConfigSettings + webAPIAddr netip.Addr + webAPIPort uint16 + plainDNSPort uint16 + ) + + func() { + config.Lock() + defer config.Unlock() + + tlsConf = config.TLS + webAPIAddr = config.HTTPConfig.Address.Addr() + webAPIPort = config.HTTPConfig.Address.Port() + plainDNSPort = config.DNS.Port + }() + + err = validatePorts( + tcpPort(webAPIPort), + tcpPort(setts.PortHTTPS), + tcpPort(setts.PortDNSOverTLS), + tcpPort(setts.PortDNSCrypt), + udpPort(plainDNSPort), + udpPort(setts.PortDNSOverQUIC), + ) + if err != nil { + // Don't wrap the error because it's informative enough as is. + return err + } + + // Don't wrap the error because it's informative enough as is. + return checkPortAvailability(tlsConf, setts.tlsConfigSettings, webAPIAddr) +} + +// checkPortAvailability checks [tlsConfigSettings.PortHTTPS], +// [tlsConfigSettings.PortDNSOverTLS], and [tlsConfigSettings.PortDNSOverQUIC] +// are available for use. It checks the current configuration and, if needed, +// attempts to bind to the port. The function returns human-readable error +// messages for the frontend. This is best-effort check to prevent an "address +// already in use" error. +// +// TODO(a.garipov): Adapt for HTTP/3. +func checkPortAvailability( + currConf tlsConfigSettings, + newConf tlsConfigSettings, + addr netip.Addr, +) (err error) { + const ( + networkTCP = "tcp" + networkUDP = "udp" + + protoHTTPS = "HTTPS" + protoDoT = "DNS-over-TLS" + protoDoQ = "DNS-over-QUIC" + ) + + needBindingCheck := []struct { + network string + proto string + currPort uint16 + newPort uint16 + }{{ + network: networkTCP, + proto: protoHTTPS, + currPort: currConf.PortHTTPS, + newPort: newConf.PortHTTPS, + }, { + network: networkTCP, + proto: protoDoT, + currPort: currConf.PortDNSOverTLS, + newPort: newConf.PortDNSOverTLS, + }, { + network: networkUDP, + proto: protoDoQ, + currPort: currConf.PortDNSOverQUIC, + newPort: newConf.PortDNSOverQUIC, + }} + + var errs []error + for _, v := range needBindingCheck { + port := v.newPort + if v.currPort == port { + continue + } + + addrPort := netip.AddrPortFrom(addr, port) + err = aghnet.CheckPort(v.network, addrPort) + if err != nil { + errs = append(errs, fmt.Errorf("port %d for %s is not available", port, v.proto)) + } + } + + return errors.Join(errs...) +} + +// validatePorts validates the uniqueness of TCP and UDP ports for AdGuard Home +// DNS protocols. +func validatePorts( + bindPort, dohPort, dotPort, dnscryptTCPPort tcpPort, + dnsPort, doqPort udpPort, +) (err error) { + tcpPorts := aghalg.UniqChecker[tcpPort]{} + addPorts( + tcpPorts, + bindPort, + dohPort, + dotPort, + dnscryptTCPPort, + tcpPort(dnsPort), + ) + + err = tcpPorts.Validate() + if err != nil { + return fmt.Errorf("validating tcp ports: %w", err) + } + + udpPorts := aghalg.UniqChecker[udpPort]{} + addPorts(udpPorts, dnsPort, doqPort) + + err = udpPorts.Validate() + if err != nil { + return fmt.Errorf("validating udp ports: %w", err) + } + + return nil +} + +// handleTLSConfigure is the handler for the POST /control/tls/configure HTTP +// API. +// +// TODO(m.kazantsev): Improve maintainability. +func (web *webAPI) handleTLSConfigure(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + req, err := unmarshalTLS(r) + if err != nil { + aghhttp.ErrorAndLog( + ctx, + web.logger, + r, + w, + http.StatusBadRequest, + "Failed to unmarshal TLS config: %s", + err, + ) + + return + } + + var restartHTTPS bool + defer func() { + if restartHTTPS { + web.tlsManager.confModifier.Apply(ctx) + } + }() + + extTLSConf := web.tlsManager.extendedTLSConfig() + + if req.PrivateKeySaved { + req.PrivateKey = extTLSConf.PrivateKey + } + + req.StrictSNICheck = extTLSConf.StrictSNICheck + + if err = web.validateTLSSettings(req); err != nil { + aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "%s", err) + + return + } + + status := &tlsConfigStatus{} + err = web.tlsManager.loadTLSConfig(ctx, &req.tlsConfigSettings, status) + if err != nil { + resp := &tlsConfig{ + tlsConfigSettingsExt: req, + tlsConfigStatus: status, + } + + web.tlsManager.marshalTLS(ctx, w, r, resp) + + return + } + + newTLSConf := &req.tlsConfigSettings + newTLSConf.Status = *status + + restartHTTPS = web.tlsManager.setConfig(ctx, newTLSConf, req.ServePlainDNS) + + err = web.reconfigureDNSServer(ctx, newTLSConf) + if err != nil { + web.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err) + + aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusInternalServerError, "%s", err) + + return + } + + resp := &tlsConfig{ + tlsConfigSettingsExt: req, + tlsConfigStatus: status, + } + + web.tlsManager.marshalTLS(ctx, w, r, resp) + rc := http.NewResponseController(w) + err = rc.Flush() + if err != nil { + web.logger.ErrorContext(ctx, "flushing response", slogutil.KeyError, err) + } + + // The background context is used because the TLSConfigChanged wraps context + // with timeout on its own and shuts down the server, which handles current + // request. It is also should be done in a separate goroutine due to the + // same reason. + if restartHTTPS { + go web.tlsConfigChanged(context.Background(), &req.tlsConfigSettings) + } +} + +// reconfigureDNSServer updates the DNS server configuration using extTLSConf. +// extTLSConf must not be nil. +func (web *webAPI) reconfigureDNSServer( + ctx context.Context, + extTLSConf *tlsConfigSettings, +) (err error) { + newConf, err := newServerConfig( + &config.DNS, + config.Clients.Sources, + extTLSConf, + config.HTTPConfig.DoH, + web.tlsManager, + web.httpReg, + globalContext.clients.storage, + web.tlsManager.confModifier, + ) + if err != nil { + return fmt.Errorf("generating forwarding dns server config: %w", err) + } + + err = globalContext.dnsServer.Reconfigure(ctx, newConf) + if err != nil { + return fmt.Errorf("starting forwarding dns server: %w", err) + } + + return nil +} diff --git a/internal/home/web_internal_test.go b/internal/home/web_internal_test.go new file mode 100644 index 00000000000..23ff2b7d5da --- /dev/null +++ b/internal/home/web_internal_test.go @@ -0,0 +1,354 @@ +package home + +import ( + "bytes" + "crypto/tls" + "encoding/base64" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "os" + "path/filepath" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/agh" + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" + "github.com/AdguardTeam/AdGuardHome/internal/aghtls" + "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/golibs/timeutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWebAPI_HandleTLSConfigure(t *testing.T) { + // Store the global state before making any changes. + storeGlobals(t) + + var ( + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ + Logger: testLogger, + }) + require.NoError(t, err) + + err = globalContext.dnsServer.Prepare( + testutil.ContextWithTimeout(t, testTimeout), + &dnsforward.ServerConfig{ + TLSConf: &dnsforward.TLSConfig{}, + Config: dnsforward.Config{ + UpstreamMode: dnsforward.UpstreamModeLoadBalance, + EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false}, + ClientsContainer: dnsforward.EmptyClientsContainer{}, + }, + ServePlainDNS: true, + }) + require.NoError(t, err) + + globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{ + BaseLogger: testLogger, + Logger: testLogger, + Clock: timeutil.SystemClock{}, + }) + require.NoError(t, err) + + config.DNS.BindHosts = []netip.Addr{netutil.IPv4Localhost()} + config.DNS.Port = 0 + + const wantSerialNumber int64 = 1 + + // Prepare the TLS manager configuration. + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + certDER, key := newCertAndKey(t, wantSerialNumber) + writeCertAndKey(t, certDER, certPath, key, keyPath) + + // Initialize the TLS manager and assert its configuration. + m, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, + manager: aghtls.EmptyManager{}, + tlsSettings: tlsConfigSettings{ + Enabled: true, + CertificatePath: certPath, + PrivateKeyPath: keyPath, + ServePlainDNS: true, + }, + servePlainDNS: true, + }) + require.NoError(t, err) + + web := newTestWeb(t, &webConfig{tlsManager: m}) + m.setWebAPI(web) + + extTLSConf := m.extendedTLSConfig() + assertCertSerialNumber(t, extTLSConf, wantSerialNumber) + + // Prepare a request with the new TLS configuration. + setts := &tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortHTTPS: 4433, + CertificatePath: testCertificatePath, + PrivateKeyPath: testPrivateKeyPath, + }, + } + + req, err := json.Marshal(setts) + require.NoError(t, err) + + r := httptest.NewRequest(http.MethodPost, "/control/tls/configure", bytes.NewReader(req)) + w := httptest.NewRecorder() + + // Reconfigure the TLS manager. + web.handleTLSConfigure(w, r) + + // The [webAPI.handleTLSConfigure] method will start the DNS server and + // it should be stopped after the test ends. + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout)) + }) + + res := &tlsConfig{ + tlsConfigStatus: &tlsConfigStatus{}, + } + + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + testCertChainData := requireReadFile(t, testCertificatePath) + testPrivateKeyData := requireReadFile(t, testPrivateKeyPath) + + cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) + require.NoError(t, err) + + wantIssuer := cert.Leaf.Issuer.String() + assert.Equal(t, wantIssuer, res.tlsConfigStatus.Issuer) + + // Assert that the Web API's TLS configuration has been updated. + // + // TODO(s.chzhen): Remove when [httpsServer.cond] is removed. + assert.Eventually(t, func() bool { + web.httpsServer.condLock.Lock() + defer web.httpsServer.condLock.Unlock() + + cert = web.httpsServer.cert + if cert.Leaf == nil { + return false + } + + assert.Equal(t, wantIssuer, cert.Leaf.Issuer.String()) + + return true + }, testTimeout, testTimeout/10) +} + +func TestWebAPI_HandleTLSStatus(t *testing.T) { + var ( + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + testCertChain := requireReadFile(t, testCertificatePath) + testPrivateKeyData := requireReadFile(t, testPrivateKeyPath) + + m, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, + manager: aghtls.EmptyManager{}, + tlsSettings: tlsConfigSettings{ + Enabled: true, + CertificateChain: string(testCertChain), + PrivateKey: string(testPrivateKeyData), + }, + servePlainDNS: false, + }) + require.NoError(t, err) + + web := newTestWeb(t, &webConfig{tlsManager: m}) + m.setWebAPI(web) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/control/tls/status", nil) + web.handleTLSStatus(w, r) + + res := &tlsConfigSettingsExt{} + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChain) + assert.True(t, res.Enabled) + assert.Equal(t, wantCertificateChain, res.CertificateChain) + assert.True(t, res.PrivateKeySaved) +} + +func TestWebAPI_ValidateTLSSettings(t *testing.T) { + storeGlobals(t) + + var ( + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + m, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, + manager: aghtls.EmptyManager{}, + servePlainDNS: false, + }) + require.NoError(t, err) + + web := newTestWeb(t, &webConfig{tlsManager: m}) + m.setWebAPI(web) + + tcpLn, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + testutil.CleanupAndRequireSuccess(t, tcpLn.Close) + + tcpAddr := testutil.RequireTypeAssert[*net.TCPAddr](t, tcpLn.Addr()) + busyTCPPort := tcpAddr.Port + + udpLn, err := net.ListenPacket("udp", ":0") + require.NoError(t, err) + + testutil.CleanupAndRequireSuccess(t, udpLn.Close) + + udpAddr := testutil.RequireTypeAssert[*net.UDPAddr](t, udpLn.LocalAddr()) + busyUDPPort := udpAddr.Port + + testCases := []struct { + name string + wantErr string + setts tlsConfigSettingsExt + }{{ + name: "basic", + wantErr: "", + setts: tlsConfigSettingsExt{}, + }, { + name: "disabled_all", + wantErr: "plain DNS is required in case encryption protocols are disabled", + setts: tlsConfigSettingsExt{ + ServePlainDNS: aghalg.NBFalse, + }, + }, { + name: "busy_https_port", + wantErr: fmt.Sprintf("port %d for HTTPS is not available", busyTCPPort), + setts: tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortHTTPS: uint16(busyTCPPort), + }, + }, + }, { + name: "busy_dot_port", + wantErr: fmt.Sprintf("port %d for DNS-over-TLS is not available", busyTCPPort), + setts: tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortDNSOverTLS: uint16(busyTCPPort), + }, + }, + }, { + name: "busy_doq_port", + wantErr: fmt.Sprintf("port %d for DNS-over-QUIC is not available", busyUDPPort), + setts: tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortDNSOverQUIC: uint16(busyUDPPort), + }, + }, + }, { + name: "duplicate_port", + wantErr: "validating tcp ports: duplicated values: [4433]", + setts: tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + PortHTTPS: 4433, + PortDNSOverTLS: 4433, + }, + }, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err = web.validateTLSSettings(tc.setts) + testutil.AssertErrorMsg(t, tc.wantErr, err) + }) + } +} + +func TestWebAPI_HandleTLSValidate(t *testing.T) { + storeGlobals(t) + + var ( + ctx = testutil.ContextWithTimeout(t, testTimeout) + err error + ) + + m, err := newTLSManager(ctx, &tlsManagerConfig{ + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, + manager: aghtls.EmptyManager{}, + tlsSettings: tlsConfigSettings{ + Enabled: true, + CertificatePath: testCertificatePath, + PrivateKeyPath: testPrivateKeyPath, + }, + servePlainDNS: false, + }) + require.NoError(t, err) + + web := newTestWeb(t, &webConfig{tlsManager: m}) + m.setWebAPI(web) + + setts := &tlsConfigSettingsExt{ + tlsConfigSettings: tlsConfigSettings{ + Enabled: true, + CertificatePath: testCertificatePath, + PrivateKeyPath: testPrivateKeyPath, + }, + } + + req, err := json.Marshal(setts) + require.NoError(t, err) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/control/tls/validate", bytes.NewReader(req)) + web.handleTLSValidate(w, r) + + res := &tlsConfigStatus{} + err = json.NewDecoder(w.Body).Decode(res) + require.NoError(t, err) + + testCertChainData := requireReadFile(t, testCertificatePath) + testPrivateKeyData := requireReadFile(t, testPrivateKeyPath) + + cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData) + require.NoError(t, err) + + wantIssuer := cert.Leaf.Issuer.String() + assert.Equal(t, wantIssuer, res.Issuer) +} + +// requireReadFile reads the file at the specified path and returns its content. +// +// TODO(m.kazantsev): Move to golibs/testutil. +func requireReadFile(tb testing.TB, path string) (data []byte) { + tb.Helper() + + data, err := os.ReadFile(path) + require.NoError(tb, err) + + return data +} From 5d3859bd36e3eef84abfcf4ebc1185379c7bfd4e Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Wed, 17 Jun 2026 13:18:16 +0000 Subject: [PATCH 3/4] Pull request 2681: 8276-fix-ech-base64 Updates #8276. Squashed commit of the following: commit d92a630ba1c01a1a516f723c95421a9f4a9c53d3 Merge: b6fcfbb40 af9142e98 Author: Ainar Garipov Date: Wed Jun 17 16:07:38 2026 +0300 Merge branch 'master' into 8276-fix-ech-base64 commit b6fcfbb40683105807fd1a37cc023746c444c96e Author: Ainar Garipov Date: Wed Jun 17 14:17:25 2026 +0300 all: fix chlog commit 1f3ded96fd375ff5210d105d88383fac514faa45 Author: Ainar Garipov Date: Tue Jun 16 21:17:52 2026 +0300 dnsforward: do not require padding when parsing ech --- CHANGELOG.md | 17 ++++++++++------- internal/dnsforward/svcbmsg.go | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0464aa8d931..6cdd6c29978 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,12 +18,6 @@ See also the [v0.107.78 GitHub milestone][ms-v0.107.78]. NOTE: Add new changes BELOW THIS COMMENT. --> -### Added - -- Improved updater logging to give users more insight into the problem with version updating ([#8410]). - -[#8410]: https://github.com/AdguardTeam/AdGuardHome/issues/8410 - ### Security - Go version has been updated to prevent the possibility of exploiting the Go vulnerabilities fixed in [1.26.4][go-1.26.4]. @@ -32,6 +26,12 @@ NOTE: Add new changes BELOW THIS COMMENT. - The size of rulelists is limited. This is necessary to prevent a user's machine from becoming overloaded if the filter source misbehaves. +### Added + +- Improved updater logging to give users more insight into the problem with version updating ([#8410]). + +[#8410]: https://github.com/AdguardTeam/AdGuardHome/issues/8410 + ### Changed - The interval of filter updates can now be set to any number of ours between 0 and 365 days in the configuration file. @@ -42,10 +42,13 @@ NOTE: Add new changes BELOW THIS COMMENT. ### Fixed +- The parsing of the `ech` parameter in DNS rewrite rules for the HTTPS record type ([#8276]). + - Blocked services check on the Custom filtering rules page does not work properly without specifying of a client. -[rfc9113]: https://datatracker.ietf.org/doc/html/rfc9113 +[#8276]: https://github.com/AdguardTeam/AdGuardHome/issues/8276 [go-1.26.4]: https://groups.google.com/g/golang-announce/c/tKs3rmcBcKw +[rfc9113]: https://datatracker.ietf.org/doc/html/rfc9113