diff --git a/config.go b/config.go index c317a11..5aa93f5 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "io" + "net" "os" "path" "reflect" @@ -12,11 +13,6 @@ import ( "strings" ) -var ( - selfURL127 string // 127.0.0.1:listenAddr - selfURLLH string // localhost:listenAddr -) - const ( version = "0.3.5" defaultListenAddr = "127.0.0.1:7777" @@ -24,7 +20,7 @@ const ( type Config struct { RcFile string // config file - ListenAddr string + ListenAddr []string SocksAddr string Core int SshServer string @@ -75,8 +71,9 @@ func init() { func parseCmdLineConfig() *Config { var c Config + var listenAddr string flag.StringVar(&c.RcFile, "rc", path.Join(dsFile.dir, rcFname), "configuration file") - flag.StringVar(&c.ListenAddr, "listen", "", "proxy server listen address, default to "+defaultListenAddr) + flag.StringVar(&listenAddr, "listen", "", "proxy server listen address, default to "+defaultListenAddr) flag.StringVar(&c.SocksAddr, "socks", "", "socks proxy address") flag.IntVar(&c.Core, "core", 2, "number of cores to use") flag.StringVar(&c.SshServer, "sshServer", "", "remote server which will ssh to and provide sock server") @@ -97,6 +94,9 @@ func parseCmdLineConfig() *Config { // flag.BoolVar(&c.AlwaysProxy, "alwaysProxy", false, "always use parent proxy") flag.Parse() + if listenAddr != "" { + c.ListenAddr = []string{listenAddr} + } return &c } @@ -116,7 +116,30 @@ func parseBool(v string, msg string) bool { type configParser struct{} func (p configParser) ParseListen(val string) { - config.ListenAddr = val + arr := strings.Split(val, ",") + config.ListenAddr = make([]string, 0, len(arr)) + for _, s := range arr { + s = strings.TrimSpace(s) + h, port := splitHostPort(s) + if port == "" { + fmt.Printf("listen address %s has no port\n", s) + os.Exit(1) + } + if h == "" || h == "0.0.0.0" { + addrs, err := hostIP() + if err != nil { + fmt.Println("Can't determine host IP address") + fmt.Println("Specify specific host IP address or correct network settings.") + os.Exit(1) + } + for _, ad := range addrs { + // fmt.Println("host addr:", ad) + config.ListenAddr = append(config.ListenAddr, net.JoinHostPort(ad, port)) + } + } else { + config.ListenAddr = append(config.ListenAddr, s) + } + } } func isServerAddrValid(val string) bool { @@ -271,13 +294,10 @@ func updateConfig(new *Config) { } } } - if config.ListenAddr == "" { - config.ListenAddr = defaultListenAddr + if new.ListenAddr != nil { + config.ListenAddr = new.ListenAddr + } + if config.ListenAddr == nil { + config.ListenAddr = []string{defaultListenAddr} } } - -func setSelfURL() { - _, port := splitHostPort(config.ListenAddr) - selfURL127 = "127.0.0.1:" + port - selfURLLH = "localhost:" + port -} diff --git a/error.go b/error.go index 98df5b8..9412dba 100644 --- a/error.go +++ b/error.go @@ -24,14 +24,14 @@ var errPageRawTmpl = ` var blockedFormRawTmpl = `

Refresh to retry or add {{.Domain}} to -
+ blocked sites
` -var directFormRawTmpl = `
+var directFormRawTmpl = ` direct accessible sites @@ -122,21 +122,21 @@ func sendRedirectPage(w io.Writer, location string) { "", fmt.Sprintf("Location: %s\r\n", location)) } -func sendBlockedErrorPage(w io.Writer, codeReason, h1, msg string, r *Request) { +func sendBlockedErrorPage(c *clientConn, codeReason, h1, msg string, r *Request) { // If host is IP or in always DS, we can't add it to blocked or direct domain list. Just // return ordinary error page. h, _ := splitHostPort(r.URL.Host) if hostIsIP(r.URL.Host) || domainSet.isHostInAlwaysDs(h) { - sendErrorPage(w, codeReason, h1, msg) + sendErrorPage(c, codeReason, h1, msg) return } data := struct { - ListenAddr string - Host string - Domain string + ProxyAddr string + Host string + Domain string }{ - config.ListenAddr, + c.proxy.addr, h, host2Domain(r.URL.Host), } @@ -151,5 +151,5 @@ func sendBlockedErrorPage(w io.Writer, codeReason, h1, msg string, r *Request) { return } } - sendPageGeneric(w, codeReason, "[Error] "+h1, msg, buf.String(), "") + sendPageGeneric(c, codeReason, "[Error] "+h1, msg, buf.String(), "") } diff --git a/http.go b/http.go index 26b433d..55c8d8e 100644 --- a/http.go +++ b/http.go @@ -113,7 +113,7 @@ func splitHostPort(s string) (host, port string) { return s, "" } // Scan back, make sure we find ':' - for i := len(s) - 2; i > 0; i-- { + for i := len(s) - 2; i >= 0; i-- { c := s[i] switch { case c == ':': @@ -130,9 +130,7 @@ func splitHostPort(s string) (host, port string) { // will check the correctness of the host. func ParseRequestURI(rawurl string) (*URL, error) { if rawurl[0] == '/' { - // OS X seems to send only path to the server if the url is 127.0.0.1 return &URL{Host: "", Path: rawurl}, nil - // return nil, errors.New("Invalid proxy request URI: " + rawurl) } var f []string diff --git a/http_test.go b/http_test.go index 13dd17b..26292dd 100644 --- a/http_test.go +++ b/http_test.go @@ -13,12 +13,13 @@ func TestSplitHostPort(t *testing.T) { {"google.com", "google.com", ""}, {"google.com:80", "google.com", "80"}, {"google.com80", "google.com80", ""}, + {":7777", "", "7777"}, } for _, td := range testData { h, p := splitHostPort(td.host) if h != td.hostNoPort || p != td.port { - t.Errorf("%s returns %v %v", td.host, td.hostNoPort, td.port) + t.Errorf("%s returns %v:%v", td.host, td.hostNoPort, td.port) } } } diff --git a/main.go b/main.go index 83178a6..82fd37b 100644 --- a/main.go +++ b/main.go @@ -40,7 +40,6 @@ func main() { initLog() - initProxyServerAddr() initSocksServer() initShadowSocks() @@ -50,8 +49,6 @@ func main() { hasParentProxy = true } - setSelfURL() - domainSet.load() /* if *cpuprofile != "" { @@ -77,6 +74,10 @@ func main() { go sigHandler() go runSSH() - py := NewProxy(config.ListenAddr) - py.Serve() + if len(config.ListenAddr) > 1 { + for _, addr := range config.ListenAddr[1:] { + go NewProxy(addr).Serve() + } + } + NewProxy(config.ListenAddr[0]).Serve() } diff --git a/pac.go b/pac.go index 2dd586f..276ab70 100644 --- a/pac.go +++ b/pac.go @@ -3,7 +3,6 @@ package main import ( "bytes" "fmt" - "io" "os" "strings" "text/template" @@ -12,12 +11,11 @@ import ( var pac struct { template *template.Template topLevelDomain string - proxyServerAddr string } func init() { const pacRawTmpl = `var direct = 'DIRECT'; -var httpProxy = '{{.ProxyAddr}}'; +var httpProxy = 'PROXY {{.ProxyAddr}}; DIRECT'; var directList = [ "localhost", @@ -71,31 +69,11 @@ function FindProxyForURL(url, host) { pac.topLevelDomain = buf.String()[:buf.Len()-2] // remove the final comma } -func initProxyServerAddr() { - listen, port := splitHostPort(config.ListenAddr) - if listen == "0.0.0.0" { - addrs, err := hostIP() - if err != nil { - errl.Println("Either change listen address to specific IP, or correct your host network settings.") - os.Exit(1) - } - - for _, ip := range addrs { - pac.proxyServerAddr += fmt.Sprintf("PROXY %s:%s; ", ip, port) - } - pac.proxyServerAddr += "DIRECT" - info.Printf("proxy listen address is %s, PAC will have proxy address: %s\n", - config.ListenAddr, pac.proxyServerAddr) - } else { - pac.proxyServerAddr = fmt.Sprintf("PROXY %s; DIRECT", config.ListenAddr) - } -} - // No need for content-length as we are closing connection var pacHeader = []byte("HTTP/1.1 200 OK\r\nServer: cow-proxy\r\n" + "Content-Type: application/x-ns-proxy-autoconfig\r\nConnection: close\r\n\r\n") -func sendPAC(w io.Writer) { +func sendPAC(c *clientConn) { // domains in PAC file needs double quote ds1 := strings.Join(domainSet.alwaysDirect.toSlice(), "\",\n\"") ds2 := strings.Join(domainSet.direct.toSlice(), "\",\n\"") @@ -109,10 +87,10 @@ func sendPAC(w io.Writer) { } if ds == "" { // Empty direct domain list - w.Write(pacHeader) - pacproxy := fmt.Sprintf("function FindProxyForURL(url, host) { return '%s'; };", - pac.proxyServerAddr) - w.Write([]byte(pacproxy)) + c.Write(pacHeader) + pacproxy := fmt.Sprintf("function FindProxyForURL(url, host) { return 'PROXY %s; DIRECT'; };", + c.proxy.addr) + c.Write([]byte(pacproxy)) return } @@ -121,12 +99,12 @@ func sendPAC(w io.Writer) { DirectDomains string TopLevel string }{ - pac.proxyServerAddr, + c.proxy.addr, ",\n\"" + ds + "\"", pac.topLevelDomain, } - if _, err := w.Write(pacHeader); err != nil { + if _, err := c.Write(pacHeader); err != nil { debug.Println("Error writing pac header") return } @@ -135,7 +113,7 @@ func sendPAC(w io.Writer) { if err := pac.template.Execute(buf, data); err != nil { errl.Println("Error generating pac file:", err) } - if _, err := w.Write(buf.Bytes()); err != nil { + if _, err := c.Write(buf.Bytes()); err != nil { debug.Println("Error writing pac content:", err) } } diff --git a/proxy.go b/proxy.go index 7373b04..9022d9f 100644 --- a/proxy.go +++ b/proxy.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net" - "os" // "reflect" "strconv" "strings" @@ -86,6 +85,7 @@ type clientConn struct { bufRd *bufio.Reader serverConn map[string]*serverConn // request serverConn, host:port as key buf []byte // buffer for reading request, avoids repeatedly allocating buffer + proxy *Proxy } var ( @@ -108,7 +108,7 @@ func (py *Proxy) Serve() { ln, err := net.Listen("tcp", py.addr) if err != nil { fmt.Println("Server creation failed:", err) - os.Exit(1) + return } info.Printf("COW proxy address %s, PAC url %s\n", py.addr, "http://"+py.addr+"/pac") @@ -121,7 +121,7 @@ func (py *Proxy) Serve() { if debug { debug.Println("New Client:", conn.RemoteAddr()) } - c := newClientConn(conn) + c := newClientConn(conn, py) go c.serve() } } @@ -130,11 +130,12 @@ func (py *Proxy) Serve() { // bufio.Reader's Read const bufSize = 4096 -func newClientConn(rwc net.Conn) *clientConn { +func newClientConn(rwc net.Conn, proxy *Proxy) *clientConn { c := &clientConn{ Conn: rwc, serverConn: map[string]*serverConn{}, bufRd: bufio.NewReaderSize(rwc, bufSize), + proxy: proxy, } return c } @@ -151,7 +152,7 @@ func (c *clientConn) close() { } func isSelfURL(url string) bool { - return url == "" || url == selfURLLH || url == selfURL127 + return url == "" } func (c *clientConn) getRequest() (r *Request) { diff --git a/util.go b/util.go index 7b8a7d9..f75f3e4 100644 --- a/util.go +++ b/util.go @@ -2,8 +2,8 @@ package main import ( "bufio" - // "fmt" "errors" + "fmt" "io" "log" "net" @@ -93,13 +93,13 @@ func isDirExists(path string) (bool, error) { func hostIP() (addrs []string, err error) { name, err := os.Hostname() if err != nil { - errl.Printf("Error get host name: %v\n", err) + fmt.Printf("Error get host name: %v\n", err) return } addrs, err = net.LookupHost(name) if err != nil { - errl.Printf("Error getting host IP address: %v\n", err) + fmt.Printf("Error getting host IP address: %v\n", err) return } return