Skip to content

Commit a59c68d

Browse files
committed
XDNS finalmask: use raw client socket for all resolvers to respect sockopt and sendThrough settings.
1 parent b465036 commit a59c68d

2 files changed

Lines changed: 105 additions & 121 deletions

File tree

transport/internet/finalmask/xdns/client.go

Lines changed: 93 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
go_errors "errors"
1010
"io"
1111
"net"
12+
"net/netip"
1213
"strconv"
1314
"strings"
1415
"sync"
@@ -37,11 +38,12 @@ type packet struct {
3738
}
3839

3940
type xdnsConnClient struct {
40-
conn net.PacketConn
41-
resolverConns []net.PacketConn
41+
net.PacketConn
42+
4243
resolverAddrs []*net.UDPAddr
4344
resolverIdx uint32
4445
resolverSend []atomic.Uint32
46+
resolverNormalizedAddrs []netip.AddrPort
4547

4648
clientID []byte
4749
domains []Name
@@ -54,6 +56,11 @@ type xdnsConnClient struct {
5456
mutex sync.Mutex
5557
}
5658

59+
func normalizeAddrPort(ap netip.AddrPort) netip.AddrPort {
60+
ip6 := netip.AddrFrom16(ap.Addr().As16())
61+
return netip.AddrPortFrom(ip6, ap.Port())
62+
}
63+
5764
func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
5865
if len(c.Resolvers) == 0 {
5966
return nil, errors.New("empty resolvers")
@@ -74,8 +81,8 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
7481
servers = append(servers, parts[1])
7582
}
7683

77-
var resolverConns []net.PacketConn
7884
var resolverAddrs []*net.UDPAddr
85+
var resolverNormalizedAddrs []netip.AddrPort
7986
var resolverSend []atomic.Uint32
8087
for _, rs := range servers {
8188
h, p, err := net.SplitHostPort(rs)
@@ -90,28 +97,19 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
9097
if port == 0 {
9198
return nil, errors.New("invalid port")
9299
}
93-
var uc net.PacketConn
94-
if ip.To4() != nil {
95-
uc, err = net.ListenPacket("udp4", ":0")
96-
} else {
97-
uc, err = net.ListenPacket("udp6", ":0")
98-
}
99-
if err != nil {
100-
for _, rc := range resolverConns {
101-
rc.Close()
102-
}
103-
return nil, errors.New("failed to create resolver socket: ", err)
104-
}
105-
resolverConns = append(resolverConns, uc)
106-
resolverAddrs = append(resolverAddrs, &net.UDPAddr{IP: ip, Port: port})
100+
101+
addr := &net.UDPAddr{IP: ip, Port: port}
102+
resolverAddrs = append(resolverAddrs, addr)
103+
resolverNormalizedAddrs = append(resolverNormalizedAddrs, normalizeAddrPort(addr.AddrPort()))
107104
}
108-
resolverSend = make([]atomic.Uint32, len(resolverConns))
105+
resolverSend = make([]atomic.Uint32, len(resolverAddrs))
109106

110107
conn := &xdnsConnClient{
111-
conn: raw,
112-
resolverConns: resolverConns,
108+
PacketConn: raw,
109+
113110
resolverAddrs: resolverAddrs,
114111
resolverSend: resolverSend,
112+
resolverNormalizedAddrs: resolverNormalizedAddrs,
115113

116114
clientID: make([]byte, 8),
117115
domains: domains,
@@ -130,69 +128,69 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
130128
}
131129

132130
func (c *xdnsConnClient) recvLoop() {
133-
var wg sync.WaitGroup
134-
135-
for i, rc := range c.resolverConns {
136-
wg.Add(1)
137-
go func() {
138-
defer wg.Done()
139-
140-
var buf [finalmask.UDPSize]byte
141-
142-
for {
143-
if c.closed {
144-
break
145-
}
146-
147-
n, addr, err := rc.ReadFrom(buf[:])
148-
if err != nil {
149-
if go_errors.Is(err, net.ErrClosed) {
150-
break
151-
}
152-
continue
153-
}
154-
155-
resp, err := MessageFromWireFormat(buf[:n])
156-
if err != nil {
157-
errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err)
158-
continue
159-
}
160-
161-
payload := dnsResponsePayload(&resp, c.domains)
162-
163-
r := bytes.NewReader(payload)
164-
anyPacket := false
165-
for {
166-
p, err := nextPacket(r)
167-
if err != nil {
168-
break
169-
}
170-
anyPacket = true
171-
172-
buf := make([]byte, len(p))
173-
copy(buf, p)
174-
select {
175-
case c.readQueue <- &packet{
176-
p: buf,
177-
addr: addr,
178-
}:
179-
default:
180-
errors.LogDebug(context.Background(), addr, " mask read err queue full")
181-
}
182-
}
183-
184-
if anyPacket {
185-
c.resolverSend[i].Store(0)
186-
select {
187-
case c.pollChan <- struct{}{}:
188-
default:
189-
}
190-
}
131+
var buf [finalmask.UDPSize]byte
132+
133+
for {
134+
if c.closed {
135+
break
136+
}
137+
138+
n, addr, err := c.PacketConn.ReadFrom(buf[:])
139+
if err != nil {
140+
if go_errors.Is(err, net.ErrClosed) {
141+
break
191142
}
192-
}()
193-
}
143+
continue
144+
}
145+
146+
resp, err := MessageFromWireFormat(buf[:n])
147+
if err != nil {
148+
errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err)
149+
continue
150+
}
194151

195-
wg.Wait()
152+
payload, domain := dnsResponsePayload(&resp, c.domains)
153+
154+
rsIdx := -1
155+
readAddr := normalizeAddrPort(addr.(*net.UDPAddr).AddrPort())
156+
for j, rsAddr := range c.resolverNormalizedAddrs {
157+
if readAddr == rsAddr && domain.EqualTo(c.domains[j]) {
158+
rsIdx = j
159+
break
160+
}
161+
}
162+
163+
r := bytes.NewReader(payload)
164+
anyPacket := false
165+
for {
166+
p, err := nextPacket(r)
167+
if err != nil {
168+
break
169+
}
170+
anyPacket = true
171+
172+
buf := make([]byte, len(p))
173+
copy(buf, p)
174+
select {
175+
case c.readQueue <- &packet{
176+
p: buf,
177+
addr: addr,
178+
}:
179+
default:
180+
errors.LogDebug(context.Background(), addr, " mask read err queue full")
181+
}
182+
}
183+
184+
if anyPacket {
185+
if rsIdx >= 0 {
186+
c.resolverSend[rsIdx].Store(0)
187+
}
188+
select {
189+
case c.pollChan <- struct{}{}:
190+
default:
191+
}
192+
}
193+
}
196194

197195
errors.LogDebug(context.Background(), "xdns closed")
198196

@@ -255,10 +253,10 @@ func (c *xdnsConnClient) sendLoop() {
255253

256254
cur := c.resolverIdx
257255
curSend := c.resolverSend[c.resolverIdx].Add(1)
258-
_, _ = c.resolverConns[c.resolverIdx].WriteTo(p.p, c.resolverAddrs[c.resolverIdx])
256+
_, _ = c.PacketConn.WriteTo(p.p, c.resolverAddrs[c.resolverIdx])
259257
for {
260258
c.resolverIdx += 1
261-
c.resolverIdx %= uint32(len(c.resolverConns))
259+
c.resolverIdx %= uint32(len(c.resolverAddrs))
262260
if c.resolverIdx == cur {
263261
break
264262
}
@@ -290,7 +288,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
290288
return 0, io.ErrClosedPipe
291289
}
292290

293-
encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverConns))])
291+
encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverAddrs))])
294292
if err != nil {
295293
errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p))
296294
return 0, nil
@@ -310,35 +308,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
310308

