diff --git a/.gitignore b/.gitignore index d382f9d..3a83dd6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .vscode .idea *.o +*.tmp bpf_bpfeb.go bpf_bpfel.go dae diff --git a/common/consts/ebpf.go b/common/consts/ebpf.go index 469a293..744aea0 100644 --- a/common/consts/ebpf.go +++ b/common/consts/ebpf.go @@ -6,9 +6,8 @@ package consts const ( - AppName = "dae" - MaxInterfaceIpNum = 8 - BpfPinRoot = "/sys/fs/bpf" + AppName = "dae" + BpfPinRoot = "/sys/fs/bpf" AddrHdrSize = 20 ) @@ -49,7 +48,8 @@ type OutboundIndex uint8 const ( OutboundDirect OutboundIndex = 0 OutboundBlock OutboundIndex = 1 - OutboundControlPlaneDirect OutboundIndex = 0xFE + OutboundControlPlaneDirect OutboundIndex = 0xFD + OutboundLogicalOr OutboundIndex = 0xFE OutboundLogicalAnd OutboundIndex = 0xFF ) @@ -61,6 +61,8 @@ func (i OutboundIndex) String() string { return "block" case OutboundControlPlaneDirect: return "" + case OutboundLogicalOr: + return "" case OutboundLogicalAnd: return "" default: diff --git a/common/utils.go b/common/utils.go index b8d2abd..819fb21 100644 --- a/common/utils.go +++ b/common/utils.go @@ -106,7 +106,7 @@ func ParseMac(mac string) (addr [6]byte, err error) { return addr, nil } -func ParsePortRange(pr string) (portRange [2]int, err error) { +func ParsePortRange(pr string) (portRange [2]uint16, err error) { fields := strings.SplitN(pr, "-", 2) for i, field := range fields { if field == "" { @@ -119,7 +119,7 @@ func ParsePortRange(pr string) (portRange [2]int, err error) { if port < 0 || port > 0xffff { return portRange, fmt.Errorf("port %v exceeds uint16 range", port) } - portRange[i] = port + portRange[i] = uint16(port) } if len(fields) == 1 { portRange[1] = portRange[0] diff --git a/component/control/bpf_utils.go b/component/control/bpf_utils.go index 4a15414..5070304 100644 --- a/component/control/bpf_utils.go +++ b/component/control/bpf_utils.go @@ -6,6 +6,7 @@ package control import ( + "encoding/binary" "fmt" "github.com/cilium/ebpf" "github.com/v2rayA/dae/common" @@ -19,6 +20,18 @@ type _bpfLpmKey struct { Data [4]uint32 } +type _bpfPortRange struct { + PortStart uint16 + PortEnd uint16 +} + +func (r _bpfPortRange) Encode() uint32 { + var b [4]byte + binary.LittleEndian.PutUint16(b[:2], r.PortStart) + binary.LittleEndian.PutUint16(b[2:], r.PortEnd) + return binary.BigEndian.Uint32(b[:]) +} + func (o *bpfObjects) newLpmMap(keys []_bpfLpmKey, values []uint32) (m *ebpf.Map, err error) { m, err = ebpf.NewMap(&ebpf.MapSpec{ Type: ebpf.LPMTrie, diff --git a/component/control/control_plane.go b/component/control/control_plane.go index 13783b4..8ec9c66 100644 --- a/component/control/control_plane.go +++ b/component/control/control_plane.go @@ -26,6 +26,7 @@ import ( "net/netip" "os" "path/filepath" + "reflect" "strconv" "strings" "sync" @@ -70,11 +71,18 @@ func NewControlPlane( // Load pre-compiled programs and maps into the kernel. var bpf bpfObjects + var ProgramOptions ebpf.ProgramOptions + if log.IsLevelEnabled(logrus.TraceLevel) { + ProgramOptions = ebpf.ProgramOptions{ + LogLevel: ebpf.LogLevelStats, + } + } retryLoadBpf: if err = loadBpfObjects(&bpf, &ebpf.CollectionOptions{ Maps: ebpf.MapOptions{ PinPath: pinPath, }, + Programs: ProgramOptions, }); err != nil { if errors.Is(err, ebpf.ErrMapIncompatible) { // Map property is incompatible. Remove the old map and try again. @@ -88,7 +96,17 @@ retryLoadBpf: log.Warnf("New map format was incompatible with existing map %v, and the old one was removed.", mapName) goto retryLoadBpf } - return nil, fmt.Errorf("loading objects: %w", err) + // Get detailed log from ebpf.internal.*VerifierError + if log.IsLevelEnabled(logrus.TraceLevel) { + if v := reflect.Indirect(reflect.ValueOf(errors.Unwrap(errors.Unwrap(err)))); v.Kind() == reflect.Struct { + if log := v.FieldByName("Log"); log.IsValid() { + if strSlice, ok := log.Interface().([]string); ok { + err = fmt.Errorf("%v", strings.Join(strSlice, "\n")) + } + } + } + } + return nil, fmt.Errorf("loading objects: %v", err) } // Write params. if err = bpf.ParamMap.Update(consts.DisableL4TxChecksumKey, consts.DisableL4ChecksumPolicy_SetZero, ebpf.UpdateAny); err != nil { diff --git a/component/control/kern/tproxy.c b/component/control/kern/tproxy.c index 0188be2..eb1fa4b 100644 --- a/component/control/kern/tproxy.c +++ b/component/control/kern/tproxy.c @@ -42,15 +42,16 @@ #define OUTBOUND_DIRECT 0 #define OUTBOUND_BLOCK 1 -#define OUTBOUND_CONTROL_PLANE_DIRECT 0xFE +#define OUTBOUND_CONTROL_PLANE_DIRECT 0xFD +#define OUTBOUND_LOGICAL_OR 0xFE #define OUTBOUND_LOGICAL_AND 0xFF +#define OUTBOUND_LOGICAL_MASK 0xFE enum { DISABLE_L4_CHECKSUM_POLICY_ENABLE_L4_CHECKSUM, DISABLE_L4_CHECKSUM_POLICY_RESTORE, DISABLE_L4_CHECKSUM_POLICY_SET_ZERO, }; -#define OUTBOUND_LOGICAL_AND 0xFF // Param keys: static const __u32 zero_key = 0; @@ -208,7 +209,17 @@ struct port_range { __u16 port_start; __u16 port_end; }; -struct routing { + +/* + Look at following rule: + + domain(geosite:cn, suffix: google.com) && l4proto(tcp) -> my_group + + pseudocode: domain(geosite:cn || suffix:google.com) && l4proto(tcp) -> my_group + + A match_set can be: IP set geosite:cn, suffix google.com, tcp proto + */ +struct match_set { union { __u32 __value; // Placeholder for bpf2go. @@ -218,13 +229,13 @@ struct routing { enum IP_VERSION ip_version; }; enum ROUTING_TYPE type; - __u8 outbound; // 255 means logical AND. 254 means dirty. User-defined value - // range is [0, 253]. + bool not ; // A subrule flag (this is not a match_set flag). + __u8 outbound; // User-defined value range is [0, 252]. }; struct { __uint(type, BPF_MAP_TYPE_ARRAY); __type(key, __u32); - __type(value, struct routing); + __type(value, struct match_set); __uint(max_entries, MAX_ROUTING_LEN); // __uint(pinning, LIBBPF_PIN_BY_NAME); } routing_map SEC(".maps"); @@ -809,42 +820,45 @@ static long routing(__u32 flag[3], void *l4_hdr, __be32 saddr[4], bpf_map_update_elem(&lpm_key_map, &key, &lpm_key_mac, BPF_ANY); struct map_lpm_type *lpm; - struct routing *routing; - // Rule is like: domain(domain:baidu.com) && port(443) -> proxy + struct match_set *match_set; + // Rule is like: domain(suffix:baidu.com, suffix:google.com) && port(443) -> + // proxy Subrule is like: domain(suffix:baidu.com, suffix:google.com) Match + // set is like: suffix:baidu.com bool bad_rule = false; + bool good_subrule = false; struct domain_routing *domain_routing; __u32 *p_u32; #pragma unroll for (__u32 i = 0; i < MAX_ROUTING_LEN; i++) { __u32 k = i; // Clone to pass code checker. - routing = bpf_map_lookup_elem(&routing_map, &k); - if (!routing) { + match_set = bpf_map_lookup_elem(&routing_map, &k); + if (!match_set) { return -EFAULT; } - if (bad_rule) { + if (bad_rule || good_subrule) { goto before_next_loop; } - key = (key & (__u32)0) | (__u32)routing->type; + key = (key & (__u32)0) | (__u32)match_set->type; if ((lpm_key = bpf_map_lookup_elem(&lpm_key_map, &key))) { - lpm = bpf_map_lookup_elem(&lpm_array_map, &routing->index); + lpm = bpf_map_lookup_elem(&lpm_array_map, &match_set->index); if (unlikely(!lpm)) { return -EFAULT; } - if (!bpf_map_lookup_elem(lpm, lpm_key)) { - // Routing not hit. - bad_rule = true; + if (bpf_map_lookup_elem(lpm, lpm_key)) { + // match_set hits. + good_subrule = true; } } else if ((p_u32 = bpf_map_lookup_elem(&h_port_map, &key))) { - if (*p_u32 < routing->port_range.port_start || - *p_u32 > routing->port_range.port_end) { - bad_rule = true; + if (*p_u32 >= match_set->port_range.port_start && + *p_u32 <= match_set->port_range.port_end) { + good_subrule = true; } } else if ((p_u32 = bpf_map_lookup_elem(&l4proto_ipversion_map, &key))) { - if (!(*p_u32 & routing->__value)) { - bad_rule = true; + if (*p_u32 & match_set->__value) { + good_subrule = true; } - } else if (routing->type == ROUTING_TYPE_DOMAIN_SET) { + } else if (match_set->type == ROUTING_TYPE_DOMAIN_SET) { // Bottleneck of insns limit. // We fixed it by invoking bpf_map_lookup_elem here. @@ -852,31 +866,46 @@ static long routing(__u32 flag[3], void *l4_hdr, __be32 saddr[4], domain_routing = bpf_map_lookup_elem(&domain_routing_map, daddr); if (!domain_routing) { // No domain corresponding to IP. - bad_rule = true; goto before_next_loop; } // We use key instead of k to pass checker. - if (!((domain_routing->bitmap[i / 32] >> (i % 32)) & 1)) { - bad_rule = true; + if ((domain_routing->bitmap[i / 32] >> (i % 32)) & 1) { + good_subrule = true; } - } else if (routing->type == ROUTING_TYPE_FINAL) { - bad_rule = false; + } else if (match_set->type == ROUTING_TYPE_FINAL) { + good_subrule = true; } else { return -EINVAL; } before_next_loop: - if (routing->outbound != OUTBOUND_LOGICAL_AND) { + if (match_set->outbound != OUTBOUND_LOGICAL_OR && !bad_rule) { + // This match_set reaches the end of subrule. + // We are now at end of rule, or next match_set belongs to another + // subrule. + if (good_subrule == match_set->not ) { + // This subrule does not hit. + bad_rule = true; + } else { + // This subrule hits. + // Reset the good_subrule flag. + good_subrule = false; + } + } + if ((match_set->outbound & OUTBOUND_LOGICAL_MASK) != + OUTBOUND_LOGICAL_MASK) { // Tail of a rule (line). // Decide whether to hit. if (!bad_rule) { - if (routing->outbound == OUTBOUND_DIRECT && h_dport == 53 && + if (match_set->outbound == OUTBOUND_DIRECT && h_dport == 53 && _l4proto == L4PROTO_TYPE_UDP) { // DNS packet should go through control plane. return OUTBOUND_CONTROL_PLANE_DIRECT; } - return routing->outbound; + // bpf_printk("match_set->type: %d, match_set->not: %d", match_set->type, + // match_set->not ); + return match_set->outbound; } bad_rule = false; } diff --git a/component/control/routing_matcher_builder.go b/component/control/routing_matcher_builder.go index fcae444..49d5e44 100644 --- a/component/control/routing_matcher_builder.go +++ b/component/control/routing_matcher_builder.go @@ -11,6 +11,7 @@ import ( "github.com/v2rayA/dae/common" "github.com/v2rayA/dae/common/consts" "github.com/v2rayA/dae/component/routing" + "github.com/v2rayA/dae/pkg/config_parser" "net/netip" "strconv" ) @@ -25,7 +26,7 @@ type RoutingMatcherBuilder struct { *routing.DefaultMatcherBuilder outboundName2Id map[string]uint8 bpf *bpfObjects - rules []bpfRouting + rules []bpfMatchSet SimulatedLpmTries [][]netip.Prefix SimulatedDomainSet []DomainSet Final string @@ -39,9 +40,12 @@ func NewRoutingMatcherBuilder(outboundName2Id map[string]uint8, bpf *bpfObjects) func (b *RoutingMatcherBuilder) OutboundToId(outbound string) uint8 { var outboundId uint8 - if outbound == routing.FakeOutbound_AND { + switch outbound { + case routing.FakeOutbound_AND: outboundId = uint8(consts.OutboundLogicalAnd) - } else { + case routing.FakeOutbound_OR: + outboundId = uint8(consts.OutboundLogicalOr) + default: var ok bool outboundId, ok = b.outboundName2Id[outbound] if !ok { @@ -51,7 +55,7 @@ func (b *RoutingMatcherBuilder) OutboundToId(outbound string) uint8 { return outboundId } -func (b *RoutingMatcherBuilder) AddDomain(key string, values []string, outbound string) { +func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound string) { if b.err != nil { return } @@ -69,13 +73,14 @@ func (b *RoutingMatcherBuilder) AddDomain(key string, values []string, outbound RuleIndex: len(b.rules), Domains: values, }) - b.rules = append(b.rules, bpfRouting{ + b.rules = append(b.rules, bpfMatchSet{ Type: uint32(consts.RoutingType_DomainSet), + Not: f.Not, Outbound: b.OutboundToId(outbound), }) } -func (b *RoutingMatcherBuilder) AddSourceMac(macAddrs [][6]byte, outbound string) { +func (b *RoutingMatcherBuilder) AddSourceMac(f *config_parser.Function, macAddrs [][6]byte, outbound string) { if b.err != nil { return } @@ -88,58 +93,91 @@ func (b *RoutingMatcherBuilder) AddSourceMac(macAddrs [][6]byte, outbound string } lpmTrieIndex := len(b.SimulatedLpmTries) b.SimulatedLpmTries = append(b.SimulatedLpmTries, values) - b.rules = append(b.rules, bpfRouting{ + b.rules = append(b.rules, bpfMatchSet{ Type: uint32(consts.RoutingType_Mac), Value: uint32(lpmTrieIndex), + Not: f.Not, Outbound: b.OutboundToId(outbound), }) } -func (b *RoutingMatcherBuilder) AddIp(values []netip.Prefix, outbound string) { +func (b *RoutingMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound string) { if b.err != nil { return } lpmTrieIndex := len(b.SimulatedLpmTries) b.SimulatedLpmTries = append(b.SimulatedLpmTries, values) - b.rules = append(b.rules, bpfRouting{ + b.rules = append(b.rules, bpfMatchSet{ Type: uint32(consts.RoutingType_IpSet), Value: uint32(lpmTrieIndex), + Not: f.Not, Outbound: b.OutboundToId(outbound), }) } -func (b *RoutingMatcherBuilder) AddSourceIp(values []netip.Prefix, outbound string) { +func (b *RoutingMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound string) { + for _, value := range values { + b.rules = append(b.rules, bpfMatchSet{ + Type: uint32(consts.RoutingType_Port), + Value: _bpfPortRange{ + PortStart: value[0], + PortEnd: value[1], + }.Encode(), + Not: f.Not, + Outbound: b.OutboundToId(outbound), + }) + } +} + +func (b *RoutingMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound string) { if b.err != nil { return } lpmTrieIndex := len(b.SimulatedLpmTries) b.SimulatedLpmTries = append(b.SimulatedLpmTries, values) - b.rules = append(b.rules, bpfRouting{ + b.rules = append(b.rules, bpfMatchSet{ Type: uint32(consts.RoutingType_SourceIpSet), Value: uint32(lpmTrieIndex), + Not: f.Not, Outbound: b.OutboundToId(outbound), }) } -func (b *RoutingMatcherBuilder) AddL4Proto(values consts.L4ProtoType, outbound string) { +func (b *RoutingMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound string) { + for _, value := range values { + b.rules = append(b.rules, bpfMatchSet{ + Type: uint32(consts.RoutingType_SourcePort), + Value: _bpfPortRange{ + PortStart: value[0], + PortEnd: value[1], + }.Encode(), + Not: f.Not, + Outbound: b.OutboundToId(outbound), + }) + } +} + +func (b *RoutingMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound string) { if b.err != nil { return } - b.rules = append(b.rules, bpfRouting{ + b.rules = append(b.rules, bpfMatchSet{ Type: uint32(consts.RoutingType_L4Proto), Value: uint32(values), + Not: f.Not, Outbound: b.OutboundToId(outbound), }) } -func (b *RoutingMatcherBuilder) AddIpVersion(values consts.IpVersion, outbound string) { +func (b *RoutingMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersion, outbound string) { if b.err != nil { return } - b.rules = append(b.rules, bpfRouting{ + b.rules = append(b.rules, bpfMatchSet{ Type: uint32(consts.RoutingType_IpVersion), Value: uint32(values), + Not: f.Not, Outbound: b.OutboundToId(outbound), }) } @@ -149,7 +187,7 @@ func (b *RoutingMatcherBuilder) AddFinal(outbound string) { return } b.Final = outbound - b.rules = append(b.rules, bpfRouting{ + b.rules = append(b.rules, bpfMatchSet{ Type: uint32(consts.RoutingType_Final), Outbound: b.OutboundToId(outbound), }) @@ -193,3 +231,7 @@ func (b *RoutingMatcherBuilder) Build() (err error) { } return nil } + +//func (b *RoutingMatcherBuilder) AddAnyBefore(f *config_parser.Function, key string, values []string, outbound string) { +// logrus.Debugln(f.Not, f.Name, key, outbound) +//} diff --git a/component/outbound/dialer_group.go b/component/outbound/dialer_group.go index 5a2cf2d..d7b6120 100644 --- a/component/outbound/dialer_group.go +++ b/component/outbound/dialer_group.go @@ -10,63 +10,10 @@ import ( "github.com/sirupsen/logrus" "github.com/v2rayA/dae/common/consts" "github.com/v2rayA/dae/component/outbound/dialer" - "github.com/v2rayA/dae/config" - "github.com/v2rayA/dae/pkg/config_parser" "golang.org/x/net/proxy" "net" - "strconv" ) -type DialerSelectionPolicy struct { - Policy consts.DialerSelectionPolicy - FixedIndex int -} - -func NewDialerSelectionPolicyFromGroupParam(param *config.GroupParam) (policy *DialerSelectionPolicy, err error) { - switch val := param.Policy.(type) { - case string: - switch consts.DialerSelectionPolicy(val) { - case consts.DialerSelectionPolicy_Random, - consts.DialerSelectionPolicy_MinAverage10Latencies, - consts.DialerSelectionPolicy_MinLastLatency: - return &DialerSelectionPolicy{ - Policy: consts.DialerSelectionPolicy(val), - }, nil - case consts.DialerSelectionPolicy_Fixed: - return nil, fmt.Errorf("%v need to specify node index", val) - default: - return nil, fmt.Errorf("unexpected policy: %v", val) - } - case []*config_parser.Function: - if len(val) > 1 || len(val) == 0 { - logrus.Debugf("%@", val) - return nil, fmt.Errorf("policy should be exact 1 function: got %v", len(val)) - } - f := val[0] - switch consts.DialerSelectionPolicy(f.Name) { - case consts.DialerSelectionPolicy_Fixed: - // Should be like: - // policy: fixed(0) - if len(f.Params) > 1 || f.Params[0].Key != "" { - return nil, fmt.Errorf(`invalid "%v" param format`, f.Name) - } - strIndex := f.Params[0].Val - index, err := strconv.Atoi(strIndex) - if len(f.Params) > 1 || f.Params[0].Key != "" { - return nil, fmt.Errorf(`invalid "%v" param format: %w`, f.Name, err) - } - return &DialerSelectionPolicy{ - Policy: consts.DialerSelectionPolicy(f.Name), - FixedIndex: index, - }, nil - default: - return nil, fmt.Errorf("unexpected policy func: %v", f.Name) - } - default: - return nil, fmt.Errorf("unexpected param.Policy.(type): %T", val) - } -} - type DialerGroup struct { proxy.Dialer block *dialer.Dialer diff --git a/component/outbound/dialer_selection_policy.go b/component/outbound/dialer_selection_policy.go new file mode 100644 index 0000000..f35f69c --- /dev/null +++ b/component/outbound/dialer_selection_policy.go @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: AGPL-3.0-only + * Copyright (c) since 2023, mzz2017 + */ + +package outbound + +import ( + "fmt" + "github.com/sirupsen/logrus" + "github.com/v2rayA/dae/common/consts" + "github.com/v2rayA/dae/config" + "github.com/v2rayA/dae/pkg/config_parser" + "strconv" +) + +type DialerSelectionPolicy struct { + Policy consts.DialerSelectionPolicy + FixedIndex int +} + +func NewDialerSelectionPolicyFromGroupParam(param *config.GroupParam) (policy *DialerSelectionPolicy, err error) { + switch val := param.Policy.(type) { + case string: + switch consts.DialerSelectionPolicy(val) { + case consts.DialerSelectionPolicy_Random, + consts.DialerSelectionPolicy_MinAverage10Latencies, + consts.DialerSelectionPolicy_MinLastLatency: + return &DialerSelectionPolicy{ + Policy: consts.DialerSelectionPolicy(val), + }, nil + case consts.DialerSelectionPolicy_Fixed: + return nil, fmt.Errorf("%v need to specify node index", val) + default: + return nil, fmt.Errorf("unexpected policy: %v", val) + } + case []*config_parser.Function: + if len(val) > 1 || len(val) == 0 { + logrus.Debugf("%@", val) + return nil, fmt.Errorf("policy should be exact 1 function: got %v", len(val)) + } + f := val[0] + switch consts.DialerSelectionPolicy(f.Name) { + case consts.DialerSelectionPolicy_Fixed: + // Should be like: + // policy: fixed(0) + if f.Not { + return nil, fmt.Errorf("policy param does not support not operator: !%v()", f.Name) + } + if len(f.Params) > 1 || f.Params[0].Key != "" { + return nil, fmt.Errorf(`invalid "%v" param format`, f.Name) + } + strIndex := f.Params[0].Val + index, err := strconv.Atoi(strIndex) + if len(f.Params) > 1 || f.Params[0].Key != "" { + return nil, fmt.Errorf(`invalid "%v" param format: %w`, f.Name, err) + } + return &DialerSelectionPolicy{ + Policy: consts.DialerSelectionPolicy(f.Name), + FixedIndex: index, + }, nil + default: + return nil, fmt.Errorf("unexpected policy func: %v", f.Name) + } + default: + return nil, fmt.Errorf("unexpected param.Policy.(type): %T", val) + } +} diff --git a/component/outbound/filter.go b/component/outbound/filter.go index 660f129..82866c7 100644 --- a/component/outbound/filter.go +++ b/component/outbound/filter.go @@ -43,6 +43,7 @@ func NewDialerSetFromLinks(option *dialer.GlobalOption, nodes []string) *DialerS func hit(dialer *dialer.Dialer, filters []*config_parser.Function) (hit bool, err error) { // Example // filter: name(regex:'^.*hk.*$', keyword:'sg') && name(keyword:'disney') + // filter: !name(regex: 'HK|TW|SG') && name(keyword: disney) // And for _, filter := range filters { @@ -74,7 +75,8 @@ func hit(dialer *dialer.Dialer, filters []*config_parser.Function) (hit bool, er default: return false, fmt.Errorf(`unsupported filter input type: "%v"`, filter.Name) } - if !subFilterHit { + + if subFilterHit == filter.Not { return false, nil } } diff --git a/component/routing/matcher_builder.go b/component/routing/matcher_builder.go index 326bb43..e372b0c 100644 --- a/component/routing/matcher_builder.go +++ b/component/routing/matcher_builder.go @@ -15,28 +15,32 @@ import ( ) var FakeOutbound_AND = consts.OutboundLogicalAnd.String() +var FakeOutbound_OR = consts.OutboundLogicalOr.String() type MatcherBuilder interface { - AddDomain(key string, values []string, outbound string) - AddIp(values []netip.Prefix, outbound string) - AddPort(values [][2]int, outbound string) - AddSourceIp(values []netip.Prefix, outbound string) - AddSourcePort(values [][2]int, outbound string) - AddL4Proto(values consts.L4ProtoType, outbound string) - AddIpVersion(values consts.IpVersion, outbound string) - AddSourceMac(values [][6]byte, outbound string) + AddDomain(f *config_parser.Function, key string, values []string, outbound string) + AddIp(f *config_parser.Function, values []netip.Prefix, outbound string) + AddPort(f *config_parser.Function, values [][2]uint16, outbound string) + AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound string) + AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound string) + AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound string) + AddIpVersion(f *config_parser.Function, values consts.IpVersion, outbound string) + AddSourceMac(f *config_parser.Function, values [][6]byte, outbound string) AddFinal(outbound string) - AddAnyBefore(function string, key string, values []string, outbound string) - AddAnyAfter(function string, key string, values []string, outbound string) + AddAnyBefore(f *config_parser.Function, key string, values []string, outbound string) + AddAnyAfter(f *config_parser.Function, key string, values []string, outbound string) Build() (err error) } -func GroupParamValuesByKey(params []*config_parser.Param) map[string][]string { +func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) { groups := make(map[string][]string) for _, param := range params { + if _, ok := groups[param.Key]; !ok { + keyOrder = append(keyOrder, param.Key) + } groups[param.Key] = append(groups[param.Key], param.Val) } - return groups + return groups, keyOrder } func ParsePrefixes(values []string) (cidrs []netip.Prefix, err error) { @@ -59,26 +63,31 @@ func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingR // rule is like: domain(domain:baidu.com) && port(443) -> proxy for iFunc, f := range rule.AndFunctions { // f is like: domain(domain:baidu.com) - paramValueGroups := GroupParamValuesByKey(f.Params) - for key, paramValueGroup := range paramValueGroups { + paramValueGroups, keyOrder := GroupParamValuesByKey(f.Params) + for jMatchSet, key := range keyOrder { + paramValueGroup := paramValueGroups[key] // Preprocess the outbound and pass FakeOutbound_AND to all but the last function. - outbound := FakeOutbound_AND - if iFunc == len(rule.AndFunctions)-1 { - outbound = rule.Outbound + outbound := FakeOutbound_OR + if jMatchSet == len(keyOrder)-1 { + outbound = FakeOutbound_AND + if iFunc == len(rule.AndFunctions)-1 { + outbound = rule.Outbound + } } - builder.AddAnyBefore(f.Name, key, paramValueGroup, outbound) + + builder.AddAnyBefore(f, key, paramValueGroup, outbound) switch f.Name { case consts.Function_Domain: - builder.AddDomain(key, paramValueGroup, outbound) + builder.AddDomain(f, key, paramValueGroup, outbound) case consts.Function_Ip, consts.Function_SourceIp: cidrs, err := ParsePrefixes(paramValueGroup) if err != nil { return err } if f.Name == consts.Function_Ip { - builder.AddIp(cidrs, outbound) + builder.AddIp(f, cidrs, outbound) } else { - builder.AddSourceIp(cidrs, outbound) + builder.AddSourceIp(f, cidrs, outbound) } case consts.Function_Mac: var macAddrs [][6]byte @@ -89,9 +98,9 @@ func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingR } macAddrs = append(macAddrs, mac) } - builder.AddSourceMac(macAddrs, outbound) + builder.AddSourceMac(f, macAddrs, outbound) case consts.Function_Port, consts.Function_SourcePort: - var portRanges [][2]int + var portRanges [][2]uint16 for _, v := range paramValueGroup { portRange, err := common.ParsePortRange(v) if err != nil { @@ -100,9 +109,9 @@ func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingR portRanges = append(portRanges, portRange) } if f.Name == consts.Function_Port { - builder.AddPort(portRanges, outbound) + builder.AddPort(f, portRanges, outbound) } else { - builder.AddSourcePort(portRanges, outbound) + builder.AddSourcePort(f, portRanges, outbound) } case consts.Function_L4Proto: var l4protoType consts.L4ProtoType @@ -114,7 +123,7 @@ func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingR l4protoType |= consts.L4ProtoType_UDP } } - builder.AddL4Proto(l4protoType, outbound) + builder.AddL4Proto(f, l4protoType, outbound) case consts.Function_IpVersion: var ipVersion consts.IpVersion for _, v := range paramValueGroup { @@ -125,33 +134,46 @@ func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingR ipVersion |= consts.IpVersion_6 } } - builder.AddIpVersion(ipVersion, outbound) + builder.AddIpVersion(f, ipVersion, outbound) default: return fmt.Errorf("unsupported function name: %v", f.Name) } - builder.AddAnyAfter(f.Name, key, paramValueGroup, outbound) + builder.AddAnyAfter(f, key, paramValueGroup, outbound) } } } - builder.AddAnyBefore("final", "", nil, finalOutbound) + builder.AddAnyBefore(&config_parser.Function{ + Name: "final", + }, "", nil, finalOutbound) builder.AddFinal(finalOutbound) - builder.AddAnyAfter("final", "", nil, finalOutbound) + builder.AddAnyAfter(&config_parser.Function{ + Name: "final", + }, "", nil, finalOutbound) return nil } -type DefaultMatcherBuilder struct{} - -func (d *DefaultMatcherBuilder) AddDomain(values []string, outbound string) {} -func (d *DefaultMatcherBuilder) AddIp(values []netip.Prefix, outbound string) {} -func (d *DefaultMatcherBuilder) AddPort(values [][2]int, outbound string) {} -func (d *DefaultMatcherBuilder) AddSource(values []netip.Prefix, outbound string) {} -func (d *DefaultMatcherBuilder) AddSourcePort(values [][2]int, outbound string) {} -func (d *DefaultMatcherBuilder) AddL4Proto(values consts.L4ProtoType, outbound string) {} -func (d *DefaultMatcherBuilder) AddIpVersion(values consts.IpVersion, outbound string) {} -func (d *DefaultMatcherBuilder) AddMac(values [][6]byte, outbound string) {} -func (d *DefaultMatcherBuilder) AddFinal(outbound string) {} -func (d *DefaultMatcherBuilder) AddAnyBefore(function string, key string, values []string, outbound string) { +type DefaultMatcherBuilder struct { } -func (d *DefaultMatcherBuilder) AddAnyAfter(function string, key string, values []string, outbound string) { + +func (d *DefaultMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound string) { +} +func (d *DefaultMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound string) { +} +func (d *DefaultMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound string) { +} +func (d *DefaultMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound string) { +} +func (d *DefaultMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound string) { +} +func (d *DefaultMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound string) { +} +func (d *DefaultMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersion, outbound string) { +} +func (d *DefaultMatcherBuilder) AddSourceMac(f *config_parser.Function, values [][6]byte, outbound string) { +} +func (d *DefaultMatcherBuilder) AddFinal(outbound string) {} +func (d *DefaultMatcherBuilder) AddAnyBefore(f *config_parser.Function, key string, values []string, outbound string) { +} +func (d *DefaultMatcherBuilder) AddAnyAfter(f *config_parser.Function, key string, values []string, outbound string) { } func (d *DefaultMatcherBuilder) Build() (err error) { return nil } diff --git a/example.conf b/example.conf index 2d13598..347e13f 100644 --- a/example.conf +++ b/example.conf @@ -33,12 +33,15 @@ group { filter: name(keyword: HK) # Randomly select a node from the group for every connection. + # policy: random + + # Select the first node from the group for every connection. policy: fixed(0) } disney { # Pass node names as input of keyword/regex filter. - filter: name(regex:'HK|SG|TW', keyword:'JP') && name(keyword:'GCP') + filter: name(regex:'HK|SG|TW', keyword:'JP') && !name(keyword:'GCP') # Select the node with min average of the last 10 latencies from the group for every connection. policy: min_avg10 @@ -48,15 +51,18 @@ group { # Pass node names as input of keyword filter. filter: name(keyword:AWS) - # Select the first node from the group for every connection. + # Select the node with min last latency from the group for every connection. policy: min } } routing { + #ip(geoip:private)->direct + !port(443) -> direct + sport(123) -> direct + !sip(192.168.0.252/30) -> direct domain(geosite:category-ads) -> block l4proto(udp) && mac('02:42:ac:11:00:03') -> my_group - domain(geosite:category-ads) -> block domain(geosite:disney) -> disney domain(geosite:netflix) -> netflix ip(geoip:cn) -> direct diff --git a/pkg/config_parser/section.go b/pkg/config_parser/section.go index 2ce8239..3b38f2e 100644 --- a/pkg/config_parser/section.go +++ b/pkg/config_parser/section.go @@ -129,6 +129,7 @@ func (p *Param) String(compact bool) string { type Function struct { Name string + Not bool Params []*Param } diff --git a/pkg/config_parser/walker.go b/pkg/config_parser/walker.go index cd26d50..d12d495 100644 --- a/pkg/config_parser/walker.go +++ b/pkg/config_parser/walker.go @@ -80,8 +80,14 @@ type functionVerifier func(function *Function, ctx interface{}) bool func (w *Walker) parseFunctionPrototype(ctx *dae_config.FunctionPrototypeContext, verifier functionVerifier) *Function { children := ctx.GetChildren() - funcName := children[0].(*antlr.TerminalNodeImpl).GetText() - paramList := children[2].(*dae_config.OptParameterListContext) + not := false + offset := 0 + if children[0].(*antlr.TerminalNodeImpl).GetText() == "!" { + offset++ + not = true + } + funcName := children[offset+0].(*antlr.TerminalNodeImpl).GetText() + paramList := children[offset+2].(*dae_config.OptParameterListContext) children = paramList.GetChildren() if len(children) == 0 { w.ReportError(ctx, ErrorType_Unsupported, "empty parameter list") @@ -91,6 +97,7 @@ func (w *Walker) parseFunctionPrototype(ctx *dae_config.FunctionPrototypeContext params := w.parseNonEmptyParamList(nonEmptyParamList) f := &Function{ Name: funcName, + Not: not, Params: params, } // Verify function name and param keys. @@ -117,6 +124,14 @@ func (w *Walker) ReportError(ctx interface{}, errorType ErrorType, target ...str w.parser.NotifyErrorListeners(fmt.Sprintf("%v %v.", tgt, errorType), bCtx.GetStart(), nil) } +func (w *Walker) declarationFunctionVerifier(function *Function, ctx interface{}) bool { + //if function.Not { + // w.ReportError(ctx, ErrorType_Unsupported, "Not operator in param declaration") + // return false + //} + return true +} + func (w *Walker) parseDeclaration(ctx dae_config.IDeclarationContext) *Param { children := ctx.GetChildren() key := children[0].(*antlr.TerminalNodeImpl).GetText() @@ -128,7 +143,7 @@ func (w *Walker) parseDeclaration(ctx dae_config.IDeclarationContext) *Param { Val: value, } case *dae_config.FunctionPrototypeExpressionContext: - andFunctions := w.parseFunctionPrototypeExpression(valueCtx, nil) + andFunctions := w.parseFunctionPrototypeExpression(valueCtx, w.declarationFunctionVerifier) if andFunctions == nil { return nil }