99 go_errors "errors"
1010 "io"
1111 "net"
12+ "net/netip"
1213 "strconv"
1314 "strings"
1415 "sync"
@@ -37,11 +38,12 @@ type packet struct {
3738}
3839
3940type xdnsConnClient struct {
40- conn net.PacketConn
41- resolverConns []net. PacketConn
41+ net.PacketConn
42+
4243 resolverAddrs []* net.UDPAddr
4344 resolverIdx uint32
4445 resolverSend []atomic.Uint32
46+ resolverNormalizedAddrs []netip.AddrPort
4547
4648 clientID []byte
4749 domains []Name
@@ -54,6 +56,11 @@ type xdnsConnClient struct {
5456 mutex sync.Mutex
5557}
5658
59+ func normalizeAddrPort (ap netip.AddrPort ) netip.AddrPort {
60+ ip6 := netip .AddrFrom16 (ap .Addr ().As16 ())
61+ return netip .AddrPortFrom (ip6 , ap .Port ())
62+ }
63+
5764func NewConnClient (c * Config , raw net.PacketConn ) (net.PacketConn , error ) {
5865 if len (c .Resolvers ) == 0 {
5966 return nil , errors .New ("empty resolvers" )
@@ -74,8 +81,8 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
7481 servers = append (servers , parts [1 ])
7582 }
7683
77- var resolverConns []net.PacketConn
7884 var resolverAddrs []* net.UDPAddr
85+ var resolverNormalizedAddrs []netip.AddrPort
7986 var resolverSend []atomic.Uint32
8087 for _ , rs := range servers {
8188 h , p , err := net .SplitHostPort (rs )
@@ -90,28 +97,19 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
9097 if port == 0 {
9198 return nil , errors .New ("invalid port" )
9299 }
93- var uc net.PacketConn
94- if ip .To4 () != nil {
95- uc , err = net .ListenPacket ("udp4" , ":0" )
96- } else {
97- uc , err = net .ListenPacket ("udp6" , ":0" )
98- }
99- if err != nil {
100- for _ , rc := range resolverConns {
101- rc .Close ()
102- }
103- return nil , errors .New ("failed to create resolver socket: " , err )
104- }
105- resolverConns = append (resolverConns , uc )
106- resolverAddrs = append (resolverAddrs , & net.UDPAddr {IP : ip , Port : port })
100+
101+ addr := & net.UDPAddr {IP : ip , Port : port }
102+ resolverAddrs = append (resolverAddrs , addr )
103+ resolverNormalizedAddrs = append (resolverNormalizedAddrs , normalizeAddrPort (addr .AddrPort ()))
107104 }
108- resolverSend = make ([]atomic.Uint32 , len (resolverConns ))
105+ resolverSend = make ([]atomic.Uint32 , len (resolverAddrs ))
109106
110107 conn := & xdnsConnClient {
111- conn : raw ,
112- resolverConns : resolverConns ,
108+ PacketConn : raw ,
109+
113110 resolverAddrs : resolverAddrs ,
114111 resolverSend : resolverSend ,
112+ resolverNormalizedAddrs : resolverNormalizedAddrs ,
115113
116114 clientID : make ([]byte , 8 ),
117115 domains : domains ,
@@ -130,69 +128,69 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
130128}
131129
132130func (c * xdnsConnClient ) recvLoop () {
133- var wg sync.WaitGroup
134-
135- for i , rc := range c .resolverConns {
136- wg .Add (1 )
137- go func () {
138- defer wg .Done ()
139-
140- var buf [finalmask .UDPSize ]byte
141-
142- for {
143- if c .closed {
144- break
145- }
146-
147- n , addr , err := rc .ReadFrom (buf [:])
148- if err != nil {
149- if go_errors .Is (err , net .ErrClosed ) {
150- break
151- }
152- continue
153- }
154-
155- resp , err := MessageFromWireFormat (buf [:n ])
156- if err != nil {
157- errors .LogDebug (context .Background (), addr , " xdns from wireformat err " , err )
158- continue
159- }
160-
161- payload := dnsResponsePayload (& resp , c .domains )
162-
163- r := bytes .NewReader (payload )
164- anyPacket := false
165- for {
166- p , err := nextPacket (r )
167- if err != nil {
168- break
169- }
170- anyPacket = true
171-
172- buf := make ([]byte , len (p ))
173- copy (buf , p )
174- select {
175- case c .readQueue <- & packet {
176- p : buf ,
177- addr : addr ,
178- }:
179- default :
180- errors .LogDebug (context .Background (), addr , " mask read err queue full" )
181- }
182- }
183-
184- if anyPacket {
185- c .resolverSend [i ].Store (0 )
186- select {
187- case c .pollChan <- struct {}{}:
188- default :
189- }
190- }
131+ var buf [finalmask .UDPSize ]byte
132+
133+ for {
134+ if c .closed {
135+ break
136+ }
137+
138+ n , addr , err := c .PacketConn .ReadFrom (buf [:])
139+ if err != nil {
140+ if go_errors .Is (err , net .ErrClosed ) {
141+ break
191142 }
192- }()
193- }
143+ continue
144+ }
145+
146+ resp , err := MessageFromWireFormat (buf [:n ])
147+ if err != nil {
148+ errors .LogDebug (context .Background (), addr , " xdns from wireformat err " , err )
149+ continue
150+ }
194151
195- wg .Wait ()
152+ payload , domain := dnsResponsePayload (& resp , c .domains )
153+
154+ rsIdx := - 1
155+ readAddr := normalizeAddrPort (addr .(* net.UDPAddr ).AddrPort ())
156+ for j , rsAddr := range c .resolverNormalizedAddrs {
157+ if readAddr == rsAddr && domain .EqualTo (c .domains [j ]) {
158+ rsIdx = j
159+ break
160+ }
161+ }
162+
163+ r := bytes .NewReader (payload )
164+ anyPacket := false
165+ for {
166+ p , err := nextPacket (r )
167+ if err != nil {
168+ break
169+ }
170+ anyPacket = true
171+
172+ buf := make ([]byte , len (p ))
173+ copy (buf , p )
174+ select {
175+ case c .readQueue <- & packet {
176+ p : buf ,
177+ addr : addr ,
178+ }:
179+ default :
180+ errors .LogDebug (context .Background (), addr , " mask read err queue full" )
181+ }
182+ }
183+
184+ if anyPacket {
185+ if rsIdx >= 0 {
186+ c .resolverSend [rsIdx ].Store (0 )
187+ }
188+ select {
189+ case c .pollChan <- struct {}{}:
190+ default :
191+ }
192+ }
193+ }
196194
197195 errors .LogDebug (context .Background (), "xdns closed" )
198196
@@ -255,10 +253,10 @@ func (c *xdnsConnClient) sendLoop() {
255253
256254 cur := c .resolverIdx
257255 curSend := c .resolverSend [c .resolverIdx ].Add (1 )
258- _ , _ = c .resolverConns [ c . resolverIdx ] .WriteTo (p .p , c .resolverAddrs [c .resolverIdx ])
256+ _ , _ = c .PacketConn .WriteTo (p .p , c .resolverAddrs [c .resolverIdx ])
259257 for {
260258 c .resolverIdx += 1
261- c .resolverIdx %= uint32 (len (c .resolverConns ))
259+ c .resolverIdx %= uint32 (len (c .resolverAddrs ))
262260 if c .resolverIdx == cur {
263261 break
264262 }
@@ -290,7 +288,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
290288 return 0 , io .ErrClosedPipe
291289 }
292290
293- encoded , err := encode (p , c .clientID , c .domains [c .resolverIdx % uint32 (len (c .resolverConns ))])
291+ encoded , err := encode (p , c .clientID , c .domains [c .resolverIdx % uint32 (len (c .resolverAddrs ))])
294292 if err != nil {
295293 errors .LogDebug (context .Background (), addr , " xdns wireformat err " , err , " " , len (p ))
296294 return 0 , nil
@@ -310,35 +308,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
310308
311309func (c * xdnsConnClient ) Close () error {
312310 c .closed = true
313- for _ , rc := range c .resolverConns {
314- rc .Close ()
315- }
316- return c .conn .Close ()
317- }
318-
319- func (c * xdnsConnClient ) LocalAddr () net.Addr {
320- return c .conn .LocalAddr ()
321- }
322-
323- func (c * xdnsConnClient ) SetDeadline (t time.Time ) error {
324- for _ , rc := range c .resolverConns {
325- rc .SetDeadline (t )
326- }
327- return c .conn .SetDeadline (t )
328- }
329-
330- func (c * xdnsConnClient ) SetReadDeadline (t time.Time ) error {
331- for _ , rc := range c .resolverConns {
332- rc .SetReadDeadline (t )
333- }
334- return c .conn .SetReadDeadline (t )
335- }
336-
337- func (c * xdnsConnClient ) SetWriteDeadline (t time.Time ) error {
338- for _ , rc := range c .resolverConns {
339- rc .SetWriteDeadline (t )
340- }
341- return c .conn .SetWriteDeadline (t )
311+ return c .PacketConn .Close ()
342312}
343313
344314func encode (p []byte , clientID []byte , domain Name ) ([]byte , error ) {
@@ -430,37 +400,39 @@ func nextPacket(r *bytes.Reader) ([]byte, error) {
430400 return p , err
431401}
432402
433- func dnsResponsePayload (resp * Message , domains []Name ) []byte {
403+ func dnsResponsePayload (resp * Message , domains []Name ) ( []byte , Name ) {
434404 if resp .Flags & 0x8000 != 0x8000 {
435- return nil
405+ return nil , nil
436406 }
437407 if resp .Flags & 0x000f != RcodeNoError {
438- return nil
408+ return nil , nil
439409 }
440410
441411 if len (resp .Answer ) != 1 {
442- return nil
412+ return nil , nil
443413 }
444414 answer := resp .Answer [0 ]
445415
446416 var ok bool
417+ var respDomain Name = nil
447418 for _ , domain := range domains {
448419 _ , ok = answer .Name .TrimSuffix (domain )
449420 if ok {
421+ respDomain = domain
450422 break
451423 }
452424 }
453425 if ! ok {
454- return nil
426+ return nil , nil
455427 }
456428
457429 if answer .Type != RRTypeTXT {
458- return nil
430+ return nil , nil
459431 }
460432 payload , err := DecodeRDataTXT (answer .Data )
461433 if err != nil {
462- return nil
434+ return nil , nil
463435 }
464436
465- return payload
437+ return payload , respDomain
466438}
0 commit comments