sshTunnelGateway refactor (#3784)

This commit is contained in:
fatedier
2023-11-21 11:19:35 +08:00
parent 8b432e179d
commit d5b41f1e14
34 changed files with 1036 additions and 1255 deletions

View File

@ -21,7 +21,6 @@ import (
"net"
"reflect"
"strconv"
"strings"
"sync"
"time"
@ -230,14 +229,8 @@ func (pxy *BaseProxy) handleUserTCPConnection(userConn net.Conn) {
return
}
var workConn net.Conn
// try all connections from the pool
if strings.HasPrefix(pxy.GetLoginMsg().User, v1.SSHClientLoginUserPrefix) {
workConn, err = pxy.getWorkConnFn()
} else {
workConn, err = pxy.GetWorkConnFromPool(userConn.RemoteAddr(), userConn.LocalAddr())
}
workConn, err := pxy.GetWorkConnFromPool(userConn.RemoteAddr(), userConn.LocalAddr())
if err != nil {
return
}

View File

@ -18,13 +18,10 @@ import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"reflect"
"strconv"
"time"
@ -32,7 +29,6 @@ import (
fmux "github.com/hashicorp/yamux"
quic "github.com/quic-go/quic-go"
"github.com/samber/lo"
"golang.org/x/crypto/ssh"
"github.com/fatedier/frp/assets"
"github.com/fatedier/frp/pkg/auth"
@ -41,7 +37,7 @@ import (
"github.com/fatedier/frp/pkg/msg"
"github.com/fatedier/frp/pkg/nathole"
plugin "github.com/fatedier/frp/pkg/plugin/server"
frpssh "github.com/fatedier/frp/pkg/ssh"
"github.com/fatedier/frp/pkg/ssh"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net"
@ -71,10 +67,6 @@ type Service struct {
// Accept connections from client
listener net.Listener
// Accept connections using ssh
sshListener net.Listener
sshConfig *ssh.ServerConfig
// Accept connections using kcp
kcpListener net.Listener
@ -87,6 +79,8 @@ type Service struct {
// Accept frp tls connections
tlsListener net.Listener
virtualListener *utilnet.InternalListener
// Manage all controllers
ctlManager *ControlManager
@ -102,6 +96,8 @@ type Service struct {
// All resource managers and controllers
rc *controller.ResourceController
sshTunnelGateway *ssh.Gateway
// Verifies authentication based on selected method
authVerifier auth.Verifier
@ -133,6 +129,7 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
TCPPortManager: ports.NewManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts),
UDPPortManager: ports.NewManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts),
},
virtualListener: utilnet.NewInternalListener(),
httpVhostRouter: vhost.NewRouters(),
authVerifier: auth.NewAuthVerifier(cfg.Auth),
tlsConfig: tlsConfig,
@ -208,67 +205,6 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
svr.listener = ln
log.Info("frps tcp listen on %s", address)
if cfg.SSHTunnelGateway.BindPort > 0 {
if cfg.SSHTunnelGateway.PublicKeyFilesPath != "" {
cfg.SSHTunnelGateway.PublicKeyFilesMap, err = v1.LoadSSHPublicKeyFilesInDir(cfg.SSHTunnelGateway.PublicKeyFilesPath)
if err != nil {
return nil, fmt.Errorf("load ssh all public key files error: %v", err)
}
log.Info("load %v public key files success", cfg.SSHTunnelGateway.PublicKeyFilesPath)
}
svr.sshConfig = &ssh.ServerConfig{
NoClientAuth: lo.If(cfg.SSHTunnelGateway.PublicKeyFilesPath == "", true).Else(false),
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
parsedAuthorizedKey, ok := cfg.SSHTunnelGateway.PublicKeyFilesMap[ssh.FingerprintSHA256(key)]
if !ok {
return nil, errors.New("cannot find public key file")
}
if key.Type() == parsedAuthorizedKey.Type() && reflect.DeepEqual(parsedAuthorizedKey, key) {
return &ssh.Permissions{
Extensions: map[string]string{},
}, nil
}
return nil, fmt.Errorf("unknown public key for %q", conn.User())
},
}
var privateBytes []byte
if cfg.SSHTunnelGateway.PrivateKeyFilePath != "" {
privateBytes, err = os.ReadFile(cfg.SSHTunnelGateway.PrivateKeyFilePath)
if err != nil {
log.Error("Failed to load private key")
return nil, err
}
log.Info("load %v private key file success", cfg.SSHTunnelGateway.PrivateKeyFilePath)
} else {
privateBytes, err = v1.GeneratePrivateKey()
if err != nil {
log.Error("Failed to load private key")
return nil, err
}
log.Info("auto gen private key file success")
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
log.Error("Failed to parse private key, error: %v", err)
return nil, err
}
svr.sshConfig.AddHostKey(private)
sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHTunnelGateway.BindPort))
svr.sshListener, err = net.Listen("tcp", sshAddr)
if err != nil {
log.Error("Failed to listen on %v, error: %v", sshAddr, err)
return nil, err
}
log.Info("ssh server listening on %v", sshAddr)
}
// Listen for accepting connections from client using kcp protocol.
if cfg.KCPBindPort > 0 {
address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.KCPBindPort))
@ -293,7 +229,17 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
err = fmt.Errorf("listen on quic udp address %s error: %v", address, err)
return
}
log.Info("frps quic listen on quic %s", address)
log.Info("frps quic listen on %s", address)
}
if cfg.SSHTunnelGateway.BindPort > 0 {
sshGateway, err := ssh.NewGateway(cfg.SSHTunnelGateway, cfg.ProxyBindAddr, svr.virtualListener)
if err != nil {
err = fmt.Errorf("create ssh gateway error: %v", err)
return nil, err
}
svr.sshTunnelGateway = sshGateway
log.Info("frps sshTunnelGateway listen on port %d", cfg.SSHTunnelGateway.BindPort)
}
// Listen for accepting connections from client using websocket protocol.
@ -396,23 +342,26 @@ func (svr *Service) Run(ctx context.Context) {
svr.ctx = ctx
svr.cancel = cancel
if svr.sshListener != nil {
go svr.HandleSSHListener(svr.sshListener)
}
go svr.HandleListener(svr.virtualListener, true)
if svr.kcpListener != nil {
go svr.HandleListener(svr.kcpListener)
go svr.HandleListener(svr.kcpListener, false)
}
if svr.quicListener != nil {
go svr.HandleQUICListener(svr.quicListener)
}
go svr.HandleListener(svr.websocketListener)
go svr.HandleListener(svr.tlsListener)
go svr.HandleListener(svr.websocketListener, false)
go svr.HandleListener(svr.tlsListener, false)
if svr.rc.NatHoleController != nil {
go svr.rc.NatHoleController.CleanWorker(svr.ctx)
}
svr.HandleListener(svr.listener)
if svr.sshTunnelGateway != nil {
go svr.sshTunnelGateway.Run()
}
svr.HandleListener(svr.listener, false)
<-svr.ctx.Done()
// service context may not be canceled by svr.Close(), we should call it here to release resources
@ -422,10 +371,6 @@ func (svr *Service) Run(ctx context.Context) {
}
func (svr *Service) Close() error {
if svr.sshListener != nil {
svr.sshListener.Close()
svr.sshListener = nil
}
if svr.kcpListener != nil {
svr.kcpListener.Close()
svr.kcpListener = nil
@ -516,7 +461,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) {
}
}
func (svr *Service) HandleListener(l net.Listener) {
func (svr *Service) HandleListener(l net.Listener, internal bool) {
// Listen for incoming connections from client.
for {
c, err := l.Accept()
@ -532,8 +477,9 @@ func (svr *Service) HandleListener(l net.Listener) {
log.Trace("start check TLS connection...")
originConn := c
forceTLS := svr.cfg.Transport.TLS.Force && !internal
var isTLS, custom bool
c, isTLS, custom, err = utilnet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.Transport.TLS.Force, connReadTimeout)
c, isTLS, custom, err = utilnet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, forceTLS, connReadTimeout)
if err != nil {
log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err)
originConn.Close()
@ -543,7 +489,7 @@ func (svr *Service) HandleListener(l net.Listener) {
// Start a new goroutine to handle connection.
go func(ctx context.Context, frpConn net.Conn) {
if lo.FromPtr(svr.cfg.Transport.TCPMux) {
if lo.FromPtr(svr.cfg.Transport.TCPMux) && !internal {
fmuxCfg := fmux.DefaultConfig()
fmuxCfg.KeepAliveInterval = time.Duration(svr.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
fmuxCfg.LogOutput = io.Discard
@ -571,52 +517,6 @@ func (svr *Service) HandleListener(l net.Listener) {
}
}
func (svr *Service) HandleSSHListener(listener net.Listener) {
for {
tcpConn, err := listener.Accept()
if err != nil {
log.Error("failed to accept incoming ssh connection (%s)", err)
return
}
log.Info("new tcp conn connected: %v", tcpConn.RemoteAddr().String())
pxyPayloadCh := make(chan v1.ProxyConfigurer)
replyCh := make(chan interface{})
ss, err := frpssh.NewSSHService(tcpConn, svr.sshConfig, pxyPayloadCh, replyCh)
if err != nil {
log.Error("new ssh service error: %v", err)
continue
}
ss.Run()
go func() {
for {
pxyCfg := <-pxyPayloadCh
ctx := context.Background()
// TODO fill client common config and login msg
vs, err := frpssh.NewVirtualService(ctx, v1.ClientCommonConfig{}, *svr.cfg,
msg.Login{User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String()},
svr.rc, pxyCfg, ss, replyCh)
if err != nil {
log.Error("new virtual service error: %v", err)
ss.Close()
return
}
err = vs.Run(ctx)
if err != nil {
log.Error("proxy run error: %v", err)
vs.Close()
return
}
}
}()
}
}
func (svr *Service) HandleQUICListener(l *quic.Listener) {
// Listen for incoming connections from client.
for {