feat: support iptables tproxy (#80)

This commit is contained in:
mzz 2023-06-04 11:38:05 +08:00 committed by GitHub
parent cbcbec9a1a
commit ee09ae17e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 313 additions and 229 deletions

View File

@ -4,8 +4,6 @@ on:
push: push:
branches: branches:
- main - main
- fix*
- feat*
paths: paths:
- "**/*.go" - "**/*.go"
- "**/*.c" - "**/*.c"

View File

@ -23,6 +23,7 @@ else
STRIP_FLAG := -strip=$(STRIP_PATH) STRIP_FLAG := -strip=$(STRIP_PATH)
endif endif
# Do NOT remove the line below. This line is for CI.
#export GOMODCACHE=$(PWD)/go-mod #export GOMODCACHE=$(PWD)/go-mod
# Get version from .git. # Get version from .git.
@ -41,7 +42,7 @@ dae: export GOOS=linux
dae: ebpf dae: ebpf
go build -o $(OUTPUT) -trimpath -ldflags "-s -w -X github.com/daeuniverse/dae/cmd.Version=$(VERSION) -X github.com/daeuniverse/dae/common/consts.MaxMatchSetLen_=$(MAX_MATCH_SET_LEN)" . go build -o $(OUTPUT) -trimpath -ldflags "-s -w -X github.com/daeuniverse/dae/cmd.Version=$(VERSION) -X github.com/daeuniverse/dae/common/consts.MaxMatchSetLen_=$(MAX_MATCH_SET_LEN)" .
clean-ebpf: clean-ebpf:
@rm -f control/bpf_bpf*.go && \ @rm -f control/bpf_bpf*.go && \
rm -f control/bpf_bpf*.o rm -f control/bpf_bpf*.o
fmt: fmt:

View File

@ -1,9 +1,12 @@
package cmd package cmd
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand" "github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/protocol/direct"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -247,6 +250,20 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
if !conf.Global.DisableWaitingNetwork && len(conf.Subscription) > 0 { if !conf.Global.DisableWaitingNetwork && len(conf.Subscription) > 0 {
epo := 5 * time.Second epo := 5 * time.Second
client := http.Client{ client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialer{Dialer: direct.SymmetricDirect}
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae), addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: conn,
LAddr: nil,
RAddr: nil,
}, nil
},
},
Timeout: epo, Timeout: epo,
} }
log.Infoln("Waiting for network...") log.Infoln("Waiting for network...")
@ -274,8 +291,25 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
if len(conf.Subscription) > 0 { if len(conf.Subscription) > 0 {
log.Infoln("Fetching subscriptions...") log.Infoln("Fetching subscriptions...")
} }
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialer{Dialer: direct.SymmetricDirect}
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae), addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: conn,
LAddr: nil,
RAddr: nil,
}, nil
},
},
Timeout: 30 * time.Second,
}
for _, sub := range conf.Subscription { for _, sub := range conf.Subscription {
tag, nodes, err := subscription.ResolveSubscription(log, filepath.Dir(cfgFile), string(sub)) tag, nodes, err := subscription.ResolveSubscription(log, &client, filepath.Dir(cfgFile), string(sub))
if err != nil { if err != nil {
log.Warnf(`failed to resolve subscription "%v": %v`, sub, err) log.Warnf(`failed to resolve subscription "%v": %v`, sub, err)
resolvingfailed = true resolvingfailed = true

View File

@ -146,6 +146,7 @@ var (
const ( const (
TproxyMark uint32 = 0x8000000 TproxyMark uint32 = 0x8000000
Recognize uint16 = 0x2017
LoopbackIfIndex = 1 LoopbackIfIndex = 1
) )

View File

@ -19,7 +19,6 @@ import (
"github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand" "github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/pool" "github.com/mzz2017/softwind/pool"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
) )
@ -91,8 +90,8 @@ func SystemDns() (dns netip.AddrPort, err error) {
return systemDns, nil return systemDns, nil
} }
func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, tcp bool) (addrs []netip.Addr, err error) { func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (addrs []netip.Addr, err error) {
resources, err := resolve(ctx, d, dns, host, typ, tcp) resources, err := resolve(ctx, d, dns, host, typ, network)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -118,16 +117,14 @@ func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, ho
return addrs, nil return addrs, nil
} }
func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, tcp bool) (records []string, err error) { func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, network string) (records []string, err error) {
typ := dnsmessage.TypeNS typ := dnsmessage.TypeNS
resources, err := resolve(ctx, d, dns, host, typ, tcp) resources, err := resolve(ctx, d, dns, host, typ, network)
if err != nil { if err != nil {
return nil, err return nil, err
} }
logrus.Println(host, len(resources))
for _, ans := range resources { for _, ans := range resources {
if ans.Header.Type != typ { if ans.Header.Type != typ {
logrus.Println(host, ans.Header.Type)
continue continue
} }
ns, ok := ans.Body.(*dnsmessage.NSResource) ns, ok := ans.Body.(*dnsmessage.NSResource)
@ -139,7 +136,7 @@ func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host
return records, nil return records, nil
} }
func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, tcp bool) (ans []dnsmessage.Resource, err error) { func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (ans []dnsmessage.Resource, err error) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
fqdn := host fqdn := host
@ -202,7 +199,11 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
if err != nil { if err != nil {
return nil, err return nil, err
} }
if tcp { magicNetwork, err := netproxy.ParseMagicNetwork(network)
if err != nil {
return nil, err
}
if magicNetwork.Network == "tcp" {
// Put DNS request length // Put DNS request length
buf := pool.Get(2 + len(b)) buf := pool.Get(2 + len(b))
defer pool.Put(buf) defer pool.Put(buf)
@ -213,12 +214,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
// Dial and write. // Dial and write.
cd := &netproxy.ContextDialer{Dialer: d} cd := &netproxy.ContextDialer{Dialer: d}
var c netproxy.Conn c, err := cd.DialContext(ctx, network, dns.String())
if tcp {
c, err = cd.DialTcpContext(ctx, dns.String())
} else {
c, err = cd.DialUdpContext(ctx, dns.String())
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -228,7 +224,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
return nil, err return nil, err
} }
ch := make(chan error, 2) ch := make(chan error, 2)
if !tcp { if magicNetwork.Network == "udp" {
go func() { go func() {
// Resend every 3 seconds for UDP. // Resend every 3 seconds for UDP.
for { for {
@ -249,7 +245,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
go func() { go func() {
buf := pool.Get(512) buf := pool.Get(512)
defer pool.Put(buf) defer pool.Put(buf)
if tcp { if magicNetwork.Network == "tcp" {
// Read DNS response length // Read DNS response length
_, err := io.ReadFull(c, buf[:2]) _, err := io.ReadFull(c, buf[:2])
if err != nil { if err != nil {

View File

@ -22,7 +22,7 @@ type Ip46 struct {
Ip6 netip.Addr Ip6 netip.Addr
} }
func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, tcp bool, race bool) (ipv46 *Ip46, err error) { func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, network string, race bool) (ipv46 *Ip46, err error) {
var log *logrus.Logger var log *logrus.Logger
if _log := ctx.Value("logger"); _log != nil { if _log := ctx.Value("logger"); _log != nil {
log = _log.(*logrus.Logger) log = _log.(*logrus.Logger)
@ -49,7 +49,7 @@ func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort
} }
}() }()
var e error var e error
addrs4, e = ResolveNetip(ctx4, dialer, dns, host, dnsmessage.TypeA, tcp) addrs4, e = ResolveNetip(ctx4, dialer, dns, host, dnsmessage.TypeA, network)
if err != nil && !errors.Is(e, context.Canceled) { if err != nil && !errors.Is(e, context.Canceled) {
err4 = e err4 = e
return return
@ -67,7 +67,7 @@ func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort
} }
}() }()
var e error var e error
addrs6, e = ResolveNetip(ctx6, dialer, dns, host, dnsmessage.TypeAAAA, tcp) addrs6, e = ResolveNetip(ctx6, dialer, dns, host, dnsmessage.TypeAAAA, network)
if err != nil && !errors.Is(e, context.Canceled) { if err != nil && !errors.Is(e, context.Canceled) {
err6 = e err6 = e
return return

View File

@ -137,7 +137,7 @@ func ResolveFile(u *url.URL, configDir string) (b []byte, err error) {
return bytes.TrimSpace(b), err return bytes.TrimSpace(b), err
} }
func ResolveSubscription(log *logrus.Logger, configDir string, subscription string) (tag string, nodes []string, err error) { func ResolveSubscription(log *logrus.Logger, client *http.Client, configDir string, subscription string) (tag string, nodes []string, err error) {
/// Get tag. /// Get tag.
tag, subscription = common.GetTagFromLinkLikePlaintext(subscription) tag, subscription = common.GetTagFromLinkLikePlaintext(subscription)
@ -160,7 +160,7 @@ func ResolveSubscription(log *logrus.Logger, configDir string, subscription stri
goto resolve goto resolve
default: default:
} }
resp, err = http.Get(subscription) resp, err = client.Get(subscription)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }

View File

@ -12,6 +12,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"github.com/mzz2017/softwind/netproxy"
"net/netip" "net/netip"
"net/url" "net/url"
"path/filepath" "path/filepath"
@ -221,25 +222,25 @@ func FuzzyDecode(to interface{}, val string) bool {
v := reflect.Indirect(reflect.ValueOf(to)) v := reflect.Indirect(reflect.ValueOf(to))
switch v.Kind() { switch v.Kind() {
case reflect.Int: case reflect.Int:
i, err := strconv.ParseInt(val, 10, strconv.IntSize) i, err := strconv.ParseInt(val, 0, strconv.IntSize)
if err != nil { if err != nil {
return false return false
} }
v.SetInt(i) v.SetInt(i)
case reflect.Int8: case reflect.Int8:
i, err := strconv.ParseInt(val, 10, 8) i, err := strconv.ParseInt(val, 0, 8)
if err != nil { if err != nil {
return false return false
} }
v.SetInt(i) v.SetInt(i)
case reflect.Int16: case reflect.Int16:
i, err := strconv.ParseInt(val, 10, 16) i, err := strconv.ParseInt(val, 0, 16)
if err != nil { if err != nil {
return false return false
} }
v.SetInt(i) v.SetInt(i)
case reflect.Int32: case reflect.Int32:
i, err := strconv.ParseInt(val, 10, 32) i, err := strconv.ParseInt(val, 0, 32)
if err != nil { if err != nil {
return false return false
} }
@ -253,38 +254,38 @@ func FuzzyDecode(to interface{}, val string) bool {
} }
v.Set(reflect.ValueOf(duration)) v.Set(reflect.ValueOf(duration))
default: default:
i, err := strconv.ParseInt(val, 10, 64) i, err := strconv.ParseInt(val, 0, 64)
if err != nil { if err != nil {
return false return false
} }
v.SetInt(i) v.SetInt(i)
} }
case reflect.Uint: case reflect.Uint:
i, err := strconv.ParseUint(val, 10, strconv.IntSize) i, err := strconv.ParseUint(val, 0, strconv.IntSize)
if err != nil { if err != nil {
return false return false
} }
v.SetUint(i) v.SetUint(i)
case reflect.Uint8: case reflect.Uint8:
i, err := strconv.ParseUint(val, 10, 8) i, err := strconv.ParseUint(val, 0, 8)
if err != nil { if err != nil {
return false return false
} }
v.SetUint(i) v.SetUint(i)
case reflect.Uint16: case reflect.Uint16:
i, err := strconv.ParseUint(val, 10, 16) i, err := strconv.ParseUint(val, 0, 16)
if err != nil { if err != nil {
return false return false
} }
v.SetUint(i) v.SetUint(i)
case reflect.Uint32: case reflect.Uint32:
i, err := strconv.ParseUint(val, 10, 32) i, err := strconv.ParseUint(val, 0, 32)
if err != nil { if err != nil {
return false return false
} }
v.SetUint(i) v.SetUint(i)
case reflect.Uint64: case reflect.Uint64:
i, err := strconv.ParseUint(val, 10, 64) i, err := strconv.ParseUint(val, 0, 64)
if err != nil { if err != nil {
return false return false
} }
@ -458,6 +459,17 @@ nextLink:
return Deduplicate(defaultIfs), nil return Deduplicate(defaultIfs), nil
} }
func MagicNetwork(network string, mark uint32) string {
if mark == 0 {
return network
} else {
return netproxy.MagicNetwork{
Network: network,
Mark: mark,
}.Encode()
}
}
func IsValidHttpMethod(method string) bool { func IsValidHttpMethod(method string) bool {
switch method { switch method {
case "GET", "POST", "PUT", "PATCH", "DELETE", "COPY", "HEAD", "OPTIONS", "LINK", "UNLINK", "PURGE", "LOCK", "UNLOCK", "PROPFIND", "CONNECT", "TRACE": case "GET", "POST", "PUT", "PATCH", "DELETE", "COPY", "HEAD", "OPTIONS", "LINK", "UNLINK", "PURGE", "LOCK", "UNLOCK", "PROPFIND", "CONNECT", "TRACE":

View File

@ -32,9 +32,10 @@ type Dns struct {
} }
type NewOption struct { type NewOption struct {
Logger *logrus.Logger Logger *logrus.Logger
LocationFinder *assets.LocationFinder LocationFinder *assets.LocationFinder
UpstreamReadyCallback func(dnsUpstream *Upstream) (err error) UpstreamReadyCallback func(dnsUpstream *Upstream) (err error)
UpstreamResolverNetwork string
} }
func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) { func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
@ -62,7 +63,8 @@ func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
return nil, fmt.Errorf("%w: %v", BadUpstreamFormatError, err) return nil, fmt.Errorf("%w: %v", BadUpstreamFormatError, err)
} }
r := &UpstreamResolver{ r := &UpstreamResolver{
Raw: u, Raw: u,
Network: opt.UpstreamResolverNetwork,
FinishInitCallback: func(i int) func(raw *url.URL, upstream *Upstream) (err error) { FinishInitCallback: func(i int) func(raw *url.URL, upstream *Upstream) (err error) {
return func(raw *url.URL, upstream *Upstream) (err error) { return func(raw *url.URL, upstream *Upstream) (err error) {
if opt != nil && opt.UpstreamReadyCallback != nil { if opt != nil && opt.UpstreamReadyCallback != nil {
@ -77,6 +79,9 @@ func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
return nil return nil
} }
}(i), }(i),
mu: sync.Mutex{},
upstream: nil,
init: false,
} }
upstreamName2Id[tag] = uint8(len(s.upstream)) upstreamName2Id[tag] = uint8(len(s.upstream))
s.upstream = append(s.upstream, r) s.upstream = append(s.upstream, r)

View File

@ -72,7 +72,7 @@ type Upstream struct {
*netutils.Ip46 *netutils.Ip46
} }
func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err error) { func NewUpstream(ctx context.Context, upstream *url.URL, resolverNetwork string) (up *Upstream, err error) {
scheme, hostname, port, err := ParseRawUpstream(upstream) scheme, hostname, port, err := ParseRawUpstream(upstream)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %v", FormatError, err) return nil, fmt.Errorf("%w: %v", FormatError, err)
@ -88,7 +88,7 @@ func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err erro
} }
}() }()
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, false, false) ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, resolverNetwork, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err) return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
} }
@ -131,7 +131,8 @@ func (u *Upstream) String() string {
} }
type UpstreamResolver struct { type UpstreamResolver struct {
Raw *url.URL Raw *url.URL
Network string
// FinishInitCallback may be invoked again if err is not nil // FinishInitCallback may be invoked again if err is not nil
FinishInitCallback func(raw *url.URL, upstream *Upstream) (err error) FinishInitCallback func(raw *url.URL, upstream *Upstream) (err error)
mu sync.Mutex mu sync.Mutex
@ -154,7 +155,7 @@ func (u *UpstreamResolver) GetUpstream() (_ *Upstream, err error) {
}() }()
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel() defer cancel()
if u.upstream, err = NewUpstream(ctx, u.Raw); err != nil { if u.upstream, err = NewUpstream(ctx, u.Raw, u.Network); err != nil {
return nil, fmt.Errorf("failed to init dns upstream: %w", err) return nil, fmt.Errorf("failed to init dns upstream: %w", err)
} }
} }

View File

@ -9,6 +9,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/daeuniverse/dae/common"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@ -121,7 +122,7 @@ type TcpCheckOption struct {
Method string Method string
} }
func ParseTcpCheckOption(ctx context.Context, rawURL []string, method string) (opt *TcpCheckOption, err error) { func ParseTcpCheckOption(ctx context.Context, rawURL []string, method string, resolverNetwork string) (opt *TcpCheckOption, err error) {
if method == "" { if method == "" {
method = http.MethodGet method = http.MethodGet
} }
@ -146,7 +147,7 @@ func ParseTcpCheckOption(ctx context.Context, rawURL []string, method string) (o
if len(rawURL) > 1 { if len(rawURL) > 1 {
ip46 = parseIp46FromList(rawURL[1:]) ip46 = parseIp46FromList(rawURL[1:])
} else { } else {
ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false, false) ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), resolverNetwork, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -164,7 +165,7 @@ type CheckDnsOption struct {
*netutils.Ip46 *netutils.Ip46
} }
func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string) (opt *CheckDnsOption, err error) { func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string, resolverNetwork string) (opt *CheckDnsOption, err error) {
systemDns, err := netutils.SystemDns() systemDns, err := netutils.SystemDns()
if err != nil { if err != nil {
return nil, err return nil, err
@ -191,7 +192,7 @@ func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string) (opt *CheckD
if len(dnsHostPort) > 1 { if len(dnsHostPort) > 1 {
ip46 = parseIp46FromList(dnsHostPort[1:]) ip46 = parseIp46FromList(dnsHostPort[1:])
} else { } else {
ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, false, false) ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, resolverNetwork, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -204,11 +205,12 @@ func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string) (opt *CheckD
} }
type TcpCheckOptionRaw struct { type TcpCheckOptionRaw struct {
opt *TcpCheckOption opt *TcpCheckOption
mu sync.Mutex mu sync.Mutex
Log *logrus.Logger Log *logrus.Logger
Raw []string Raw []string
Method string ResolverNetwork string
Method string
} }
func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) { func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) {
@ -218,7 +220,7 @@ func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel() defer cancel()
ctx = context.WithValue(ctx, "logger", c.Log) ctx = context.WithValue(ctx, "logger", c.Log)
tcpCheckOption, err := ParseTcpCheckOption(ctx, c.Raw, c.Method) tcpCheckOption, err := ParseTcpCheckOption(ctx, c.Raw, c.Method, c.ResolverNetwork)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err) return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
} }
@ -228,9 +230,10 @@ func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) {
} }
type CheckDnsOptionRaw struct { type CheckDnsOptionRaw struct {
opt *CheckDnsOption opt *CheckDnsOption
mu sync.Mutex mu sync.Mutex
Raw []string Raw []string
ResolverNetwork string
} }
func (c *CheckDnsOptionRaw) Option() (opt *CheckDnsOption, err error) { func (c *CheckDnsOptionRaw) Option() (opt *CheckDnsOption, err error) {
@ -239,7 +242,7 @@ func (c *CheckDnsOptionRaw) Option() (opt *CheckDnsOption, err error) {
if c.opt == nil { if c.opt == nil {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel() defer cancel()
udpCheckOption, err := ParseCheckDnsOption(ctx, c.Raw) udpCheckOption, err := ParseCheckDnsOption(ctx, c.Raw, c.ResolverNetwork)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err) return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
} }
@ -266,6 +269,10 @@ func (d *Dialer) ActivateCheck() {
func (d *Dialer) aliveBackground() { func (d *Dialer) aliveBackground() {
timeout := 10 * time.Second timeout := 10 * time.Second
cycle := d.CheckInterval cycle := d.CheckInterval
var tcpSomark uint32
if network, err := netproxy.ParseMagicNetwork(d.TcpCheckOptionRaw.ResolverNetwork); err == nil {
tcpSomark = network.Mark
}
tcp4CheckOpt := &CheckOption{ tcp4CheckOpt := &CheckOption{
networkType: &NetworkType{ networkType: &NetworkType{
L4Proto: consts.L4ProtoStr_TCP, L4Proto: consts.L4ProtoStr_TCP,
@ -285,7 +292,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.") }).Debugln("Skip check due to no DNS record.")
return false, nil return false, nil
} }
return d.HttpCheck(ctx, opt.Url, opt.Ip4, opt.Method) return d.HttpCheck(ctx, opt.Url, opt.Ip4, opt.Method, tcpSomark)
}, },
} }
tcp6CheckOpt := &CheckOption{ tcp6CheckOpt := &CheckOption{
@ -307,7 +314,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.") }).Debugln("Skip check due to no DNS record.")
return false, nil return false, nil
} }
return d.HttpCheck(ctx, opt.Url, opt.Ip6, opt.Method) return d.HttpCheck(ctx, opt.Url, opt.Ip6, opt.Method, tcpSomark)
}, },
} }
tcp4CheckDnsOpt := &CheckOption{ tcp4CheckDnsOpt := &CheckOption{
@ -329,7 +336,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.") }).Debugln("Skip check due to no DNS record.")
return false, nil return false, nil
} }
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), true) return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
}, },
} }
tcp6CheckDnsOpt := &CheckOption{ tcp6CheckDnsOpt := &CheckOption{
@ -351,7 +358,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.") }).Debugln("Skip check due to no DNS record.")
return false, nil return false, nil
} }
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), true) return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
}, },
} }
udp4CheckDnsOpt := &CheckOption{ udp4CheckDnsOpt := &CheckOption{
@ -372,7 +379,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.") }).Debugln("Skip check due to no DNS record.")
return false, nil return false, nil
} }
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), false) return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
}, },
} }
udp6CheckDnsOpt := &CheckOption{ udp6CheckDnsOpt := &CheckOption{
@ -393,7 +400,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.") }).Debugln("Skip check due to no DNS record.")
return false, nil return false, nil
} }
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), false) return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
}, },
} }
var CheckOpts = []*CheckOption{ var CheckOpts = []*CheckOption{
@ -535,7 +542,7 @@ func (d *Dialer) Check(timeout time.Duration,
return ok, err return ok, err
} }
func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr, method string) (ok bool, err error) { func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr, method string, soMark uint32) (ok bool, err error) {
// HTTP(S) check. // HTTP(S) check.
if method == "" { if method == "" {
method = http.MethodGet method = http.MethodGet
@ -545,7 +552,7 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
// Force to dial "ip". // Force to dial "ip".
conn, err := cd.DialTcpContext(ctx, net.JoinHostPort(ip.String(), u.Port())) conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", soMark), net.JoinHostPort(ip.String(), u.Port()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -584,8 +591,8 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
} }
} }
func (d *Dialer) DnsCheck(ctx context.Context, dns netip.AddrPort, tcp bool) (ok bool, err error) { func (d *Dialer) DnsCheck(ctx context.Context, dns netip.AddrPort, network string) (ok bool, err error) {
addrs, err := netutils.ResolveNetip(ctx, d, dns, consts.UdpCheckLookupHost, dnsmessage.TypeA, tcp) addrs, err := netutils.ResolveNetip(ctx, d, dns, consts.UdpCheckLookupHost, dnsmessage.TypeA, network)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@ -2,6 +2,7 @@ package trojan
import ( import (
"fmt" "fmt"
"github.com/daeuniverse/dae/component/outbound/transport/tls"
"net" "net"
"net/url" "net/url"
"strconv" "strconv"
@ -9,7 +10,6 @@ import (
"github.com/daeuniverse/dae/common" "github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/component/outbound/dialer" "github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/daeuniverse/dae/component/outbound/transport/tls"
"github.com/daeuniverse/dae/component/outbound/transport/ws" "github.com/daeuniverse/dae/component/outbound/transport/ws"
"github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/protocol" "github.com/mzz2017/softwind/protocol"

View File

@ -63,38 +63,28 @@ func (s *SimpleObfs) Dial(network, addr string) (c netproxy.Conn, err error) {
} }
switch magicNetwork.Network { switch magicNetwork.Network {
case "tcp": case "tcp":
return s.DialTcp(addr) rc, err := s.dialer.Dial(network, s.addr)
if err != nil {
return nil, fmt.Errorf("[simpleobfs]: dial to %s: %w", s.addr, err)
}
host, port, err := net.SplitHostPort(s.addr)
if err != nil {
return nil, err
}
if s.host != "" {
host = s.host
}
switch s.obfstype {
case HTTP:
c = NewHTTPObfs(rc, host, port, s.path)
case TLS:
c = NewTLSObfs(rc, host)
}
return c, err
case "udp": case "udp":
return s.DialUdp(addr) return nil, fmt.Errorf("%w: simpleobfs+udp", netproxy.UnsupportedTunnelTypeError)
default: default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network) return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
} }
} }
func (s *SimpleObfs) DialUdp(addr string) (conn netproxy.PacketConn, err error) {
return nil, fmt.Errorf("%w: simpleobfs+udp", netproxy.UnsupportedTunnelTypeError)
}
// DialTcp connects to the address addr on the network net via the proxy.
func (s *SimpleObfs) DialTcp(addr string) (c netproxy.Conn, err error) {
rc, err := s.dialer.DialTcp(s.addr)
if err != nil {
return nil, fmt.Errorf("[simpleobfs]: dial to %s: %w", s.addr, err)
}
host, port, err := net.SplitHostPort(s.addr)
if err != nil {
return nil, err
}
if s.host != "" {
host = s.host
}
switch s.obfstype {
case HTTP:
c = NewHTTPObfs(rc, host, port, s.path)
case TLS:
c = NewTLSObfs(rc, host)
}
return c, err
}

View File

@ -61,55 +61,47 @@ func (s *Tls) Dial(network, addr string) (c netproxy.Conn, err error) {
} }
switch magicNetwork.Network { switch magicNetwork.Network {
case "tcp": case "tcp":
return s.DialTcp(addr) rc, err := s.dialer.Dial(network, addr)
if err != nil {
return nil, fmt.Errorf("[Tls]: dial to %s: %w", s.addr, err)
}
var tlsConn interface {
netproxy.Conn
Handshake() error
}
switch s.tlsImplentation {
case "tls":
tlsConn = tls.Client(&netproxy.FakeNetConn{
Conn: rc,
LAddr: nil,
RAddr: nil,
}, s.tlsConfig)
case "utls":
clientHelloID, err := nameToUtlsClientHelloID(s.utlsImitate)
if err != nil {
return nil, err
}
tlsConn = utls.UClient(&netproxy.FakeNetConn{
Conn: rc,
LAddr: nil,
RAddr: nil,
}, uTLSConfigFromTLSConfig(s.tlsConfig), *clientHelloID)
default:
return nil, fmt.Errorf("unknown tls implementation: %v", s.tlsImplentation)
}
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
return tlsConn, err
case "udp": case "udp":
return s.DialUdp(addr) return nil, fmt.Errorf("%w: tls+udp", netproxy.UnsupportedTunnelTypeError)
default: default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network) return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
} }
} }
func (s *Tls) DialUdp(addr string) (conn netproxy.PacketConn, err error) {
return nil, fmt.Errorf("%w: tls+udp", netproxy.UnsupportedTunnelTypeError)
}
func (s *Tls) DialTcp(addr string) (conn netproxy.Conn, err error) {
rc, err := s.dialer.DialTcp(addr)
if err != nil {
return nil, fmt.Errorf("[Tls]: dial to %s: %w", s.addr, err)
}
var tlsConn interface {
netproxy.Conn
Handshake() error
}
switch s.tlsImplentation {
case "tls":
tlsConn = tls.Client(&netproxy.FakeNetConn{
Conn: rc,
LAddr: nil,
RAddr: nil,
}, s.tlsConfig)
case "utls":
clientHelloID, err := nameToUtlsClientHelloID(s.utlsImitate)
if err != nil {
return nil, err
}
tlsConn = utls.UClient(&netproxy.FakeNetConn{
Conn: rc,
LAddr: nil,
RAddr: nil,
}, uTLSConfigFromTLSConfig(s.tlsConfig), *clientHelloID)
default:
return nil, fmt.Errorf("unknown tls implementation: %v", s.tlsImplentation)
}
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
return tlsConn, err
}

