GitBucket
4.21.2
Toggle navigation
Snippets
Sign in
Files
Branches
1
Releases
Issues
1
Pull requests
Labels
Priorities
Milestones
Wiki
Forks
mark.george
/
httpcat
Browse code
Improve CORS handling
Closes
#1
1 parent
adaac98
commit
100c322f8ca391e0a1acb6a8c2a5d7fc0f29da1f
Mark George
authored
on 16 Apr 2022
Patch
Showing
2 changed files
src/parsing.go
src/server.go
Ignore Space
Show notes
View
src/parsing.go
/** Author: Mark George <mark.george@otago.ac.nz> License: Zero-Clause BSD License */ package main import ( "errors" "strconv" "strings" ) var ( headers map[string]string = make(map[string]string) bodies map[string]string = make(map[string]string) statuses map[string]int = make(map[string]int) ) func parseHeader(header string) (string, string, error) { parts := strings.Split(header, ":") if len(parts) < 2 { return "", "", errors.New("header is not in the correct format") } else { name := strings.TrimSpace(parts[0]) value := strings.TrimSpace(parts[1]) // value had a ':' in it (probably a URL) so append the remaining part if len(parts) == 3 { value = value + ":" + strings.TrimSpace(parts[2]) } headers[name] = value return name, value, nil } } func parseRoute(route string) (string, string, int, error) { parts := strings.Split(route, "|") if len(parts) != 3 { return "", "", 0, errors.New("route is not in the correct format") } else { path := parts[0] body := parts[1] status := parts[2] code, err := strconv.Atoi(status) if err != nil { return "", "", 0, errors.New("route is not in the correct format") } bodies[path] = body statuses[path] = code return path, body, code, nil } }
/** Author: Mark George <mark.george@otago.ac.nz> License: Zero-Clause BSD License */ package main import ( "errors" "strconv" "strings" ) var ( headers map[string]string = make(map[string]string) bodies map[string]string = make(map[string]string) statuses map[string]int = make(map[string]int) ) func parseHeader(header string) (string, string, error) { parts := strings.Split(header, ":") if len(parts) != 2 { return "", "", errors.New("header is not in the correct format") } else { name := parts[0] value := parts[1] headers[name] = value return name, value, nil } } func parseRoute(route string) (string, string, int, error) { parts := strings.Split(route, "|") if len(parts) != 3 { return "", "", 0, errors.New("route is not in the correct format") } else { path := parts[0] body := parts[1] status := parts[2] code, err := strconv.Atoi(status) if err != nil { return "", "", 0, errors.New("route is not in the correct format") } bodies[path] = body statuses[path] = code return path, body, code, nil } }
Ignore Space
Show notes
View
src/server.go
/** Author: Mark George <mark.george@otago.ac.nz> License: Zero-Clause BSD License */ package main import ( "fmt" "net/http" "os" "strconv" ) type ServerCommand struct { Port int `short:"p" long:"port" description:"Port to listen on." default:"8080"` Routes []string `short:"r" long:"route" description:"Route which is made up of a path, a response body, and a response status, all separated by the pipe '|' character. Repeat for additional routes." default:"/|testing|200"` Cors bool `short:"c" long:"cors" description:"Enable Cross Origin Resource Sharing (CORS)."` Headers []string `short:"H" long:"header" description:"A header to add to response in the form name:value. Repeat for additional headers."` } var ( serverOptions *ServerCommand ) func (opts *ServerCommand) Execute(args []string) error { serverOptions = opts // parse custom headers if provided if len(opts.Headers) > 0 { for _, header := range opts.Headers { if _, _, err := parseHeader(header); err != nil { fmt.Println("The following header is not in the correct format: " + header + "\nCorrect format is: 'name:value'") os.Exit(1) } } } fmt.Println("HTTP server listening on port " + strconv.Itoa(opts.Port) + "\n") if len(opts.Routes) > 0 { fmt.Println("Routes:") for _, route := range opts.Routes { if path, _, _, err := parseRoute(route); err == nil { fmt.Println(" " + path) http.HandleFunc(path, serverRequestHandler) } else { fmt.Println("The following route is not in the correct format: " + route + "\nCorrect format is: \"/path|response body|status code\"") os.Exit(1) } } } if err := http.ListenAndServe(":"+strconv.Itoa(opts.Port), nil); err != nil { fmt.Fprintf(os.Stderr, "Could not start server. Is port %d available?\n", opts.Port) os.Exit(1) } return nil } func serverRequestHandler(rsp http.ResponseWriter, req *http.Request) { path := req.RequestURI body := bodies[path] code := statuses[path] showRequest(req) // CORS preflight if serverOptions.Cors && req.Method == http.MethodOptions { // was the OPTIONS request actually a CORS request? if secFetchMode := req.Header.Get("Sec-Fetch-Mode"); secFetchMode == "cors" { if globalOptions.Verbose { fmt.Println("Responding to CORS preflight") } if origin := req.Header.Get("Origin"); origin != "" { if globalOptions.Verbose { fmt.Println("Adding CORS response header - Access-Control-Allow-Origin: " + origin) } // allow the origin rsp.Header().Set("Access-Control-Allow-Origin", origin) if headers := req.Header.Get("Access-Control-Request-Headers"); headers != "" { if globalOptions.Verbose { fmt.Println("Adding CORS response header - Access-Control-Allow-Headers: " + headers) } // allow any requested headers rsp.Header().Set("Access-Control-Allow-Headers", headers) } if method := req.Header.Get("Access-Control-Request-Method"); method != "" { if globalOptions.Verbose { fmt.Println("Adding CORS response header - Access-Control-Allow-Methods: " + method) } // allow any requested methods rsp.Header().Set("Access-Control-Allow-Methods", method) } if globalOptions.Verbose { fmt.Println("Adding CORS response header - Access-Control-Max-Age: " + "600") } // allow browser to reuse the permissions for 10 minutes rsp.Header().Set("Access-Control-Max-Age", "600") rsp.WriteHeader(200) return } else { if globalOptions.Verbose { fmt.Println("CORS preflight did not contain 'Origin' header!") } rsp.Header().Set("Content-Type", "text/plain") rsp.WriteHeader(400) fmt.Fprint(rsp, "Origin header was not provided") return } } } // add custom headers if provided if len(serverOptions.Headers) > 0 { for name, value := range headers { if globalOptions.Verbose { fmt.Println("Adding response header - " + name + ": " + value) } rsp.Header().Set(name, value) } } // add CORS allow-origin header if serverOptions.Cors && req.Method != http.MethodOptions { if origin := req.Header.Get("Origin"); origin != "" { if globalOptions.Verbose { fmt.Println("Adding CORS response header - Access-Control-Allow-Origin: " + origin) } rsp.Header().Set("Access-Control-Allow-Origin", origin) } } // add status line rsp.WriteHeader(code) // add body fmt.Fprint(rsp, body) }
/** Author: Mark George <mark.george@otago.ac.nz> License: Zero-Clause BSD License */ package main import ( "fmt" "net/http" "os" "strconv" ) type ServerCommand struct { Port int `short:"p" long:"port" description:"Port to listen on." default:"8080"` Routes []string `short:"r" long:"route" description:"Route which is made up of a path, a response body, and a response status, all separated by the pipe '|' character. Repeat for additional routes." default:"/|testing|200"` Cors bool `short:"c" long:"cors" description:"Enable Cross Origin Resource Sharing (CORS)."` Headers []string `short:"H" long:"header" description:"A header to add to response in the form name:value. Repeat for additional headers."` } var ( serverOptions *ServerCommand ) func (opts *ServerCommand) Execute(args []string) error { serverOptions = opts // parse custom headers if provided if len(opts.Headers) > 0 { for _, header := range opts.Headers { if _, _, err := parseHeader(header); err != nil { fmt.Println("The following header is not in the correct format: " + header + "\nCorrect format is: 'name:value'") os.Exit(1) } } } fmt.Println("HTTP server listening on port " + strconv.Itoa(opts.Port) + "\n") if len(opts.Routes) > 0 { fmt.Println("Routes:") for _, route := range opts.Routes { if path, _, _, err := parseRoute(route); err == nil { fmt.Println(" " + path) http.HandleFunc(path, serverRequestHandler) } else { fmt.Println("The following route is not in the correct format: " + route + "\nCorrect format is: \"/path|response body|status code\"") os.Exit(1) } } } if err := http.ListenAndServe(":"+strconv.Itoa(opts.Port), nil); err != nil { fmt.Fprintf(os.Stderr, "Could not start server. Is port %d available?\n", opts.Port) os.Exit(1) } return nil } func serverRequestHandler(rsp http.ResponseWriter, req *http.Request) { path := req.RequestURI body := bodies[path] code := statuses[path] showRequest(req) // CORS if serverOptions.Cors && req.Method == http.MethodOptions { if globalOptions.Verbose { fmt.Println("CORS preflight") } if origin := req.Header.Get("Origin"); origin != "" { rsp.Header().Set("Access-Control-Allow-Origin", origin) if headers := req.Header.Get("Access-Control-Request-Headers"); headers != "" { rsp.Header().Set("Access-Control-Allow-Headers", headers) } if method := req.Header.Get("Access-Control-Request-Method"); method != "" { rsp.Header().Set("Access-Control-Allow-Method", method) } rsp.WriteHeader(200) return } else { fmt.Println("WARNING - CORS preflight did not contain 'Origin' header!") } } // add custom headers if provided if len(serverOptions.Headers) > 0 { for name, value := range headers { if globalOptions.Verbose { fmt.Println("Adding response header - " + name + ": " + value) } rsp.Header().Set(name, value) } } // add status line rsp.WriteHeader(code) // add body fmt.Fprintf(rsp, body) }
Show line notes below