311309
func (c *xdnsConnClient) Close() error {
312310
c.closed = true
313-
for _, rc := range c.resolverConns {
314-
rc.Close()
315-
}
316-
return c.conn.Close()
317-
}
318-
319-
func (c *xdnsConnClient) LocalAddr() net.Addr {
320-
return c.conn.LocalAddr()
321-
}
322-
323-
func (c *xdnsConnClient) SetDeadline(t time.Time) error {
324-
for _, rc := range c.resolverConns {
325-
rc.SetDeadline(t)
326-
}
327-
return c.conn.SetDeadline(t)
328-
}
329-
330-
func (c *xdnsConnClient) SetReadDeadline(t time.Time) error {
331-
for _, rc := range c.resolverConns {
332-
rc.SetReadDeadline(t)
333-
}
334-
return c.conn.SetReadDeadline(t)
335-
}
336-
337-
func (c *xdnsConnClient) SetWriteDeadline(t time.Time) error {
338-
for _, rc := range c.resolverConns {
339-
rc.SetWriteDeadline(t)
340-
}
341-
return c.conn.SetWriteDeadline(t)
311+
return c.PacketConn.Close()
342312
}
343313

344314
func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
@@ -430,37 +400,39 @@ func nextPacket(r *bytes.Reader) ([]byte, error) {
430400
return p, err
431401
}
432402

433-
func dnsResponsePayload(resp *Message, domains []Name) []byte {
403+
func dnsResponsePayload(resp *Message, domains []Name) ([]byte, Name) {
434404
if resp.Flags&0x8000 != 0x8000 {
435-
return nil
405+
return nil, nil
436406
}
437407
if resp.Flags&0x000f != RcodeNoError {
438-
return nil
408+
return nil, nil
439409
}
440410

441411
if len(resp.Answer) != 1 {
442-
return nil
412+
return nil, nil
443413
}
444414
answer := resp.Answer[0]
445415

446416
var ok bool
417+
var respDomain Name = nil
447418
for _, domain := range domains {
448419
_, ok = answer.Name.TrimSuffix(domain)
449420
if ok {
421+
respDomain = domain
450422
break
451423
}
452424
}
453425
if !ok {
454-
return nil
426+
return nil, nil
455427
}
456428

457429
if answer.Type != RRTypeTXT {
458-
return nil
430+
return nil, nil
459431
}
460432
payload, err := DecodeRDataTXT(answer.Data)
461433
if err != nil {
462-
return nil
434+
return nil, nil
463435
}
464436

465-
return payload
437+
return payload, respDomain
466438
}

transport/internet/finalmask/xdns/dns.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,18 @@ func (name Name) TrimSuffix(suffix Name) (Name, bool) {
150150
return fore, true
151151
}
152152

153+
func (name Name) EqualTo(another Name) bool {
154+
if len(name) != len(another) {
155+
return false
156+
}
157+
for i, label := range name {
158+
if !bytes.Equal(label, another[i]) {
159+
return false
160+
}
161+
}
162+
return true
163+
}
164+
153165
// Message represents a DNS message.
154166
//
155167
// https://tools.ietf.org/html/rfc1035#section-4.1

0 commit comments

Comments
 (0)