optimize some code (#3801)

This commit is contained in:
fatedier
2023-11-27 15:47:49 +08:00
committed by GitHub
parent d5b41f1e14
commit 69ae2b0b69
52 changed files with 880 additions and 600 deletions

View File

@ -26,21 +26,21 @@ import (
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/transport"
"github.com/fatedier/frp/pkg/util/log"
utilnet "github.com/fatedier/frp/pkg/util/net"
netpkg "github.com/fatedier/frp/pkg/util/net"
)
type Gateway struct {
bindPort int
ln net.Listener
serverPeerListener *utilnet.InternalListener
peerServerListener *netpkg.InternalListener
sshConfig *ssh.ServerConfig
}
func NewGateway(
cfg v1.SSHTunnelGateway, bindAddr string,
serverPeerListener *utilnet.InternalListener,
peerServerListener *netpkg.InternalListener,
) (*Gateway, error) {
sshConfig := &ssh.ServerConfig{}
@ -71,15 +71,8 @@ func NewGateway(
}
sshConfig.AddHostKey(privateKey)
sshConfig.NoClientAuth = cfg.AuthorizedKeysFile == ""
sshConfig.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if cfg.AuthorizedKeysFile == "" {
return &ssh.Permissions{
Extensions: map[string]string{
"user": "",
},
}, nil
}
authorizedKeysMap, err := loadAuthorizedKeysFromFile(cfg.AuthorizedKeysFile)
if err != nil {
return nil, fmt.Errorf("internal error")
@ -103,7 +96,7 @@ func NewGateway(
return &Gateway{
bindPort: cfg.BindPort,
ln: ln,
serverPeerListener: serverPeerListener,
peerServerListener: peerServerListener,
sshConfig: sshConfig,
}, nil
}
@ -121,7 +114,7 @@ func (g *Gateway) Run() {
func (g *Gateway) handleConn(conn net.Conn) {
defer conn.Close()
ts, err := NewTunnelServer(conn, g.sshConfig, g.serverPeerListener)
ts, err := NewTunnelServer(conn, g.sshConfig, g.peerServerListener)
if err != nil {
return
}

View File

@ -17,9 +17,11 @@ package ssh
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
libio "github.com/fatedier/golib/io"
@ -27,10 +29,12 @@ import (
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"github.com/fatedier/frp/client/proxy"
"github.com/fatedier/frp/pkg/config"
v1 "github.com/fatedier/frp/pkg/config/v1"
"github.com/fatedier/frp/pkg/msg"
utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/log"
netpkg "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/util"
"github.com/fatedier/frp/pkg/util/xlog"
"github.com/fatedier/frp/pkg/virtual"
@ -64,15 +68,16 @@ type TunnelServer struct {
sc *ssh.ServerConfig
vc *virtual.Client
serverPeerListener *utilnet.InternalListener
peerServerListener *netpkg.InternalListener
doneCh chan struct{}
closeDoneChOnce sync.Once
}
func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, serverPeerListener *utilnet.InternalListener) (*TunnelServer, error) {
func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, peerServerListener *netpkg.InternalListener) (*TunnelServer, error) {
s := &TunnelServer{
underlyingConn: conn,
sc: sc,
serverPeerListener: serverPeerListener,
peerServerListener: peerServerListener,
doneCh: make(chan struct{}),
}
return s, nil
@ -94,19 +99,35 @@ func (s *TunnelServer) Run() error {
if err != nil {
return err
}
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User)
clientCfg.Complete()
if sshConn.Permissions != nil {
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User)
}
pc.Complete(clientCfg.User)
s.vc = virtual.NewClient(clientCfg)
// join workConn and ssh channel
s.vc.SetInWorkConnCallback(func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool {
c, err := s.openConn(addr)
if err != nil {
vc, err := virtual.NewClient(virtual.ClientOptions{
Common: clientCfg,
Spec: &msg.ClientSpec{
Type: "ssh-tunnel",
// If ssh does not require authentication, then the virtual client needs to authenticate through a token.
// Otherwise, once ssh authentication is passed, the virtual client does not need to authenticate again.
AlwaysAuthPass: !s.sc.NoClientAuth,
},
HandleWorkConnCb: func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool {
// join workConn and ssh channel
c, err := s.openConn(addr)
if err != nil {
return false
}
libio.Join(c, workConn)
return false
}
libio.Join(c, workConn)
return false
},
})
if err != nil {
return err
}
s.vc = vc
// transfer connection from virtual client to server peer listener
go func() {
l := s.vc.PeerListener()
@ -115,21 +136,35 @@ func (s *TunnelServer) Run() error {
if err != nil {
return
}
_ = s.serverPeerListener.PutConn(conn)
_ = s.peerServerListener.PutConn(conn)
}
}()
xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100})
ctx := xlog.NewContext(context.Background(), xl)
go func() {
_ = s.vc.Run(ctx)
// If vc.Run returns, it means that the virtual client has been closed, and the ssh tunnel connection should be closed.
// One scenario is that the virtual client exits due to login failure.
s.closeDoneChOnce.Do(func() {
_ = sshConn.Close()
close(s.doneCh)
})
}()
s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc})
_ = sshConn.Wait()
_ = sshConn.Close()
if err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil {
log.Warn("wait proxy status ready error: %v", err)
} else {
_ = sshConn.Wait()
}
s.vc.Close()
close(s.doneCh)
log.Trace("ssh tunnel connection from %v closed", sshConn.RemoteAddr())
s.closeDoneChOnce.Do(func() {
_ = sshConn.Close()
close(s.doneCh)
})
return nil
}
@ -217,6 +252,14 @@ func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPaylo
if err := cmd.ParseFlags(args); err != nil {
return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err)
}
// if name is not set, generate a random one
if pc.GetBaseConfig().Name == "" {
id, err := util.RandIDWithLen(8)
if err != nil {
return nil, nil, fmt.Errorf("generate random id error: %v", err)
}
pc.GetBaseConfig().Name = fmt.Sprintf("sshtunnel-%s-%s", proxyType, id)
}
return &clientCfg, pc, nil
}
@ -274,6 +317,34 @@ func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
}
go ssh.DiscardRequests(reqs)
conn := utilnet.WrapReadWriteCloserToConn(channel, s.underlyingConn)
conn := netpkg.WrapReadWriteCloserToConn(channel, s.underlyingConn)
return conn, nil
}
func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) error {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case <-ticker.C:
ps, err := s.vc.Service().GetProxyStatus(name)
if err != nil {
continue
}
switch ps.Phase {
case proxy.ProxyPhaseRunning:
return nil
case proxy.ProxyPhaseStartErr, proxy.ProxyPhaseClosed:
return errors.New(ps.Err)
}
case <-timer.C:
return fmt.Errorf("wait proxy status ready timeout")
case <-s.doneCh:
return fmt.Errorf("ssh tunnel server closed")
}
}
}