View File

@ -13,10 +13,10 @@ import (
// Ws is a base Ws struct // Ws is a base Ws struct
type Ws struct { type Ws struct {
dialer netproxy.Dialer dialer netproxy.Dialer
wsAddr string wsAddr string
header http.Header header http.Header
wsDialer *websocket.Dialer tlsClientConfig *tls.Config
} }
// NewWs returns a Ws infra. // NewWs returns a Ws infra.
@ -43,23 +43,9 @@ func NewWs(s string, d netproxy.Dialer) (*Ws, error) {
Host: u.Host, Host: u.Host,
} }
t.wsAddr = wsUrl.String() + u.Path t.wsAddr = wsUrl.String() + u.Path
t.wsDialer = &websocket.Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
c, err := d.DialTcp(addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: c,
LAddr: nil,
RAddr: nil,
}, nil
},
//Subprotocols: []string{"binary"},
}
if u.Scheme == "wss" { if u.Scheme == "wss" {
skipVerify, _ := strconv.ParseBool(u.Query().Get("allowInsecure")) skipVerify, _ := strconv.ParseBool(u.Query().Get("allowInsecure"))
t.wsDialer.TLSClientConfig = &tls.Config{ t.tlsClientConfig = &tls.Config{
ServerName: u.Query().Get("sni"), ServerName: u.Query().Get("sni"),
InsecureSkipVerify: skipVerify, InsecureSkipVerify: skipVerify,
} }
@ -74,23 +60,28 @@ func (s *Ws) Dial(network, addr string) (c netproxy.Conn, err error) {
} }
switch magicNetwork.Network { switch magicNetwork.Network {
case "tcp": case "tcp":
return s.DialTcp(addr) wsDialer := &websocket.Dialer{
NetDial: func(_, addr string) (net.Conn, error) {
c, err := s.dialer.Dial(network, addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: c,
LAddr: nil,
RAddr: nil,
}, nil
},
//Subprotocols: []string{"binary"},
}
rc, _, err := wsDialer.Dial(s.wsAddr, s.header)
if err != nil {
return nil, fmt.Errorf("[Ws]: dial to %s: %w", s.wsAddr, err)
}
return newConn(rc), err
case "udp": case "udp":
return s.DialUdp(addr) return nil, fmt.Errorf("%w: ws+udp", netproxy.UnsupportedTunnelTypeError)
default: default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network) return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
} }
} }
func (s *Ws) DialUdp(addr string) (netproxy.PacketConn, error) {
return nil, fmt.Errorf("%w: ws+udp", netproxy.UnsupportedTunnelTypeError)
}
// DialTcp connects to the address addr on the network net via the infra.
func (s *Ws) DialTcp(addr string) (netproxy.Conn, error) {
rc, _, err := s.wsDialer.Dial(s.wsAddr, s.header)
if err != nil {
return nil, fmt.Errorf("[Ws]: dial to %s: %w", s.wsAddr, err)
}
return newConn(rc), err
}

