Support multiple listen address.

This commit is contained in:
Chen Yufei
2012-12-31 23:39:54 +08:00
parent 29f971d84e
commit fe86197942
8 changed files with 72 additions and 73 deletions

View File

@@ -5,6 +5,7 @@ import (
"flag"
"fmt"
"io"
"net"
"os"
"path"
"reflect"
@@ -12,11 +13,6 @@ import (
"strings"
)
var (
selfURL127 string // 127.0.0.1:listenAddr
selfURLLH string // localhost:listenAddr
)
const (
version = "0.3.5"
defaultListenAddr = "127.0.0.1:7777"
@@ -24,7 +20,7 @@ const (
type Config struct {
RcFile string // config file
ListenAddr string
ListenAddr []string
SocksAddr string
Core int
SshServer string
@@ -75,8 +71,9 @@ func init() {
func parseCmdLineConfig() *Config {
var c Config
var listenAddr string
flag.StringVar(&c.RcFile, "rc", path.Join(dsFile.dir, rcFname), "configuration file")
flag.StringVar(&c.ListenAddr, "listen", "", "proxy server listen address, default to "+defaultListenAddr)
flag.StringVar(&listenAddr, "listen", "", "proxy server listen address, default to "+defaultListenAddr)
flag.StringVar(&c.SocksAddr, "socks", "", "socks proxy address")
flag.IntVar(&c.Core, "core", 2, "number of cores to use")
flag.StringVar(&c.SshServer, "sshServer", "", "remote server which will ssh to and provide sock server")
@@ -97,6 +94,9 @@ func parseCmdLineConfig() *Config {
// flag.BoolVar(&c.AlwaysProxy, "alwaysProxy", false, "always use parent proxy")
flag.Parse()
if listenAddr != "" {
c.ListenAddr = []string{listenAddr}
}
return &c
}
@@ -116,7 +116,30 @@ func parseBool(v string, msg string) bool {
type configParser struct{}
func (p configParser) ParseListen(val string) {
config.ListenAddr = val
arr := strings.Split(val, ",")
config.ListenAddr = make([]string, 0, len(arr))
for _, s := range arr {
s = strings.TrimSpace(s)
h, port := splitHostPort(s)
if port == "" {
fmt.Printf("listen address %s has no port\n", s)
os.Exit(1)
}
if h == "" || h == "0.0.0.0" {
addrs, err := hostIP()
if err != nil {
fmt.Println("Can't determine host IP address")
fmt.Println("Specify specific host IP address or correct network settings.")
os.Exit(1)
}
for _, ad := range addrs {
// fmt.Println("host addr:", ad)
config.ListenAddr = append(config.ListenAddr, net.JoinHostPort(ad, port))
}
} else {
config.ListenAddr = append(config.ListenAddr, s)
}
}
}
func isServerAddrValid(val string) bool {
@@ -271,13 +294,10 @@ func updateConfig(new *Config) {
}
}
}
if config.ListenAddr == "" {
config.ListenAddr = defaultListenAddr
if new.ListenAddr != nil {
config.ListenAddr = new.ListenAddr
}
if config.ListenAddr == nil {
config.ListenAddr = []string{defaultListenAddr}
}
}
func setSelfURL() {
_, port := splitHostPort(config.ListenAddr)
selfURL127 = "127.0.0.1:" + port
selfURLLH = "localhost:" + port
}

View File

@@ -24,14 +24,14 @@ var errPageRawTmpl = `<!DOCTYPE html>
var blockedFormRawTmpl = `<p></p>
<b>Refresh to retry</b> or add <b>{{.Domain}}</b> to
<form action="http://{{.ListenAddr}}/blocked" method="get">
<form action="http://{{.ProxyAddr}}/blocked" method="get">
<input type="hidden" name="host" value={{.Host}}>
<b>blocked sites</b>
<input type="submit" name="submit" value="blocked">
</form>
`
var directFormRawTmpl = `<form action="http://{{.ListenAddr}}/direct" method="get">
var directFormRawTmpl = `<form action="http://{{.ProxyAddr}}/direct" method="get">
<input type="hidden" name="host" value={{.Host}}>
<b>direct accessible sites</b>
<input type="submit" name="submit" value="direct">
@@ -122,21 +122,21 @@ func sendRedirectPage(w io.Writer, location string) {
"", fmt.Sprintf("Location: %s\r\n", location))
}
func sendBlockedErrorPage(w io.Writer, codeReason, h1, msg string, r *Request) {
func sendBlockedErrorPage(c *clientConn, codeReason, h1, msg string, r *Request) {
// If host is IP or in always DS, we can't add it to blocked or direct domain list. Just
// return ordinary error page.
h, _ := splitHostPort(r.URL.Host)
if hostIsIP(r.URL.Host) || domainSet.isHostInAlwaysDs(h) {
sendErrorPage(w, codeReason, h1, msg)
sendErrorPage(c, codeReason, h1, msg)
return
}
data := struct {
ListenAddr string
Host string
Domain string
ProxyAddr string
Host string
Domain string
}{
config.ListenAddr,
c.proxy.addr,
h,
host2Domain(r.URL.Host),
}
@@ -151,5 +151,5 @@ func sendBlockedErrorPage(w io.Writer, codeReason, h1, msg string, r *Request) {
return
}
}
sendPageGeneric(w, codeReason, "[Error] "+h1, msg, buf.String(), "")
sendPageGeneric(c, codeReason, "[Error] "+h1, msg, buf.String(), "")
}

View File

@@ -113,7 +113,7 @@ func splitHostPort(s string) (host, port string) {
return s, ""
}
// Scan back, make sure we find ':'
for i := len(s) - 2; i > 0; i-- {
for i := len(s) - 2; i >= 0; i-- {
c := s[i]
switch {
case c == ':':
@@ -130,9 +130,7 @@ func splitHostPort(s string) (host, port string) {
// will check the correctness of the host.
func ParseRequestURI(rawurl string) (*URL, error) {
if rawurl[0] == '/' {
// OS X seems to send only path to the server if the url is 127.0.0.1
return &URL{Host: "", Path: rawurl}, nil
// return nil, errors.New("Invalid proxy request URI: " + rawurl)
}
var f []string

View File

@@ -13,12 +13,13 @@ func TestSplitHostPort(t *testing.T) {
{"google.com", "google.com", ""},
{"google.com:80", "google.com", "80"},
{"google.com80", "google.com80", ""},
{":7777", "", "7777"},
}
for _, td := range testData {
h, p := splitHostPort(td.host)
if h != td.hostNoPort || p != td.port {
t.Errorf("%s returns %v %v", td.host, td.hostNoPort, td.port)
t.Errorf("%s returns %v:%v", td.host, td.hostNoPort, td.port)
}
}
}

11
main.go
View File

@@ -40,7 +40,6 @@ func main() {
initLog()
initProxyServerAddr()
initSocksServer()
initShadowSocks()
@@ -50,8 +49,6 @@ func main() {
hasParentProxy = true
}
setSelfURL()
domainSet.load()
/*
if *cpuprofile != "" {
@@ -77,6 +74,10 @@ func main() {
go sigHandler()
go runSSH()
py := NewProxy(config.ListenAddr)
py.Serve()
if len(config.ListenAddr) > 1 {
for _, addr := range config.ListenAddr[1:] {
go NewProxy(addr).Serve()
}
}
NewProxy(config.ListenAddr[0]).Serve()
}

40
pac.go
View File

@@ -3,7 +3,6 @@ package main
import (
"bytes"
"fmt"
"io"
"os"
"strings"
"text/template"
@@ -12,12 +11,11 @@ import (
var pac struct {
template *template.Template
topLevelDomain string
proxyServerAddr string
}
func init() {
const pacRawTmpl = `var direct = 'DIRECT';
var httpProxy = '{{.ProxyAddr}}';
var httpProxy = 'PROXY {{.ProxyAddr}}; DIRECT';
var directList = [
"localhost",
@@ -71,31 +69,11 @@ function FindProxyForURL(url, host) {
pac.topLevelDomain = buf.String()[:buf.Len()-2] // remove the final comma
}
func initProxyServerAddr() {
listen, port := splitHostPort(config.ListenAddr)
if listen == "0.0.0.0" {
addrs, err := hostIP()
if err != nil {
errl.Println("Either change listen address to specific IP, or correct your host network settings.")
os.Exit(1)
}
for _, ip := range addrs {
pac.proxyServerAddr += fmt.Sprintf("PROXY %s:%s; ", ip, port)
}
pac.proxyServerAddr += "DIRECT"
info.Printf("proxy listen address is %s, PAC will have proxy address: %s\n",
config.ListenAddr, pac.proxyServerAddr)
} else {
pac.proxyServerAddr = fmt.Sprintf("PROXY %s; DIRECT", config.ListenAddr)
}
}
// No need for content-length as we are closing connection
var pacHeader = []byte("HTTP/1.1 200 OK\r\nServer: cow-proxy\r\n" +
"Content-Type: application/x-ns-proxy-autoconfig\r\nConnection: close\r\n\r\n")
func sendPAC(w io.Writer) {
func sendPAC(c *clientConn) {
// domains in PAC file needs double quote
ds1 := strings.Join(domainSet.alwaysDirect.toSlice(), "\",\n\"")
ds2 := strings.Join(domainSet.direct.toSlice(), "\",\n\"")
@@ -109,10 +87,10 @@ func sendPAC(w io.Writer) {
}
if ds == "" {
// Empty direct domain list
w.Write(pacHeader)
pacproxy := fmt.Sprintf("function FindProxyForURL(url, host) { return '%s'; };",
pac.proxyServerAddr)
w.Write([]byte(pacproxy))
c.Write(pacHeader)
pacproxy := fmt.Sprintf("function FindProxyForURL(url, host) { return 'PROXY %s; DIRECT'; };",
c.proxy.addr)
c.Write([]byte(pacproxy))
return
}
@@ -121,12 +99,12 @@ func sendPAC(w io.Writer) {
DirectDomains string
TopLevel string
}{
pac.proxyServerAddr,
c.proxy.addr,
",\n\"" + ds + "\"",
pac.topLevelDomain,
}
if _, err := w.Write(pacHeader); err != nil {
if _, err := c.Write(pacHeader); err != nil {
debug.Println("Error writing pac header")
return
}
@@ -135,7 +113,7 @@ func sendPAC(w io.Writer) {
if err := pac.template.Execute(buf, data); err != nil {
errl.Println("Error generating pac file:", err)
}
if _, err := w.Write(buf.Bytes()); err != nil {
if _, err := c.Write(buf.Bytes()); err != nil {
debug.Println("Error writing pac content:", err)
}
}

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"io"
"net"
"os"
// "reflect"
"strconv"
"strings"
@@ -86,6 +85,7 @@ type clientConn struct {
bufRd *bufio.Reader
serverConn map[string]*serverConn // request serverConn, host:port as key
buf []byte // buffer for reading request, avoids repeatedly allocating buffer
proxy *Proxy
}
var (
@@ -108,7 +108,7 @@ func (py *Proxy) Serve() {
ln, err := net.Listen("tcp", py.addr)
if err != nil {
fmt.Println("Server creation failed:", err)
os.Exit(1)
return
}
info.Printf("COW proxy address %s, PAC url %s\n", py.addr, "http://"+py.addr+"/pac")
@@ -121,7 +121,7 @@ func (py *Proxy) Serve() {
if debug {
debug.Println("New Client:", conn.RemoteAddr())
}
c := newClientConn(conn)
c := newClientConn(conn, py)
go c.serve()
}
}
@@ -130,11 +130,12 @@ func (py *Proxy) Serve() {
// bufio.Reader's Read
const bufSize = 4096
func newClientConn(rwc net.Conn) *clientConn {
func newClientConn(rwc net.Conn, proxy *Proxy) *clientConn {
c := &clientConn{
Conn: rwc,
serverConn: map[string]*serverConn{},
bufRd: bufio.NewReaderSize(rwc, bufSize),
proxy: proxy,
}
return c
}
@@ -151,7 +152,7 @@ func (c *clientConn) close() {
}
func isSelfURL(url string) bool {
return url == "" || url == selfURLLH || url == selfURL127
return url == ""
}
func (c *clientConn) getRequest() (r *Request) {

View File

@@ -2,8 +2,8 @@ package main
import (
"bufio"
// "fmt"
"errors"
"fmt"
"io"
"log"
"net"
@@ -93,13 +93,13 @@ func isDirExists(path string) (bool, error) {
func hostIP() (addrs []string, err error) {
name, err := os.Hostname()
if err != nil {
errl.Printf("Error get host name: %v\n", err)
fmt.Printf("Error get host name: %v\n", err)
return
}
addrs, err = net.LookupHost(name)
if err != nil {
errl.Printf("Error getting host IP address: %v\n", err)
fmt.Printf("Error getting host IP address: %v\n", err)
return
}
return