mirror of
https://github.com/thomseddon/traefik-forward-auth.git
synced 2026-02-06 22:22:15 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
26e7995ea1 | ||
|
|
20cef15e7b | ||
|
|
c4317b7503 | ||
|
|
4ffb6593d5 | ||
|
|
6c6f75e80d | ||
|
|
8be8244b13 | ||
|
|
f96a3fb332 | ||
|
|
c19f622fbd | ||
|
|
04f5499f0b | ||
|
|
41560feaa7 | ||
|
|
1743537438 |
36
.github/workflows/ci.yml
vendored
Normal file
36
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ^1.13
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Get dependencies
|
||||
run: |
|
||||
go get -v -t -d ./...
|
||||
if [ -f Gopkg.toml ]; then
|
||||
curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh
|
||||
dep ensure
|
||||
fi
|
||||
|
||||
- name: Build
|
||||
run: go build -v ./...
|
||||
|
||||
- name: Test
|
||||
run: go test -v ./...
|
||||
71
.github/workflows/codeql-analysis.yml
vendored
Normal file
71
.github/workflows/codeql-analysis.yml
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
# For most projects, this workflow file will not need changing; you simply need
|
||||
# to commit it to your repository.
|
||||
#
|
||||
# You may wish to alter this file to override the set of languages analyzed,
|
||||
# or to provide custom queries or build logic.
|
||||
name: "CodeQL"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master]
|
||||
pull_request:
|
||||
# The branches below must be a subset of the branches above
|
||||
branches: [master]
|
||||
schedule:
|
||||
- cron: '0 10 * * 2'
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# Override automatic language detection by changing the below list
|
||||
# Supported options are ['csharp', 'cpp', 'go', 'java', 'javascript', 'python']
|
||||
language: ['go']
|
||||
# Learn more...
|
||||
# https://docs.github.com/en/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#overriding-automatic-language-detection
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
# We must fetch at least the immediate parents so that if this is
|
||||
# a pull request then we can checkout the head.
|
||||
fetch-depth: 2
|
||||
|
||||
# If this run was triggered by a pull request event, then checkout
|
||||
# the head of the pull request instead of the merge commit.
|
||||
- run: git checkout HEAD^2
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v1
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
# If you wish to specify custom queries, you can do so here or in a config file.
|
||||
# By default, queries listed here will override any specified in a config file.
|
||||
# Prefix the list here with "+" to use these queries and those in the config file.
|
||||
# queries: ./path/to/local/query, your-org/your-repo/queries@main
|
||||
|
||||
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
|
||||
# If this step fails, then you should remove it and run the build manually (see below)
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v1
|
||||
|
||||
# ℹ️ Command-line programs to run using the OS shell.
|
||||
# 📚 https://git.io/JvXDl
|
||||
|
||||
# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
|
||||
# and modify them (or add more) to build your code if your project
|
||||
# uses a compiled language
|
||||
|
||||
#- run: |
|
||||
# make bootstrap
|
||||
# make release
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v1
|
||||
57
.github/workflows/release.yml
vendored
Normal file
57
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
|
||||
build:
|
||||
name: Build release binaries
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ^1.13
|
||||
id: go
|
||||
|
||||
- name: Build AMD64
|
||||
run: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 GO111MODULE=on go build -a -installsuffix nocgo -v -o traefik-forward-auth_amd64 ./cmd
|
||||
|
||||
- name: Build ARM
|
||||
run: CGO_ENABLED=0 GOOS=linux GOARCH=arm GO111MODULE=on go build -a -installsuffix nocgo -v -o traefik-forward-auth_arm ./cmd
|
||||
|
||||
- name: Get tag name
|
||||
run: echo "TAG=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
||||
|
||||
- name: Get artifact details
|
||||
uses: octokit/request-action@v2.x
|
||||
id: get_release_details
|
||||
with:
|
||||
route: get /repos/${{ github.repository }}/releases/tags/${{ env.TAG }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload AMD64 release asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ fromJson(steps.get_release_details.outputs.data).upload_url }}
|
||||
asset_path: traefik-forward-auth_amd64
|
||||
asset_name: traefik-forward-auth_amd64
|
||||
asset_content_type: application/octet-stream
|
||||
|
||||
- name: Upload ARM release asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ fromJson(steps.get_release_details.outputs.data).upload_url }}
|
||||
asset_path: traefik-forward-auth_arm
|
||||
asset_name: traefik-forward-auth_arm
|
||||
asset_content_type: application/octet-stream
|
||||
@@ -1,5 +0,0 @@
|
||||
language: go
|
||||
sudo: false
|
||||
go:
|
||||
- "1.12"
|
||||
script: env GO111MODULE=on go test -v ./...
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.13-alpine as builder
|
||||
FROM golang:1.20-alpine as builder
|
||||
|
||||
# Setup
|
||||
RUN mkdir -p /go/src/github.com/thomseddon/traefik-forward-auth
|
||||
|
||||
24
README.md
24
README.md
@@ -1,5 +1,5 @@
|
||||
|
||||
# Traefik Forward Auth [](https://travis-ci.org/thomseddon/traefik-forward-auth) [](https://goreportcard.com/report/github.com/thomseddon/traefik-forward-auth)  [](https://GitHub.com/thomseddon/traefik-forward-auth/releases/)
|
||||
# Traefik Forward Auth  [](https://goreportcard.com/report/github.com/thomseddon/traefik-forward-auth)  [](https://GitHub.com/thomseddon/traefik-forward-auth/releases/)
|
||||
|
||||
|
||||
A minimal forward authentication service that provides OAuth/SSO login and authentication for the [traefik](https://github.com/containous/traefik) reverse proxy/load balancer.
|
||||
@@ -9,8 +9,8 @@ A minimal forward authentication service that provides OAuth/SSO login and authe
|
||||
- Seamlessly overlays any http service with a single endpoint (see: `url-path` in [Configuration](#configuration))
|
||||
- Supports multiple providers including Google and OpenID Connect (supported by Azure, Github, Salesforce etc.)
|
||||
- Supports multiple domains/subdomains by dynamically generating redirect_uri's
|
||||
- Allows authentication to be selectively applied/bypassed based on request parameters (see `rules` in [Configuration](#configuration)))
|
||||
- Supports use of centralised authentication host/redirect_uri (see `auth-host` in [Configuration](#configuration)))
|
||||
- Allows authentication to be selectively applied/bypassed based on request parameters (see `rules` in [Configuration](#configuration))
|
||||
- Supports use of centralised authentication host/redirect_uri (see `auth-host` in [Configuration](#configuration))
|
||||
- Allows authentication to persist across multiple domains (see [Cookie Domains](#cookie-domains))
|
||||
- Supports extended authentication beyond Google token lifetime (see: `lifetime` in [Configuration](#configuration))
|
||||
|
||||
@@ -47,6 +47,8 @@ You can also use the latest incremental releases found on [docker hub](https://h
|
||||
|
||||
ARM releases are also available on docker hub, just append `-arm` or `-arm64` to your desired released (e.g. `2-arm` or `2.1-arm64`).
|
||||
|
||||
We also build binary files for usage without docker starting with releases after 2.2.0 You can find these as assets of the specific GitHub release.
|
||||
|
||||
#### Upgrade Guide
|
||||
|
||||
v2 was released in June 2019, whilst this is fully backwards compatible, a number of configuration options were modified, please see the [upgrade guide](https://github.com/thomseddon/traefik-forward-auth/wiki/v2-Upgrade-Guide) to prevent warnings on startup and ensure you are using the current configuration.
|
||||
@@ -92,7 +94,7 @@ services:
|
||||
|
||||
#### Advanced:
|
||||
|
||||
Please see the examples directory for a more complete [docker-compose.yml](https://github.com/thomseddon/traefik-forward-auth/blob/master/examples/traefik-v2/swarm/docker-compose.yml) or [kubernetes/simple-separate-pod](https://github.com/thomseddon/traefik-forward-auth/blob/masterexamples/traefik-v2/kubernetes/simple-separate-pod/).
|
||||
Please see the examples directory for a more complete [docker-compose.yml](https://github.com/thomseddon/traefik-forward-auth/blob/master/examples/traefik-v2/swarm/docker-compose.yml) or [kubernetes/simple-separate-pod](https://github.com/thomseddon/traefik-forward-auth/blob/master/examples/traefik-v2/kubernetes/simple-separate-pod/).
|
||||
|
||||
Also in the examples directory is [docker-compose-auth-host.yml](https://github.com/thomseddon/traefik-forward-auth/blob/master/examples/traefik-v2/swarm/docker-compose-auth-host.yml) and [kubernetes/advanced-separate-pod](https://github.com/thomseddon/traefik-forward-auth/blob/master/examples/traefik-v2/kubernetes/advanced-separate-pod/) which shows how to configure a central auth host, along with some other options.
|
||||
|
||||
@@ -162,6 +164,7 @@ Application Options:
|
||||
--url-path= Callback URL Path (default: /_oauth) [$URL_PATH]
|
||||
--secret= Secret used for signing (required) [$SECRET]
|
||||
--whitelist= Only allow given email addresses, can be set multiple times [$WHITELIST]
|
||||
--port= Port to listen on (default: 4181) [$PORT]
|
||||
--rule.<name>.<param>= Rule definitions, param can be: "action", "rule" or "provider"
|
||||
|
||||
Google Provider:
|
||||
@@ -321,6 +324,7 @@ All options can be supplied in any of the following ways, in the following prece
|
||||
- `action` - same usage as [`default-action`](#default-action), supported values:
|
||||
- `auth` (default)
|
||||
- `allow`
|
||||
- `domains` - optional, same usage as [`domain`](#domain)
|
||||
- `provider` - same usage as [`default-provider`](#default-provider), supported values:
|
||||
- `google`
|
||||
- `oidc`
|
||||
@@ -333,6 +337,7 @@ All options can be supplied in any of the following ways, in the following prece
|
||||
- ``Path(`path`, `/articles/{category}/{id:[0-9]+}`, ...)``
|
||||
- ``PathPrefix(`/products/`, `/articles/{category}/{id:[0-9]+}`)``
|
||||
- ``Query(`foo=bar`, `bar=baz`)``
|
||||
- `whitelist` - optional, same usage as whitelist`](#whitelist)
|
||||
|
||||
For example:
|
||||
```
|
||||
@@ -348,6 +353,11 @@ All options can be supplied in any of the following ways, in the following prece
|
||||
rule.oidc.action = auth
|
||||
rule.oidc.provider = oidc
|
||||
rule.oidc.rule = PathPrefix(`/github`)
|
||||
|
||||
# Allow jane@example.com to `/janes-eyes-only`
|
||||
rule.two.action = allow
|
||||
rule.two.rule = Path(`/janes-eyes-only`)
|
||||
rule.two.whitelist = jane@example.com
|
||||
```
|
||||
|
||||
Note: It is possible to break your redirect flow with rules, please be careful not to create an `allow` rule that matches your redirect_uri unless you know what you're doing. This limitation is being tracked in in #101 and the behaviour will change in future releases.
|
||||
@@ -361,7 +371,7 @@ You can restrict who can login with the following parameters:
|
||||
* `domain` - Use this to limit logins to a specific domain, e.g. test.com only
|
||||
* `whitelist` - Use this to only allow specific users to login e.g. thom@test.com only
|
||||
|
||||
Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3).
|
||||
Note, if you pass both `whitelist` and `domain`, then the default behaviour is for only `whitelist` to be used and `domain` will be effectively ignored. You can allow users matching *either* `whitelist` or `domain` by passing the `match-whitelist-or-domain` parameter (this will be the default behaviour in v3). If you set `domains` or `whitelist` on a rule, the global configuration is ignored.
|
||||
|
||||
### Forwarded Headers
|
||||
|
||||
@@ -416,8 +426,6 @@ spec:
|
||||
- name: traefik-forward-auth
|
||||
```
|
||||
|
||||
Note: If using auth host mode, you must apply the middleware to your auth host ingress.
|
||||
|
||||
See the examples directory for more examples.
|
||||
|
||||
#### Selective Container Authentication in Swarm
|
||||
@@ -432,8 +440,6 @@ whoami:
|
||||
- "traefik.http.routers.whoami.middlewares=traefik-forward-auth"
|
||||
```
|
||||
|
||||
Note: If using auth host mode, you must apply the middleware to the traefik-forward-auth container.
|
||||
|
||||
See the examples directory for more examples.
|
||||
|
||||
#### Rules Based Authentication
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
internal "github.com/thomseddon/traefik-forward-auth/internal"
|
||||
@@ -25,6 +26,6 @@ func main() {
|
||||
|
||||
// Start
|
||||
log.WithField("config", config).Debug("Starting with config")
|
||||
log.Info("Listening on :4181")
|
||||
log.Info(http.ListenAndServe(":4181", nil))
|
||||
log.Infof("Listening on :%d", config.Port)
|
||||
log.Info(http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil))
|
||||
}
|
||||
|
||||
@@ -16,7 +16,5 @@ spec:
|
||||
services:
|
||||
- name: traefik-forward-auth
|
||||
port: 4181
|
||||
middlewares:
|
||||
- name: traefik-forward-auth
|
||||
tls:
|
||||
certresolver: default
|
||||
|
||||
59
go.mod
59
go.mod
@@ -1,26 +1,49 @@
|
||||
module github.com/thomseddon/traefik-forward-auth
|
||||
|
||||
go 1.13
|
||||
go 1.22
|
||||
|
||||
toolchain go1.22.2
|
||||
|
||||
require (
|
||||
github.com/containous/traefik/v2 v2.1.2
|
||||
github.com/coreos/go-oidc v2.1.0+incompatible
|
||||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/sirupsen/logrus v1.4.2
|
||||
github.com/stretchr/testify v1.4.0
|
||||
github.com/coreos/go-oidc v2.2.1+incompatible
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/thomseddon/go-flags v1.4.1-0.20190507184247-a3629c504486
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45
|
||||
gopkg.in/square/go-jose.v2 v2.3.1
|
||||
github.com/traefik/traefik/v2 v2.11.2
|
||||
golang.org/x/oauth2 v0.20.0
|
||||
gopkg.in/square/go-jose.v2 v2.6.0
|
||||
)
|
||||
|
||||
// From traefik
|
||||
replace (
|
||||
github.com/Azure/go-autorest => github.com/Azure/go-autorest v12.4.1+incompatible
|
||||
github.com/abbot/go-http-auth => github.com/containous/go-http-auth v0.4.1-0.20180112153951-65b0cdae8d7f
|
||||
github.com/docker/docker => github.com/docker/engine v1.4.2-0.20191113042239-ea84732a7725
|
||||
github.com/go-check/check => github.com/containous/check v0.0.0-20170915194414-ca0bf163426a
|
||||
github.com/gorilla/mux => github.com/containous/mux v0.0.0-20181024131434-c33f32e26898
|
||||
github.com/mailgun/minheap => github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595
|
||||
github.com/mailgun/multibuf => github.com/containous/multibuf v0.0.0-20190809014333-8b6c9a7e6bba
|
||||
github.com/rancher/go-rancher-metadata => github.com/containous/go-rancher-metadata v0.0.0-20190402144056-c6a65f8b7a28
|
||||
require (
|
||||
github.com/containous/alice v0.0.0-20181107144136-d83ebdd94cbd // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/gorilla/mux v1.8.1 // indirect
|
||||
github.com/gravitational/trace v1.4.0 // indirect
|
||||
github.com/jonboulle/clockwork v0.4.0 // indirect
|
||||
github.com/miekg/dns v1.1.59 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/pquerna/cachecontrol v0.2.0 // indirect
|
||||
github.com/traefik/paerser v0.2.0 // indirect
|
||||
github.com/vulcand/predicate v1.2.0 // indirect
|
||||
golang.org/x/crypto v0.23.0 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/sys v0.20.0 // indirect
|
||||
golang.org/x/term v0.20.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/tools v0.21.0 // indirect
|
||||
google.golang.org/appengine v1.6.8 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
// Containous forks
|
||||
replace (
|
||||
github.com/abbot/go-http-auth => github.com/containous/go-http-auth v0.4.1-0.20200324110947-a37a7636d23e
|
||||
github.com/go-check/check => github.com/containous/check v0.0.0-20170915194414-ca0bf163426a
|
||||
github.com/gorilla/mux => github.com/containous/mux v0.0.0-20220627093034-b2dd784e613f
|
||||
github.com/mailgun/minheap => github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595
|
||||
)
|
||||
|
||||
119
internal/auth.go
119
internal/auth.go
@@ -59,18 +59,28 @@ func ValidateCookie(r *http.Request, c *http.Cookie) (string, error) {
|
||||
// ValidateEmail checks if the given email address matches either a whitelisted
|
||||
// email address, as defined by the "whitelist" config parameter. Or is part of
|
||||
// a permitted domain, as defined by the "domains" config parameter
|
||||
func ValidateEmail(email string) bool {
|
||||
func ValidateEmail(email, ruleName string) bool {
|
||||
// Use global config by default
|
||||
whitelist := config.Whitelist
|
||||
domains := config.Domains
|
||||
|
||||
if rule, ok := config.Rules[ruleName]; ok {
|
||||
// Override with rule config if found
|
||||
if len(rule.Whitelist) > 0 || len(rule.Domains) > 0 {
|
||||
whitelist = rule.Whitelist
|
||||
domains = rule.Domains
|
||||
}
|
||||
}
|
||||
|
||||
// Do we have any validation to perform?
|
||||
if len(config.Whitelist) == 0 && len(config.Domains) == 0 {
|
||||
if len(whitelist) == 0 && len(domains) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Email whitelist validation
|
||||
if len(config.Whitelist) > 0 {
|
||||
for _, whitelist := range config.Whitelist {
|
||||
if email == whitelist {
|
||||
return true
|
||||
}
|
||||
if len(whitelist) > 0 {
|
||||
if ValidateWhitelist(email, whitelist) {
|
||||
return true
|
||||
}
|
||||
|
||||
// If we're not matching *either*, stop here
|
||||
@@ -80,43 +90,54 @@ func ValidateEmail(email string) bool {
|
||||
}
|
||||
|
||||
// Domain validation
|
||||
if len(config.Domains) > 0 {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) < 2 {
|
||||
return false
|
||||
}
|
||||
for _, domain := range config.Domains {
|
||||
if domain == parts[1] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(domains) > 0 && ValidateDomains(email, domains) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateWhitelist checks if the email is in whitelist
|
||||
func ValidateWhitelist(email string, whitelist CommaSeparatedList) bool {
|
||||
for _, whitelist := range whitelist {
|
||||
if email == whitelist {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateDomains checks if the email matches a whitelisted domain
|
||||
func ValidateDomains(email string, domains CommaSeparatedList) bool {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) < 2 {
|
||||
return false
|
||||
}
|
||||
for _, domain := range domains {
|
||||
if domain == parts[1] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
|
||||
// Get the redirect base
|
||||
func redirectBase(r *http.Request) string {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
|
||||
return fmt.Sprintf("%s://%s", proto, host)
|
||||
return fmt.Sprintf("%s://%s", r.Header.Get("X-Forwarded-Proto"), r.Host)
|
||||
}
|
||||
|
||||
// Return url
|
||||
func returnUrl(r *http.Request) string {
|
||||
path := r.Header.Get("X-Forwarded-Uri")
|
||||
|
||||
return fmt.Sprintf("%s%s", redirectBase(r), path)
|
||||
return fmt.Sprintf("%s%s", redirectBase(r), r.URL.Path)
|
||||
}
|
||||
|
||||
// Get oauth redirect uri
|
||||
func redirectUri(r *http.Request) string {
|
||||
if use, _ := useAuthDomain(r); use {
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
|
||||
p := r.Header.Get("X-Forwarded-Proto")
|
||||
return fmt.Sprintf("%s://%s%s", p, config.AuthHost, config.Path)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", redirectBase(r), config.Path)
|
||||
@@ -129,7 +150,7 @@ func useAuthDomain(r *http.Request) (bool, string) {
|
||||
}
|
||||
|
||||
// Does the request match a given cookie domain?
|
||||
reqMatch, reqHost := matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
|
||||
reqMatch, reqHost := matchCookieDomains(r.Host)
|
||||
|
||||
// Do any of the auth hosts match a cookie domain?
|
||||
authMatch, authHost := matchCookieDomains(config.AuthHost)
|
||||
@@ -170,23 +191,31 @@ func ClearCookie(r *http.Request) *http.Cookie {
|
||||
}
|
||||
}
|
||||
|
||||
func buildCSRFCookieName(nonce string) string {
|
||||
return config.CSRFCookieName + "_" + nonce[:6]
|
||||
}
|
||||
|
||||
// MakeCSRFCookie makes a csrf cookie (used during login only)
|
||||
//
|
||||
// Note, CSRF cookies live shorter than auth cookies, a fixed 1h.
|
||||
// That's because some CSRF cookies may belong to auth flows that don't complete
|
||||
// and thus may not get cleared by ClearCookie.
|
||||
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: config.CSRFCookieName,
|
||||
Name: buildCSRFCookieName(nonce),
|
||||
Value: nonce,
|
||||
Path: "/",
|
||||
Domain: csrfCookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: !config.InsecureCookie,
|
||||
Expires: cookieExpiry(),
|
||||
Expires: time.Now().Local().Add(time.Hour * 1),
|
||||
}
|
||||
}
|
||||
|
||||
// ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie
|
||||
func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||
func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: config.CSRFCookieName,
|
||||
Name: c.Name,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: csrfCookieDomain(r),
|
||||
@@ -196,18 +225,18 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateCSRFCookie validates the csrf cookie against state
|
||||
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
|
||||
state := r.URL.Query().Get("state")
|
||||
// FindCSRFCookie extracts the CSRF cookie from the request based on state.
|
||||
func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) {
|
||||
// Check for CSRF cookie
|
||||
return r.Cookie(buildCSRFCookieName(state))
|
||||
}
|
||||
|
||||
// ValidateCSRFCookie validates the csrf cookie against state
|
||||
func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) {
|
||||
if len(c.Value) != 32 {
|
||||
return false, "", "", errors.New("Invalid CSRF cookie value")
|
||||
}
|
||||
|
||||
if len(state) < 34 {
|
||||
return false, "", "", errors.New("Invalid CSRF state value")
|
||||
}
|
||||
|
||||
// Check nonce match
|
||||
if c.Value != state[:32] {
|
||||
return false, "", "", errors.New("CSRF cookie does not match state")
|
||||
@@ -229,6 +258,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string {
|
||||
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
|
||||
}
|
||||
|
||||
// ValidateState checks whether the state is of right length.
|
||||
func ValidateState(state string) error {
|
||||
if len(state) < 34 {
|
||||
return errors.New("Invalid CSRF state value")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Nonce generates a random nonce
|
||||
func Nonce() (error, string) {
|
||||
nonce := make([]byte, 16)
|
||||
@@ -242,10 +279,8 @@ func Nonce() (error, string) {
|
||||
|
||||
// Cookie domain
|
||||
func cookieDomain(r *http.Request) string {
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
|
||||
// Check if any of the given cookie domains matches
|
||||
_, domain := matchCookieDomains(host)
|
||||
_, domain := matchCookieDomains(r.Host)
|
||||
return domain
|
||||
}
|
||||
|
||||
@@ -255,7 +290,7 @@ func csrfCookieDomain(r *http.Request) string {
|
||||
if use, domain := useAuthDomain(r); use {
|
||||
host = domain
|
||||
} else {
|
||||
host = r.Header.Get("X-Forwarded-Host")
|
||||
host = r.Host
|
||||
}
|
||||
|
||||
// Remove port
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package tfa
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -66,32 +66,25 @@ func TestAuthValidateEmail(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
|
||||
// Should allow any
|
||||
v := ValidateEmail("test@test.com")
|
||||
// Should allow any with no whitelist/domain is specified
|
||||
v := ValidateEmail("test@test.com", "default")
|
||||
assert.True(v, "should allow any domain if email domain is not defined")
|
||||
v = ValidateEmail("one@two.com")
|
||||
v = ValidateEmail("one@two.com", "default")
|
||||
assert.True(v, "should allow any domain if email domain is not defined")
|
||||
|
||||
// Should block non matching domain
|
||||
config.Domains = []string{"test.com"}
|
||||
v = ValidateEmail("one@two.com")
|
||||
assert.False(v, "should not allow user from another domain")
|
||||
|
||||
// Should allow matching domain
|
||||
config.Domains = []string{"test.com"}
|
||||
v = ValidateEmail("test@test.com")
|
||||
v = ValidateEmail("one@two.com", "default")
|
||||
assert.False(v, "should not allow user from another domain")
|
||||
v = ValidateEmail("test@test.com", "default")
|
||||
assert.True(v, "should allow user from allowed domain")
|
||||
|
||||
// Should block non whitelisted email address
|
||||
config.Domains = []string{}
|
||||
config.Whitelist = []string{"test@test.com"}
|
||||
v = ValidateEmail("one@two.com")
|
||||
assert.False(v, "should not allow user not in whitelist")
|
||||
|
||||
// Should allow matching whitelisted email address
|
||||
config.Domains = []string{}
|
||||
config.Whitelist = []string{"test@test.com"}
|
||||
v = ValidateEmail("test@test.com")
|
||||
v = ValidateEmail("one@two.com", "default")
|
||||
assert.False(v, "should not allow user not in whitelist")
|
||||
v = ValidateEmail("test@test.com", "default")
|
||||
assert.True(v, "should allow user in whitelist")
|
||||
|
||||
// Should allow only matching email address when
|
||||
@@ -99,33 +92,113 @@ func TestAuthValidateEmail(t *testing.T) {
|
||||
config.Domains = []string{"example.com"}
|
||||
config.Whitelist = []string{"test@test.com"}
|
||||
config.MatchWhitelistOrDomain = false
|
||||
v = ValidateEmail("test@test.com")
|
||||
assert.True(v, "should allow user in whitelist")
|
||||
v = ValidateEmail("test@example.com")
|
||||
assert.False(v, "should not allow user from valid domain")
|
||||
v = ValidateEmail("one@two.com")
|
||||
v = ValidateEmail("one@two.com", "default")
|
||||
assert.False(v, "should not allow user not in either")
|
||||
v = ValidateEmail("test@example.com", "default")
|
||||
assert.False(v, "should not allow user from allowed domain")
|
||||
v = ValidateEmail("test@test.com", "default")
|
||||
assert.True(v, "should allow user in whitelist")
|
||||
|
||||
// Should allow either matching domain or email address when
|
||||
// MatchWhitelistOrDomain is enabled
|
||||
config.Domains = []string{"example.com"}
|
||||
config.Whitelist = []string{"test@test.com"}
|
||||
config.MatchWhitelistOrDomain = true
|
||||
v = ValidateEmail("test@test.com")
|
||||
assert.True(v, "should allow user in whitelist")
|
||||
v = ValidateEmail("test@example.com")
|
||||
assert.True(v, "should allow user from valid domain")
|
||||
v = ValidateEmail("one@two.com")
|
||||
v = ValidateEmail("one@two.com", "default")
|
||||
assert.False(v, "should not allow user not in either")
|
||||
v = ValidateEmail("test@example.com", "default")
|
||||
assert.True(v, "should allow user from allowed domain")
|
||||
v = ValidateEmail("test@test.com", "default")
|
||||
assert.True(v, "should allow user in whitelist")
|
||||
|
||||
// Rule testing
|
||||
|
||||
// Should use global whitelist/domain when not specified on rule
|
||||
config.Domains = []string{"example.com"}
|
||||
config.Whitelist = []string{"test@test.com"}
|
||||
config.Rules = map[string]*Rule{"test": NewRule()}
|
||||
config.MatchWhitelistOrDomain = true
|
||||
v = ValidateEmail("one@two.com", "test")
|
||||
assert.False(v, "should not allow user not in either")
|
||||
v = ValidateEmail("test@example.com", "test")
|
||||
assert.True(v, "should allow user from allowed global domain")
|
||||
v = ValidateEmail("test@test.com", "test")
|
||||
assert.True(v, "should allow user in global whitelist")
|
||||
|
||||
// Should allow matching domain in rule
|
||||
config.Domains = []string{"testglobal.com"}
|
||||
config.Whitelist = []string{}
|
||||
rule := NewRule()
|
||||
config.Rules = map[string]*Rule{"test": rule}
|
||||
rule.Domains = []string{"testrule.com"}
|
||||
config.MatchWhitelistOrDomain = false
|
||||
v = ValidateEmail("one@two.com", "test")
|
||||
assert.False(v, "should not allow user from another domain")
|
||||
v = ValidateEmail("one@testglobal.com", "test")
|
||||
assert.False(v, "should not allow user from global domain")
|
||||
v = ValidateEmail("test@testrule.com", "test")
|
||||
assert.True(v, "should allow user from allowed domain")
|
||||
|
||||
// Should allow matching whitelist in rule
|
||||
config.Domains = []string{}
|
||||
config.Whitelist = []string{"test@testglobal.com"}
|
||||
rule = NewRule()
|
||||
config.Rules = map[string]*Rule{"test": rule}
|
||||
rule.Whitelist = []string{"test@testrule.com"}
|
||||
config.MatchWhitelistOrDomain = false
|
||||
v = ValidateEmail("one@two.com", "test")
|
||||
assert.False(v, "should not allow user from another domain")
|
||||
v = ValidateEmail("test@testglobal.com", "test")
|
||||
assert.False(v, "should not allow user from global domain")
|
||||
v = ValidateEmail("test@testrule.com", "test")
|
||||
assert.True(v, "should allow user from allowed domain")
|
||||
|
||||
// Should allow only matching email address when
|
||||
// MatchWhitelistOrDomain is disabled
|
||||
config.Domains = []string{"exampleglobal.com"}
|
||||
config.Whitelist = []string{"test@testglobal.com"}
|
||||
rule = NewRule()
|
||||
config.Rules = map[string]*Rule{"test": rule}
|
||||
rule.Domains = []string{"examplerule.com"}
|
||||
rule.Whitelist = []string{"test@testrule.com"}
|
||||
config.MatchWhitelistOrDomain = false
|
||||
v = ValidateEmail("one@two.com", "test")
|
||||
assert.False(v, "should not allow user not in either")
|
||||
v = ValidateEmail("test@testglobal.com", "test")
|
||||
assert.False(v, "should not allow user in global whitelist")
|
||||
v = ValidateEmail("test@exampleglobal.com", "test")
|
||||
assert.False(v, "should not allow user from global domain")
|
||||
v = ValidateEmail("test@examplerule.com", "test")
|
||||
assert.False(v, "should not allow user from allowed domain")
|
||||
v = ValidateEmail("test@testrule.com", "test")
|
||||
assert.True(v, "should allow user in whitelist")
|
||||
|
||||
// Should allow either matching domain or email address when
|
||||
// MatchWhitelistOrDomain is enabled
|
||||
config.Domains = []string{"exampleglobal.com"}
|
||||
config.Whitelist = []string{"test@testglobal.com"}
|
||||
rule = NewRule()
|
||||
config.Rules = map[string]*Rule{"test": rule}
|
||||
rule.Domains = []string{"examplerule.com"}
|
||||
rule.Whitelist = []string{"test@testrule.com"}
|
||||
config.MatchWhitelistOrDomain = true
|
||||
v = ValidateEmail("one@two.com", "test")
|
||||
assert.False(v, "should not allow user not in either")
|
||||
v = ValidateEmail("test@testglobal.com", "test")
|
||||
assert.False(v, "should not allow user in global whitelist")
|
||||
v = ValidateEmail("test@exampleglobal.com", "test")
|
||||
assert.False(v, "should not allow user from global domain")
|
||||
v = ValidateEmail("test@examplerule.com", "test")
|
||||
assert.True(v, "should allow user from allowed domain")
|
||||
v = ValidateEmail("test@testrule.com", "test")
|
||||
assert.True(v, "should allow user in whitelist")
|
||||
}
|
||||
|
||||
func TestRedirectUri(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
r := httptest.NewRequest("GET", "http://app.example.com/hello", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "app.example.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
//
|
||||
// No Auth Host
|
||||
@@ -167,10 +240,8 @@ func TestRedirectUri(t *testing.T) {
|
||||
// With Auth URL + cookie domain, but from different domain
|
||||
// - will not use auth host
|
||||
//
|
||||
r, _ = http.NewRequest("GET", "http://another.com", nil)
|
||||
r = httptest.NewRequest("GET", "https://another.com/hello", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "https")
|
||||
r.Header.Add("X-Forwarded-Host", "another.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
config.AuthHost = "auth.example.com"
|
||||
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
||||
@@ -217,29 +288,30 @@ func TestAuthMakeCSRFCookie(t *testing.T) {
|
||||
|
||||
// No cookie domain or auth url
|
||||
c := MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
assert.Equal("_forward_auth_csrf_123456", c.Name)
|
||||
assert.Equal("app.example.com", c.Domain)
|
||||
|
||||
// With cookie domain but no auth url
|
||||
config = &Config{
|
||||
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||
}
|
||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
||||
c = MakeCSRFCookie(r, "12222278901234567890123456789012")
|
||||
assert.Equal("_forward_auth_csrf_122222", c.Name)
|
||||
assert.Equal("app.example.com", c.Domain)
|
||||
|
||||
// With cookie domain and auth url
|
||||
config = &Config{
|
||||
AuthHost: "auth.example.com",
|
||||
CookieDomains: []CookieDomain{*NewCookieDomain("example.com")},
|
||||
}
|
||||
c = MakeCSRFCookie(r, "12345678901234567890123456789012")
|
||||
config.AuthHost = "auth.example.com"
|
||||
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
|
||||
c = MakeCSRFCookie(r, "12333378901234567890123456789012")
|
||||
assert.Equal("_forward_auth_csrf_123333", c.Name)
|
||||
assert.Equal("example.com", c.Domain)
|
||||
}
|
||||
|
||||
func TestAuthClearCSRFCookie(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
|
||||
c := ClearCSRFCookie(r)
|
||||
c := ClearCSRFCookie(r, &http.Cookie{Name: "someCsrfCookie"})
|
||||
assert.Equal("someCsrfCookie", c.Name)
|
||||
if c.Value != "" {
|
||||
t.Error("ClearCSRFCookie should create cookie with empty value")
|
||||
}
|
||||
@@ -249,63 +321,62 @@ func TestAuthValidateCSRFCookie(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config, _ = NewConfig([]string{})
|
||||
c := &http.Cookie{}
|
||||
|
||||
newCsrfRequest := func(state string) *http.Request {
|
||||
u := fmt.Sprintf("http://example.com?state=%s", state)
|
||||
r, _ := http.NewRequest("GET", u, nil)
|
||||
return r
|
||||
}
|
||||
state := ""
|
||||
|
||||
// Should require 32 char string
|
||||
r := newCsrfRequest("")
|
||||
state = ""
|
||||
c.Value = ""
|
||||
valid, _, _, err := ValidateCSRFCookie(r, c)
|
||||
valid, _, _, err := ValidateCSRFCookie(c, state)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF cookie value", err.Error())
|
||||
}
|
||||
c.Value = "123456789012345678901234567890123"
|
||||
valid, _, _, err = ValidateCSRFCookie(r, c)
|
||||
valid, _, _, err = ValidateCSRFCookie(c, state)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF cookie value", err.Error())
|
||||
}
|
||||
|
||||
// Should require valid state
|
||||
r = newCsrfRequest("12345678901234567890123456789012:")
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, _, err = ValidateCSRFCookie(r, c)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF state value", err.Error())
|
||||
}
|
||||
|
||||
// Should require provider
|
||||
r = newCsrfRequest("12345678901234567890123456789012:99")
|
||||
state = "12345678901234567890123456789012:99"
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, _, _, err = ValidateCSRFCookie(r, c)
|
||||
valid, _, _, err = ValidateCSRFCookie(c, state)
|
||||
assert.False(valid)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF state format", err.Error())
|
||||
}
|
||||
|
||||
// Should allow valid state
|
||||
r = newCsrfRequest("12345678901234567890123456789012:p99:url123")
|
||||
state = "12345678901234567890123456789012:p99:url123"
|
||||
c.Value = "12345678901234567890123456789012"
|
||||
valid, provider, redirect, err := ValidateCSRFCookie(r, c)
|
||||
valid, provider, redirect, err := ValidateCSRFCookie(c, state)
|
||||
assert.True(valid, "valid request should return valid")
|
||||
assert.Nil(err, "valid request should not return an error")
|
||||
assert.Equal("p99", provider, "valid request should return correct provider")
|
||||
assert.Equal("url123", redirect, "valid request should return correct redirect")
|
||||
}
|
||||
|
||||
func TestValidateState(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
// Should require valid state
|
||||
state := "12345678901234567890123456789012:"
|
||||
err := ValidateState(state)
|
||||
if assert.Error(err) {
|
||||
assert.Equal("Invalid CSRF state value", err.Error())
|
||||
}
|
||||
// Should pass this state
|
||||
state = "12345678901234567890123456789012:p99:url123"
|
||||
err = ValidateState(state)
|
||||
assert.Nil(err, "valid request should not return an error")
|
||||
}
|
||||
|
||||
func TestMakeState(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
r, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
r := httptest.NewRequest("GET", "http://example.com/hello", nil)
|
||||
r.Header.Add("X-Forwarded-Proto", "http")
|
||||
r.Header.Add("X-Forwarded-Host", "example.com")
|
||||
r.Header.Add("X-Forwarded-Uri", "/hello")
|
||||
|
||||
// Test with google
|
||||
p := provider.Google{}
|
||||
|
||||
@@ -39,6 +39,7 @@ type Config struct {
|
||||
Path string `long:"url-path" env:"URL_PATH" default:"/_oauth" description:"Callback URL Path"`
|
||||
SecretString string `long:"secret" env:"SECRET" description:"Secret used for signing (required)" json:"-"`
|
||||
Whitelist CommaSeparatedList `long:"whitelist" env:"WHITELIST" env-delim:"," description:"Only allow given email addresses, can be set multiple times"`
|
||||
Port int `long:"port" env:"PORT" default:"4181" description:"Port to listen on"`
|
||||
|
||||
Providers provider.Providers `group:"providers" namespace:"providers" env-namespace:"PROVIDERS"`
|
||||
Rules map[string]*Rule `long:"rule.<name>.<param>" description:"Rule definitions, param can be: \"action\", \"rule\" or \"provider\""`
|
||||
@@ -210,6 +211,14 @@ func (c *Config) parseUnknownFlag(option string, arg flags.SplitArgument, args [
|
||||
rule.Rule = val
|
||||
case "provider":
|
||||
rule.Provider = val
|
||||
case "whitelist":
|
||||
list := CommaSeparatedList{}
|
||||
list.UnmarshalFlag(val)
|
||||
rule.Whitelist = list
|
||||
case "domains":
|
||||
list := CommaSeparatedList{}
|
||||
list.UnmarshalFlag(val)
|
||||
rule.Domains = list
|
||||
default:
|
||||
return args, fmt.Errorf("invalid route param: %v", option)
|
||||
}
|
||||
@@ -327,9 +336,11 @@ func (c *Config) setupProvider(name string) error {
|
||||
|
||||
// Rule holds defined rules
|
||||
type Rule struct {
|
||||
Action string
|
||||
Rule string
|
||||
Provider string
|
||||
Action string
|
||||
Rule string
|
||||
Provider string
|
||||
Whitelist CommaSeparatedList
|
||||
Domains CommaSeparatedList
|
||||
}
|
||||
|
||||
// NewRule creates a new rule object
|
||||
|
||||
@@ -37,6 +37,7 @@ func TestConfigDefaults(t *testing.T) {
|
||||
assert.False(c.MatchWhitelistOrDomain)
|
||||
assert.Equal("/_oauth", c.Path)
|
||||
assert.Len(c.Whitelist, 0)
|
||||
assert.Equal(c.Port, 4181)
|
||||
|
||||
assert.Equal("select_account", c.Providers.Google.Prompt)
|
||||
}
|
||||
@@ -51,6 +52,7 @@ func TestConfigParseArgs(t *testing.T) {
|
||||
"--rule.1.rule=PathPrefix(`/one`)",
|
||||
"--rule.two.action=auth",
|
||||
"--rule.two.rule=\"Host(`two.com`) && Path(`/two`)\"",
|
||||
"--port=8000",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
|
||||
@@ -58,6 +60,7 @@ func TestConfigParseArgs(t *testing.T) {
|
||||
assert.Equal("cookiename", c.CookieName)
|
||||
assert.Equal("csrfcookiename", c.CSRFCookieName)
|
||||
assert.Equal("oidc", c.DefaultProvider)
|
||||
assert.Equal(8000, c.Port)
|
||||
|
||||
// Check rules
|
||||
assert.Equal(map[string]*Rule{
|
||||
|
||||
@@ -4,17 +4,17 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/rules"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/thomseddon/traefik-forward-auth/internal/provider"
|
||||
muxhttp "github.com/traefik/traefik/v2/pkg/muxer/http"
|
||||
)
|
||||
|
||||
// Server contains router and handler methods
|
||||
// Server contains muxer and handler methods
|
||||
type Server struct {
|
||||
router *rules.Router
|
||||
muxer *muxhttp.Muxer
|
||||
}
|
||||
|
||||
// NewServer creates a new server object and builds router
|
||||
// NewServer creates a new server object and builds muxer
|
||||
func NewServer() *Server {
|
||||
s := &Server{}
|
||||
s.buildRoutes()
|
||||
@@ -23,32 +23,32 @@ func NewServer() *Server {
|
||||
|
||||
func (s *Server) buildRoutes() {
|
||||
var err error
|
||||
s.router, err = rules.NewRouter()
|
||||
s.muxer, err = muxhttp.NewMuxer()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Let's build a router
|
||||
// Let's build a muxer
|
||||
for name, rule := range config.Rules {
|
||||
matchRule := rule.formattedRule()
|
||||
if rule.Action == "allow" {
|
||||
s.router.AddRoute(matchRule, 1, s.AllowHandler(name))
|
||||
_ = s.muxer.AddRoute(matchRule, 1, s.AllowHandler(name))
|
||||
} else {
|
||||
s.router.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, name))
|
||||
_ = s.muxer.AddRoute(matchRule, 1, s.AuthHandler(rule.Provider, name))
|
||||
}
|
||||
}
|
||||
|
||||
// Add callback handler
|
||||
s.router.Handle(config.Path, s.AuthCallbackHandler())
|
||||
s.muxer.Handle(config.Path, s.AuthCallbackHandler())
|
||||
|
||||
// Add logout handler
|
||||
s.router.Handle(config.Path+"/logout", s.LogoutHandler())
|
||||
s.muxer.Handle(config.Path+"/logout", s.LogoutHandler())
|
||||
|
||||
// Add a default handler
|
||||
if config.DefaultAction == "allow" {
|
||||
s.router.NewRoute().Handler(s.AllowHandler("default"))
|
||||
s.muxer.NewRoute().Handler(s.AllowHandler("default"))
|
||||
} else {
|
||||
s.router.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, "default"))
|
||||
s.muxer.NewRoute().Handler(s.AuthHandler(config.DefaultProvider, "default"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,10 +58,14 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Modify request
|
||||
r.Method = r.Header.Get("X-Forwarded-Method")
|
||||
r.Host = r.Header.Get("X-Forwarded-Host")
|
||||
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||
|
||||
// Read URI from header if we're acting as forward auth middleware
|
||||
if _, ok := r.Header["X-Forwarded-Uri"]; ok {
|
||||
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
|
||||
}
|
||||
|
||||
// Pass to mux
|
||||
s.router.ServeHTTP(w, r)
|
||||
s.muxer.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// AllowHandler Allows requests
|
||||
@@ -101,7 +105,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Validate user
|
||||
valid := ValidateEmail(email)
|
||||
valid := ValidateEmail(email, rule)
|
||||
if !valid {
|
||||
logger.WithField("email", email).Warn("Invalid email")
|
||||
http.Error(w, "Not authorized", 401)
|
||||
@@ -121,16 +125,26 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
// Logging setup
|
||||
logger := s.logger(r, "AuthCallback", "default", "Handling callback")
|
||||
|
||||
// Check state
|
||||
state := r.URL.Query().Get("state")
|
||||
if err := ValidateState(state); err != nil {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
}).Warn("Error validating state")
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for CSRF cookie
|
||||
c, err := r.Cookie(config.CSRFCookieName)
|
||||
c, err := FindCSRFCookie(r, state)
|
||||
if err != nil {
|
||||
logger.Info("Missing csrf cookie")
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state
|
||||
valid, providerName, redirect, err := ValidateCSRFCookie(r, c)
|
||||
// Validate CSRF cookie against state
|
||||
valid, providerName, redirect, err := ValidateCSRFCookie(c, state)
|
||||
if !valid {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
@@ -153,7 +167,7 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Clear CSRF cookie
|
||||
http.SetCookie(w, ClearCSRFCookie(r))
|
||||
http.SetCookie(w, ClearCSRFCookie(r, c))
|
||||
|
||||
// Exchange code for token
|
||||
token, err := p.ExchangeCode(redirectUri(r), r.URL.Query().Get("code"))
|
||||
|
||||
@@ -31,6 +31,37 @@ func init() {
|
||||
* Tests
|
||||
*/
|
||||
|
||||
func TestServerRootHandler(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config = newDefaultConfig()
|
||||
|
||||
// X-Forwarded headers should be read into request
|
||||
req := httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should?ignore=me", nil)
|
||||
req.Header.Add("X-Forwarded-Method", "GET")
|
||||
req.Header.Add("X-Forwarded-Proto", "https")
|
||||
req.Header.Add("X-Forwarded-Host", "example.com")
|
||||
req.Header.Add("X-Forwarded-Uri", "/foo?q=bar")
|
||||
NewServer().RootHandler(httptest.NewRecorder(), req)
|
||||
|
||||
assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
|
||||
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
|
||||
assert.Equal("/foo", req.URL.Path, "x-forwarded-uri should be read into request")
|
||||
assert.Equal("/foo?q=bar", req.URL.RequestURI(), "x-forwarded-uri should be read into request")
|
||||
|
||||
// Other X-Forwarded headers should be read in into request and original URL
|
||||
// should be preserved if X-Forwarded-Uri not present
|
||||
req = httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should-not?ignore=me", nil)
|
||||
req.Header.Add("X-Forwarded-Method", "GET")
|
||||
req.Header.Add("X-Forwarded-Proto", "https")
|
||||
req.Header.Add("X-Forwarded-Host", "example.com")
|
||||
NewServer().RootHandler(httptest.NewRecorder(), req)
|
||||
|
||||
assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
|
||||
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
|
||||
assert.Equal("/should-not", req.URL.Path, "request url should be preserved if x-forwarded-uri not present")
|
||||
assert.Equal("/should-not?ignore=me", req.URL.RequestURI(), "request url should be preserved if x-forwarded-uri not present")
|
||||
}
|
||||
|
||||
func TestServerAuthHandlerInvalid(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
config = newDefaultConfig()
|
||||
@@ -90,15 +121,15 @@ func TestServerAuthHandlerExpired(t *testing.T) {
|
||||
config.Domains = []string{"test.com"}
|
||||
|
||||
// Should redirect expired cookie
|
||||
req := newDefaultHttpRequest("/foo")
|
||||
req := newHTTPRequest("GET", "http://example.com/foo")
|
||||
c := MakeCookie(req, "test@example.com")
|
||||
res, _ := doHttpRequest(req, c)
|
||||
assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected")
|
||||
require.Equal(t, 307, res.StatusCode, "request with expired cookie should be redirected")
|
||||
|
||||
// Check for CSRF cookie
|
||||
var cookie *http.Cookie
|
||||
for _, c := range res.Cookies() {
|
||||
if c.Name == config.CSRFCookieName {
|
||||
if strings.HasPrefix(c.Name, config.CSRFCookieName) {
|
||||
cookie = c
|
||||
}
|
||||
}
|
||||
@@ -116,7 +147,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
|
||||
config = newDefaultConfig()
|
||||
|
||||
// Should allow valid request email
|
||||
req := newDefaultHttpRequest("/foo")
|
||||
req := newHTTPRequest("GET", "http://example.com/foo")
|
||||
c := MakeCookie(req, "test@example.com")
|
||||
config.Domains = []string{}
|
||||
|
||||
@@ -131,6 +162,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
|
||||
|
||||
func TestServerAuthCallback(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
config = newDefaultConfig()
|
||||
|
||||
// Setup OAuth server
|
||||
@@ -148,27 +180,28 @@ func TestServerAuthCallback(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should pass auth response request to callback
|
||||
req := newDefaultHttpRequest("/_oauth")
|
||||
req := newHTTPRequest("GET", "http://example.com/_oauth")
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised")
|
||||
|
||||
// Should catch invalid csrf cookie
|
||||
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
|
||||
nonce := "12345678901234567890123456789012"
|
||||
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":http://redirect")
|
||||
c := MakeCSRFCookie(req, "nononononononononononononononono")
|
||||
res, _ = doHttpRequest(req, c)
|
||||
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
|
||||
|
||||
// Should catch invalid provider cookie
|
||||
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:invalid:http://redirect")
|
||||
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":invalid:http://redirect")
|
||||
c = MakeCSRFCookie(req, nonce)
|
||||
res, _ = doHttpRequest(req, c)
|
||||
assert.Equal(401, res.StatusCode, "auth callback with invalid provider shouldn't be authorised")
|
||||
|
||||
// Should redirect valid request
|
||||
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
|
||||
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
|
||||
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":google:http://redirect")
|
||||
c = MakeCSRFCookie(req, nonce)
|
||||
res, _ = doHttpRequest(req, c)
|
||||
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
|
||||
require.Equal(307, res.StatusCode, "valid auth callback should be allowed")
|
||||
|
||||
fwd, _ := res.Location()
|
||||
assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url")
|
||||
@@ -360,17 +393,17 @@ func TestServerRouteHost(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should block any request
|
||||
req := newHttpRequest("GET", "https://example.com/", "/")
|
||||
req := newHTTPRequest("GET", "https://example.com/")
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
||||
|
||||
// Should allow matching request
|
||||
req = newHttpRequest("GET", "https://api.example.com/", "/")
|
||||
req = newHTTPRequest("GET", "https://api.example.com/")
|
||||
res, _ = doHttpRequest(req, nil)
|
||||
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||
|
||||
// Should allow matching request
|
||||
req = newHttpRequest("GET", "https://sub8.example.com/", "/")
|
||||
req = newHTTPRequest("GET", "https://sub8.example.com/")
|
||||
res, _ = doHttpRequest(req, nil)
|
||||
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||
}
|
||||
@@ -386,12 +419,12 @@ func TestServerRouteMethod(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should block any request
|
||||
req := newHttpRequest("GET", "https://example.com/", "/")
|
||||
req := newHTTPRequest("GET", "https://example.com/")
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
||||
|
||||
// Should allow matching request
|
||||
req = newHttpRequest("PUT", "https://example.com/", "/")
|
||||
req = newHTTPRequest("PUT", "https://example.com/")
|
||||
res, _ = doHttpRequest(req, nil)
|
||||
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||
}
|
||||
@@ -441,12 +474,12 @@ func TestServerRouteQuery(t *testing.T) {
|
||||
}
|
||||
|
||||
// Should block any request
|
||||
req := newHttpRequest("GET", "https://example.com/", "/?q=no")
|
||||
req := newHTTPRequest("GET", "https://example.com/?q=no")
|
||||
res, _ := doHttpRequest(req, nil)
|
||||
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
|
||||
|
||||
// Should allow matching request
|
||||
req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123")
|
||||
req = newHTTPRequest("GET", "https://api.example.com/?q=test123")
|
||||
res, _ = doHttpRequest(req, nil)
|
||||
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
|
||||
}
|
||||
@@ -531,16 +564,17 @@ func newDefaultConfig() *Config {
|
||||
return config
|
||||
}
|
||||
|
||||
// TODO: replace with newHTTPRequest("GET", "http://example.com/"+uri)
|
||||
func newDefaultHttpRequest(uri string) *http.Request {
|
||||
return newHttpRequest("", "http://example.com/", uri)
|
||||
return newHTTPRequest("GET", "http://example.com"+uri)
|
||||
}
|
||||
|
||||
func newHttpRequest(method, dest, uri string) *http.Request {
|
||||
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
|
||||
p, _ := url.Parse(dest)
|
||||
func newHTTPRequest(method, target string) *http.Request {
|
||||
u, _ := url.Parse(target)
|
||||
r := httptest.NewRequest(method, target, nil)
|
||||
r.Header.Add("X-Forwarded-Method", method)
|
||||
r.Header.Add("X-Forwarded-Proto", p.Scheme)
|
||||
r.Header.Add("X-Forwarded-Host", p.Host)
|
||||
r.Header.Add("X-Forwarded-Uri", uri)
|
||||
r.Header.Add("X-Forwarded-Proto", u.Scheme)
|
||||
r.Header.Add("X-Forwarded-Host", u.Host)
|
||||
r.Header.Add("X-Forwarded-Uri", u.RequestURI())
|
||||
return r
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user