View File

@ -14,8 +14,10 @@ import (
) )
type Global struct { type Global struct {
TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"` TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"`
LogLevel string `mapstructure:"log_level" default:"info"` TproxyPortProtect bool `mapstructure:"tproxy_port_protect" default:"true"`
SoMarkFromDae uint32 `mapstructure:"so_mark_from_dae"`
LogLevel string `mapstructure:"log_level" default:"info"`
// We use DirectTcpCheckUrl to check (tcp)*(ipv4/ipv6) connectivity for direct. // We use DirectTcpCheckUrl to check (tcp)*(ipv4/ipv6) connectivity for direct.
//DirectTcpCheckUrl string `mapstructure:"direct_tcp_check_url" default:"http://www.qualcomm.cn/generate_204"` //DirectTcpCheckUrl string `mapstructure:"direct_tcp_check_url" default:"http://www.qualcomm.cn/generate_204"`
TcpCheckUrl []string `mapstructure:"tcp_check_url" default:"http://cp.cloudflare.com,1.1.1.1,2606:4700:4700::1111"` TcpCheckUrl []string `mapstructure:"tcp_check_url" default:"http://cp.cloudflare.com,1.1.1.1,2606:4700:4700::1111"`

View File

@ -36,6 +36,8 @@ var SectionDescription = map[string]Desc{
var GlobalDesc = Desc{ var GlobalDesc = Desc{
"tproxy_port": "tproxy port to listen on. It is NOT a HTTP/SOCKS port, and is just used by eBPF program.\nIn normal case, you do not need to use it.", "tproxy_port": "tproxy port to listen on. It is NOT a HTTP/SOCKS port, and is just used by eBPF program.\nIn normal case, you do not need to use it.",
"tproxy_port_protect": "Set it true to protect tproxy port from unsolicited traffic. Set it false to allow users to use self-managed iptables tproxy rules.",
"so_mark_from_dae": "If not zero, traffic sent from dae will be set SO_MARK. It is useful to avoid traffic loop with iptables tproxy rules.",
"log_level": "Log level: error, warn, info, debug, trace.", "log_level": "Log level: error, warn, info, debug, trace.",
"tcp_check_url": "Node connectivity check.\nHost of URL should have both IPv4 and IPv6 if you have double stack in local.\nConsidering traffic consumption, it is recommended to choose a site with anycast IP and less response.", "tcp_check_url": "Node connectivity check.\nHost of URL should have both IPv4 and IPv6 if you have double stack in local.\nConsidering traffic consumption, it is recommended to choose a site with anycast IP and less response.",
"tcp_check_http_method": "The HTTP request method to `tcp_check_url`. Use 'CONNECT' by default because some server implementations bypass accounting for this kind of traffic.", "tcp_check_http_method": "The HTTP request method to `tcp_check_url`. Use 'CONNECT' by default because some server implementations bypass accounting for this kind of traffic.",

View File

@ -67,7 +67,9 @@ type ControlPlane struct {
wanInterface []string wanInterface []string
lanInterface []string lanInterface []string
sniffingTimeout time.Duration sniffingTimeout time.Duration
tproxyPortProtect bool
soMarkFromDae uint32
} }
func NewControlPlane( func NewControlPlane(
@ -226,9 +228,17 @@ func NewControlPlane(
log.Warnln("AllowInsecure is enabled, but it is not recommended. Please make sure you have to turn it on.") log.Warnln("AllowInsecure is enabled, but it is not recommended. Please make sure you have to turn it on.")
} }
option := &dialer.GlobalOption{ option := &dialer.GlobalOption{
Log: log, Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: global.TcpCheckUrl, Log: log, Method: global.TcpCheckHttpMethod}, TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: global.UdpCheckDns}, Raw: global.TcpCheckUrl,
Log: log,
ResolverNetwork: common.MagicNetwork("udp", global.SoMarkFromDae),
Method: global.TcpCheckHttpMethod,
},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{
Raw: global.UdpCheckDns,
ResolverNetwork: common.MagicNetwork("udp", global.SoMarkFromDae),
},
CheckInterval: global.CheckInterval, CheckInterval: global.CheckInterval,
CheckTolerance: global.CheckTolerance, CheckTolerance: global.CheckTolerance,
CheckDnsTcp: true, CheckDnsTcp: true,
@ -337,23 +347,25 @@ func NewControlPlane(
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
plane := &ControlPlane{ plane := &ControlPlane{
log: log, log: log,
core: core, core: core,
deferFuncs: deferFuncs, deferFuncs: deferFuncs,
listenIp: "0.0.0.0", listenIp: "0.0.0.0",
outbounds: outbounds, outbounds: outbounds,
dnsController: nil, dnsController: nil,
onceNetworkReady: sync.Once{}, onceNetworkReady: sync.Once{},
dialMode: dialMode, dialMode: dialMode,
routingMatcher: routingMatcher, routingMatcher: routingMatcher,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
ready: make(chan struct{}), ready: make(chan struct{}),
muRealDomainSet: sync.Mutex{}, muRealDomainSet: sync.Mutex{},
realDomainSet: bloom.NewWithEstimates(2048, 0.001), realDomainSet: bloom.NewWithEstimates(2048, 0.001),
lanInterface: global.LanInterface, lanInterface: global.LanInterface,
wanInterface: global.WanInterface, wanInterface: global.WanInterface,
sniffingTimeout: sniffingTimeout, sniffingTimeout: sniffingTimeout,
tproxyPortProtect: global.TproxyPortProtect,
soMarkFromDae: global.SoMarkFromDae,
} }
defer func() { defer func() {
if err != nil { if err != nil {
@ -363,9 +375,10 @@ func NewControlPlane(
/// DNS upstream. /// DNS upstream.
dnsUpstream, err := dns.New(dnsConfig, &dns.NewOption{ dnsUpstream, err := dns.New(dnsConfig, &dns.NewOption{
Logger: log, Logger: log,
LocationFinder: locationFinder, LocationFinder: locationFinder,
UpstreamReadyCallback: plane.dnsUpstreamReadyCallback, UpstreamReadyCallback: plane.dnsUpstreamReadyCallback,
UpstreamResolverNetwork: common.MagicNetwork("udp", global.SoMarkFromDae),
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -559,7 +572,7 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip
// TODO: use DNS controller and re-route by control plane. // TODO: use DNS controller and re-route by control plane.
systemDns, err := netutils.SystemDns() systemDns, err := netutils.SystemDns()
if err == nil { if err == nil {
if ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, domain, false, true); err == nil && (ip46.Ip4.IsValid() || ip46.Ip6.IsValid()) { if ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, domain, common.MagicNetwork("udp", c.soMarkFromDae), true); err == nil && (ip46.Ip4.IsValid() || ip46.Ip6.IsValid()) {
// Has A/AAAA records. It is a real domain. // Has A/AAAA records. It is a real domain.
dialMode = consts.DialMode_Domain dialMode = consts.DialMode_Domain
// Add it to real-domain set. // Add it to real-domain set.
@ -717,8 +730,21 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err
lastErr := err lastErr := err
addrHdr, dataOffset, err := ParseAddrHdr(data) addrHdr, dataOffset, err := ParseAddrHdr(data)
if err != nil { if err != nil {
c.log.Warnf("No AddrPort presented: %v, %v", lastErr, err) if c.tproxyPortProtect {
return c.log.Warnf("No AddrPort presented: %v, %v", lastErr, err)
return
} else {
routingResult = &bpfRoutingResult{
Mark: 0,
Must: 0,
Mac: [6]uint8{},
Outbound: uint8(consts.OutboundControlPlaneRouting),
Pname: [16]uint8{},
Pid: 0,
}
realDst = pktDst
goto destRetrieved
}
} }
n := copy(data, data[dataOffset:]) n := copy(data, data[dataOffset:])
data = data[:n] data = data[:n]
@ -731,6 +757,7 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err
} else { } else {
realDst = pktDst realDst = pktDst
} }
destRetrieved:
if e := c.handlePkt(udpConn, data, common.ConvergeAddrPort(src), common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult); e != nil { if e := c.handlePkt(udpConn, data, common.ConvergeAddrPort(src), common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult); e != nil {
c.log.Warnln("handlePkt:", e) c.log.Warnln("handlePkt:", e)
} }
@ -814,6 +841,9 @@ func (c *ControlPlane) chooseBestDnsDialer(
if err != nil { if err != nil {
return nil, err return nil, err
} }
if mark == 0 {
mark = c.soMarkFromDae
}
if int(outboundIndex) >= len(c.outbounds) { if int(outboundIndex) >= len(c.outbounds) {
return nil, fmt.Errorf("bad outbound index: %v", outboundIndex) return nil, fmt.Errorf("bad outbound index: %v", outboundIndex)
} }

View File

@ -10,6 +10,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/daeuniverse/dae/common"
"io" "io"
"math" "math"
"net" "net"
@ -652,7 +653,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
// TODO: connection pool. // TODO: connection pool.
conn, err = dialArgument.bestDialer.Dial( conn, err = dialArgument.bestDialer.Dial(
MagicNetwork("udp", dialArgument.mark), common.MagicNetwork("udp", dialArgument.mark),
dialArgument.bestTarget.String(), dialArgument.bestTarget.String(),
) )
if err != nil { if err != nil {
@ -714,7 +715,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
case consts.L4ProtoStr_TCP: case consts.L4ProtoStr_TCP:
// We can block here because we are in a coroutine. // We can block here because we are in a coroutine.
conn, err = dialArgument.bestDialer.Dial(MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String()) conn, err = dialArgument.bestDialer.Dial(common.MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String())
if err != nil { if err != nil {
return fmt.Errorf("failed to dial proxy to tcp: %w", err) return fmt.Errorf("failed to dial proxy to tcp: %w", err)
} }

View File

@ -64,6 +64,7 @@
#define IS_LAN 1 #define IS_LAN 1
#define TPROXY_MARK 0x8000000 #define TPROXY_MARK 0x8000000
#define RECOGNIZE 0x2017
#define ESOCKTNOSUPPORT 94 /* Socket type not supported */ #define ESOCKTNOSUPPORT 94 /* Socket type not supported */
@ -139,6 +140,7 @@ struct routing_result {
struct dst_routing_result { struct dst_routing_result {
__be32 ip[4]; __be32 ip[4];
__be16 port; __be16 port;
__u16 recognize;
struct routing_result routing_result; struct routing_result routing_result;
}; };
@ -1751,6 +1753,7 @@ int tproxy_wan_egress(struct __sk_buff *skb) {
__builtin_memset(&new_hdr, 0, sizeof(new_hdr)); __builtin_memset(&new_hdr, 0, sizeof(new_hdr));
__builtin_memcpy(new_hdr.ip, &tuples.dip, IPV6_BYTE_LENGTH); __builtin_memcpy(new_hdr.ip, &tuples.dip, IPV6_BYTE_LENGTH);
new_hdr.port = udph.dest; new_hdr.port = udph.dest;
new_hdr.recognize = RECOGNIZE;
new_hdr.routing_result.outbound = s64_ret; new_hdr.routing_result.outbound = s64_ret;
new_hdr.routing_result.mark = s64_ret >> 8; new_hdr.routing_result.mark = s64_ret >> 8;
new_hdr.routing_result.must = (s64_ret >> 40) & 1; new_hdr.routing_result.must = (s64_ret >> 40) & 1;

View File

@ -50,7 +50,19 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
Ip: struct{ U6Addr8 [16]uint8 }{U6Addr8: ip6}, Ip: struct{ U6Addr8 [16]uint8 }{U6Addr8: ip6},
Port: common.Htons(src.Port()), Port: common.Htons(src.Port()),
}, &value); e != nil { }, &value); e != nil {
return fmt.Errorf("failed to retrieve target info %v: %v, %v", src.String(), err, e) if c.tproxyPortProtect {
return fmt.Errorf("failed to retrieve target info %v: %v, %v", src.String(), err, e)
} else {
routingResult = &bpfRoutingResult{
Mark: 0,
Must: 0,
Mac: [6]uint8{},
Outbound: uint8(consts.OutboundControlPlaneRouting),
Pname: [16]uint8{},
Pid: 0,
}
goto destRetrieved
}
} }
routingResult = &value.RoutingResult routingResult = &value.RoutingResult
@ -60,6 +72,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
} }
dst = netip.AddrPortFrom(dstAddr, common.Htons(value.Port)) dst = netip.AddrPortFrom(dstAddr, common.Htons(value.Port))
} }
destRetrieved:
src = common.ConvergeAddrPort(src) src = common.ConvergeAddrPort(src)
dst = common.ConvergeAddrPort(dst) dst = common.ConvergeAddrPort(dst)
@ -92,6 +105,9 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
dialTarget, _ = c.ChooseDialTarget(outboundIndex, dst, domain) dialTarget, _ = c.ChooseDialTarget(outboundIndex, dst, domain)
default: default:
} }
if routingResult.Mark == 0 {
routingResult.Mark = c.soMarkFromDae
}
// TODO: Set-up ip to domain mapping and show domain if possible. // TODO: Set-up ip to domain mapping and show domain if possible.
if outboundIndex < 0 || int(outboundIndex) >= len(c.outbounds) { if outboundIndex < 0 || int(outboundIndex) >= len(c.outbounds) {
return fmt.Errorf("outbound id from bpf is out of range: %v not in [0, %v]", outboundIndex, len(c.outbounds)-1) return fmt.Errorf("outbound id from bpf is out of range: %v not in [0, %v]", outboundIndex, len(c.outbounds)-1)
@ -122,7 +138,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
} }
// Dial and relay. // Dial and relay.
rConn, err := d.Dial(MagicNetwork("tcp", routingResult.Mark), dialTarget) rConn, err := d.Dial(common.MagicNetwork("tcp", routingResult.Mark), dialTarget)
if err != nil { if err != nil {
return fmt.Errorf("failed to dial %v: %w", dst, err) return fmt.Errorf("failed to dial %v: %w", dst, err)
} }

View File

@ -48,6 +48,9 @@ func ParseAddrHdr(data []byte) (hdr *bpfDstRoutingResult, dataOffset int, err er
return nil, 0, fmt.Errorf("data is too short to parse AddrHdr") return nil, 0, fmt.Errorf("data is too short to parse AddrHdr")
} }
_hdr := *(*bpfDstRoutingResult)(unsafe.Pointer(&data[0])) _hdr := *(*bpfDstRoutingResult)(unsafe.Pointer(&data[0]))
if _hdr.Recognize != consts.Recognize {
return nil, 0, fmt.Errorf("bad recognize")
}
_hdr.Port = common.Ntohs(_hdr.Port) _hdr.Port = common.Ntohs(_hdr.Port)
return &_hdr, dataOffset, nil return &_hdr, dataOffset, nil
} }
@ -173,6 +176,9 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
dialTarget, _ = c.ChooseDialTarget(outboundIndex, realDst, domain) dialTarget, _ = c.ChooseDialTarget(outboundIndex, realDst, domain)
default: default:
} }
if routingResult.Mark == 0 {
routingResult.Mark = c.soMarkFromDae
}
if isDns { if isDns {
return c.dnsController.Handle_(dnsMessage, &udpRequest{ return c.dnsController.Handle_(dnsMessage, &udpRequest{
lanWanFlag: lanWanFlag, lanWanFlag: lanWanFlag,
@ -226,7 +232,7 @@ getNew:
}, },
NatTimeout: natTimeout, NatTimeout: natTimeout,
Dialer: dialerForNew, Dialer: dialerForNew,
Network: MagicNetwork("udp", routingResult.Mark), Network: common.MagicNetwork("udp", routingResult.Mark),
Target: dialTarget, Target: dialTarget,
}) })
if err != nil { if err != nil {

View File

@ -16,7 +16,6 @@ import (
"github.com/daeuniverse/dae/common" "github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
"github.com/mzz2017/softwind/netproxy"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -160,17 +159,6 @@ func SetSendRedirects(ifname string, val string) {
_ = setSendRedirects(ifname, consts.IpVersionStr_4, val) _ = setSendRedirects(ifname, consts.IpVersionStr_4, val)
} }
func MagicNetwork(network string, mark uint32) string {
if mark == 0 {
return network
} else {
return netproxy.MagicNetwork{
Network: network,
Mark: mark,
}.Encode()
}
}
func ProcessName2String(pname []uint8) string { func ProcessName2String(pname []uint8) string {
return string(bytes.TrimRight(pname[:], string([]byte{0}))) return string(bytes.TrimRight(pname[:], string([]byte{0})))
} }

View File

@ -5,6 +5,14 @@ global {
# In normal case, you do not need to use it. # In normal case, you do not need to use it.
tproxy_port: 12345 tproxy_port: 12345
# Set it true to protect tproxy port from unsolicited traffic. Set it false to allow users to use self-managed
# iptables tproxy rules.
tproxy_port_protect: true
# If not zero, traffic sent from dae will be set SO_MARK. It is useful to avoid traffic loop with iptables tproxy
# rules.
so_mark_from_dae: 0
# Log level: error, warn, info, debug, trace. # Log level: error, warn, info, debug, trace.
log_level: info log_level: info

2
go.mod
View File

@ -11,7 +11,7 @@ require (
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/json-iterator/go v1.1.12 github.com/json-iterator/go v1.1.12
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/mzz2017/softwind v0.0.0-20230501115403-98d9a7116d72 github.com/mzz2017/softwind v0.0.0-20230513064540-9e88f7ce1d9c
github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd
github.com/safchain/ethtool v0.0.0-20230116090318-67cc41908669 github.com/safchain/ethtool v0.0.0-20230116090318-67cc41908669
github.com/sirupsen/logrus v1.9.0 github.com/sirupsen/logrus v1.9.0

4
go.sum
View File

@ -78,8 +78,8 @@ github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
github.com/mzz2017/disk-bloom v1.0.1 h1:rEF9MiXd9qMW3ibRpqcerLXULoTgRlM21yqqJl1B90M= github.com/mzz2017/disk-bloom v1.0.1 h1:rEF9MiXd9qMW3ibRpqcerLXULoTgRlM21yqqJl1B90M=
github.com/mzz2017/disk-bloom v1.0.1/go.mod h1:JLHETtUu44Z6iBmsqzkOtFlRvXSlKnxjwiBRDapizDI= github.com/mzz2017/disk-bloom v1.0.1/go.mod h1:JLHETtUu44Z6iBmsqzkOtFlRvXSlKnxjwiBRDapizDI=
github.com/mzz2017/softwind v0.0.0-20230501115403-98d9a7116d72 h1:h6xMzLtz5pW24T8E+GSdNJ9lRYh5cDpgL85d5c3/om0= github.com/mzz2017/softwind v0.0.0-20230513064540-9e88f7ce1d9c h1:cVIRZXtrbp4Ef69/RcC6Kp/exJ+H1H3T46xfPYDYVCM=
github.com/mzz2017/softwind v0.0.0-20230501115403-98d9a7116d72/go.mod h1:V8GFOtdpTgzCJtCVXRqjmdDsY+PIhCCx4JpD0zq8Z7I= github.com/mzz2017/softwind v0.0.0-20230513064540-9e88f7ce1d9c/go.mod h1:V8GFOtdpTgzCJtCVXRqjmdDsY+PIhCCx4JpD0zq8Z7I=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=