mirror of
https://github.com/fatedier/frp.git
synced 2025-07-29 14:22:20 +07:00
optimize some code (#3801)
This commit is contained in:
@ -26,21 +26,21 @@ import (
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
type Gateway struct {
|
||||
bindPort int
|
||||
ln net.Listener
|
||||
|
||||
serverPeerListener *utilnet.InternalListener
|
||||
peerServerListener *netpkg.InternalListener
|
||||
|
||||
sshConfig *ssh.ServerConfig
|
||||
}
|
||||
|
||||
func NewGateway(
|
||||
cfg v1.SSHTunnelGateway, bindAddr string,
|
||||
serverPeerListener *utilnet.InternalListener,
|
||||
peerServerListener *netpkg.InternalListener,
|
||||
) (*Gateway, error) {
|
||||
sshConfig := &ssh.ServerConfig{}
|
||||
|
||||
@ -71,15 +71,8 @@ func NewGateway(
|
||||
}
|
||||
sshConfig.AddHostKey(privateKey)
|
||||
|
||||
sshConfig.NoClientAuth = cfg.AuthorizedKeysFile == ""
|
||||
sshConfig.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if cfg.AuthorizedKeysFile == "" {
|
||||
return &ssh.Permissions{
|
||||
Extensions: map[string]string{
|
||||
"user": "",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
authorizedKeysMap, err := loadAuthorizedKeysFromFile(cfg.AuthorizedKeysFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("internal error")
|
||||
@ -103,7 +96,7 @@ func NewGateway(
|
||||
return &Gateway{
|
||||
bindPort: cfg.BindPort,
|
||||
ln: ln,
|
||||
serverPeerListener: serverPeerListener,
|
||||
peerServerListener: peerServerListener,
|
||||
sshConfig: sshConfig,
|
||||
}, nil
|
||||
}
|
||||
@ -121,7 +114,7 @@ func (g *Gateway) Run() {
|
||||
func (g *Gateway) handleConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
ts, err := NewTunnelServer(conn, g.sshConfig, g.serverPeerListener)
|
||||
ts, err := NewTunnelServer(conn, g.sshConfig, g.peerServerListener)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -17,9 +17,11 @@ package ssh
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
libio "github.com/fatedier/golib/io"
|
||||
@ -27,10 +29,12 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/fatedier/frp/client/proxy"
|
||||
"github.com/fatedier/frp/pkg/config"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
utilnet "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/virtual"
|
||||
@ -64,15 +68,16 @@ type TunnelServer struct {
|
||||
sc *ssh.ServerConfig
|
||||
|
||||
vc *virtual.Client
|
||||
serverPeerListener *utilnet.InternalListener
|
||||
peerServerListener *netpkg.InternalListener
|
||||
doneCh chan struct{}
|
||||
closeDoneChOnce sync.Once
|
||||
}
|
||||
|
||||
func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, serverPeerListener *utilnet.InternalListener) (*TunnelServer, error) {
|
||||
func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, peerServerListener *netpkg.InternalListener) (*TunnelServer, error) {
|
||||
s := &TunnelServer{
|
||||
underlyingConn: conn,
|
||||
sc: sc,
|
||||
serverPeerListener: serverPeerListener,
|
||||
peerServerListener: peerServerListener,
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
return s, nil
|
||||
@ -94,19 +99,35 @@ func (s *TunnelServer) Run() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User)
|
||||
clientCfg.Complete()
|
||||
if sshConn.Permissions != nil {
|
||||
clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User)
|
||||
}
|
||||
pc.Complete(clientCfg.User)
|
||||
|
||||
s.vc = virtual.NewClient(clientCfg)
|
||||
// join workConn and ssh channel
|
||||
s.vc.SetInWorkConnCallback(func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool {
|
||||
c, err := s.openConn(addr)
|
||||
if err != nil {
|
||||
vc, err := virtual.NewClient(virtual.ClientOptions{
|
||||
Common: clientCfg,
|
||||
Spec: &msg.ClientSpec{
|
||||
Type: "ssh-tunnel",
|
||||
// If ssh does not require authentication, then the virtual client needs to authenticate through a token.
|
||||
// Otherwise, once ssh authentication is passed, the virtual client does not need to authenticate again.
|
||||
AlwaysAuthPass: !s.sc.NoClientAuth,
|
||||
},
|
||||
HandleWorkConnCb: func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool {
|
||||
// join workConn and ssh channel
|
||||
c, err := s.openConn(addr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
libio.Join(c, workConn)
|
||||
return false
|
||||
}
|
||||
libio.Join(c, workConn)
|
||||
return false
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.vc = vc
|
||||
|
||||
// transfer connection from virtual client to server peer listener
|
||||
go func() {
|
||||
l := s.vc.PeerListener()
|
||||
@ -115,21 +136,35 @@ func (s *TunnelServer) Run() error {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = s.serverPeerListener.PutConn(conn)
|
||||
_ = s.peerServerListener.PutConn(conn)
|
||||
}
|
||||
}()
|
||||
xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100})
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
go func() {
|
||||
_ = s.vc.Run(ctx)
|
||||
// If vc.Run returns, it means that the virtual client has been closed, and the ssh tunnel connection should be closed.
|
||||
// One scenario is that the virtual client exits due to login failure.
|
||||
s.closeDoneChOnce.Do(func() {
|
||||
_ = sshConn.Close()
|
||||
close(s.doneCh)
|
||||
})
|
||||
}()
|
||||
|
||||
s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc})
|
||||
|
||||
_ = sshConn.Wait()
|
||||
_ = sshConn.Close()
|
||||
if err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil {
|
||||
log.Warn("wait proxy status ready error: %v", err)
|
||||
} else {
|
||||
_ = sshConn.Wait()
|
||||
}
|
||||
|
||||
s.vc.Close()
|
||||
close(s.doneCh)
|
||||
log.Trace("ssh tunnel connection from %v closed", sshConn.RemoteAddr())
|
||||
s.closeDoneChOnce.Do(func() {
|
||||
_ = sshConn.Close()
|
||||
close(s.doneCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -217,6 +252,14 @@ func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPaylo
|
||||
if err := cmd.ParseFlags(args); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err)
|
||||
}
|
||||
// if name is not set, generate a random one
|
||||
if pc.GetBaseConfig().Name == "" {
|
||||
id, err := util.RandIDWithLen(8)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate random id error: %v", err)
|
||||
}
|
||||
pc.GetBaseConfig().Name = fmt.Sprintf("sshtunnel-%s-%s", proxyType, id)
|
||||
}
|
||||
return &clientCfg, pc, nil
|
||||
}
|
||||
|
||||
@ -274,6 +317,34 @@ func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) {
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
conn := utilnet.WrapReadWriteCloserToConn(channel, s.underlyingConn)
|
||||
conn := netpkg.WrapReadWriteCloserToConn(channel, s.underlyingConn)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) error {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ps, err := s.vc.Service().GetProxyStatus(name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
switch ps.Phase {
|
||||
case proxy.ProxyPhaseRunning:
|
||||
return nil
|
||||
case proxy.ProxyPhaseStartErr, proxy.ProxyPhaseClosed:
|
||||
return errors.New(ps.Err)
|
||||
}
|
||||
case <-timer.C:
|
||||
return fmt.Errorf("wait proxy status ready timeout")
|
||||
case <-s.doneCh:
|
||||
return fmt.Errorf("ssh tunnel server closed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user