feat: support iptables tproxy (#80)

This commit is contained in:
mzz
2023-06-04 11:38:05 +08:00
committed by GitHub
parent cbcbec9a1a
commit ee09ae17e3
26 changed files with 313 additions and 229 deletions

View File

@ -32,9 +32,10 @@ type Dns struct {
}
type NewOption struct {
Logger *logrus.Logger
LocationFinder *assets.LocationFinder
UpstreamReadyCallback func(dnsUpstream *Upstream) (err error)
Logger *logrus.Logger
LocationFinder *assets.LocationFinder
UpstreamReadyCallback func(dnsUpstream *Upstream) (err error)
UpstreamResolverNetwork string
}
func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
@ -62,7 +63,8 @@ func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
return nil, fmt.Errorf("%w: %v", BadUpstreamFormatError, err)
}
r := &UpstreamResolver{
Raw: u,
Raw: u,
Network: opt.UpstreamResolverNetwork,
FinishInitCallback: func(i int) func(raw *url.URL, upstream *Upstream) (err error) {
return func(raw *url.URL, upstream *Upstream) (err error) {
if opt != nil && opt.UpstreamReadyCallback != nil {
@ -77,6 +79,9 @@ func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
return nil
}
}(i),
mu: sync.Mutex{},
upstream: nil,
init: false,
}
upstreamName2Id[tag] = uint8(len(s.upstream))
s.upstream = append(s.upstream, r)

View File

@ -72,7 +72,7 @@ type Upstream struct {
*netutils.Ip46
}
func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err error) {
func NewUpstream(ctx context.Context, upstream *url.URL, resolverNetwork string) (up *Upstream, err error) {
scheme, hostname, port, err := ParseRawUpstream(upstream)
if err != nil {
return nil, fmt.Errorf("%w: %v", FormatError, err)
@ -88,7 +88,7 @@ func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err erro
}
}()
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, false, false)
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, resolverNetwork, false)
if err != nil {
return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
}
@ -131,7 +131,8 @@ func (u *Upstream) String() string {
}
type UpstreamResolver struct {
Raw *url.URL
Raw *url.URL
Network string
// FinishInitCallback may be invoked again if err is not nil
FinishInitCallback func(raw *url.URL, upstream *Upstream) (err error)
mu sync.Mutex
@ -154,7 +155,7 @@ func (u *UpstreamResolver) GetUpstream() (_ *Upstream, err error) {
}()
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
if u.upstream, err = NewUpstream(ctx, u.Raw); err != nil {
if u.upstream, err = NewUpstream(ctx, u.Raw, u.Network); err != nil {
return nil, fmt.Errorf("failed to init dns upstream: %w", err)
}
}

View File

@ -9,6 +9,7 @@ import (
"context"
"errors"
"fmt"
"github.com/daeuniverse/dae/common"
"net"
"net/http"
"net/netip"
@ -121,7 +122,7 @@ type TcpCheckOption struct {
Method string
}
func ParseTcpCheckOption(ctx context.Context, rawURL []string, method string) (opt *TcpCheckOption, err error) {
func ParseTcpCheckOption(ctx context.Context, rawURL []string, method string, resolverNetwork string) (opt *TcpCheckOption, err error) {
if method == "" {
method = http.MethodGet
}
@ -146,7 +147,7 @@ func ParseTcpCheckOption(ctx context.Context, rawURL []string, method string) (o
if len(rawURL) > 1 {
ip46 = parseIp46FromList(rawURL[1:])
} else {
ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false, false)
ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), resolverNetwork, false)
if err != nil {
return nil, err
}
@ -164,7 +165,7 @@ type CheckDnsOption struct {
*netutils.Ip46
}
func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string) (opt *CheckDnsOption, err error) {
func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string, resolverNetwork string) (opt *CheckDnsOption, err error) {
systemDns, err := netutils.SystemDns()
if err != nil {
return nil, err
@ -191,7 +192,7 @@ func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string) (opt *CheckD
if len(dnsHostPort) > 1 {
ip46 = parseIp46FromList(dnsHostPort[1:])
} else {
ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, false, false)
ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, resolverNetwork, false)
if err != nil {
return nil, err
}
@ -204,11 +205,12 @@ func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string) (opt *CheckD
}
type TcpCheckOptionRaw struct {
opt *TcpCheckOption
mu sync.Mutex
Log *logrus.Logger
Raw []string
Method string
opt *TcpCheckOption
mu sync.Mutex
Log *logrus.Logger
Raw []string
ResolverNetwork string
Method string
}
func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) {
@ -218,7 +220,7 @@ func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
ctx = context.WithValue(ctx, "logger", c.Log)
tcpCheckOption, err := ParseTcpCheckOption(ctx, c.Raw, c.Method)
tcpCheckOption, err := ParseTcpCheckOption(ctx, c.Raw, c.Method, c.ResolverNetwork)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
@ -228,9 +230,10 @@ func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) {
}
type CheckDnsOptionRaw struct {
opt *CheckDnsOption
mu sync.Mutex
Raw []string
opt *CheckDnsOption
mu sync.Mutex
Raw []string
ResolverNetwork string
}
func (c *CheckDnsOptionRaw) Option() (opt *CheckDnsOption, err error) {
@ -239,7 +242,7 @@ func (c *CheckDnsOptionRaw) Option() (opt *CheckDnsOption, err error) {
if c.opt == nil {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
udpCheckOption, err := ParseCheckDnsOption(ctx, c.Raw)
udpCheckOption, err := ParseCheckDnsOption(ctx, c.Raw, c.ResolverNetwork)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
@ -266,6 +269,10 @@ func (d *Dialer) ActivateCheck() {
func (d *Dialer) aliveBackground() {
timeout := 10 * time.Second
cycle := d.CheckInterval
var tcpSomark uint32
if network, err := netproxy.ParseMagicNetwork(d.TcpCheckOptionRaw.ResolverNetwork); err == nil {
tcpSomark = network.Mark
}
tcp4CheckOpt := &CheckOption{
networkType: &NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
@ -285,7 +292,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.HttpCheck(ctx, opt.Url, opt.Ip4, opt.Method)
return d.HttpCheck(ctx, opt.Url, opt.Ip4, opt.Method, tcpSomark)
},
}
tcp6CheckOpt := &CheckOption{
@ -307,7 +314,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.HttpCheck(ctx, opt.Url, opt.Ip6, opt.Method)
return d.HttpCheck(ctx, opt.Url, opt.Ip6, opt.Method, tcpSomark)
},
}
tcp4CheckDnsOpt := &CheckOption{
@ -329,7 +336,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), true)
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
},
}
tcp6CheckDnsOpt := &CheckOption{
@ -351,7 +358,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), true)
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
},
}
udp4CheckDnsOpt := &CheckOption{
@ -372,7 +379,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), false)
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
},
}
udp6CheckDnsOpt := &CheckOption{
@ -393,7 +400,7 @@ func (d *Dialer) aliveBackground() {
}).Debugln("Skip check due to no DNS record.")
return false, nil
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), false)
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), d.CheckDnsOptionRaw.ResolverNetwork)
},
}
var CheckOpts = []*CheckOption{
@ -535,7 +542,7 @@ func (d *Dialer) Check(timeout time.Duration,
return ok, err
}
func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr, method string) (ok bool, err error) {
func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr, method string, soMark uint32) (ok bool, err error) {
// HTTP(S) check.
if method == "" {
method = http.MethodGet
@ -545,7 +552,7 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
// Force to dial "ip".
conn, err := cd.DialTcpContext(ctx, net.JoinHostPort(ip.String(), u.Port()))
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", soMark), net.JoinHostPort(ip.String(), u.Port()))
if err != nil {
return nil, err
}
@ -584,8 +591,8 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
}
}
func (d *Dialer) DnsCheck(ctx context.Context, dns netip.AddrPort, tcp bool) (ok bool, err error) {
addrs, err := netutils.ResolveNetip(ctx, d, dns, consts.UdpCheckLookupHost, dnsmessage.TypeA, tcp)
func (d *Dialer) DnsCheck(ctx context.Context, dns netip.AddrPort, network string) (ok bool, err error) {
addrs, err := netutils.ResolveNetip(ctx, d, dns, consts.UdpCheckLookupHost, dnsmessage.TypeA, network)
if err != nil {
return false, err
}

View File

@ -2,6 +2,7 @@ package trojan
import (
"fmt"
"github.com/daeuniverse/dae/component/outbound/transport/tls"
"net"
"net/url"
"strconv"
@ -9,7 +10,6 @@ import (
"github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/daeuniverse/dae/component/outbound/transport/tls"
"github.com/daeuniverse/dae/component/outbound/transport/ws"
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/protocol"

View File

@ -63,38 +63,28 @@ func (s *SimpleObfs) Dial(network, addr string) (c netproxy.Conn, err error) {
}
switch magicNetwork.Network {
case "tcp":
return s.DialTcp(addr)
rc, err := s.dialer.Dial(network, s.addr)
if err != nil {
return nil, fmt.Errorf("[simpleobfs]: dial to %s: %w", s.addr, err)
}
host, port, err := net.SplitHostPort(s.addr)
if err != nil {
return nil, err
}
if s.host != "" {
host = s.host
}
switch s.obfstype {
case HTTP:
c = NewHTTPObfs(rc, host, port, s.path)
case TLS:
c = NewTLSObfs(rc, host)
}
return c, err
case "udp":
return s.DialUdp(addr)
return nil, fmt.Errorf("%w: simpleobfs+udp", netproxy.UnsupportedTunnelTypeError)
default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
}
}
func (s *SimpleObfs) DialUdp(addr string) (conn netproxy.PacketConn, err error) {
return nil, fmt.Errorf("%w: simpleobfs+udp", netproxy.UnsupportedTunnelTypeError)
}
// DialTcp connects to the address addr on the network net via the proxy.
func (s *SimpleObfs) DialTcp(addr string) (c netproxy.Conn, err error) {
rc, err := s.dialer.DialTcp(s.addr)
if err != nil {
return nil, fmt.Errorf("[simpleobfs]: dial to %s: %w", s.addr, err)
}
host, port, err := net.SplitHostPort(s.addr)
if err != nil {
return nil, err
}
if s.host != "" {
host = s.host
}
switch s.obfstype {
case HTTP:
c = NewHTTPObfs(rc, host, port, s.path)
case TLS:
c = NewTLSObfs(rc, host)
}
return c, err
}

View File

@ -61,55 +61,47 @@ func (s *Tls) Dial(network, addr string) (c netproxy.Conn, err error) {
}
switch magicNetwork.Network {
case "tcp":
return s.DialTcp(addr)
rc, err := s.dialer.Dial(network, addr)
if err != nil {
return nil, fmt.Errorf("[Tls]: dial to %s: %w", s.addr, err)
}
var tlsConn interface {
netproxy.Conn
Handshake() error
}
switch s.tlsImplentation {
case "tls":
tlsConn = tls.Client(&netproxy.FakeNetConn{
Conn: rc,
LAddr: nil,
RAddr: nil,
}, s.tlsConfig)
case "utls":
clientHelloID, err := nameToUtlsClientHelloID(s.utlsImitate)
if err != nil {
return nil, err
}
tlsConn = utls.UClient(&netproxy.FakeNetConn{
Conn: rc,
LAddr: nil,
RAddr: nil,
}, uTLSConfigFromTLSConfig(s.tlsConfig), *clientHelloID)
default:
return nil, fmt.Errorf("unknown tls implementation: %v", s.tlsImplentation)
}
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
return tlsConn, err
case "udp":
return s.DialUdp(addr)
return nil, fmt.Errorf("%w: tls+udp", netproxy.UnsupportedTunnelTypeError)
default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
}
}
func (s *Tls) DialUdp(addr string) (conn netproxy.PacketConn, err error) {
return nil, fmt.Errorf("%w: tls+udp", netproxy.UnsupportedTunnelTypeError)
}
func (s *Tls) DialTcp(addr string) (conn netproxy.Conn, err error) {
rc, err := s.dialer.DialTcp(addr)
if err != nil {
return nil, fmt.Errorf("[Tls]: dial to %s: %w", s.addr, err)
}
var tlsConn interface {
netproxy.Conn
Handshake() error
}
switch s.tlsImplentation {
case "tls":
tlsConn = tls.Client(&netproxy.FakeNetConn{
Conn: rc,
LAddr: nil,
RAddr: nil,
}, s.tlsConfig)
case "utls":
clientHelloID, err := nameToUtlsClientHelloID(s.utlsImitate)
if err != nil {
return nil, err
}
tlsConn = utls.UClient(&netproxy.FakeNetConn{
Conn: rc,
LAddr: nil,
RAddr: nil,
}, uTLSConfigFromTLSConfig(s.tlsConfig), *clientHelloID)
default:
return nil, fmt.Errorf("unknown tls implementation: %v", s.tlsImplentation)
}
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
return tlsConn, err
}

View File

@ -13,10 +13,10 @@ import (
// Ws is a base Ws struct
type Ws struct {
dialer netproxy.Dialer
wsAddr string
header http.Header
wsDialer *websocket.Dialer
dialer netproxy.Dialer
wsAddr string
header http.Header
tlsClientConfig *tls.Config
}
// NewWs returns a Ws infra.
@ -43,23 +43,9 @@ func NewWs(s string, d netproxy.Dialer) (*Ws, error) {
Host: u.Host,
}
t.wsAddr = wsUrl.String() + u.Path
t.wsDialer = &websocket.Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
c, err := d.DialTcp(addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: c,
LAddr: nil,
RAddr: nil,
}, nil
},
//Subprotocols: []string{"binary"},
}
if u.Scheme == "wss" {
skipVerify, _ := strconv.ParseBool(u.Query().Get("allowInsecure"))
t.wsDialer.TLSClientConfig = &tls.Config{
t.tlsClientConfig = &tls.Config{
ServerName: u.Query().Get("sni"),
InsecureSkipVerify: skipVerify,
}
@ -74,23 +60,28 @@ func (s *Ws) Dial(network, addr string) (c netproxy.Conn, err error) {
}
switch magicNetwork.Network {
case "tcp":
return s.DialTcp(addr)
wsDialer := &websocket.Dialer{
NetDial: func(_, addr string) (net.Conn, error) {
c, err := s.dialer.Dial(network, addr)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{
Conn: c,
LAddr: nil,
RAddr: nil,
}, nil
},
//Subprotocols: []string{"binary"},
}
rc, _, err := wsDialer.Dial(s.wsAddr, s.header)
if err != nil {
return nil, fmt.Errorf("[Ws]: dial to %s: %w", s.wsAddr, err)
}
return newConn(rc), err
case "udp":
return s.DialUdp(addr)
return nil, fmt.Errorf("%w: ws+udp", netproxy.UnsupportedTunnelTypeError)
default:
return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network)
}
}
func (s *Ws) DialUdp(addr string) (netproxy.PacketConn, error) {
return nil, fmt.Errorf("%w: ws+udp", netproxy.UnsupportedTunnelTypeError)
}
// DialTcp connects to the address addr on the network net via the infra.
func (s *Ws) DialTcp(addr string) (netproxy.Conn, error) {
rc, _, err := s.wsDialer.Dial(s.wsAddr, s.header)
if err != nil {
return nil, fmt.Errorf("[Ws]: dial to %s: %w", s.wsAddr, err)
}
return newConn(rc), err
}