diff --git a/go.sum b/go.sum index 4535b11..81f698c 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,7 @@ github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoH github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/thomseddon/go-flags v1.4.0 h1:cHj56pbnQxlGo2lx2P8f0Dph4TRYKBJzoPuF2lqNvW4= github.com/thomseddon/go-flags v1.4.0/go.mod h1:NK9eZpNBmSKVxvyB/MExg6jW0Bo9hQyAuCP+b8MJFow= github.com/thomseddon/go-flags v1.4.1-0.20190507181358-ce437f05b7fb h1:L311/fJ7WXmFDDtuhf22PkVJqZpqLbEsmGSTEGv7ZQY= diff --git a/internal/server_test.go b/internal/server_test.go index e62267c..8b40a64 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TODO: @@ -300,6 +301,41 @@ func TestServerRouteQuery(t *testing.T) { assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") } +func TestServerAuthCallbackWithRules(t *testing.T) { + assert := assert.New(t) + config, _ = NewConfig([]string{}) + config.Rules = map[string]*Rule{ + "1": { + Action: "auth", + Rule: "Host(`example.com`) && Path(`/private`)", + }, + "2": { + Action: "allow", + Rule: "Host(`example.com`)", + }, + } + // Should allow /test request + req := newHttpRequest("GET", "https://example.com/", "/test") + res, _ := doHttpRequest(req, nil) + assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") + + // Should block /private request + req = newHttpRequest("GET", "https://example.com/", "/private") + res, _ = doHttpRequest(req, nil) + assert.Equal(307, res.StatusCode, "request matching auth rule should require auth") + + // Should allow callback request + req = newHttpRequest("GET", "https://example.com/", "/_oauth?state=12345678901234567890123456789012:https://example.com/private") + c := MakeCSRFCookie(req, "12345678901234567890123456789012") + res, _ = doHttpRequest(req, c) + require.Equal(t, 307, res.StatusCode, "callback request should be redirected") + + fwd, _ := res.Location() + assert.Equal("http", fwd.Scheme, "callback request should be correctly redirected") + assert.Equal("example.com", fwd.Host, "callback request should be correctly redirected") + assert.Equal("/private", fwd.Path, "callback request should be correctly redirected") +} + /** * Utilities */