From de85c9455a4577ed8635f65806ac23858e1b3820 Mon Sep 17 00:00:00 2001 From: fatedier Date: Fri, 2 Jun 2023 16:06:29 +0800 Subject: [PATCH] stcp, xtcp, sudp: support allow_users and specified server user (#3472) --- .gitignore | 1 + Makefile | 15 +++ hack/download.sh | 63 ++++++++++ hack/run-e2e.sh | 26 ++-- pkg/config/client.go | 2 +- pkg/config/proxy.go | 216 ++++++++++++++++++---------------- pkg/config/visitor.go | 6 +- pkg/msg/msg.go | 5 +- pkg/nathole/controller.go | 28 +++-- server/control.go | 8 +- server/proxy/stcp.go | 7 +- server/proxy/sudp.go | 8 +- server/proxy/xtcp.go | 8 +- server/service.go | 11 +- server/visitor/visitor.go | 43 ++++--- test/e2e/basic/basic.go | 82 +++++++++---- test/e2e/framework/process.go | 2 +- 17 files changed, 355 insertions(+), 176 deletions(-) create mode 100755 hack/download.sh diff --git a/.gitignore b/.gitignore index 0d8ca50d..f6df315b 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ packages/ release/ test/bin/ vendor/ +lastversion/ dist/ .idea/ .vscode/ diff --git a/Makefile b/Makefile index e9b9ec54..d94e7c36 100644 --- a/Makefile +++ b/Makefile @@ -46,8 +46,23 @@ e2e: e2e-trace: DEBUG=true LOG_LEVEL=trace ./hack/run-e2e.sh +e2e-compatibility-last-frpc: + if [ ! -d "./lastversion" ]; then \ + TARGET_DIRNAME=lastversion ./hack/download.sh; \ + fi + FRPC_PATH="`pwd`/lastversion/frpc" ./hack/run-e2e.sh + rm -r ./lastversion + +e2e-compatibility-last-frps: + if [ ! -d "./lastversion" ]; then \ + TARGET_DIRNAME=lastversion ./hack/download.sh; \ + fi + FRPS_PATH="`pwd`/lastversion/frps" ./hack/run-e2e.sh + rm -r ./lastversion + alltest: vet gotest e2e clean: rm -f ./bin/frpc rm -f ./bin/frps + rm -rf ./lastversion diff --git a/hack/download.sh b/hack/download.sh new file mode 100755 index 00000000..acdf033a --- /dev/null +++ b/hack/download.sh @@ -0,0 +1,63 @@ +#!/bin/sh + +OS="$(go env GOOS)" +ARCH="$(go env GOARCH)" + +if [ "${TARGET_OS}" ]; then + OS="${TARGET_OS}" +fi +if [ "${TARGET_ARCH}" ]; then + ARCH="${TARGET_ARCH}" +fi + +# Determine the latest version by version number ignoring alpha, beta, and rc versions. +if [ "${FRP_VERSION}" = "" ] ; then + FRP_VERSION="$(curl -sL https://github.com/fatedier/frp/releases | \ + grep -o 'releases/tag/v[0-9]*.[0-9]*.[0-9]*"' | sort -V | \ + tail -1 | awk -F'/' '{ print $3}')" + FRP_VERSION="${FRP_VERSION%?}" + FRP_VERSION="${FRP_VERSION#?}" +fi + +if [ "${FRP_VERSION}" = "" ] ; then + printf "Unable to get latest frp version. Set FRP_VERSION env var and re-run. For example: export FRP_VERSION=1.0.0" + exit 1; +fi + +SUFFIX=".tar.gz" +if [ "${OS}" = "windows" ] ; then + SUFFIX=".zip" +fi +NAME="frp_${FRP_VERSION}_${OS}_${ARCH}${SUFFIX}" +DIR_NAME="frp_${FRP_VERSION}_${OS}_${ARCH}" +URL="https://github.com/fatedier/frp/releases/download/v${FRP_VERSION}/${NAME}" + +download_and_extract() { + printf "Downloading %s from %s ...\n" "$NAME" "${URL}" + if ! curl -o /dev/null -sIf "${URL}"; then + printf "\n%s is not found, please specify a valid FRP_VERSION\n" "${URL}" + exit 1 + fi + curl -fsLO "${URL}" + filename=$NAME + + if [ "${OS}" = "windows" ]; then + unzip "${filename}" + else + tar -xzf "${filename}" + fi + rm "${filename}" + + if [ "${TARGET_DIRNAME}" ]; then + mv "${DIR_NAME}" "${TARGET_DIRNAME}" + DIR_NAME="${TARGET_DIRNAME}" + fi +} + +download_and_extract + +printf "" +printf "\nfrp %s Download Complete!\n" "$FRP_VERSION" +printf "\n" +printf "frp has been successfully downloaded into the %s folder on your system.\n" "$DIR_NAME" +printf "\n" diff --git a/hack/run-e2e.sh b/hack/run-e2e.sh index e5dfecf5..953df0ee 100755 --- a/hack/run-e2e.sh +++ b/hack/run-e2e.sh @@ -1,20 +1,30 @@ -#!/usr/bin/env bash +#!/bin/sh -ROOT=$(unset CDPATH && cd $(dirname "${BASH_SOURCE[0]}")/.. && pwd) +SCRIPT=$(readlink -f "$0") +ROOT=$(unset CDPATH && cd "$(dirname "$SCRIPT")/.." && pwd) -which ginkgo &> /dev/null -if [ $? -ne 0 ]; then +ginkgo_command=$(which ginkgo 2>/dev/null) +if [ -z "$ginkgo_command" ]; then echo "ginkgo not found, try to install..." go install github.com/onsi/ginkgo/v2/ginkgo@v2.8.3 fi debug=false -if [ x${DEBUG} == x"true" ]; then +if [ "x${DEBUG}" = "xtrue" ]; then debug=true fi logLevel=debug -if [ x${LOG_LEVEL} != x"" ]; then - logLevel=${LOG_LEVEL} +if [ "${LOG_LEVEL}" ]; then + logLevel="${LOG_LEVEL}" fi -ginkgo -nodes=8 --poll-progress-after=30s ${ROOT}/test/e2e -- -frpc-path=${ROOT}/bin/frpc -frps-path=${ROOT}/bin/frps -log-level=${logLevel} -debug=${debug} +frpcPath=${ROOT}/bin/frpc +if [ "${FRPC_PATH}" ]; then + frpcPath="${FRPC_PATH}" +fi +frpsPath=${ROOT}/bin/frps +if [ "${FRPS_PATH}" ]; then + frpsPath="${FRPS_PATH}" +fi + +ginkgo -nodes=8 --poll-progress-after=60s ${ROOT}/test/e2e -- -frpc-path=${frpcPath} -frps-path=${frpsPath} -log-level=${logLevel} -debug=${debug} diff --git a/pkg/config/client.go b/pkg/config/client.go index 50a46d51..080f9a18 100644 --- a/pkg/config/client.go +++ b/pkg/config/client.go @@ -352,7 +352,7 @@ func LoadAllProxyConfsFromIni( case "visitor": newConf, newErr := NewVisitorConfFromIni(prefix, name, section) if newErr != nil { - return nil, nil, newErr + return nil, nil, fmt.Errorf("failed to parse visitor %s, err: %v", name, newErr) } visitorConfs[prefix+name] = newConf default: diff --git a/pkg/config/proxy.go b/pkg/config/proxy.go index 9cc89492..ae1d077c 100644 --- a/pkg/config/proxy.go +++ b/pkg/config/proxy.go @@ -178,6 +178,16 @@ func (cfg *RoleServerCommonConf) setDefaultValues() { cfg.Role = "server" } +func (cfg *RoleServerCommonConf) marshalToMsg(m *msg.NewProxy) { + m.Sk = cfg.Sk + m.AllowUsers = cfg.AllowUsers +} + +func (cfg *RoleServerCommonConf) unmarshalFromMsg(m *msg.NewProxy) { + cfg.Sk = m.Sk + cfg.AllowUsers = m.AllowUsers +} + // HTTP type HTTPProxyConf struct { BaseProxyConf `ini:",extends"` @@ -260,7 +270,7 @@ func NewProxyConfFromIni(prefix, name string, section *ini.Section) (ProxyConf, conf := DefaultProxyConf(proxyType) if conf == nil { - return nil, fmt.Errorf("proxy %s has invalid type [%s]", name, proxyType) + return nil, fmt.Errorf("invalid type [%s]", proxyType) } if err := conf.UnmarshalFromIni(prefix, name, section); err != nil { @@ -274,17 +284,17 @@ func NewProxyConfFromIni(prefix, name string, section *ini.Section) (ProxyConf, } // Proxy loaded from msg -func NewProxyConfFromMsg(pMsg *msg.NewProxy, serverCfg ServerCommonConf) (ProxyConf, error) { - if pMsg.ProxyType == "" { - pMsg.ProxyType = consts.TCPProxy +func NewProxyConfFromMsg(m *msg.NewProxy, serverCfg ServerCommonConf) (ProxyConf, error) { + if m.ProxyType == "" { + m.ProxyType = consts.TCPProxy } - conf := DefaultProxyConf(pMsg.ProxyType) + conf := DefaultProxyConf(m.ProxyType) if conf == nil { - return nil, fmt.Errorf("proxy [%s] type [%s] error", pMsg.ProxyName, pMsg.ProxyType) + return nil, fmt.Errorf("proxy [%s] type [%s] error", m.ProxyName, m.ProxyType) } - conf.UnmarshalFromMsg(pMsg) + conf.UnmarshalFromMsg(m) err := conf.ValidateForServer(serverCfg) if err != nil { @@ -341,35 +351,35 @@ func (cfg *BaseProxyConf) decorate(prefix string, name string, section *ini.Sect return nil } -func (cfg *BaseProxyConf) marshalToMsg(pMsg *msg.NewProxy) { - pMsg.ProxyName = cfg.ProxyName - pMsg.ProxyType = cfg.ProxyType - pMsg.UseEncryption = cfg.UseEncryption - pMsg.UseCompression = cfg.UseCompression - pMsg.BandwidthLimit = cfg.BandwidthLimit.String() +func (cfg *BaseProxyConf) marshalToMsg(m *msg.NewProxy) { + m.ProxyName = cfg.ProxyName + m.ProxyType = cfg.ProxyType + m.UseEncryption = cfg.UseEncryption + m.UseCompression = cfg.UseCompression + m.BandwidthLimit = cfg.BandwidthLimit.String() // leave it empty for default value to reduce traffic if cfg.BandwidthLimitMode != "client" { - pMsg.BandwidthLimitMode = cfg.BandwidthLimitMode + m.BandwidthLimitMode = cfg.BandwidthLimitMode } - pMsg.Group = cfg.Group - pMsg.GroupKey = cfg.GroupKey - pMsg.Metas = cfg.Metas + m.Group = cfg.Group + m.GroupKey = cfg.GroupKey + m.Metas = cfg.Metas } -func (cfg *BaseProxyConf) unmarshalFromMsg(pMsg *msg.NewProxy) { - cfg.ProxyName = pMsg.ProxyName - cfg.ProxyType = pMsg.ProxyType - cfg.UseEncryption = pMsg.UseEncryption - cfg.UseCompression = pMsg.UseCompression - if pMsg.BandwidthLimit != "" { - cfg.BandwidthLimit, _ = NewBandwidthQuantity(pMsg.BandwidthLimit) +func (cfg *BaseProxyConf) unmarshalFromMsg(m *msg.NewProxy) { + cfg.ProxyName = m.ProxyName + cfg.ProxyType = m.ProxyType + cfg.UseEncryption = m.UseEncryption + cfg.UseCompression = m.UseCompression + if m.BandwidthLimit != "" { + cfg.BandwidthLimit, _ = NewBandwidthQuantity(m.BandwidthLimit) } - if pMsg.BandwidthLimitMode != "" { - cfg.BandwidthLimitMode = pMsg.BandwidthLimitMode + if m.BandwidthLimitMode != "" { + cfg.BandwidthLimitMode = m.BandwidthLimitMode } - cfg.Group = pMsg.Group - cfg.GroupKey = pMsg.GroupKey - cfg.Metas = pMsg.Metas + cfg.Group = m.Group + cfg.GroupKey = m.GroupKey + cfg.Metas = m.Metas } func (cfg *BaseProxyConf) validateForClient() (err error) { @@ -482,11 +492,11 @@ func preUnmarshalFromIni(cfg ProxyConf, prefix string, name string, section *ini } // TCP -func (cfg *TCPProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.unmarshalFromMsg(pMsg) +func (cfg *TCPProxyConf) UnmarshalFromMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.unmarshalFromMsg(m) // Add custom logic unmarshal if exists - cfg.RemotePort = pMsg.RemotePort + cfg.RemotePort = m.RemotePort } func (cfg *TCPProxyConf) UnmarshalFromIni(prefix string, name string, section *ini.Section) error { @@ -500,11 +510,11 @@ func (cfg *TCPProxyConf) UnmarshalFromIni(prefix string, name string, section *i return nil } -func (cfg *TCPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.marshalToMsg(pMsg) +func (cfg *TCPProxyConf) MarshalToMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.marshalToMsg(m) // Add custom logic marshal if exists - pMsg.RemotePort = cfg.RemotePort + m.RemotePort = cfg.RemotePort } func (cfg *TCPProxyConf) ValidateForClient() (err error) { @@ -536,28 +546,28 @@ func (cfg *TCPMuxProxyConf) UnmarshalFromIni(prefix string, name string, section return nil } -func (cfg *TCPMuxProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.unmarshalFromMsg(pMsg) +func (cfg *TCPMuxProxyConf) UnmarshalFromMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.unmarshalFromMsg(m) // Add custom logic unmarshal if exists - cfg.CustomDomains = pMsg.CustomDomains - cfg.SubDomain = pMsg.SubDomain - cfg.Multiplexer = pMsg.Multiplexer - cfg.HTTPUser = pMsg.HTTPUser - cfg.HTTPPwd = pMsg.HTTPPwd - cfg.RouteByHTTPUser = pMsg.RouteByHTTPUser + cfg.CustomDomains = m.CustomDomains + cfg.SubDomain = m.SubDomain + cfg.Multiplexer = m.Multiplexer + cfg.HTTPUser = m.HTTPUser + cfg.HTTPPwd = m.HTTPPwd + cfg.RouteByHTTPUser = m.RouteByHTTPUser } -func (cfg *TCPMuxProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.marshalToMsg(pMsg) +func (cfg *TCPMuxProxyConf) MarshalToMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.marshalToMsg(m) // Add custom logic marshal if exists - pMsg.CustomDomains = cfg.CustomDomains - pMsg.SubDomain = cfg.SubDomain - pMsg.Multiplexer = cfg.Multiplexer - pMsg.HTTPUser = cfg.HTTPUser - pMsg.HTTPPwd = cfg.HTTPPwd - pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser + m.CustomDomains = cfg.CustomDomains + m.SubDomain = cfg.SubDomain + m.Multiplexer = cfg.Multiplexer + m.HTTPUser = cfg.HTTPUser + m.HTTPPwd = cfg.HTTPPwd + m.RouteByHTTPUser = cfg.RouteByHTTPUser } func (cfg *TCPMuxProxyConf) ValidateForClient() (err error) { @@ -610,18 +620,18 @@ func (cfg *UDPProxyConf) UnmarshalFromIni(prefix string, name string, section *i return nil } -func (cfg *UDPProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.unmarshalFromMsg(pMsg) +func (cfg *UDPProxyConf) UnmarshalFromMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.unmarshalFromMsg(m) // Add custom logic unmarshal if exists - cfg.RemotePort = pMsg.RemotePort + cfg.RemotePort = m.RemotePort } -func (cfg *UDPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.marshalToMsg(pMsg) +func (cfg *UDPProxyConf) MarshalToMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.marshalToMsg(m) // Add custom logic marshal if exists - pMsg.RemotePort = cfg.RemotePort + m.RemotePort = cfg.RemotePort } func (cfg *UDPProxyConf) ValidateForClient() (err error) { @@ -653,32 +663,32 @@ func (cfg *HTTPProxyConf) UnmarshalFromIni(prefix string, name string, section * return nil } -func (cfg *HTTPProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.unmarshalFromMsg(pMsg) +func (cfg *HTTPProxyConf) UnmarshalFromMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.unmarshalFromMsg(m) // Add custom logic unmarshal if exists - cfg.CustomDomains = pMsg.CustomDomains - cfg.SubDomain = pMsg.SubDomain - cfg.Locations = pMsg.Locations - cfg.HostHeaderRewrite = pMsg.HostHeaderRewrite - cfg.HTTPUser = pMsg.HTTPUser - cfg.HTTPPwd = pMsg.HTTPPwd - cfg.Headers = pMsg.Headers - cfg.RouteByHTTPUser = pMsg.RouteByHTTPUser + cfg.CustomDomains = m.CustomDomains + cfg.SubDomain = m.SubDomain + cfg.Locations = m.Locations + cfg.HostHeaderRewrite = m.HostHeaderRewrite + cfg.HTTPUser = m.HTTPUser + cfg.HTTPPwd = m.HTTPPwd + cfg.Headers = m.Headers + cfg.RouteByHTTPUser = m.RouteByHTTPUser } -func (cfg *HTTPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.marshalToMsg(pMsg) +func (cfg *HTTPProxyConf) MarshalToMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.marshalToMsg(m) // Add custom logic marshal if exists - pMsg.CustomDomains = cfg.CustomDomains - pMsg.SubDomain = cfg.SubDomain - pMsg.Locations = cfg.Locations - pMsg.HostHeaderRewrite = cfg.HostHeaderRewrite - pMsg.HTTPUser = cfg.HTTPUser - pMsg.HTTPPwd = cfg.HTTPPwd - pMsg.Headers = cfg.Headers - pMsg.RouteByHTTPUser = cfg.RouteByHTTPUser + m.CustomDomains = cfg.CustomDomains + m.SubDomain = cfg.SubDomain + m.Locations = cfg.Locations + m.HostHeaderRewrite = cfg.HostHeaderRewrite + m.HTTPUser = cfg.HTTPUser + m.HTTPPwd = cfg.HTTPPwd + m.Headers = cfg.Headers + m.RouteByHTTPUser = cfg.RouteByHTTPUser } func (cfg *HTTPProxyConf) ValidateForClient() (err error) { @@ -722,20 +732,20 @@ func (cfg *HTTPSProxyConf) UnmarshalFromIni(prefix string, name string, section return nil } -func (cfg *HTTPSProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.unmarshalFromMsg(pMsg) +func (cfg *HTTPSProxyConf) UnmarshalFromMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.unmarshalFromMsg(m) // Add custom logic unmarshal if exists - cfg.CustomDomains = pMsg.CustomDomains - cfg.SubDomain = pMsg.SubDomain + cfg.CustomDomains = m.CustomDomains + cfg.SubDomain = m.SubDomain } -func (cfg *HTTPSProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.marshalToMsg(pMsg) +func (cfg *HTTPSProxyConf) MarshalToMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.marshalToMsg(m) // Add custom logic marshal if exists - pMsg.CustomDomains = cfg.CustomDomains - pMsg.SubDomain = cfg.SubDomain + m.CustomDomains = cfg.CustomDomains + m.SubDomain = cfg.SubDomain } func (cfg *HTTPSProxyConf) ValidateForClient() (err error) { @@ -784,18 +794,18 @@ func (cfg *SUDPProxyConf) UnmarshalFromIni(prefix string, name string, section * } // Only for role server. -func (cfg *SUDPProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.unmarshalFromMsg(pMsg) +func (cfg *SUDPProxyConf) UnmarshalFromMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.unmarshalFromMsg(m) // Add custom logic unmarshal if exists - cfg.Sk = pMsg.Sk + cfg.RoleServerCommonConf.unmarshalFromMsg(m) } -func (cfg *SUDPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.marshalToMsg(pMsg) +func (cfg *SUDPProxyConf) MarshalToMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.marshalToMsg(m) // Add custom logic marshal if exists - pMsg.Sk = cfg.Sk + cfg.RoleServerCommonConf.marshalToMsg(m) } func (cfg *SUDPProxyConf) ValidateForClient() (err error) { @@ -838,18 +848,18 @@ func (cfg *STCPProxyConf) UnmarshalFromIni(prefix string, name string, section * } // Only for role server. -func (cfg *STCPProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.unmarshalFromMsg(pMsg) +func (cfg *STCPProxyConf) UnmarshalFromMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.unmarshalFromMsg(m) // Add custom logic unmarshal if exists - cfg.Sk = pMsg.Sk + cfg.RoleServerCommonConf.unmarshalFromMsg(m) } -func (cfg *STCPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.marshalToMsg(pMsg) +func (cfg *STCPProxyConf) MarshalToMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.marshalToMsg(m) // Add custom logic marshal if exists - pMsg.Sk = cfg.Sk + cfg.RoleServerCommonConf.marshalToMsg(m) } func (cfg *STCPProxyConf) ValidateForClient() (err error) { @@ -892,18 +902,18 @@ func (cfg *XTCPProxyConf) UnmarshalFromIni(prefix string, name string, section * } // Only for role server. -func (cfg *XTCPProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.unmarshalFromMsg(pMsg) +func (cfg *XTCPProxyConf) UnmarshalFromMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.unmarshalFromMsg(m) // Add custom logic unmarshal if exists - cfg.Sk = pMsg.Sk + cfg.RoleServerCommonConf.unmarshalFromMsg(m) } -func (cfg *XTCPProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { - cfg.BaseProxyConf.marshalToMsg(pMsg) +func (cfg *XTCPProxyConf) MarshalToMsg(m *msg.NewProxy) { + cfg.BaseProxyConf.marshalToMsg(m) // Add custom logic marshal if exists - pMsg.Sk = cfg.Sk + cfg.RoleServerCommonConf.marshalToMsg(m) } func (cfg *XTCPProxyConf) ValidateForClient() (err error) { diff --git a/pkg/config/visitor.go b/pkg/config/visitor.go index 1f388bad..31a8a02b 100644 --- a/pkg/config/visitor.go +++ b/pkg/config/visitor.go @@ -94,16 +94,16 @@ func NewVisitorConfFromIni(prefix string, name string, section *ini.Section) (Vi visitorType := section.Key("type").String() if visitorType == "" { - return nil, fmt.Errorf("visitor [%s] type shouldn't be empty", name) + return nil, fmt.Errorf("type shouldn't be empty") } conf := DefaultVisitorConf(visitorType) if conf == nil { - return nil, fmt.Errorf("visitor [%s] type [%s] error", name, visitorType) + return nil, fmt.Errorf("type [%s] error", visitorType) } if err := conf.UnmarshalFromIni(prefix, name, section); err != nil { - return nil, fmt.Errorf("visitor [%s] type [%s] error", name, visitorType) + return nil, fmt.Errorf("type [%s] error", visitorType) } if err := conf.Validate(); err != nil { diff --git a/pkg/msg/msg.go b/pkg/msg/msg.go index 2cb291ac..7a865785 100644 --- a/pkg/msg/msg.go +++ b/pkg/msg/msg.go @@ -110,8 +110,9 @@ type NewProxy struct { Headers map[string]string `json:"headers,omitempty"` RouteByHTTPUser string `json:"route_by_http_user,omitempty"` - // stcp - Sk string `json:"sk,omitempty"` + // stcp, sudp, xtcp + Sk string `json:"sk,omitempty"` + AllowUsers []string `json:"allow_users,omitempty"` // tcpmux Multiplexer string `json:"multiplexer,omitempty"` diff --git a/pkg/nathole/controller.go b/pkg/nathole/controller.go index a04006b9..6f97455a 100644 --- a/pkg/nathole/controller.go +++ b/pkg/nathole/controller.go @@ -43,9 +43,10 @@ func NewTransactionID() string { } type ClientCfg struct { - name string - sk string - sidCh chan string + name string + sk string + allowUsers []string + sidCh chan string } type Session struct { @@ -120,11 +121,12 @@ func (c *Controller) CleanWorker(ctx context.Context) { } } -func (c *Controller) ListenClient(name string, sk string) chan string { +func (c *Controller) ListenClient(name string, sk string, allowUsers []string) chan string { cfg := &ClientCfg{ - name: name, - sk: sk, - sidCh: make(chan string), + name: name, + sk: sk, + allowUsers: allowUsers, + sidCh: make(chan string), } c.mu.Lock() defer c.mu.Unlock() @@ -144,14 +146,18 @@ func (c *Controller) GenSid() string { return fmt.Sprintf("%d%s", t, id) } -func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter) { +func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport.MessageTransporter, visitorUser string) { if m.PreCheck { - _, ok := c.clientCfgs[m.ProxyName] + cfg, ok := c.clientCfgs[m.ProxyName] if !ok { _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName))) - } else { - _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, "")) + return } + if !lo.Contains(cfg.allowUsers, visitorUser) && !lo.Contains(cfg.allowUsers, "*") { + _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, fmt.Sprintf("xtcp visitor user [%s] not allowed for [%s]", visitorUser, m.ProxyName))) + return + } + _ = transporter.Send(c.GenNatHoleResponse(m.TransactionID, nil, "")) return } diff --git a/server/control.go b/server/control.go index 01075484..4d6b802c 100644 --- a/server/control.go +++ b/server/control.go @@ -524,7 +524,7 @@ func (ctl *Control) manager() { } func (ctl *Control) HandleNatHoleVisitor(m *msg.NatHoleVisitor) { - ctl.rc.NatHoleController.HandleVisitor(m, ctl.msgTransporter) + ctl.rc.NatHoleController.HandleVisitor(m, ctl.msgTransporter, ctl.loginMsg.User) } func (ctl *Control) HandleNatHoleClient(m *msg.NatHoleClient) { @@ -537,7 +537,7 @@ func (ctl *Control) HandleNatHoleReport(m *msg.NatHoleReport) { func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { var pxyConf config.ProxyConf - // Load configures from NewProxy message and check. + // Load configures from NewProxy message and validate. pxyConf, err = config.NewProxyConfFromMsg(pxyMsg, ctl.serverCfg) if err != nil { return @@ -550,8 +550,8 @@ func (ctl *Control) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err RunID: ctl.runID, } - // NewProxy will return a interface Proxy. - // In fact it create different proxies by different proxy type, we just call run() here. + // NewProxy will return an interface Proxy. + // In fact, it creates different proxies based on the proxy type. We just call run() here. pxy, err := proxy.NewProxy(ctl.ctx, userInfo, ctl.rc, ctl.poolCount, ctl.GetWorkConn, pxyConf, ctl.serverCfg, ctl.loginMsg) if err != nil { return remoteAddr, err diff --git a/server/proxy/stcp.go b/server/proxy/stcp.go index 2ece4057..f5311be4 100644 --- a/server/proxy/stcp.go +++ b/server/proxy/stcp.go @@ -27,7 +27,12 @@ type STCPProxy struct { func (pxy *STCPProxy) Run() (remoteAddr string, err error) { xl := pxy.xl - listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Sk) + allowUsers := pxy.cfg.AllowUsers + // if allowUsers is empty, only allow same user from proxy + if len(allowUsers) == 0 { + allowUsers = []string{pxy.GetUserInfo().User} + } + listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Sk, allowUsers) if errRet != nil { err = errRet return diff --git a/server/proxy/sudp.go b/server/proxy/sudp.go index 93707f23..82bf8d23 100644 --- a/server/proxy/sudp.go +++ b/server/proxy/sudp.go @@ -27,8 +27,12 @@ type SUDPProxy struct { func (pxy *SUDPProxy) Run() (remoteAddr string, err error) { xl := pxy.xl - - listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Sk) + allowUsers := pxy.cfg.AllowUsers + // if allowUsers is empty, only allow same user from proxy + if len(allowUsers) == 0 { + allowUsers = []string{pxy.GetUserInfo().User} + } + listener, errRet := pxy.rc.VisitorManager.Listen(pxy.GetName(), pxy.cfg.Sk, allowUsers) if errRet != nil { err = errRet return diff --git a/server/proxy/xtcp.go b/server/proxy/xtcp.go index 8b2717bf..9f4b9f41 100644 --- a/server/proxy/xtcp.go +++ b/server/proxy/xtcp.go @@ -35,11 +35,15 @@ func (pxy *XTCPProxy) Run() (remoteAddr string, err error) { xl := pxy.xl if pxy.rc.NatHoleController == nil { - xl.Error("udp port for xtcp is not specified.") err = fmt.Errorf("xtcp is not supported in frps") return } - sidCh := pxy.rc.NatHoleController.ListenClient(pxy.GetName(), pxy.cfg.Sk) + allowUsers := pxy.cfg.AllowUsers + // if allowUsers is empty, only allow same user from proxy + if len(allowUsers) == 0 { + allowUsers = []string{pxy.GetUserInfo().User} + } + sidCh := pxy.rc.NatHoleController.ListenClient(pxy.GetName(), pxy.cfg.Sk, allowUsers) go func() { for { select { diff --git a/server/service.go b/server/service.go index 4378b95c..e67ee73d 100644 --- a/server/service.go +++ b/server/service.go @@ -587,6 +587,15 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) } func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVisitorConn) error { + visitorUser := "" + // TODO: Compatible with old versions, can be without runID, user is empty. In later versions, it will be mandatory to include runID. + if newMsg.RunID != "" { + ctl, exist := svr.ctlManager.GetByID(newMsg.RunID) + if !exist { + return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID) + } + visitorUser = ctl.loginMsg.User + } return svr.rc.VisitorManager.NewConn(newMsg.ProxyName, visitorConn, newMsg.Timestamp, newMsg.SignKey, - newMsg.UseEncryption, newMsg.UseCompression) + newMsg.UseEncryption, newMsg.UseCompression, visitorUser) } diff --git a/server/visitor/visitor.go b/server/visitor/visitor.go index e66f7a07..c76bcee1 100644 --- a/server/visitor/visitor.go +++ b/server/visitor/visitor.go @@ -21,57 +21,69 @@ import ( "sync" libio "github.com/fatedier/golib/io" + "github.com/samber/lo" utilnet "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/pkg/util/util" ) +type listenerBundle struct { + l *utilnet.InternalListener + sk string + allowUsers []string +} + // Manager for visitor listeners. type Manager struct { - visitorListeners map[string]*utilnet.InternalListener - skMap map[string]string + listeners map[string]*listenerBundle mu sync.RWMutex } func NewManager() *Manager { return &Manager{ - visitorListeners: make(map[string]*utilnet.InternalListener), - skMap: make(map[string]string), + listeners: make(map[string]*listenerBundle), } } -func (vm *Manager) Listen(name string, sk string) (l *utilnet.InternalListener, err error) { +func (vm *Manager) Listen(name string, sk string, allowUsers []string) (l *utilnet.InternalListener, err error) { vm.mu.Lock() defer vm.mu.Unlock() - if _, ok := vm.visitorListeners[name]; ok { + if _, ok := vm.listeners[name]; ok { err = fmt.Errorf("custom listener for [%s] is repeated", name) return } l = utilnet.NewInternalListener() - vm.visitorListeners[name] = l - vm.skMap[name] = sk + vm.listeners[name] = &listenerBundle{ + l: l, + sk: sk, + allowUsers: allowUsers, + } return } func (vm *Manager) NewConn(name string, conn net.Conn, timestamp int64, signKey string, - useEncryption bool, useCompression bool, + useEncryption bool, useCompression bool, visitorUser string, ) (err error) { vm.mu.RLock() defer vm.mu.RUnlock() - if l, ok := vm.visitorListeners[name]; ok { - var sk string - if sk = vm.skMap[name]; util.GetAuthKey(sk, timestamp) != signKey { + if l, ok := vm.listeners[name]; ok { + if util.GetAuthKey(l.sk, timestamp) != signKey { err = fmt.Errorf("visitor connection of [%s] auth failed", name) return } + if !lo.Contains(l.allowUsers, visitorUser) && !lo.Contains(l.allowUsers, "*") { + err = fmt.Errorf("visitor connection of [%s] user [%s] not allowed", name, visitorUser) + return + } + var rwc io.ReadWriteCloser = conn if useEncryption { - if rwc, err = libio.WithEncryption(rwc, []byte(sk)); err != nil { + if rwc, err = libio.WithEncryption(rwc, []byte(l.sk)); err != nil { err = fmt.Errorf("create encryption connection failed: %v", err) return } @@ -79,7 +91,7 @@ func (vm *Manager) NewConn(name string, conn net.Conn, timestamp int64, signKey if useCompression { rwc = libio.WithCompression(rwc) } - err = l.PutConn(utilnet.WrapReadWriteCloserToConn(rwc, conn)) + err = l.l.PutConn(utilnet.WrapReadWriteCloserToConn(rwc, conn)) } else { err = fmt.Errorf("custom listener for [%s] doesn't exist", name) return @@ -91,6 +103,5 @@ func (vm *Manager) CloseListener(name string) { vm.mu.Lock() defer vm.mu.Unlock() - delete(vm.visitorListeners, name) - delete(vm.skMap, name) + delete(vm.listeners, name) } diff --git a/test/e2e/basic/basic.go b/test/e2e/basic/basic.go index 9f5914de..c2a5d274 100644 --- a/test/e2e/basic/basic.go +++ b/test/e2e/basic/basic.go @@ -282,8 +282,9 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() { proxyType := t ginkgo.It(fmt.Sprintf("Expose echo server with %s", strings.ToUpper(proxyType)), func() { serverConf := consts.DefaultServerConfig - clientServerConf := consts.DefaultClientConfig - clientVisitorConf := consts.DefaultClientConfig + clientServerConf := consts.DefaultClientConfig + "\nuser = user1" + clientVisitorConf := consts.DefaultClientConfig + "\nuser = user1" + clientUser2VisitorConf := consts.DefaultClientConfig + "\nuser = user2" localPortName := "" protocol := "tcp" @@ -312,7 +313,7 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() { `+extra, proxyName, proxyType, correctSK, localPortName) } getProxyVisitorConf := func(proxyName string, portName, visitorSK, extra string) string { - return fmt.Sprintf(` + out := fmt.Sprintf(` [%s] type = %s role = visitor @@ -320,14 +321,22 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() { sk = %s bind_port = {{ .%s }} `+extra, proxyName, proxyType, proxyName, visitorSK, portName) + if proxyType == "xtcp" { + // Set keep_tunnel_open to reduce testing time. + out += "\nkeep_tunnel_open = true" + } + return out } tests := []struct { - proxyName string - bindPortName string - visitorSK string - extraConfig string - expectError bool + proxyName string + bindPortName string + visitorSK string + commonExtraConfig string + proxyExtraConfig string + visitorExtraConfig string + expectError bool + user2 bool }{ { proxyName: "normal", @@ -335,22 +344,22 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() { visitorSK: correctSK, }, { - proxyName: "with-encryption", - bindPortName: port.GenName("WithEncryption"), - visitorSK: correctSK, - extraConfig: "use_encryption = true", + proxyName: "with-encryption", + bindPortName: port.GenName("WithEncryption"), + visitorSK: correctSK, + commonExtraConfig: "use_encryption = true", }, { - proxyName: "with-compression", - bindPortName: port.GenName("WithCompression"), - visitorSK: correctSK, - extraConfig: "use_compression = true", + proxyName: "with-compression", + bindPortName: port.GenName("WithCompression"), + visitorSK: correctSK, + commonExtraConfig: "use_compression = true", }, { proxyName: "with-encryption-and-compression", bindPortName: port.GenName("WithEncryptionAndCompression"), visitorSK: correctSK, - extraConfig: ` + commonExtraConfig: ` use_encryption = true use_compression = true `, @@ -361,22 +370,53 @@ var _ = ginkgo.Describe("[Feature: Basic]", func() { visitorSK: wrongSK, expectError: true, }, + { + proxyName: "allowed-user", + bindPortName: port.GenName("AllowedUser"), + visitorSK: correctSK, + proxyExtraConfig: "allow_users = another, user2", + visitorExtraConfig: "server_user = user1", + user2: true, + }, + { + proxyName: "not-allowed-user", + bindPortName: port.GenName("NotAllowedUser"), + visitorSK: correctSK, + proxyExtraConfig: "allow_users = invalid", + visitorExtraConfig: "server_user = user1", + expectError: true, + }, + { + proxyName: "allow-all", + bindPortName: port.GenName("AllowAll"), + visitorSK: correctSK, + proxyExtraConfig: "allow_users = *", + visitorExtraConfig: "server_user = user1", + user2: true, + }, } // build all client config for _, test := range tests { - clientServerConf += getProxyServerConf(test.proxyName, test.extraConfig) + "\n" + clientServerConf += getProxyServerConf(test.proxyName, test.commonExtraConfig+"\n"+test.proxyExtraConfig) + "\n" } for _, test := range tests { - clientVisitorConf += getProxyVisitorConf(test.proxyName, test.bindPortName, test.visitorSK, test.extraConfig) + "\n" + config := getProxyVisitorConf( + test.proxyName, test.bindPortName, test.visitorSK, test.commonExtraConfig+"\n"+test.visitorExtraConfig, + ) + "\n" + if test.user2 { + clientUser2VisitorConf += config + } else { + clientVisitorConf += config + } } // run frps and frpc - f.RunProcesses([]string{serverConf}, []string{clientServerConf, clientVisitorConf}) + f.RunProcesses([]string{serverConf}, []string{clientServerConf, clientVisitorConf, clientUser2VisitorConf}) for _, test := range tests { framework.NewRequestExpect(f). RequestModify(func(r *request.Request) { - r.Timeout(10 * time.Second) + r.Timeout(3 * time.Second) }). Protocol(protocol). PortName(test.bindPortName). diff --git a/test/e2e/framework/process.go b/test/e2e/framework/process.go index dba809dd..ca717e25 100644 --- a/test/e2e/framework/process.go +++ b/test/e2e/framework/process.go @@ -56,7 +56,7 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str ExpectNoError(err) time.Sleep(500 * time.Millisecond) } - time.Sleep(2 * time.Second) + time.Sleep(3 * time.Second) return currentServerProcesses, currentClientProcesses }