// https://developer.fastly.com/solutions/examples/validate-domain-on-query-string

package main

import (
	"context"
	"fmt"
	"io"
  "net/url"
  "regexp"
  "strings"

	"github.com/fastly/compute-sdk-go/fsthttp"
)

// BackendName is the name of our service backend.
const BackendName = "origin_0"

// List of valid domains.
func ValidDomain() []string{
   return []string{
  "example.com",
  "fastly.com",
   }
}

func main() {
	fsthttp.ServeFunc(func(ctx context.Context, w fsthttp.ResponseWriter, r *fsthttp.Request) {
		// This requires your service to be configured with a backend
		// named "origin_0" and pointing to "https://http-me.glitch.me".

    // Parse query string and store them in a map of values.
    // Percent-encoded strings are decoded by default. 
    q, _ := url.ParseQuery(r.URL.RawQuery)
    // Check if url query string exists.
    u, exist := q["url"]
    // If it exists, then check if it is a valid domain.
    if exist {
      if IsValidDomain(u[0]){
        fmt.Println("Valid domain.")
      }else{
        fmt.Println("Invalid domain.")
         // Generate a synthetic 403 status response.
        resp := &fsthttp.Response{
          StatusCode: fsthttp.StatusForbidden,
          Body: io.NopCloser(strings.NewReader("")),
        }
        flush(resp,w)
        return
      }
    }else{
      fmt.Println("Unable to extract domain from query string.")
       // Generate a synthetic 400 status response.
      resp := &fsthttp.Response{
        StatusCode: fsthttp.StatusBadRequest,
        Body: io.NopCloser(strings.NewReader("")),
      }
      flush(resp,w)
      return
    }
    
		resp, err := r.Send(ctx, BackendName)
		if err != nil {
			w.WriteHeader(fsthttp.StatusBadGateway)
			fmt.Fprintln(w, err.Error())
			return
		}

		// Write response to client.
		flush(resp,w)
	})
}

// IsValidDomain() function checks if a string is in  
// the list of valid domains.
func IsValidDomain(d string) bool {
  re := regexp.MustCompile(`^https?://([^/]*)/`)
  domain := re.FindStringSubmatch(d)
  for _, valid := range ValidDomain() {
    if domain[1] == valid {
      return true 
    }
  }
  return false
}

func flush(resp *fsthttp.Response, w fsthttp.ResponseWriter) {
	w.Header().Reset(resp.Header)
	w.WriteHeader(resp.StatusCode)
	io.Copy(w, resp.Body)
}