Support multiple listen address.

This commit is contained in:
Chen Yufei
2012-12-31 23:39:54 +08:00
parent 29f971d84e
commit fe86197942
8 changed files with 72 additions and 73 deletions

View File

@@ -5,6 +5,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"net"
"os" "os"
"path" "path"
"reflect" "reflect"
@@ -12,11 +13,6 @@ import (
"strings" "strings"
) )
var (
selfURL127 string // 127.0.0.1:listenAddr
selfURLLH string // localhost:listenAddr
)
const ( const (
version = "0.3.5" version = "0.3.5"
defaultListenAddr = "127.0.0.1:7777" defaultListenAddr = "127.0.0.1:7777"
@@ -24,7 +20,7 @@ const (
type Config struct { type Config struct {
RcFile string // config file RcFile string // config file
ListenAddr string ListenAddr []string
SocksAddr string SocksAddr string
Core int Core int
SshServer string SshServer string
@@ -75,8 +71,9 @@ func init() {
func parseCmdLineConfig() *Config { func parseCmdLineConfig() *Config {
var c Config var c Config
var listenAddr string
flag.StringVar(&c.RcFile, "rc", path.Join(dsFile.dir, rcFname), "configuration file") flag.StringVar(&c.RcFile, "rc", path.Join(dsFile.dir, rcFname), "configuration file")
flag.StringVar(&c.ListenAddr, "listen", "", "proxy server listen address, default to "+defaultListenAddr) flag.StringVar(&listenAddr, "listen", "", "proxy server listen address, default to "+defaultListenAddr)
flag.StringVar(&c.SocksAddr, "socks", "", "socks proxy address") flag.StringVar(&c.SocksAddr, "socks", "", "socks proxy address")
flag.IntVar(&c.Core, "core", 2, "number of cores to use") flag.IntVar(&c.Core, "core", 2, "number of cores to use")
flag.StringVar(&c.SshServer, "sshServer", "", "remote server which will ssh to and provide sock server") flag.StringVar(&c.SshServer, "sshServer", "", "remote server which will ssh to and provide sock server")
@@ -97,6 +94,9 @@ func parseCmdLineConfig() *Config {
// flag.BoolVar(&c.AlwaysProxy, "alwaysProxy", false, "always use parent proxy") // flag.BoolVar(&c.AlwaysProxy, "alwaysProxy", false, "always use parent proxy")
flag.Parse() flag.Parse()
if listenAddr != "" {
c.ListenAddr = []string{listenAddr}
}
return &c return &c
} }
@@ -116,7 +116,30 @@ func parseBool(v string, msg string) bool {
type configParser struct{} type configParser struct{}
func (p configParser) ParseListen(val string) { func (p configParser) ParseListen(val string) {
config.ListenAddr = val arr := strings.Split(val, ",")
config.ListenAddr = make([]string, 0, len(arr))
for _, s := range arr {
s = strings.TrimSpace(s)
h, port := splitHostPort(s)
if port == "" {
fmt.Printf("listen address %s has no port\n", s)
os.Exit(1)
}
if h == "" || h == "0.0.0.0" {
addrs, err := hostIP()
if err != nil {
fmt.Println("Can't determine host IP address")
fmt.Println("Specify specific host IP address or correct network settings.")
os.Exit(1)
}
for _, ad := range addrs {
// fmt.Println("host addr:", ad)
config.ListenAddr = append(config.ListenAddr, net.JoinHostPort(ad, port))
}
} else {
config.ListenAddr = append(config.ListenAddr, s)
}
}
} }
func isServerAddrValid(val string) bool { func isServerAddrValid(val string) bool {
@@ -271,13 +294,10 @@ func updateConfig(new *Config) {
} }
} }
} }
if config.ListenAddr == "" { if new.ListenAddr != nil {
config.ListenAddr = defaultListenAddr config.ListenAddr = new.ListenAddr
}
if config.ListenAddr == nil {
config.ListenAddr = []string{defaultListenAddr}
} }
} }
func setSelfURL() {
_, port := splitHostPort(config.ListenAddr)
selfURL127 = "127.0.0.1:" + port
selfURLLH = "localhost:" + port
}

View File

@@ -24,14 +24,14 @@ var errPageRawTmpl = `<!DOCTYPE html>
var blockedFormRawTmpl = `<p></p> var blockedFormRawTmpl = `<p></p>
<b>Refresh to retry</b> or add <b>{{.Domain}}</b> to <b>Refresh to retry</b> or add <b>{{.Domain}}</b> to
<form action="http://{{.ListenAddr}}/blocked" method="get"> <form action="http://{{.ProxyAddr}}/blocked" method="get">
<input type="hidden" name="host" value={{.Host}}> <input type="hidden" name="host" value={{.Host}}>
<b>blocked sites</b> <b>blocked sites</b>
<input type="submit" name="submit" value="blocked"> <input type="submit" name="submit" value="blocked">
</form> </form>
` `
var directFormRawTmpl = `<form action="http://{{.ListenAddr}}/direct" method="get"> var directFormRawTmpl = `<form action="http://{{.ProxyAddr}}/direct" method="get">
<input type="hidden" name="host" value={{.Host}}> <input type="hidden" name="host" value={{.Host}}>
<b>direct accessible sites</b> <b>direct accessible sites</b>
<input type="submit" name="submit" value="direct"> <input type="submit" name="submit" value="direct">
@@ -122,21 +122,21 @@ func sendRedirectPage(w io.Writer, location string) {
"", fmt.Sprintf("Location: %s\r\n", location)) "", fmt.Sprintf("Location: %s\r\n", location))
} }
func sendBlockedErrorPage(w io.Writer, codeReason, h1, msg string, r *Request) { func sendBlockedErrorPage(c *clientConn, codeReason, h1, msg string, r *Request) {
// If host is IP or in always DS, we can't add it to blocked or direct domain list. Just // If host is IP or in always DS, we can't add it to blocked or direct domain list. Just
// return ordinary error page. // return ordinary error page.
h, _ := splitHostPort(r.URL.Host) h, _ := splitHostPort(r.URL.Host)
if hostIsIP(r.URL.Host) || domainSet.isHostInAlwaysDs(h) { if hostIsIP(r.URL.Host) || domainSet.isHostInAlwaysDs(h) {
sendErrorPage(w, codeReason, h1, msg) sendErrorPage(c, codeReason, h1, msg)
return return
} }
data := struct { data := struct {
ListenAddr string ProxyAddr string
Host string Host string
Domain string Domain string
}{ }{
config.ListenAddr, c.proxy.addr,
h, h,
host2Domain(r.URL.Host), host2Domain(r.URL.Host),
} }
@@ -151,5 +151,5 @@ func sendBlockedErrorPage(w io.Writer, codeReason, h1, msg string, r *Request) {
return return
} }
} }
sendPageGeneric(w, codeReason, "[Error] "+h1, msg, buf.String(), "") sendPageGeneric(c, codeReason, "[Error] "+h1, msg, buf.String(), "")
} }

View File

@@ -113,7 +113,7 @@ func splitHostPort(s string) (host, port string) {
return s, "" return s, ""
} }
// Scan back, make sure we find ':' // Scan back, make sure we find ':'
for i := len(s) - 2; i > 0; i-- { for i := len(s) - 2; i >= 0; i-- {
c := s[i] c := s[i]
switch { switch {
case c == ':': case c == ':':
@@ -130,9 +130,7 @@ func splitHostPort(s string) (host, port string) {
// will check the correctness of the host. // will check the correctness of the host.
func ParseRequestURI(rawurl string) (*URL, error) { func ParseRequestURI(rawurl string) (*URL, error) {
if rawurl[0] == '/' { if rawurl[0] == '/' {
// OS X seems to send only path to the server if the url is 127.0.0.1
return &URL{Host: "", Path: rawurl}, nil return &URL{Host: "", Path: rawurl}, nil
// return nil, errors.New("Invalid proxy request URI: " + rawurl)
} }
var f []string var f []string

View File

@@ -13,12 +13,13 @@ func TestSplitHostPort(t *testing.T) {
{"google.com", "google.com", ""}, {"google.com", "google.com", ""},
{"google.com:80", "google.com", "80"}, {"google.com:80", "google.com", "80"},
{"google.com80", "google.com80", ""}, {"google.com80", "google.com80", ""},
{":7777", "", "7777"},
} }
for _, td := range testData { for _, td := range testData {
h, p := splitHostPort(td.host) h, p := splitHostPort(td.host)
if h != td.hostNoPort || p != td.port { if h != td.hostNoPort || p != td.port {
t.Errorf("%s returns %v %v", td.host, td.hostNoPort, td.port) t.Errorf("%s returns %v:%v", td.host, td.hostNoPort, td.port)
} }
} }
} }

11
main.go
View File

@@ -40,7 +40,6 @@ func main() {
initLog() initLog()
initProxyServerAddr()
initSocksServer() initSocksServer()
initShadowSocks() initShadowSocks()
@@ -50,8 +49,6 @@ func main() {
hasParentProxy = true hasParentProxy = true
} }
setSelfURL()
domainSet.load() domainSet.load()
/* /*
if *cpuprofile != "" { if *cpuprofile != "" {
@@ -77,6 +74,10 @@ func main() {
go sigHandler() go sigHandler()
go runSSH() go runSSH()
py := NewProxy(config.ListenAddr) if len(config.ListenAddr) > 1 {
py.Serve() for _, addr := range config.ListenAddr[1:] {
go NewProxy(addr).Serve()
}
}
NewProxy(config.ListenAddr[0]).Serve()
} }

40
pac.go
View File

@@ -3,7 +3,6 @@ package main
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"os" "os"
"strings" "strings"
"text/template" "text/template"
@@ -12,12 +11,11 @@ import (
var pac struct { var pac struct {
template *template.Template template *template.Template
topLevelDomain string topLevelDomain string
proxyServerAddr string
} }
func init() { func init() {
const pacRawTmpl = `var direct = 'DIRECT'; const pacRawTmpl = `var direct = 'DIRECT';
var httpProxy = '{{.ProxyAddr}}'; var httpProxy = 'PROXY {{.ProxyAddr}}; DIRECT';
var directList = [ var directList = [
"localhost", "localhost",
@@ -71,31 +69,11 @@ function FindProxyForURL(url, host) {
pac.topLevelDomain = buf.String()[:buf.Len()-2] // remove the final comma pac.topLevelDomain = buf.String()[:buf.Len()-2] // remove the final comma
} }
func initProxyServerAddr() {
listen, port := splitHostPort(config.ListenAddr)
if listen == "0.0.0.0" {
addrs, err := hostIP()
if err != nil {
errl.Println("Either change listen address to specific IP, or correct your host network settings.")
os.Exit(1)
}
for _, ip := range addrs {
pac.proxyServerAddr += fmt.Sprintf("PROXY %s:%s; ", ip, port)
}
pac.proxyServerAddr += "DIRECT"
info.Printf("proxy listen address is %s, PAC will have proxy address: %s\n",
config.ListenAddr, pac.proxyServerAddr)
} else {
pac.proxyServerAddr = fmt.Sprintf("PROXY %s; DIRECT", config.ListenAddr)
}
}
// No need for content-length as we are closing connection // No need for content-length as we are closing connection
var pacHeader = []byte("HTTP/1.1 200 OK\r\nServer: cow-proxy\r\n" + var pacHeader = []byte("HTTP/1.1 200 OK\r\nServer: cow-proxy\r\n" +
"Content-Type: application/x-ns-proxy-autoconfig\r\nConnection: close\r\n\r\n") "Content-Type: application/x-ns-proxy-autoconfig\r\nConnection: close\r\n\r\n")
func sendPAC(w io.Writer) { func sendPAC(c *clientConn) {
// domains in PAC file needs double quote // domains in PAC file needs double quote
ds1 := strings.Join(domainSet.alwaysDirect.toSlice(), "\",\n\"") ds1 := strings.Join(domainSet.alwaysDirect.toSlice(), "\",\n\"")
ds2 := strings.Join(domainSet.direct.toSlice(), "\",\n\"") ds2 := strings.Join(domainSet.direct.toSlice(), "\",\n\"")
@@ -109,10 +87,10 @@ func sendPAC(w io.Writer) {
} }
if ds == "" { if ds == "" {
// Empty direct domain list // Empty direct domain list
w.Write(pacHeader) c.Write(pacHeader)
pacproxy := fmt.Sprintf("function FindProxyForURL(url, host) { return '%s'; };", pacproxy := fmt.Sprintf("function FindProxyForURL(url, host) { return 'PROXY %s; DIRECT'; };",
pac.proxyServerAddr) c.proxy.addr)
w.Write([]byte(pacproxy)) c.Write([]byte(pacproxy))
return return
} }
@@ -121,12 +99,12 @@ func sendPAC(w io.Writer) {
DirectDomains string DirectDomains string
TopLevel string TopLevel string
}{ }{
pac.proxyServerAddr, c.proxy.addr,
",\n\"" + ds + "\"", ",\n\"" + ds + "\"",
pac.topLevelDomain, pac.topLevelDomain,
} }
if _, err := w.Write(pacHeader); err != nil { if _, err := c.Write(pacHeader); err != nil {
debug.Println("Error writing pac header") debug.Println("Error writing pac header")
return return
} }
@@ -135,7 +113,7 @@ func sendPAC(w io.Writer) {
if err := pac.template.Execute(buf, data); err != nil { if err := pac.template.Execute(buf, data); err != nil {
errl.Println("Error generating pac file:", err) errl.Println("Error generating pac file:", err)
} }
if _, err := w.Write(buf.Bytes()); err != nil { if _, err := c.Write(buf.Bytes()); err != nil {
debug.Println("Error writing pac content:", err) debug.Println("Error writing pac content:", err)
} }
} }

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"os"
// "reflect" // "reflect"
"strconv" "strconv"
"strings" "strings"
@@ -86,6 +85,7 @@ type clientConn struct {
bufRd *bufio.Reader bufRd *bufio.Reader
serverConn map[string]*serverConn // request serverConn, host:port as key serverConn map[string]*serverConn // request serverConn, host:port as key
buf []byte // buffer for reading request, avoids repeatedly allocating buffer buf []byte // buffer for reading request, avoids repeatedly allocating buffer
proxy *Proxy
} }
var ( var (
@@ -108,7 +108,7 @@ func (py *Proxy) Serve() {
ln, err := net.Listen("tcp", py.addr) ln, err := net.Listen("tcp", py.addr)
if err != nil { if err != nil {
fmt.Println("Server creation failed:", err) fmt.Println("Server creation failed:", err)
os.Exit(1) return
} }
info.Printf("COW proxy address %s, PAC url %s\n", py.addr, "http://"+py.addr+"/pac") info.Printf("COW proxy address %s, PAC url %s\n", py.addr, "http://"+py.addr+"/pac")
@@ -121,7 +121,7 @@ func (py *Proxy) Serve() {
if debug { if debug {
debug.Println("New Client:", conn.RemoteAddr()) debug.Println("New Client:", conn.RemoteAddr())
} }
c := newClientConn(conn) c := newClientConn(conn, py)
go c.serve() go c.serve()
} }
} }
@@ -130,11 +130,12 @@ func (py *Proxy) Serve() {
// bufio.Reader's Read // bufio.Reader's Read
const bufSize = 4096 const bufSize = 4096
func newClientConn(rwc net.Conn) *clientConn { func newClientConn(rwc net.Conn, proxy *Proxy) *clientConn {
c := &clientConn{ c := &clientConn{
Conn: rwc, Conn: rwc,
serverConn: map[string]*serverConn{}, serverConn: map[string]*serverConn{},
bufRd: bufio.NewReaderSize(rwc, bufSize), bufRd: bufio.NewReaderSize(rwc, bufSize),
proxy: proxy,
} }
return c return c
} }
@@ -151,7 +152,7 @@ func (c *clientConn) close() {
} }
func isSelfURL(url string) bool { func isSelfURL(url string) bool {
return url == "" || url == selfURLLH || url == selfURL127 return url == ""
} }
func (c *clientConn) getRequest() (r *Request) { func (c *clientConn) getRequest() (r *Request) {

View File

@@ -2,8 +2,8 @@ package main
import ( import (
"bufio" "bufio"
// "fmt"
"errors" "errors"
"fmt"
"io" "io"
"log" "log"
"net" "net"
@@ -93,13 +93,13 @@ func isDirExists(path string) (bool, error) {
func hostIP() (addrs []string, err error) { func hostIP() (addrs []string, err error) {
name, err := os.Hostname() name, err := os.Hostname()
if err != nil { if err != nil {
errl.Printf("Error get host name: %v\n", err) fmt.Printf("Error get host name: %v\n", err)
return return
} }
addrs, err = net.LookupHost(name) addrs, err = net.LookupHost(name)
if err != nil { if err != nil {
errl.Printf("Error getting host IP address: %v\n", err) fmt.Printf("Error getting host IP address: %v\n", err)
return return
} }
return return