package main

import (
	"context"
	"crypto/rsa"
	"errors"
	"fmt"
	"io"
	"regexp"

	"github.com/fastly/compute-sdk-go/fsthttp"
	"github.com/golang-jwt/jwt/v5"
	"net/url"
	"strings"
)

// config contains configuration values -
// modify these to change the behavior of this fiddle
var config = map[string]string{
	// anonAccess specifies how to handle JWT verification failure - 'deny' or 'allow'
	//   deny - Redirects to /login with redirect_to query set to original request URL
	//   allow - Sets request header auth-state: anonymous and fetches through
	"anonAccess": "allow",

	// timeInvalid specifies behavior when exp claim is missing - 'block' or 'anonymous'
	"timeInvalid": "block",

	// pathInvalid specifies behavior when request path doesn't match path claim - 'block' or 'anonymous'
	// "pathInvalid": "anonymous",
	"pathInvalid": "block",
}

// publicKeyString is a public key whose private key is used to sign
// tokens used in the tests. It is an RS512 public key that can be
// generated with openssl like:
// ssh-keygen -t rsa -b 4096 -m PEM -E SHA512 -f jwtRS512.key -N ""
// openssl rsa -in jwtRS512.key -pubout -outform PEM -out jwtRS512.key.pub
// cat jwtRS512.key # private key (use this for signing, don't share this!)
// cat jwtRS512.key.pub # public key (pasted below)
const publicKeyString = `\
-----BEGIN PUBLIC KEY-----
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAsxCH4oVozTOx56a0MX8w
wIKqXP3LahmCsMYGdoimoFfaU7+820Ww2zNHuo4dJHLq+zwBOzaCxC9rtCNKiCNR
+n4Oy0zEdV+4nbSMIb1tGcIJOQCXtKrM/y+Dd0dXZdYrdLBkqI7VRHuIdSKQWIAu
jq3W58U4obXWSsTWj40PPN5tgd6yh97qP0sqqVZvmBqhxFmfmpREn0dXhUSKLsUR
3ut84fhsHI1LHyB5I+nh8OSMRuWFwm48+xaDrA2ZDvWQFX1/A9zY0amDUeEGzqbF
MyJB/9TG8OIDdHGf7QsW0W5sa6LwLtzna0yTxs5T3HcL4QBG9ro8w0nJCHGrqtA6
D1uiiK3h8iHgYISYRbVSQwjRZHYg/x7j9glf3xzpdmDzgenms1zH+o3tWiUKMj+m
dv0V71r1lN1KE7l19kLchi0+Cmf0maMqborWseOjZSI3wK9aZ0lOVQOfIrO2Y5bY
whd77Q5STV0KqXsCD11KTKcHUrzYndP/4RYfLlaskN7J9ZGAvdDZ3ZIQb4BngEOb
hpzdeiIa2cn/rfyw2K5dzgiglyGOUDDlfiY+5rbS6J2IIHibX+/N+g+cdA6oMFGV
RSNbx72cVQ0viiMAlremsrkqPIBIw1r+XM6PtR3CWDUegHtuYd+/6IQC4Q+JO4jE
NhE9JVjIx0OSclteIP2SnJ0CAwEAAQ==
-----END PUBLIC KEY-----`

// publicKey is the parsed form of the public key defined above
var publicKey, _ = jwt.ParseRSAPublicKeyFromPEM([]byte(publicKeyString))

// getJwtTokenFromRequest gets JWT token value from request Cookie
func getJwtTokenFromRequest(r *fsthttp.Request) string {
	auth := ""

	cookiesHeader := r.Header.Get("Cookie")
	for _, cookie := range strings.Split(cookiesHeader, ";") {
		cookieSplit := strings.SplitN(cookie, "=", 2)
		if cookieSplit[0] == "auth" {
			auth, _ = url.QueryUnescape(cookieSplit[1])
			break
		}
	}

	return auth
}

// removeJwtTokenFromRequest removes the JWT token value from backend request Cookie
func removeJwtTokenFromRequest(r *fsthttp.Request) {
	cookiesHeader := r.Header.Get("Cookie")
	if cookiesHeader == "" {
		return
	}

	var otherCookies []string
	for _, cookie := range strings.Split(cookiesHeader, ";") {
		cookieSplit := strings.SplitN(cookie, "=", 2)
		if cookieSplit[0] == "auth" {
			continue
		}
		otherCookies = append(otherCookies, cookie)
	}

	r.Header.Set("Cookie", strings.Join(otherCookies, ";"))
}

