diff --git a/ssh/server.go b/ssh/server.go index 70045bdfd8..5e9d019952 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -66,10 +66,25 @@ type ServerConfig struct { hostKeys []Signer + // ImplictAuthMethod is sent to the client in the list of acceptable + // authentication methods. To make an authentication decision based on + // connection metadata use NoClientAuthCallback. If NoClientAuthCallback is + // nil, the value is unused. + ImplictAuthMethod string + // NoClientAuth is true if clients are allowed to connect without // authenticating. + // To determine NoClientAuth at runtime, set NoClientAuth to true + // and the optional NoClientAuthCallback to a non-nil value. NoClientAuth bool + // NoClientAuthCallback, if non-nil, is called when a user + // attempts to authenticate with auth method "none". + // NoClientAuth must also be set to true for this be used, or + // this func is unused. + // If the function returns ErrDenied, the connection is terminated. + NoClientAuthCallback func(ConnMetadata) (*Permissions, error) + // MaxAuthTries specifies the maximum number of authentication attempts // permitted per connection. If set to a negative number, the number of // attempts are unlimited. If set to zero, the number of attempts are limited @@ -78,6 +93,7 @@ type ServerConfig struct { // PasswordCallback, if non-nil, is called when a user // attempts to authenticate using a password. + // If the function returns ErrDenied, the connection is terminated. PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) // PublicKeyCallback, if non-nil, is called when a client @@ -88,6 +104,7 @@ type ServerConfig struct { // offered is in fact used to authenticate. To record any data // depending on the public key, store it inside a // Permissions.Extensions entry. + // If the function returns ErrDenied, the connection is terminated. PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) // KeyboardInteractiveCallback, if non-nil, is called when @@ -97,6 +114,7 @@ type ServerConfig struct { // Challenge rounds. To avoid information leaks, the client // should be presented a challenge even if the user is // unknown. + // If the function returns ErrDenied, the connection is terminated. KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) // AuthLogCallback, if non-nil, is called to log all authentication @@ -292,6 +310,19 @@ func isAcceptableAlgo(algo string) bool { return false } +// WithBannerError is an error wrapper type that can be returned from an authentication +// function to additionally write out a banner error message. +type WithBannerError struct { + Err error + Message string +} + +func (e WithBannerError) Unwrap() error { + return e.Err +} + +func (e WithBannerError) Error() string { return e.Err.Error() } + func checkSourceAddress(addr net.Addr, sourceAddrs string) error { if addr == nil { return errors.New("ssh: no address known for client, but source-address match required") @@ -389,12 +420,19 @@ func (l ServerAuthError) Error() string { return "[" + strings.Join(errs, ", ") + "]" } -// ErrNoAuth is the error value returned if no -// authentication method has been passed yet. This happens as a normal -// part of the authentication loop, since the client first tries -// 'none' authentication to discover available methods. -// It is returned in ServerAuthError.Errors from NewServerConn. -var ErrNoAuth = errors.New("ssh: no auth passed yet") +var ( + // ErrDenied can be returned from an authentication callback to inform the + // client that access is denied and that no further attempt will be accepted + // on the connection. + ErrDenied = errors.New("ssh: access denied") + + // ErrNoAuth is the error value returned if no + // authentication method has been passed yet. This happens as a normal + // part of the authentication loop, since the client first tries + // 'none' authentication to discover available methods. + // It is returned in ServerAuthError.Errors from NewServerConn. + ErrNoAuth = errors.New("ssh: no auth passed yet") +) func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { sessionID := s.transport.getSessionID() @@ -455,7 +493,11 @@ userAuthLoop: switch userAuthReq.Method { case "none": if config.NoClientAuth { - authErr = nil + if config.NoClientAuthCallback != nil { + perms, authErr = config.NoClientAuthCallback(s) + } else { + authErr = nil + } } // allow initial attempt of 'none' without penalty @@ -639,6 +681,25 @@ userAuthLoop: break userAuthLoop } + var w WithBannerError + if errors.As(authErr, &w) && w.Message != "" { + bannerMsg := &userAuthBannerMsg{Message: w.Message} + if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil { + return nil, err + } + } + if errors.Is(authErr, ErrDenied) { + var failureMsg userAuthFailureMsg + if config.ImplictAuthMethod != "" { + failureMsg.Methods = []string{config.ImplictAuthMethod} + } + if err := s.transport.writePacket(Marshal(failureMsg)); err != nil { + return nil, err + } + + return nil, authErr + } + authFailures++ if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries { // If we have hit the max attempts, don't bother sending the @@ -666,6 +727,9 @@ userAuthLoop: } var failureMsg userAuthFailureMsg + if config.NoClientAuthCallback != nil && config.ImplictAuthMethod != "" { + failureMsg.Methods = append(failureMsg.Methods, config.ImplictAuthMethod) + } if config.PasswordCallback != nil { failureMsg.Methods = append(failureMsg.Methods, "password") } diff --git a/ssh/session_test.go b/ssh/session_test.go index 1009affddd..c421adfa46 100644 --- a/ssh/session_test.go +++ b/ssh/session_test.go @@ -780,3 +780,54 @@ func TestHostKeyAlgorithms(t *testing.T) { t.Fatal("succeeded connecting with unknown hostkey algorithm") } } + +func TestServerClientAuthCallback(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + userCh := make(chan string, 1) + + serverConf := &ServerConfig{ + NoClientAuth: true, + NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) { + userCh <- conn.User() + return nil, nil + }, + } + const someUsername = "some-username" + + serverConf.AddHostKey(testSigners["ecdsa"]) + clientConf := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + User: someUsername, + } + + go func() { + _, chans, reqs, err := NewServerConn(c1, serverConf) + if err != nil { + t.Errorf("server handshake: %v", err) + userCh <- "error" + return + } + go DiscardRequests(reqs) + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + conn, _, _, err := NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("client handshake: %v", err) + return + } + conn.Close() + + got := <-userCh + if got != someUsername { + t.Errorf("username = %q; want %q", got, someUsername) + } +}