mirror of
https://github.com/fatedier/frp.git
synced 2025-07-20 12:50:04 +07:00
sshTunnelGateway refactor (#3784)
This commit is contained in:
@ -21,7 +21,6 @@ import (
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -230,14 +229,8 @@ func (pxy *BaseProxy) handleUserTCPConnection(userConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
var workConn net.Conn
|
||||
|
||||
// try all connections from the pool
|
||||
if strings.HasPrefix(pxy.GetLoginMsg().User, v1.SSHClientLoginUserPrefix) {
|
||||
workConn, err = pxy.getWorkConnFn()
|
||||
} else {
|
||||
workConn, err = pxy.GetWorkConnFromPool(userConn.RemoteAddr(), userConn.LocalAddr())
|
||||
}
|
||||
workConn, err := pxy.GetWorkConnFromPool(userConn.RemoteAddr(), userConn.LocalAddr())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -18,13 +18,10 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@ -32,7 +29,6 @@ import (
|
||||
fmux "github.com/hashicorp/yamux"
|
||||
quic "github.com/quic-go/quic-go"
|
||||
"github.com/samber/lo"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/fatedier/frp/assets"
|
||||
"github.com/fatedier/frp/pkg/auth"
|
||||
@ -41,7 +37,7 @@ import (
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/nathole"
|
||||
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
||||
frpssh "github.com/fatedier/frp/pkg/ssh"
|
||||
"github.com/fatedier/frp/pkg/ssh"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||
@ -71,10 +67,6 @@ type Service struct {
|
||||
// Accept connections from client
|
||||
listener net.Listener
|
||||
|
||||
// Accept connections using ssh
|
||||
sshListener net.Listener
|
||||
sshConfig *ssh.ServerConfig
|
||||
|
||||
// Accept connections using kcp
|
||||
kcpListener net.Listener
|
||||
|
||||
@ -87,6 +79,8 @@ type Service struct {
|
||||
// Accept frp tls connections
|
||||
tlsListener net.Listener
|
||||
|
||||
virtualListener *utilnet.InternalListener
|
||||
|
||||
// Manage all controllers
|
||||
ctlManager *ControlManager
|
||||
|
||||
@ -102,6 +96,8 @@ type Service struct {
|
||||
// All resource managers and controllers
|
||||
rc *controller.ResourceController
|
||||
|
||||
sshTunnelGateway *ssh.Gateway
|
||||
|
||||
// Verifies authentication based on selected method
|
||||
authVerifier auth.Verifier
|
||||
|
||||
@ -133,6 +129,7 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
|
||||
TCPPortManager: ports.NewManager("tcp", cfg.ProxyBindAddr, cfg.AllowPorts),
|
||||
UDPPortManager: ports.NewManager("udp", cfg.ProxyBindAddr, cfg.AllowPorts),
|
||||
},
|
||||
virtualListener: utilnet.NewInternalListener(),
|
||||
httpVhostRouter: vhost.NewRouters(),
|
||||
authVerifier: auth.NewAuthVerifier(cfg.Auth),
|
||||
tlsConfig: tlsConfig,
|
||||
@ -208,67 +205,6 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
|
||||
svr.listener = ln
|
||||
log.Info("frps tcp listen on %s", address)
|
||||
|
||||
if cfg.SSHTunnelGateway.BindPort > 0 {
|
||||
|
||||
if cfg.SSHTunnelGateway.PublicKeyFilesPath != "" {
|
||||
cfg.SSHTunnelGateway.PublicKeyFilesMap, err = v1.LoadSSHPublicKeyFilesInDir(cfg.SSHTunnelGateway.PublicKeyFilesPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load ssh all public key files error: %v", err)
|
||||
}
|
||||
log.Info("load %v public key files success", cfg.SSHTunnelGateway.PublicKeyFilesPath)
|
||||
}
|
||||
|
||||
svr.sshConfig = &ssh.ServerConfig{
|
||||
NoClientAuth: lo.If(cfg.SSHTunnelGateway.PublicKeyFilesPath == "", true).Else(false),
|
||||
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
parsedAuthorizedKey, ok := cfg.SSHTunnelGateway.PublicKeyFilesMap[ssh.FingerprintSHA256(key)]
|
||||
if !ok {
|
||||
return nil, errors.New("cannot find public key file")
|
||||
}
|
||||
|
||||
if key.Type() == parsedAuthorizedKey.Type() && reflect.DeepEqual(parsedAuthorizedKey, key) {
|
||||
return &ssh.Permissions{
|
||||
Extensions: map[string]string{},
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unknown public key for %q", conn.User())
|
||||
},
|
||||
}
|
||||
|
||||
var privateBytes []byte
|
||||
if cfg.SSHTunnelGateway.PrivateKeyFilePath != "" {
|
||||
privateBytes, err = os.ReadFile(cfg.SSHTunnelGateway.PrivateKeyFilePath)
|
||||
if err != nil {
|
||||
log.Error("Failed to load private key")
|
||||
return nil, err
|
||||
}
|
||||
log.Info("load %v private key file success", cfg.SSHTunnelGateway.PrivateKeyFilePath)
|
||||
} else {
|
||||
privateBytes, err = v1.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
log.Error("Failed to load private key")
|
||||
return nil, err
|
||||
}
|
||||
log.Info("auto gen private key file success")
|
||||
}
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
if err != nil {
|
||||
log.Error("Failed to parse private key, error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svr.sshConfig.AddHostKey(private)
|
||||
|
||||
sshAddr := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.SSHTunnelGateway.BindPort))
|
||||
svr.sshListener, err = net.Listen("tcp", sshAddr)
|
||||
if err != nil {
|
||||
log.Error("Failed to listen on %v, error: %v", sshAddr, err)
|
||||
return nil, err
|
||||
}
|
||||
log.Info("ssh server listening on %v", sshAddr)
|
||||
}
|
||||
|
||||
// Listen for accepting connections from client using kcp protocol.
|
||||
if cfg.KCPBindPort > 0 {
|
||||
address := net.JoinHostPort(cfg.BindAddr, strconv.Itoa(cfg.KCPBindPort))
|
||||
@ -293,7 +229,17 @@ func NewService(cfg *v1.ServerConfig) (svr *Service, err error) {
|
||||
err = fmt.Errorf("listen on quic udp address %s error: %v", address, err)
|
||||
return
|
||||
}
|
||||
log.Info("frps quic listen on quic %s", address)
|
||||
log.Info("frps quic listen on %s", address)
|
||||
}
|
||||
|
||||
if cfg.SSHTunnelGateway.BindPort > 0 {
|
||||
sshGateway, err := ssh.NewGateway(cfg.SSHTunnelGateway, cfg.ProxyBindAddr, svr.virtualListener)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("create ssh gateway error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
svr.sshTunnelGateway = sshGateway
|
||||
log.Info("frps sshTunnelGateway listen on port %d", cfg.SSHTunnelGateway.BindPort)
|
||||
}
|
||||
|
||||
// Listen for accepting connections from client using websocket protocol.
|
||||
@ -396,23 +342,26 @@ func (svr *Service) Run(ctx context.Context) {
|
||||
svr.ctx = ctx
|
||||
svr.cancel = cancel
|
||||
|
||||
if svr.sshListener != nil {
|
||||
go svr.HandleSSHListener(svr.sshListener)
|
||||
}
|
||||
go svr.HandleListener(svr.virtualListener, true)
|
||||
|
||||
if svr.kcpListener != nil {
|
||||
go svr.HandleListener(svr.kcpListener)
|
||||
go svr.HandleListener(svr.kcpListener, false)
|
||||
}
|
||||
if svr.quicListener != nil {
|
||||
go svr.HandleQUICListener(svr.quicListener)
|
||||
}
|
||||
go svr.HandleListener(svr.websocketListener)
|
||||
go svr.HandleListener(svr.tlsListener)
|
||||
go svr.HandleListener(svr.websocketListener, false)
|
||||
go svr.HandleListener(svr.tlsListener, false)
|
||||
|
||||
if svr.rc.NatHoleController != nil {
|
||||
go svr.rc.NatHoleController.CleanWorker(svr.ctx)
|
||||
}
|
||||
svr.HandleListener(svr.listener)
|
||||
|
||||
if svr.sshTunnelGateway != nil {
|
||||
go svr.sshTunnelGateway.Run()
|
||||
}
|
||||
|
||||
svr.HandleListener(svr.listener, false)
|
||||
|
||||
<-svr.ctx.Done()
|
||||
// service context may not be canceled by svr.Close(), we should call it here to release resources
|
||||
@ -422,10 +371,6 @@ func (svr *Service) Run(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (svr *Service) Close() error {
|
||||
if svr.sshListener != nil {
|
||||
svr.sshListener.Close()
|
||||
svr.sshListener = nil
|
||||
}
|
||||
if svr.kcpListener != nil {
|
||||
svr.kcpListener.Close()
|
||||
svr.kcpListener = nil
|
||||
@ -516,7 +461,7 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (svr *Service) HandleListener(l net.Listener) {
|
||||
func (svr *Service) HandleListener(l net.Listener, internal bool) {
|
||||
// Listen for incoming connections from client.
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
@ -532,8 +477,9 @@ func (svr *Service) HandleListener(l net.Listener) {
|
||||
|
||||
log.Trace("start check TLS connection...")
|
||||
originConn := c
|
||||
forceTLS := svr.cfg.Transport.TLS.Force && !internal
|
||||
var isTLS, custom bool
|
||||
c, isTLS, custom, err = utilnet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, svr.cfg.Transport.TLS.Force, connReadTimeout)
|
||||
c, isTLS, custom, err = utilnet.CheckAndEnableTLSServerConnWithTimeout(c, svr.tlsConfig, forceTLS, connReadTimeout)
|
||||
if err != nil {
|
||||
log.Warn("CheckAndEnableTLSServerConnWithTimeout error: %v", err)
|
||||
originConn.Close()
|
||||
@ -543,7 +489,7 @@ func (svr *Service) HandleListener(l net.Listener) {
|
||||
|
||||
// Start a new goroutine to handle connection.
|
||||
go func(ctx context.Context, frpConn net.Conn) {
|
||||
if lo.FromPtr(svr.cfg.Transport.TCPMux) {
|
||||
if lo.FromPtr(svr.cfg.Transport.TCPMux) && !internal {
|
||||
fmuxCfg := fmux.DefaultConfig()
|
||||
fmuxCfg.KeepAliveInterval = time.Duration(svr.cfg.Transport.TCPMuxKeepaliveInterval) * time.Second
|
||||
fmuxCfg.LogOutput = io.Discard
|
||||
@ -571,52 +517,6 @@ func (svr *Service) HandleListener(l net.Listener) {
|
||||
}
|
||||
}
|
||||
|
||||
func (svr *Service) HandleSSHListener(listener net.Listener) {
|
||||
for {
|
||||
tcpConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Error("failed to accept incoming ssh connection (%s)", err)
|
||||
return
|
||||
}
|
||||
log.Info("new tcp conn connected: %v", tcpConn.RemoteAddr().String())
|
||||
|
||||
pxyPayloadCh := make(chan v1.ProxyConfigurer)
|
||||
replyCh := make(chan interface{})
|
||||
|
||||
ss, err := frpssh.NewSSHService(tcpConn, svr.sshConfig, pxyPayloadCh, replyCh)
|
||||
if err != nil {
|
||||
log.Error("new ssh service error: %v", err)
|
||||
continue
|
||||
}
|
||||
ss.Run()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
pxyCfg := <-pxyPayloadCh
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// TODO fill client common config and login msg
|
||||
vs, err := frpssh.NewVirtualService(ctx, v1.ClientCommonConfig{}, *svr.cfg,
|
||||
msg.Login{User: v1.SSHClientLoginUserPrefix + tcpConn.RemoteAddr().String()},
|
||||
svr.rc, pxyCfg, ss, replyCh)
|
||||
if err != nil {
|
||||
log.Error("new virtual service error: %v", err)
|
||||
ss.Close()
|
||||
return
|
||||
}
|
||||
|
||||
err = vs.Run(ctx)
|
||||
if err != nil {
|
||||
log.Error("proxy run error: %v", err)
|
||||
vs.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (svr *Service) HandleQUICListener(l *quic.Listener) {
|
||||
// Listen for incoming connections from client.
|
||||
for {
|
||||
|
Reference in New Issue
Block a user