diff --git a/client/service.go b/client/service.go index 5db1bd28..c43f8f60 100644 --- a/client/service.go +++ b/client/service.go @@ -42,6 +42,14 @@ func init() { crypto.DefaultSalt = "frp" } +type cancelErr struct { + Err error +} + +func (e cancelErr) Error() string { + return e.Err.Error() +} + // ServiceOptions contains options for creating a new client service. type ServiceOptions struct { Common *v1.ClientCommonConfig @@ -108,7 +116,7 @@ type Service struct { // service context ctx context.Context // call cancel to stop service - cancel context.CancelFunc + cancel context.CancelCauseFunc gracefulShutdownDuration time.Duration connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector @@ -145,7 +153,7 @@ func NewService(options ServiceOptions) (*Service, error) { } func (svr *Service) Run(ctx context.Context) error { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancelCause(ctx) svr.ctx = xlog.NewContext(ctx, xlog.FromContextSafe(ctx)) svr.cancel = cancel @@ -157,7 +165,9 @@ func (svr *Service) Run(ctx context.Context) error { // first login to frps svr.loopLoginUntilSuccess(10*time.Second, lo.FromPtr(svr.common.LoginFailExit)) if svr.ctl == nil { - return fmt.Errorf("the process exited because the first login to the server failed, and the loginFailExit feature is enabled") + cancelCause := cancelErr{} + _ = errors.As(context.Cause(svr.ctx), &cancelCause) + return fmt.Errorf("login to the server failed: %v. With loginFailExit enabled, no additional retries will be attempted", cancelCause.Err) } go svr.keepControllerWorking() @@ -280,7 +290,7 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE if err != nil { xl.Warn("connect to server error: %v", err) if firstLoginExit { - svr.cancel() + svr.cancel(cancelErr{Err: err}) } return err } @@ -356,7 +366,7 @@ func (svr *Service) Close() { func (svr *Service) GracefulClose(d time.Duration) { svr.gracefulShutdownDuration = d - svr.cancel() + svr.cancel(nil) } func (svr *Service) stop() { diff --git a/pkg/config/flags.go b/pkg/config/flags.go index c0e87164..712e3d3f 100644 --- a/pkg/config/flags.go +++ b/pkg/config/flags.go @@ -25,6 +25,18 @@ import ( "github.com/fatedier/frp/pkg/config/v1/validation" ) +type RegisterFlagOption func(*registerFlagOptions) + +type registerFlagOptions struct { + sshMode bool +} + +func WithSSHMode() RegisterFlagOption { + return func(o *registerFlagOptions) { + o.sshMode = true + } +} + type BandwidthQuantityFlag struct { V *types.BandwidthQuantity } @@ -41,8 +53,9 @@ func (f *BandwidthQuantityFlag) Type() string { return "string" } -func RegisterProxyFlags(cmd *cobra.Command, c v1.ProxyConfigurer) { - registerProxyBaseConfigFlags(cmd, c.GetBaseConfig()) +func RegisterProxyFlags(cmd *cobra.Command, c v1.ProxyConfigurer, opts ...RegisterFlagOption) { + registerProxyBaseConfigFlags(cmd, c.GetBaseConfig(), opts...) + switch cc := c.(type) { case *v1.TCPProxyConfig: cmd.Flags().IntVarP(&cc.RemotePort, "remote_port", "r", 0, "remote port") @@ -73,17 +86,25 @@ func RegisterProxyFlags(cmd *cobra.Command, c v1.ProxyConfigurer) { } } -func registerProxyBaseConfigFlags(cmd *cobra.Command, c *v1.ProxyBaseConfig) { +func registerProxyBaseConfigFlags(cmd *cobra.Command, c *v1.ProxyBaseConfig, opts ...RegisterFlagOption) { if c == nil { return } + options := ®isterFlagOptions{} + for _, opt := range opts { + opt(options) + } + cmd.Flags().StringVarP(&c.Name, "proxy_name", "n", "", "proxy name") - cmd.Flags().StringVarP(&c.LocalIP, "local_ip", "i", "127.0.0.1", "local ip") - cmd.Flags().IntVarP(&c.LocalPort, "local_port", "l", 0, "local port") - cmd.Flags().BoolVarP(&c.Transport.UseEncryption, "ue", "", false, "use encryption") - cmd.Flags().BoolVarP(&c.Transport.UseCompression, "uc", "", false, "use compression") - cmd.Flags().StringVarP(&c.Transport.BandwidthLimitMode, "bandwidth_limit_mode", "", types.BandwidthLimitModeClient, "bandwidth limit mode") - cmd.Flags().VarP(&BandwidthQuantityFlag{V: &c.Transport.BandwidthLimit}, "bandwidth_limit", "", "bandwidth limit (e.g. 100KB or 1MB)") + + if !options.sshMode { + cmd.Flags().StringVarP(&c.LocalIP, "local_ip", "i", "127.0.0.1", "local ip") + cmd.Flags().IntVarP(&c.LocalPort, "local_port", "l", 0, "local port") + cmd.Flags().BoolVarP(&c.Transport.UseEncryption, "ue", "", false, "use encryption") + cmd.Flags().BoolVarP(&c.Transport.UseCompression, "uc", "", false, "use compression") + cmd.Flags().StringVarP(&c.Transport.BandwidthLimitMode, "bandwidth_limit_mode", "", types.BandwidthLimitModeClient, "bandwidth limit mode") + cmd.Flags().VarP(&BandwidthQuantityFlag{V: &c.Transport.BandwidthLimit}, "bandwidth_limit", "", "bandwidth limit (e.g. 100KB or 1MB)") + } } func registerProxyDomainConfigFlags(cmd *cobra.Command, c *v1.DomainConfig) { @@ -94,13 +115,13 @@ func registerProxyDomainConfigFlags(cmd *cobra.Command, c *v1.DomainConfig) { cmd.Flags().StringVarP(&c.SubDomain, "sd", "", "", "sub domain") } -func RegisterVisitorFlags(cmd *cobra.Command, c v1.VisitorConfigurer) { - registerVisitorBaseConfigFlags(cmd, c.GetBaseConfig()) +func RegisterVisitorFlags(cmd *cobra.Command, c v1.VisitorConfigurer, opts ...RegisterFlagOption) { + registerVisitorBaseConfigFlags(cmd, c.GetBaseConfig(), opts...) // add visitor flags if exist } -func registerVisitorBaseConfigFlags(cmd *cobra.Command, c *v1.VisitorBaseConfig) { +func registerVisitorBaseConfigFlags(cmd *cobra.Command, c *v1.VisitorBaseConfig, _ ...RegisterFlagOption) { if c == nil { return } @@ -113,21 +134,27 @@ func registerVisitorBaseConfigFlags(cmd *cobra.Command, c *v1.VisitorBaseConfig) cmd.Flags().IntVarP(&c.BindPort, "bind_port", "", 0, "bind port") } -func RegisterClientCommonConfigFlags(cmd *cobra.Command, c *v1.ClientCommonConfig) { - cmd.PersistentFlags().StringVarP(&c.ServerAddr, "server_addr", "s", "127.0.0.1", "frp server's address") - cmd.PersistentFlags().IntVarP(&c.ServerPort, "server_port", "P", 7000, "frp server's port") - cmd.PersistentFlags().StringVarP(&c.User, "user", "u", "", "user") - cmd.PersistentFlags().StringVarP(&c.Transport.Protocol, "protocol", "p", "tcp", - fmt.Sprintf("optional values are %v", validation.SupportedTransportProtocols)) - cmd.PersistentFlags().StringVarP(&c.Auth.Token, "token", "t", "", "auth token") - cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level") - cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "console or file path") - cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log file reversed days") - cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console") - cmd.PersistentFlags().StringVarP(&c.Transport.TLS.ServerName, "tls_server_name", "", "", "specify the custom server name of tls certificate") - cmd.PersistentFlags().StringVarP(&c.DNSServer, "dns_server", "", "", "specify dns server instead of using system default one") +func RegisterClientCommonConfigFlags(cmd *cobra.Command, c *v1.ClientCommonConfig, opts ...RegisterFlagOption) { + options := ®isterFlagOptions{} + for _, opt := range opts { + opt(options) + } - c.Transport.TLS.Enable = cmd.PersistentFlags().BoolP("tls_enable", "", true, "enable frpc tls") + if !options.sshMode { + cmd.PersistentFlags().StringVarP(&c.ServerAddr, "server_addr", "s", "127.0.0.1", "frp server's address") + cmd.PersistentFlags().IntVarP(&c.ServerPort, "server_port", "P", 7000, "frp server's port") + cmd.PersistentFlags().StringVarP(&c.Transport.Protocol, "protocol", "p", "tcp", + fmt.Sprintf("optional values are %v", validation.SupportedTransportProtocols)) + cmd.PersistentFlags().StringVarP(&c.Log.Level, "log_level", "", "info", "log level") + cmd.PersistentFlags().StringVarP(&c.Log.To, "log_file", "", "console", "console or file path") + cmd.PersistentFlags().Int64VarP(&c.Log.MaxDays, "log_max_days", "", 3, "log file reversed days") + cmd.PersistentFlags().BoolVarP(&c.Log.DisablePrintColor, "disable_log_color", "", false, "disable log color in console") + cmd.PersistentFlags().StringVarP(&c.Transport.TLS.ServerName, "tls_server_name", "", "", "specify the custom server name of tls certificate") + cmd.PersistentFlags().StringVarP(&c.DNSServer, "dns_server", "", "", "specify dns server instead of using system default one") + c.Transport.TLS.Enable = cmd.PersistentFlags().BoolP("tls_enable", "", true, "enable frpc tls") + } + cmd.PersistentFlags().StringVarP(&c.User, "user", "u", "", "user") + cmd.PersistentFlags().StringVarP(&c.Auth.Token, "token", "t", "", "auth token") } type PortsRangeSliceFlag struct { @@ -185,7 +212,7 @@ func (f *BoolFuncFlag) Type() string { return "bool" } -func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig) { +func RegisterServerConfigFlags(cmd *cobra.Command, c *v1.ServerConfig, opts ...RegisterFlagOption) { cmd.PersistentFlags().StringVarP(&c.BindAddr, "bind_addr", "", "0.0.0.0", "bind address") cmd.PersistentFlags().IntVarP(&c.BindPort, "bind_port", "p", 7000, "bind port") cmd.PersistentFlags().IntVarP(&c.KCPBindPort, "kcp_bind_port", "", 0, "kcp bind udp port") diff --git a/pkg/ssh/server.go b/pkg/ssh/server.go index 30e79c64..264669a3 100644 --- a/pkg/ssh/server.go +++ b/pkg/ssh/server.go @@ -27,6 +27,7 @@ import ( libio "github.com/fatedier/golib/io" "github.com/samber/lo" "github.com/spf13/cobra" + flag "github.com/spf13/pflag" "golang.org/x/crypto/ssh" "github.com/fatedier/frp/client/proxy" @@ -64,6 +65,7 @@ type TunnelServer struct { underlyingConn net.Conn sshConn *ssh.ServerConn sc *ssh.ServerConfig + firstChannel ssh.Channel vc *virtual.Client peerServerListener *netpkg.InternalListener @@ -86,6 +88,7 @@ func (s *TunnelServer) Run() error { if err != nil { return err } + s.sshConn = sshConn addr, extraPayload, err := s.waitForwardAddrAndExtraPayload(channels, requests, 3*time.Second) @@ -93,9 +96,14 @@ func (s *TunnelServer) Run() error { return err } - clientCfg, pc, err := s.parseClientAndProxyConfigurer(addr, extraPayload) + clientCfg, pc, helpMessage, err := s.parseClientAndProxyConfigurer(addr, extraPayload) if err != nil { - return err + if errors.Is(err, flag.ErrHelp) { + s.writeToClient(helpMessage) + return nil + } + s.writeToClient(err.Error()) + return fmt.Errorf("parse flags from ssh client error: %v", err) } clientCfg.Complete() if sshConn.Permissions != nil { @@ -142,7 +150,11 @@ func (s *TunnelServer) Run() error { xl := xlog.New().AddPrefix(xlog.LogPrefix{Name: "sshVirtualClient", Value: "sshVirtualClient", Priority: 100}) ctx := xlog.NewContext(context.Background(), xl) go func() { - _ = s.vc.Run(ctx) + vcErr := s.vc.Run(ctx) + if vcErr != nil { + s.writeToClient(vcErr.Error()) + } + // If vc.Run returns, it means that the virtual client has been closed, and the ssh tunnel connection should be closed. // One scenario is that the virtual client exits due to login failure. s.closeDoneChOnce.Do(func() { @@ -153,9 +165,12 @@ func (s *TunnelServer) Run() error { s.vc.UpdateProxyConfigurer([]v1.ProxyConfigurer{pc}) - if err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil { + if ps, err := s.waitProxyStatusReady(pc.GetBaseConfig().Name, time.Second); err != nil { + s.writeToClient(err.Error()) log.Warn("wait proxy status ready error: %v", err) } else { + // success + s.writeToClient(createSuccessInfo(clientCfg.User, pc, ps)) _ = sshConn.Wait() } @@ -168,6 +183,13 @@ func (s *TunnelServer) Run() error { return nil } +func (s *TunnelServer) writeToClient(data string) { + if s.firstChannel == nil { + return + } + _, _ = s.firstChannel.Write([]byte(data + "\n")) +} + func (s *TunnelServer) waitForwardAddrAndExtraPayload( channels <-chan ssh.NewChannel, requests <-chan *ssh.Request, @@ -225,38 +247,47 @@ func (s *TunnelServer) waitForwardAddrAndExtraPayload( return addr, extraPayload, nil } -func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, error) { - cmd := &cobra.Command{} +func (s *TunnelServer) parseClientAndProxyConfigurer(_ *tcpipForward, extraPayload string) (*v1.ClientCommonConfig, v1.ProxyConfigurer, string, error) { + helpMessage := "" + cmd := &cobra.Command{ + Use: "ssh v0@{address} [command]", + Short: "ssh v0@{address} [command]", + Run: func(*cobra.Command, []string) {}, + } args := strings.Split(extraPayload, " ") if len(args) < 1 { - return nil, nil, fmt.Errorf("invalid extra payload") + return nil, nil, helpMessage, fmt.Errorf("invalid extra payload") } proxyType := strings.TrimSpace(args[0]) supportTypes := []string{"tcp", "http", "https", "tcpmux", "stcp"} if !lo.Contains(supportTypes, proxyType) { - return nil, nil, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes) + return nil, nil, helpMessage, fmt.Errorf("invalid proxy type: %s, support types: %v", proxyType, supportTypes) } pc := v1.NewProxyConfigurerByType(v1.ProxyType(proxyType)) if pc == nil { - return nil, nil, fmt.Errorf("new proxy configurer error") + return nil, nil, helpMessage, fmt.Errorf("new proxy configurer error") } - config.RegisterProxyFlags(cmd, pc) + config.RegisterProxyFlags(cmd, pc, config.WithSSHMode()) clientCfg := v1.ClientCommonConfig{} - config.RegisterClientCommonConfigFlags(cmd, &clientCfg) + config.RegisterClientCommonConfigFlags(cmd, &clientCfg, config.WithSSHMode()) + cmd.InitDefaultHelpCmd() if err := cmd.ParseFlags(args); err != nil { - return nil, nil, fmt.Errorf("parse flags from ssh client error: %v", err) + if errors.Is(err, flag.ErrHelp) { + helpMessage = cmd.UsageString() + } + return nil, nil, helpMessage, err } // if name is not set, generate a random one if pc.GetBaseConfig().Name == "" { id, err := util.RandIDWithLen(8) if err != nil { - return nil, nil, fmt.Errorf("generate random id error: %v", err) + return nil, nil, helpMessage, fmt.Errorf("generate random id error: %v", err) } pc.GetBaseConfig().Name = fmt.Sprintf("sshtunnel-%s-%s", proxyType, id) } - return &clientCfg, pc, nil + return &clientCfg, pc, helpMessage, nil } func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh chan string) { @@ -264,6 +295,9 @@ func (s *TunnelServer) handleNewChannel(channel ssh.NewChannel, extraPayloadCh c if err != nil { return } + if s.firstChannel == nil { + s.firstChannel = ch + } go s.keepAlive(ch) for req := range reqs { @@ -320,7 +354,7 @@ func (s *TunnelServer) openConn(addr *tcpipForward) (net.Conn, error) { return conn, nil } -func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) error { +func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) (*proxy.WorkingStatus, error) { ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() @@ -336,14 +370,14 @@ func (s *TunnelServer) waitProxyStatusReady(name string, timeout time.Duration) } switch ps.Phase { case proxy.ProxyPhaseRunning: - return nil + return ps, nil case proxy.ProxyPhaseStartErr, proxy.ProxyPhaseClosed: - return errors.New(ps.Err) + return ps, errors.New(ps.Err) } case <-timer.C: - return fmt.Errorf("wait proxy status ready timeout") + return nil, fmt.Errorf("wait proxy status ready timeout") case <-s.doneCh: - return fmt.Errorf("ssh tunnel server closed") + return nil, fmt.Errorf("ssh tunnel server closed") } } } diff --git a/pkg/ssh/terminal.go b/pkg/ssh/terminal.go new file mode 100644 index 00000000..a2e9a6ff --- /dev/null +++ b/pkg/ssh/terminal.go @@ -0,0 +1,31 @@ +// Copyright 2023 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "github.com/fatedier/frp/client/proxy" + v1 "github.com/fatedier/frp/pkg/config/v1" +) + +func createSuccessInfo(user string, pc v1.ProxyConfigurer, ps *proxy.WorkingStatus) string { + base := pc.GetBaseConfig() + out := "\n" + out += "frp (via SSH) (Ctrl+C to quit)\n\n" + out += "User: " + user + "\n" + out += "ProxyName: " + base.Name + "\n" + out += "Type: " + base.Type + "\n" + out += "RemoteAddress: " + ps.RemoteAddr + "\n" + return out +} diff --git a/test/e2e/pkg/ssh/client.go b/test/e2e/pkg/ssh/client.go index 1a923e9c..b45e39da 100644 --- a/test/e2e/pkg/ssh/client.go +++ b/test/e2e/pkg/ssh/client.go @@ -41,7 +41,7 @@ func (c *TunnelClient) Start() error { return err } c.ln = l - ch, req, err := conn.OpenChannel("direct", []byte("")) + ch, req, err := conn.OpenChannel("session", []byte("")) if err != nil { return err }