package main
import (
"context"
"crypto/rsa"
"errors"
"fmt"
"io"
"regexp"
"github.com/fastly/compute-sdk-go/fsthttp"
"github.com/golang-jwt/jwt/v5"
"net/url"
"strings"
)
var config = map[string]string{
"anonAccess": "allow",
"timeInvalid": "block",
"pathInvalid": "block",
}
const publicKeyString = `\
-----BEGIN PUBLIC KEY-----
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAsxCH4oVozTOx56a0MX8w
wIKqXP3LahmCsMYGdoimoFfaU7+820Ww2zNHuo4dJHLq+zwBOzaCxC9rtCNKiCNR
+n4Oy0zEdV+4nbSMIb1tGcIJOQCXtKrM/y+Dd0dXZdYrdLBkqI7VRHuIdSKQWIAu
jq3W58U4obXWSsTWj40PPN5tgd6yh97qP0sqqVZvmBqhxFmfmpREn0dXhUSKLsUR
3ut84fhsHI1LHyB5I+nh8OSMRuWFwm48+xaDrA2ZDvWQFX1/A9zY0amDUeEGzqbF
MyJB/9TG8OIDdHGf7QsW0W5sa6LwLtzna0yTxs5T3HcL4QBG9ro8w0nJCHGrqtA6
D1uiiK3h8iHgYISYRbVSQwjRZHYg/x7j9glf3xzpdmDzgenms1zH+o3tWiUKMj+m
dv0V71r1lN1KE7l19kLchi0+Cmf0maMqborWseOjZSI3wK9aZ0lOVQOfIrO2Y5bY
whd77Q5STV0KqXsCD11KTKcHUrzYndP/4RYfLlaskN7J9ZGAvdDZ3ZIQb4BngEOb
hpzdeiIa2cn/rfyw2K5dzgiglyGOUDDlfiY+5rbS6J2IIHibX+/N+g+cdA6oMFGV
RSNbx72cVQ0viiMAlremsrkqPIBIw1r+XM6PtR3CWDUegHtuYd+/6IQC4Q+JO4jE
NhE9JVjIx0OSclteIP2SnJ0CAwEAAQ==
-----END PUBLIC KEY-----`
var publicKey, _ = jwt.ParseRSAPublicKeyFromPEM([]byte(publicKeyString))
func getJwtTokenFromRequest(r *fsthttp.Request) string {
auth := ""
cookiesHeader := r.Header.Get("Cookie")
for _, cookie := range strings.Split(cookiesHeader, ";") {
cookieSplit := strings.SplitN(cookie, "=", 2)
if cookieSplit[0] == "auth" {
auth, _ = url.QueryUnescape(cookieSplit[1])
break
}
}
return auth
}
func removeJwtTokenFromRequest(r *fsthttp.Request) {
cookiesHeader := r.Header.Get("Cookie")
if cookiesHeader == "" {
return
}
var otherCookies []string
for _, cookie := range strings.Split(cookiesHeader, ";") {
cookieSplit := strings.SplitN(cookie, "=", 2)
if cookieSplit[0] == "auth" {
continue
}
otherCookies = append(otherCookies, cookie)
}
r.Header.Set("Cookie", strings.Join(otherCookies, ";"))
}
func validateJwtSignature(r *fsthttp.Request, key *rsa.PublicKey) (*jwt.Token, error) {
token := getJwtTokenFromRequest(r)
if token == "" {
return nil, errors.New("no JWT in request")
}
payload, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return key, nil
})
if err != nil {
return nil, err
}
if config["timeInvalid"] == "block" {
_, err := payload.Claims.GetExpirationTime()
if err != nil {
return nil, errors.New("exp claim not present")
}
}
if config["pathInvalid"] == "block" {
path, _ := payload.Claims.(jwt.MapClaims)["path"].(string)
if path != "" {
pattern := regexp.MustCompile(`([.?+^$[\\(){}|/-])`).ReplaceAllString(path, `\$1`)
pattern = regexp.MustCompile(`\*`).ReplaceAllString(pattern, `.*`)
regex := regexp.MustCompile(pattern)
if !regex.MatchString(r.URL.Path) {
return nil, errors.New("path claim not matched")
}
}
}
return payload, err
}
func buildRedirectResponse(url *url.URL, message string) *fsthttp.Response {
u, _ := url.Parse("/login")
location := url.ResolveReference(u)
returnTo := url.Path
if url.RawQuery != "" {
returnTo += "?" + url.RawQuery
}
q := location.Query()
q.Set("return_to", returnTo)
location.RawQuery = q.Encode()
h := fsthttp.NewHeader()
h.Set("Location", location.String())
if message != "" {
h.Set("fastly-jwt-error", message)
}
return &fsthttp.Response{
Header: h,
StatusCode: fsthttp.StatusTemporaryRedirect,
}
}
func handleRequest(ctx context.Context, r *fsthttp.Request) *fsthttp.Response {
req := r.Clone()
req.Body = r.Body
fmt.Println("Checking JWT...")
token, err := validateJwtSignature(r, publicKey)
var requiredTag = ""
if err == nil {
fmt.Println("JWT Token verified successfully!")
req.Header.Set("auth-state", "authenticated")
userid, _ := token.Claims.(jwt.MapClaims)["uid"].(string)
req.Header.Set("auth-userid", userid)
groups, _ := token.Claims.(jwt.MapClaims)["groups"].(string)
req.Header.Set("auth-groups", groups)
name, _ := token.Claims.(jwt.MapClaims)["name"].(string)
req.Header.Set("auth-name", name)
var adminValue string
admin, _ := token.Claims.(jwt.MapClaims)["admin"].(bool)
if admin {
adminValue = "1"
} else {
adminValue = "0"
}
req.Header.Set("auth-is-admin", adminValue)
requiredTag, _ = token.Claims.(jwt.MapClaims)["tag"].(string)
removeJwtTokenFromRequest(req)
} else {
if config["anonAccess"] == "deny" {
fmt.Printf("Response redirect, %v\n", err.Error())
return buildRedirectResponse(r.URL, err.Error())
}
fmt.Printf("Allow anonymous access, %v\n", err.Error())
req.Header.Set("auth-state", "anonymous")
}
resp, err := req.Send(ctx, "origin_0")
if requiredTag != "" {
surrogateKeyHeader := regexp.MustCompile(`\s`).Split(resp.Header.Get("surrogate-key"), -1)
requiredTagFound := false
for _, v := range surrogateKeyHeader {
if v == requiredTag {
requiredTagFound = true
break
}
}
if !requiredTagFound {
return buildRedirectResponse(req.URL, "Required tag missing")
}
}
return resp
}
func main() {
fsthttp.ServeFunc(func(ctx context.Context, w fsthttp.ResponseWriter, r *fsthttp.Request) {
resp := handleRequest(ctx, r)
w.Header().Reset(resp.Header)
w.WriteHeader(resp.StatusCode)
if resp.Body != nil {
io.Copy(w, resp.Body)
}
})
}