Newer
Older
httpcat / src / proxy.go
/**
  Author: Mark George <mark.george@otago.ac.nz>
  License: Zero-Clause BSD License
*/

package main

import (
	"fmt"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strconv"
	"strings"
)

var (
	proxyOpts *ProxyCommand
)

type ProxyCommand struct {
	Port   int    `short:"p" long:"port" description:"The port that the proxy listens on." required:"true"`
	Target string `short:"t" long:"target" description:"The URL for the target web server that the proxy forwards requests to." required:"true"`
}

func (opts *ProxyCommand) Execute(args []string) error {
	fmt.Println("Proxying port " + strconv.Itoa(opts.Port) + " to " + opts.Target)

	proxyOpts = opts

	proxy, err := CreateProxy(opts.Target)

	if err != nil {
		panic(err)
	}

	http.HandleFunc("/", RequestHandler(proxy))
	if err := http.ListenAndServe(":"+strconv.Itoa(opts.Port), nil); err != nil {
		panic(err)
	}

	return nil
}

func CreateProxy(target string) (*httputil.ReverseProxy, error) {

	url, err := url.Parse(target)

	// bad target URL
	if err != nil {
		return nil, err
	}

	proxy := httputil.NewSingleHostReverseProxy(url)

	OriginalDirector := proxy.Director

	proxy.Director = func(req *http.Request) {
		FixHeaders(url, req)

		showRequest(req)

		// send request to target
		OriginalDirector(req)
	}

	proxy.ModifyResponse = func(rsp *http.Response) error {
		// rewrite the port for redirect responses
		switch rsp.StatusCode {
		case 301, 302, 303, 307, 308:
			{
				url, _ := url.Parse(rsp.Header["Location"][0])
				newUrl := RewritePort(url, strconv.Itoa(proxyOpts.Port))
				rsp.Header["Location"][0] = newUrl.String()
			}
		}

		showResponse(rsp)
		return nil
	}

	proxy.ErrorHandler = func(rsp http.ResponseWriter, req *http.Request, err error) {
		fmt.Println("Error sending request to target server:")
		fmt.Println("   " + err.Error())
	}

	return proxy, nil
}

func FixHeaders(url *url.URL, req *http.Request) {
	// TODO this may be necessary for proxying HTTPS (need to test)
	req.URL.Host = url.Host
	req.URL.Scheme = url.Scheme
	req.Host = url.Host
}

func RewritePort(url *url.URL, newPort string) url.URL {
	var newUrl string

	if url.Port() != "" {
		// port exists, so replace :port with :newport
		newUrl = strings.Replace(url.String(), ":"+url.Port(), ":"+newPort, 1)
	} else {
		// no port, so replace hostname with hostname:port
		newUrl = strings.Replace(url.String(), url.Hostname(), url.Hostname()+":"+newPort, 1)
	}
	rewrittenUrl, _ := url.Parse(newUrl)
	return *rewrittenUrl
}

func RequestHandler(proxy *httputil.ReverseProxy) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, r *http.Request) {
		proxy.ServeHTTP(w, r)
	}
}