diff --git a/common/consts/ebpf.go b/common/consts/ebpf.go index 4e94c9d..22db40f 100644 --- a/common/consts/ebpf.go +++ b/common/consts/ebpf.go @@ -15,7 +15,7 @@ const ( AppName = "dae" BpfPinRoot = "/sys/fs/bpf" - AddrHdrSize = 20 + AddrHdrSize = 24 TaskCommLen = 16 ) diff --git a/common/consts/routing.go b/common/consts/routing.go index b11cfc3..f7e5917 100644 --- a/common/consts/routing.go +++ b/common/consts/routing.go @@ -10,8 +10,8 @@ type RoutingDomainKey string const ( RoutingDomainKey_Full RoutingDomainKey = "full" RoutingDomainKey_Keyword RoutingDomainKey = "keyword" - RoutingDomainKey_Suffix RoutingDomainKey = "suffix" - RoutingDomainKey_Regex RoutingDomainKey = "regex" + RoutingDomainKey_Suffix RoutingDomainKey = "suffix" + RoutingDomainKey_Regex RoutingDomainKey = "regex" Function_Domain = "domain" Function_Ip = "ip" @@ -24,4 +24,6 @@ const ( Function_ProcessName = "pname" Declaration_Fallback = "fallback" + + OutboundParam_Mark = "mark" ) diff --git a/common/netutils/dns.go b/common/netutils/dns.go index 806ebec..fcac253 100644 --- a/common/netutils/dns.go +++ b/common/netutils/dns.go @@ -114,6 +114,7 @@ func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, ho return nil, err } if tcp { + // Put DNS request length buf := pool.Get(2 + len(b)) defer pool.Put(buf) binary.BigEndian.PutUint16(buf, uint16(len(b))) @@ -160,6 +161,7 @@ func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, ho buf := pool.Get(512) defer pool.Put(buf) if tcp { + // Read DNS response length _, err := io.ReadFull(c, buf[:2]) if err != nil { ch <- err diff --git a/common/utils.go b/common/utils.go index 69f1fb3..3dce960 100644 --- a/common/utils.go +++ b/common/utils.go @@ -283,9 +283,9 @@ func FuzzyDecode(to interface{}, val string) bool { v.SetUint(i) case reflect.Bool: switch strings.ToLower(val) { - case "true", "1", "y", "yes": + case "true", "t", "1", "y", "yes", "on": v.SetBool(true) - case "false", "0", "n", "no": + case "false", "f", "0", "n", "no", "off": v.SetBool(false) default: return false diff --git a/component/outbound/dialer/block.go b/component/outbound/dialer/block.go index f2cc014..d901f79 100644 --- a/component/outbound/dialer/block.go +++ b/component/outbound/dialer/block.go @@ -6,6 +6,7 @@ package dialer import ( + "fmt" "github.com/mzz2017/softwind/netproxy" "net" ) @@ -14,6 +15,21 @@ type blockDialer struct { DialCallback func() } +func (d *blockDialer) Dial(network, addr string) (c netproxy.Conn, err error) { + magicNetwork, err := netproxy.ParseMagicNetwork(network) + if err != nil { + return nil, err + } + switch magicNetwork.Network { + case "tcp": + return d.DialTcp(addr) + case "udp": + return d.DialUdp(addr) + default: + return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network) + } +} + func (d *blockDialer) DialTcp(addr string) (c netproxy.Conn, err error) { d.DialCallback() return nil, net.ErrClosed diff --git a/component/outbound/transport/simpleobfs/simpleobfs.go b/component/outbound/transport/simpleobfs/simpleobfs.go index 9863538..2d60ddb 100644 --- a/component/outbound/transport/simpleobfs/simpleobfs.go +++ b/component/outbound/transport/simpleobfs/simpleobfs.go @@ -56,6 +56,22 @@ func NewSimpleObfs(s string, d netproxy.Dialer) (*SimpleObfs, error) { return t, nil } + +func (s *SimpleObfs) Dial(network, addr string) (c netproxy.Conn, err error) { + magicNetwork, err := netproxy.ParseMagicNetwork(network) + if err != nil { + return nil, err + } + switch magicNetwork.Network { + case "tcp": + return s.DialTcp(addr) + case "udp": + return s.DialUdp(addr) + default: + return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network) + } +} + func (s *SimpleObfs) DialUdp(addr string) (conn netproxy.PacketConn, err error) { return nil, fmt.Errorf("%w: simpleobfs+udp", netproxy.UnsupportedTunnelTypeError) } diff --git a/component/outbound/transport/tls/tls.go b/component/outbound/transport/tls/tls.go index 392ab37..d8055ca 100644 --- a/component/outbound/transport/tls/tls.go +++ b/component/outbound/transport/tls/tls.go @@ -47,9 +47,25 @@ func NewTls(s string, d netproxy.Dialer) (*Tls, error) { return t, nil } +func (s *Tls) Dial(network, addr string) (c netproxy.Conn, err error) { + magicNetwork, err := netproxy.ParseMagicNetwork(network) + if err != nil { + return nil, err + } + switch magicNetwork.Network { + case "tcp": + return s.DialTcp(addr) + case "udp": + return s.DialUdp(addr) + default: + return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network) + } +} + func (s *Tls) DialUdp(addr string) (conn netproxy.PacketConn, err error) { return nil, fmt.Errorf("%w: tls+udp", netproxy.UnsupportedTunnelTypeError) } + func (s *Tls) DialTcp(addr string) (conn netproxy.Conn, err error) { rc, err := s.dialer.DialTcp(addr) if err != nil { diff --git a/component/outbound/transport/ws/ws.go b/component/outbound/transport/ws/ws.go index 3f2af0c..d29b057 100644 --- a/component/outbound/transport/ws/ws.go +++ b/component/outbound/transport/ws/ws.go @@ -67,6 +67,21 @@ func NewWs(s string, d netproxy.Dialer) (*Ws, error) { return t, nil } +func (s *Ws) Dial(network, addr string) (c netproxy.Conn, err error) { + magicNetwork, err := netproxy.ParseMagicNetwork(network) + if err != nil { + return nil, err + } + switch magicNetwork.Network { + case "tcp": + return s.DialTcp(addr) + case "udp": + return s.DialUdp(addr) + default: + return nil, fmt.Errorf("%w: %v", netproxy.UnsupportedTunnelTypeError, network) + } +} + func (s *Ws) DialUdp(addr string) (netproxy.PacketConn, error) { return nil, fmt.Errorf("%w: ws+udp", netproxy.UnsupportedTunnelTypeError) } diff --git a/component/routing/matcher_builder.go b/component/routing/matcher_builder.go index cacd199..b2bf102 100644 --- a/component/routing/matcher_builder.go +++ b/component/routing/matcher_builder.go @@ -12,6 +12,7 @@ import ( "github.com/v2rayA/dae/common/consts" "github.com/v2rayA/dae/pkg/config_parser" "net/netip" + "strconv" "strings" ) @@ -25,19 +26,24 @@ type DomainSet struct { Domains []string } +type Outbound struct { + Name string + Mark uint32 +} + type MatcherBuilder interface { - AddDomain(f *config_parser.Function, key string, values []string, outbound string) - AddIp(f *config_parser.Function, values []netip.Prefix, outbound string) - AddPort(f *config_parser.Function, values [][2]uint16, outbound string) - AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound string) - AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound string) - AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound string) - AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound string) - AddSourceMac(f *config_parser.Function, values [][6]byte, outbound string) - AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound string) - AddFallback(outbound string) - AddAnyBefore(f *config_parser.Function, key string, values []string, outbound string) - AddAnyAfter(f *config_parser.Function, key string, values []string, outbound string) + AddDomain(f *config_parser.Function, key string, values []string, outbound *Outbound) + AddIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) + AddPort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) + AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) + AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) + AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *Outbound) + AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *Outbound) + AddSourceMac(f *config_parser.Function, values [][6]byte, outbound *Outbound) + AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *Outbound) + AddFallback(outbound *Outbound) + AddAnyBefore(f *config_parser.Function, key string, values []string, outbound *Outbound) + AddAnyAfter(f *config_parser.Function, key string, values []string, outbound *Outbound) } func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) { @@ -72,9 +78,34 @@ func ToProcessName(processName string) (procName [consts.TaskCommLen]byte) { return procName } -func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*config_parser.RoutingRule, fallbackOutbound string) (err error) { +func parseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err error) { + outbound = &Outbound{ + Name: rawOutbound.Name, + Mark: 0, + } + for _, p := range rawOutbound.Params { + switch p.Key { + case consts.OutboundParam_Mark: + var _mark uint64 + _mark, err = strconv.ParseUint(p.Val, 0, 32) + if err != nil { + return nil, fmt.Errorf("failed to parse mark: %v", err) + } + outbound.Mark = uint32(_mark) + default: + return nil, fmt.Errorf("unknown outbound param: %v", p.Key) + } + } + return outbound, nil +} + +func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*config_parser.RoutingRule, fallbackOutbound interface{}) (err error) { for _, rule := range rules { log.Debugln("[rule]", rule.String(true)) + outbound, err := parseOutbound(&rule.Outbound) + if err != nil { + return err + } // rule is like: domain(domain:baidu.com) && port(443) -> proxy for iFunc, f := range rule.AndFunctions { @@ -82,12 +113,15 @@ func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*co paramValueGroups, keyOrder := GroupParamValuesByKey(f.Params) for jMatchSet, key := range keyOrder { paramValueGroup := paramValueGroups[key] - // Preprocess the outbound and pass FakeOutbound_AND to all but the last function. - outbound := FakeOutbound_OR + // Preprocess the outbound. + overrideOutbound := &Outbound{ + Name: FakeOutbound_OR, + Mark: outbound.Mark, + } if jMatchSet == len(keyOrder)-1 { - outbound = FakeOutbound_AND + overrideOutbound.Name = FakeOutbound_AND if iFunc == len(rule.AndFunctions)-1 { - outbound = rule.Outbound + overrideOutbound.Name = outbound.Name } } @@ -97,22 +131,22 @@ func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*co if f.Not { symNot = "!" } - log.Debugf("\t%v%v(%v) -> %v", symNot, f.Name, key, outbound) + log.Debugf("\t%v%v(%v) -> %v", symNot, f.Name, key, overrideOutbound) } - builder.AddAnyBefore(f, key, paramValueGroup, outbound) + builder.AddAnyBefore(f, key, paramValueGroup, overrideOutbound) switch f.Name { case consts.Function_Domain: - builder.AddDomain(f, key, paramValueGroup, outbound) + builder.AddDomain(f, key, paramValueGroup, overrideOutbound) case consts.Function_Ip, consts.Function_SourceIp: cidrs, err := ParsePrefixes(paramValueGroup) if err != nil { return err } if f.Name == consts.Function_Ip { - builder.AddIp(f, cidrs, outbound) + builder.AddIp(f, cidrs, overrideOutbound) } else { - builder.AddSourceIp(f, cidrs, outbound) + builder.AddSourceIp(f, cidrs, overrideOutbound) } case consts.Function_Mac: var macAddrs [][6]byte @@ -123,7 +157,7 @@ func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*co } macAddrs = append(macAddrs, mac) } - builder.AddSourceMac(f, macAddrs, outbound) + builder.AddSourceMac(f, macAddrs, overrideOutbound) case consts.Function_Port, consts.Function_SourcePort: var portRanges [][2]uint16 for _, v := range paramValueGroup { @@ -134,9 +168,9 @@ func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*co portRanges = append(portRanges, portRange) } if f.Name == consts.Function_Port { - builder.AddPort(f, portRanges, outbound) + builder.AddPort(f, portRanges, overrideOutbound) } else { - builder.AddSourcePort(f, portRanges, outbound) + builder.AddSourcePort(f, portRanges, overrideOutbound) } case consts.Function_L4Proto: var l4protoType consts.L4ProtoType @@ -148,7 +182,7 @@ func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*co l4protoType |= consts.L4ProtoType_UDP } } - builder.AddL4Proto(f, l4protoType, outbound) + builder.AddL4Proto(f, l4protoType, overrideOutbound) case consts.Function_IpVersion: var ipVersion consts.IpVersionType for _, v := range paramValueGroup { @@ -159,7 +193,7 @@ func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*co ipVersion |= consts.IpVersion_6 } } - builder.AddIpVersion(f, ipVersion, outbound) + builder.AddIpVersion(f, ipVersion, overrideOutbound) case consts.Function_ProcessName: var procNames [][consts.TaskCommLen]byte for _, v := range paramValueGroup { @@ -168,47 +202,60 @@ func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*co } procNames = append(procNames, ToProcessName(v)) } - builder.AddProcessName(f, procNames, outbound) + builder.AddProcessName(f, procNames, overrideOutbound) default: return fmt.Errorf("unsupported function name: %v", f.Name) } - builder.AddAnyAfter(f, key, paramValueGroup, outbound) + builder.AddAnyAfter(f, key, paramValueGroup, overrideOutbound) } } } + var rawFallback *config_parser.Function + switch fallback := fallbackOutbound.(type) { + case string: + rawFallback = &config_parser.Function{Name: fallback} + case *config_parser.Function: + rawFallback = fallback + default: + return fmt.Errorf("unknown type of 'fallback' in section routing: %T", fallback) + } + fallback, err := parseOutbound(rawFallback) + if err != nil { + return err + } builder.AddAnyBefore(&config_parser.Function{ Name: "fallback", - }, "", nil, fallbackOutbound) - builder.AddFallback(fallbackOutbound) + }, "", nil, fallback) + builder.AddFallback(fallback) builder.AddAnyAfter(&config_parser.Function{ Name: "fallback", - }, "", nil, fallbackOutbound) + }, "", nil, fallback) return nil } type DefaultMatcherBuilder struct { } -func (d *DefaultMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound string) { +func (d *DefaultMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound *Outbound) { } -func (d *DefaultMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound string) { +func (d *DefaultMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) { } -func (d *DefaultMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound string) { +func (d *DefaultMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) { } -func (d *DefaultMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound string) { +func (d *DefaultMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) { } -func (d *DefaultMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound string) { +func (d *DefaultMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) { } -func (d *DefaultMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound string) { +func (d *DefaultMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *Outbound) { } -func (d *DefaultMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound string) { +func (d *DefaultMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *Outbound) { } -func (d *DefaultMatcherBuilder) AddSourceMac(f *config_parser.Function, values [][6]byte, outbound string) { +func (d *DefaultMatcherBuilder) AddSourceMac(f *config_parser.Function, values [][6]byte, outbound *Outbound) { } -func (d *DefaultMatcherBuilder) AddFallback(outbound string) {} -func (d *DefaultMatcherBuilder) AddAnyBefore(f *config_parser.Function, key string, values []string, outbound string) { +func (d *DefaultMatcherBuilder) AddFallback(outbound *Outbound) {} +func (d *DefaultMatcherBuilder) AddAnyBefore(f *config_parser.Function, key string, values []string, outbound *Outbound) { } -func (d *DefaultMatcherBuilder) AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound string) { +func (d *DefaultMatcherBuilder) AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *Outbound) { } -func (d *DefaultMatcherBuilder) AddAnyAfter(f *config_parser.Function, key string, values []string, outbound string) { +func (d *DefaultMatcherBuilder) AddAnyAfter(f *config_parser.Function, key string, values []string, outbound *Outbound) { } diff --git a/component/routing/optimizer.go b/component/routing/optimizer.go index 90ef1fd..120a4be 100644 --- a/component/routing/optimizer.go +++ b/component/routing/optimizer.go @@ -81,7 +81,7 @@ func (o *MergeAndSortRulesOptimizer) Optimize(rules []*config_parser.RoutingRule if len(mergingRule.AndFunctions) == 1 && len(rules[i].AndFunctions) == 1 && mergingRule.AndFunctions[0].Name == rules[i].AndFunctions[0].Name && - rules[i].Outbound == mergingRule.Outbound { + rules[i].Outbound.String(true) == mergingRule.Outbound.String(true) { mergingRule.AndFunctions[0].Params = append(mergingRule.AndFunctions[0].Params, rules[i].AndFunctions[0].Params...) } else { newRules = append(newRules, mergingRule) diff --git a/config/config.go b/config/config.go index 901f68d..53697ad 100644 --- a/config/config.go +++ b/config/config.go @@ -42,8 +42,8 @@ type GroupParam struct { type Routing struct { Rules []*config_parser.RoutingRule `mapstructure:"_"` - Fallback string `mapstructure:"fallback"` - Final string `mapstructure:"final"` + Fallback interface{} `mapstructure:"fallback"` + Final interface{} `mapstructure:"final"` } type Params struct { diff --git a/config/patch.go b/config/patch.go index adb7c9f..a8d722f 100644 --- a/config/patch.go +++ b/config/patch.go @@ -5,7 +5,10 @@ package config -import "fmt" +import ( + "fmt" + "github.com/sirupsen/logrus" +) type patch func(params *Params) error @@ -15,11 +18,12 @@ var patches = []patch{ func patchRoutingFallback(params *Params) error { // We renamed final as fallback. So we apply this patch for compatibility with older users. - if params.Routing.Fallback == "" && params.Routing.Final != "" { + if params.Routing.Fallback == nil && params.Routing.Final != nil { params.Routing.Fallback = params.Routing.Final + logrus.Warnln("Name 'final' in section routing was deprecated and will be removed in the future; please rename it as 'fallback'") } // Fallback is required. - if params.Routing.Fallback == "" { + if params.Routing.Fallback == nil { return fmt.Errorf("fallback is required in routing") } return nil diff --git a/control/control_plane.go b/control/control_plane.go index 248d08c..66f8427 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -528,9 +528,9 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) { go func(data []byte, src netip.AddrPort) { defer pool.Put(data) var realDst netip.AddrPort - var outboundIndex consts.OutboundIndex + var routingResult *bpfRoutingResult pktDst := RetrieveOriginalDest(oob[:oobn]) - outboundIndex, err := c.core.RetrieveOutboundIndex(src, pktDst, unix.IPPROTO_UDP) + routingResult, err := c.core.RetrieveRoutingResult(src, pktDst, unix.IPPROTO_UDP) if err != nil { // WAN. Old method. addrHdr, dataOffset, err := ParseAddrHdr(data) @@ -539,13 +539,16 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) { return } copy(data, data[dataOffset:]) - outboundIndex = consts.OutboundIndex(addrHdr.Outbound) + routingResult = &bpfRoutingResult{ + Mark: addrHdr.Mark, + Outbound: addrHdr.Outbound, + } src = netip.AddrPortFrom(addrHdr.Dest.Addr(), src.Port()) realDst = addrHdr.Dest } else { realDst = pktDst } - if e := c.handlePkt(udpConn, data, src, pktDst, realDst, outboundIndex); e != nil { + if e := c.handlePkt(udpConn, data, src, pktDst, realDst, routingResult); e != nil { c.log.Warnln("handlePkt:", e) } }(newBuf, src) diff --git a/control/kern/tproxy.c b/control/kern/tproxy.c index 3a28bef..ef9b3aa 100644 --- a/control/kern/tproxy.c +++ b/control/kern/tproxy.c @@ -132,6 +132,7 @@ struct ip_port_outbound { __be16 port; __u8 outbound; __u8 unused; + __u32 mark; }; struct tuples { @@ -153,12 +154,18 @@ struct { __uint(max_entries, MAX_DST_MAPPING_NUM); /// NOTICE: It MUST be pinned. __uint(pinning, LIBBPF_PIN_BY_NAME); -} tcp_dst_map SEC(".maps"); +} tcp_dst_map + SEC(".maps"); // This map is only for old method (redirect mode in WAN). + +struct routing_result { + __u32 mark; + __u8 outbound; +}; struct { __uint(type, BPF_MAP_TYPE_LRU_HASH); __type(key, struct tuples); - __type(value, __u32); // outbound + __type(value, struct routing_result); // outbound __uint(max_entries, MAX_DST_MAPPING_NUM); /// NOTICE: It MUST be pinned. __uint(pinning, LIBBPF_PIN_BY_NAME); @@ -312,6 +319,8 @@ struct match_set { bool not ; // A subrule flag (this is not a match_set flag). enum MatchType type; __u8 outbound; // User-defined value range is [0, 252]. + __u8 unused; + __u32 mark; }; struct { __uint(type, BPF_MAP_TYPE_ARRAY); @@ -925,7 +934,8 @@ decap_after_udp_hdr(struct __sk_buff *skb, __u8 ipversion, __u8 ihl, } // Do not use __always_inline here because this function is too heavy. -static int __attribute__((noinline)) +// low -> high: outbound(8b) mark(32b) unused(23b) sign(1b) +static __s64 __attribute__((noinline)) routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], const __be32 _daddr[4], const __be32 mac[4]) { #define _l4proto_type flag[0] @@ -1153,9 +1163,9 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], if (match_set->outbound == OUTBOUND_DIRECT && h_dport == 53 && _l4proto_type == L4ProtoType_UDP) { // DNS packet should go through control plane. - return OUTBOUND_CONTROL_PLANE_DIRECT; + return OUTBOUND_CONTROL_PLANE_DIRECT | (match_set->mark << 8); } - return match_set->outbound; + return match_set->outbound | (match_set->mark << 8); } bad_rule = false; } @@ -1334,7 +1344,15 @@ new_connection: bpf_printk("shot routing: %d", ret); return TC_ACT_SHOT; } - __u32 outbound = ret; + struct routing_result routing_result = {0}; + routing_result.outbound = ret; + routing_result.mark = ret >> 8; + // Save routing result. + if ((ret = bpf_map_update_elem(&routing_tuples_map, &tuples, &routing_result, + BPF_ANY))) { + bpf_printk("shot save routing result: %d", ret); + return TC_ACT_SHOT; + } #if defined(__DEBUG_ROUTING) || defined(__PRINT_ROUTING_RESULT) if (l4proto == IPPROTO_TCP) { bpf_printk("tcp(lan): outbound: %u, target: %pI6:%u", outbound, @@ -1344,21 +1362,24 @@ new_connection: tuples.dip.u6_addr32, bpf_ntohs(tuples.dport)); } #endif - if (outbound == OUTBOUND_DIRECT || outbound == OUTBOUND_MUST_DIRECT) { + if (routing_result.outbound == OUTBOUND_DIRECT || + routing_result.outbound == OUTBOUND_MUST_DIRECT) { __u32 *nat; if ((nat = bpf_map_lookup_elem(¶m_map, &control_plane_nat_direct_key)) && *nat) { + // Do not mark if packet is sent to control_plane. goto control_plane_tproxy; } + skb->mark = routing_result.mark; goto direct; - } else if (unlikely(outbound == OUTBOUND_BLOCK)) { + } else if (unlikely(routing_result.outbound == OUTBOUND_BLOCK)) { goto block; } // Check outbound connectivity in specific ipversion and l4proto. struct outbound_connectivity_query q = {0}; - q.outbound = outbound; + q.outbound = routing_result.outbound; q.ipversion = ipversion; q.l4proto = l4proto; __u32 *alive; @@ -1370,12 +1391,6 @@ new_connection: } control_plane_tproxy: - // Save routing result. - if ((ret = bpf_map_update_elem(&routing_tuples_map, &tuples, &outbound, - BPF_ANY))) { - bpf_printk("shot save routing result: %d", ret); - return TC_ACT_SHOT; - } // Assign to control plane. @@ -1581,6 +1596,7 @@ int tproxy_wan_egress(struct __sk_buff *skb) { __builtin_memcpy(&key_src.ip, &tuples.dip, IPV6_BYTE_LENGTH); key_src.port = tcph.source; __u8 outbound; + __u32 mark; if (unlikely(tcp_state_syn)) { // New TCP connection. // bpf_printk("[%X]New Connection", bpf_ntohl(tcph.seq)); @@ -1612,6 +1628,7 @@ int tproxy_wan_egress(struct __sk_buff *skb) { } outbound = ret; + mark = ret >> 8; #if defined(__DEBUG_ROUTING) || defined(__PRINT_ROUTING_RESULT) // Print only new connection. @@ -1631,9 +1648,13 @@ int tproxy_wan_egress(struct __sk_buff *skb) { return TC_ACT_OK; } outbound = dst->outbound; + mark = dst->mark; } - if (outbound == OUTBOUND_DIRECT || outbound == OUTBOUND_MUST_DIRECT) { + if ((outbound == OUTBOUND_DIRECT || outbound == OUTBOUND_MUST_DIRECT) && + mark == 0 // If mark is not zero, we should re-route it, so we send it + // to control plane in WAN. + ) { return TC_ACT_OK; } else if (unlikely(outbound == OUTBOUND_BLOCK)) { return TC_ACT_SHOT; @@ -1659,6 +1680,7 @@ int tproxy_wan_egress(struct __sk_buff *skb) { __builtin_memcpy(value_dst.ip, &tuples.dip, IPV6_BYTE_LENGTH); value_dst.port = tcph.dest; value_dst.outbound = outbound; + value_dst.mark = mark; // bpf_printk("UPDATE: %pI6:%u", key_src.ip.u6_addr32, // bpf_ntohs(key_src.port)); bpf_map_update_elem(&tcp_dst_map, &key_src, &value_dst, BPF_ANY); @@ -1711,15 +1733,20 @@ int tproxy_wan_egress(struct __sk_buff *skb) { return TC_ACT_SHOT; } new_hdr.outbound = ret; + new_hdr.mark = ret >> 8; #if defined(__DEBUG_ROUTING) || defined(__PRINT_ROUTING_RESULT) - __u32 pid = pid_pname ? pid_pname->pid : 0; + __u32 pid = pid_pname ? pid_pname->pid : 0; bpf_printk("udp(wan): from %pI6:%u [PID %u]", tuples.sip.u6_addr32, bpf_ntohs(tuples.sport), pid); bpf_printk("udp(wan): outbound: %u, %pI6:%u", new_hdr.outbound, tuples.dip.u6_addr32, bpf_ntohs(tuples.dport)); #endif - if (new_hdr.outbound == OUTBOUND_DIRECT || new_hdr.outbound == OUTBOUND_MUST_DIRECT) { + if ((new_hdr.outbound == OUTBOUND_DIRECT || + new_hdr.outbound == OUTBOUND_MUST_DIRECT) && + new_hdr.mark == 0 // If mark is not zero, we should re-route it, so we + // send it to control plane in WAN. + ) { return TC_ACT_OK; } else if (unlikely(new_hdr.outbound == OUTBOUND_BLOCK)) { return TC_ACT_SHOT; diff --git a/control/routing_matcher_builder.go b/control/routing_matcher_builder.go index 5e2ef57..ef1f499 100644 --- a/control/routing_matcher_builder.go +++ b/control/routing_matcher_builder.go @@ -52,7 +52,7 @@ func (b *RoutingMatcherBuilder) OutboundToId(outbound string) uint8 { return outboundId } -func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound string) { +func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound *routing.Outbound) { if b.err != nil { return } @@ -73,11 +73,12 @@ func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string, b.rules = append(b.rules, bpfMatchSet{ Type: uint8(consts.MatchType_DomainSet), Not: f.Not, - Outbound: b.OutboundToId(outbound), + Outbound: b.OutboundToId(outbound.Name), + Mark: outbound.Mark, }) } -func (b *RoutingMatcherBuilder) AddSourceMac(f *config_parser.Function, macAddrs [][6]byte, outbound string) { +func (b *RoutingMatcherBuilder) AddSourceMac(f *config_parser.Function, macAddrs [][6]byte, outbound *routing.Outbound) { if b.err != nil { return } @@ -94,14 +95,15 @@ func (b *RoutingMatcherBuilder) AddSourceMac(f *config_parser.Function, macAddrs Value: [16]byte{}, Type: uint8(consts.MatchType_Mac), Not: f.Not, - Outbound: b.OutboundToId(outbound), + Outbound: b.OutboundToId(outbound.Name), + Mark: outbound.Mark, } binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex)) b.rules = append(b.rules, set) } -func (b *RoutingMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound string) { +func (b *RoutingMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) { if b.err != nil { return } @@ -111,17 +113,18 @@ func (b *RoutingMatcherBuilder) AddIp(f *config_parser.Function, values []netip. Value: [16]byte{}, Type: uint8(consts.MatchType_IpSet), Not: f.Not, - Outbound: b.OutboundToId(outbound), + Outbound: b.OutboundToId(outbound.Name), + Mark: outbound.Mark, } binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex)) b.rules = append(b.rules, set) } -func (b *RoutingMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, _outbound string) { +func (b *RoutingMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) { for i, value := range values { - outbound := routing.FakeOutbound_OR + outboundName := routing.FakeOutbound_OR if i == len(values)-1 { - outbound = _outbound + outboundName = outbound.Name } b.rules = append(b.rules, bpfMatchSet{ Type: uint8(consts.MatchType_Port), @@ -130,12 +133,13 @@ func (b *RoutingMatcherBuilder) AddPort(f *config_parser.Function, values [][2]u PortEnd: value[1], }.Encode(), Not: f.Not, - Outbound: b.OutboundToId(outbound), + Outbound: b.OutboundToId(outboundName), + Mark: outbound.Mark, }) } } -func (b *RoutingMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound string) { +func (b *RoutingMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) { if b.err != nil { return } @@ -145,17 +149,18 @@ func (b *RoutingMatcherBuilder) AddSourceIp(f *config_parser.Function, values [] Value: [16]byte{}, Type: uint8(consts.MatchType_SourceIpSet), Not: f.Not, - Outbound: b.OutboundToId(outbound), + Outbound: b.OutboundToId(outbound.Name), + Mark: outbound.Mark, } binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex)) b.rules = append(b.rules, set) } -func (b *RoutingMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, _outbound string) { +func (b *RoutingMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) { for i, value := range values { - outbound := routing.FakeOutbound_OR + outboundName := routing.FakeOutbound_OR if i == len(values)-1 { - outbound = _outbound + outboundName = outbound.Name } b.rules = append(b.rules, bpfMatchSet{ Type: uint8(consts.MatchType_SourcePort), @@ -164,12 +169,13 @@ func (b *RoutingMatcherBuilder) AddSourcePort(f *config_parser.Function, values PortEnd: value[1], }.Encode(), Not: f.Not, - Outbound: b.OutboundToId(outbound), + Outbound: b.OutboundToId(outboundName), + Mark: outbound.Mark, }) } } -func (b *RoutingMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound string) { +func (b *RoutingMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *routing.Outbound) { if b.err != nil { return } @@ -177,11 +183,12 @@ func (b *RoutingMatcherBuilder) AddL4Proto(f *config_parser.Function, values con Value: [16]byte{byte(values)}, Type: uint8(consts.MatchType_L4Proto), Not: f.Not, - Outbound: b.OutboundToId(outbound), + Outbound: b.OutboundToId(outbound.Name), + Mark: outbound.Mark, }) } -func (b *RoutingMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound string) { +func (b *RoutingMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *routing.Outbound) { if b.err != nil { return } @@ -189,33 +196,36 @@ func (b *RoutingMatcherBuilder) AddIpVersion(f *config_parser.Function, values c Value: [16]byte{byte(values)}, Type: uint8(consts.MatchType_IpVersion), Not: f.Not, - Outbound: b.OutboundToId(outbound), + Outbound: b.OutboundToId(outbound.Name), + Mark: outbound.Mark, }) } -func (b *RoutingMatcherBuilder) AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, _outbound string) { +func (b *RoutingMatcherBuilder) AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *routing.Outbound) { for i, value := range values { - outbound := routing.FakeOutbound_OR + outboundName := routing.FakeOutbound_OR if i == len(values)-1 { - outbound = _outbound + outboundName = outbound.Name } matchSet := bpfMatchSet{ Type: uint8(consts.MatchType_ProcessName), Not: f.Not, - Outbound: b.OutboundToId(outbound), + Outbound: b.OutboundToId(outboundName), + Mark: outbound.Mark, } copy(matchSet.Value[:], value[:]) b.rules = append(b.rules, matchSet) } } -func (b *RoutingMatcherBuilder) AddFallback(outbound string) { +func (b *RoutingMatcherBuilder) AddFallback(outbound *routing.Outbound) { if b.err != nil { return } b.rules = append(b.rules, bpfMatchSet{ Type: uint8(consts.MatchType_Fallback), - Outbound: b.OutboundToId(outbound), + Outbound: b.OutboundToId(outbound.Name), + Mark: outbound.Mark, }) } diff --git a/control/tcp.go b/control/tcp.go index 50d173e..50ca6dc 100644 --- a/control/tcp.go +++ b/control/tcp.go @@ -41,7 +41,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) { // Get tuples and outbound. src := lConn.RemoteAddr().(*net.TCPAddr).AddrPort() dst := lConn.LocalAddr().(*net.TCPAddr).AddrPort() - outboundIndex, err := c.core.RetrieveOutboundIndex(src, dst, unix.IPPROTO_TCP) + routingResult, err := c.core.RetrieveRoutingResult(src, dst, unix.IPPROTO_TCP) if err != nil { // WAN. Old method. var value bpfIpPortOutbound @@ -52,7 +52,10 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) { }, &value); e != nil { return fmt.Errorf("failed to retrieve target info %v: %v, %v", src.String(), err, e) } - outboundIndex = consts.OutboundIndex(value.Outbound) + routingResult = &bpfRoutingResult{ + Mark: value.Mark, + Outbound: value.Outbound, + } dstAddr, ok := netip.AddrFromSlice(common.Ipv6Uint32ArrayToByteSlice(value.Ip)) if !ok { @@ -60,6 +63,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) { } dst = netip.AddrPortFrom(dstAddr, internal.Htons(value.Port)) } + var outboundIndex = consts.OutboundIndex(routingResult.Outbound) switch outboundIndex { case consts.OutboundDirect: @@ -102,7 +106,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) { // Dial and relay. dst = netip.AddrPortFrom(common.ConvergeIp(dst.Addr()), dst.Port()) - rConn, err := d.DialTcp(c.ChooseDialTarget(outboundIndex, dst, domain)) + rConn, err := d.Dial(GetNetwork("tcp", routingResult.Mark), c.ChooseDialTarget(outboundIndex, dst, domain)) if err != nil { return fmt.Errorf("failed to dial %v: %w", dst, err) } diff --git a/control/udp.go b/control/udp.go index 1430558..8f77d01 100644 --- a/control/udp.go +++ b/control/udp.go @@ -47,6 +47,7 @@ func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Message, tim type AddrHdr struct { Dest netip.AddrPort Outbound uint8 + Mark uint32 } func ParseAddrHdr(data []byte) (hdr *AddrHdr, dataOffset int, err error) { @@ -58,9 +59,11 @@ func ParseAddrHdr(data []byte) (hdr *AddrHdr, dataOffset int, err error) { destAddr, _ := netip.AddrFromSlice(data[:ipSize]) port := binary.BigEndian.Uint16(data[ipSize:]) outbound := data[ipSize+2] + mark := binary.BigEndian.Uint32(data[ipSize+4:]) return &AddrHdr{ Dest: netip.AddrPortFrom(destAddr, port), Outbound: outbound, + Mark: mark, }, dataOffset, nil } @@ -71,12 +74,14 @@ func (hdr *AddrHdr) ToBytesFromPool() []byte { copy(buf, ip[:]) binary.BigEndian.PutUint16(buf[ipSize:], hdr.Dest.Port()) buf[ipSize+2] = hdr.Outbound + binary.BigEndian.PutUint32(buf[ipSize+4:], hdr.Mark) return buf } -func sendPktWithHdrWithFlag(data []byte, from netip.AddrPort, lConn *net.UDPConn, to netip.AddrPort, lanWanFlag consts.LanWanFlag) error { +func sendPktWithHdrWithFlag(data []byte, mark uint32, from netip.AddrPort, lConn *net.UDPConn, to netip.AddrPort, lanWanFlag consts.LanWanFlag) error { hdr := AddrHdr{ Dest: from, + Mark: mark, Outbound: uint8(lanWanFlag), // Pass some message to the kernel program. } bHdr := hdr.ToBytesFromPool() @@ -100,7 +105,7 @@ func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn if err != nil { if errors.Is(err, syscall.EADDRINUSE) { // Port collision, use traditional method. - return sendPktWithHdrWithFlag(data, from, lConn, to, lanWanFlag) + return sendPktWithHdrWithFlag(data, 0, from, lConn, to, lanWanFlag) } return err } @@ -140,7 +145,7 @@ func (c *ControlPlane) WriteToUDP(lanWanFlag consts.LanWanFlag, lConn *net.UDPCo } } -func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, realDst netip.AddrPort, outboundIndex consts.OutboundIndex) (err error) { +func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, realDst netip.AddrPort, routingResult *bpfRoutingResult) (err error) { var lanWanFlag consts.LanWanFlag var realSrc netip.AddrPort var domain string @@ -155,6 +160,7 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r } mustDirect := false + outboundIndex := consts.OutboundIndex(routingResult.Outbound) switch outboundIndex { case consts.OutboundDirect: case consts.OutboundMustDirect: @@ -326,13 +332,13 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r if retry > MaxRetry { return fmt.Errorf("touch max retry limit") } + ue, isNew, err = DefaultUdpEndpointPool.GetOrCreate(realSrc, &UdpEndpointOptions{ Handler: udpHandler, NatTimeout: natTimeout, - DialerFunc: func() (*dialer.Dialer, error) { - return dialerForNew, nil - }, - Target: tgtToSend, + Dialer: dialerForNew, + Network: GetNetwork("udp", routingResult.Mark), + Target: tgtToSend, }) if err != nil { return fmt.Errorf("failed to GetOrCreate (policy: %v): %w", outbound.GetSelectionPolicy(), err) @@ -382,7 +388,7 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r // We can block because we are in a coroutine. - conn, err := dialerForNew.DialTcp(tgtToSend) + conn, err := dialerForNew.Dial(GetNetwork("tcp", routingResult.Mark), tgtToSend) if err != nil { return fmt.Errorf("failed to dial proxy to tcp: %w", err) } diff --git a/control/udp_endpoint.go b/control/udp_endpoint.go index 2ff83cd..95ede36 100644 --- a/control/udp_endpoint.go +++ b/control/udp_endpoint.go @@ -22,7 +22,7 @@ type UdpEndpoint struct { conn netproxy.PacketConn // mu protects deadlineTimer mu sync.Mutex - deadlineTimer *time.Timer // nil means UdpEndpoint was closed + deadlineTimer *time.Timer handler UdpHandler NatTimeout time.Duration @@ -48,7 +48,7 @@ func (ue *UdpEndpoint) start() { } } ue.mu.Lock() - ue.Close() + ue.deadlineTimer.Stop() ue.mu.Unlock() } @@ -56,15 +56,13 @@ func (ue *UdpEndpoint) WriteTo(b []byte, addr string) (int, error) { return ue.conn.WriteTo(b, addr) } -func (ue *UdpEndpoint) Close() (err error) { +func (ue *UdpEndpoint) Close() error { ue.mu.Lock() if ue.deadlineTimer != nil { - err = ue.conn.Close() ue.deadlineTimer.Stop() - ue.deadlineTimer = nil } ue.mu.Unlock() - return err + return ue.conn.Close() } // UdpEndpointPool is a full-cone udp conn pool @@ -75,7 +73,9 @@ type UdpEndpointPool struct { type UdpEndpointOptions struct { Handler UdpHandler NatTimeout time.Duration - DialerFunc func() (*dialer.Dialer, error) + Dialer *dialer.Dialer + // Network is useful for MagicNetwork + Network string // Target is useful only if the underlay does not support Full-cone. Target string } @@ -118,12 +118,7 @@ func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEnd return nil, true, fmt.Errorf("createOption.Handler cannot be nil") } - d, err := createOption.DialerFunc() - if err != nil { - return nil, true, err - } - - udpConn, err := d.DialUdp(createOption.Target) + udpConn, err := createOption.Dialer.Dial(createOption.Network, createOption.Target) if err != nil { return nil, true, err } @@ -142,7 +137,7 @@ func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEnd }), handler: createOption.Handler, NatTimeout: createOption.NatTimeout, - Dialer: d, + Dialer: createOption.Dialer, } p.pool[lAddr] = ue // Receive UDP messages. @@ -151,9 +146,7 @@ func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEnd } else { // Postpone the deadline. ue.mu.Lock() - if ue.deadlineTimer != nil { - ue.deadlineTimer.Reset(ue.NatTimeout) - } + ue.deadlineTimer.Reset(ue.NatTimeout) ue.mu.Unlock() } return ue, isNew, nil diff --git a/control/tproxy_utils.go b/control/utils.go similarity index 71% rename from control/tproxy_utils.go rename to control/utils.go index e691c56..1c6e643 100644 --- a/control/tproxy_utils.go +++ b/control/utils.go @@ -9,6 +9,7 @@ import ( "bytes" "encoding/binary" "fmt" + "github.com/mzz2017/softwind/netproxy" "github.com/v2rayA/dae/common/consts" internal "github.com/v2rayA/dae/pkg/ebpf_internal" "golang.org/x/sys/unix" @@ -17,7 +18,7 @@ import ( "syscall" ) -func (c *ControlPlaneCore) RetrieveOutboundIndex(src, dst netip.AddrPort, l4proto uint8) (outboundIndex consts.OutboundIndex, err error) { +func (c *ControlPlaneCore) RetrieveRoutingResult(src, dst netip.AddrPort, l4proto uint8) (result *bpfRoutingResult, err error) { srcIp6 := src.Addr().As16() dstIp6 := dst.Addr().As16() @@ -29,14 +30,11 @@ func (c *ControlPlaneCore) RetrieveOutboundIndex(src, dst netip.AddrPort, l4prot L4proto: l4proto, } - var _outboundIndex uint32 - if err := c.bpf.RoutingTuplesMap.Lookup(tuples, &_outboundIndex); err != nil { - return 0, fmt.Errorf("reading map: key [%v, %v, %v]: %w", src.String(), l4proto, dst.String(), err) + var routingResult bpfRoutingResult + if err := c.bpf.RoutingTuplesMap.Lookup(tuples, &routingResult); err != nil { + return nil, fmt.Errorf("reading map: key [%v, %v, %v]: %w", src.String(), l4proto, dst.String(), err) } - if _outboundIndex > uint32(consts.OutboundMax) { - return 0, fmt.Errorf("bad outbound index") - } - return consts.OutboundIndex(_outboundIndex), nil + return &routingResult, nil } func RetrieveOriginalDest(oob []byte) netip.AddrPort { @@ -67,7 +65,7 @@ func checkIpforward(ifname string, ipversion consts.IpVersionStr) error { if bytes.Equal(bytes.TrimSpace(b), []byte("1")) { return nil } - return fmt.Errorf("ipforward on %v is off: %v", ifname, path) + return fmt.Errorf("ipforward on %v is off: %v; see https://github.com/v2rayA/dae#enable-ip-forwarding", ifname, path) } func CheckIpforward(ifname string) error { @@ -79,3 +77,14 @@ func CheckIpforward(ifname string) error { } return nil } + +func GetNetwork(network string, mark uint32) string { + if mark == 0 { + return network + } else { + return netproxy.MagicNetwork{ + Network: network, + Mark: mark, + }.Encode() + } +} diff --git a/go.mod b/go.mod index ebd8c41..6f45e76 100644 --- a/go.mod +++ b/go.mod @@ -10,12 +10,12 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/json-iterator/go v1.1.12 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 - github.com/mzz2017/softwind v0.0.0-20230217170818-542cba31602f + github.com/mzz2017/softwind v0.0.0-20230220064728-6940dc11777f github.com/safchain/ethtool v0.0.0-20230116090318-67cc41908669 github.com/sirupsen/logrus v1.9.0 github.com/spf13/cobra v1.6.1 github.com/v2rayA/ahocorasick-domain v0.0.0-20230218160829-122a074c48c8 - github.com/v2rayA/dae-config-dist/go/dae_config v0.0.0-20230201041341-1758ee5161c1 + github.com/v2rayA/dae-config-dist/go/dae_config v0.0.0-20230219173344-413f12027632 github.com/vishvananda/netlink v1.1.0 github.com/x-cray/logrus-prefixed-formatter v0.5.2 golang.org/x/crypto v0.5.0 diff --git a/go.sum b/go.sum index 55ea0a5..f1232cf 100644 --- a/go.sum +++ b/go.sum @@ -68,8 +68,8 @@ github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/mzz2017/disk-bloom v1.0.1 h1:rEF9MiXd9qMW3ibRpqcerLXULoTgRlM21yqqJl1B90M= github.com/mzz2017/disk-bloom v1.0.1/go.mod h1:JLHETtUu44Z6iBmsqzkOtFlRvXSlKnxjwiBRDapizDI= -github.com/mzz2017/softwind v0.0.0-20230217170818-542cba31602f h1:zxc1LkGfczEwEdnOlaGbIhyVsoL9dWHEL2WQ4pPgC0c= -github.com/mzz2017/softwind v0.0.0-20230217170818-542cba31602f/go.mod h1:V8GFOtdpTgzCJtCVXRqjmdDsY+PIhCCx4JpD0zq8Z7I= +github.com/mzz2017/softwind v0.0.0-20230220064728-6940dc11777f h1:Lmwy7FFI0PrWw0TgoQYtDiZBlCd/VZ1hBlySauTVWj4= +github.com/mzz2017/softwind v0.0.0-20230220064728-6940dc11777f/go.mod h1:V8GFOtdpTgzCJtCVXRqjmdDsY+PIhCCx4JpD0zq8Z7I= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= @@ -109,8 +109,8 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/v2rayA/ahocorasick-domain v0.0.0-20230218160829-122a074c48c8 h1:2Liq3JvM/acVQZ7Gq9U5PpznMzlFRPYMPQxC2yXSi74= github.com/v2rayA/ahocorasick-domain v0.0.0-20230218160829-122a074c48c8/go.mod h1:mWch8I826zic/bKaCyE9ZZbWtFgEW0ox3EQ0NGm5DGw= -github.com/v2rayA/dae-config-dist/go/dae_config v0.0.0-20230201041341-1758ee5161c1 h1:Ke91ZtZItOO8/SK8nhZ1tXfXcUxj4Meq5pET/L9bHII= -github.com/v2rayA/dae-config-dist/go/dae_config v0.0.0-20230201041341-1758ee5161c1/go.mod h1:JiTWeZybOkBfCqv/fy5jbFhXTxuLlyrI76gRNazz2sU= +github.com/v2rayA/dae-config-dist/go/dae_config v0.0.0-20230219173344-413f12027632 h1:MJ6+M3MpiVMdiZn3An88ZFNeLcLiN7hTaPsd2bVyduI= +github.com/v2rayA/dae-config-dist/go/dae_config v0.0.0-20230219173344-413f12027632/go.mod h1:JiTWeZybOkBfCqv/fy5jbFhXTxuLlyrI76gRNazz2sU= github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= diff --git a/pkg/config_parser/section.go b/pkg/config_parser/section.go index 85c6e2a..736a24d 100644 --- a/pkg/config_parser/section.go +++ b/pkg/config_parser/section.go @@ -178,7 +178,7 @@ func (p *paramAndFunctions) String(compact bool) string { type RoutingRule struct { AndFunctions []*Function - Outbound string + Outbound Function } func (r *RoutingRule) String(calcN bool) string { @@ -206,6 +206,6 @@ func (r *RoutingRule) String(calcN bool) string { } builder.WriteString(fmt.Sprintf("%v%v(%v)", symNot, f.Name, paramBuilder.String())) } - builder.WriteString(" -> " + r.Outbound) + builder.WriteString(" -> " + r.Outbound.String(true)) return builder.String() } diff --git a/pkg/config_parser/walker.go b/pkg/config_parser/walker.go index af58239..0d2da83 100644 --- a/pkg/config_parser/walker.go +++ b/pkg/config_parser/walker.go @@ -3,6 +3,8 @@ * Copyright (c) 2022-2023, v2rayA Organization */ +// This file should trace https://github.com/v2rayA/dae-config-dist/blob/main/dae_config.g4. + package config_parser import ( @@ -209,10 +211,18 @@ func (w *Walker) parseRoutingRule(ctx dae_config.IRoutingRuleContext) *RoutingRu andFunctions := w.parseFunctionPrototypeExpression(functionList, nil) // Parse outbound. - outbound := children[2].(*dae_config.Bare_literalContext).GetText() + outboundExpr := children[2].(*dae_config.OutboundExprContext) + var outbound *Function + if literal := outboundExpr.Bare_literal(); literal != nil { + outbound = &Function{Name: literal.GetText()} + } else if f := outboundExpr.FunctionPrototype(); f != nil { + outbound = w.parseFunctionPrototype(f.(*dae_config.FunctionPrototypeContext), nil) + } else { + panic("unknown outboundExpr") + } return &RoutingRule{ AndFunctions: andFunctions, - Outbound: outbound, + Outbound: *outbound, } } diff --git a/routing.md b/routing.md index 4fd922c..312c318 100644 --- a/routing.md +++ b/routing.md @@ -3,12 +3,15 @@ ## Examples: ```shell -# Built-in outbounds: block, direct +### Built-in outbounds: block, direct, must_direct +# The difference between "direct" and "must_direct" is that "direct" will intercept and process DNS request (for traffic +# split use), but "must_direct" will not. "must_direct" is useful when there are traffic loops of DNS requests. +### fallback outbound # If no rule matches, traffic will go through the outbound defined by fallback. fallback: my_group -# Domain rule +### Domain rule domain(suffix: v2raya.org) -> my_group # equals to domain(v2raya.org) -> my_group domain(full: dns.google) -> my_group @@ -17,63 +20,80 @@ domain(regexp: '\.goo.*\.com$') -> my_group domain(geosite:category-ads) -> block domain(geosite:cn)->direct -# Dest IP rule +### Dest IP rule ip(8.8.8.8) -> direct ip(101.97.0.0/16) -> direct ip(geoip:private) -> direct -# Source IP rule +### Source IP rule sip(192.168.0.0/24) -> my_group sip(192.168.50.0/24) -> direct -# Dest port rule +### Dest port rule port(80) -> direct port(10080-30000) -> direct -# Source port rule +### Source port rule sport(38563) -> direct sport(10080-30000) -> direct -# Level 4 protocol rule: +### Level 4 protocol rule: l4proto(tcp) -> my_group l4proto(udp) -> direct -# IP version rule: +### IP version rule: ipversion(4) -> block ipversion(6) -> ipv6_group -# Source MAC rule +### Source MAC rule mac('02:42:ac:11:00:02') -> direct -# Process Name rule (only support local process) +### Process Name rule (only support localhost process when binding to WAN) pname(curl) -> direct -# Multiple domains rule +### Multiple domains rule domain(keyword: google, suffix: www.twitter.com, suffix: v2raya.org) -> my_group -# Multiple IP rule +### Multiple IP rule ip(geoip:cn, geoip:private) -> direct ip(9.9.9.9, 223.5.5.5) -> direct sip(192.168.0.6, 192.168.0.10, 192.168.0.15) -> direct -# 'And' rule +### 'And' rule ip(geoip:cn) && port(80) -> direct ip(8.8.8.8) && l4proto(tcp) && port(1-1023, 8443) -> my_group ip(1.1.1.1) && sip(10.0.0.1, 172.20.0.0/16) -> direct -# 'Not' rule +### 'Not' rule !domain(geosite:google-scholar, geosite:category-scholar-!cn, geosite:category-scholar-cn ) -> my_group -# Little more complex rule +### Little more complex rule domain(geosite:geolocation-!cn) && !domain(geosite:google-scholar, geosite:category-scholar-!cn, geosite:category-scholar-cn ) -> my_group -# Customized DAT file +### Customized DAT file domain(ext:"yourdatfile.dat:yourtag")->direct ip(ext:"yourdatfile.dat:yourtag")->direct + +### Mark for direct/must_direct outbound +# Mark is useful when you want to redirect traffic to specific interface (such as wireguard) or other advanced uses. +# Traffic from LAN will not be forwarded by dae to archive higher performance if lan_nat_direct is off (you can set it +# off only if you are sure dae is on a bridge device). + +# An example of redirecting Disney traffic to wg0 is given here. +# You need set ip rule and ip table like this: +# 1. Set all traffic with mark 0x800/0x800 to use route table 1145: +# >> ip rule add fwmark 0x800/0x800 table 1145 +# >> ip -6 rule add fwmark 0x800/0x800 table 1145 +# 2. Set default route of route table 1145: +# >> ip route add default dev wg0 scope global table 1145 +# >> ip -6 route add default dev wg0 scope global table 1145 +# Notice that interface wg0, mark 0x800, table 1145 can be set by preferences, but cannot conflict. +# 3. Set routing rules in dae config file. +domain(geosite:disney) -> direct(mark: 0x800) ```