diff --git a/auth.go b/auth.go index 9ce2007..2da891d 100644 --- a/auth.go +++ b/auth.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "net" - "os" "strconv" "strings" "text/template" @@ -55,24 +54,20 @@ func parseAllowedClient(val string) { s := strings.TrimSpace(v) ipAndMask := strings.Split(s, "/") if len(ipAndMask) > 2 { - fmt.Println("allowedClient syntax error: client should be the form ip/nbitmask") - os.Exit(1) + Fatal("allowedClient syntax error: client should be the form ip/nbitmask") } ip := net.ParseIP(ipAndMask[0]) if ip == nil { - fmt.Printf("allowedClient syntax error %s: ip address not valid\n", s) - os.Exit(1) + Fatal("allowedClient syntax error %s: ip address not valid\n", s) } var mask net.IPMask if len(ipAndMask) == 2 { nbit, err := strconv.Atoi(ipAndMask[1]) if err != nil { - fmt.Printf("allowedClient syntax error %s: %v\n", s, err) - os.Exit(1) + Fatal("allowedClient syntax error %s: %v\n", s, err) } if nbit > 32 { - fmt.Println("allowedClient error: mask number should <= 32") - os.Exit(1) + Fatal("allowedClient error: mask number should <= 32") } mask = NewNbitIPv4Mask(nbit) } else { @@ -113,8 +108,7 @@ func initAuth() { "Content-Length: " + fmt.Sprintf("%d", len(authRawBodyTmpl)) + "\r\n\r\n" + authRawBodyTmpl var err error if auth.template, err = template.New("auth").Parse(rawTemplate); err != nil { - errl.Println("Internal error generating auth template:", err) - os.Exit(1) + Fatal("Internal error generating auth template:", err) } } diff --git a/config.go b/config.go index 2af178f..b9e3f79 100644 --- a/config.go +++ b/config.go @@ -20,27 +20,38 @@ const ( ) type Config struct { - RcFile string // config file - ListenAddr []string - AddrInPAC []string - SocksParent string + RcFile string // config file + ListenAddr []string + LogFile string + AlwaysProxy bool + + // socks parent proxy + SocksParent string + SshServer string + + // http parent proxy HttpParent string hasHttpParent bool HttpUserPasswd string httpAuthHeader []byte // basic authentication header constructed from HttpUserPasswd - Core int - SshServer string - DetectSSLErr bool - LogFile string - AlwaysProxy bool - ShadowSocks []string - ShadowPasswd []string - ShadowMethod []string // shadowsocks encryption method - UserPasswd string - AllowedClient string - AuthTimeout time.Duration - DialTimeout time.Duration - ReadTimeout time.Duration + + // shadowsocks proxy + ShadowSocks []string + ShadowPasswd []string + ShadowMethod []string // shadowsocks encryption method + + // authenticate client + UserPasswd string + AllowedClient string + AuthTimeout time.Duration + + // advanced options + DialTimeout time.Duration + ReadTimeout time.Duration + + Core int + AddrInPAC []string + DetectSSLErr bool // not configurable in config file PrintVer bool @@ -99,8 +110,7 @@ func parseBool(v, msg string) bool { case "false": return false default: - fmt.Printf("Config error: %s should be true or false\n", msg) - os.Exit(1) + Fatalf("Config error: %s should be true or false\n", msg) } return false } @@ -108,8 +118,7 @@ func parseBool(v, msg string) bool { func parseInt(val, msg string) (i int) { var err error if i, err = strconv.Atoi(val); err != nil { - fmt.Printf("Config error: %s should be an integer\n", msg) - os.Exit(1) + Fatalf("Config error: %s should be an integer\n", msg) } return } @@ -117,64 +126,12 @@ func parseInt(val, msg string) (i int) { func parseDuration(val, msg string) (d time.Duration) { var err error if d, err = time.ParseDuration(val); err != nil { - fmt.Printf("Config error: %s %v\n", msg, err) - os.Exit(1) + Fatalf("Config error: %s %v\n", msg, err) } return } -type configParser struct{} - -func (p configParser) ParseListen(val string) { - // Has specified command line options - if config.ListenAddr != nil { - return - } - arr := strings.Split(val, ",") - config.ListenAddr = make([]string, len(arr)) - for i, s := range arr { - s = strings.TrimSpace(s) - host, port := splitHostPort(s) - if port == "" { - fmt.Printf("listen address %s has no port\n", s) - os.Exit(1) - } - if host == "" || host == "0.0.0.0" { - if len(arr) > 1 { - fmt.Printf("Too much listen addresses: "+ - "%s represents all ip addresses on this host.\n", s) - os.Exit(1) - } - } - config.ListenAddr[i] = s - } -} - -func (p configParser) ParseAddrInPAC(val string) { - arr := strings.Split(val, ",") - config.AddrInPAC = make([]string, len(arr)) - for i, s := range arr { - if s == "" { - continue - } - s = strings.TrimSpace(s) - host, port := splitHostPort(s) - if port == "" { - fmt.Printf("proxy address in PAC %s has no port\n", s) - os.Exit(1) - } - if host == "0.0.0.0" { - fmt.Println("Can't use 0.0.0.0 as proxy address in PAC") - os.Exit(1) - } - config.AddrInPAC[i] = s - } -} - func hasPort(val string) bool { - if val == "" { // default value is empty - return true - } _, port := splitHostPort(val) if port == "" { return false @@ -193,6 +150,54 @@ func isUserPasswdValid(val string) bool { return true } +type configParser struct{} + +func (p configParser) ParseLogFile(val string) { + config.LogFile = val +} + +func (p configParser) ParseListen(val string) { + // Has specified command line options + if config.ListenAddr != nil { + return + } + arr := strings.Split(val, ",") + config.ListenAddr = make([]string, len(arr)) + for i, s := range arr { + s = strings.TrimSpace(s) + host, port := splitHostPort(s) + if port == "" { + Fatalf("listen address %s has no port\n", s) + } + if host == "" || host == "0.0.0.0" { + if len(arr) > 1 { + Fatalf("Too much listen addresses: "+ + "%s represents all ip addresses on this host.\n", s) + } + } + config.ListenAddr[i] = s + } +} + +func (p configParser) ParseAddrInPAC(val string) { + arr := strings.Split(val, ",") + config.AddrInPAC = make([]string, len(arr)) + for i, s := range arr { + if s == "" { + continue + } + s = strings.TrimSpace(s) + host, port := splitHostPort(s) + if port == "" { + Fatalf("proxy address in PAC %s has no port\n", s) + } + if host == "0.0.0.0" { + Fatal("Can't use 0.0.0.0 as proxy address in PAC") + } + config.AddrInPAC[i] = s + } +} + func (p configParser) ParseSocks(val string) { fmt.Println("socks option is going to be renamed to socksParent in the future, please change it") p.ParseSocksParent(val) @@ -202,17 +207,19 @@ func (p configParser) ParseSocks(val string) { func (p configParser) ParseSocksParent(val string) { if val == "" { - fmt.Println("empty socks parent") - os.Exit(1) + Fatal("empty socks parent") } config.SocksParent = val parentProxyCreator = append(parentProxyCreator, createctSocksConnection) } +func (p configParser) ParseSshServer(val string) { + config.SshServer = val +} + func (p configParser) ParseHttpParent(val string) { if val == "" { - fmt.Println("empty http parent") - os.Exit(1) + Fatal("empty http parent") } config.HttpParent = val parentProxyCreator = append(parentProxyCreator, createHttpProxyConnection) @@ -220,20 +227,11 @@ func (p configParser) ParseHttpParent(val string) { func (p configParser) ParseHttpUserPasswd(val string) { if val == "" { - fmt.Println("empty http user passwd") - os.Exit(1) + Fatal("empty http user passwd") } config.HttpUserPasswd = val } -func (p configParser) ParseCore(val string) { - config.Core = parseInt(val, "core") -} - -func (p configParser) ParseSshServer(val string) { - config.SshServer = val -} - func (p configParser) ParseUpdateBlocked(val string) { // config.UpdateBlocked = parseBool(val, "updateBlocked") fmt.Println("updateBlocked option will be removed in future, please remove it") @@ -249,26 +247,16 @@ func (p configParser) ParseAutoRetry(val string) { fmt.Println("autoRetry option will be removed in future, please remove it") } -func (p configParser) ParseDetectSSLErr(val string) { - config.DetectSSLErr = parseBool(val, "detectSSLErr") -} - -func (p configParser) ParseLogFile(val string) { - config.LogFile = val -} - func (p configParser) ParseAlwaysProxy(val string) { config.AlwaysProxy = parseBool(val, "alwaysProxy") } func (p configParser) ParseShadowSocks(val string) { if val == "" { - fmt.Println("empty shadowsocks server") - os.Exit(1) + Fatal("empty shadowsocks server") } if !hasPort(val) { - fmt.Println("shadowsocks server must have port specified") - os.Exit(1) + Fatal("shadowsocks server must have port specified") } parentProxyCreator = append(parentProxyCreator, createShadowSocksConnecter(len(config.ShadowSocks))) config.ShadowSocks = append(config.ShadowSocks, val) @@ -276,24 +264,20 @@ func (p configParser) ParseShadowSocks(val string) { func (p configParser) ParseShadowPasswd(val string) { if val == "" { - fmt.Println("empty shadowsocks password") - os.Exit(1) + Fatal("empty shadowsocks password") } if len(config.ShadowPasswd)+1 > len(config.ShadowSocks) { - fmt.Println("must specify shadowsocks server before it's password") - os.Exit(1) + Fatal("must specify shadowsocks server before it's password") } if len(config.ShadowPasswd)+1 < len(config.ShadowSocks) { - fmt.Println("must specify shadowsocks password for each server") - os.Exit(1) + Fatal("must specify shadowsocks password for each server") } config.ShadowPasswd = append(config.ShadowPasswd, val) } func (p configParser) ParseShadowMethod(val string) { if len(config.ShadowMethod)+1 > len(config.ShadowSocks) { - fmt.Println("must specify shadowsocks server before it's encryption method") - os.Exit(1) + Fatal("must specify shadowsocks server before it's encryption method") } for len(config.ShadowMethod)+1 < len(config.ShadowSocks) { // use empty string for unspecified encryption method @@ -306,6 +290,9 @@ func (p configParser) ParseShadowMethod(val string) { // doesn't need to know the details of authentication implementation. func (p configParser) ParseUserPasswd(val string) { + if val == "" { + Fatal("empty userPasswd") + } config.UserPasswd = val } @@ -317,6 +304,10 @@ func (p configParser) ParseAuthTimeout(val string) { config.AuthTimeout = parseDuration(val, "authTimeout") } +func (p configParser) ParseCore(val string) { + config.Core = parseInt(val, "core") +} + func (p configParser) ParseReadTimeout(val string) { config.ReadTimeout = parseDuration(val, "readTimeout") } @@ -325,6 +316,10 @@ func (p configParser) ParseDialTimeout(val string) { config.DialTimeout = parseDuration(val, "dialTimeout") } +func (p configParser) ParseDetectSSLErr(val string) { + config.DetectSSLErr = parseBool(val, "detectSSLErr") +} + func parseConfig(path string) { // fmt.Println("rcFile:", path) f, err := os.Open(expandTilde(path)) @@ -351,8 +346,7 @@ func parseConfig(path string) { if err == io.EOF { return } else if err != nil { - errl.Printf("Error reading rc file: %v\n", err) - os.Exit(1) + Fatalf("Error reading rc file: %v\n", err) } line = strings.TrimSpace(line) @@ -362,8 +356,7 @@ func parseConfig(path string) { v := strings.Split(line, "=") if len(v) != 2 { - fmt.Println("Config error: syntax error on line", n) - os.Exit(1) + Fatal("Config error: syntax error on line", n) } key, val := strings.TrimSpace(v[0]), strings.TrimSpace(v[1]) if val == "" { @@ -373,8 +366,7 @@ func parseConfig(path string) { methodName := "Parse" + strings.ToUpper(key[0:1]) + key[1:] method := parser.MethodByName(methodName) if method == zeroMethod { - fmt.Printf("Config error: no such option \"%s\"\n", key) - os.Exit(1) + Fatalf("Config error: no such option \"%s\"\n", key) } args := []reflect.Value{reflect.ValueOf(val)} method.Call(args) @@ -409,8 +401,7 @@ func updateConfig(nc *Config) { } if config.AddrInPAC != nil { if len(config.AddrInPAC) != len(config.ListenAddr) { - fmt.Println("Number of listen addresses and addr in PAC not match.") - os.Exit(1) + Fatal("Number of listen addresses and addr in PAC not match.") } } else { // empty string in addrInPac means same as listenAddr @@ -429,24 +420,19 @@ func updateConfig(nc *Config) { // call it after updateConfig. func checkConfig() { if !hasPort(config.HttpParent) { - fmt.Println("parent http server must have port specified") - os.Exit(1) + Fatal("parent http server must have port specified") } if !hasPort(config.SocksParent) { - fmt.Println("parent socks server must have port specified") - os.Exit(1) + Fatal("parent socks server must have port specified") } if !isUserPasswdValid(config.UserPasswd) { - fmt.Println("user password syntax wrong, should be in the form of user:passwd") - os.Exit(1) + Fatal("user password syntax wrong, should be in the form of user:passwd") } if !isUserPasswdValid(config.HttpUserPasswd) { - fmt.Println("http parent user password syntax wrong, should be in the form of user:passwd") - os.Exit(1) + Fatal("http parent user password syntax wrong, should be in the form of user:passwd") } if len(config.ShadowSocks) != len(config.ShadowPasswd) { - fmt.Println("number of shadowsocks server and password does not match") - os.Exit(1) + Fatal("number of shadowsocks server and password does not match") } for len(config.ShadowMethod) < len(config.ShadowSocks) { config.ShadowMethod = append(config.ShadowMethod, "") diff --git a/error.go b/error.go index 7241e8c..cc7b914 100644 --- a/error.go +++ b/error.go @@ -2,9 +2,7 @@ package main import ( "bytes" - "fmt" "io" - "os" "text/template" "time" ) @@ -35,12 +33,10 @@ var errPageTmpl, headTmpl, blockedFormTmpl, directFormTmpl *template.Template func init() { var err error if headTmpl, err = template.New("errorHead").Parse(headRawTmpl); err != nil { - fmt.Println("Internal error on generating error head template") - os.Exit(1) + Fatal("Internal error on generating error head template") } if errPageTmpl, err = template.New("errorPage").Parse(errPageRawTmpl); err != nil { - fmt.Println("Internal error on generating error page template") - os.Exit(1) + Fatalf("Internal error on generating error page template") } } diff --git a/http.go b/http.go index 5d6da35..4cd0bec 100644 --- a/http.go +++ b/http.go @@ -201,6 +201,9 @@ func (url *URL) HostIsIP() bool { // For port, return empty string if no port specified. // This also works for IPv6 address. func splitHostPort(s string) (host, port string) { + if len(s) == 0 { + return "", "" + } // Common case should has no port, check the last char first if !IsDigit(s[len(s)-1]) { return s, "" diff --git a/log.go b/log.go index 194621e..8f8aba5 100644 --- a/log.go +++ b/log.go @@ -122,3 +122,13 @@ func (d responseLogging) Printf(format string, args ...interface{}) { responseLog.Printf(format, args...) } } + +func Fatal(args ...interface{}) { + fmt.Println(args...) + os.Exit(1) +} + +func Fatalf(format string, args ...interface{}) { + fmt.Printf(format, args...) + os.Exit(1) +} diff --git a/main.go b/main.go index 24a9c77..42719bd 100644 --- a/main.go +++ b/main.go @@ -62,8 +62,7 @@ func main() { if *cpuprofile != "" { f, err := os.Create(*cpuprofile) if err != nil { - info.Println(err) - os.Exit(1) + Fatal(err) } pprof.StartCPUProfile(f) } diff --git a/pac.go b/pac.go index 17d5454..3d8ec94 100644 --- a/pac.go +++ b/pac.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "net" - "os" "strings" "text/template" "time" @@ -80,8 +79,7 @@ function FindProxyForURL(url, host) { var err error pac.template, err = template.New("pac").Parse(pacRawTmpl) if err != nil { - fmt.Println("Internal error on generating pac file template:", err) - os.Exit(1) + Fatal("Internal error on generating pac file template:", err) } var buf bytes.Buffer diff --git a/shadowsocks.go b/shadowsocks.go index 52f9284..7db7997 100644 --- a/shadowsocks.go +++ b/shadowsocks.go @@ -2,9 +2,7 @@ package main import ( "errors" - "fmt" ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" - "os" ) var noShadowSocksErr = errors.New("No shadowsocks configuration") @@ -19,8 +17,7 @@ func initShadowSocks() { for i, _ := range config.ShadowSocks { // initialize cipher for each shadowsocks connection if c, err := ss.NewCipher(config.ShadowMethod[i], config.ShadowPasswd[i]); err != nil { - fmt.Println("creating shadowsocks cipher:", err) - os.Exit(1) + Fatal("creating shadowsocks cipher:", err) } else { cipher = append(cipher, c) } diff --git a/util.go b/util.go index 24db4d7..c7c656c 100644 --- a/util.go +++ b/util.go @@ -181,7 +181,7 @@ var digitTbl = [256]int8{ // No prefix (e.g. 0xdeadbeef) should given. // base can only be 10 or 16. func ParseIntFromBytes(b []byte, base int) (n int64, err error) { - // Currently, one have to convert []byte to string to use strconv + // Currently, we have to convert []byte to string to use strconv // Refer to: http://code.google.com/p/go/issues/detail?id=2632 // That's why I created this function. if base != 10 && base != 16 {