Newer
Older
httpcat / src / server.go
/**
  Author: Mark George <mark.george@otago.ac.nz>
  License: Zero-Clause BSD License
*/

package main

import (
	"fmt"
	"net/http"
	"os"
	"strconv"
)

type ServerCommand struct {
	Port    int      `short:"p" long:"port" description:"Port to listen on." default:"8080"`
	Routes  []string `short:"r" long:"route" description:"Route which is made up of a path, a response body, and a response status, all separated by the pipe '|' character.  Repeat for additional routes." default:"/|testing|200"`
	Cors    bool     `short:"c" long:"cors" description:"Enable Cross Origin Resource Sharing (CORS)."`
	Headers []string `short:"H" long:"header" description:"A header to add to response in the form name:value. Repeat for additional headers."`
}

var (
	serverOptions *ServerCommand
)

func (opts *ServerCommand) Execute(args []string) error {

	serverOptions = opts

	// parse custom headers if provided
	if len(opts.Headers) > 0 {
		for _, header := range opts.Headers {
			if _, _, err := parseHeader(header); err != nil {
				fmt.Println("The following header is not in the correct format: " + header + "\nCorrect format is: 'name:value'")
				os.Exit(1)
			}
		}
	}

	fmt.Println("HTTP server listening on port " + strconv.Itoa(opts.Port) + "\n")

	if len(opts.Routes) > 0 {
		fmt.Println("Routes:")
		for _, route := range opts.Routes {
			if path, _, _, err := parseRoute(route); err == nil {
				fmt.Println("  " + path)
				http.HandleFunc(path, serverRequestHandler)
			} else {
				fmt.Println("The following route is not in the correct format: " + route + "\nCorrect format is: \"/path|response body|status code\"")
				os.Exit(1)
			}
		}
	}

	if err := http.ListenAndServe(":"+strconv.Itoa(opts.Port), nil); err != nil {
		fmt.Fprintf(os.Stderr, "Could not start server.  Is port %d available?\n", opts.Port)
		os.Exit(1)
	}

	return nil
}

func serverRequestHandler(rsp http.ResponseWriter, req *http.Request) {

	path := req.RequestURI
	body := bodies[path]
	code := statuses[path]

	showRequest(req)

	// CORS
	if serverOptions.Cors && req.Method == http.MethodOptions {

		if globalOptions.Verbose {
			fmt.Println("CORS preflight")
		}

		if origin := req.Header.Get("Origin"); origin != "" {

			rsp.Header().Set("Access-Control-Allow-Origin", origin)

			if headers := req.Header.Get("Access-Control-Request-Headers"); headers != "" {
				rsp.Header().Set("Access-Control-Allow-Headers", headers)
			}

			if method := req.Header.Get("Access-Control-Request-Method"); method != "" {
				rsp.Header().Set("Access-Control-Allow-Method", method)
			}

			rsp.WriteHeader(200)

			return

		} else {
			fmt.Println("WARNING - CORS preflight did not contain 'Origin' header!")
		}
	}

	// add custom headers if provided
	if len(serverOptions.Headers) > 0 {

		for name, value := range headers {

			if globalOptions.Verbose {
				fmt.Println("Adding response header - " + name + ": " + value)
			}

			rsp.Header().Set(name, value)
		}
	}

	// add status line
	rsp.WriteHeader(code)

	// add body
	fmt.Fprintf(rsp, body)

}