feat: support not operator and port, sport routing func

This commit is contained in:
mzz2017 2023-01-29 07:31:52 +08:00
parent 49afec8079
commit c8d11cbecf
14 changed files with 322 additions and 156 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
.vscode
.idea
*.o
*.tmp
bpf_bpfeb.go
bpf_bpfel.go
dae

View File

@ -7,7 +7,6 @@ package consts
const (
AppName = "dae"
MaxInterfaceIpNum = 8
BpfPinRoot = "/sys/fs/bpf"
AddrHdrSize = 20
@ -49,7 +48,8 @@ type OutboundIndex uint8
const (
OutboundDirect OutboundIndex = 0
OutboundBlock OutboundIndex = 1
OutboundControlPlaneDirect OutboundIndex = 0xFE
OutboundControlPlaneDirect OutboundIndex = 0xFD
OutboundLogicalOr OutboundIndex = 0xFE
OutboundLogicalAnd OutboundIndex = 0xFF
)
@ -61,6 +61,8 @@ func (i OutboundIndex) String() string {
return "block"
case OutboundControlPlaneDirect:
return "<Control Plane Direct>"
case OutboundLogicalOr:
return "<OR>"
case OutboundLogicalAnd:
return "<AND>"
default:

View File

@ -106,7 +106,7 @@ func ParseMac(mac string) (addr [6]byte, err error) {
return addr, nil
}
func ParsePortRange(pr string) (portRange [2]int, err error) {
func ParsePortRange(pr string) (portRange [2]uint16, err error) {
fields := strings.SplitN(pr, "-", 2)
for i, field := range fields {
if field == "" {
@ -119,7 +119,7 @@ func ParsePortRange(pr string) (portRange [2]int, err error) {
if port < 0 || port > 0xffff {
return portRange, fmt.Errorf("port %v exceeds uint16 range", port)
}
portRange[i] = port
portRange[i] = uint16(port)
}
if len(fields) == 1 {
portRange[1] = portRange[0]

View File

@ -6,6 +6,7 @@
package control
import (
"encoding/binary"
"fmt"
"github.com/cilium/ebpf"
"github.com/v2rayA/dae/common"
@ -19,6 +20,18 @@ type _bpfLpmKey struct {
Data [4]uint32
}
type _bpfPortRange struct {
PortStart uint16
PortEnd uint16
}
func (r _bpfPortRange) Encode() uint32 {
var b [4]byte
binary.LittleEndian.PutUint16(b[:2], r.PortStart)
binary.LittleEndian.PutUint16(b[2:], r.PortEnd)
return binary.BigEndian.Uint32(b[:])
}
func (o *bpfObjects) newLpmMap(keys []_bpfLpmKey, values []uint32) (m *ebpf.Map, err error) {
m, err = ebpf.NewMap(&ebpf.MapSpec{
Type: ebpf.LPMTrie,

View File

@ -26,6 +26,7 @@ import (
"net/netip"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"
@ -70,11 +71,18 @@ func NewControlPlane(
// Load pre-compiled programs and maps into the kernel.
var bpf bpfObjects
var ProgramOptions ebpf.ProgramOptions
if log.IsLevelEnabled(logrus.TraceLevel) {
ProgramOptions = ebpf.ProgramOptions{
LogLevel: ebpf.LogLevelStats,
}
}
retryLoadBpf:
if err = loadBpfObjects(&bpf, &ebpf.CollectionOptions{
Maps: ebpf.MapOptions{
PinPath: pinPath,
},
Programs: ProgramOptions,
}); err != nil {
if errors.Is(err, ebpf.ErrMapIncompatible) {
// Map property is incompatible. Remove the old map and try again.
@ -88,7 +96,17 @@ retryLoadBpf:
log.Warnf("New map format was incompatible with existing map %v, and the old one was removed.", mapName)
goto retryLoadBpf
}
return nil, fmt.Errorf("loading objects: %w", err)
// Get detailed log from ebpf.internal.*VerifierError
if log.IsLevelEnabled(logrus.TraceLevel) {
if v := reflect.Indirect(reflect.ValueOf(errors.Unwrap(errors.Unwrap(err)))); v.Kind() == reflect.Struct {
if log := v.FieldByName("Log"); log.IsValid() {
if strSlice, ok := log.Interface().([]string); ok {
err = fmt.Errorf("%v", strings.Join(strSlice, "\n"))
}
}
}
}
return nil, fmt.Errorf("loading objects: %v", err)
}
// Write params.
if err = bpf.ParamMap.Update(consts.DisableL4TxChecksumKey, consts.DisableL4ChecksumPolicy_SetZero, ebpf.UpdateAny); err != nil {

View File

@ -42,15 +42,16 @@
#define OUTBOUND_DIRECT 0
#define OUTBOUND_BLOCK 1
#define OUTBOUND_CONTROL_PLANE_DIRECT 0xFE
#define OUTBOUND_CONTROL_PLANE_DIRECT 0xFD
#define OUTBOUND_LOGICAL_OR 0xFE
#define OUTBOUND_LOGICAL_AND 0xFF
#define OUTBOUND_LOGICAL_MASK 0xFE
enum {
DISABLE_L4_CHECKSUM_POLICY_ENABLE_L4_CHECKSUM,
DISABLE_L4_CHECKSUM_POLICY_RESTORE,
DISABLE_L4_CHECKSUM_POLICY_SET_ZERO,
};
#define OUTBOUND_LOGICAL_AND 0xFF
// Param keys:
static const __u32 zero_key = 0;
@ -208,7 +209,17 @@ struct port_range {
__u16 port_start;
__u16 port_end;
};
struct routing {
/*
Look at following rule:
domain(geosite:cn, suffix: google.com) && l4proto(tcp) -> my_group
pseudocode: domain(geosite:cn || suffix:google.com) && l4proto(tcp) -> my_group
A match_set can be: IP set geosite:cn, suffix google.com, tcp proto
*/
struct match_set {
union {
__u32 __value; // Placeholder for bpf2go.
@ -218,13 +229,13 @@ struct routing {
enum IP_VERSION ip_version;
};
enum ROUTING_TYPE type;
__u8 outbound; // 255 means logical AND. 254 means dirty. User-defined value
// range is [0, 253].
bool not ; // A subrule flag (this is not a match_set flag).
__u8 outbound; // User-defined value range is [0, 252].
};
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
__type(key, __u32);
__type(value, struct routing);
__type(value, struct match_set);
__uint(max_entries, MAX_ROUTING_LEN);
// __uint(pinning, LIBBPF_PIN_BY_NAME);
} routing_map SEC(".maps");
@ -809,42 +820,45 @@ static long routing(__u32 flag[3], void *l4_hdr, __be32 saddr[4],
bpf_map_update_elem(&lpm_key_map, &key, &lpm_key_mac, BPF_ANY);
struct map_lpm_type *lpm;
struct routing *routing;
// Rule is like: domain(domain:baidu.com) && port(443) -> proxy
struct match_set *match_set;
// Rule is like: domain(suffix:baidu.com, suffix:google.com) && port(443) ->
// proxy Subrule is like: domain(suffix:baidu.com, suffix:google.com) Match
// set is like: suffix:baidu.com
bool bad_rule = false;
bool good_subrule = false;
struct domain_routing *domain_routing;
__u32 *p_u32;
#pragma unroll
for (__u32 i = 0; i < MAX_ROUTING_LEN; i++) {
__u32 k = i; // Clone to pass code checker.
routing = bpf_map_lookup_elem(&routing_map, &k);
if (!routing) {
match_set = bpf_map_lookup_elem(&routing_map, &k);
if (!match_set) {
return -EFAULT;
}
if (bad_rule) {
if (bad_rule || good_subrule) {
goto before_next_loop;
}
key = (key & (__u32)0) | (__u32)routing->type;
key = (key & (__u32)0) | (__u32)match_set->type;
if ((lpm_key = bpf_map_lookup_elem(&lpm_key_map, &key))) {
lpm = bpf_map_lookup_elem(&lpm_array_map, &routing->index);
lpm = bpf_map_lookup_elem(&lpm_array_map, &match_set->index);
if (unlikely(!lpm)) {
return -EFAULT;
}
if (!bpf_map_lookup_elem(lpm, lpm_key)) {
// Routing not hit.
bad_rule = true;
if (bpf_map_lookup_elem(lpm, lpm_key)) {
// match_set hits.
good_subrule = true;
}
} else if ((p_u32 = bpf_map_lookup_elem(&h_port_map, &key))) {
if (*p_u32 < routing->port_range.port_start ||
*p_u32 > routing->port_range.port_end) {
bad_rule = true;
if (*p_u32 >= match_set->port_range.port_start &&
*p_u32 <= match_set->port_range.port_end) {
good_subrule = true;
}
} else if ((p_u32 = bpf_map_lookup_elem(&l4proto_ipversion_map, &key))) {
if (!(*p_u32 & routing->__value)) {
bad_rule = true;
if (*p_u32 & match_set->__value) {
good_subrule = true;
}
} else if (routing->type == ROUTING_TYPE_DOMAIN_SET) {
} else if (match_set->type == ROUTING_TYPE_DOMAIN_SET) {
// Bottleneck of insns limit.
// We fixed it by invoking bpf_map_lookup_elem here.
@ -852,31 +866,46 @@ static long routing(__u32 flag[3], void *l4_hdr, __be32 saddr[4],
domain_routing = bpf_map_lookup_elem(&domain_routing_map, daddr);
if (!domain_routing) {
// No domain corresponding to IP.
bad_rule = true;
goto before_next_loop;
}
// We use key instead of k to pass checker.
if (!((domain_routing->bitmap[i / 32] >> (i % 32)) & 1)) {
bad_rule = true;
if ((domain_routing->bitmap[i / 32] >> (i % 32)) & 1) {
good_subrule = true;
}
} else if (routing->type == ROUTING_TYPE_FINAL) {
bad_rule = false;
} else if (match_set->type == ROUTING_TYPE_FINAL) {
good_subrule = true;
} else {
return -EINVAL;
}
before_next_loop:
if (routing->outbound != OUTBOUND_LOGICAL_AND) {
if (match_set->outbound != OUTBOUND_LOGICAL_OR && !bad_rule) {
// This match_set reaches the end of subrule.
// We are now at end of rule, or next match_set belongs to another
// subrule.
if (good_subrule == match_set->not ) {
// This subrule does not hit.
bad_rule = true;
} else {
// This subrule hits.
// Reset the good_subrule flag.
good_subrule = false;
}
}
if ((match_set->outbound & OUTBOUND_LOGICAL_MASK) !=
OUTBOUND_LOGICAL_MASK) {
// Tail of a rule (line).
// Decide whether to hit.
if (!bad_rule) {
if (routing->outbound == OUTBOUND_DIRECT && h_dport == 53 &&
if (match_set->outbound == OUTBOUND_DIRECT && h_dport == 53 &&
_l4proto == L4PROTO_TYPE_UDP) {
// DNS packet should go through control plane.
return OUTBOUND_CONTROL_PLANE_DIRECT;
}
return routing->outbound;
// bpf_printk("match_set->type: %d, match_set->not: %d", match_set->type,
// match_set->not );
return match_set->outbound;
}
bad_rule = false;
}

View File

@ -11,6 +11,7 @@ import (
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/pkg/config_parser"
"net/netip"
"strconv"
)
@ -25,7 +26,7 @@ type RoutingMatcherBuilder struct {
*routing.DefaultMatcherBuilder
outboundName2Id map[string]uint8
bpf *bpfObjects
rules []bpfRouting
rules []bpfMatchSet
SimulatedLpmTries [][]netip.Prefix
SimulatedDomainSet []DomainSet
Final string
@ -39,9 +40,12 @@ func NewRoutingMatcherBuilder(outboundName2Id map[string]uint8, bpf *bpfObjects)
func (b *RoutingMatcherBuilder) OutboundToId(outbound string) uint8 {
var outboundId uint8
if outbound == routing.FakeOutbound_AND {
switch outbound {
case routing.FakeOutbound_AND:
outboundId = uint8(consts.OutboundLogicalAnd)
} else {
case routing.FakeOutbound_OR:
outboundId = uint8(consts.OutboundLogicalOr)
default:
var ok bool
outboundId, ok = b.outboundName2Id[outbound]
if !ok {
@ -51,7 +55,7 @@ func (b *RoutingMatcherBuilder) OutboundToId(outbound string) uint8 {
return outboundId
}
func (b *RoutingMatcherBuilder) AddDomain(key string, values []string, outbound string) {
func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound string) {
if b.err != nil {
return
}
@ -69,13 +73,14 @@ func (b *RoutingMatcherBuilder) AddDomain(key string, values []string, outbound
RuleIndex: len(b.rules),
Domains: values,
})
b.rules = append(b.rules, bpfRouting{
b.rules = append(b.rules, bpfMatchSet{
Type: uint32(consts.RoutingType_DomainSet),
Not: f.Not,
Outbound: b.OutboundToId(outbound),
})
}
func (b *RoutingMatcherBuilder) AddSourceMac(macAddrs [][6]byte, outbound string) {
func (b *RoutingMatcherBuilder) AddSourceMac(f *config_parser.Function, macAddrs [][6]byte, outbound string) {
if b.err != nil {
return
}
@ -88,58 +93,91 @@ func (b *RoutingMatcherBuilder) AddSourceMac(macAddrs [][6]byte, outbound string
}
lpmTrieIndex := len(b.SimulatedLpmTries)
b.SimulatedLpmTries = append(b.SimulatedLpmTries, values)
b.rules = append(b.rules, bpfRouting{
b.rules = append(b.rules, bpfMatchSet{
Type: uint32(consts.RoutingType_Mac),
Value: uint32(lpmTrieIndex),
Not: f.Not,
Outbound: b.OutboundToId(outbound),
})
}
func (b *RoutingMatcherBuilder) AddIp(values []netip.Prefix, outbound string) {
func (b *RoutingMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound string) {
if b.err != nil {
return
}
lpmTrieIndex := len(b.SimulatedLpmTries)
b.SimulatedLpmTries = append(b.SimulatedLpmTries, values)
b.rules = append(b.rules, bpfRouting{
b.rules = append(b.rules, bpfMatchSet{
Type: uint32(consts.RoutingType_IpSet),
Value: uint32(lpmTrieIndex),
Not: f.Not,
Outbound: b.OutboundToId(outbound),
})
}
func (b *RoutingMatcherBuilder) AddSourceIp(values []netip.Prefix, outbound string) {
func (b *RoutingMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound string) {
for _, value := range values {
b.rules = append(b.rules, bpfMatchSet{
Type: uint32(consts.RoutingType_Port),
Value: _bpfPortRange{
PortStart: value[0],
PortEnd: value[1],
}.Encode(),
Not: f.Not,
Outbound: b.OutboundToId(outbound),
})
}
}
func (b *RoutingMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound string) {
if b.err != nil {
return
}
lpmTrieIndex := len(b.SimulatedLpmTries)
b.SimulatedLpmTries = append(b.SimulatedLpmTries, values)
b.rules = append(b.rules, bpfRouting{
b.rules = append(b.rules, bpfMatchSet{
Type: uint32(consts.RoutingType_SourceIpSet),
Value: uint32(lpmTrieIndex),
Not: f.Not,
Outbound: b.OutboundToId(outbound),
})
}
func (b *RoutingMatcherBuilder) AddL4Proto(values consts.L4ProtoType, outbound string) {
func (b *RoutingMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound string) {
for _, value := range values {
b.rules = append(b.rules, bpfMatchSet{
Type: uint32(consts.RoutingType_SourcePort),
Value: _bpfPortRange{
PortStart: value[0],
PortEnd: value[1],
}.Encode(),
Not: f.Not,
Outbound: b.OutboundToId(outbound),
})
}
}
func (b *RoutingMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound string) {
if b.err != nil {
return
}
b.rules = append(b.rules, bpfRouting{
b.rules = append(b.rules, bpfMatchSet{
Type: uint32(consts.RoutingType_L4Proto),
Value: uint32(values),
Not: f.Not,
Outbound: b.OutboundToId(outbound),
})
}
func (b *RoutingMatcherBuilder) AddIpVersion(values consts.IpVersion, outbound string) {
func (b *RoutingMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersion, outbound string) {
if b.err != nil {
return
}
b.rules = append(b.rules, bpfRouting{
b.rules = append(b.rules, bpfMatchSet{
Type: uint32(consts.RoutingType_IpVersion),
Value: uint32(values),
Not: f.Not,
Outbound: b.OutboundToId(outbound),
})
}
@ -149,7 +187,7 @@ func (b *RoutingMatcherBuilder) AddFinal(outbound string) {
return
}
b.Final = outbound
b.rules = append(b.rules, bpfRouting{
b.rules = append(b.rules, bpfMatchSet{
Type: uint32(consts.RoutingType_Final),
Outbound: b.OutboundToId(outbound),
})
@ -193,3 +231,7 @@ func (b *RoutingMatcherBuilder) Build() (err error) {
}
return nil
}
//func (b *RoutingMatcherBuilder) AddAnyBefore(f *config_parser.Function, key string, values []string, outbound string) {
// logrus.Debugln(f.Not, f.Name, key, outbound)
//}

View File

@ -10,63 +10,10 @@ import (
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/outbound/dialer"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/pkg/config_parser"
"golang.org/x/net/proxy"
"net"
"strconv"
)
type DialerSelectionPolicy struct {
Policy consts.DialerSelectionPolicy
FixedIndex int
}
func NewDialerSelectionPolicyFromGroupParam(param *config.GroupParam) (policy *DialerSelectionPolicy, err error) {
switch val := param.Policy.(type) {
case string:
switch consts.DialerSelectionPolicy(val) {
case consts.DialerSelectionPolicy_Random,
consts.DialerSelectionPolicy_MinAverage10Latencies,
consts.DialerSelectionPolicy_MinLastLatency:
return &DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy(val),
}, nil
case consts.DialerSelectionPolicy_Fixed:
return nil, fmt.Errorf("%v need to specify node index", val)
default:
return nil, fmt.Errorf("unexpected policy: %v", val)
}
case []*config_parser.Function:
if len(val) > 1 || len(val) == 0 {
logrus.Debugf("%@", val)
return nil, fmt.Errorf("policy should be exact 1 function: got %v", len(val))
}
f := val[0]
switch consts.DialerSelectionPolicy(f.Name) {
case consts.DialerSelectionPolicy_Fixed:
// Should be like:
// policy: fixed(0)
if len(f.Params) > 1 || f.Params[0].Key != "" {
return nil, fmt.Errorf(`invalid "%v" param format`, f.Name)
}
strIndex := f.Params[0].Val
index, err := strconv.Atoi(strIndex)
if len(f.Params) > 1 || f.Params[0].Key != "" {
return nil, fmt.Errorf(`invalid "%v" param format: %w`, f.Name, err)
}
return &DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy(f.Name),
FixedIndex: index,
}, nil
default:
return nil, fmt.Errorf("unexpected policy func: %v", f.Name)
}
default:
return nil, fmt.Errorf("unexpected param.Policy.(type): %T", val)
}
}
type DialerGroup struct {
proxy.Dialer
block *dialer.Dialer

View File

@ -0,0 +1,68 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2023, mzz2017 <mzz@tuta.io>
*/
package outbound
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/pkg/config_parser"
"strconv"
)
type DialerSelectionPolicy struct {
Policy consts.DialerSelectionPolicy
FixedIndex int
}
func NewDialerSelectionPolicyFromGroupParam(param *config.GroupParam) (policy *DialerSelectionPolicy, err error) {
switch val := param.Policy.(type) {
case string:
switch consts.DialerSelectionPolicy(val) {
case consts.DialerSelectionPolicy_Random,
consts.DialerSelectionPolicy_MinAverage10Latencies,
consts.DialerSelectionPolicy_MinLastLatency:
return &DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy(val),
}, nil
case consts.DialerSelectionPolicy_Fixed:
return nil, fmt.Errorf("%v need to specify node index", val)
default:
return nil, fmt.Errorf("unexpected policy: %v", val)
}
case []*config_parser.Function:
if len(val) > 1 || len(val) == 0 {
logrus.Debugf("%@", val)
return nil, fmt.Errorf("policy should be exact 1 function: got %v", len(val))
}
f := val[0]
switch consts.DialerSelectionPolicy(f.Name) {
case consts.DialerSelectionPolicy_Fixed:
// Should be like:
// policy: fixed(0)
if f.Not {
return nil, fmt.Errorf("policy param does not support not operator: !%v()", f.Name)
}
if len(f.Params) > 1 || f.Params[0].Key != "" {
return nil, fmt.Errorf(`invalid "%v" param format`, f.Name)
}
strIndex := f.Params[0].Val
index, err := strconv.Atoi(strIndex)
if len(f.Params) > 1 || f.Params[0].Key != "" {
return nil, fmt.Errorf(`invalid "%v" param format: %w`, f.Name, err)
}
return &DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy(f.Name),
FixedIndex: index,
}, nil
default:
return nil, fmt.Errorf("unexpected policy func: %v", f.Name)
}
default:
return nil, fmt.Errorf("unexpected param.Policy.(type): %T", val)
}
}

View File

@ -43,6 +43,7 @@ func NewDialerSetFromLinks(option *dialer.GlobalOption, nodes []string) *DialerS
func hit(dialer *dialer.Dialer, filters []*config_parser.Function) (hit bool, err error) {
// Example
// filter: name(regex:'^.*hk.*$', keyword:'sg') && name(keyword:'disney')
// filter: !name(regex: 'HK|TW|SG') && name(keyword: disney)
// And
for _, filter := range filters {
@ -74,7 +75,8 @@ func hit(dialer *dialer.Dialer, filters []*config_parser.Function) (hit bool, er
default:
return false, fmt.Errorf(`unsupported filter input type: "%v"`, filter.Name)
}
if !subFilterHit {
if subFilterHit == filter.Not {
return false, nil
}
}

View File

@ -15,28 +15,32 @@ import (
)
var FakeOutbound_AND = consts.OutboundLogicalAnd.String()
var FakeOutbound_OR = consts.OutboundLogicalOr.String()
type MatcherBuilder interface {
AddDomain(key string, values []string, outbound string)
AddIp(values []netip.Prefix, outbound string)
AddPort(values [][2]int, outbound string)
AddSourceIp(values []netip.Prefix, outbound string)
AddSourcePort(values [][2]int, outbound string)
AddL4Proto(values consts.L4ProtoType, outbound string)
AddIpVersion(values consts.IpVersion, outbound string)
AddSourceMac(values [][6]byte, outbound string)
AddDomain(f *config_parser.Function, key string, values []string, outbound string)
AddIp(f *config_parser.Function, values []netip.Prefix, outbound string)
AddPort(f *config_parser.Function, values [][2]uint16, outbound string)
AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound string)
AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound string)
AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound string)
AddIpVersion(f *config_parser.Function, values consts.IpVersion, outbound string)
AddSourceMac(f *config_parser.Function, values [][6]byte, outbound string)
AddFinal(outbound string)
AddAnyBefore(function string, key string, values []string, outbound string)
AddAnyAfter(function string, key string, values []string, outbound string)
AddAnyBefore(f *config_parser.Function, key string, values []string, outbound string)
AddAnyAfter(f *config_parser.Function, key string, values []string, outbound string)
Build() (err error)
}
func GroupParamValuesByKey(params []*config_parser.Param) map[string][]string {
func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) {
groups := make(map[string][]string)
for _, param := range params {
if _, ok := groups[param.Key]; !ok {
keyOrder = append(keyOrder, param.Key)
}
groups[param.Key] = append(groups[param.Key], param.Val)
}
return groups
return groups, keyOrder
}
func ParsePrefixes(values []string) (cidrs []netip.Prefix, err error) {
@ -59,26 +63,31 @@ func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingR
// rule is like: domain(domain:baidu.com) && port(443) -> proxy
for iFunc, f := range rule.AndFunctions {
// f is like: domain(domain:baidu.com)
paramValueGroups := GroupParamValuesByKey(f.Params)
for key, paramValueGroup := range paramValueGroups {
paramValueGroups, keyOrder := GroupParamValuesByKey(f.Params)
for jMatchSet, key := range keyOrder {
paramValueGroup := paramValueGroups[key]
// Preprocess the outbound and pass FakeOutbound_AND to all but the last function.
outbound := FakeOutbound_AND
outbound := FakeOutbound_OR
if jMatchSet == len(keyOrder)-1 {
outbound = FakeOutbound_AND
if iFunc == len(rule.AndFunctions)-1 {
outbound = rule.Outbound
}
builder.AddAnyBefore(f.Name, key, paramValueGroup, outbound)
}
builder.AddAnyBefore(f, key, paramValueGroup, outbound)
switch f.Name {
case consts.Function_Domain:
builder.AddDomain(key, paramValueGroup, outbound)
builder.AddDomain(f, key, paramValueGroup, outbound)
case consts.Function_Ip, consts.Function_SourceIp:
cidrs, err := ParsePrefixes(paramValueGroup)
if err != nil {
return err
}
if f.Name == consts.Function_Ip {
builder.AddIp(cidrs, outbound)
builder.AddIp(f, cidrs, outbound)
} else {
builder.AddSourceIp(cidrs, outbound)
builder.AddSourceIp(f, cidrs, outbound)
}
case consts.Function_Mac:
var macAddrs [][6]byte
@ -89,9 +98,9 @@ func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingR
}
macAddrs = append(macAddrs, mac)
}
builder.AddSourceMac(macAddrs, outbound)
builder.AddSourceMac(f, macAddrs, outbound)
case consts.Function_Port, consts.Function_SourcePort:
var portRanges [][2]int
var portRanges [][2]uint16
for _, v := range paramValueGroup {
portRange, err := common.ParsePortRange(v)
if err != nil {
@ -100,9 +109,9 @@ func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingR
portRanges = append(portRanges, portRange)
}
if f.Name == consts.Function_Port {
builder.AddPort(portRanges, outbound)
builder.AddPort(f, portRanges, outbound)
} else {
builder.AddSourcePort(portRanges, outbound)
builder.AddSourcePort(f, portRanges, outbound)
}
case consts.Function_L4Proto:
var l4protoType consts.L4ProtoType
@ -114,7 +123,7 @@ func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingR
l4protoType |= consts.L4ProtoType_UDP
}
}
builder.AddL4Proto(l4protoType, outbound)
builder.AddL4Proto(f, l4protoType, outbound)
case consts.Function_IpVersion:
var ipVersion consts.IpVersion
for _, v := range paramValueGroup {
@ -125,33 +134,46 @@ func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingR
ipVersion |= consts.IpVersion_6
}
}
builder.AddIpVersion(ipVersion, outbound)
builder.AddIpVersion(f, ipVersion, outbound)
default:
return fmt.Errorf("unsupported function name: %v", f.Name)
}
builder.AddAnyAfter(f.Name, key, paramValueGroup, outbound)
builder.AddAnyAfter(f, key, paramValueGroup, outbound)
}
}
}
builder.AddAnyBefore("final", "", nil, finalOutbound)
builder.AddAnyBefore(&config_parser.Function{
Name: "final",
}, "", nil, finalOutbound)
builder.AddFinal(finalOutbound)
builder.AddAnyAfter("final", "", nil, finalOutbound)
builder.AddAnyAfter(&config_parser.Function{
Name: "final",
}, "", nil, finalOutbound)
return nil
}
type DefaultMatcherBuilder struct{}
func (d *DefaultMatcherBuilder) AddDomain(values []string, outbound string) {}
func (d *DefaultMatcherBuilder) AddIp(values []netip.Prefix, outbound string) {}
func (d *DefaultMatcherBuilder) AddPort(values [][2]int, outbound string) {}
func (d *DefaultMatcherBuilder) AddSource(values []netip.Prefix, outbound string) {}
func (d *DefaultMatcherBuilder) AddSourcePort(values [][2]int, outbound string) {}
func (d *DefaultMatcherBuilder) AddL4Proto(values consts.L4ProtoType, outbound string) {}
func (d *DefaultMatcherBuilder) AddIpVersion(values consts.IpVersion, outbound string) {}
func (d *DefaultMatcherBuilder) AddMac(values [][6]byte, outbound string) {}
func (d *DefaultMatcherBuilder) AddFinal(outbound string) {}
func (d *DefaultMatcherBuilder) AddAnyBefore(function string, key string, values []string, outbound string) {
type DefaultMatcherBuilder struct {
}
func (d *DefaultMatcherBuilder) AddAnyAfter(function string, key string, values []string, outbound string) {
func (d *DefaultMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound string) {
}
func (d *DefaultMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound string) {
}
func (d *DefaultMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound string) {
}
func (d *DefaultMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound string) {
}
func (d *DefaultMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound string) {
}
func (d *DefaultMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound string) {
}
func (d *DefaultMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersion, outbound string) {
}
func (d *DefaultMatcherBuilder) AddSourceMac(f *config_parser.Function, values [][6]byte, outbound string) {
}
func (d *DefaultMatcherBuilder) AddFinal(outbound string) {}
func (d *DefaultMatcherBuilder) AddAnyBefore(f *config_parser.Function, key string, values []string, outbound string) {
}
func (d *DefaultMatcherBuilder) AddAnyAfter(f *config_parser.Function, key string, values []string, outbound string) {
}
func (d *DefaultMatcherBuilder) Build() (err error) { return nil }

View File

@ -33,12 +33,15 @@ group {
filter: name(keyword: HK)
# Randomly select a node from the group for every connection.
# policy: random
# Select the first node from the group for every connection.
policy: fixed(0)
}
disney {
# Pass node names as input of keyword/regex filter.
filter: name(regex:'HK|SG|TW', keyword:'JP') && name(keyword:'GCP')
filter: name(regex:'HK|SG|TW', keyword:'JP') && !name(keyword:'GCP')
# Select the node with min average of the last 10 latencies from the group for every connection.
policy: min_avg10
@ -48,15 +51,18 @@ group {
# Pass node names as input of keyword filter.
filter: name(keyword:AWS)
# Select the first node from the group for every connection.
# Select the node with min last latency from the group for every connection.
policy: min
}
}
routing {
#ip(geoip:private)->direct
!port(443) -> direct
sport(123) -> direct
!sip(192.168.0.252/30) -> direct
domain(geosite:category-ads) -> block
l4proto(udp) && mac('02:42:ac:11:00:03') -> my_group
domain(geosite:category-ads) -> block
domain(geosite:disney) -> disney
domain(geosite:netflix) -> netflix
ip(geoip:cn) -> direct

View File

@ -129,6 +129,7 @@ func (p *Param) String(compact bool) string {
type Function struct {
Name string
Not bool
Params []*Param
}

View File

@ -80,8 +80,14 @@ type functionVerifier func(function *Function, ctx interface{}) bool
func (w *Walker) parseFunctionPrototype(ctx *dae_config.FunctionPrototypeContext, verifier functionVerifier) *Function {
children := ctx.GetChildren()
funcName := children[0].(*antlr.TerminalNodeImpl).GetText()
paramList := children[2].(*dae_config.OptParameterListContext)
not := false
offset := 0
if children[0].(*antlr.TerminalNodeImpl).GetText() == "!" {
offset++
not = true
}
funcName := children[offset+0].(*antlr.TerminalNodeImpl).GetText()
paramList := children[offset+2].(*dae_config.OptParameterListContext)
children = paramList.GetChildren()
if len(children) == 0 {
w.ReportError(ctx, ErrorType_Unsupported, "empty parameter list")
@ -91,6 +97,7 @@ func (w *Walker) parseFunctionPrototype(ctx *dae_config.FunctionPrototypeContext
params := w.parseNonEmptyParamList(nonEmptyParamList)
f := &Function{
Name: funcName,
Not: not,
Params: params,
}
// Verify function name and param keys.
@ -117,6 +124,14 @@ func (w *Walker) ReportError(ctx interface{}, errorType ErrorType, target ...str
w.parser.NotifyErrorListeners(fmt.Sprintf("%v %v.", tgt, errorType), bCtx.GetStart(), nil)
}
func (w *Walker) declarationFunctionVerifier(function *Function, ctx interface{}) bool {
//if function.Not {
// w.ReportError(ctx, ErrorType_Unsupported, "Not operator in param declaration")
// return false
//}
return true
}
func (w *Walker) parseDeclaration(ctx dae_config.IDeclarationContext) *Param {
children := ctx.GetChildren()
key := children[0].(*antlr.TerminalNodeImpl).GetText()
@ -128,7 +143,7 @@ func (w *Walker) parseDeclaration(ctx dae_config.IDeclarationContext) *Param {
Val: value,
}
case *dae_config.FunctionPrototypeExpressionContext:
andFunctions := w.parseFunctionPrototypeExpression(valueCtx, nil)
andFunctions := w.parseFunctionPrototypeExpression(valueCtx, w.declarationFunctionVerifier)
if andFunctions == nil {
return nil
}