// Copyright 2023 The frp Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ssh import ( "context" "encoding/binary" "errors" "fmt" "net" "strings" "sync" "time" libio "github.com/fatedier/golib/io" "github.com/samber/lo" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" "github.com/fatedier/frp/client/proxy" "github.com/fatedier/frp/pkg/config" v1 "github.com/fatedier/frp/pkg/config/v1" "github.com/fatedier/frp/pkg/msg" "github.com/fatedier/frp/pkg/util/log" netpkg "github.com/fatedier/frp/pkg/util/net" "github.com/fatedier/frp/pkg/util/util" "github.com/fatedier/frp/pkg/util/xlog" "github.com/fatedier/frp/pkg/virtual" ) const ( // https://datatracker.ietf.org/doc/html/rfc4254#page-16 ChannelTypeServerOpenChannel = "forwarded-tcpip" RequestTypeForward = "tcpip-forward" ) type tcpipForward struct { Host string Port uint32 } // https://datatracker.ietf.org/doc/html/rfc4254#page-16 type forwardedTCPPayload struct { Addr string Port uint32 // can be default empty value but do not delete it // because ssh protocol shoule be reserved OriginAddr string OriginPort uint32 } type TunnelServer struct { underlyingConn net.Conn sshConn *ssh.ServerConn sc *ssh.ServerConfig vc *virtual.Client peerServerListener *netpkg.InternalListener doneCh chan struct{} closeDoneChOnce sync.Once } func NewTunnelServer(conn net.Conn, sc *ssh.ServerConfig, peerServerListener *netpkg.InternalListener) (*TunnelServer, error) { s := &TunnelServer{ underlyingConn: conn, sc: sc, peerServerListener: peerServerListener, doneCh: make(chan struct{}), } return s, nil } func (s *TunnelServer) Run() error { sshConn, channels, requests, err := ssh.NewServerConn(s.underlyingConn, s.sc) if err != nil { return err } s.sshConn = sshConn addr, extraPayload, err := s.waitForwardAddrAndExtraPayload(channels, requests, 3*time.Second) if err != nil { return err } clientCfg, pc, err := s.parseClientAndProxyConfigurer(addr, extraPayload) if err != nil { return err } clientCfg.Complete() if sshConn.Permissions != nil { clientCfg.User = util.EmptyOr(sshConn.Permissions.Extensions["user"], clientCfg.User) } pc.Complete(clientCfg.User) vc, err := virtual.NewClient(virtual.ClientOptions{ Common: clientCfg, Spec: &msg.ClientSpec{ Type: "ssh-tunnel", // If ssh does not require authentication, then the virtual client needs to authenticate through a token. // Otherwise, once ssh authentication is passed, the virtual client does not need to authenticate again. AlwaysAuthPass: !s.sc.NoClientAuth, }, HandleWorkConnCb: func(base *v1.ProxyBaseConfig, workConn net.Conn, m *msg.StartWorkConn) bool { // join workConn and ssh channel c, err := s.openConn(addr) if err != nil { return false } libio.Join(c, workConn) return false }, }) if err != nil { return err } s.vc = vc // transfer connection from virtual client to server peer listener go func() { l := s.vc.PeerListener() for { conn, err := l.Accept() if err != nil { return } _ = s.peerServerListener.PutConn(conn) } }() xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100}) ctx := xlog.NewContext(context.Background(), xl) go func() { _ = s.vc.Run(ctx) // If vc.Run returns, it means that the virtual client has been closed, and the ssh tunnel connection should be closed. // One scenario is that the virtual client exits due to login failure. s.closeDoneChOnce.Do(func() { _ = sshConn.Close() close(s.doneCh) }) }() s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc}) if err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil { log.Warn("wait proxy status ready error: %v", err) } else { _ = sshConn.Wait() } s.vc.Close() log.Trace("ssh tunnel connection from %v closed", sshConn.RemoteAddr()) s.closeDoneChOnce.Do(func() { _ = sshConn.Close() close(s.doneCh) }) return nil } func (s *TunnelServer) waitForwardAddrAndExtraPayload( channels <-chan ssh.NewChannel, requests <-chan *ssh.Request, timeout time.Duration, ) (*tcpipForward, string, error) { addrCh := make(chan *tcpipForward, 1) extraPayloadCh := make(chan string, 1) // get forward address go func() { addrGot := false for req := range requests { switch req.Type { case RequestTypeForward: if !addrGot { payload := tcpipForward{} if err := ssh.Unmarshal(req.Payload, &payload); err != nil { return } addrGot = true addrCh <- &payload } default: if req.WantReply { _ = req.Reply(true, nil) } } } }() // get extra payload go func() { for newChannel := range channels { // extraPayload will send to extraPayloadCh go s.handleNewChannel(newChannel, extraPayloadCh) } }() var ( addr *tcpipForward extraPayload string ) timer := time.NewTimer(timeout) defer timer.Stop() for { select { case v := <-addrCh: addr = v case extra := <-extraPayloadCh: extraPayload = extra case <-timer.C: return nil, "", fmt.Errorf("get addr and extra payload timeout") } if addr != nil && extraPayload != "" { break } } return addr, extraPayload, nil } func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, error) { cmd := &cobra.Command{} args := strings.Split(extraPayload, " ") if len(args) < 1 { return nil, nil, fmt.Errorf("invalid extra payload") } proxyType := strings.TrimSpace(args[0]) supportTypes := []string{"tcp", "http", "https", "tcpmux", "stcp"} if !lo.Contains(supportTypes, proxyType) { return nil, nil, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes) } pc := v1.NewProxyConfigurerByType(v1.ProxyType(proxyType)) if pc == nil { return nil, nil, fmt.Errorf("new proxy configurer error") } config.RegisterProxyFlags(cmd, pc) clientCfg := v1.ClientCommonConfig{} config.RegisterClientCommonConfigFlags(cmd, &clientCfg) if err := cmd.ParseFlags(args); err != nil { return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err) } // if name is not set, generate a random one if pc.GetBaseConfig().Name == "" { id, err := util.RandIDWithLen(8) if err != nil { return nil, nil, fmt.Errorf("generate random id error: %v", err) } pc.GetBaseConfig().Name = fmt.Sprintf("sshtunnel-%s-%s", proxyType, id) } return &clientCfg, pc, nil } func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh chan string) { ch, reqs, err := channel.Accept() if err != nil { return } go s.keepAlive(ch) for req := range reqs { if req.Type != "exec" { continue } if len(req.Payload) <= 4 { continue } end := 4 + binary.BigEndian.Uint32(req.Payload[:4]) if len(req.Payload) < int(end) { continue } extraPayload := string(req.Payload[4:end]) select { case extraPayloadCh <- extraPayload: default: } } } func (s *TunnelServer) keepAlive(ch ssh.Channel) { tk := time.NewTicker(time.Second * 30) defer tk.Stop() for { select { case <-tk.C: _, err := ch.SendRequest("heartbeat", false, nil) if err != nil { return } case <-s.doneCh: return } } } func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) { payload := forwardedTCPPayload{ Addr: addr.Host, Port: addr.Port, } channel, reqs, err := s.sshConn.OpenChannel(ChannelTypeServerOpenChannel, ssh.Marshal(&payload)) if err != nil { return nil, fmt.Errorf("open ssh channel error: %v", err) } go ssh.DiscardRequests(reqs) conn := netpkg.WrapReadWriteCloserToConn(channel, s.underlyingConn) return conn, nil } func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) error { ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() timer := time.NewTimer(timeout) defer timer.Stop() for { select { case <-ticker.C: ps, err := s.vc.Service().GetProxyStatus(name) if err != nil { continue } switch ps.Phase { case proxy.ProxyPhaseRunning: return nil case proxy.ProxyPhaseStartErr, proxy.ProxyPhaseClosed: return errors.New(ps.Err) } case <-timer.C: return fmt.Errorf("wait proxy status ready timeout") case <-s.doneCh: return fmt.Errorf("ssh tunnel server closed") } } }