Newer
Older
httpcat / src / server.go
/**
  Author: Mark George <mark.george@otago.ac.nz>
  License: Zero-Clause BSD License: https://opensource.org/license/0bsd/
*/

package main

import (
	"fmt"
	"net/http"
	"os"
	"strconv"
)

type ServerCommand struct {
	Port    int      `short:"p" long:"port" description:"Port to listen on." default:"8080"`
	Routes  []string `short:"r" long:"route" description:"Add a route which is made up of a path, a response body, a content type, and a response status code, all separated by the pipe '|' character.  Repeat for additional routes.  Example '/|Testing 123|text/plain|200'."`
	Cors    bool     `short:"c" long:"cors" description:"Enable Cross Origin Resource Sharing (CORS).  Allows all requested methods and headers."`
	Headers []string `short:"H" long:"header" description:"A header to add to response in the form name|value. Use multiple times for multiple headers."`
}

var (
	serverOptions *ServerCommand
)

func (opts *ServerCommand) Execute(args []string) error {

	serverOptions = opts

	// parse custom headers if provided
	if len(opts.Headers) > 0 {
		for _, header := range opts.Headers {
			if _, _, err := parseHeader(header); err != nil {
				fmt.Println("The following header is not in the correct format: " + header + "\nCorrect format is: 'name|value'")
				os.Exit(1)
			}
		}
	}

	fmt.Println("HTTP server listening on port " + strconv.Itoa(opts.Port) + "\n")

	if len(opts.Routes) > 0 {
		fmt.Println("Routes:")
		for _, route := range opts.Routes {
			if path, err := parseRoute(route); err == nil {
				fmt.Println("  " + path)
				http.HandleFunc(path, serverRequestHandler)
			} else {
				fmt.Println("The following route is not in the correct format: " + route + "\nCorrect format is: \"/path|response body|content type|status code\"")
				os.Exit(1)
			}
		}
		fmt.Println()
	} else {
		// wildcard mode
		http.HandleFunc("/", serverRequestHandler)
	}

	if err := http.ListenAndServe(":"+strconv.Itoa(opts.Port), nil); err != nil {
		fmt.Fprintf(os.Stderr, "Could not start server.  Is port %d available?\n", opts.Port)
		os.Exit(1)
	}

	return nil
}

func serverRequestHandler(rsp http.ResponseWriter, req *http.Request) {

	path := req.RequestURI
	body := bodies[path]
	contentType := contentTypes[path]
	code := statuses[path]

	// are we in wildcard mode?
	if code == 0 && len(statuses) == 0 {
		if req.Method == http.MethodGet {
			code = 200
		} else {
			code = 204
		}
	} else if code == 0 {
		// not in wildcard mode so this is a bad path
		code = 404
	}

	showRequest(req)

	// CORS preflight
	if serverOptions.Cors && req.Method == http.MethodOptions {

		// was the OPTIONS request actually a CORS request?
		if secFetchMode := req.Header.Get("Sec-Fetch-Mode"); secFetchMode == "cors" {

			if globalOptions.Verbose {
				fmt.Println("Responding to CORS preflight")
			}

			if origin := req.Header.Get("Origin"); origin != "" {

				if globalOptions.Verbose {
					fmt.Println("Adding CORS response header - Access-Control-Allow-Origin: " + origin)
				}

				// allow the origin
				rsp.Header().Set("Access-Control-Allow-Origin", origin)

				if headers := req.Header.Get("Access-Control-Request-Headers"); headers != "" {

					if globalOptions.Verbose {
						fmt.Println("Adding CORS response header - Access-Control-Allow-Headers: " + headers)
					}

					// allow any requested headers
					rsp.Header().Set("Access-Control-Allow-Headers", headers)
				}

				if method := req.Header.Get("Access-Control-Request-Method"); method != "" {
					if globalOptions.Verbose {
						fmt.Println("Adding CORS response header - Access-Control-Allow-Methods: " + method)
					}
					// allow any requested methods
					rsp.Header().Set("Access-Control-Allow-Methods", method)
				}

				if globalOptions.Verbose {
					fmt.Println("Adding CORS response header - Access-Control-Max-Age: " + "600")
				}
				// allow browser to reuse the permissions for 10 minutes
				rsp.Header().Set("Access-Control-Max-Age", "600")

				rsp.WriteHeader(200)

				return

			} else {
				if globalOptions.Verbose {
					fmt.Println("CORS preflight did not contain 'Origin' header!")
				}
				rsp.Header().Set("Content-Type", "text/plain")
				rsp.WriteHeader(400)
				fmt.Fprint(rsp, "Origin header was not provided")

				return
			}
		}

	}

	// add custom headers if provided
	if len(serverOptions.Headers) > 0 {

		for name, value := range headers {

			if globalOptions.Verbose {
				fmt.Println("Adding response header - " + name + ": " + value)
			}

			rsp.Header().Set(name, value)
		}
	}

	// add CORS allow-origin header
	if serverOptions.Cors && req.Method != http.MethodOptions {
		if origin := req.Header.Get("Origin"); origin != "" {
			if globalOptions.Verbose {
				fmt.Println("Adding CORS response header - Access-Control-Allow-Origin: " + origin)
			}
			rsp.Header().Set("Access-Control-Allow-Origin", origin)
		}
	}

	// add content type if specified
	if len(contentType) > 0 {
		rsp.Header().Set("Content-Type", contentType)
	}

	// add status line
	rsp.WriteHeader(code)

	// add body
	fmt.Fprint(rsp, body)
}