/**
Author: Mark George <mark.george@otago.ac.nz>
License: Zero-Clause BSD License
*/
package main
import (
"fmt"
"net/http"
"os"
"strconv"
)
type ServerCommand struct {
Port int `short:"p" long:"port" description:"Port to listen on." default:"8080"`
Routes []string `short:"r" long:"route" description:"Route which is made up of a path, a response body, and a response status, all separated by the pipe '|' character. Repeat for additional routes." default:"/|testing|200"`
Cors bool `short:"c" long:"cors" description:"Enable Cross Origin Resource Sharing (CORS)."`
Headers []string `short:"H" long:"header" description:"A header to add to response in the form name:value. Repeat for additional headers."`
}
var (
serverOptions *ServerCommand
)
func (opts *ServerCommand) Execute(args []string) error {
serverOptions = opts
// parse custom headers if provided
if len(opts.Headers) > 0 {
for _, header := range opts.Headers {
if _, _, err := parseHeader(header); err != nil {
fmt.Println("The following header is not in the correct format: " + header + "\nCorrect format is: 'name:value'")
os.Exit(1)
}
}
}
fmt.Println("HTTP server listening on port " + strconv.Itoa(opts.Port) + "\n")
if len(opts.Routes) > 0 {
fmt.Println("Routes:")
for _, route := range opts.Routes {
if path, _, _, err := parseRoute(route); err == nil {
fmt.Println(" " + path)
http.HandleFunc(path, serverRequestHandler)
} else {
fmt.Println("The following route is not in the correct format: " + route + "\nCorrect format is: \"/path|response body|status code\"")
os.Exit(1)
}
}
}
if err := http.ListenAndServe(":"+strconv.Itoa(opts.Port), nil); err != nil {
fmt.Fprintf(os.Stderr, "Could not start server. Is port %d available?\n", opts.Port)
os.Exit(1)
}
return nil
}
func serverRequestHandler(rsp http.ResponseWriter, req *http.Request) {
path := req.RequestURI
body := bodies[path]
code := statuses[path]
showRequest(req)
// CORS
if serverOptions.Cors && req.Method == http.MethodOptions {
if globalOptions.Verbose {
fmt.Println("CORS preflight")
}
if origin := req.Header.Get("Origin"); origin != "" {
rsp.Header().Set("Access-Control-Allow-Origin", origin)
if headers := req.Header.Get("Access-Control-Request-Headers"); headers != "" {
rsp.Header().Set("Access-Control-Allow-Headers", headers)
}
if method := req.Header.Get("Access-Control-Request-Method"); method != "" {
rsp.Header().Set("Access-Control-Allow-Method", method)
}
rsp.WriteHeader(200)
return
} else {
fmt.Println("WARNING - CORS preflight did not contain 'Origin' header!")
}
}
// add custom headers if provided
if len(serverOptions.Headers) > 0 {
for name, value := range headers {
if globalOptions.Verbose {
fmt.Println("Adding response header - " + name + ": " + value)
}
rsp.Header().Set(name, value)
}
}
// add status line
rsp.WriteHeader(code)
// add body
fmt.Fprintf(rsp, body)
}