Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions client/internal/dns/resutil/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,35 @@ func FormatAnswers(answers []dns.RR) string {
}
return "[" + strings.Join(parts, ", ") + "]"
}

// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
// RFC 6891 a responder must not include an OPT RR toward a client that did not
// advertise EDNS0.
func StripOPT(msg *dns.Msg) {
if len(msg.Extra) == 0 {
return
}
out := msg.Extra[:0]
for _, rr := range msg.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
msg.Extra = out
}

// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
// the message, if present.
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
opt := msg.IsEdns0()
if opt == nil {
return nil, false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok {
return ede, true
}
}
return nil, false
}
39 changes: 39 additions & 0 deletions client/internal/dns/resutil/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,42 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {

assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
}

func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
StripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}

func TestExtractEDE(t *testing.T) {
t.Run("no edns", func(t *testing.T) {
_, ok := ExtractEDE(&dns.Msg{})
assert.False(t, ok, "message without OPT has no EDE")
})

t.Run("edns without ede", func(t *testing.T) {
rm := &dns.Msg{}
rm.SetEdns0(4096, false)
_, ok := ExtractEDE(rm)
assert.False(t, ok, "OPT without EDE option returns false")
})

t.Run("with ede", func(t *testing.T) {
rm := &dns.Msg{}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
rm.Extra = append(rm.Extra, opt)

ede, ok := ExtractEDE(rm)
assert.True(t, ok, "EDE option should be found")
assert.Equal(t, uint16(49152), ede.InfoCode)
assert.Equal(t, "upstream timeout", ede.ExtraText)
})
}
20 changes: 2 additions & 18 deletions client/internal/dns/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns {
stripOPT(rm)
resutil.StripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
Expand All @@ -462,7 +462,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
}

if !hadEdns {
stripOPT(rm)
resutil.StripOPT(rm)
}

u.markUpstreamOk(upstream)
Expand Down Expand Up @@ -520,22 +520,6 @@ func upstreamUDPSize() uint16 {
return dns.MinMsgSize
}

// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}

func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
return &upstreamFailure{upstream: upstream, reason: err.Error()}
Expand Down
13 changes: 0 additions & 13 deletions client/internal/dns/upstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,19 +913,6 @@ func TestEDEName(t *testing.T) {
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
}

func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}

func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
Expand Down
46 changes: 45 additions & 1 deletion client/internal/dnsfwd/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ import (
const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second

// EDE info codes the forwarder emits on upstream failures so the querying
// client can see the reason without inspecting this peer's logs. They live in
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
// real upstream EDE here, so these cannot collide with a genuine code.
const (
edeNetbirdUpstreamTimeout uint16 = 49152
edeNetbirdUpstreamFailure uint16 = 49153
)

type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
Expand Down Expand Up @@ -220,7 +229,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q

result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
return
}

Expand Down Expand Up @@ -333,6 +342,7 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg,
domain string,
result resutil.LookupResult,
reqHasEdns bool,
startTime time.Time,
) {
qType := question.Qtype
Expand Down Expand Up @@ -374,6 +384,10 @@ func (f *DNSForwarder) handleDNSError(
logger.Warnf(errResolveFailed, domain, result.Err)
}

if reqHasEdns {
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
}

f.writeResponse(logger, w, resp, domain, startTime)
}

Expand Down Expand Up @@ -414,3 +428,33 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar

return selectedResId, matches
}

// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
func edeCodeFor(dnsErr *net.DNSError) uint16 {
if dnsErr != nil && dnsErr.IsTimeout {
return edeNetbirdUpstreamTimeout
}
return edeNetbirdUpstreamFailure
}

// edeText builds the EDE extra-text describing the class of upstream failure.
// It deliberately omits the upstream server address, which may be an internal
// resolver and is exposed to any client permitted to use the route; the full
// detail stays in the forwarder's local log.
func edeText(dnsErr *net.DNSError) string {
if dnsErr != nil && dnsErr.IsTimeout {
return "netbird forwarder: upstream timeout"
}
return "netbird forwarder: upstream failure"
}

// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
// creating the OPT pseudo-record if the response does not already carry one.
func attachEDE(resp *dns.Msg, code uint16, text string) {
opt := resp.IsEdns0()
if opt == nil {
resp.SetEdns0(dns.DefaultMsgSize, false)
opt = resp.IsEdns0()
}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
}
80 changes: 80 additions & 0 deletions client/internal/dnsfwd/forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/require"

firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
Expand Down Expand Up @@ -617,6 +618,85 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
}
}

func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
tests := []struct {
name string
lookupErr error
reqEdns bool
wantEDE bool
wantCode uint16
wantTextHas string
}{
{
name: "timeout with edns0",
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
reqEdns: true,
wantEDE: true,
wantCode: edeNetbirdUpstreamTimeout,
wantTextHas: "netbird forwarder: upstream timeout",
},
{
name: "server failure with edns0",
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
reqEdns: true,
wantEDE: true,
wantCode: edeNetbirdUpstreamFailure,
wantTextHas: "netbird forwarder: upstream failure",
},
{
name: "no edns0 in request omits ede",
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
reqEdns: false,
wantEDE: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver

d, err := domain.FromString("example.com")
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})

mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr(nil), tt.lookupErr).Once()

query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
if tt.reqEdns {
query.SetEdns0(dns.DefaultMsgSize, false)
}

var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}

forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
mockResolver.AssertExpectations(t)

require.NotNil(t, writtenResp, "expected a response")
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")

ede, ok := resutil.ExtractEDE(writtenResp)
if !tt.wantEDE {
assert.False(t, ok, "response must not carry EDE")
return
}
require.True(t, ok, "response must carry EDE")
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
})
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}
Expand Down
15 changes: 15 additions & 0 deletions client/internal/routemanager/dnsinterceptor/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true
}

// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
// describing why a lookup failed. The OPT is stripped from the reply when
// the original client did not request EDNS0.
hadEdns := r.IsEdns0() != nil
if !hadEdns {
r.SetEdns0(dns.DefaultMsgSize, false)
}

upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
Expand All @@ -260,6 +268,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}

if ede, ok := resutil.ExtractEDE(reply); ok {
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
}
if !hadEdns {
resutil.StripOPT(reply)
}

resutil.SetMeta(w, "peer", peerKey)

reply.Id = r.Id
Expand Down
Loading