feat: add virtual must_rules outbound

This commit is contained in:
mzz2017 2023-04-02 12:02:57 +08:00
parent 006b7fbfd2
commit 648710a40e
10 changed files with 65 additions and 32 deletions

View File

@ -53,6 +53,7 @@ const (
MatchType_Mac MatchType_Mac
MatchType_ProcessName MatchType_ProcessName
MatchType_Fallback MatchType_Fallback
MatchType_MustRules
MatchType_Upstream MatchType_Upstream
MatchType_QType MatchType_QType
@ -64,17 +65,19 @@ const (
OutboundDirect OutboundIndex = iota OutboundDirect OutboundIndex = iota
OutboundBlock OutboundBlock
OutboundMustRules OutboundIndex = 0xFC
OutboundControlPlaneRouting OutboundIndex = 0xFD OutboundControlPlaneRouting OutboundIndex = 0xFD
OutboundLogicalOr OutboundIndex = 0xFE OutboundLogicalOr OutboundIndex = 0xFE
OutboundLogicalAnd OutboundIndex = 0xFF OutboundLogicalAnd OutboundIndex = 0xFF
OutboundLogicalMask OutboundIndex = 0xFE OutboundLogicalMask OutboundIndex = 0xFE
OutboundMax = OutboundLogicalAnd OutboundUserDefinedMax = OutboundMustRules - 1
OutboundUserDefinedMax = OutboundControlPlaneRouting - 1
) )
func (i OutboundIndex) String() string { func (i OutboundIndex) String() string {
switch i { switch i {
case OutboundMustRules:
return "must_rules"
case OutboundDirect: case OutboundDirect:
return "direct" return "direct"
case OutboundBlock: case OutboundBlock:

View File

@ -31,6 +31,10 @@ func patchEmptyDns(params *Config) error {
func patchMustOutbound(params *Config) error { func patchMustOutbound(params *Config) error {
for i := range params.Routing.Rules { for i := range params.Routing.Rules {
if strings.HasPrefix(params.Routing.Rules[i].Outbound.Name, "must_") { if strings.HasPrefix(params.Routing.Rules[i].Outbound.Name, "must_") {
if params.Routing.Rules[i].Outbound.Name == "must_rules" {
// Reserve must_rules.
continue
}
params.Routing.Rules[i].Outbound.Name = strings.TrimPrefix(params.Routing.Rules[i].Outbound.Name, "must_") params.Routing.Rules[i].Outbound.Name = strings.TrimPrefix(params.Routing.Rules[i].Outbound.Name, "must_")
params.Routing.Rules[i].Outbound.Params = append(params.Routing.Rules[i].Outbound.Params, &config_parser.Param{ params.Routing.Rules[i].Outbound.Params = append(params.Routing.Rules[i].Outbound.Params, &config_parser.Param{
Val: "must", Val: "must",

View File

@ -753,7 +753,7 @@ func (c *ControlPlane) chooseBestDnsDialer(
default: default:
return nil, fmt.Errorf("unexpected ipversion: %v", ver) return nil, fmt.Errorf("unexpected ipversion: %v", ver)
} }
outboundIndex, mark, err := c.Route(req.realSrc, netip.AddrPortFrom(dAddr, dnsUpstream.Port), "", proto.ToL4ProtoType(), req.routingResult) outboundIndex, mark, _, err := c.Route(req.realSrc, netip.AddrPortFrom(dAddr, dnsUpstream.Port), "", proto.ToL4ProtoType(), req.routingResult)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -60,6 +60,7 @@
#define OUTBOUND_DIRECT 0 #define OUTBOUND_DIRECT 0
#define OUTBOUND_BLOCK 1 #define OUTBOUND_BLOCK 1
#define OUTBOUND_MUST_RULES 0xFC
#define OUTBOUND_CONTROL_PLANE_ROUTING 0xFD #define OUTBOUND_CONTROL_PLANE_ROUTING 0xFD
#define OUTBOUND_LOGICAL_OR 0xFE #define OUTBOUND_LOGICAL_OR 0xFE
#define OUTBOUND_LOGICAL_AND 0xFF #define OUTBOUND_LOGICAL_AND 0xFF
@ -1000,8 +1001,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
// Rule is like: domain(suffix:baidu.com, suffix:google.com) && port(443) -> // Rule is like: domain(suffix:baidu.com, suffix:google.com) && port(443) ->
// proxy Subrule is like: domain(suffix:baidu.com, suffix:google.com) Match // proxy Subrule is like: domain(suffix:baidu.com, suffix:google.com) Match
// set is like: suffix:baidu.com // set is like: suffix:baidu.com
bool bad_rule = false; __u8 must_goodsubrule_badrule = 0;
bool good_subrule = false;
struct domain_routing *domain_routing; struct domain_routing *domain_routing;
__u32 *p_u32; __u32 *p_u32;
__u16 *p_u16; __u16 *p_u16;
@ -1015,7 +1015,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
if (unlikely(!match_set)) { if (unlikely(!match_set)) {
return -EFAULT; return -EFAULT;
} }
if (bad_rule || good_subrule) { if ((must_goodsubrule_badrule & 0b1) || (must_goodsubrule_badrule & 0b10)) {
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
key = match_set->type; key = match_set->type;
bpf_printk("key(match_set->type): %llu", key); bpf_printk("key(match_set->type): %llu", key);
@ -1041,7 +1041,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
} }
if (bpf_map_lookup_elem(lpm, lpm_key)) { if (bpf_map_lookup_elem(lpm, lpm_key)) {
// match_set hits. // match_set hits.
good_subrule = true; must_goodsubrule_badrule |= 0b10;
} }
} else if ((p_u16 = bpf_map_lookup_elem(&h_port_map, &key))) { } else if ((p_u16 = bpf_map_lookup_elem(&h_port_map, &key))) {
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
@ -1054,7 +1054,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
#endif #endif
if (*p_u16 >= match_set->port_range.port_start && if (*p_u16 >= match_set->port_range.port_start &&
*p_u16 <= match_set->port_range.port_end) { *p_u16 <= match_set->port_range.port_end) {
good_subrule = true; must_goodsubrule_badrule |= 0b10;
} }
} else if ((p_u32 = bpf_map_lookup_elem(&l4proto_ipversion_map, &key))) { } else if ((p_u32 = bpf_map_lookup_elem(&l4proto_ipversion_map, &key))) {
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
@ -1063,7 +1063,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
match_set->type, match_set->not, match_set->outbound); match_set->type, match_set->not, match_set->outbound);
#endif #endif
if (*p_u32 & *(__u32 *)&match_set->__value) { if (*p_u32 & *(__u32 *)&match_set->__value) {
good_subrule = true; must_goodsubrule_badrule |= 0b10;
} }
} else if (match_set->type == MatchType_DomainSet) { } else if (match_set->type == MatchType_DomainSet) {
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
@ -1081,17 +1081,17 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
// We use key instead of k to pass checker. // We use key instead of k to pass checker.
if ((domain_routing->bitmap[i / 32] >> (i % 32)) & 1) { if ((domain_routing->bitmap[i / 32] >> (i % 32)) & 1) {
good_subrule = true; must_goodsubrule_badrule |= 0b10;
} }
} else if (match_set->type == MatchType_ProcessName) { } else if (match_set->type == MatchType_ProcessName) {
if (_is_wan && equal16(match_set->pname, _pname)) { if (_is_wan && equal16(match_set->pname, _pname)) {
good_subrule = true; must_goodsubrule_badrule |= 0b10;
} }
} else if (match_set->type == MatchType_Fallback) { } else if (match_set->type == MatchType_Fallback) {
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
bpf_printk("CHECK: hit fallback"); bpf_printk("CHECK: hit fallback");
#endif #endif
good_subrule = true; must_goodsubrule_badrule |= 0b10;
} else { } else {
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
bpf_printk("CHECK: <unknown>, match_set->type: %u, not: %d, " bpf_printk("CHECK: <unknown>, match_set->type: %u, not: %d, "
@ -1110,13 +1110,13 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
// We are now at end of rule, or next match_set belongs to another // We are now at end of rule, or next match_set belongs to another
// subrule. // subrule.
if (good_subrule == match_set->not ) { if ((must_goodsubrule_badrule & 0b10) > 0 == match_set->not ) {
// This subrule does not hit. // This subrule does not hit.
bad_rule = true; must_goodsubrule_badrule |= 0b1;
} }
// Reset good_subrule. // Reset good_subrule.
good_subrule = false; must_goodsubrule_badrule &= ~0b10;
} }
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
bpf_printk("_bad_rule: %d", bad_rule); bpf_printk("_bad_rule: %d", bad_rule);
@ -1125,7 +1125,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
OUTBOUND_LOGICAL_MASK) { OUTBOUND_LOGICAL_MASK) {
// Tail of a rule (line). // Tail of a rule (line).
// Decide whether to hit. // Decide whether to hit.
if (!bad_rule) { if (!(must_goodsubrule_badrule & 0b1)) {
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
bpf_printk("MATCHED: match_set->type: %u, match_set->not: %d", bpf_printk("MATCHED: match_set->type: %u, match_set->not: %d",
match_set->type, match_set->not ); match_set->type, match_set->not );
@ -1133,6 +1133,13 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
// DNS requests should routed by control plane if outbound is not // DNS requests should routed by control plane if outbound is not
// must_direct. // must_direct.
if (match_set->outbound == OUTBOUND_MUST_RULES) {
must_goodsubrule_badrule |= 0b100;
continue;
}
if (must_goodsubrule_badrule & 0b100) {
match_set->must = true;
}
if (!match_set->must && h_dport == 53 && if (!match_set->must && h_dport == 53 &&
_l4proto_type == L4ProtoType_UDP) { _l4proto_type == L4ProtoType_UDP) {
return (__s64)OUTBOUND_CONTROL_PLANE_ROUTING | return (__s64)OUTBOUND_CONTROL_PLANE_ROUTING |
@ -1141,7 +1148,7 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
return (__s64)match_set->outbound | ((__s64)match_set->mark << 8) | return (__s64)match_set->outbound | ((__s64)match_set->mark << 8) |
((__s64)match_set->must << 40); ((__s64)match_set->must << 40);
} }
bad_rule = false; must_goodsubrule_badrule &= ~0b1;
} }
} }
bpf_printk("No match_set hits. Did coder forget to sync " bpf_printk("No match_set hits. Did coder forget to sync "

View File

@ -60,6 +60,8 @@ func (b *RoutingMatcherBuilder) outboundToId(outbound string) (uint8, error) {
outboundId = uint8(consts.OutboundLogicalOr) outboundId = uint8(consts.OutboundLogicalOr)
case consts.OutboundLogicalAnd.String(): case consts.OutboundLogicalAnd.String():
outboundId = uint8(consts.OutboundLogicalAnd) outboundId = uint8(consts.OutboundLogicalAnd)
case consts.OutboundMustRules.String():
outboundId = uint8(consts.OutboundMustRules)
default: default:
var ok bool var ok bool
outboundId, ok = b.outboundName2Id[outbound] outboundId, ok = b.outboundName2Id[outbound]

View File

@ -33,9 +33,9 @@ func (m *RoutingMatcher) Match(
domain string, domain string,
processName [16]uint8, processName [16]uint8,
mac []byte, mac []byte,
) (outboundIndex consts.OutboundIndex, mark uint32, err error) { ) (outboundIndex consts.OutboundIndex, mark uint32, must bool, err error) {
if len(sourceAddr) != net.IPv6len || len(destAddr) != net.IPv6len || len(mac) != net.IPv6len { if len(sourceAddr) != net.IPv6len || len(destAddr) != net.IPv6len || len(mac) != net.IPv6len {
return 0, 0, fmt.Errorf("bad address length") return 0, 0, false, fmt.Errorf("bad address length")
} }
lpmKeys := make([]*_bpfLpmKey, consts.MatchType_Mac+1) lpmKeys := make([]*_bpfLpmKey, consts.MatchType_Mac+1)
lpmKeys[consts.MatchType_IpSet] = &_bpfLpmKey{ lpmKeys[consts.MatchType_IpSet] = &_bpfLpmKey{
@ -110,7 +110,7 @@ func (m *RoutingMatcher) Match(
case consts.MatchType_Fallback: case consts.MatchType_Fallback:
goodSubrule = true goodSubrule = true
default: default:
return 0, 0, fmt.Errorf("unknown match type: %v", match.Type) return 0, 0, false, fmt.Errorf("unknown match type: %v", match.Type)
} }
beforeNextLoop: beforeNextLoop:
outbound := consts.OutboundIndex(match.Outbound) outbound := consts.OutboundIndex(match.Outbound)
@ -133,10 +133,17 @@ func (m *RoutingMatcher) Match(
// Tail of a rule (line). // Tail of a rule (line).
// Decide whether to hit. // Decide whether to hit.
if !badRule { if !badRule {
return outbound, match.Mark, nil if outbound == consts.OutboundMustRules {
must = true
continue
}
if must {
match.Must = true
}
return outbound, match.Mark, match.Must, nil
} }
badRule = false badRule = false
} }
} }
return 0, 0, fmt.Errorf("no match set hit") return 0, 0, false, fmt.Errorf("no match set hit")
} }

View File

@ -71,7 +71,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
switch outboundIndex { switch outboundIndex {
case consts.OutboundDirect: case consts.OutboundDirect:
case consts.OutboundControlPlaneRouting: case consts.OutboundControlPlaneRouting:
if outboundIndex, routingResult.Mark, err = c.Route(src, dst, domain, consts.L4ProtoType_TCP, routingResult); err != nil { if outboundIndex, routingResult.Mark, _, err = c.Route(src, dst, domain, consts.L4ProtoType_TCP, routingResult); err != nil {
return err return err
} }
routingResult.Outbound = uint8(outboundIndex) routingResult.Outbound = uint8(outboundIndex)

View File

@ -154,7 +154,7 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
break break
} }
if outboundIndex, routingResult.Mark, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil { if outboundIndex, routingResult.Mark, _, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil {
return err return err
} }
routingResult.Outbound = uint8(outboundIndex) routingResult.Outbound = uint8(outboundIndex)

View File

@ -19,7 +19,7 @@ import (
"syscall" "syscall"
) )
func (c *ControlPlane) Route(src, dst netip.AddrPort, domain string, l4proto consts.L4ProtoType, routingResult *bpfRoutingResult) (outboundIndex consts.OutboundIndex, mark uint32, err error) { func (c *ControlPlane) Route(src, dst netip.AddrPort, domain string, l4proto consts.L4ProtoType, routingResult *bpfRoutingResult) (outboundIndex consts.OutboundIndex, mark uint32, must bool, err error) {
var ipVersion consts.IpVersionType var ipVersion consts.IpVersionType
if dst.Addr().Is4() || dst.Addr().Is4In6() { if dst.Addr().Is4() || dst.Addr().Is4In6() {
ipVersion = consts.IpVersion_4 ipVersion = consts.IpVersion_4
@ -28,7 +28,7 @@ func (c *ControlPlane) Route(src, dst netip.AddrPort, domain string, l4proto con
} }
bSrc := src.Addr().As16() bSrc := src.Addr().As16()
bDst := dst.Addr().As16() bDst := dst.Addr().As16()
if outboundIndex, mark, err = c.routingMatcher.Match( if outboundIndex, mark, must, err = c.routingMatcher.Match(
bSrc[:], bSrc[:],
bDst[:], bDst[:],
src.Port(), src.Port(),
@ -39,10 +39,10 @@ func (c *ControlPlane) Route(src, dst netip.AddrPort, domain string, l4proto con
routingResult.Pname, routingResult.Pname,
append([]uint8{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, routingResult.Mac[:]...), append([]uint8{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, routingResult.Mac[:]...),
); err != nil { ); err != nil {
return 0, 0, err return 0, 0, false, err
} }
return outboundIndex, mark, nil return outboundIndex, mark, false, nil
} }
func (c *controlPlaneCore) RetrieveRoutingResult(src, dst netip.AddrPort, l4proto uint8) (result *bpfRoutingResult, err error) { func (c *controlPlaneCore) RetrieveRoutingResult(src, dst netip.AddrPort, l4proto uint8) (result *bpfRoutingResult, err error) {

View File

@ -3,10 +3,12 @@
## Examples: ## Examples:
```shell ```shell
### Built-in outbounds: block, direct ### Built-in outbounds: block, direct, must_rules
# The difference between "direct" and "must_direct" is that "direct" will hijack and process DNS request (for traffic
# split use), but "must_direct" will not. "must_direct" is useful when there are traffic loops of DNS requests. # must_rules means no redirecting DNS traffic to dae and continue to matching.
# "must_direct" can be written as "direct(must)". # For single rule, the difference between "direct" and "must_direct" is that "direct" will hijack and process DNS request
# (for traffic split use), but "must_direct" will not. "must_direct" is useful when there are traffic loops of DNS requests.
# "must_direct" can also be written as "direct(must)".
# Similarly, "must_groupname" is also supported to NOT hijack and process DNS traffic, which equals to "groupname(must)". # Similarly, "must_groupname" is also supported to NOT hijack and process DNS traffic, which equals to "groupname(must)".
### fallback outbound ### fallback outbound
@ -98,4 +100,12 @@ dip(ext:"yourdatfile.dat:yourtag")->direct
# Notice that interface wg0, mark 0x800, table 1145 can be set by preferences, but cannot conflict. # Notice that interface wg0, mark 0x800, table 1145 can be set by preferences, but cannot conflict.
# 3. Set routing rules in dae config file. # 3. Set routing rules in dae config file.
domain(geosite:disney) -> direct(mark: 0x800) domain(geosite:disney) -> direct(mark: 0x800)
### Must rules
# For following rules, DNS requests will be forcibly redirected to dae except from mosdns.
# Different from must_direct/must_my_group, traffic from mosdns will continue to match other rules.
pname(mosdns) -> must_rules
ip(geoip:cn) -> direct
domain(geosite:cn) -> direct
fallback: my_group
``` ```