diff --git a/common/consts/ebpf.go b/common/consts/ebpf.go index 82f82b8..41ed6be 100644 --- a/common/consts/ebpf.go +++ b/common/consts/ebpf.go @@ -53,6 +53,7 @@ const ( MatchType_Mac MatchType_ProcessName MatchType_Fallback + MatchType_MustRules MatchType_Upstream MatchType_QType @@ -64,17 +65,19 @@ const ( OutboundDirect OutboundIndex = iota OutboundBlock + OutboundMustRules OutboundIndex = 0xFC OutboundControlPlaneRouting OutboundIndex = 0xFD OutboundLogicalOr OutboundIndex = 0xFE OutboundLogicalAnd OutboundIndex = 0xFF OutboundLogicalMask OutboundIndex = 0xFE - OutboundMax = OutboundLogicalAnd - OutboundUserDefinedMax = OutboundControlPlaneRouting - 1 + OutboundUserDefinedMax = OutboundMustRules - 1 ) func (i OutboundIndex) String() string { switch i { + case OutboundMustRules: + return "must_rules" case OutboundDirect: return "direct" case OutboundBlock: diff --git a/config/patch.go b/config/patch.go index 215600d..438f287 100644 --- a/config/patch.go +++ b/config/patch.go @@ -31,6 +31,10 @@ func patchEmptyDns(params *Config) error { func patchMustOutbound(params *Config) error { for i := range params.Routing.Rules { if strings.HasPrefix(params.Routing.Rules[i].Outbound.Name, "must_") { + if params.Routing.Rules[i].Outbound.Name == "must_rules" { + // Reserve must_rules. + continue + } params.Routing.Rules[i].Outbound.Name = strings.TrimPrefix(params.Routing.Rules[i].Outbound.Name, "must_") params.Routing.Rules[i].Outbound.Params = append(params.Routing.Rules[i].Outbound.Params, &config_parser.Param{ Val: "must", diff --git a/control/control_plane.go b/control/control_plane.go index ded2eac..e94eed8 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -753,7 +753,7 @@ func (c *ControlPlane) chooseBestDnsDialer( default: return nil, fmt.Errorf("unexpected ipversion: %v", ver) } - outboundIndex, mark, err := c.Route(req.realSrc, netip.AddrPortFrom(dAddr, dnsUpstream.Port), "", proto.ToL4ProtoType(), req.routingResult) + outboundIndex, mark, _, err := c.Route(req.realSrc, netip.AddrPortFrom(dAddr, dnsUpstream.Port), "", proto.ToL4ProtoType(), req.routingResult) if err != nil { return nil, err } diff --git a/control/kern/tproxy.c b/control/kern/tproxy.c index 8da8412..daf7941 100644 --- a/control/kern/tproxy.c +++ b/control/kern/tproxy.c @@ -60,6 +60,7 @@ #define OUTBOUND_DIRECT 0 #define OUTBOUND_BLOCK 1 +#define OUTBOUND_MUST_RULES 0xFC #define OUTBOUND_CONTROL_PLANE_ROUTING 0xFD #define OUTBOUND_LOGICAL_OR 0xFE #define OUTBOUND_LOGICAL_AND 0xFF @@ -1000,8 +1001,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], // Rule is like: domain(suffix:baidu.com, suffix:google.com) && port(443) -> // proxy Subrule is like: domain(suffix:baidu.com, suffix:google.com) Match // set is like: suffix:baidu.com - bool bad_rule = false; - bool good_subrule = false; + __u8 must_goodsubrule_badrule = 0; struct domain_routing *domain_routing; __u32 *p_u32; __u16 *p_u16; @@ -1015,7 +1015,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], if (unlikely(!match_set)) { return -EFAULT; } - if (bad_rule || good_subrule) { + if ((must_goodsubrule_badrule & 0b1) || (must_goodsubrule_badrule & 0b10)) { #ifdef __DEBUG_ROUTING key = match_set->type; bpf_printk("key(match_set->type): %llu", key); @@ -1041,7 +1041,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], } if (bpf_map_lookup_elem(lpm, lpm_key)) { // match_set hits. - good_subrule = true; + must_goodsubrule_badrule |= 0b10; } } else if ((p_u16 = bpf_map_lookup_elem(&h_port_map, &key))) { #ifdef __DEBUG_ROUTING @@ -1054,7 +1054,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], #endif if (*p_u16 >= match_set->port_range.port_start && *p_u16 <= match_set->port_range.port_end) { - good_subrule = true; + must_goodsubrule_badrule |= 0b10; } } else if ((p_u32 = bpf_map_lookup_elem(&l4proto_ipversion_map, &key))) { #ifdef __DEBUG_ROUTING @@ -1063,7 +1063,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], match_set->type, match_set->not, match_set->outbound); #endif if (*p_u32 & *(__u32 *)&match_set->__value) { - good_subrule = true; + must_goodsubrule_badrule |= 0b10; } } else if (match_set->type == MatchType_DomainSet) { #ifdef __DEBUG_ROUTING @@ -1081,17 +1081,17 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], // We use key instead of k to pass checker. if ((domain_routing->bitmap[i / 32] >> (i % 32)) & 1) { - good_subrule = true; + must_goodsubrule_badrule |= 0b10; } } else if (match_set->type == MatchType_ProcessName) { if (_is_wan && equal16(match_set->pname, _pname)) { - good_subrule = true; + must_goodsubrule_badrule |= 0b10; } } else if (match_set->type == MatchType_Fallback) { #ifdef __DEBUG_ROUTING bpf_printk("CHECK: hit fallback"); #endif - good_subrule = true; + must_goodsubrule_badrule |= 0b10; } else { #ifdef __DEBUG_ROUTING bpf_printk("CHECK: , match_set->type: %u, not: %d, " @@ -1110,13 +1110,13 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], // We are now at end of rule, or next match_set belongs to another // subrule. - if (good_subrule == match_set->not ) { + if ((must_goodsubrule_badrule & 0b10) > 0 == match_set->not ) { // This subrule does not hit. - bad_rule = true; + must_goodsubrule_badrule |= 0b1; } // Reset good_subrule. - good_subrule = false; + must_goodsubrule_badrule &= ~0b10; } #ifdef __DEBUG_ROUTING bpf_printk("_bad_rule: %d", bad_rule); @@ -1125,7 +1125,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], OUTBOUND_LOGICAL_MASK) { // Tail of a rule (line). // Decide whether to hit. - if (!bad_rule) { + if (!(must_goodsubrule_badrule & 0b1)) { #ifdef __DEBUG_ROUTING bpf_printk("MATCHED: match_set->type: %u, match_set->not: %d", match_set->type, match_set->not ); @@ -1133,6 +1133,13 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], // DNS requests should routed by control plane if outbound is not // must_direct. + if (match_set->outbound == OUTBOUND_MUST_RULES) { + must_goodsubrule_badrule |= 0b100; + continue; + } + if (must_goodsubrule_badrule & 0b100) { + match_set->must = true; + } if (!match_set->must && h_dport == 53 && _l4proto_type == L4ProtoType_UDP) { return (__s64)OUTBOUND_CONTROL_PLANE_ROUTING | @@ -1141,7 +1148,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], return (__s64)match_set->outbound | ((__s64)match_set->mark << 8) | ((__s64)match_set->must << 40); } - bad_rule = false; + must_goodsubrule_badrule &= ~0b1; } } bpf_printk("No match_set hits. Did coder forget to sync " diff --git a/control/routing_matcher_builder.go b/control/routing_matcher_builder.go index 7417d0d..8aa6329 100644 --- a/control/routing_matcher_builder.go +++ b/control/routing_matcher_builder.go @@ -60,6 +60,8 @@ func (b *RoutingMatcherBuilder) outboundToId(outbound string) (uint8, error) { outboundId = uint8(consts.OutboundLogicalOr) case consts.OutboundLogicalAnd.String(): outboundId = uint8(consts.OutboundLogicalAnd) + case consts.OutboundMustRules.String(): + outboundId = uint8(consts.OutboundMustRules) default: var ok bool outboundId, ok = b.outboundName2Id[outbound] diff --git a/control/routing_matcher_userspace.go b/control/routing_matcher_userspace.go index 7eeb471..b371d60 100644 --- a/control/routing_matcher_userspace.go +++ b/control/routing_matcher_userspace.go @@ -33,9 +33,9 @@ func (m *RoutingMatcher) Match( domain string, processName [16]uint8, mac []byte, -) (outboundIndex consts.OutboundIndex, mark uint32, err error) { +) (outboundIndex consts.OutboundIndex, mark uint32, must bool, err error) { if len(sourceAddr) != net.IPv6len || len(destAddr) != net.IPv6len || len(mac) != net.IPv6len { - return 0, 0, fmt.Errorf("bad address length") + return 0, 0, false, fmt.Errorf("bad address length") } lpmKeys := make([]*_bpfLpmKey, consts.MatchType_Mac+1) lpmKeys[consts.MatchType_IpSet] = &_bpfLpmKey{ @@ -110,7 +110,7 @@ func (m *RoutingMatcher) Match( case consts.MatchType_Fallback: goodSubrule = true default: - return 0, 0, fmt.Errorf("unknown match type: %v", match.Type) + return 0, 0, false, fmt.Errorf("unknown match type: %v", match.Type) } beforeNextLoop: outbound := consts.OutboundIndex(match.Outbound) @@ -133,10 +133,17 @@ func (m *RoutingMatcher) Match( // Tail of a rule (line). // Decide whether to hit. if !badRule { - return outbound, match.Mark, nil + if outbound == consts.OutboundMustRules { + must = true + continue + } + if must { + match.Must = true + } + return outbound, match.Mark, match.Must, nil } badRule = false } } - return 0, 0, fmt.Errorf("no match set hit") + return 0, 0, false, fmt.Errorf("no match set hit") } diff --git a/control/tcp.go b/control/tcp.go index 217b334..0d6f09c 100644 --- a/control/tcp.go +++ b/control/tcp.go @@ -71,7 +71,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) { switch outboundIndex { case consts.OutboundDirect: case consts.OutboundControlPlaneRouting: - if outboundIndex, routingResult.Mark, err = c.Route(src, dst, domain, consts.L4ProtoType_TCP, routingResult); err != nil { + if outboundIndex, routingResult.Mark, _, err = c.Route(src, dst, domain, consts.L4ProtoType_TCP, routingResult); err != nil { return err } routingResult.Outbound = uint8(outboundIndex) diff --git a/control/udp.go b/control/udp.go index fdad8d5..27841a4 100644 --- a/control/udp.go +++ b/control/udp.go @@ -154,7 +154,7 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r break } - if outboundIndex, routingResult.Mark, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil { + if outboundIndex, routingResult.Mark, _, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil { return err } routingResult.Outbound = uint8(outboundIndex) diff --git a/control/utils.go b/control/utils.go index 21f2359..6c8d81b 100644 --- a/control/utils.go +++ b/control/utils.go @@ -19,7 +19,7 @@ import ( "syscall" ) -func (c *ControlPlane) Route(src, dst netip.AddrPort, domain string, l4proto consts.L4ProtoType, routingResult *bpfRoutingResult) (outboundIndex consts.OutboundIndex, mark uint32, err error) { +func (c *ControlPlane) Route(src, dst netip.AddrPort, domain string, l4proto consts.L4ProtoType, routingResult *bpfRoutingResult) (outboundIndex consts.OutboundIndex, mark uint32, must bool, err error) { var ipVersion consts.IpVersionType if dst.Addr().Is4() || dst.Addr().Is4In6() { ipVersion = consts.IpVersion_4 @@ -28,7 +28,7 @@ func (c *ControlPlane) Route(src, dst netip.AddrPort, domain string, l4proto con } bSrc := src.Addr().As16() bDst := dst.Addr().As16() - if outboundIndex, mark, err = c.routingMatcher.Match( + if outboundIndex, mark, must, err = c.routingMatcher.Match( bSrc[:], bDst[:], src.Port(), @@ -39,10 +39,10 @@ func (c *ControlPlane) Route(src, dst netip.AddrPort, domain string, l4proto con routingResult.Pname, append([]uint8{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, routingResult.Mac[:]...), ); err != nil { - return 0, 0, err + return 0, 0, false, err } - return outboundIndex, mark, nil + return outboundIndex, mark, false, nil } func (c *controlPlaneCore) RetrieveRoutingResult(src, dst netip.AddrPort, l4proto uint8) (result *bpfRoutingResult, err error) { diff --git a/docs/routing.md b/docs/routing.md index 2e09795..4b306a6 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -3,10 +3,12 @@ ## Examples: ```shell -### Built-in outbounds: block, direct -# The difference between "direct" and "must_direct" is that "direct" will hijack and process DNS request (for traffic -# split use), but "must_direct" will not. "must_direct" is useful when there are traffic loops of DNS requests. -# "must_direct" can be written as "direct(must)". +### Built-in outbounds: block, direct, must_rules + +# must_rules means no redirecting DNS traffic to dae and continue to matching. +# For single rule, the difference between "direct" and "must_direct" is that "direct" will hijack and process DNS request +# (for traffic split use), but "must_direct" will not. "must_direct" is useful when there are traffic loops of DNS requests. +# "must_direct" can also be written as "direct(must)". # Similarly, "must_groupname" is also supported to NOT hijack and process DNS traffic, which equals to "groupname(must)". ### fallback outbound @@ -98,4 +100,12 @@ dip(ext:"yourdatfile.dat:yourtag")->direct # Notice that interface wg0, mark 0x800, table 1145 can be set by preferences, but cannot conflict. # 3. Set routing rules in dae config file. domain(geosite:disney) -> direct(mark: 0x800) + +### Must rules +# For following rules, DNS requests will be forcibly redirected to dae except from mosdns. +# Different from must_direct/must_my_group, traffic from mosdns will continue to match other rules. +pname(mosdns) -> must_rules +ip(geoip:cn) -> direct +domain(geosite:cn) -> direct +fallback: my_group ```