Newer
Older
httpcat / src / proxy.go
/**
  Author: Mark George <mark.george@otago.ac.nz>
  License: Zero-Clause BSD License: https://opensource.org/license/0bsd/
*/

package main

import (
	"fmt"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strconv"
	"strings"
)

var proxyOpts *ProxyCommand

type ProxyCommand struct {
	Port   int    `short:"p" long:"port"   description:"The port that the proxy listens on."                                    required:"true"`
	Target string `short:"t" long:"target" description:"The URL for the target web server that the proxy forwards requests to." required:"true"`
	Filter string `short:"f" long:"filter" description:"Only show bodies that match the given Content-Type."`
}

func (opts *ProxyCommand) Execute(args []string) error {
	fmt.Println("Proxying port " + strconv.Itoa(opts.Port) + " to " + opts.Target)

	proxyOpts = opts

	proxy, err := CreateProxy(opts.Target)
	if err != nil {
		panic(err)
	}

	http.HandleFunc("/", RequestHandler(proxy))
	if err := http.ListenAndServe(":"+strconv.Itoa(opts.Port), nil); err != nil {
		panic(err)
	}

	return nil
}

func CreateProxy(target string) (*httputil.ReverseProxy, error) {
	url, err := url.Parse(target)
	// bad target URL
	if err != nil {
		return nil, err
	}

	proxy := httputil.NewSingleHostReverseProxy(url)

	OriginalDirector := proxy.Director

	proxy.Director = func(req *http.Request) {
		FixHeaders(url, req)

		if proxyOpts.Filter != "" {

			contentType, hasContentType := req.Header["Content-Type"]

			// show body if it matches filter, or has no content-type header
			if !hasContentType || contentType[0] == proxyOpts.Filter {
				showRequest(req)
			} else {
				showFilteredRequest(req)
			}

		} else {
			showRequest(req)
		}

		// send request to target
		OriginalDirector(req)
	}

	proxy.ModifyResponse = func(rsp *http.Response) error {
		// rewrite the port for redirect responses
		switch rsp.StatusCode {
		case 301, 302, 303, 307, 308:
			url, _ := url.Parse(rsp.Header["Location"][0])
			newUrl := RewritePort(url, strconv.Itoa(proxyOpts.Port))
			rsp.Header["Location"][0] = newUrl.String()
		}

		if proxyOpts.Filter != "" {

			contentType, hasContentType := rsp.Header["Content-Type"]

			// show body if it matches filter, or has no content-type header
			if !hasContentType || contentType[0] == proxyOpts.Filter {
				showResponse(rsp)
			} else {
				showFilteredResponse(rsp)
			}

		} else {
			showResponse(rsp)
		}

		return nil
	}

	proxy.ErrorHandler = func(rsp http.ResponseWriter, req *http.Request, err error) {
		fmt.Println("Error sending request to target server:")
		fmt.Println("   " + err.Error())
	}

	return proxy, nil
}

func FixHeaders(url *url.URL, req *http.Request) {
	// TODO this may be necessary for proxying HTTPS (need to test)
	req.URL.Host = url.Host
	req.URL.Scheme = url.Scheme
	req.Host = url.Host
}

func RewritePort(url *url.URL, newPort string) url.URL {
	var newUrl string

	if url.Port() != "" {
		// port exists, so replace :port with :newport
		newUrl = strings.Replace(url.String(), ":"+url.Port(), ":"+newPort, 1)
	} else {
		// no port, so replace hostname with hostname:port
		newUrl = strings.Replace(url.String(), url.Hostname(), url.Hostname()+":"+newPort, 1)
	}
	rewrittenUrl, _ := url.Parse(newUrl)
	return *rewrittenUrl
}

func RequestHandler(proxy *httputil.ReverseProxy) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, r *http.Request) {
		proxy.ServeHTTP(w, r)
	}
}