fix: potential panic and check upstreams format before using them

This commit is contained in:
mzz2017
2023-03-21 18:27:30 +08:00
parent 948cfd3f7f
commit 4592be2320
5 changed files with 41 additions and 12 deletions

View File

@ -21,6 +21,7 @@ import (
var BadUpstreamFormatError = fmt.Errorf("bad upstream format")
type Dns struct {
log *logrus.Logger
upstream []*UpstreamResolver
upstream2IndexMu sync.Mutex
upstream2Index map[*Upstream]int
@ -34,6 +35,7 @@ type NewOption struct {
func New(log *logrus.Logger, dns *config.Dns, opt *NewOption) (s *Dns, err error) {
s = &Dns{
log: log,
upstream2Index: map[*Upstream]int{
nil: int(consts.DnsRequestOutboundIndex_AsIs),
},
@ -115,10 +117,26 @@ func New(log *logrus.Logger, dns *config.Dns, opt *NewOption) (s *Dns, err error
return s, nil
}
func (s *Dns) InitUpstreams() {
func (s *Dns) CheckUpstreamsFormat() error {
for _, upstream := range s.upstream {
upstream.GetUpstream()
_, _, _, err := ParseRawUpstream(upstream.Raw)
if err != nil {
return err
}
}
return nil
}
func (s *Dns) InitUpstreams() {
var wg sync.WaitGroup
for _, upstream := range s.upstream {
wg.Add(1)
go func(upstream *UpstreamResolver) {
upstream.GetUpstream()
wg.Done()
}(upstream)
}
wg.Wait()
}
func (s *Dns) RequestSelect(msg *dnsmessage.Message) (upstream *Upstream, err error) {

View File

@ -8,9 +8,9 @@ package dns
import (
"context"
"fmt"
"github.com/mzz2017/softwind/protocol/direct"
"github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/common/netutils"
"github.com/mzz2017/softwind/protocol/direct"
"net"
"net/url"
"strconv"
@ -18,12 +18,17 @@ import (
"time"
)
var (
FormatError = fmt.Errorf("format error")
)
type UpstreamScheme string
const (
UpstreamScheme_TCP UpstreamScheme = "tcp"
UpstreamScheme_UDP UpstreamScheme = "udp"
UpstreamScheme_TCP_UDP UpstreamScheme = "tcp+udp"
UpstreamScheme_TCP UpstreamScheme = "tcp"
UpstreamScheme_UDP UpstreamScheme = "udp"
UpstreamScheme_TCP_UDP UpstreamScheme = "tcp+udp"
upstreamScheme_TCP_UDP_Alias UpstreamScheme = "udp+tcp"
)
func (s UpstreamScheme) ContainsTcp() bool {
@ -39,13 +44,16 @@ func (s UpstreamScheme) ContainsTcp() bool {
func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, port uint16, err error) {
var __port string
switch scheme = UpstreamScheme(raw.Scheme); scheme {
case upstreamScheme_TCP_UDP_Alias:
scheme = UpstreamScheme_TCP_UDP
fallthrough
case UpstreamScheme_TCP, UpstreamScheme_UDP, UpstreamScheme_TCP_UDP:
__port = raw.Port()
if __port == "" {
__port = "53"
}
default:
return "", "", 0, fmt.Errorf("unexpected dns_upstream format")
return "", "", 0, fmt.Errorf("unexpected scheme: %v", raw.Scheme)
}
_port, err := strconv.ParseUint(__port, 10, 16)
if err != nil {
@ -66,7 +74,7 @@ type Upstream struct {
func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err error) {
scheme, hostname, port, err := ParseRawUpstream(upstream)
if err != nil {
return nil, err
return nil, fmt.Errorf("%w: %v", FormatError, err)
}
systemDns, err := netutils.SystemDns()
@ -146,7 +154,7 @@ func (u *UpstreamResolver) GetUpstream() (_ *Upstream, err error) {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
if u.upstream, err = NewUpstream(ctx, u.Raw); err != nil {
return nil, fmt.Errorf("failed to init dns upstream: %v", err)
return nil, fmt.Errorf("failed to init dns upstream: %w", err)
}
}
return u.upstream, nil