diff --git a/Makefile b/Makefile index d94e7c36..f8326891 100644 --- a/Makefile +++ b/Makefile @@ -26,10 +26,10 @@ vet: go vet ./... frps: - env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -o bin/frps ./cmd/frps + env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags frps -o bin/frps ./cmd/frps frpc: - env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -o bin/frpc ./cmd/frpc + env CGO_ENABLED=0 go build -trimpath -ldflags "$(LDFLAGS)" -tags frpc -o bin/frpc ./cmd/frpc test: gotest diff --git a/client/connector.go b/client/connector.go new file mode 100644 index 00000000..2ff9b491 --- /dev/null +++ b/client/connector.go @@ -0,0 +1,223 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "crypto/tls" + "io" + "net" + "strconv" + "strings" + "time" + + libdial "github.com/fatedier/golib/net/dial" + fmux "github.com/hashicorp/yamux" + quic "github.com/quic-go/quic-go" + "github.com/samber/lo" + + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/transport" + utilnet "github.com/fatedier/frp/pkg/util/net" + "github.com/fatedier/frp/pkg/util/xlog" +) + +// Connector is a interface for establishing connections to the server. +type Connector interface { + Open() error + Connect() (net.Conn, error) + Close() error +} + +// defaultConnectorImpl is the default implementation of Connector for normal frpc. +type defaultConnectorImpl struct { + ctx context.Context + cfg *v1.ClientCommonConfig + + muxSession *fmux.Session + quicConn quic.Connection +} + +func NewConnector(ctx context.Context, cfg *v1.ClientCommonConfig) Connector { + return &defaultConnectorImpl{ + ctx: ctx, + cfg: cfg, + } +} + +// Open opens a underlying connection to the server. +// The underlying connection is either a TCP connection or a QUIC connection. +// After the underlying connection is established, you can call Connect() to get a stream. +// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect(). +func (c *defaultConnectorImpl) Open() error { + xl := xlog.FromContextSafe(c.ctx) + + // special for quic + if strings.EqualFold(c.cfg.Transport.Protocol, "quic") { + var tlsConfig *tls.Config + var err error + sn := c.cfg.Transport.TLS.ServerName + if sn == "" { + sn = c.cfg.ServerAddr + } + if lo.FromPtr(c.cfg.Transport.TLS.Enable) { + tlsConfig, err = transport.NewClientTLSConfig( + c.cfg.Transport.TLS.CertFile, + c.cfg.Transport.TLS.KeyFile, + c.cfg.Transport.TLS.TrustedCaFile, + sn) + } else { + tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn) + } + if err != nil { + xl.Warn("fail to build tls configuration, err: %v", err) + return err + } + tlsConfig.NextProtos = []string{"frp"} + + conn, err := quic.DialAddr( + c.ctx, + net.JoinHostPort(c.cfg.ServerAddr, strconv.Itoa(c.cfg.ServerPort)), + tlsConfig, &quic.Config{ + MaxIdleTimeout: time.Duration(c.cfg.Transport.QUIC.MaxIdleTimeout) * time.Second, + MaxIncomingStreams: int64(c.cfg.Transport.QUIC.MaxIncomingStreams), + KeepAlivePeriod: time.Duration(c.cfg.Transport.QUIC.KeepalivePeriod) * time.Second, + }) + if err != nil { + return err + } + c.quicConn = conn + return nil + } + + if !lo.FromPtr(c.cfg.Transport.TCPMux) { + return nil + } + + conn, err := c.realConnect() + if err != nil { + return err + } + + fmuxCfg := fmux.DefaultConfig() + fmuxCfg.KeepAliveInterval = time.Duration(c.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second + fmuxCfg.LogOutput = io.Discard + fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024 + session, err := fmux.Client(conn, fmuxCfg) + if err != nil { + return err + } + c.muxSession = session + return nil +} + +// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled. +func (c *defaultConnectorImpl) Connect() (net.Conn, error) { + if c.quicConn != nil { + stream, err := c.quicConn.OpenStreamSync(context.Background()) + if err != nil { + return nil, err + } + return utilnet.QuicStreamToNetConn(stream, c.quicConn), nil + } else if c.muxSession != nil { + stream, err := c.muxSession.OpenStream() + if err != nil { + return nil, err + } + return stream, nil + } + + return c.realConnect() +} + +func (c *defaultConnectorImpl) realConnect() (net.Conn, error) { + xl := xlog.FromContextSafe(c.ctx) + var tlsConfig *tls.Config + var err error + tlsEnable := lo.FromPtr(c.cfg.Transport.TLS.Enable) + if c.cfg.Transport.Protocol == "wss" { + tlsEnable = true + } + if tlsEnable { + sn := c.cfg.Transport.TLS.ServerName + if sn == "" { + sn = c.cfg.ServerAddr + } + + tlsConfig, err = transport.NewClientTLSConfig( + c.cfg.Transport.TLS.CertFile, + c.cfg.Transport.TLS.KeyFile, + c.cfg.Transport.TLS.TrustedCaFile, + sn) + if err != nil { + xl.Warn("fail to build tls configuration, err: %v", err) + return nil, err + } + } + + proxyType, addr, auth, err := libdial.ParseProxyURL(c.cfg.Transport.ProxyURL) + if err != nil { + xl.Error("fail to parse proxy url") + return nil, err + } + dialOptions := []libdial.DialOption{} + protocol := c.cfg.Transport.Protocol + switch protocol { + case "websocket": + protocol = "tcp" + dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")})) + dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{ + Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)), + })) + dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig)) + case "wss": + protocol = "tcp" + dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig)) + // Make sure that if it is wss, the websocket hook is executed after the tls hook. + dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110})) + default: + dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{ + Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(c.cfg.Transport.TLS.DisableCustomTLSFirstByte)), + })) + dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig)) + } + + if c.cfg.Transport.ConnectServerLocalIP != "" { + dialOptions = append(dialOptions, libdial.WithLocalAddr(c.cfg.Transport.ConnectServerLocalIP)) + } + dialOptions = append(dialOptions, + libdial.WithProtocol(protocol), + libdial.WithTimeout(time.Duration(c.cfg.Transport.DialServerTimeout)*time.Second), + libdial.WithKeepAlive(time.Duration(c.cfg.Transport.DialServerKeepAlive)*time.Second), + libdial.WithProxy(proxyType, addr), + libdial.WithProxyAuth(auth), + ) + conn, err := libdial.DialContext( + c.ctx, + net.JoinHostPort(c.cfg.ServerAddr, strconv.Itoa(c.cfg.ServerPort)), + dialOptions..., + ) + return conn, err +} + +func (c *defaultConnectorImpl) Close() error { + if c.quicConn != nil { + _ = c.quicConn.CloseWithError(0, "") + } + if c.muxSession != nil { + _ = c.muxSession.Close() + } + return nil +} diff --git a/client/control.go b/client/control.go index c8d186ca..be028ec4 100644 --- a/client/control.go +++ b/client/control.go @@ -58,8 +58,8 @@ type Control struct { // control connection. Once conn is closed, the msgDispatcher and the entire Control will exit. conn net.Conn - // use cm to create new connections, which could be real TCP connections or virtual streams. - cm *ConnectionManager + // use connector to create new connections, which could be real TCP connections or virtual streams. + connector Connector doneCh chan struct{} @@ -77,7 +77,7 @@ type Control struct { } func NewControl( - ctx context.Context, runID string, conn net.Conn, cm *ConnectionManager, + ctx context.Context, runID string, conn net.Conn, connector Connector, clientCfg *v1.ClientCommonConfig, pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer, @@ -92,7 +92,7 @@ func NewControl( runID: runID, pxyCfgs: pxyCfgs, conn: conn, - cm: cm, + connector: connector, doneCh: make(chan struct{}), } ctl.lastPong.Store(time.Now()) @@ -122,6 +122,10 @@ func (ctl *Control) Run() { go ctl.vm.Run() } +func (ctl *Control) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { + ctl.pm.SetInWorkConnCallback(cb) +} + func (ctl *Control) handleReqWorkConn(_ msg.Message) { xl := ctl.xl workConn, err := ctl.connectServer() @@ -207,7 +211,7 @@ func (ctl *Control) GracefulClose(d time.Duration) error { time.Sleep(d) ctl.conn.Close() - ctl.cm.Close() + ctl.connector.Close() return nil } @@ -218,7 +222,7 @@ func (ctl *Control) Done() <-chan struct{} { // connectServer return a new connection to frps func (ctl *Control) connectServer() (conn net.Conn, err error) { - return ctl.cm.Connect() + return ctl.connector.Connect() } func (ctl *Control) registerMsgHandlers() { @@ -282,7 +286,7 @@ func (ctl *Control) worker() { ctl.pm.Close() ctl.vm.Close() - ctl.cm.Close() + ctl.connector.Close() close(ctl.doneCh) } diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go index 5ba63f94..396539c0 100644 --- a/client/proxy/proxy.go +++ b/client/proxy/proxy.go @@ -47,10 +47,9 @@ func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy, v // Proxy defines how to handle work connections for different proxy type. type Proxy interface { Run() error - // InWorkConn accept work connections registered to server. InWorkConn(net.Conn, *msg.StartWorkConn) - + SetInWorkConnCallback(func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) /* continue */ bool) Close() } @@ -89,7 +88,8 @@ type BaseProxy struct { limiter *rate.Limiter // proxyPlugin is used to handle connections instead of dialing to local service. // It's only validate for TCP protocol now. - proxyPlugin plugin.Plugin + proxyPlugin plugin.Plugin + inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) /* continue */ bool mu sync.RWMutex xl *xlog.Logger @@ -113,7 +113,16 @@ func (pxy *BaseProxy) Close() { } } +func (pxy *BaseProxy) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { + pxy.inWorkConnCallback = cb +} + func (pxy *BaseProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { + if pxy.inWorkConnCallback != nil { + if !pxy.inWorkConnCallback(pxy.baseCfg, conn, m) { + return + } + } pxy.HandleTCPWorkConnection(conn, m, []byte(pxy.clientCfg.Auth.Token)) } @@ -132,7 +141,7 @@ func (pxy *BaseProxy) HandleTCPWorkConnection(workConn net.Conn, m *msg.StartWor }) } - xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t", + xl.Trace("handle tcp work connection, useEncryption: %t, useCompression: %t", baseCfg.Transport.UseEncryption, baseCfg.Transport.UseCompression) if baseCfg.Transport.UseEncryption { remote, err = libio.WithEncryption(remote, encKey) diff --git a/client/proxy/proxy_manager.go b/client/proxy/proxy_manager.go index db66cb26..dadf6481 100644 --- a/client/proxy/proxy_manager.go +++ b/client/proxy/proxy_manager.go @@ -31,8 +31,9 @@ import ( ) type Manager struct { - proxies map[string]*Wrapper - msgTransporter transport.MessageTransporter + proxies map[string]*Wrapper + msgTransporter transport.MessageTransporter + inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool closed bool mu sync.RWMutex @@ -71,6 +72,10 @@ func (pm *Manager) StartProxy(name string, remoteAddr string, serverRespErr stri return nil } +func (pm *Manager) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { + pm.inWorkConnCallback = cb +} + func (pm *Manager) Close() { pm.mu.Lock() defer pm.mu.Unlock() @@ -146,6 +151,9 @@ func (pm *Manager) Reload(pxyCfgs []v1.ProxyConfigurer) { name := cfg.GetBaseConfig().Name if _, ok := pm.proxies[name]; !ok { pxy := NewWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.msgTransporter) + if pm.inWorkConnCallback != nil { + pxy.SetInWorkConnCallback(pm.inWorkConnCallback) + } pm.proxies[name] = pxy addPxyNames = append(addPxyNames, name) diff --git a/client/proxy/proxy_wrapper.go b/client/proxy/proxy_wrapper.go index 346c6d07..84f24abb 100644 --- a/client/proxy/proxy_wrapper.go +++ b/client/proxy/proxy_wrapper.go @@ -121,6 +121,10 @@ func NewWrapper( return pw } +func (pw *Wrapper) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { + pw.pxy.SetInWorkConnCallback(cb) +} + func (pw *Wrapper) SetRunningStatus(remoteAddr string, respErr string) error { pw.mu.Lock() defer pw.mu.Unlock() diff --git a/client/proxy/sudp.go b/client/proxy/sudp.go index e67a3397..f9fe53bc 100644 --- a/client/proxy/sudp.go +++ b/client/proxy/sudp.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !frps + package proxy import ( diff --git a/client/proxy/udp.go b/client/proxy/udp.go index d7a790c1..d8590f68 100644 --- a/client/proxy/udp.go +++ b/client/proxy/udp.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !frps + package proxy import ( diff --git a/client/proxy/xtcp.go b/client/proxy/xtcp.go index 8271099b..b286a931 100644 --- a/client/proxy/xtcp.go +++ b/client/proxy/xtcp.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !frps + package proxy import ( diff --git a/client/service.go b/client/service.go index 66a642c1..7c3cd039 100644 --- a/client/service.go +++ b/client/service.go @@ -16,30 +16,22 @@ package client import ( "context" - "crypto/tls" "errors" "fmt" - "io" "net" "runtime" "strconv" - "strings" "sync" "time" "github.com/fatedier/golib/crypto" - libdial "github.com/fatedier/golib/net/dial" - fmux "github.com/hashicorp/yamux" - quic "github.com/quic-go/quic-go" "github.com/samber/lo" "github.com/fatedier/frp/assets" "github.com/fatedier/frp/pkg/auth" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/msg" - "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/log" - utilnet "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/pkg/util/version" "github.com/fatedier/frp/pkg/util/wait" "github.com/fatedier/frp/pkg/util/xlog" @@ -75,6 +67,9 @@ type Service struct { // call cancel to stop service cancel context.CancelFunc gracefulDuration time.Duration + + connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector + inWorkConnCallback func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool } func NewService( @@ -84,15 +79,24 @@ func NewService( cfgFile string, ) *Service { return &Service{ - authSetter: auth.NewAuthSetter(cfg.Auth), - cfg: cfg, - cfgFile: cfgFile, - pxyCfgs: pxyCfgs, - visitorCfgs: visitorCfgs, - ctx: context.Background(), + authSetter: auth.NewAuthSetter(cfg.Auth), + cfg: cfg, + cfgFile: cfgFile, + pxyCfgs: pxyCfgs, + visitorCfgs: visitorCfgs, + ctx: context.Background(), + connectorCreator: NewConnector, } } +func (svr *Service) SetConnectorCreator(h func(context.Context, *v1.ClientCommonConfig) Connector) { + svr.connectorCreator = h +} + +func (svr *Service) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { + svr.inWorkConnCallback = cb +} + func (svr *Service) GetController() *Control { svr.ctlMu.RLock() defer svr.ctlMu.RUnlock() @@ -101,7 +105,7 @@ func (svr *Service) GetController() *Control { func (svr *Service) Run(ctx context.Context) error { ctx, cancel := context.WithCancel(ctx) - svr.ctx = xlog.NewContext(ctx, xlog.New()) + svr.ctx = xlog.NewContext(ctx, xlog.FromContextSafe(ctx)) svr.cancel = cancel // set custom DNSServer @@ -173,21 +177,20 @@ func (svr *Service) keepControllerWorking() { // login creates a connection to frps and registers it self as a client // conn: control connection // session: if it's not nil, using tcp mux -func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) { +func (svr *Service) login() (conn net.Conn, connector Connector, err error) { xl := xlog.FromContextSafe(svr.ctx) - cm = NewConnectionManager(svr.ctx, svr.cfg) - - if err = cm.OpenConnection(); err != nil { + connector = svr.connectorCreator(svr.ctx, svr.cfg) + if err = connector.Open(); err != nil { return nil, nil, err } defer func() { if err != nil { - cm.Close() + connector.Close() } }() - conn, err = cm.Connect() + conn, err = connector.Connect() if err != nil { return } @@ -226,8 +229,7 @@ func (svr *Service) login() (conn net.Conn, cm *ConnectionManager, err error) { } svr.runID = loginRespMsg.RunID - xl.ResetPrefixes() - xl.AppendPrefix(svr.runID) + xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID}) xl.Info("login to server success, get run id [%s]", loginRespMsg.RunID) return @@ -239,7 +241,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE loginFunc := func() error { xl.Info("try to connect to server...") - conn, cm, err := svr.login() + conn, connector, err := svr.login() if err != nil { xl.Warn("connect to server error: %v", err) if firstLoginExit { @@ -248,13 +250,14 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE return err } - ctl, err := NewControl(svr.ctx, svr.runID, conn, cm, + ctl, err := NewControl(svr.ctx, svr.runID, conn, connector, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.authSetter) if err != nil { conn.Close() xl.Error("NewControl error: %v", err) return err } + ctl.SetInWorkConnCallback(svr.inWorkConnCallback) ctl.Run() // close and replace previous control @@ -314,184 +317,3 @@ func (svr *Service) stop() { svr.ctl = nil } } - -// ConnectionManager is a wrapper for establishing connections to the server. -type ConnectionManager struct { - ctx context.Context - cfg *v1.ClientCommonConfig - - muxSession *fmux.Session - quicConn quic.Connection -} - -func NewConnectionManager(ctx context.Context, cfg *v1.ClientCommonConfig) *ConnectionManager { - return &ConnectionManager{ - ctx: ctx, - cfg: cfg, - } -} - -// OpenConnection opens a underlying connection to the server. -// The underlying connection is either a TCP connection or a QUIC connection. -// After the underlying connection is established, you can call Connect() to get a stream. -// If TCPMux isn't enabled, the underlying connection is nil, you will get a new real TCP connection every time you call Connect(). -func (cm *ConnectionManager) OpenConnection() error { - xl := xlog.FromContextSafe(cm.ctx) - - // special for quic - if strings.EqualFold(cm.cfg.Transport.Protocol, "quic") { - var tlsConfig *tls.Config - var err error - sn := cm.cfg.Transport.TLS.ServerName - if sn == "" { - sn = cm.cfg.ServerAddr - } - if lo.FromPtr(cm.cfg.Transport.TLS.Enable) { - tlsConfig, err = transport.NewClientTLSConfig( - cm.cfg.Transport.TLS.CertFile, - cm.cfg.Transport.TLS.KeyFile, - cm.cfg.Transport.TLS.TrustedCaFile, - sn) - } else { - tlsConfig, err = transport.NewClientTLSConfig("", "", "", sn) - } - if err != nil { - xl.Warn("fail to build tls configuration, err: %v", err) - return err - } - tlsConfig.NextProtos = []string{"frp"} - - conn, err := quic.DialAddr( - cm.ctx, - net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)), - tlsConfig, &quic.Config{ - MaxIdleTimeout: time.Duration(cm.cfg.Transport.QUIC.MaxIdleTimeout) * time.Second, - MaxIncomingStreams: int64(cm.cfg.Transport.QUIC.MaxIncomingStreams), - KeepAlivePeriod: time.Duration(cm.cfg.Transport.QUIC.KeepalivePeriod) * time.Second, - }) - if err != nil { - return err - } - cm.quicConn = conn - return nil - } - - if !lo.FromPtr(cm.cfg.Transport.TCPMux) { - return nil - } - - conn, err := cm.realConnect() - if err != nil { - return err - } - - fmuxCfg := fmux.DefaultConfig() - fmuxCfg.KeepAliveInterval = time.Duration(cm.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second - fmuxCfg.LogOutput = io.Discard - fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024 - session, err := fmux.Client(conn, fmuxCfg) - if err != nil { - return err - } - cm.muxSession = session - return nil -} - -// Connect returns a stream from the underlying connection, or a new TCP connection if TCPMux isn't enabled. -func (cm *ConnectionManager) Connect() (net.Conn, error) { - if cm.quicConn != nil { - stream, err := cm.quicConn.OpenStreamSync(context.Background()) - if err != nil { - return nil, err - } - return utilnet.QuicStreamToNetConn(stream, cm.quicConn), nil - } else if cm.muxSession != nil { - stream, err := cm.muxSession.OpenStream() - if err != nil { - return nil, err - } - return stream, nil - } - - return cm.realConnect() -} - -func (cm *ConnectionManager) realConnect() (net.Conn, error) { - xl := xlog.FromContextSafe(cm.ctx) - var tlsConfig *tls.Config - var err error - tlsEnable := lo.FromPtr(cm.cfg.Transport.TLS.Enable) - if cm.cfg.Transport.Protocol == "wss" { - tlsEnable = true - } - if tlsEnable { - sn := cm.cfg.Transport.TLS.ServerName - if sn == "" { - sn = cm.cfg.ServerAddr - } - - tlsConfig, err = transport.NewClientTLSConfig( - cm.cfg.Transport.TLS.CertFile, - cm.cfg.Transport.TLS.KeyFile, - cm.cfg.Transport.TLS.TrustedCaFile, - sn) - if err != nil { - xl.Warn("fail to build tls configuration, err: %v", err) - return nil, err - } - } - - proxyType, addr, auth, err := libdial.ParseProxyURL(cm.cfg.Transport.ProxyURL) - if err != nil { - xl.Error("fail to parse proxy url") - return nil, err - } - dialOptions := []libdial.DialOption{} - protocol := cm.cfg.Transport.Protocol - switch protocol { - case "websocket": - protocol = "tcp" - dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, "")})) - dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{ - Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(cm.cfg.Transport.TLS.DisableCustomTLSFirstByte)), - })) - dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig)) - case "wss": - protocol = "tcp" - dialOptions = append(dialOptions, libdial.WithTLSConfigAndPriority(100, tlsConfig)) - // Make sure that if it is wss, the websocket hook is executed after the tls hook. - dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{Hook: utilnet.DialHookWebsocket(protocol, tlsConfig.ServerName), Priority: 110})) - default: - dialOptions = append(dialOptions, libdial.WithAfterHook(libdial.AfterHook{ - Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, lo.FromPtr(cm.cfg.Transport.TLS.DisableCustomTLSFirstByte)), - })) - dialOptions = append(dialOptions, libdial.WithTLSConfig(tlsConfig)) - } - - if cm.cfg.Transport.ConnectServerLocalIP != "" { - dialOptions = append(dialOptions, libdial.WithLocalAddr(cm.cfg.Transport.ConnectServerLocalIP)) - } - dialOptions = append(dialOptions, - libdial.WithProtocol(protocol), - libdial.WithTimeout(time.Duration(cm.cfg.Transport.DialServerTimeout)*time.Second), - libdial.WithKeepAlive(time.Duration(cm.cfg.Transport.DialServerKeepAlive)*time.Second), - libdial.WithProxy(proxyType, addr), - libdial.WithProxyAuth(auth), - ) - conn, err := libdial.DialContext( - cm.ctx, - net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)), - dialOptions..., - ) - return conn, err -} - -func (cm *ConnectionManager) Close() error { - if cm.quicConn != nil { - _ = cm.quicConn.CloseWithError(0, "") - } - if cm.muxSession != nil { - _ = cm.muxSession.Close() - } - return nil -} diff --git a/cmd/frpc/sub/proxy.go b/cmd/frpc/sub/proxy.go index 7ae8d353..96050943 100644 --- a/cmd/frpc/sub/proxy.go +++ b/cmd/frpc/sub/proxy.go @@ -21,6 +21,7 @@ import ( "github.com/samber/lo" "github.com/spf13/cobra" + "github.com/fatedier/frp/pkg/config" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/config/v1/validation" ) @@ -50,8 +51,8 @@ func init() { } clientCfg := v1.ClientCommonConfig{} cmd := NewProxyCommand(string(typ), c, &clientCfg) - RegisterClientCommonConfigFlags(cmd, &clientCfg) - RegisterProxyFlags(cmd, c) + config.RegisterClientCommonConfigFlags(cmd, &clientCfg) + config.RegisterProxyFlags(cmd, c) // add sub command for visitor if lo.Contains(visitorTypes, v1.VisitorType(typ)) { @@ -60,7 +61,7 @@ func init() { panic("visitor type: " + typ + " not support") } visitorCmd := NewVisitorCommand(string(typ), vc, &clientCfg) - RegisterVisitorFlags(visitorCmd, vc) + config.RegisterVisitorFlags(visitorCmd, vc) cmd.AddCommand(visitorCmd) } rootCmd.AddCommand(cmd) diff --git a/cmd/frps/flags.go b/cmd/frps/flags.go deleted file mode 100644 index 50170684..00000000 --- a/cmd/frps/flags.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2023 The frp Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "strconv" - - "github.com/spf13/cobra" - - "github.com/fatedier/frp/pkg/config/types" - v1 "github.com/fatedier/frp/pkg/config/v1" -) - -type PortsRangeSliceFlag struct { - V *[]types.PortsRange -} - -func (f *PortsRangeSliceFlag) String() string { - if f.V == nil { - return "" - } - return types.PortsRangeSlice(*f.V).String() -} - -func (f *PortsRangeSliceFlag) Set(s string) error { - slice, err := types.NewPortsRangeSliceFromString(s) - if err != nil { - return err - } - *f.V = slice - return nil -} - -func (f *PortsRangeSliceFlag) Type() string { - return "string" -} - -type BoolFuncFlag struct { - TrueFunc func() - FalseFunc func() - - v bool -} - -func (f *BoolFuncFlag) String() string { - return strconv.FormatBool(f.v) -} - -func (f *BoolFuncFlag) Set(s string) error { - f.v = strconv.FormatBool(f.v) == "true" - - if !f.v { - if f.FalseFunc != nil { - f.FalseFunc() - } - return nil - } - - if f.TrueFunc != nil { - f.TrueFunc() - } - return nil -} - -func (f *BoolFuncFlag) Type() string { - return "bool" -} - -func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig) { - cmd.PersistentFlags().StringVarP(&c.BindAddr, "bind_addr", "", "0.0.0.0", "bind address") - cmd.PersistentFlags().IntVarP(&c.BindPort, "bind_port", "p", 7000, "bind port") - cmd.PersistentFlags().IntVarP(&c.KCPBindPort, "kcp_bind_port", "", 0, "kcp bind udp port") - cmd.PersistentFlags().StringVarP(&c.ProxyBindAddr, "proxy_bind_addr", "", "0.0.0.0", "proxy bind address") - cmd.PersistentFlags().IntVarP(&c.VhostHTTPPort, "vhost_http_port", "", 0, "vhost http port") - cmd.PersistentFlags().IntVarP(&c.VhostHTTPSPort, "vhost_https_port", "", 0, "vhost https port") - cmd.PersistentFlags().Int64VarP(&c.VhostHTTPTimeout, "vhost_http_timeout", "", 60, "vhost http response header timeout") - cmd.PersistentFlags().StringVarP(&c.WebServer.Addr, "dashboard_addr", "", "0.0.0.0", "dashboard address") - cmd.PersistentFlags().IntVarP(&c.WebServer.Port, "dashboard_port", "", 0, "dashboard port") - cmd.PersistentFlags().StringVarP(&c.WebServer.User, "dashboard_user", "", "admin", "dashboard user") - cmd.PersistentFlags().StringVarP(&c.WebServer.Password, "dashboard_pwd", "", "admin", "dashboard password") - cmd.PersistentFlags().BoolVarP(&c.EnablePrometheus, "enable_prometheus", "", false, "enable prometheus dashboard") - cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "log file") - cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level") - cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log max days") - cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console") - cmd.PersistentFlags().StringVarP(&c.Auth.Token, "token", "t", "", "auth token") - cmd.PersistentFlags().StringVarP(&c.SubDomainHost, "subdomain_host", "", "", "subdomain host") - cmd.PersistentFlags().VarP(&PortsRangeSliceFlag{V: &c.AllowPorts}, "allow_ports", "", "allow ports") - cmd.PersistentFlags().Int64VarP(&c.MaxPortsPerClient, "max_ports_per_client", "", 0, "max ports per client") - cmd.PersistentFlags().BoolVarP(&c.Transport.TLS.Force, "tls_only", "", false, "frps tls only") - - webServerTLS := v1.TLSConfig{} - cmd.PersistentFlags().StringVarP(&webServerTLS.CertFile, "dashboard_tls_cert_file", "", "", "dashboard tls cert file") - cmd.PersistentFlags().StringVarP(&webServerTLS.KeyFile, "dashboard_tls_key_file", "", "", "dashboard tls key file") - cmd.PersistentFlags().VarP(&BoolFuncFlag{ - TrueFunc: func() { c.WebServer.TLS = &webServerTLS }, - }, "dashboard_tls_mode", "", "if enable dashboard tls mode") -} diff --git a/cmd/frps/root.go b/cmd/frps/root.go index 5f32fe9c..0cf8e4e7 100644 --- a/cmd/frps/root.go +++ b/cmd/frps/root.go @@ -42,7 +42,7 @@ func init() { rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frps") rootCmd.PersistentFlags().BoolVarP(&strictConfigMode, "strict_config", "", false, "strict config parsing mode, unknown fields will cause error") - RegisterServerConfigFlags(rootCmd, &serverCfg) + config.RegisterServerConfigFlags(rootCmd, &serverCfg) } var rootCmd = &cobra.Command{ diff --git a/go.mod b/go.mod index 8d0055e6..d11e1ef4 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/quic-go/quic-go v0.37.4 github.com/rodaine/table v1.1.0 github.com/samber/lo v1.38.1 - github.com/spf13/cobra v1.7.0 + github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.15.0 golang.org/x/net v0.17.0 diff --git a/go.sum b/go.sum index 49cef0b2..56966be2 100644 --- a/go.sum +++ b/go.sum @@ -16,7 +16,7 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/coreos/go-oidc/v3 v3.6.0 h1:AKVxfYw1Gmkn/w96z0DbT/B/xFnzTd3MkZvWLjF4n/o= github.com/coreos/go-oidc/v3 v3.6.0/go.mod h1:ZpHUsHBucTUj6WOkrP4E20UPynbLZzhTQ1XKCXkxyPc= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -128,8 +128,8 @@ github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUz github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/cmd/frpc/sub/flags.go b/pkg/config/flags.go similarity index 61% rename from cmd/frpc/sub/flags.go rename to pkg/config/flags.go index eb3cc010..0c37e608 100644 --- a/cmd/frpc/sub/flags.go +++ b/pkg/config/flags.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sub +package config import ( "fmt" + "strconv" "github.com/spf13/cobra" @@ -123,3 +124,89 @@ func RegisterClientCommonConfigFlags(cmd *cobra.Command, c *v1.ClientCommonConfi c.Transport.TLS.Enable = cmd.PersistentFlags().BoolP("tls_enable", "", true, "enable frpc tls") } + +type PortsRangeSliceFlag struct { + V *[]types.PortsRange +} + +func (f *PortsRangeSliceFlag) String() string { + if f.V == nil { + return "" + } + return types.PortsRangeSlice(*f.V).String() +} + +func (f *PortsRangeSliceFlag) Set(s string) error { + slice, err := types.NewPortsRangeSliceFromString(s) + if err != nil { + return err + } + *f.V = slice + return nil +} + +func (f *PortsRangeSliceFlag) Type() string { + return "string" +} + +type BoolFuncFlag struct { + TrueFunc func() + FalseFunc func() + + v bool +} + +func (f *BoolFuncFlag) String() string { + return strconv.FormatBool(f.v) +} + +func (f *BoolFuncFlag) Set(s string) error { + f.v = strconv.FormatBool(f.v) == "true" + + if !f.v { + if f.FalseFunc != nil { + f.FalseFunc() + } + return nil + } + + if f.TrueFunc != nil { + f.TrueFunc() + } + return nil +} + +func (f *BoolFuncFlag) Type() string { + return "bool" +} + +func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig) { + cmd.PersistentFlags().StringVarP(&c.BindAddr, "bind_addr", "", "0.0.0.0", "bind address") + cmd.PersistentFlags().IntVarP(&c.BindPort, "bind_port", "p", 7000, "bind port") + cmd.PersistentFlags().IntVarP(&c.KCPBindPort, "kcp_bind_port", "", 0, "kcp bind udp port") + cmd.PersistentFlags().StringVarP(&c.ProxyBindAddr, "proxy_bind_addr", "", "0.0.0.0", "proxy bind address") + cmd.PersistentFlags().IntVarP(&c.VhostHTTPPort, "vhost_http_port", "", 0, "vhost http port") + cmd.PersistentFlags().IntVarP(&c.VhostHTTPSPort, "vhost_https_port", "", 0, "vhost https port") + cmd.PersistentFlags().Int64VarP(&c.VhostHTTPTimeout, "vhost_http_timeout", "", 60, "vhost http response header timeout") + cmd.PersistentFlags().StringVarP(&c.WebServer.Addr, "dashboard_addr", "", "0.0.0.0", "dashboard address") + cmd.PersistentFlags().IntVarP(&c.WebServer.Port, "dashboard_port", "", 0, "dashboard port") + cmd.PersistentFlags().StringVarP(&c.WebServer.User, "dashboard_user", "", "admin", "dashboard user") + cmd.PersistentFlags().StringVarP(&c.WebServer.Password, "dashboard_pwd", "", "admin", "dashboard password") + cmd.PersistentFlags().BoolVarP(&c.EnablePrometheus, "enable_prometheus", "", false, "enable prometheus dashboard") + cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "log file") + cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level") + cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log max days") + cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console") + cmd.PersistentFlags().StringVarP(&c.Auth.Token, "token", "t", "", "auth token") + cmd.PersistentFlags().StringVarP(&c.SubDomainHost, "subdomain_host", "", "", "subdomain host") + cmd.PersistentFlags().VarP(&PortsRangeSliceFlag{V: &c.AllowPorts}, "allow_ports", "", "allow ports") + cmd.PersistentFlags().Int64VarP(&c.MaxPortsPerClient, "max_ports_per_client", "", 0, "max ports per client") + cmd.PersistentFlags().BoolVarP(&c.Transport.TLS.Force, "tls_only", "", false, "frps tls only") + + webServerTLS := v1.TLSConfig{} + cmd.PersistentFlags().StringVarP(&webServerTLS.CertFile, "dashboard_tls_cert_file", "", "", "dashboard tls cert file") + cmd.PersistentFlags().StringVarP(&webServerTLS.KeyFile, "dashboard_tls_key_file", "", "", "dashboard tls key file") + cmd.PersistentFlags().VarP(&BoolFuncFlag{ + TrueFunc: func() { c.WebServer.TLS = &webServerTLS }, + }, "dashboard_tls_mode", "", "if enable dashboard tls mode") +} diff --git a/pkg/config/v1/server.go b/pkg/config/v1/server.go index f562be8e..03b05d9d 100644 --- a/pkg/config/v1/server.go +++ b/pkg/config/v1/server.go @@ -16,21 +16,11 @@ package v1 import ( "github.com/samber/lo" - "golang.org/x/crypto/ssh" "github.com/fatedier/frp/pkg/config/types" "github.com/fatedier/frp/pkg/util/util" ) -type SSHTunnelGateway struct { - BindPort int `json:"bindPort,omitempty" validate:"gte=0,lte=65535"` - PrivateKeyFilePath string `json:"privateKeyFilePath,omitempty"` - PublicKeyFilesPath string `json:"publicKeyFilesPath,omitempty"` - - // store all public key file. load all when init - PublicKeyFilesMap map[string]ssh.PublicKey -} - type ServerConfig struct { APIMetadata @@ -41,9 +31,6 @@ type ServerConfig struct { // BindPort specifies the port that the server listens on. By default, this // value is 7000. BindPort int `json:"bindPort,omitempty"` - - SSHTunnelGateway SSHTunnelGateway `json:"sshGatewayConfig,omitempty"` - // KCPBindPort specifies the KCP port that the server listens on. If this // value is 0, the server will not listen for KCP connections. KCPBindPort int `json:"kcpBindPort,omitempty"` @@ -80,6 +67,8 @@ type ServerConfig struct { // value is "", a default page will be displayed. Custom404Page string `json:"custom404Page,omitempty"` + SSHTunnelGateway SSHTunnelGateway `json:"sshTunnelGateway,omitempty"` + WebServer WebServerConfig `json:"webServer,omitempty"` // EnablePrometheus will export prometheus metrics on webserver address // in /metrics api. @@ -114,6 +103,7 @@ func (c *ServerConfig) Complete() { c.Log.Complete() c.Transport.Complete() c.WebServer.Complete() + c.SSHTunnelGateway.Complete() c.BindAddr = util.EmptyOr(c.BindAddr, "0.0.0.0") c.BindPort = util.EmptyOr(c.BindPort, 7000) @@ -202,3 +192,14 @@ type TLSServerConfig struct { TLSConfig } + +type SSHTunnelGateway struct { + BindPort int `json:"bindPort,omitempty"` + PrivateKeyFile string `json:"privateKeyFile,omitempty"` + AutoGenPrivateKeyPath string `json:"autoGenPrivateKeyPath,omitempty"` + AuthorizedKeysFile string `json:"authorizedKeysFile,omitempty"` +} + +func (c *SSHTunnelGateway) Complete() { + c.AutoGenPrivateKeyPath = util.EmptyOr(c.AutoGenPrivateKeyPath, "./.autogen_ssh_key") +} diff --git a/pkg/config/v1/ssh.go b/pkg/config/v1/ssh.go deleted file mode 100644 index 440305d4..00000000 --- a/pkg/config/v1/ssh.go +++ /dev/null @@ -1,72 +0,0 @@ -package v1 - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "os" - "path/filepath" - - "golang.org/x/crypto/ssh" -) - -const ( - // custom define - SSHClientLoginUserPrefix = "_frpc_ssh_client_" -) - -// encodePrivateKeyToPEM encodes Private Key from RSA to PEM format -func GeneratePrivateKey() ([]byte, error) { - privateKey, err := generatePrivateKey() - if err != nil { - return nil, errors.New("gen private key error") - } - - privBlock := pem.Block{ - Type: "RSA PRIVATE KEY", - Headers: nil, - Bytes: x509.MarshalPKCS1PrivateKey(privateKey), - } - - return pem.EncodeToMemory(&privBlock), nil -} - -// generatePrivateKey creates a RSA Private Key of specified byte size -func generatePrivateKey() (*rsa.PrivateKey, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, err - } - - err = privateKey.Validate() - if err != nil { - return nil, err - } - return privateKey, nil -} - -func LoadSSHPublicKeyFilesInDir(dirPath string) (map[string]ssh.PublicKey, error) { - fileMap := make(map[string]ssh.PublicKey) - files, err := os.ReadDir(dirPath) - if err != nil { - return nil, err - } - - for _, file := range files { - filePath := filepath.Join(dirPath, file.Name()) - content, err := os.ReadFile(filePath) - if err != nil { - return nil, err - } - - parsedAuthorizedKey, _, _, _, err := ssh.ParseAuthorizedKey(content) - if err != nil { - continue - } - fileMap[ssh.FingerprintSHA256(parsedAuthorizedKey)] = parsedAuthorizedKey - } - - return fileMap, nil -} diff --git a/pkg/plugin/client/http2https.go b/pkg/plugin/client/http2https.go index 7f093af1..fd3e44b4 100644 --- a/pkg/plugin/client/http2https.go +++ b/pkg/plugin/client/http2https.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !frps + package plugin import ( diff --git a/pkg/plugin/client/http_proxy.go b/pkg/plugin/client/http_proxy.go index 06c6296a..65abf19d 100644 --- a/pkg/plugin/client/http_proxy.go +++ b/pkg/plugin/client/http_proxy.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !frps + package plugin import ( diff --git a/pkg/plugin/client/https2http.go b/pkg/plugin/client/https2http.go index aa498f3f..4a1c85b9 100644 --- a/pkg/plugin/client/https2http.go +++ b/pkg/plugin/client/https2http.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !frps + package plugin import ( diff --git a/pkg/plugin/client/https2https.go b/pkg/plugin/client/https2https.go index fc38f62b..81386ac6 100644 --- a/pkg/plugin/client/https2https.go +++ b/pkg/plugin/client/https2https.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !frps + package plugin import ( diff --git a/pkg/plugin/client/socks5.go b/pkg/plugin/client/socks5.go index c2e253d2..33e87b53 100644 --- a/pkg/plugin/client/socks5.go +++ b/pkg/plugin/client/socks5.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !frps + package plugin import ( diff --git a/pkg/plugin/client/static_file.go b/pkg/plugin/client/static_file.go index 20b79a09..faf03f7d 100644 --- a/pkg/plugin/client/static_file.go +++ b/pkg/plugin/client/static_file.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !frps + package plugin import ( diff --git a/pkg/plugin/client/unix_domain_socket.go b/pkg/plugin/client/unix_domain_socket.go index f186ec92..df68ffb4 100644 --- a/pkg/plugin/client/unix_domain_socket.go +++ b/pkg/plugin/client/unix_domain_socket.go @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !frps + package plugin import ( diff --git a/pkg/ssh/gateway.go b/pkg/ssh/gateway.go new file mode 100644 index 00000000..8f87e998 --- /dev/null +++ b/pkg/ssh/gateway.go @@ -0,0 +1,149 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "fmt" + "net" + "os" + "strconv" + "strings" + + "golang.org/x/crypto/ssh" + + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/transport" + "github.com/fatedier/frp/pkg/util/log" + utilnet "github.com/fatedier/frp/pkg/util/net" +) + +type Gateway struct { + bindPort int + ln net.Listener + + serverPeerListener *utilnet.InternalListener + + sshConfig *ssh.ServerConfig +} + +func NewGateway( + cfg v1.SSHTunnelGateway, bindAddr string, + serverPeerListener *utilnet.InternalListener, +) (*Gateway, error) { + sshConfig := &ssh.ServerConfig{} + + // privateKey + var ( + privateKeyBytes []byte + err error + ) + if cfg.PrivateKeyFile != "" { + privateKeyBytes, err = os.ReadFile(cfg.PrivateKeyFile) + } else { + if cfg.AutoGenPrivateKeyPath != "" { + privateKeyBytes, _ = os.ReadFile(cfg.AutoGenPrivateKeyPath) + } + if len(privateKeyBytes) == 0 { + privateKeyBytes, err = transport.NewRandomPrivateKey() + if err == nil && cfg.AutoGenPrivateKeyPath != "" { + err = os.WriteFile(cfg.AutoGenPrivateKeyPath, privateKeyBytes, 0o600) + } + } + } + if err != nil { + return nil, err + } + privateKey, err := ssh.ParsePrivateKey(privateKeyBytes) + if err != nil { + return nil, err + } + sshConfig.AddHostKey(privateKey) + + sshConfig.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if cfg.AuthorizedKeysFile == "" { + return &ssh.Permissions{ + Extensions: map[string]string{ + "user": "", + }, + }, nil + } + + authorizedKeysMap, err := loadAuthorizedKeysFromFile(cfg.AuthorizedKeysFile) + if err != nil { + return nil, fmt.Errorf("internal error") + } + + user, ok := authorizedKeysMap[string(key.Marshal())] + if !ok { + return nil, fmt.Errorf("unknown public key for remoteAddr %q", conn.RemoteAddr()) + } + return &ssh.Permissions{ + Extensions: map[string]string{ + "user": user, + }, + }, nil + } + + ln, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(cfg.BindPort))) + if err != nil { + return nil, err + } + return &Gateway{ + bindPort: cfg.BindPort, + ln: ln, + serverPeerListener: serverPeerListener, + sshConfig: sshConfig, + }, nil +} + +func (g *Gateway) Run() { + for { + conn, err := g.ln.Accept() + if err != nil { + return + } + go g.handleConn(conn) + } +} + +func (g *Gateway) handleConn(conn net.Conn) { + defer conn.Close() + + ts, err := NewTunnelServer(conn, g.sshConfig, g.serverPeerListener) + if err != nil { + return + } + if err := ts.Run(); err != nil { + log.Error("ssh tunnel server run error: %v", err) + } +} + +func loadAuthorizedKeysFromFile(path string) (map[string]string, error) { + authorizedKeysMap := make(map[string]string) // value is username + authorizedKeysBytes, err := os.ReadFile(path) + if err != nil { + return nil, err + } + for len(authorizedKeysBytes) > 0 { + pubKey, comment, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + return nil, err + } + + authorizedKeysMap[string(pubKey.Marshal())] = strings.TrimSpace(comment) + authorizedKeysBytes = rest + } + return authorizedKeysMap, nil +} diff --git a/pkg/ssh/server.go b/pkg/ssh/server.go new file mode 100644 index 00000000..13c87b68 --- /dev/null +++ b/pkg/ssh/server.go @@ -0,0 +1,279 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "context" + "encoding/binary" + "fmt" + "net" + "strings" + "time" + + libio "github.com/fatedier/golib/io" + "github.com/samber/lo" + "github.com/spf13/cobra" + "golang.org/x/crypto/ssh" + + "github.com/fatedier/frp/pkg/config" + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/msg" + utilnet "github.com/fatedier/frp/pkg/util/net" + "github.com/fatedier/frp/pkg/util/util" + "github.com/fatedier/frp/pkg/util/xlog" + "github.com/fatedier/frp/pkg/virtual" +) + +const ( + // https://datatracker.ietf.org/doc/html/rfc4254#page-16 + ChannelTypeServerOpenChannel = "forwarded-tcpip" + RequestTypeForward = "tcpip-forward" +) + +type tcpipForward struct { + Host string + Port uint32 +} + +// https://datatracker.ietf.org/doc/html/rfc4254#page-16 +type forwardedTCPPayload struct { + Addr string + Port uint32 + + // can be default empty value but do not delete it + // because ssh protocol shoule be reserved + OriginAddr string + OriginPort uint32 +} + +type TunnelServer struct { + underlyingConn net.Conn + sshConn *ssh.ServerConn + sc *ssh.ServerConfig + + vc *virtual.Client + serverPeerListener *utilnet.InternalListener + doneCh chan struct{} +} + +func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, serverPeerListener *utilnet.InternalListener) (*TunnelServer, error) { + s := &TunnelServer{ + underlyingConn: conn, + sc: sc, + serverPeerListener: serverPeerListener, + doneCh: make(chan struct{}), + } + return s, nil +} + +func (s *TunnelServer) Run() error { + sshConn, channels, requests, err := ssh.NewServerConn(s.underlyingConn, s.sc) + if err != nil { + return err + } + s.sshConn = sshConn + + addr, extraPayload, err := s.waitForwardAddrAndExtraPayload(channels, requests, 3*time.Second) + if err != nil { + return err + } + + clientCfg, pc, err := s.parseClientAndProxyConfigurer(addr, extraPayload) + if err != nil { + return err + } + clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User) + pc.Complete(clientCfg.User) + + s.vc = virtual.NewClient(clientCfg) + // join workConn and ssh channel + s.vc.SetInWorkConnCallback(func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool { + c, err := s.openConn(addr) + if err != nil { + return false + } + libio.Join(c, workConn) + return false + }) + // transfer connection from virtual client to server peer listener + go func() { + l := s.vc.PeerListener() + for { + conn, err := l.Accept() + if err != nil { + return + } + _ = s.serverPeerListener.PutConn(conn) + } + }() + xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100}) + ctx := xlog.NewContext(context.Background(), xl) + go func() { + _ = s.vc.Run(ctx) + }() + + s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc}) + + _ = sshConn.Wait() + _ = sshConn.Close() + s.vc.Close() + close(s.doneCh) + return nil +} + +func (s *TunnelServer) waitForwardAddrAndExtraPayload( + channels <-chan ssh.NewChannel, + requests <-chan *ssh.Request, + timeout time.Duration, +) (*tcpipForward, string, error) { + addrCh := make(chan *tcpipForward, 1) + extraPayloadCh := make(chan string, 1) + + // get forward address + go func() { + addrGot := false + for req := range requests { + switch req.Type { + case RequestTypeForward: + if !addrGot { + payload := tcpipForward{} + if err := ssh.Unmarshal(req.Payload, &payload); err != nil { + return + } + addrGot = true + addrCh <- &payload + } + default: + if req.WantReply { + _ = req.Reply(true, nil) + } + } + } + }() + + // get extra payload + go func() { + for newChannel := range channels { + // extraPayload will send to extraPayloadCh + go s.handleNewChannel(newChannel, extraPayloadCh) + } + }() + + var ( + addr *tcpipForward + extraPayload string + ) + + timer := time.NewTimer(timeout) + defer timer.Stop() + for { + select { + case v := <-addrCh: + addr = v + case extra := <-extraPayloadCh: + extraPayload = extra + case <-timer.C: + return nil, "", fmt.Errorf("get addr and extra payload timeout") + } + if addr != nil && extraPayload != "" { + break + } + } + return addr, extraPayload, nil +} + +func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, error) { + cmd := &cobra.Command{} + args := strings.Split(extraPayload, " ") + if len(args) < 1 { + return nil, nil, fmt.Errorf("invalid extra payload") + } + proxyType := strings.TrimSpace(args[0]) + supportTypes := []string{"tcp", "http", "https", "tcpmux", "stcp"} + if !lo.Contains(supportTypes, proxyType) { + return nil, nil, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes) + } + pc := v1.NewProxyConfigurerByType(v1.ProxyType(proxyType)) + if pc == nil { + return nil, nil, fmt.Errorf("new proxy configurer error") + } + config.RegisterProxyFlags(cmd, pc) + + clientCfg := v1.ClientCommonConfig{} + config.RegisterClientCommonConfigFlags(cmd, &clientCfg) + + if err := cmd.ParseFlags(args); err != nil { + return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err) + } + return &clientCfg, pc, nil +} + +func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh chan string) { + ch, reqs, err := channel.Accept() + if err != nil { + return + } + go s.keepAlive(ch) + + for req := range reqs { + if req.Type != "exec" { + continue + } + if len(req.Payload) <= 4 { + continue + } + end := 4 + binary.BigEndian.Uint32(req.Payload[:4]) + if len(req.Payload) < int(end) { + continue + } + extraPayload := string(req.Payload[4:end]) + select { + case extraPayloadCh <- extraPayload: + default: + } + } +} + +func (s *TunnelServer) keepAlive(ch ssh.Channel) { + tk := time.NewTicker(time.Second * 30) + defer tk.Stop() + + for { + select { + case <-tk.C: + _, err := ch.SendRequest("heartbeat", false, nil) + if err != nil { + return + } + case <-s.doneCh: + return + } + } +} + +func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) { + payload := forwardedTCPPayload{ + Addr: addr.Host, + Port: addr.Port, + } + channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload)) + if err != nil { + return nil, fmt.Errorf("open ssh channel error: %v", err) + } + go ssh.DiscardRequests(reqs) + + conn := utilnet.WrapReadWriteCloserToConn(channel, s.underlyingConn) + return conn, nil +} diff --git a/pkg/ssh/service.go b/pkg/ssh/service.go deleted file mode 100644 index ce0bc52c..00000000 --- a/pkg/ssh/service.go +++ /dev/null @@ -1,497 +0,0 @@ -package ssh - -import ( - "encoding/binary" - "errors" - "flag" - "fmt" - "io" - "net" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - gerror "github.com/fatedier/golib/errors" - "golang.org/x/crypto/ssh" - - v1 "github.com/fatedier/frp/pkg/config/v1" - "github.com/fatedier/frp/pkg/util/log" -) - -const ( - // ssh protocol define - // https://datatracker.ietf.org/doc/html/rfc4254#page-16 - ChannelTypeServerOpenChannel = "forwarded-tcpip" - RequestTypeForward = "tcpip-forward" - - // golang ssh package define. - // https://pkg.go.dev/golang.org/x/crypto/ssh - RequestTypeHeartbeat = "keepalive@openssh.com" -) - -// 当 proxy 失败会返回该错误 -type VProxyError struct{} - -// ssh protocol define -// https://datatracker.ietf.org/doc/html/rfc4254#page-16 -// parse ssh client cmds input -type forwardedTCPPayload struct { - Addr string - Port uint32 - - // can be default empty value but do not delete it - // because ssh protocol shoule be reserved - OriginAddr string - OriginPort uint32 -} - -// custom define -// parse ssh client cmds input -type CmdPayload struct { - Address string - Port uint32 -} - -// custom define -// with frp control cmds -type ExtraPayload struct { - Type string - - // TODO port can be set by extra message and priority to ssh raw cmd - Address string - Port uint32 -} - -type Service struct { - tcpConn net.Conn - cfg *ssh.ServerConfig - - sshConn *ssh.ServerConn - gChannel <-chan ssh.NewChannel - gReq <-chan *ssh.Request - - addrPayloadCh chan CmdPayload - extraPayloadCh chan ExtraPayload - - proxyPayloadCh chan v1.ProxyConfigurer - replyCh chan interface{} - - closeCh chan struct{} - exit int32 -} - -func NewSSHService( - tcpConn net.Conn, - cfg *ssh.ServerConfig, - proxyPayloadCh chan v1.ProxyConfigurer, - replyCh chan interface{}, -) (ss *Service, err error) { - ss = &Service{ - tcpConn: tcpConn, - cfg: cfg, - - addrPayloadCh: make(chan CmdPayload), - extraPayloadCh: make(chan ExtraPayload), - - proxyPayloadCh: proxyPayloadCh, - replyCh: replyCh, - - closeCh: make(chan struct{}), - exit: 0, - } - - ss.sshConn, ss.gChannel, ss.gReq, err = ssh.NewServerConn(tcpConn, cfg) - if err != nil { - log.Error("ssh handshake error: %v", err) - return nil, err - } - - log.Info("ssh connection success") - - return ss, nil -} - -func (ss *Service) Run() { - go ss.loopGenerateProxy() - go ss.loopParseCmdPayload() - go ss.loopParseExtraPayload() - go ss.loopReply() -} - -func (ss *Service) Exit() <-chan struct{} { - return ss.closeCh -} - -func (ss *Service) Close() { - if atomic.LoadInt32(&ss.exit) == 1 { - return - } - - select { - case <-ss.closeCh: - return - default: - } - - close(ss.closeCh) - close(ss.addrPayloadCh) - close(ss.extraPayloadCh) - - _ = ss.sshConn.Wait() - - ss.sshConn.Close() - ss.tcpConn.Close() - - atomic.StoreInt32(&ss.exit, 1) - - log.Info("ssh service close") -} - -func (ss *Service) loopParseCmdPayload() { - for { - select { - case req, ok := <-ss.gReq: - if !ok { - log.Info("global request is close") - ss.Close() - return - } - - switch req.Type { - case RequestTypeForward: - var addrPayload CmdPayload - if err := ssh.Unmarshal(req.Payload, &addrPayload); err != nil { - log.Error("ssh unmarshal error: %v", err) - return - } - _ = gerror.PanicToError(func() { - ss.addrPayloadCh <- addrPayload - }) - default: - if req.Type == RequestTypeHeartbeat { - log.Debug("ssh heartbeat data") - } else { - log.Info("default req, data: %v", req) - } - } - if req.WantReply { - err := req.Reply(true, nil) - if err != nil { - log.Error("reply to ssh client error: %v", err) - } - } - case <-ss.closeCh: - log.Info("loop parse cmd payload close") - return - } - } -} - -func (ss *Service) loopSendHeartbeat(ch ssh.Channel) { - tk := time.NewTicker(time.Second * 60) - defer tk.Stop() - - for { - select { - case <-tk.C: - ok, err := ch.SendRequest("heartbeat", false, nil) - if err != nil { - log.Error("channel send req error: %v", err) - if err == io.EOF { - ss.Close() - return - } - continue - } - log.Debug("heartbeat send success, ok: %v", ok) - case <-ss.closeCh: - return - } - } -} - -func (ss *Service) loopParseExtraPayload() { - log.Info("loop parse extra payload start") - - for newChannel := range ss.gChannel { - ch, req, err := newChannel.Accept() - if err != nil { - log.Error("channel accept error: %v", err) - return - } - - go ss.loopSendHeartbeat(ch) - - go func(req <-chan *ssh.Request) { - for r := range req { - if len(r.Payload) <= 4 { - log.Info("r.payload is less than 4") - continue - } - if !strings.Contains(string(r.Payload), "tcp") && !strings.Contains(string(r.Payload), "http") { - log.Info("ssh protocol exchange data") - continue - } - - // [4byte data_len|data] - end := 4 + binary.BigEndian.Uint32(r.Payload[:4]) - if end > uint32(len(r.Payload)) { - end = uint32(len(r.Payload)) - } - p := string(r.Payload[4:end]) - - msg, err := parseSSHExtraMessage(p) - if err != nil { - log.Error("parse ssh extra message error: %v, payload: %v", err, r.Payload) - continue - } - _ = gerror.PanicToError(func() { - ss.extraPayloadCh <- msg - }) - return - } - }(req) - } -} - -func (ss *Service) SSHConn() *ssh.ServerConn { - return ss.sshConn -} - -func (ss *Service) TCPConn() net.Conn { - return ss.tcpConn -} - -func (ss *Service) loopReply() { - for { - select { - case <-ss.closeCh: - log.Info("loop reply close") - return - case req := <-ss.replyCh: - switch req.(type) { - case *VProxyError: - log.Error("run frp proxy error, close ssh service") - ss.Close() - default: - // TODO - } - } - } -} - -func (ss *Service) loopGenerateProxy() { - log.Info("loop generate proxy start") - - for { - if atomic.LoadInt32(&ss.exit) == 1 { - return - } - - wg := new(sync.WaitGroup) - wg.Add(2) - - var p1 CmdPayload - var p2 ExtraPayload - - go func() { - defer wg.Done() - for { - select { - case <-ss.closeCh: - return - case p1 = <-ss.addrPayloadCh: - return - } - } - }() - - go func() { - defer wg.Done() - for { - select { - case <-ss.closeCh: - return - case p2 = <-ss.extraPayloadCh: - return - } - } - }() - - wg.Wait() - - if atomic.LoadInt32(&ss.exit) == 1 { - return - } - - switch p2.Type { - case "http": - case "tcp": - ss.proxyPayloadCh <- &v1.TCPProxyConfig{ - ProxyBaseConfig: v1.ProxyBaseConfig{ - Name: fmt.Sprintf("ssh-proxy-%v-%v", ss.tcpConn.RemoteAddr().String(), time.Now().UnixNano()), - Type: p2.Type, - - ProxyBackend: v1.ProxyBackend{ - LocalIP: p1.Address, - }, - }, - RemotePort: int(p1.Port), - } - default: - log.Warn("invalid frp proxy type: %v", p2.Type) - } - } -} - -func parseSSHExtraMessage(s string) (p ExtraPayload, err error) { - sn := len(s) - - log.Info("parse ssh extra message: %v", s) - - ss := strings.Fields(s) - if len(ss) == 0 { - if sn != 0 { - ss = append(ss, s) - } else { - return p, fmt.Errorf("invalid ssh input, args: %v", ss) - } - } - - for i, v := range ss { - ss[i] = strings.TrimSpace(v) - } - - if ss[0] != "tcp" && ss[0] != "http" { - return p, fmt.Errorf("only support tcp/http now") - } - - switch ss[0] { - case "tcp": - tcpCmd, err := ParseTCPCommand(ss) - if err != nil { - return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err) - } - - port, _ := strconv.Atoi(tcpCmd.Port) - - p = ExtraPayload{ - Type: "tcp", - Address: tcpCmd.Address, - Port: uint32(port), - } - case "http": - httpCmd, err := ParseHTTPCommand(ss) - if err != nil { - return ExtraPayload{}, fmt.Errorf("invalid ssh input: %v", err) - } - - _ = httpCmd - - p = ExtraPayload{ - Type: "http", - } - } - - return p, nil -} - -type HTTPCommand struct { - Domain string - BasicAuthUser string - BasicAuthPass string -} - -func ParseHTTPCommand(params []string) (*HTTPCommand, error) { - if len(params) < 2 { - return nil, errors.New("invalid HTTP command") - } - - var ( - basicAuth string - domainURL string - basicAuthUser string - basicAuthPass string - ) - - fs := flag.NewFlagSet("http", flag.ContinueOnError) - fs.StringVar(&basicAuth, "basic-auth", "", "") - fs.StringVar(&domainURL, "domain", "", "") - - fs.SetOutput(&nullWriter{}) // Disables usage output - - err := fs.Parse(params[2:]) - if err != nil { - if !errors.Is(err, flag.ErrHelp) { - return nil, err - } - } - - if basicAuth != "" { - authParts := strings.SplitN(basicAuth, ":", 2) - basicAuthUser = authParts[0] - if len(authParts) > 1 { - basicAuthPass = authParts[1] - } - } - - httpCmd := &HTTPCommand{ - Domain: domainURL, - BasicAuthUser: basicAuthUser, - BasicAuthPass: basicAuthPass, - } - return httpCmd, nil -} - -type TCPCommand struct { - Address string - Port string -} - -func ParseTCPCommand(params []string) (*TCPCommand, error) { - if len(params) == 0 || params[0] != "tcp" { - return nil, errors.New("invalid TCP command") - } - - if len(params) == 1 { - return &TCPCommand{}, nil - } - - var ( - address string - port string - ) - - fs := flag.NewFlagSet("tcp", flag.ContinueOnError) - fs.StringVar(&address, "address", "", "The IP address to listen on") - fs.StringVar(&port, "port", "", "The port to listen on") - fs.SetOutput(&nullWriter{}) // Disables usage output - - args := params[1:] - err := fs.Parse(args) - if err != nil { - if !errors.Is(err, flag.ErrHelp) { - return nil, err - } - } - - parsedAddr, err := net.ResolveIPAddr("ip", address) - if err != nil { - return nil, err - } - if _, err := net.LookupPort("tcp", port); err != nil { - return nil, err - } - - tcpCmd := &TCPCommand{ - Address: parsedAddr.String(), - Port: port, - } - return tcpCmd, nil -} - -type nullWriter struct{} - -func (w *nullWriter) Write(p []byte) (n int, err error) { return len(p), nil } diff --git a/pkg/ssh/vclient.go b/pkg/ssh/vclient.go deleted file mode 100644 index e78c8284..00000000 --- a/pkg/ssh/vclient.go +++ /dev/null @@ -1,185 +0,0 @@ -package ssh - -import ( - "context" - "fmt" - "net" - "sync/atomic" - "time" - - "golang.org/x/crypto/ssh" - - "github.com/fatedier/frp/pkg/config" - v1 "github.com/fatedier/frp/pkg/config/v1" - "github.com/fatedier/frp/pkg/msg" - plugin "github.com/fatedier/frp/pkg/plugin/server" - "github.com/fatedier/frp/pkg/util/log" - frp_net "github.com/fatedier/frp/pkg/util/net" - "github.com/fatedier/frp/pkg/util/util" - "github.com/fatedier/frp/pkg/util/xlog" - "github.com/fatedier/frp/server/controller" - "github.com/fatedier/frp/server/proxy" -) - -// VirtualService is a client VirtualService run in frps -type VirtualService struct { - clientCfg v1.ClientCommonConfig - pxyCfg v1.ProxyConfigurer - serverCfg v1.ServerConfig - - sshSvc *Service - - // uniq id got from frps, attach it in loginMsg - runID string - loginMsg *msg.Login - - // All resource managers and controllers - rc *controller.ResourceController - - exit uint32 // 0 means not exit - // SSHService context - ctx context.Context - // call cancel to stop SSHService - cancel context.CancelFunc - - replyCh chan interface{} - pxy proxy.Proxy -} - -func NewVirtualService( - ctx context.Context, - clientCfg v1.ClientCommonConfig, - serverCfg v1.ServerConfig, - logMsg msg.Login, - rc *controller.ResourceController, - pxyCfg v1.ProxyConfigurer, - sshSvc *Service, - replyCh chan interface{}, -) (svr *VirtualService, err error) { - svr = &VirtualService{ - clientCfg: clientCfg, - serverCfg: serverCfg, - rc: rc, - - loginMsg: &logMsg, - - sshSvc: sshSvc, - pxyCfg: pxyCfg, - - ctx: ctx, - exit: 0, - - replyCh: replyCh, - } - - svr.runID, err = util.RandID() - if err != nil { - return nil, err - } - - go svr.loopCheck() - - return -} - -func (svr *VirtualService) Run(ctx context.Context) (err error) { - ctx, cancel := context.WithCancel(ctx) - svr.ctx = xlog.NewContext(ctx, xlog.New()) - svr.cancel = cancel - - remoteAddr, err := svr.RegisterProxy(&msg.NewProxy{ - ProxyName: svr.pxyCfg.(*v1.TCPProxyConfig).Name, - ProxyType: svr.pxyCfg.(*v1.TCPProxyConfig).Type, - RemotePort: svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort, - }) - if err != nil { - return err - } - - log.Info("run a reverse proxy on port: %v", remoteAddr) - - return nil -} - -func (svr *VirtualService) Close() { - svr.GracefulClose(time.Duration(0)) -} - -func (svr *VirtualService) GracefulClose(d time.Duration) { - atomic.StoreUint32(&svr.exit, 1) - svr.pxy.Close() - - if svr.cancel != nil { - svr.cancel() - } - - svr.replyCh <- &VProxyError{} -} - -func (svr *VirtualService) loopCheck() { - <-svr.sshSvc.Exit() - svr.pxy.Close() - log.Info("virtual client service close") -} - -func (svr *VirtualService) RegisterProxy(pxyMsg *msg.NewProxy) (remoteAddr string, err error) { - var pxyConf v1.ProxyConfigurer - pxyConf, err = config.NewProxyConfigurerFromMsg(pxyMsg, &svr.serverCfg) - if err != nil { - return - } - - // User info - userInfo := plugin.UserInfo{ - User: svr.loginMsg.User, - Metas: svr.loginMsg.Metas, - RunID: svr.runID, - } - - svr.pxy, err = proxy.NewProxy(svr.ctx, &proxy.Options{ - LoginMsg: svr.loginMsg, - UserInfo: userInfo, - Configurer: pxyConf, - ResourceController: svr.rc, - - GetWorkConnFn: svr.GetWorkConn, - PoolCount: 10, - - ServerCfg: &svr.serverCfg, - }) - if err != nil { - return remoteAddr, err - } - - remoteAddr, err = svr.pxy.Run() - if err != nil { - log.Warn("proxy run error: %v", err) - return - } - - defer func() { - if err != nil { - log.Warn("proxy close") - svr.pxy.Close() - } - }() - - return -} - -func (svr *VirtualService) GetWorkConn() (workConn net.Conn, err error) { - // tell ssh client open a new stream for work - payload := forwardedTCPPayload{ - Addr: svr.serverCfg.BindAddr, // TODO refine - Port: uint32(svr.pxyCfg.(*v1.TCPProxyConfig).RemotePort), - } - - channel, reqs, err := svr.sshSvc.SSHConn().OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(payload)) - if err != nil { - return nil, fmt.Errorf("open ssh channel error: %v", err) - } - go ssh.DiscardRequests(reqs) - - workConn = frp_net.WrapReadWriteCloserToConn(channel, svr.sshSvc.tcpConn) - return workConn, nil -} diff --git a/pkg/transport/tls.go b/pkg/transport/tls.go index d92b1a82..5bc75921 100644 --- a/pkg/transport/tls.go +++ b/pkg/transport/tls.go @@ -128,3 +128,15 @@ func NewClientTLSConfig(certPath, keyPath, caPath, serverName string) (*tls.Conf return base, nil } + +func NewRandomPrivateKey() ([]byte, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + return keyPEM, nil +} diff --git a/pkg/util/xlog/xlog.go b/pkg/util/xlog/xlog.go index b5746f9d..7b69dcaf 100644 --- a/pkg/util/xlog/xlog.go +++ b/pkg/util/xlog/xlog.go @@ -15,40 +15,81 @@ package xlog import ( + "sort" + "github.com/fatedier/frp/pkg/util/log" ) +type LogPrefix struct { + // Name is the name of the prefix, it won't be displayed in log but used to identify the prefix. + Name string + // Value is the value of the prefix, it will be displayed in log. + Value string + // The prefix with higher priority will be displayed first, default is 10. + Priority int +} + // Logger is not thread safety for operations on prefix type Logger struct { - prefixes []string + prefixes []LogPrefix prefixString string } func New() *Logger { return &Logger{ - prefixes: make([]string, 0), + prefixes: make([]LogPrefix, 0), } } -func (l *Logger) ResetPrefixes() (old []string) { +func (l *Logger) ResetPrefixes() (old []LogPrefix) { old = l.prefixes - l.prefixes = make([]string, 0) + l.prefixes = make([]LogPrefix, 0) l.prefixString = "" return } func (l *Logger) AppendPrefix(prefix string) *Logger { - l.prefixes = append(l.prefixes, prefix) - l.prefixString += "[" + prefix + "] " + return l.AddPrefix(LogPrefix{ + Name: prefix, + Value: prefix, + Priority: 10, + }) +} + +func (l *Logger) AddPrefix(prefix LogPrefix) *Logger { + found := false + if prefix.Priority <= 0 { + prefix.Priority = 10 + } + for _, p := range l.prefixes { + if p.Name == prefix.Name { + found = true + p.Value = prefix.Value + p.Priority = prefix.Priority + } + } + if !found { + l.prefixes = append(l.prefixes, prefix) + } + l.renderPrefixString() return l } +func (l *Logger) renderPrefixString() { + sort.SliceStable(l.prefixes, func(i, j int) bool { + return l.prefixes[i].Priority < l.prefixes[j].Priority + }) + l.prefixString = "" + for _, v := range l.prefixes { + l.prefixString += "[" + v.Value + "] " + } +} + func (l *Logger) Spawn() *Logger { nl := New() - for _, v := range l.prefixes { - nl.AppendPrefix(v) - } + nl.prefixes = append(nl.prefixes, l.prefixes...) + nl.renderPrefixString() return nl } diff --git a/pkg/virtual/client.go b/pkg/virtual/client.go new file mode 100644 index 00000000..d0369a1a --- /dev/null +++ b/pkg/virtual/client.go @@ -0,0 +1,92 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package virtual + +import ( + "context" + "net" + + "github.com/fatedier/frp/client" + v1 "github.com/fatedier/frp/pkg/config/v1" + "github.com/fatedier/frp/pkg/msg" + utilnet "github.com/fatedier/frp/pkg/util/net" +) + +type Client struct { + l *utilnet.InternalListener + svr *client.Service +} + +func NewClient(cfg *v1.ClientCommonConfig) *Client { + cfg.Complete() + + ln := utilnet.NewInternalListener() + + svr := client.NewService(cfg, nil, nil, "") + svr.SetConnectorCreator(func(context.Context, *v1.ClientCommonConfig) client.Connector { + return &pipeConnector{ + peerListener: ln, + } + }) + + return &Client{ + l: ln, + svr: svr, + } +} + +func (c *Client) PeerListener() net.Listener { + return c.l +} + +func (c *Client) SetInWorkConnCallback(cb func(*v1.ProxyBaseConfig, net.Conn, *msg.StartWorkConn) bool) { + c.svr.SetInWorkConnCallback(cb) +} + +func (c *Client) UpdateProxyConfigurer(proxyCfgs []v1.ProxyConfigurer) { + _ = c.svr.ReloadConf(proxyCfgs, nil) +} + +func (c *Client) Run(ctx context.Context) error { + return c.svr.Run(ctx) +} + +func (c *Client) Close() { + c.l.Close() + c.svr.Close() +} + +type pipeConnector struct { + peerListener *utilnet.InternalListener +} + +func (pc *pipeConnector) Open() error { + return nil +} + +func (pc *pipeConnector) Connect() (net.Conn, error) { + c1, c2 := net.Pipe() + if err := pc.peerListener.PutConn(c1); err != nil { + c1.Close() + c2.Close() + return nil, err + } + return c2, nil +} + +func (pc *pipeConnector) Close() error { + pc.peerListener.Close() + return nil +} diff --git a/server/proxy/proxy.go b/server/proxy/proxy.go index 5ea99f1e..fe6f781b 100644 --- a/server/proxy/proxy.go +++ b/server/proxy/proxy.go @@ -21,7 +21,6 @@ import ( "net" "reflect" "strconv" - "strings" "sync" "time" @@ -230,14 +229,8 @@ func (pxy *BaseProxy) handleUserTCPConnection(userConn net.Conn) { return } - var workConn net.Conn - // try all connections from the pool - if strings.HasPrefix(pxy.GetLoginMsg().User, v1.SSHClientLoginUserPrefix) { - workConn, err = pxy.getWorkConnFn() - } else { - workConn, err = pxy.GetWorkConnFromPool(userConn.RemoteAddr(), userConn.LocalAddr()) - } + workConn, err := pxy.GetWorkConnFromPool(userConn.RemoteAddr(), userConn.LocalAddr()) if err != nil { return } diff --git a/server/service.go b/server/service.go index 2ca501be..02efec91 100644 --- a/server/service.go +++ b/server/service.go @@ -18,13 +18,10 @@ import ( "bytes" "context" "crypto/tls" - "errors" "fmt" "io" "net" "net/http" - "os" - "reflect" "strconv" "time" @@ -32,7 +29,6 @@ import ( fmux "github.com/hashicorp/yamux" quic "github.com/quic-go/quic-go" "github.com/samber/lo" - "golang.org/x/crypto/ssh" "github.com/fatedier/frp/assets" "github.com/fatedier/frp/pkg/auth" @@ -41,7 +37,7 @@ import ( "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/nathole" plugin "github.com/fatedier/frp/pkg/plugin/server" - frpssh "github.com/fatedier/frp/pkg/ssh" + "github.com/fatedier/frp/pkg/ssh" "github.com/fatedier/frp/pkg/transport" "github.com/fatedier/frp/pkg/util/log" utilnet "github.com/fatedier/frp/pkg/util/net" @@ -71,10 +67,6 @@ type Service struct { // Accept connections from client listener net.Listener - // Accept connections using ssh - sshListener net.Listener - sshConfig *ssh.ServerConfig - // Accept connections using kcp kcpListener net.Listener @@ -87,6 +79,8 @@ type Service struct { // Accept frp tls connections tlsListener net.Listener + virtualListener *utilnet.InternalListener + // Manage all controllers ctlManager *ControlManager @@ -102,6 +96,8 @@ type Service struct { // All resource managers and controllers rc *controller.ResourceController + sshTunnelGateway *ssh.Gateway + // Verifies authentication based on selected method authVerifier auth.Verifier @@ -133,6 +129,7 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) { TCPPortManager: ports.NewManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts), UDPPortManager: ports.NewManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts), }, + virtualListener: utilnet.NewInternalListener(), httpVhostRouter: vhost.NewRouters(), authVerifier: auth.NewAuthVerifier(cfg.Auth), tlsConfig: tlsConfig, @@ -208,67 +205,6 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) { svr.listener = ln log.Info("frps tcp listen on %s", address) - if cfg.SSHTunnelGateway.BindPort > 0 { - - if cfg.SSHTunnelGateway.PublicKeyFilesPath != "" { - cfg.SSHTunnelGateway.PublicKeyFilesMap, err = v1.LoadSSHPublicKeyFilesInDir(cfg.SSHTunnelGateway.PublicKeyFilesPath) - if err != nil { - return nil, fmt.Errorf("load ssh all public key files error: %v", err) - } - log.Info("load %v public key files success", cfg.SSHTunnelGateway.PublicKeyFilesPath) - } - - svr.sshConfig = &ssh.ServerConfig{ - NoClientAuth: lo.If(cfg.SSHTunnelGateway.PublicKeyFilesPath == "", true).Else(false), - - PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - parsedAuthorizedKey, ok := cfg.SSHTunnelGateway.PublicKeyFilesMap[ssh.FingerprintSHA256(key)] - if !ok { - return nil, errors.New("cannot find public key file") - } - - if key.Type() == parsedAuthorizedKey.Type() && reflect.DeepEqual(parsedAuthorizedKey, key) { - return &ssh.Permissions{ - Extensions: map[string]string{}, - }, nil - } - return nil, fmt.Errorf("unknown public key for %q", conn.User()) - }, - } - - var privateBytes []byte - if cfg.SSHTunnelGateway.PrivateKeyFilePath != "" { - privateBytes, err = os.ReadFile(cfg.SSHTunnelGateway.PrivateKeyFilePath) - if err != nil { - log.Error("Failed to load private key") - return nil, err - } - log.Info("load %v private key file success", cfg.SSHTunnelGateway.PrivateKeyFilePath) - } else { - privateBytes, err = v1.GeneratePrivateKey() - if err != nil { - log.Error("Failed to load private key") - return nil, err - } - log.Info("auto gen private key file success") - } - private, err := ssh.ParsePrivateKey(privateBytes) - if err != nil { - log.Error("Failed to parse private key, error: %v", err) - return nil, err - } - - svr.sshConfig.AddHostKey(private) - - sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHTunnelGateway.BindPort)) - svr.sshListener, err = net.Listen("tcp", sshAddr) - if err != nil { - log.Error("Failed to listen on %v, error: %v", sshAddr, err) - return nil, err - } - log.Info("ssh server listening on %v", sshAddr) - } - // Listen for accepting connections from client using kcp protocol. if cfg.KCPBindPort > 0 { address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.KCPBindPort)) @@ -293,7 +229,17 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) { err = fmt.Errorf("listen on quic udp address %s error: %v", address, err) return } - log.Info("frps quic listen on quic %s", address) + log.Info("frps quic listen on %s", address) + } + + if cfg.SSHTunnelGateway.BindPort > 0 { + sshGateway, err := ssh.NewGateway(cfg.SSHTunnelGateway, cfg.ProxyBindAddr, svr.virtualListener) + if err != nil { + err = fmt.Errorf("create ssh gateway error: %v", err) + return nil, err + } + svr.sshTunnelGateway = sshGateway + log.Info("frps sshTunnelGateway listen on port %d", cfg.SSHTunnelGateway.BindPort) } // Listen for accepting connections from client using websocket protocol. @@ -396,23 +342,26 @@ func (svr *Service) Run(ctx context.Context) { svr.ctx = ctx svr.cancel = cancel - if svr.sshListener != nil { - go svr.HandleSSHListener(svr.sshListener) - } + go svr.HandleListener(svr.virtualListener, true) if svr.kcpListener != nil { - go svr.HandleListener(svr.kcpListener) + go svr.HandleListener(svr.kcpListener, false) } if svr.quicListener != nil { go svr.HandleQUICListener(svr.quicListener) } - go svr.HandleListener(svr.websocketListener) - go svr.HandleListener(svr.tlsListener) + go svr.HandleListener(svr.websocketListener, false) + go svr.HandleListener(svr.tlsListener, false) if svr.rc.NatHoleController != nil { go svr.rc.NatHoleController.CleanWorker(svr.ctx) } - svr.HandleListener(svr.listener) + + if svr.sshTunnelGateway != nil { + go svr.sshTunnelGateway.Run() + } + + svr.HandleListener(svr.listener, false) <-svr.ctx.Done() // service context may not be canceled by svr.Close(), we should call it here to release resources @@ -422,10 +371,6 @@ func (svr *Service) Run(ctx context.Context) { } func (svr *Service) Close() error { - if svr.sshListener != nil { - svr.sshListener.Close() - svr.sshListener = nil - } if svr.kcpListener != nil { svr.kcpListener.Close() svr.kcpListener = nil @@ -516,7 +461,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) { } } -func (svr *Service) HandleListener(l net.Listener) { +func (svr *Service) HandleListener(l net.Listener, internal bool) { // Listen for incoming connections from client. for { c, err := l.Accept() @@ -532,8 +477,9 @@ func (svr *Service) HandleListener(l net.Listener) { log.Trace("start check TLS connection...") originConn := c + forceTLS := svr.cfg.Transport.TLS.Force && !internal var isTLS, custom bool - c, isTLS, custom, err = utilnet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.Transport.TLS.Force, connReadTimeout) + c, isTLS, custom, err = utilnet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, forceTLS, connReadTimeout) if err != nil { log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err) originConn.Close() @@ -543,7 +489,7 @@ func (svr *Service) HandleListener(l net.Listener) { // Start a new goroutine to handle connection. go func(ctx context.Context, frpConn net.Conn) { - if lo.FromPtr(svr.cfg.Transport.TCPMux) { + if lo.FromPtr(svr.cfg.Transport.TCPMux) && !internal { fmuxCfg := fmux.DefaultConfig() fmuxCfg.KeepAliveInterval = time.Duration(svr.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second fmuxCfg.LogOutput = io.Discard @@ -571,52 +517,6 @@ func (svr *Service) HandleListener(l net.Listener) { } } -func (svr *Service) HandleSSHListener(listener net.Listener) { - for { - tcpConn, err := listener.Accept() - if err != nil { - log.Error("failed to accept incoming ssh connection (%s)", err) - return - } - log.Info("new tcp conn connected: %v", tcpConn.RemoteAddr().String()) - - pxyPayloadCh := make(chan v1.ProxyConfigurer) - replyCh := make(chan interface{}) - - ss, err := frpssh.NewSSHService(tcpConn, svr.sshConfig, pxyPayloadCh, replyCh) - if err != nil { - log.Error("new ssh service error: %v", err) - continue - } - ss.Run() - - go func() { - for { - pxyCfg := <-pxyPayloadCh - - ctx := context.Background() - - // TODO fill client common config and login msg - vs, err := frpssh.NewVirtualService(ctx, v1.ClientCommonConfig{}, *svr.cfg, - msg.Login{User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String()}, - svr.rc, pxyCfg, ss, replyCh) - if err != nil { - log.Error("new virtual service error: %v", err) - ss.Close() - return - } - - err = vs.Run(ctx) - if err != nil { - log.Error("proxy run error: %v", err) - vs.Close() - return - } - } - }() - } -} - func (svr *Service) HandleQUICListener(l *quic.Listener) { // Listen for incoming connections from client. for {