feat: support config file

This commit is contained in:
mzz2017
2023-01-28 01:50:21 +08:00
parent edbce81e88
commit 4c248e9e1a
35 changed files with 1168 additions and 898 deletions

View File

@ -17,6 +17,8 @@ import (
"github.com/v2rayA/dae/component/outbound"
"github.com/v2rayA/dae/component/outbound/dialer"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/pkg/config_parser"
"github.com/v2rayA/dae/pkg/pool"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
@ -49,9 +51,16 @@ type ControlPlane struct {
deferFuncs []func() error
}
func NewControlPlane(log *logrus.Logger, dialerGroups []*outbound.DialerGroup, routingA string) (*ControlPlane, error) {
func NewControlPlane(
log *logrus.Logger,
nodes []string,
groups []config.Group,
routingA *config.Routing,
dnsUpstream string,
checkUrl string,
) (c *ControlPlane, err error) {
// Allow the current process to lock memory for eBPF resources.
if err := rlimit.RemoveMemlock(); err != nil {
if err = rlimit.RemoveMemlock(); err != nil {
return nil, fmt.Errorf("rlimit.RemoveMemlock:%v", err)
}
pinPath := filepath.Join(consts.BpfPinRoot, consts.AppName)
@ -60,7 +69,7 @@ func NewControlPlane(log *logrus.Logger, dialerGroups []*outbound.DialerGroup, r
// Load pre-compiled programs and maps into the kernel.
var bpf bpfObjects
retryLoadBpf:
if err := loadBpfObjects(&bpf, &ebpf.CollectionOptions{
if err = loadBpfObjects(&bpf, &ebpf.CollectionOptions{
Maps: ebpf.MapOptions{
PinPath: pinPath,
},
@ -81,29 +90,48 @@ retryLoadBpf:
}
// Write params.
if err := bpf.ParamMap.Update(consts.DisableL4TxChecksumKey, consts.DisableL4ChecksumPolicy_SetZero, ebpf.UpdateAny); err != nil {
if err = bpf.ParamMap.Update(consts.DisableL4TxChecksumKey, consts.DisableL4ChecksumPolicy_SetZero, ebpf.UpdateAny); err != nil {
return nil, err
}
if err := bpf.ParamMap.Update(consts.DisableL4RxChecksumKey, consts.DisableL4ChecksumPolicy_SetZero, ebpf.UpdateAny); err != nil {
if err = bpf.ParamMap.Update(consts.DisableL4RxChecksumKey, consts.DisableL4ChecksumPolicy_SetZero, ebpf.UpdateAny); err != nil {
return nil, err
}
// DialerGroups (outbounds).
option := &dialer.GlobalOption{
Log: log,
CheckUrl: checkUrl,
}
outbounds := []*outbound.DialerGroup{
outbound.NewDialerGroup(log, consts.OutboundDirect.String(),
[]*dialer.Dialer{dialer.NewDirectDialer(log, true)},
outbound.NewDialerGroup(option, consts.OutboundDirect.String(),
[]*dialer.Dialer{dialer.NewDirectDialer(option, true)},
outbound.DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Fixed,
FixedIndex: 0,
}),
outbound.NewDialerGroup(log, consts.OutboundBlock.String(),
[]*dialer.Dialer{dialer.NewBlockDialer(log)},
outbound.NewDialerGroup(option, consts.OutboundBlock.String(),
[]*dialer.Dialer{dialer.NewBlockDialer(option)},
outbound.DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Fixed,
FixedIndex: 0,
}),
}
outbounds = append(outbounds, dialerGroups...)
// Filter out groups.
dialerSet := outbound.NewDialerSetFromLinks(option, nodes)
for _, group := range groups {
dialers, err := dialerSet.Filter(group.Param.Filter)
if err != nil {
return nil, fmt.Errorf(`failed to create group "%v": %w`, group.Name, err)
}
policy, err := outbound.NewDialerSelectionPolicyFromGroupParam(&group.Param)
if err != nil {
return nil, fmt.Errorf("failed to create group %v: %w", group.Name, err)
}
dialerGroup := outbound.NewDialerGroup(option, group.Name, dialers, *policy)
outbounds = append(outbounds, dialerGroup)
}
// Generate outboundName2Id from outbounds.
if len(outbounds) > 0xff {
return nil, fmt.Errorf("too many outbounds")
@ -115,11 +143,8 @@ retryLoadBpf:
builder := NewRoutingMatcherBuilder(outboundName2Id, &bpf)
// Routing.
rules, final, err := routing.Parse(routingA)
if err != nil {
return nil, fmt.Errorf("routingA error:\n%w", err)
}
if rules, err = routing.ApplyRulesOptimizers(rules,
var rules []*config_parser.RoutingRule
if rules, err = routing.ApplyRulesOptimizers(routingA.Rules,
&routing.RefineFunctionParamKeyOptimizer{},
&routing.DatReaderOptimizer{Logger: log},
&routing.MergeAndSortRulesOptimizer{},
@ -130,24 +155,26 @@ retryLoadBpf:
if log.IsLevelEnabled(logrus.TraceLevel) {
var debugBuilder strings.Builder
for _, rule := range rules {
debugBuilder.WriteString(rule.String(true))
debugBuilder.WriteString(rule.String(true) + "\n")
}
log.Tracef("RoutingA:\n%vfinal: %v\n", debugBuilder.String(), final)
log.Tracef("RoutingA:\n%vfinal: %v\n", debugBuilder.String(), routingA.Final)
}
if err := routing.ApplyMatcherBuilder(builder, rules, final); err != nil {
if err = routing.ApplyMatcherBuilder(builder, rules, routingA.Final); err != nil {
return nil, fmt.Errorf("ApplyMatcherBuilder: %w", err)
}
if err := builder.Build(); err != nil {
if err = builder.Build(); err != nil {
return nil, fmt.Errorf("RoutingMatcherBuilder.Build: %w", err)
}
// DNS upstream.
cfDnsAddr := netip.AddrFrom4([4]byte{1, 1, 1, 1})
cfDnsAddr16 := cfDnsAddr.As16()
cfDnsPort := uint16(53)
if err := bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{
Ip: common.Ipv6ByteSliceToUint32Array(cfDnsAddr16[:]),
Port: swap16(cfDnsPort),
dnsAddrPort, err := netip.ParseAddrPort(dnsUpstream)
if err != nil {
return nil, fmt.Errorf("failed to parse DNS upstream: %v: %w", dnsUpstream, err)
}
dnsAddr16 := dnsAddrPort.Addr().As16()
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{
Ip: common.Ipv6ByteSliceToUint32Array(dnsAddr16[:]),
Port: swap16(dnsAddrPort.Port()),
}, ebpf.UpdateAny); err != nil {
return nil, err
}
@ -159,10 +186,10 @@ retryLoadBpf:
bpf: &bpf,
SimulatedLpmTries: builder.SimulatedLpmTries,
SimulatedDomainSet: builder.SimulatedDomainSet,
Final: final,
Final: routingA.Final,
mutex: sync.Mutex{},
dnsCache: make(map[string]*dnsCache),
dnsUpstream: netip.AddrPortFrom(cfDnsAddr, cfDnsPort),
dnsUpstream: dnsAddrPort,
deferFuncs: []func() error{bpf.Close},
}, nil
}
@ -248,10 +275,13 @@ func (c *ControlPlane) BindLink(ifname string) error {
}
}
c.deferFuncs = append(c.deferFuncs, func() error {
return netlink.QdiscDel(qdisc)
if err := netlink.QdiscDel(qdisc); err != nil {
return fmt.Errorf("QdiscDel: %w", err)
}
return nil
})
filter := &netlink.BpfFilter{
filterIngress := &netlink.BpfFilter{
FilterAttrs: netlink.FilterAttrs{
LinkIndex: link.Attrs().Index,
Parent: netlink.HANDLE_MIN_INGRESS,
@ -263,12 +293,9 @@ func (c *ControlPlane) BindLink(ifname string) error {
Name: consts.AppName + "_ingress",
DirectAction: true,
}
if err := netlink.FilterAdd(filter); err != nil {
if err := netlink.FilterAdd(filterIngress); err != nil {
return fmt.Errorf("cannot attach ebpf object to filter ingress: %w", err)
}
c.deferFuncs = append(c.deferFuncs, func() error {
return netlink.FilterDel(filter)
})
filterEgress := &netlink.BpfFilter{
FilterAttrs: netlink.FilterAttrs{
LinkIndex: link.Attrs().Index,
@ -284,9 +311,6 @@ func (c *ControlPlane) BindLink(ifname string) error {
if err := netlink.FilterAdd(filterEgress); err != nil {
return fmt.Errorf("cannot attach ebpf object to filter ingress: %w", err)
}
c.deferFuncs = append(c.deferFuncs, func() error {
return netlink.FilterDel(filter)
})
return nil
}

View File

@ -97,11 +97,10 @@ func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEnd
}
udpConn, err := createOption.Dialer.Dial("udp", createOption.Target.String())
//udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
p.pool[lAddr] = &UdpEndpoint{
ue = &UdpEndpoint{
conn: udpConn.(net.PacketConn),
deadlineTimer: time.AfterFunc(createOption.NatTimeout, func() {
p.mu.Lock()
@ -114,7 +113,7 @@ func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEnd
handler: createOption.Handler,
NatTimeout: createOption.NatTimeout,
}
ue = p.pool[lAddr]
p.pool[lAddr] = ue
// Receive UDP messages.
go ue.start()
} else {

View File

@ -6,7 +6,6 @@
package dialer
import (
"github.com/sirupsen/logrus"
"net"
)
@ -16,6 +15,6 @@ func (*blockDialer) Dial(network string, addr string) (c net.Conn, err error) {
return nil, net.ErrClosed
}
func NewBlockDialer(log *logrus.Logger) *Dialer {
return newDialer(&blockDialer{}, log, true, "block", "block", "")
func NewBlockDialer(option *GlobalOption) *Dialer {
return newDialer(&blockDialer{}, option, true, "block", "block", "")
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"github.com/mzz2017/softwind/pkg/fastrand"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/consts"
"golang.org/x/net/proxy"
"net"
"net/http"
@ -18,14 +17,13 @@ import (
)
var (
ConnectivityTestFailedErr = fmt.Errorf("connectivity test failed")
ConnectivityTestFailedErr = fmt.Errorf("Connectivity Check failed")
UnexpectedFieldErr = fmt.Errorf("unexpected field")
InvalidParameterErr = fmt.Errorf("invalid parameters")
)
type Dialer struct {
log *logrus.Logger
*GlobalOption
proxy.Dialer
supportUDP bool
name string
@ -41,23 +39,28 @@ type Dialer struct {
ticker *time.Ticker
}
type GlobalOption struct {
Log *logrus.Logger
CheckUrl string
}
// NewDialer is for register in general.
func NewDialer(dialer proxy.Dialer, log *logrus.Logger, supportUDP bool, name string, protocol string, link string) *Dialer {
d := newDialer(dialer, log, supportUDP, name, protocol, link)
func NewDialer(dialer proxy.Dialer, option *GlobalOption, supportUDP bool, name string, protocol string, link string) *Dialer {
d := newDialer(dialer, option, supportUDP, name, protocol, link)
go d.aliveBackground()
return d
}
// newDialer does not run background tasks.
func newDialer(dialer proxy.Dialer, log *logrus.Logger, supportUDP bool, name string, protocol string, link string) *Dialer {
func newDialer(dialer proxy.Dialer, option *GlobalOption, supportUDP bool, name string, protocol string, link string) *Dialer {
d := &Dialer{
Dialer: dialer,
log: log,
supportUDP: supportUDP,
name: name,
protocol: protocol,
link: link,
Latencies10: NewLatenciesN(10),
Dialer: dialer,
GlobalOption: option,
supportUDP: supportUDP,
name: name,
protocol: protocol,
link: link,
Latencies10: NewLatenciesN(10),
// Set a very big cycle to wait for init.
ticker: time.NewTicker(time.Hour),
aliveDialerSetSet: make(map[*AliveDialerSet]int),
@ -68,8 +71,8 @@ func newDialer(dialer proxy.Dialer, log *logrus.Logger, supportUDP bool, name st
func (d *Dialer) aliveBackground() {
timeout := 10 * time.Second
cycle := 15 * time.Second
// Test once immediately.
go d.Test(timeout, consts.TestUrl)
// Check once immediately.
go d.Check(timeout, d.CheckUrl)
// Sleep to avoid avalanche.
time.Sleep(time.Duration(fastrand.Int63n(int64(cycle))))
@ -79,7 +82,7 @@ func (d *Dialer) aliveBackground() {
for range d.ticker.C {
// No need to test if there is no dialer selection policy using its latency.
if len(d.aliveDialerSetSet) > 0 {
d.Test(timeout, consts.TestUrl)
d.Check(timeout, d.CheckUrl)
}
}
}
@ -124,7 +127,7 @@ func (d *Dialer) Link() string {
return d.link
}
func (d *Dialer) Test(timeout time.Duration, url string) (ok bool, err error) {
func (d *Dialer) Check(timeout time.Duration, url string) (ok bool, err error) {
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
defer cancel()
start := time.Now()
@ -134,14 +137,13 @@ func (d *Dialer) Test(timeout time.Duration, url string) (ok bool, err error) {
if ok && err == nil {
// No error.
latency := time.Since(start)
// FIXME: Use log instead of logrus.
d.log.Debugf("Connectivity Test <%v>: %v", d.name, latency)
d.Log.Debugf("Connectivity Check [%v]: %v", d.name, latency)
d.Latencies10.AppendLatency(latency)
alive = true
} else {
// Append timeout if there is any error or unexpected status code.
if err != nil {
d.log.Debugf("Connectivity Test <%v>: %v", d.name, err.Error())
d.Log.Debugf("Connectivity Check [%v]: %v", d.name, err.Error())
}
d.Latencies10.AppendLatency(timeout)
}

View File

@ -1,7 +1,6 @@
package dialer
import (
"github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"net"
)
@ -9,23 +8,23 @@ import (
var SymmetricDirect = newDirect(false)
var FullconeDirect = newDirect(true)
func NewDirectDialer(log *logrus.Logger, fullcone bool) *Dialer {
func NewDirectDialer(option *GlobalOption, fullcone bool) *Dialer {
if fullcone {
return newDialer(FullconeDirect, log, true, "direct", "direct", "")
return newDialer(FullconeDirect, option, true, "direct", "direct", "")
} else {
return newDialer(SymmetricDirect, log, true, "direct", "direct", "")
return newDialer(SymmetricDirect, option, true, "direct", "direct", "")
}
}
type direct struct {
proxy.Dialer
netDialer net.Dialer
netDialer *net.Dialer
fullCone bool
}
func newDirect(fullCone bool) proxy.Dialer {
return &direct{
netDialer: net.Dialer{},
netDialer: &net.Dialer{},
fullCone: fullCone,
}
}

View File

@ -3,7 +3,6 @@ package http
import (
"fmt"
"github.com/mzz2017/softwind/protocol/http"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/component/outbound/dialer"
"net"
"net/url"
@ -25,12 +24,12 @@ type HTTP struct {
Protocol string `json:"protocol"`
}
func NewHTTP(log *logrus.Logger, link string) (*dialer.Dialer, error) {
func NewHTTP(option *dialer.GlobalOption, link string) (*dialer.Dialer, error) {
s, err := ParseHTTPURL(link)
if err != nil {
return nil, fmt.Errorf("%w: %v", dialer.InvalidParameterErr, err)
}
return s.Dialer(log)
return s.Dialer(option)
}
func ParseHTTPURL(link string) (data *HTTP, err error) {
@ -62,13 +61,13 @@ func ParseHTTPURL(link string) (data *HTTP, err error) {
}, nil
}
func (s *HTTP) Dialer(log *logrus.Logger) (*dialer.Dialer, error) {
func (s *HTTP) Dialer(option *dialer.GlobalOption) (*dialer.Dialer, error) {
u := s.URL()
d, err := http.NewHTTPProxy(&u, dialer.SymmetricDirect) // HTTP Proxy does not support full-cone.
if err != nil {
return nil, err
}
return dialer.NewDialer(d, log, false, s.Name, s.Protocol, u.String()), nil
return dialer.NewDialer(d, option, false, s.Name, s.Protocol, u.String()), nil
}
func (s *HTTP) URL() url.URL {

View File

@ -7,11 +7,10 @@ package dialer
import (
"fmt"
"github.com/sirupsen/logrus"
"net/url"
)
type FromLinkCreator func(log *logrus.Logger, link string) (dialer *Dialer, err error)
type FromLinkCreator func(option *GlobalOption, link string) (dialer *Dialer, err error)
var fromLinkCreators = make(map[string]FromLinkCreator)
@ -19,13 +18,13 @@ func FromLinkRegister(name string, creator FromLinkCreator) {
fromLinkCreators[name] = creator
}
func NewFromLink(log *logrus.Logger, link string) (dialer *Dialer, err error) {
func NewFromLink(option *GlobalOption, link string) (dialer *Dialer, err error) {
u, err := url.Parse(link)
if err != nil {
return nil, err
}
if creator, ok := fromLinkCreators[u.Scheme]; ok {
return creator(log, link)
return creator(option, link)
} else {
return nil, fmt.Errorf("unexpected link type: %v", u.Scheme)
}

View File

@ -5,7 +5,6 @@ import (
"fmt"
"github.com/mzz2017/softwind/protocol"
"github.com/mzz2017/softwind/protocol/shadowsocks"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/component/outbound/dialer"
"github.com/v2rayA/dae/component/outbound/transport/simpleobfs"
@ -34,15 +33,15 @@ type Shadowsocks struct {
Protocol string `json:"protocol"`
}
func NewShadowsocksFromLink(log *logrus.Logger, link string) (*dialer.Dialer, error) {
func NewShadowsocksFromLink(option *dialer.GlobalOption, link string) (*dialer.Dialer, error) {
s, err := ParseSSURL(link)
if err != nil {
return nil, err
}
return s.Dialer(log)
return s.Dialer(option)
}
func (s *Shadowsocks) Dialer(log *logrus.Logger) (*dialer.Dialer, error) {
func (s *Shadowsocks) Dialer(option *dialer.GlobalOption) (*dialer.Dialer, error) {
// FIXME: support plain/none.
switch s.Cipher {
case "aes-256-gcm", "aes-128-gcm", "chacha20-poly1305", "chacha20-ietf-poly1305":
@ -77,7 +76,7 @@ func (s *Shadowsocks) Dialer(log *logrus.Logger) (*dialer.Dialer, error) {
}
supportUDP = false
}
return dialer.NewDialer(d, log, supportUDP, s.Name, s.Protocol, s.ExportToURL()), nil
return dialer.NewDialer(d, option, supportUDP, s.Name, s.Protocol, s.ExportToURL()), nil
}
func ParseSSURL(u string) (data *Shadowsocks, err error) {

View File

@ -3,7 +3,6 @@ package shadowsocksr
import (
"encoding/base64"
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/component/outbound/dialer"
ssr "github.com/v2rayA/shadowsocksR/client"
@ -31,15 +30,15 @@ type ShadowsocksR struct {
Protocol string `json:"protocol"`
}
func NewShadowsocksR(log *logrus.Logger, link string) (*dialer.Dialer, error) {
func NewShadowsocksR(option *dialer.GlobalOption, link string) (*dialer.Dialer, error) {
s, err := ParseSSRURL(link)
if err != nil {
return nil, err
}
return s.Dialer(log)
return s.Dialer(option)
}
func (s *ShadowsocksR) Dialer(log *logrus.Logger) (*dialer.Dialer, error) {
func (s *ShadowsocksR) Dialer(option *dialer.GlobalOption) (*dialer.Dialer, error) {
u := url.URL{
Scheme: "ssr",
User: url.UserPassword(s.Cipher, s.Password),
@ -55,7 +54,7 @@ func (s *ShadowsocksR) Dialer(log *logrus.Logger) (*dialer.Dialer, error) {
if err != nil {
return nil, err
}
return dialer.NewDialer(d, log, false, s.Name, s.Protocol, s.ExportToURL()), nil
return dialer.NewDialer(d, option, false, s.Name, s.Protocol, s.ExportToURL()), nil
}
func ParseSSRURL(u string) (data *ShadowsocksR, err error) {

View File

@ -2,7 +2,6 @@ package socks
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/component/outbound/dialer"
//"github.com/mzz2017/softwind/protocol/socks4"
"github.com/mzz2017/softwind/protocol/socks5"
@ -27,15 +26,15 @@ type Socks struct {
Protocol string `json:"protocol"`
}
func NewSocks(log *logrus.Logger, link string) (*dialer.Dialer, error) {
func NewSocks(option *dialer.GlobalOption, link string) (*dialer.Dialer, error) {
s, err := ParseSocksURL(link)
if err != nil {
return nil, dialer.InvalidParameterErr
}
return s.Dialer(log)
return s.Dialer(option)
}
func (s *Socks) Dialer(log *logrus.Logger) (*dialer.Dialer, error) {
func (s *Socks) Dialer(option *dialer.GlobalOption) (*dialer.Dialer, error) {
link := s.ExportToURL()
switch s.Protocol {
case "", "socks", "socks5":
@ -43,7 +42,7 @@ func (s *Socks) Dialer(log *logrus.Logger) (*dialer.Dialer, error) {
if err != nil {
return nil, err
}
return dialer.NewDialer(d, log, true, s.Name, s.Protocol, link), nil
return dialer.NewDialer(d, option, true, s.Name, s.Protocol, link), nil
//case "socks4", "socks4a":
// d, err := socks4.NewSocks4Dialer(link, &proxy.Direct{})
// if err != nil {

View File

@ -4,7 +4,6 @@ import (
"fmt"
"github.com/mzz2017/softwind/protocol"
"github.com/mzz2017/softwind/transport/grpc"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/component/outbound/dialer"
"github.com/v2rayA/dae/component/outbound/transport/tls"
@ -35,15 +34,15 @@ type Trojan struct {
Protocol string `json:"protocol"`
}
func NewTrojan(log *logrus.Logger, link string) (*dialer.Dialer, error) {
func NewTrojan(option *dialer.GlobalOption, link string) (*dialer.Dialer, error) {
s, err := ParseTrojanURL(link)
if err != nil {
return nil, err
}
return s.Dialer(log)
return s.Dialer(option)
}
func (s *Trojan) Dialer(log *logrus.Logger) (*dialer.Dialer, error) {
func (s *Trojan) Dialer(option *dialer.GlobalOption) (*dialer.Dialer, error) {
d := dialer.FullconeDirect // Trojan Proxy supports full-cone.
u := url.URL{
Scheme: "tls",
@ -102,7 +101,7 @@ func (s *Trojan) Dialer(log *logrus.Logger) (*dialer.Dialer, error) {
}); err != nil {
return nil, err
}
return dialer.NewDialer(d, log, true, s.Name, s.Protocol, s.ExportToURL()), nil
return dialer.NewDialer(d, option, true, s.Name, s.Protocol, s.ExportToURL()), nil
}
func ParseTrojanURL(u string) (data *Trojan, err error) {

View File

@ -6,7 +6,6 @@ import (
jsoniter "github.com/json-iterator/go"
"github.com/mzz2017/softwind/protocol"
"github.com/mzz2017/softwind/transport/grpc"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/component/outbound/dialer"
"github.com/v2rayA/dae/component/outbound/transport/tls"
@ -42,7 +41,7 @@ type V2Ray struct {
Protocol string `json:"protocol"`
}
func NewV2Ray(log *logrus.Logger, link string) (*dialer.Dialer, error) {
func NewV2Ray(option *dialer.GlobalOption, link string) (*dialer.Dialer, error) {
var (
s *V2Ray
err error
@ -64,10 +63,10 @@ func NewV2Ray(log *logrus.Logger, link string) (*dialer.Dialer, error) {
default:
return nil, dialer.InvalidParameterErr
}
return s.Dialer(log)
return s.Dialer(option)
}
func (s *V2Ray) Dialer(log *logrus.Logger) (data *dialer.Dialer, err error) {
func (s *V2Ray) Dialer(option *dialer.GlobalOption) (data *dialer.Dialer, err error) {
var d proxy.Dialer
switch s.Protocol {
case "vmess":
@ -148,7 +147,7 @@ func (s *V2Ray) Dialer(log *logrus.Logger) (data *dialer.Dialer, err error) {
}); err != nil {
return nil, err
}
return dialer.NewDialer(d, log, true, s.Ps, s.Protocol, s.ExportToURL()), nil
return dialer.NewDialer(d, option, true, s.Ps, s.Protocol, s.ExportToURL()), nil
}
func ParseVlessURL(vless string) (data *V2Ray, err error) {

View File

@ -10,8 +10,12 @@ import (
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/outbound/dialer"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/pkg/config_parser"
"golang.org/x/net/proxy"
"log"
"net"
"strconv"
)
type DialerSelectionPolicy struct {
@ -19,6 +23,51 @@ type DialerSelectionPolicy struct {
FixedIndex int
}
func NewDialerSelectionPolicyFromGroupParam(param *config.GroupParam) (policy *DialerSelectionPolicy, err error) {
switch val := param.Policy.(type) {
case string:
switch consts.DialerSelectionPolicy(val) {
case consts.DialerSelectionPolicy_Random,
consts.DialerSelectionPolicy_MinAverage10Latencies,
consts.DialerSelectionPolicy_MinLastLatency:
return &DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy(val),
}, nil
case consts.DialerSelectionPolicy_Fixed:
return nil, fmt.Errorf("%v need to specify node index", val)
default:
return nil, fmt.Errorf("unexpected policy: %v", val)
}
case []*config_parser.Function:
if len(val) > 1 || len(val) == 0 {
logrus.Debugf("%@", val)
return nil, fmt.Errorf("policy should be exact 1 function: got %v", len(val))
}
f := val[0]
switch consts.DialerSelectionPolicy(f.Name) {
case consts.DialerSelectionPolicy_Fixed:
// Should be like:
// policy: fixed(0)
if len(f.Params) > 1 || f.Params[0].Key != "" {
return nil, fmt.Errorf(`invalid "%v" param format`, f.Name)
}
strIndex := f.Params[0].Val
index, err := strconv.Atoi(strIndex)
if len(f.Params) > 1 || f.Params[0].Key != "" {
return nil, fmt.Errorf(`invalid "%v" param format: %w`, f.Name, err)
}
return &DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy(f.Name),
FixedIndex: index,
}, nil
default:
return nil, fmt.Errorf("unexpected policy func: %v", f.Name)
}
default:
return nil, fmt.Errorf("unexpected param.Policy.(type): %T", val)
}
}
type DialerGroup struct {
proxy.Dialer
block *dialer.Dialer
@ -34,9 +83,9 @@ type DialerGroup struct {
selectionPolicy *DialerSelectionPolicy
}
func NewDialerGroup(log *logrus.Logger, name string, dialers []*dialer.Dialer, p DialerSelectionPolicy) *DialerGroup {
func NewDialerGroup(option *dialer.GlobalOption, name string, dialers []*dialer.Dialer, p DialerSelectionPolicy) *DialerGroup {
var registeredAliveDialerSet bool
a := dialer.NewAliveDialerSet(log, p.Policy, dialers, true)
a := dialer.NewAliveDialerSet(option.Log, p.Policy, dialers, true)
switch p.Policy {
case consts.DialerSelectionPolicy_Random,
@ -56,10 +105,10 @@ func NewDialerGroup(log *logrus.Logger, name string, dialers []*dialer.Dialer, p
}
return &DialerGroup{
log: log,
log: option.Log,
Name: name,
Dialers: dialers,
block: dialer.NewBlockDialer(log),
block: dialer.NewBlockDialer(option),
AliveDialerSet: a,
registeredAliveDialerSet: registeredAliveDialerSet,
selectionPolicy: &p,
@ -121,5 +170,6 @@ func (g *DialerGroup) Dial(network string, addr string) (c net.Conn, err error)
if err != nil {
return nil, err
}
g.log.Tracef("Group [%v] dial using [%v]", g.Name, d.Name())
return d.Dial(network, addr)
}

View File

@ -14,14 +14,22 @@ import (
"time"
)
const (
testCheckUrl = "https://connectivitycheck.gstatic.com/generate_204"
)
func TestDialerGroup_Select_Fixed(t *testing.T) {
log := logger.NewLogger(2)
option := &dialer.GlobalOption{
Log: log,
CheckUrl: testCheckUrl,
}
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(log, true),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(option, true),
dialer.NewDirectDialer(option, false),
}
fixedIndex := 1
g := NewDialerGroup(log, "test-group", dialers, DialerSelectionPolicy{
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Fixed,
FixedIndex: fixedIndex,
})
@ -50,19 +58,23 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
log := logger.NewLogger(2)
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
option := &dialer.GlobalOption{
Log: log,
CheckUrl: testCheckUrl,
}
g := NewDialerGroup(log, "test-group", dialers, DialerSelectionPolicy{
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_MinLastLatency,
})
@ -113,14 +125,18 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
func TestDialerGroup_Select_Random(t *testing.T) {
log := logger.NewLogger(2)
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
option := &dialer.GlobalOption{
Log: log,
CheckUrl: testCheckUrl,
}
g := NewDialerGroup(log, "test-group", dialers, DialerSelectionPolicy{
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Random,
})
count := make([]int, len(dialers))
@ -146,14 +162,18 @@ func TestDialerGroup_Select_Random(t *testing.T) {
func TestDialerGroup_SetAlive(t *testing.T) {
log := logger.NewLogger(2)
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
dialer.NewDirectDialer(log, false),
option := &dialer.GlobalOption{
Log: log,
CheckUrl: testCheckUrl,
}
g := NewDialerGroup(log, "test-group", dialers, DialerSelectionPolicy{
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
dialer.NewDirectDialer(option, false),
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Random,
})
zeroTarget := 3

View File

@ -0,0 +1,96 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2023, mzz2017 (mzz@tuta.io). All rights reserved.
*/
package outbound
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/component/outbound/dialer"
"github.com/v2rayA/dae/pkg/config_parser"
"regexp"
"strings"
)
const (
FilterInput_Name = "name"
FilterInput_Link = "link"
)
const (
FilterKey_Name_Regex = "regex"
FilterKey_Name_Keyword = "keyword"
)
type DialerSet struct {
Dialers []*dialer.Dialer
}
func NewDialerSetFromLinks(option *dialer.GlobalOption, nodes []string) *DialerSet {
s := &DialerSet{Dialers: make([]*dialer.Dialer, 0, len(nodes))}
for _, node := range nodes {
d, err := dialer.NewFromLink(option, node)
if err != nil {
option.Log.Infof("failed to parse node: %v: %v", node, err)
continue
}
s.Dialers = append(s.Dialers, d)
}
return s
}
func hit(dialer *dialer.Dialer, filters []*config_parser.Function) (hit bool, err error) {
// Example
// filter: name(regex:'^.*hk.*$', keyword:'sg') && name(keyword:'disney')
// And
for _, filter := range filters {
var subFilterHit bool
switch filter.Name {
case FilterInput_Name:
// Or
for _, param := range filter.Params {
switch param.Key {
case FilterKey_Name_Regex:
matched, _ := regexp.MatchString(param.Val, dialer.Name())
logrus.Warnln(param.Val, matched, dialer.Name())
if matched {
subFilterHit = true
break
}
case FilterKey_Name_Keyword:
if strings.Contains(dialer.Name(), param.Val) {
subFilterHit = true
break
}
case "":
return false, fmt.Errorf(`key of "filter: %v()" cannot be empty`, filter.Name)
default:
return false, fmt.Errorf(`unsupported filter key "%v" in "filter: %v()"`, param.Key, filter.Name)
}
}
default:
return false, fmt.Errorf(`unsupported filter input type: "%v"`, filter.Name)
}
if !subFilterHit {
return false, nil
}
}
return true, nil
}
func (s *DialerSet) Filter(filters []*config_parser.Function) (dialers []*dialer.Dialer, err error) {
for _, d := range s.Dialers {
hit, err := hit(d, filters)
if err != nil {
return nil, err
}
if hit {
dialers = append(dialers, d)
}
}
return dialers, nil
}

View File

@ -1,91 +0,0 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2022, mzz2017 (mzz@tuta.io). All rights reserved.
*/
package routing
import (
"fmt"
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
"reflect"
"strings"
)
type ErrorType string
const (
ErrorType_Unsupported ErrorType = "is not supported"
ErrorType_NotSet ErrorType = "is not set"
)
type ConsoleErrorListener struct {
ErrorBuilder strings.Builder
}
func NewConsoleErrorListener() *ConsoleErrorListener {
return &ConsoleErrorListener{}
}
func (d *ConsoleErrorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) {
// Do not accumulate errors.
if d.ErrorBuilder.Len() > 0 {
return
}
backtrack := column
if backtrack > 30 {
backtrack = 30
}
starting := fmt.Sprintf("line %v:%v ", line, column)
offset := len(starting) + backtrack
var (
simplyWrite bool
token antlr.Token
)
if offendingSymbol == nil {
simplyWrite = true
} else {
token = offendingSymbol.(antlr.Token)
simplyWrite = token.GetTokenType() == -1
}
if simplyWrite {
d.ErrorBuilder.WriteString(fmt.Sprintf("%v%v", starting, msg))
return
}
beginOfLine := token.GetStart() - backtrack
strPeek := token.GetInputStream().GetText(beginOfLine, token.GetStop()+30)
wrap := strings.IndexByte(strPeek, '\n')
if wrap == -1 {
wrap = token.GetStop() + 30
} else {
wrap += beginOfLine - 1
}
strLine := token.GetInputStream().GetText(beginOfLine, wrap)
d.ErrorBuilder.WriteString(fmt.Sprintf("%v%v\n%v%v: %v\n", starting, strLine, strings.Repeat(" ", offset), strings.Repeat("^", token.GetStop()-token.GetStart()+1), msg))
}
func (d *ConsoleErrorListener) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
}
func (d *ConsoleErrorListener) ReportAttemptingFullContext(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, conflictingAlts *antlr.BitSet, configs antlr.ATNConfigSet) {
}
func (d *ConsoleErrorListener) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex, prediction int, configs antlr.ATNConfigSet) {
}
func BaseContext(ctx interface{}) (baseCtx *antlr.BaseParserRuleContext) {
val := reflect.ValueOf(ctx)
for val.Kind() == reflect.Pointer && val.Type() != reflect.TypeOf(&antlr.BaseParserRuleContext{}) {
val = val.Elem()
}
if val.Type() == reflect.TypeOf(&antlr.BaseParserRuleContext{}) {
baseCtx = val.Interface().(*antlr.BaseParserRuleContext)
} else {
baseCtxVal := val.FieldByName("BaseParserRuleContext")
if !baseCtxVal.IsValid() {
panic("has no field BaseParserRuleContext")
}
baseCtx = baseCtxVal.Interface().(*antlr.BaseParserRuleContext)
}
return baseCtx
}

View File

@ -9,6 +9,7 @@ import (
"fmt"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/pkg/config_parser"
"net/netip"
"strings"
)
@ -30,7 +31,7 @@ type MatcherBuilder interface {
Build() (err error)
}
func GroupParamValuesByKey(params []*Param) map[string][]string {
func GroupParamValuesByKey(params []*config_parser.Param) map[string][]string {
groups := make(map[string][]string)
for _, param := range params {
groups[param.Key] = append(groups[param.Key], param.Val)
@ -53,7 +54,7 @@ func ParsePrefixes(values []string) (cidrs []netip.Prefix, err error) {
return cidrs, nil
}
func ApplyMatcherBuilder(builder MatcherBuilder, rules []RoutingRule, finalOutbound string) (err error) {
func ApplyMatcherBuilder(builder MatcherBuilder, rules []*config_parser.RoutingRule, finalOutbound string) (err error) {
for _, rule := range rules {
// rule is like: domain(domain:baidu.com) && port(443) -> proxy
for iFunc, f := range rule.AndFunctions {

View File

@ -7,25 +7,26 @@ package routing
import (
"fmt"
"github.com/mohae/deepcopy"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/assets"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/pkg/config_parser"
"github.com/v2rayA/dae/pkg/geodata"
"net/netip"
"sort"
"strings"
)
import "github.com/mohae/deepcopy"
type RulesOptimizer interface {
Optimize(rules []RoutingRule) ([]RoutingRule, error)
Optimize(rules []*config_parser.RoutingRule) ([]*config_parser.RoutingRule, error)
}
func DeepCloneRules(rules []RoutingRule) (newRules []RoutingRule) {
return deepcopy.Copy(rules).([]RoutingRule)
func DeepCloneRules(rules []*config_parser.RoutingRule) (newRules []*config_parser.RoutingRule) {
return deepcopy.Copy(rules).([]*config_parser.RoutingRule)
}
func ApplyRulesOptimizers(rules []RoutingRule, optimizers ...RulesOptimizer) ([]RoutingRule, error) {
func ApplyRulesOptimizers(rules []*config_parser.RoutingRule, optimizers ...RulesOptimizer) ([]*config_parser.RoutingRule, error) {
rules = DeepCloneRules(rules)
var err error
for _, opt := range optimizers {
@ -39,7 +40,7 @@ func ApplyRulesOptimizers(rules []RoutingRule, optimizers ...RulesOptimizer) ([]
type RefineFunctionParamKeyOptimizer struct {
}
func (o *RefineFunctionParamKeyOptimizer) Optimize(rules []RoutingRule) ([]RoutingRule, error) {
func (o *RefineFunctionParamKeyOptimizer) Optimize(rules []*config_parser.RoutingRule) ([]*config_parser.RoutingRule, error) {
for _, rule := range rules {
for _, function := range rule.AndFunctions {
for _, param := range function.Params {
@ -63,7 +64,7 @@ func (o *RefineFunctionParamKeyOptimizer) Optimize(rules []RoutingRule) ([]Routi
type MergeAndSortRulesOptimizer struct {
}
func (o *MergeAndSortRulesOptimizer) Optimize(rules []RoutingRule) ([]RoutingRule, error) {
func (o *MergeAndSortRulesOptimizer) Optimize(rules []*config_parser.RoutingRule) ([]*config_parser.RoutingRule, error) {
if len(rules) == 0 {
return rules, nil
}
@ -74,7 +75,7 @@ func (o *MergeAndSortRulesOptimizer) Optimize(rules []RoutingRule) ([]RoutingRul
})
}
// Merge singleton rules with the same outbound.
var newRules []RoutingRule
var newRules []*config_parser.RoutingRule
mergingRule := rules[0]
for i := 1; i < len(rules); i++ {
if len(mergingRule.AndFunctions) == 1 &&
@ -123,20 +124,20 @@ func (o *MergeAndSortRulesOptimizer) Optimize(rules []RoutingRule) ([]RoutingRul
type DeduplicateParamsOptimizer struct {
}
func deduplicateParams(list []*Param) []*Param {
res := make([]*Param, 0, len(list))
m := make(map[Param]struct{})
func deduplicateParams(list []*config_parser.Param) []*config_parser.Param {
res := make([]*config_parser.Param, 0, len(list))
m := make(map[string]struct{})
for _, v := range list {
if _, ok := m[*v]; ok {
if _, ok := m[v.String(true)]; ok {
continue
}
m[*v] = struct{}{}
m[v.String(true)] = struct{}{}
res = append(res, v)
}
return res
}
func (o *DeduplicateParamsOptimizer) Optimize(rules []RoutingRule) ([]RoutingRule, error) {
func (o *DeduplicateParamsOptimizer) Optimize(rules []*config_parser.RoutingRule) ([]*config_parser.RoutingRule, error) {
for _, rule := range rules {
for _, f := range rule.AndFunctions {
f.Params = deduplicateParams(f.Params)
@ -149,7 +150,7 @@ type DatReaderOptimizer struct {
Logger *logrus.Logger
}
func (o *DatReaderOptimizer) loadGeoSite(filename string, code string) (params []*Param, err error) {
func (o *DatReaderOptimizer) loadGeoSite(filename string, code string) (params []*config_parser.Param, err error) {
if !strings.HasSuffix(filename, ".dat") {
filename += ".dat"
}
@ -167,25 +168,25 @@ func (o *DatReaderOptimizer) loadGeoSite(filename string, code string) (params [
switch item.Type {
case geodata.Domain_Full:
// Full.
params = append(params, &Param{
params = append(params, &config_parser.Param{
Key: consts.RoutingDomain_Full,
Val: item.Value,
})
case geodata.Domain_RootDomain:
// Suffix.
params = append(params, &Param{
params = append(params, &config_parser.Param{
Key: consts.RoutingDomain_Suffix,
Val: item.Value,
})
case geodata.Domain_Plain:
// Keyword.
params = append(params, &Param{
params = append(params, &config_parser.Param{
Key: consts.RoutingDomain_Keyword,
Val: item.Value,
})
case geodata.Domain_Regex:
// Regex.
params = append(params, &Param{
params = append(params, &config_parser.Param{
Key: consts.RoutingDomain_Regex,
Val: item.Value,
})
@ -194,7 +195,7 @@ func (o *DatReaderOptimizer) loadGeoSite(filename string, code string) (params [
return params, nil
}
func (o *DatReaderOptimizer) loadGeoIp(filename string, code string) (params []*Param, err error) {
func (o *DatReaderOptimizer) loadGeoIp(filename string, code string) (params []*config_parser.Param, err error) {
if !strings.HasSuffix(filename, ".dat") {
filename += ".dat"
}
@ -219,7 +220,7 @@ func (o *DatReaderOptimizer) loadGeoIp(filename string, code string) (params []*
if !ok {
return nil, fmt.Errorf("bad geoip file: %v", filename)
}
params = append(params, &Param{
params = append(params, &config_parser.Param{
Key: "",
Val: netip.PrefixFrom(ip, int(item.Prefix)).String(),
})
@ -227,14 +228,14 @@ func (o *DatReaderOptimizer) loadGeoIp(filename string, code string) (params []*
return params, nil
}
func (o *DatReaderOptimizer) Optimize(rules []RoutingRule) ([]RoutingRule, error) {
func (o *DatReaderOptimizer) Optimize(rules []*config_parser.RoutingRule) ([]*config_parser.RoutingRule, error) {
var err error
for _, rule := range rules {
for _, f := range rule.AndFunctions {
var newParams []*Param
var newParams []*config_parser.Param
for _, param := range f.Params {
// Parse this param and replace it with more.
var params []*Param
var params []*config_parser.Param
switch param.Key {
case "geosite":
params, err = o.loadGeoSite("geosite", param.Val)
@ -250,7 +251,7 @@ func (o *DatReaderOptimizer) Optimize(rules []RoutingRule) ([]RoutingRule, error
}
default:
// Keep this param.
params = []*Param{param}
params = []*config_parser.Param{param}
}
if err != nil {
return nil, err

View File

@ -1,34 +0,0 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2022, mzz2017 (mzz@tuta.io). All rights reserved.
*/
package routing
import (
"fmt"
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/v2rayA/RoutingA-dist/go/routingA"
)
func Parse(in string) (routingRules []RoutingRule, finalOutbound string, err error) {
errorListener := NewConsoleErrorListener()
lexer := routingA.NewroutingALexer(antlr.NewInputStream(in))
lexer.RemoveErrorListeners()
lexer.AddErrorListener(errorListener)
input := antlr.NewCommonTokenStream(lexer, 0)
parser := routingA.NewroutingAParser(input)
parser.RemoveErrorListeners()
parser.AddErrorListener(errorListener)
parser.BuildParseTrees = true
tree := parser.Start()
walker := NewRoutingAWalker(parser)
antlr.ParseTreeWalkerDefault.Walk(walker, tree)
if errorListener.ErrorBuilder.Len() != 0 {
return nil, "", fmt.Errorf("%v", errorListener.ErrorBuilder.String())
}
return walker.RoutingRules, walker.FinalOutbound, nil
}

View File

@ -1,241 +0,0 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2022, mzz2017 (mzz@tuta.io). All rights reserved.
*/
package routing
import (
"fmt"
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/v2rayA/RoutingA-dist/go/routingA"
"github.com/v2rayA/dae/common/consts"
"strconv"
"strings"
)
type RoutingAWalker struct {
*routingA.BaseroutingAListener
parser antlr.Parser
FinalOutbound string
RoutingRules []RoutingRule
}
func NewRoutingAWalker(parser antlr.Parser) *RoutingAWalker {
return &RoutingAWalker{
parser: parser,
}
}
type RoutingRule struct {
AndFunctions []*Function
Outbound string
}
func (r *RoutingRule) String(calcN bool) string {
var builder strings.Builder
var n int
for _, f := range r.AndFunctions {
if builder.Len() != 0 {
builder.WriteString(" && ")
}
var paramBuilder strings.Builder
n += len(f.Params)
for _, p := range f.Params {
if paramBuilder.Len() != 0 {
paramBuilder.WriteString(", ")
}
if p.Key != "" {
paramBuilder.WriteString(p.Key + ": " + p.Val)
} else {
paramBuilder.WriteString(p.Val)
}
}
builder.WriteString(fmt.Sprintf("%v(%v)", f.Name, paramBuilder.String()))
}
builder.WriteString(" -> " + r.Outbound)
if calcN {
builder.WriteString(" [n = " + strconv.Itoa(n) + "]")
}
builder.WriteString("\n")
return builder.String()
}
type Function struct {
Name string
Params []*Param
}
type Param struct {
Key string
Val string
}
type paramParser struct {
list []*Param
}
func getValueFromLiteral(literal *routingA.LiteralContext) string {
quote := literal.Quote_literal()
if quote == nil {
return literal.GetText()
}
text := quote.GetText()
return text[1 : len(text)-1]
}
func (p *paramParser) parseParam(ctx *routingA.ParameterContext) *Param {
children := ctx.GetChildren()
if len(children) == 3 {
return &Param{
Key: children[0].(*antlr.TerminalNodeImpl).GetText(),
Val: getValueFromLiteral(children[2].(*routingA.LiteralContext)),
}
} else if len(children) == 1 {
return &Param{
Key: "",
Val: getValueFromLiteral(children[0].(*routingA.LiteralContext)),
}
}
panic("unexpected")
}
func (p *paramParser) parseNonEmptyParamList(ctx *routingA.NonEmptyParameterListContext) {
children := ctx.GetChildren()
if len(children) == 3 {
p.list = append(p.list, p.parseParam(children[2].(*routingA.ParameterContext)))
p.parseNonEmptyParamList(children[0].(*routingA.NonEmptyParameterListContext))
} else if len(children) == 1 {
p.list = append(p.list, p.parseParam(children[0].(*routingA.ParameterContext)))
}
}
func (s *RoutingAWalker) parseNonEmptyParamList(list *routingA.NonEmptyParameterListContext) []*Param {
paramParser := new(paramParser)
paramParser.parseNonEmptyParamList(list)
return paramParser.list
}
func (s *RoutingAWalker) reportKeyUnsupportedError(ctx interface{}, keyName, funcName string) {
s.ReportError(ctx, ErrorType_Unsupported, fmt.Sprintf("key %v in %v()", strconv.Quote(keyName), funcName))
}
func (s *RoutingAWalker) parseFunctionPrototype(ctx *routingA.FunctionPrototypeContext) *Function {
children := ctx.GetChildren()
funcName := children[0].(*antlr.TerminalNodeImpl).GetText()
paramList := children[2].(*routingA.OptParameterListContext)
children = paramList.GetChildren()
if len(children) == 0 {
s.ReportError(ctx, ErrorType_Unsupported, "empty parameter list")
return nil
}
nonEmptyParamList := children[0].(*routingA.NonEmptyParameterListContext)
params := s.parseNonEmptyParamList(nonEmptyParamList)
// Validate function name and param keys.
for _, param := range params {
switch funcName {
case consts.Function_Domain:
switch param.Key {
case "", consts.RoutingDomain_Suffix,
consts.RoutingDomain_Keyword,
consts.RoutingDomain_Full,
consts.RoutingDomain_Regex,
"geosite":
default:
s.reportKeyUnsupportedError(ctx, param.Key, funcName)
return nil
}
case consts.Function_Ip, consts.Function_SourceIp:
switch param.Key {
case "", "geoip":
default:
s.reportKeyUnsupportedError(ctx, param.Key, funcName)
return nil
}
case consts.Function_Port, consts.Function_SourcePort, consts.Function_Mac, consts.Function_L4Proto, consts.Function_IpVersion:
if param.Key != "" {
s.reportKeyUnsupportedError(ctx, param.Key, funcName)
return nil
}
default:
s.ReportError(ctx, ErrorType_Unsupported)
return nil
}
}
return &Function{
Name: funcName,
Params: params,
}
}
func (s *RoutingAWalker) ReportError(ctx interface{}, errorType ErrorType, target ...string) {
bCtx := BaseContext(ctx)
tgt := strconv.Quote(bCtx.GetStart().GetText())
if len(target) != 0 {
tgt = target[0]
}
if errorType == ErrorType_NotSet {
s.parser.NotifyErrorListeners(fmt.Sprintf("%v %v.", tgt, errorType), nil, nil)
return
}
s.parser.NotifyErrorListeners(fmt.Sprintf("%v %v.", tgt, errorType), bCtx.GetStart(), nil)
}
func (s *RoutingAWalker) EnterDeclaration(ctx *routingA.DeclarationContext) {
children := ctx.GetChildren()
key := children[0].(*antlr.TerminalNodeImpl).GetText()
switch valueCtx := children[2].(type) {
case *routingA.LiteralContext:
value := getValueFromLiteral(valueCtx)
if key == consts.Declaration_Final {
s.FinalOutbound = value
} else {
s.ReportError(ctx, ErrorType_Unsupported)
return
}
case *routingA.AssignmentExpressionContext:
s.ReportError(valueCtx, ErrorType_Unsupported)
return
default:
s.ReportError(valueCtx, ErrorType_Unsupported)
return
}
}
func (s *RoutingAWalker) EnterRoutingRule(ctx *routingA.RoutingRuleContext) {
children := ctx.GetChildren()
//logrus.Debugln(ctx.GetText(), children)
left, ok := children[0].(*routingA.RoutingRuleLeftContext)
if !ok {
s.ReportError(ctx, ErrorType_Unsupported, "not *RoutingRuleLeftContext: "+ctx.GetText())
return
}
outbound := children[2].(*routingA.Bare_literalContext).GetText()
// Parse functions.
var andFunctions []*Function
children = left.GetChildren()
functionList, ok := children[1].(*routingA.FunctionPrototypeExpressionContext)
if !ok {
s.ReportError(ctx, ErrorType_Unsupported, "not *FunctionPrototypeExpressionContext: "+ctx.GetText())
return
}
children = functionList.GetChildren()
for _, child := range children {
// And rules.
if child, ok := child.(*routingA.FunctionPrototypeContext); ok {
function := s.parseFunctionPrototype(child)
andFunctions = append(andFunctions, function)
}
}
s.RoutingRules = append(s.RoutingRules, RoutingRule{
AndFunctions: andFunctions,
Outbound: outbound,
})
}
func (s *RoutingAWalker) EnterRoutingRuleOrDeclarationList(ctx *routingA.RoutingRuleOrDeclarationListContext) {
s.ReportError(ctx, ErrorType_Unsupported)
}
func (s *RoutingAWalker) ExitStart(ctx *routingA.StartContext) {
if s.FinalOutbound == "" {
s.ReportError(ctx, ErrorType_NotSet, `"default"`)
}
}