feat: support iptables tproxy (#80)

This commit is contained in:
mzz 2023-06-04 11:38:05 +08:00 committed by GitHub
parent cbcbec9a1a
commit ee09ae17e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 313 additions and 229 deletions

View File

@ -4,8 +4,6 @@ on:
push:
branches:
- main
- fix*
- feat*
paths:
- "**/*.go"
- "**/*.c"

View File

@ -23,6 +23,7 @@ else
STRIP_FLAG := -strip=$(STRIP_PATH)
endif
# Do NOT remove the line below. This line is for CI.
#export GOMODCACHE=$(PWD)/go-mod
# Get version from .git.

View File

@ -1,9 +1,12 @@
package cmd
import (
"context"
"errors"
"fmt"
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/protocol/direct"
"net"
"net/http"
"os"
@ -247,6 +250,20 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
if !conf.Global.DisableWaitingNetwork && len(conf.Subscription) > 0 {
epo := 5 * time.Second
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialer{Dialer: direct.SymmetricDirect}
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae), addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: conn,
LAddr: nil,
RAddr: nil,
}, nil
},
},
Timeout: epo,
}
log.Infoln("Waiting for network...")
@ -274,8 +291,25 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
if len(conf.Subscription) > 0 {
log.Infoln("Fetching subscriptions...")
}
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialer{Dialer: direct.SymmetricDirect}
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae), addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: conn,
LAddr: nil,
RAddr: nil,
}, nil
},
},
Timeout: 30 * time.Second,
}
for _, sub := range conf.Subscription {
tag, nodes, err := subscription.ResolveSubscription(log, filepath.Dir(cfgFile), string(sub))
tag, nodes, err := subscription.ResolveSubscription(log, &client, filepath.Dir(cfgFile), string(sub))
if err != nil {
log.Warnf(`failed to resolve subscription "%v": %v`, sub, err)
resolvingfailed = true

View File

@ -146,6 +146,7 @@ var (
const (
TproxyMark uint32 = 0x8000000
Recognize uint16 = 0x2017
LoopbackIfIndex = 1
)

View File

@ -19,7 +19,6 @@ import (
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/pool"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
@ -91,8 +90,8 @@ func SystemDns() (dns netip.AddrPort, err error) {
return systemDns, nil
}
func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, tcp bool) (addrs []netip.Addr, err error) {
resources, err := resolve(ctx, d, dns, host, typ, tcp)
func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (addrs []netip.Addr, err error) {
resources, err := resolve(ctx, d, dns, host, typ, network)
if err != nil {
return nil, err
}
@ -118,16 +117,14 @@ func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, ho
return addrs, nil
}
func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, tcp bool) (records []string, err error) {
func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, network string) (records []string, err error) {
typ := dnsmessage.TypeNS
resources, err := resolve(ctx, d, dns, host, typ, tcp)
resources, err := resolve(ctx, d, dns, host, typ, network)
if err != nil {
return nil, err
}
logrus.Println(host, len(resources))
for _, ans := range resources {
if ans.Header.Type != typ {
logrus.Println(host, ans.Header.Type)
continue
}
ns, ok := ans.Body.(*dnsmessage.NSResource)
@ -139,7 +136,7 @@ func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host
return records, nil
}
func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, tcp bool) (ans []dnsmessage.Resource, err error) {
func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (ans []dnsmessage.Resource, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
fqdn := host
@ -202,7 +199,11 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
if err != nil {
return nil, err
}
if tcp {
magicNetwork, err := netproxy.ParseMagicNetwork(network)
if err != nil {
return nil, err
}
if magicNetwork.Network == "tcp" {
// Put DNS request length
buf := pool.Get(2 + len(b))
defer pool.Put(buf)
@ -213,12 +214,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
// Dial and write.
cd := &netproxy.ContextDialer{Dialer: d}
var c netproxy.Conn
if tcp {
c, err = cd.DialTcpContext(ctx, dns.String())
} else {
c, err = cd.DialUdpContext(ctx, dns.String())
}
c, err := cd.DialContext(ctx, network, dns.String())
if err != nil {
return nil, err
}
@ -228,7 +224,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
return nil, err
}
ch := make(chan error, 2)
if !tcp {
if magicNetwork.Network == "udp" {
go func() {
// Resend every 3 seconds for UDP.
for {
@ -249,7 +245,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
go func() {
buf := pool.Get(512)
defer pool.Put(buf)
if tcp {
if magicNetwork.Network == "tcp" {
// Read DNS response length
_, err := io.ReadFull(c, buf[:2])
if err != nil {

View File

@ -22,7 +22,7 @@ type Ip46 struct {
Ip6 netip.Addr
}
func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, tcp bool, race bool) (ipv46 *Ip46, err error) {
func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, network string, race bool) (ipv46 *Ip46, err error) {
var log *logrus.Logger
if _log := ctx.Value("logger"); _log != nil {
log = _log.(*logrus.Logger)
@ -49,7 +49,7 @@ func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort
}
}()
var e error
addrs4, e = ResolveNetip(ctx4, dialer, dns, host, dnsmessage.TypeA, tcp)
addrs4, e = ResolveNetip(ctx4, dialer, dns, host, dnsmessage.TypeA, network)
if err != nil && !errors.Is(e, context.Canceled) {
err4 = e
return
@ -67,7 +67,7 @@ func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort
}
}()
var e error
addrs6, e = ResolveNetip(ctx6, dialer, dns, host, dnsmessage.TypeAAAA, tcp)
addrs6, e = ResolveNetip(ctx6, dialer, dns, host, dnsmessage.TypeAAAA, network)
if err != nil && !errors.Is(e, context.Canceled) {
err6 = e
return

View File

@ -137,7 +137,7 @@ func ResolveFile(u *url.URL, configDir string) (b []byte, err error) {
return bytes.TrimSpace(b), err
}
func ResolveSubscription(log *logrus.Logger, configDir string, subscription string) (tag string, nodes []string, err error) {
func ResolveSubscription(log *logrus.Logger, client *http.Client, configDir string, subscription string) (tag string, nodes []string, err error) {
/// Get tag.
tag, subscription = common.GetTagFromLinkLikePlaintext(subscription)
@ -160,7 +160,7 @@ func ResolveSubscription(log *logrus.Logger, configDir string, subscription stri
goto resolve
default:
}
resp, err = http.Get(subscription)
resp, err = client.Get(subscription)
if err != nil {
return "", nil, err
}

View File

@ -12,6 +12,7 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
"github.com/mzz2017/softwind/netproxy"
"net/netip"
"net/url"
"path/filepath"
@ -221,25 +222,25 @@ func FuzzyDecode(to interface{}, val string) bool {
v := reflect.Indirect(reflect.ValueOf(to))
switch v.Kind() {
case reflect.Int:
i, err := strconv.ParseInt(val, 10, strconv.IntSize)
i, err := strconv.ParseInt(val, 0, strconv.IntSize)
if err != nil {
return false
}
v.SetInt(i)
case reflect.Int8:
i, err := strconv.ParseInt(val, 10, 8)
i, err := strconv.ParseInt(val, 0, 8)
if err != nil {
return false
}
v.SetInt(i)
case reflect.Int16:
i, err := strconv.ParseInt(val, 10, 16)
i, err := strconv.ParseInt(val, 0, 16)
if err != nil {
return false
}
v.SetInt(i)
case reflect.Int32:
i, err := strconv.ParseInt(val, 10, 32)
i, err := strconv.ParseInt(val, 0, 32)
if err != nil {
return false
}
@ -253,38 +254,38 @@ func FuzzyDecode(to interface{}, val string) bool {
}
v.Set(reflect.ValueOf(duration))
default:
i, err := strconv.ParseInt(val, 10, 64)
i, err := strconv.ParseInt(val, 0, 64)
if err != nil {
return false
}
v.SetInt(i)
}
case reflect.Uint:
i, err := strconv.ParseUint(val, 10, strconv.IntSize)
i, err := strconv.ParseUint(val, 0, strconv.IntSize)
if err != nil {
return false
}
v.SetUint(i)
case reflect.Uint8:
i, err := strconv.ParseUint(val, 10, 8)
i, err := strconv.ParseUint(val, 0, 8)
if err != nil {
return false
}
v.SetUint(i)
case reflect.Uint16:
i, err := strconv.ParseUint(val, 10, 16)
i, err := strconv.ParseUint(val, 0, 16)
if err != nil {
return false
}
v.SetUint(i)
case reflect.Uint32:
i, err := strconv.ParseUint(val, 10, 32)
i, err := strconv.ParseUint(val, 0, 32)
if err != nil {
return false
}
v.SetUint(i)
case reflect.Uint64:
i, err := strconv.ParseUint(val, 10, 64)
i, err := strconv.ParseUint(val, 0, 64)
if err != nil {
return false
}
@ -458,6 +459,17 @@ nextLink:
return Deduplicate(defaultIfs), nil
}
func MagicNetwork(network string, mark uint32) string {
if mark == 0 {
return network
} else {
return netproxy.MagicNetwork{
Network: network,
Mark: mark,
}.Encode()
}
}
func IsValidHttpMethod(method string) bool {
switch method {
case "GET", "POST", "PUT", "PATCH", "DELETE", "COPY", "HEAD", "OPTIONS", "LINK", "UNLINK", "PURGE", "LOCK", "UNLOCK", "PROPFIND", "CONNECT", "TRACE":

View File

@ -35,6 +35,7 @@ type NewOption struct {
Logger *logrus.Logger
LocationFinder *assets.LocationFinder
UpstreamReadyCallback func(dnsUpstream *Upstream) (err error)
UpstreamResolverNetwork string
}
func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
@ -63,6 +64,7 @@ func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
}
r := &UpstreamResolver{
Raw: u,
Network: opt.UpstreamResolverNetwork,
FinishInitCallback: func(i int) func(raw *url.URL, upstream *Upstream) (err error) {
return func(raw *url.URL, upstream *Upstream) (err error) {
if opt != nil && opt.UpstreamReadyCallback != nil {
@ -77,6 +79,9 @@ func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
return nil
}
}(i),
mu: sync.Mutex{},
upstream: nil,
init: false,
}
upstreamName2Id[tag] = uint8(len(s.upstream))
s.upstream = append(s.upstream, r)

View File

@ -72,7 +72,7 @@ type Upstream struct {
*netutils.Ip46
}
func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err error) {
func NewUpstream(ctx context.Context, upstream *url.URL, resolverNetwork string) (up *Upstream, err error) {
scheme, hostname, port, err := ParseRawUpstream(upstream)
if err != nil {
return nil, fmt.Errorf("%w: %v", FormatError, err)
@ -88,7 +88,7 @@ func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err erro
}
}()
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, false, false)
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, resolverNetwork, false)
if err != nil {
return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
}
@ -132,6 +132,7 @@ func (u *Upstream) String() string {
type UpstreamResolver struct {
Raw *url.URL
Network string
// FinishInitCallback may be invoked again if err is not nil
FinishInitCallback func(raw *url.URL, upstream *Upstream) (err error)
mu sync.Mutex
@ -154,7 +155,7 @@ func (u *UpstreamResolver) GetUpstream() (_ *Upstream, err error) {
}()
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
if u.upstream, err = NewUpstream(ctx, u.Raw); err != nil {
if u.upstream, err = NewUpstream(ctx, u.Raw, u.Network); err != nil {
return nil, fmt.Errorf("failed to init dns upstream: %w", err)
}
}

View File

@ -9,6 +9,7 @@ import (
"context"
"errors"
"fmt"
"github.com/daeuniverse/dae/common"
"net"
"net/http"
"net/netip"
@ -121,7 +122,7 @@ type TcpCheckOption struct {
Method string
}
func ParseTcpCheckOption(ctx context.Context, rawURL []string, method string) (opt *TcpCheckOption, err error) {
func ParseTcpCheckOption(ctx context.Context, rawURL []string, method string, resolverNetwork string) (opt *TcpCheckOption, err error) {
if method == "" {
method = http.MethodGet
}
@ -146,7 +147,7 @@ func ParseTcpCheckOption(ctx context.Context, rawURL []string, method string) (o
if len(rawURL) > 1 {
ip46 = parseIp46FromList(rawURL[1:])
} else {
ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false, false)
ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), resolverNetwork, false)
if err != nil {
return nil, err
}
@ -164,7 +165,7 @@ type CheckDnsOption struct {
*netutils.Ip46
}
func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string) (opt *CheckDnsOption, err error) {
func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string, resolverNetwork string) (opt *CheckDnsOption, err error) {
systemDns, err := netutils.SystemDns()
if err != nil {
return nil, err
@ -191,7 +192,7 @@ func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string) (opt *CheckD
if len(dnsHostPort) > 1 {
ip46 = parseIp46FromList(dnsHostPort[1:])
} else {
ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, false, false)
ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, resolverNetwork, false)
if err != nil {
return nil, err
}
@ -208,6 +209,7 @@ type TcpCheckOptionRaw struct {
mu sync.Mutex
Log *logrus.Logger
Raw []string
ResolverNetwork string
Method string
}
@ -218,7 +220,7 @@ func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
ctx = context.WithValue(ctx, "logger", c.Log)
tcpCheckOption, err := ParseTcpCheckOption(ctx, c.Raw, c.Method)
tcpCheckOption, err := ParseTcpCheckOption(ctx, c.Raw, c.Method, c.ResolverNetwork)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
@ -231,6 +233,7 @@ type CheckDnsOptionRaw struct {
opt *CheckDnsOption
mu sync.Mutex
Raw []string
ResolverNetwork string
}
func (c *CheckDnsOptionRaw) Option() (opt *CheckDnsOption, err error) {
@ -239,7 +242,7 @@ func (c *CheckDnsOptionRaw) Option() (opt *CheckDnsOption, err error) {
if c.opt == nil {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
udpCheckOption, err := ParseCheckDnsOption(ctx, c.Raw)
udpCheckOption, err := ParseCheckDnsOption(ctx, c.Raw, c.ResolverNetwork)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
@ -266,6 +269,10 @@ func (d *Dialer) ActivateCheck() {
func (d *Dialer) aliveBackground() {
timeout := 10 * time.Second
cycle := d.CheckInterval
var tcpSomark uint32
if network, err := netproxy.ParseMagicNetwork(d.TcpCheckOptionRaw.ResolverNetwork); err == nil {
tcpSomark = network.Mark
}
tcp4CheckOpt := &CheckOption{
networkType: &NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
@ -285,7 +292,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.HttpCheck(ctx, opt.Url, opt.Ip4, opt.Method)
return d.HttpCheck(ctx, opt.Url, opt.Ip4, opt.Method, tcpSomark)
},
}
tcp6CheckOpt := &CheckOption{
@ -307,7 +314,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.HttpCheck(ctx, opt.Url, opt.Ip6, opt.Method)
return d.HttpCheck(ctx, opt.Url, opt.Ip6, opt.Method, tcpSomark)
},
}
tcp4CheckDnsOpt := &CheckOption{
@ -329,7 +336,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), true)
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
},
}
tcp6CheckDnsOpt := &CheckOption{
@ -351,7 +358,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), true)
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
},
}
udp4CheckDnsOpt := &CheckOption{
@ -372,7 +379,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), false)
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
},
}
udp6CheckDnsOpt := &CheckOption{
@ -393,7 +400,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), false)
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
},
}
var CheckOpts = []*CheckOption{
@ -535,7 +542,7 @@ func (d *Dialer) Check(timeout time.Duration,
return ok, err
}
func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr, method string) (ok bool, err error) {
func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr, method string, soMark uint32) (ok bool, err error) {
// HTTP(S) check.
if method == "" {
method = http.MethodGet
@ -545,7 +552,7 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
// Force to dial "ip".
conn, err := cd.DialTcpContext(ctx, net.JoinHostPort(ip.String(), u.Port()))
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", soMark), net.JoinHostPort(ip.String(), u.Port()))
if err != nil {
return nil, err
}
@ -584,8 +591,8 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
}
}
func (d *Dialer) DnsCheck(ctx context.Context, dns netip.AddrPort, tcp bool) (ok bool, err error) {
addrs, err := netutils.ResolveNetip(ctx, d, dns, consts.UdpCheckLookupHost, dnsmessage.TypeA, tcp)
func (d *Dialer) DnsCheck(ctx context.Context, dns netip.AddrPort, network string) (ok bool, err error) {
addrs, err := netutils.ResolveNetip(ctx, d, dns, consts.UdpCheckLookupHost, dnsmessage.TypeA, network)
if err != nil {
return false, err
}

View File

@ -2,6 +2,7 @@ package trojan
import (
"fmt"
"github.com/daeuniverse/dae/component/outbound/transport/tls"
"net"
"net/url"
"strconv"
@ -9,7 +10,6 @@ import (
"github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/daeuniverse/dae/component/outbound/transport/tls"
"github.com/daeuniverse/dae/component/outbound/transport/ws"
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/protocol"

View File

@ -63,22 +63,7 @@ func (s *SimpleObfs) Dial(network, addr string) (c netproxy.Conn, err error) {
}
switch magicNetwork.Network {
case "tcp":
return s.DialTcp(addr)
case "udp":
return s.DialUdp(addr)
default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
}
}
func (s *SimpleObfs) DialUdp(addr string) (conn netproxy.PacketConn, err error) {
return nil, fmt.Errorf("%w: simpleobfs+udp", netproxy.UnsupportedTunnelTypeError)
}
// DialTcp connects to the address addr on the network net via the proxy.
func (s *SimpleObfs) DialTcp(addr string) (c netproxy.Conn, err error) {
rc, err := s.dialer.DialTcp(s.addr)
rc, err := s.dialer.Dial(network, s.addr)
if err != nil {
return nil, fmt.Errorf("[simpleobfs]: dial to %s: %w", s.addr, err)
}
@ -97,4 +82,9 @@ func (s *SimpleObfs) DialTcp(addr string) (c netproxy.Conn, err error) {
c = NewTLSObfs(rc, host)
}
return c, err
case "udp":
return nil, fmt.Errorf("%w: simpleobfs+udp", netproxy.UnsupportedTunnelTypeError)
default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
}
}

View File

@ -61,20 +61,7 @@ func (s *Tls) Dial(network, addr string) (c netproxy.Conn, err error) {
}
switch magicNetwork.Network {
case "tcp":
return s.DialTcp(addr)
case "udp":
return s.DialUdp(addr)
default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
}
}
func (s *Tls) DialUdp(addr string) (conn netproxy.PacketConn, err error) {
return nil, fmt.Errorf("%w: tls+udp", netproxy.UnsupportedTunnelTypeError)
}
func (s *Tls) DialTcp(addr string) (conn netproxy.Conn, err error) {
rc, err := s.dialer.DialTcp(addr)
rc, err := s.dialer.Dial(network, addr)
if err != nil {
return nil, fmt.Errorf("[Tls]: dial to %s: %w", s.addr, err)
}
@ -112,4 +99,9 @@ func (s *Tls) DialTcp(addr string) (conn netproxy.Conn, err error) {
return nil, err
}
return tlsConn, err
case "udp":
return nil, fmt.Errorf("%w: tls+udp", netproxy.UnsupportedTunnelTypeError)
default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
}
}

View File

@ -16,7 +16,7 @@ type Ws struct {
dialer netproxy.Dialer
wsAddr string
header http.Header
wsDialer *websocket.Dialer
tlsClientConfig *tls.Config
}
// NewWs returns a Ws infra.
@ -43,23 +43,9 @@ func NewWs(s string, d netproxy.Dialer) (*Ws, error) {
Host: u.Host,
}
t.wsAddr = wsUrl.String() + u.Path
t.wsDialer = &websocket.Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
c, err := d.DialTcp(addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: c,
LAddr: nil,
RAddr: nil,
}, nil
},
//Subprotocols: []string{"binary"},
}
if u.Scheme == "wss" {
skipVerify, _ := strconv.ParseBool(u.Query().Get("allowInsecure"))
t.wsDialer.TLSClientConfig = &tls.Config{
t.tlsClientConfig = &tls.Config{
ServerName: u.Query().Get("sni"),
InsecureSkipVerify: skipVerify,
}
@ -74,23 +60,28 @@ func (s *Ws) Dial(network, addr string) (c netproxy.Conn, err error) {
}
switch magicNetwork.Network {
case "tcp":
return s.DialTcp(addr)
case "udp":
return s.DialUdp(addr)
default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
wsDialer := &websocket.Dialer{
NetDial: func(_, addr string) (net.Conn, error) {
c, err := s.dialer.Dial(network, addr)
if err != nil {
return nil, err
}
}
func (s *Ws) DialUdp(addr string) (netproxy.PacketConn, error) {
return nil, fmt.Errorf("%w: ws+udp", netproxy.UnsupportedTunnelTypeError)
}
// DialTcp connects to the address addr on the network net via the infra.
func (s *Ws) DialTcp(addr string) (netproxy.Conn, error) {
rc, _, err := s.wsDialer.Dial(s.wsAddr, s.header)
return &netproxy.FakeNetConn{
Conn: c,
LAddr: nil,
RAddr: nil,
}, nil
},
//Subprotocols: []string{"binary"},
}
rc, _, err := wsDialer.Dial(s.wsAddr, s.header)
if err != nil {
return nil, fmt.Errorf("[Ws]: dial to %s: %w", s.wsAddr, err)
}
return newConn(rc), err
case "udp":
return nil, fmt.Errorf("%w: ws+udp", netproxy.UnsupportedTunnelTypeError)
default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
}
}

View File

@ -15,6 +15,8 @@ import (
type Global struct {
TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"`
TproxyPortProtect bool `mapstructure:"tproxy_port_protect" default:"true"`
SoMarkFromDae uint32 `mapstructure:"so_mark_from_dae"`
LogLevel string `mapstructure:"log_level" default:"info"`
// We use DirectTcpCheckUrl to check (tcp)*(ipv4/ipv6) connectivity for direct.
//DirectTcpCheckUrl string `mapstructure:"direct_tcp_check_url" default:"http://www.qualcomm.cn/generate_204"`

View File

@ -36,6 +36,8 @@ var SectionDescription = map[string]Desc{
var GlobalDesc = Desc{
"tproxy_port": "tproxy port to listen on. It is NOT a HTTP/SOCKS port, and is just used by eBPF program.\nIn normal case, you do not need to use it.",
"tproxy_port_protect": "Set it true to protect tproxy port from unsolicited traffic. Set it false to allow users to use self-managed iptables tproxy rules.",
"so_mark_from_dae": "If not zero, traffic sent from dae will be set SO_MARK. It is useful to avoid traffic loop with iptables tproxy rules.",
"log_level": "Log level: error, warn, info, debug, trace.",
"tcp_check_url": "Node connectivity check.\nHost of URL should have both IPv4 and IPv6 if you have double stack in local.\nConsidering traffic consumption, it is recommended to choose a site with anycast IP and less response.",
"tcp_check_http_method": "The HTTP request method to `tcp_check_url`. Use 'CONNECT' by default because some server implementations bypass accounting for this kind of traffic.",

View File

@ -68,6 +68,8 @@ type ControlPlane struct {
lanInterface []string
sniffingTimeout time.Duration
tproxyPortProtect bool
soMarkFromDae uint32
}
func NewControlPlane(
@ -227,8 +229,16 @@ func NewControlPlane(
}
option := &dialer.GlobalOption{
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: global.TcpCheckUrl, Log: log, Method: global.TcpCheckHttpMethod},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: global.UdpCheckDns},
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{
Raw: global.TcpCheckUrl,
Log: log,
ResolverNetwork: common.MagicNetwork("udp", global.SoMarkFromDae),
Method: global.TcpCheckHttpMethod,
},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{
Raw: global.UdpCheckDns,
ResolverNetwork: common.MagicNetwork("udp", global.SoMarkFromDae),
},
CheckInterval: global.CheckInterval,
CheckTolerance: global.CheckTolerance,
CheckDnsTcp: true,
@ -354,6 +364,8 @@ func NewControlPlane(
lanInterface: global.LanInterface,
wanInterface: global.WanInterface,
sniffingTimeout: sniffingTimeout,
tproxyPortProtect: global.TproxyPortProtect,
soMarkFromDae: global.SoMarkFromDae,
}
defer func() {
if err != nil {
@ -366,6 +378,7 @@ func NewControlPlane(
Logger: log,
LocationFinder: locationFinder,
UpstreamReadyCallback: plane.dnsUpstreamReadyCallback,
UpstreamResolverNetwork: common.MagicNetwork("udp", global.SoMarkFromDae),
})
if err != nil {
return nil, err
@ -559,7 +572,7 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip
// TODO: use DNS controller and re-route by control plane.
systemDns, err := netutils.SystemDns()
if err == nil {
if ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, domain, false, true); err == nil && (ip46.Ip4.IsValid() || ip46.Ip6.IsValid()) {
if ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, domain, common.MagicNetwork("udp", c.soMarkFromDae), true); err == nil && (ip46.Ip4.IsValid() || ip46.Ip6.IsValid()) {
// Has A/AAAA records. It is a real domain.
dialMode = consts.DialMode_Domain
// Add it to real-domain set.
@ -717,8 +730,21 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err
lastErr := err
addrHdr, dataOffset, err := ParseAddrHdr(data)
if err != nil {
if c.tproxyPortProtect {
c.log.Warnf("No AddrPort presented: %v, %v", lastErr, err)
return
} else {
routingResult = &bpfRoutingResult{
Mark: 0,
Must: 0,
Mac: [6]uint8{},
Outbound: uint8(consts.OutboundControlPlaneRouting),
Pname: [16]uint8{},
Pid: 0,
}
realDst = pktDst
goto destRetrieved
}
}
n := copy(data, data[dataOffset:])
data = data[:n]
@ -731,6 +757,7 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err
} else {
realDst = pktDst
}
destRetrieved:
if e := c.handlePkt(udpConn, data, common.ConvergeAddrPort(src), common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult); e != nil {
c.log.Warnln("handlePkt:", e)
}
@ -814,6 +841,9 @@ func (c *ControlPlane) chooseBestDnsDialer(
if err != nil {
return nil, err
}
if mark == 0 {
mark = c.soMarkFromDae
}
if int(outboundIndex) >= len(c.outbounds) {
return nil, fmt.Errorf("bad outbound index: %v", outboundIndex)
}

View File

@ -10,6 +10,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"github.com/daeuniverse/dae/common"
"io"
"math"
"net"
@ -652,7 +653,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
// TODO: connection pool.
conn, err = dialArgument.bestDialer.Dial(
MagicNetwork("udp", dialArgument.mark),
common.MagicNetwork("udp", dialArgument.mark),
dialArgument.bestTarget.String(),
)
if err != nil {
@ -714,7 +715,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
case consts.L4ProtoStr_TCP:
// We can block here because we are in a coroutine.
conn, err = dialArgument.bestDialer.Dial(MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String())
conn, err = dialArgument.bestDialer.Dial(common.MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String())
if err != nil {
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
}

View File

@ -64,6 +64,7 @@
#define IS_LAN 1
#define TPROXY_MARK 0x8000000
#define RECOGNIZE 0x2017
#define ESOCKTNOSUPPORT 94 /* Socket type not supported */
@ -139,6 +140,7 @@ struct routing_result {
struct dst_routing_result {
__be32 ip[4];
__be16 port;
__u16 recognize;
struct routing_result routing_result;
};
@ -1751,6 +1753,7 @@ int tproxy_wan_egress(struct __sk_buff *skb) {
__builtin_memset(&new_hdr, 0, sizeof(new_hdr));
__builtin_memcpy(new_hdr.ip, &tuples.dip, IPV6_BYTE_LENGTH);
new_hdr.port = udph.dest;
new_hdr.recognize = RECOGNIZE;
new_hdr.routing_result.outbound = s64_ret;
new_hdr.routing_result.mark = s64_ret >> 8;
new_hdr.routing_result.must = (s64_ret >> 40) & 1;

View File

@ -50,7 +50,19 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
Ip: struct{ U6Addr8 [16]uint8 }{U6Addr8: ip6},
Port: common.Htons(src.Port()),
}, &value); e != nil {
if c.tproxyPortProtect {
return fmt.Errorf("failed to retrieve target info %v: %v, %v", src.String(), err, e)
} else {
routingResult = &bpfRoutingResult{
Mark: 0,
Must: 0,
Mac: [6]uint8{},
Outbound: uint8(consts.OutboundControlPlaneRouting),
Pname: [16]uint8{},
Pid: 0,
}
goto destRetrieved
}
}
routingResult = &value.RoutingResult
@ -60,6 +72,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
}
dst = netip.AddrPortFrom(dstAddr, common.Htons(value.Port))
}
destRetrieved:
src = common.ConvergeAddrPort(src)
dst = common.ConvergeAddrPort(dst)
@ -92,6 +105,9 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
dialTarget, _ = c.ChooseDialTarget(outboundIndex, dst, domain)
default:
}
if routingResult.Mark == 0 {
routingResult.Mark = c.soMarkFromDae
}
// TODO: Set-up ip to domain mapping and show domain if possible.
if outboundIndex < 0 || int(outboundIndex) >= len(c.outbounds) {
return fmt.Errorf("outbound id from bpf is out of range: %v not in [0, %v]", outboundIndex, len(c.outbounds)-1)
@ -122,7 +138,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
}
// Dial and relay.
rConn, err := d.Dial(MagicNetwork("tcp", routingResult.Mark), dialTarget)
rConn, err := d.Dial(common.MagicNetwork("tcp", routingResult.Mark), dialTarget)
if err != nil {
return fmt.Errorf("failed to dial %v: %w", dst, err)
}

View File

@ -48,6 +48,9 @@ func ParseAddrHdr(data []byte) (hdr *bpfDstRoutingResult, dataOffset int, err er
return nil, 0, fmt.Errorf("data is too short to parse AddrHdr")
}
_hdr := *(*bpfDstRoutingResult)(unsafe.Pointer(&data[0]))
if _hdr.Recognize != consts.Recognize {
return nil, 0, fmt.Errorf("bad recognize")
}
_hdr.Port = common.Ntohs(_hdr.Port)
return &_hdr, dataOffset, nil
}
@ -173,6 +176,9 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
dialTarget, _ = c.ChooseDialTarget(outboundIndex, realDst, domain)
default:
}
if routingResult.Mark == 0 {
routingResult.Mark = c.soMarkFromDae
}
if isDns {
return c.dnsController.Handle_(dnsMessage, &udpRequest{
lanWanFlag: lanWanFlag,
@ -226,7 +232,7 @@ getNew:
},
NatTimeout: natTimeout,
Dialer: dialerForNew,
Network: MagicNetwork("udp", routingResult.Mark),
Network: common.MagicNetwork("udp", routingResult.Mark),
Target: dialTarget,
})
if err != nil {

View File

@ -16,7 +16,6 @@ import (
"github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/consts"
"github.com/mzz2017/softwind/netproxy"
"golang.org/x/sys/unix"
)
@ -160,17 +159,6 @@ func SetSendRedirects(ifname string, val string) {
_ = setSendRedirects(ifname, consts.IpVersionStr_4, val)
}
func MagicNetwork(network string, mark uint32) string {
if mark == 0 {
return network
} else {
return netproxy.MagicNetwork{
Network: network,
Mark: mark,
}.Encode()
}
}
func ProcessName2String(pname []uint8) string {
return string(bytes.TrimRight(pname[:], string([]byte{0})))
}

View File

@ -5,6 +5,14 @@ global {
# In normal case, you do not need to use it.
tproxy_port: 12345
# Set it true to protect tproxy port from unsolicited traffic. Set it false to allow users to use self-managed
# iptables tproxy rules.
tproxy_port_protect: true
# If not zero, traffic sent from dae will be set SO_MARK. It is useful to avoid traffic loop with iptables tproxy
# rules.
so_mark_from_dae: 0
# Log level: error, warn, info, debug, trace.
log_level: info

2
go.mod
View File

@ -11,7 +11,7 @@ require (
github.com/gorilla/websocket v1.5.0
github.com/json-iterator/go v1.1.12
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/mzz2017/softwind v0.0.0-20230501115403-98d9a7116d72
github.com/mzz2017/softwind v0.0.0-20230513064540-9e88f7ce1d9c
github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd
github.com/safchain/ethtool v0.0.0-20230116090318-67cc41908669
github.com/sirupsen/logrus v1.9.0

4
go.sum
View File

@ -78,8 +78,8 @@ github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
github.com/mzz2017/disk-bloom v1.0.1 h1:rEF9MiXd9qMW3ibRpqcerLXULoTgRlM21yqqJl1B90M=
github.com/mzz2017/disk-bloom v1.0.1/go.mod h1:JLHETtUu44Z6iBmsqzkOtFlRvXSlKnxjwiBRDapizDI=
github.com/mzz2017/softwind v0.0.0-20230501115403-98d9a7116d72 h1:h6xMzLtz5pW24T8E+GSdNJ9lRYh5cDpgL85d5c3/om0=
github.com/mzz2017/softwind v0.0.0-20230501115403-98d9a7116d72/go.mod h1:V8GFOtdpTgzCJtCVXRqjmdDsY+PIhCCx4JpD0zq8Z7I=
github.com/mzz2017/softwind v0.0.0-20230513064540-9e88f7ce1d9c h1:cVIRZXtrbp4Ef69/RcC6Kp/exJ+H1H3T46xfPYDYVCM=
github.com/mzz2017/softwind v0.0.0-20230513064540-9e88f7ce1d9c/go.mod h1:V8GFOtdpTgzCJtCVXRqjmdDsY+PIhCCx4JpD0zq8Z7I=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=