// validateJwtSignature validates a JWT token in the request
func validateJwtSignature(r *fsthttp.Request, key *rsa.PublicKey) (*jwt.Token, error) {
	token := getJwtTokenFromRequest(r)
	if token == "" {
		return nil, errors.New("no JWT in request")
	}

	// This parses and validates the payload string
	payload, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
		// Don't forget to validate the alg is what you expect:
		if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
		}

		return key, nil
	})
	if err != nil {
		return nil, err
	}

	// jwt.Parse() will already verify the "expires" and "not before" timestamps
	// against the current time.
	// This additional check will make sure that we don't pass
	// the check when the exp claim is not present in the payload.
	if config["timeInvalid"] == "block" {
		_, err := payload.Claims.GetExpirationTime()
		if err != nil {
			return nil, errors.New("exp claim not present")
		}
	}

	// Check path constraint
	if config["pathInvalid"] == "block" {
		path, _ := payload.Claims.(jwt.MapClaims)["path"].(string)
		if path != "" {
			pattern := regexp.MustCompile(`([.?+^$[\\(){}|/-])`).ReplaceAllString(path, `\$1`)
			pattern = regexp.MustCompile(`\*`).ReplaceAllString(pattern, `.*`)

			regex := regexp.MustCompile(pattern)
			if !regex.MatchString(r.URL.Path) {
				return nil, errors.New("path claim not matched")
			}
		}
	}

	return payload, err
}

func buildRedirectResponse(url *url.URL, message string) *fsthttp.Response {
	u, _ := url.Parse("/login")
	location := url.ResolveReference(u)

	returnTo := url.Path
	if url.RawQuery != "" {
		returnTo += "?" + url.RawQuery
	}

	q := location.Query()
	q.Set("return_to", returnTo)
	location.RawQuery = q.Encode()

	h := fsthttp.NewHeader()
	h.Set("Location", location.String())
	if message != "" {
		h.Set("fastly-jwt-error", message)
	}

	return &fsthttp.Response{
		Header:     h,
		StatusCode: fsthttp.StatusTemporaryRedirect,
	}
}

func handleRequest(ctx context.Context, r *fsthttp.Request) *fsthttp.Response {
	// Modify request for backend
	req := r.Clone()
	req.Body = r.Body

	fmt.Println("Checking JWT...")

	token, err := validateJwtSignature(r, publicKey)

	var requiredTag = ""
	if err == nil {
		// We passed all the verification
		fmt.Println("JWT Token verified successfully!")

		req.Header.Set("auth-state", "authenticated")
		userid, _ := token.Claims.(jwt.MapClaims)["uid"].(string)
		req.Header.Set("auth-userid", userid)
		groups, _ := token.Claims.(jwt.MapClaims)["groups"].(string)
		req.Header.Set("auth-groups", groups)
		name, _ := token.Claims.(jwt.MapClaims)["name"].(string)
		req.Header.Set("auth-name", name)
		var adminValue string
		admin, _ := token.Claims.(jwt.MapClaims)["admin"].(bool)
		if admin {
			adminValue = "1"
		} else {
			adminValue = "0"
		}
		req.Header.Set("auth-is-admin", adminValue)
		requiredTag, _ = token.Claims.(jwt.MapClaims)["tag"].(string)

		// If the token was valid, we don't want to pass it through to backend
		removeJwtTokenFromRequest(req)
	} else {
		if config["anonAccess"] == "deny" {
			fmt.Printf("Response redirect, %v\n", err.Error())
			return buildRedirectResponse(r.URL, err.Error())
		}

		fmt.Printf("Allow anonymous access, %v\n", err.Error())
		req.Header.Set("auth-state", "anonymous")
	}

	// Perform fetch
	resp, err := req.Send(ctx, "origin_0")

	// Check Tag availability in response
	if requiredTag != "" {
		surrogateKeyHeader := regexp.MustCompile(`\s`).Split(resp.Header.Get("surrogate-key"), -1)
		requiredTagFound := false
		for _, v := range surrogateKeyHeader {
			if v == requiredTag {
				requiredTagFound = true
				break
			}
		}
		if !requiredTagFound {
			return buildRedirectResponse(req.URL, "Required tag missing")
		}
	}

	return resp
}

func main() {
	fsthttp.ServeFunc(func(ctx context.Context, w fsthttp.ResponseWriter, r *fsthttp.Request) {
		resp := handleRequest(ctx, r)

		w.Header().Reset(resp.Header)
		w.WriteHeader(resp.StatusCode)
		if resp.Body != nil {
			io.Copy(w, resp.Body)
		}
	})
}