From 471664c75de5f4168fe3e8ac8a16a9b22e79610d Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 13 Jun 2013 16:34:34 +0800 Subject: [PATCH] Check port in authentication. --- auth.go | 26 ++++++++++++++++++++------ auth_test.go | 8 ++++---- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/auth.go b/auth.go index 3907fa6..e9e36b4 100644 --- a/auth.go +++ b/auth.go @@ -35,8 +35,8 @@ type netAddr struct { type authUser struct { // user name is the key to auth.user, no need to store here passwd string - port uint16 // 0 means any port ha1 string // used in request digest, initialized ondemand + port uint16 // 0 means any port } var auth struct { @@ -79,7 +79,7 @@ func parseUserPasswd(userPasswd string) (user string, au *authUser, err error) { return "", nil, err } } - au = &authUser{passwd, uint16(port), ""} + au = &authUser{passwd, "", uint16(port)} return user, au, nil } @@ -87,7 +87,6 @@ func parseAllowedClient(val string) { if val == "" { return } - auth.required = true arr := strings.Split(val, ",") auth.allowedClient = make([]netAddr, len(arr)) for i, v := range arr { @@ -122,6 +121,7 @@ func addUserPasswd(val string) { return } user, au, err := parseUserPasswd(val) + debug.Println("user:", user, "port:", au.port) if err != nil { Fatal(err) } @@ -239,7 +239,7 @@ func calcRequestDigest(kv map[string]string, ha1, method string) string { return md5sum(buf.String()) } -func checkProxyAuthorization(r *Request) error { +func checkProxyAuthorization(conn *clientConn, r *Request) error { debug.Println("authorization:", r.ProxyAuthorization) arr := strings.SplitN(r.ProxyAuthorization, " ", 2) if len(arr) != 2 { @@ -264,18 +264,30 @@ func checkProxyAuthorization(r *Request) error { if time.Now().Sub(time.Unix(nonceTime, 0)) > time.Minute { return errAuthRequired } + user := authHeader["username"] au, ok := auth.user[user] if !ok { errl.Println("auth: no such user:", authHeader["username"]) return errAuthRequired } - au.initHA1(user) + + if au.port != 0 { + // check port + _, portStr := splitHostPort(conn.LocalAddr().String()) + port, _ := strconv.Atoi(portStr) + if uint16(port) != au.port { + errl.Println("auth: user", user, "port not match") + return errAuthRequired + } + } + if authHeader["qop"] != "auth" { msg := "auth: qop wrong: " + authHeader["qop"] errl.Println(msg) return errors.New(msg) } + response, ok := authHeader["response"] if !ok { msg := "auth: no request-digest" @@ -283,6 +295,7 @@ func checkProxyAuthorization(r *Request) error { return errors.New(msg) } + au.initHA1(user) digest := calcRequestDigest(authHeader, au.ha1, r.Method) if response == digest { return nil @@ -294,13 +307,14 @@ func checkProxyAuthorization(r *Request) error { func authUserPasswd(conn *clientConn, r *Request) (err error) { if r.ProxyAuthorization != "" { // client has sent authorization header - err = checkProxyAuthorization(r) + err = checkProxyAuthorization(conn, r) if err == nil { return } else if err != errAuthRequired { sendErrorPage(conn, errCodeBadReq, "Bad authorization request", err.Error()) return } + // auth required to through the following } nonce := genNonce() diff --git a/auth_test.go b/auth_test.go index d4684a5..4655fda 100644 --- a/auth_test.go +++ b/auth_test.go @@ -11,12 +11,12 @@ func TestParseUserPasswd(t *testing.T) { user string au *authUser }{ - {"foo:bar", "foo", &authUser{"bar", 0, ""}}, + {"foo:bar", "foo", &authUser{"bar", "", 0}}, {"foo:bar:-1", "", nil}, - {"hello:world:", "hello", &authUser{"world", 0, ""}}, + {"hello:world:", "hello", &authUser{"world", "", 0}}, {"hello:world:0", "", nil}, - {"hello:world:1024", "hello", &authUser{"world", 1024, ""}}, - {"hello:world:65535", "hello", &authUser{"world", 65535, ""}}, + {"hello:world:1024", "hello", &authUser{"world", "", 1024}}, + {"hello:world:65535", "hello", &authUser{"world", "", 65535}}, } for _, td := range testData {