From 3e6ccc8f45c64b40d4f273fcd6603456cefa2d6c Mon Sep 17 00:00:00 2001 From: Thom Seddon Date: Thu, 13 Jun 2019 15:13:52 +0100 Subject: [PATCH] Redirect to login on cookie expiry + simplify ValidateCookie function Possible fix for #31 --- internal/auth.go | 16 ++++++------- internal/auth_test.go | 22 +++++++---------- internal/config_test.go | 2 +- internal/server.go | 52 ++++++++++++++++++++++++----------------- internal/server_test.go | 42 +++++++++++++++++++++++++++++---- 5 files changed, 86 insertions(+), 48 deletions(-) diff --git a/internal/auth.go b/internal/auth.go index 2f569b0..b5723de 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -18,41 +18,41 @@ import ( // Request Validation // Cookie = hash(secret, cookie domain, email, expires)|expires|email -func ValidateCookie(r *http.Request, c *http.Cookie) (bool, string, error) { +func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) { parts := strings.Split(c.Value, "|") if len(parts) != 3 { - return false, "", errors.New("Invalid cookie format") + return "", errors.New("Invalid cookie format") } mac, err := base64.URLEncoding.DecodeString(parts[0]) if err != nil { - return false, "", errors.New("Unable to decode cookie mac") + return "", errors.New("Unable to decode cookie mac") } expectedSignature := cookieSignature(r, parts[2], parts[1]) expected, err := base64.URLEncoding.DecodeString(expectedSignature) if err != nil { - return false, "", errors.New("Unable to generate mac") + return "", errors.New("Unable to generate mac") } // Valid token? if !hmac.Equal(mac, expected) { - return false, "", errors.New("Invalid cookie mac") + return "", errors.New("Invalid cookie mac") } expires, err := strconv.ParseInt(parts[1], 10, 64) if err != nil { - return false, "", errors.New("Unable to parse cookie expiry") + return "", errors.New("Unable to parse cookie expiry") } // Has it expired? if time.Unix(expires, 0).Before(time.Now()) { - return false, "", errors.New("Cookie has expired") + return "", errors.New("Cookie has expired") } // Looks valid - return true, parts[2], nil + return parts[2], nil } // Validate email diff --git a/internal/auth_test.go b/internal/auth_test.go index baa1f22..9a91498 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -24,28 +24,24 @@ func TestAuthValidateCookie(t *testing.T) { // Should require 3 parts c.Value = "" - valid, _, err := ValidateCookie(r, c) - assert.False(valid) + _, err := ValidateCookie(r, c) if assert.Error(err) { assert.Equal("Invalid cookie format", err.Error()) } c.Value = "1|2" - valid, _, err = ValidateCookie(r, c) - assert.False(valid) + _, err = ValidateCookie(r, c) if assert.Error(err) { assert.Equal("Invalid cookie format", err.Error()) } c.Value = "1|2|3|4" - valid, _, err = ValidateCookie(r, c) - assert.False(valid) + _, err = ValidateCookie(r, c) if assert.Error(err) { assert.Equal("Invalid cookie format", err.Error()) } // Should catch invalid mac c.Value = "MQ==|2|3" - valid, _, err = ValidateCookie(r, c) - assert.False(valid) + _, err = ValidateCookie(r, c) if assert.Error(err) { assert.Equal("Invalid cookie mac", err.Error()) } @@ -53,8 +49,7 @@ func TestAuthValidateCookie(t *testing.T) { // Should catch expired config.Lifetime = time.Second * time.Duration(-1) c = MakeCookie(r, "test@test.com") - valid, _, err = ValidateCookie(r, c) - assert.False(valid) + _, err = ValidateCookie(r, c) if assert.Error(err) { assert.Equal("Cookie has expired", err.Error()) } @@ -62,8 +57,7 @@ func TestAuthValidateCookie(t *testing.T) { // Should accept valid cookie config.Lifetime = time.Second * time.Duration(10) c = MakeCookie(r, "test@test.com") - valid, email, err := ValidateCookie(r, c) - assert.True(valid, "valid request should return valid") + email, err := ValidateCookie(r, c) assert.Nil(err, "valid request should not return an error") assert.Equal("test@test.com", email, "valid request should return user email") } @@ -244,8 +238,8 @@ func TestAuthMakeCookie(t *testing.T) { assert.Equal("_forward_auth", c.Name) parts := strings.Split(c.Value, "|") assert.Len(parts, 3, "cookie should be 3 parts") - valid, _, _ := ValidateCookie(r, c) - assert.True(valid, "should generate valid cookie") + _, err := ValidateCookie(r, c) + assert.Nil(err, "should generate valid cookie") assert.Equal("/", c.Path) assert.Equal("app.example.com", c.Domain) assert.True(c.Secure) diff --git a/internal/config_test.go b/internal/config_test.go index 3184fd5..99f2a4c 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -237,7 +237,7 @@ func TestConfigParseEnvironmentBackwardsCompatability(t *testing.T) { "COOKIE_SECURE": "false", "COOKIE_DOMAINS": "test1.com,example.org", "COOKIE_DOMAIN": "another1.net", - "DOMAIN": "test2.com,example.org", + "DOMAIN": "test2.com,example.org", "WHITELIST": "test3.com,example.org", } for k, v := range vars { diff --git a/internal/server.go b/internal/server.go index c2ccb37..3119876 100644 --- a/internal/server.go +++ b/internal/server.go @@ -72,35 +72,25 @@ func (s *Server) AuthHandler(rule string) http.HandlerFunc { // Get auth cookie c, err := r.Cookie(config.CookieName) if err != nil { - // Error indicates no cookie, generate nonce - err, nonce := Nonce() - if err != nil { - logger.Errorf("Error generating nonce, %v", err) - http.Error(w, "Service unavailable", 503) - return - } - - // Set the CSRF cookie - http.SetCookie(w, MakeCSRFCookie(r, nonce)) - logger.Debug("Set CSRF cookie and redirecting to google login") - - // Forward them on - http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect) - - logger.Debug("Done") + s.authRedirect(logger, w, r) return } // Validate cookie - valid, email, err := ValidateCookie(r, c) - if !valid { - logger.Errorf("Invalid cookie: %v", err) - http.Error(w, "Not authorized", 401) + email, err := ValidateCookie(r, c) + if err != nil { + if err.Error() == "Cookie has expired" { + logger.Info("Cookie has expired") + s.authRedirect(logger, w, r) + } else { + logger.Errorf("Invalid cookie: %v", err) + http.Error(w, "Not authorized", 401) + } return } // Validate user - valid = ValidateEmail(email) + valid := ValidateEmail(email) if !valid { logger.WithFields(logrus.Fields{ "email": email, @@ -167,6 +157,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc { } } +func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request) { + // Error indicates no cookie, generate nonce + err, nonce := Nonce() + if err != nil { + logger.Errorf("Error generating nonce, %v", err) + http.Error(w, "Service unavailable", 503) + return + } + + // Set the CSRF cookie + http.SetCookie(w, MakeCSRFCookie(r, nonce)) + logger.Debug("Set CSRF cookie and redirecting to google login") + + // Forward them on + http.Redirect(w, r, GetLoginURL(r, nonce), http.StatusTemporaryRedirect) + + logger.Debug("Done") + return +} + func (s *Server) logger(r *http.Request, rule, msg string) *logrus.Entry { // Create logger logger := log.WithFields(logrus.Fields{ diff --git a/internal/server_test.go b/internal/server_test.go index 816cb4c..e62267c 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -8,6 +8,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -27,7 +28,7 @@ func init() { * Tests */ -func TestServerAuthHandler(t *testing.T) { +func TestServerAuthHandlerInvalid(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) @@ -57,13 +58,46 @@ func TestServerAuthHandler(t *testing.T) { res, _ = doHttpRequest(req, c) assert.Equal(401, res.StatusCode, "invalid email should not be authorised") +} + +func TestServerAuthHandlerExpired(t *testing.T) { + assert := assert.New(t) + config, _ = NewConfig([]string{}) + config.Lifetime = time.Second * time.Duration(-1) + config.Domains = []string{"test.com"} + + // Should redirect expired cookie + req := newDefaultHttpRequest("/foo") + c := MakeCookie(req, "test@example.com") + res, _ := doHttpRequest(req, c) + assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected") + + // Check for CSRF cookie + var cookie *http.Cookie + for _, c := range res.Cookies() { + if c.Name == config.CSRFCookieName { + cookie = c + } + } + assert.NotNil(cookie) + + // Check redirection location + fwd, _ := res.Location() + assert.Equal("https", fwd.Scheme, "request with expired cookie should be redirected to google") + assert.Equal("accounts.google.com", fwd.Host, "request with expired cookie should be redirected to google") + assert.Equal("/o/oauth2/auth", fwd.Path, "request with expired cookie should be redirected to google") +} + +func TestServerAuthHandlerValid(t *testing.T) { + assert := assert.New(t) + config, _ = NewConfig([]string{}) // Should allow valid request email - req = newDefaultHttpRequest("/foo") - c = MakeCookie(req, "test@example.com") + req := newDefaultHttpRequest("/foo") + c := MakeCookie(req, "test@example.com") config.Domains = []string{} - res, _ = doHttpRequest(req, c) + res, _ := doHttpRequest(req, c) assert.Equal(200, res.StatusCode, "valid request should be allowed") // Should pass through user