mirror of
https://github.com/zhigang1992/mitmproxy.git
synced 2026-04-24 04:14:57 +08:00
clean up http message models
This commit is contained in:
@@ -50,14 +50,14 @@ def _assemble_request_line(request, form=None):
|
||||
return b"%s %s %s" % (
|
||||
request.method,
|
||||
request.path,
|
||||
request.httpversion
|
||||
request.http_version
|
||||
)
|
||||
elif form == "authority":
|
||||
return b"%s %s:%d %s" % (
|
||||
request.method,
|
||||
request.host,
|
||||
request.port,
|
||||
request.httpversion
|
||||
request.http_version
|
||||
)
|
||||
elif form == "absolute":
|
||||
return b"%s %s://%s:%d%s %s" % (
|
||||
@@ -66,7 +66,7 @@ def _assemble_request_line(request, form=None):
|
||||
request.host,
|
||||
request.port,
|
||||
request.path,
|
||||
request.httpversion
|
||||
request.http_version
|
||||
)
|
||||
else: # pragma: nocover
|
||||
raise RuntimeError("Invalid request form")
|
||||
@@ -93,7 +93,7 @@ def _assemble_request_headers(request):
|
||||
|
||||
def _assemble_response_line(response):
|
||||
return b"%s %d %s" % (
|
||||
response.httpversion,
|
||||
response.http_version,
|
||||
response.status_code,
|
||||
response.msg,
|
||||
)
|
||||
|
||||
@@ -193,15 +193,45 @@ class Headers(MutableMapping, object):
|
||||
return cls([list(field) for field in state])
|
||||
|
||||
|
||||
class Request(object):
|
||||
class Message(object):
|
||||
def __init__(self, http_version, headers, body, timestamp_start, timestamp_end):
|
||||
self.http_version = http_version
|
||||
if not headers:
|
||||
headers = Headers()
|
||||
assert isinstance(headers, Headers)
|
||||
self.headers = headers
|
||||
|
||||
self._body = body
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self._body
|
||||
|
||||
@body.setter
|
||||
def body(self, body):
|
||||
self._body = body
|
||||
if isinstance(body, bytes):
|
||||
self.headers[b"Content-Length"] = str(len(body)).encode()
|
||||
|
||||
content = body
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Message):
|
||||
return self.__dict__ == other.__dict__
|
||||
return False
|
||||
|
||||
|
||||
class Request(Message):
|
||||
# This list is adopted legacy code.
|
||||
# We probably don't need to strip off keep-alive.
|
||||
_headers_to_strip_off = [
|
||||
'Proxy-Connection',
|
||||
'Keep-Alive',
|
||||
'Connection',
|
||||
'Transfer-Encoding',
|
||||
'Upgrade',
|
||||
b'Proxy-Connection',
|
||||
b'Keep-Alive',
|
||||
b'Connection',
|
||||
b'Transfer-Encoding',
|
||||
b'Upgrade',
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -212,16 +242,14 @@ class Request(object):
|
||||
host,
|
||||
port,
|
||||
path,
|
||||
httpversion,
|
||||
http_version,
|
||||
headers=None,
|
||||
body=None,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
form_out=None
|
||||
):
|
||||
if not headers:
|
||||
headers = Headers()
|
||||
assert isinstance(headers, Headers)
|
||||
super(Request, self).__init__(http_version, headers, body, timestamp_start, timestamp_end)
|
||||
|
||||
self.form_in = form_in
|
||||
self.method = method
|
||||
@@ -229,23 +257,8 @@ class Request(object):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.path = path
|
||||
self.httpversion = httpversion
|
||||
self.headers = headers
|
||||
self._body = body
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
self.form_out = form_out or form_in
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
self_d = [self.__dict__[k] for k in self.__dict__ if
|
||||
k not in ('timestamp_start', 'timestamp_end')]
|
||||
other_d = [other.__dict__[k] for k in other.__dict__ if
|
||||
k not in ('timestamp_start', 'timestamp_end')]
|
||||
return self_d == other_d
|
||||
except:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
if self.host and self.port:
|
||||
hostport = "{}:{}".format(self.host, self.port)
|
||||
@@ -262,8 +275,8 @@ class Request(object):
|
||||
response. That is, we remove ETags and If-Modified-Since headers.
|
||||
"""
|
||||
delheaders = [
|
||||
"if-modified-since",
|
||||
"if-none-match",
|
||||
b"if-modified-since",
|
||||
b"if-none-match",
|
||||
]
|
||||
for i in delheaders:
|
||||
self.headers.pop(i, None)
|
||||
@@ -273,16 +286,16 @@ class Request(object):
|
||||
Modifies this request to remove headers that will compress the
|
||||
resource's data.
|
||||
"""
|
||||
self.headers["accept-encoding"] = "identity"
|
||||
self.headers[b"accept-encoding"] = b"identity"
|
||||
|
||||
def constrain_encoding(self):
|
||||
"""
|
||||
Limits the permissible Accept-Encoding values, based on what we can
|
||||
decode appropriately.
|
||||
"""
|
||||
accept_encoding = self.headers.get("accept-encoding")
|
||||
accept_encoding = self.headers.get(b"accept-encoding")
|
||||
if accept_encoding:
|
||||
self.headers["accept-encoding"] = (
|
||||
self.headers[b"accept-encoding"] = (
|
||||
', '.join(
|
||||
e
|
||||
for e in encoding.ENCODINGS
|
||||
@@ -335,7 +348,7 @@ class Request(object):
|
||||
"""
|
||||
# FIXME: If there's an existing content-type header indicating a
|
||||
# url-encoded form, leave it alone.
|
||||
self.headers["Content-Type"] = HDR_FORM_URLENCODED
|
||||
self.headers[b"Content-Type"] = HDR_FORM_URLENCODED
|
||||
self.body = utils.urlencode(odict.lst)
|
||||
|
||||
def get_path_components(self):
|
||||
@@ -452,37 +465,17 @@ class Request(object):
|
||||
raise ValueError("Invalid URL: %s" % url)
|
||||
self.scheme, self.host, self.port, self.path = parts
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self._body
|
||||
|
||||
@body.setter
|
||||
def body(self, body):
|
||||
self._body = body
|
||||
if isinstance(body, bytes):
|
||||
self.headers["Content-Length"] = str(len(body)).encode()
|
||||
|
||||
@property
|
||||
def content(self): # pragma: no cover
|
||||
# TODO: remove deprecated getter
|
||||
return self.body
|
||||
|
||||
@content.setter
|
||||
def content(self, content): # pragma: no cover
|
||||
# TODO: remove deprecated setter
|
||||
self.body = content
|
||||
|
||||
|
||||
class Response(object):
|
||||
class Response(Message):
|
||||
_headers_to_strip_off = [
|
||||
'Proxy-Connection',
|
||||
'Alternate-Protocol',
|
||||
'Alt-Svc',
|
||||
b'Proxy-Connection',
|
||||
b'Alternate-Protocol',
|
||||
b'Alt-Svc',
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
httpversion,
|
||||
http_version,
|
||||
status_code,
|
||||
msg=None,
|
||||
headers=None,
|
||||
@@ -490,27 +483,9 @@ class Response(object):
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
):
|
||||
if not headers:
|
||||
headers = Headers()
|
||||
assert isinstance(headers, Headers)
|
||||
|
||||
self.httpversion = httpversion
|
||||
super(Response, self).__init__(http_version, headers, body, timestamp_start, timestamp_end)
|
||||
self.status_code = status_code
|
||||
self.msg = msg
|
||||
self.headers = headers
|
||||
self._body = body
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
self_d = [self.__dict__[k] for k in self.__dict__ if
|
||||
k not in ('timestamp_start', 'timestamp_end')]
|
||||
other_d = [other.__dict__[k] for k in other.__dict__ if
|
||||
k not in ('timestamp_start', 'timestamp_end')]
|
||||
return self_d == other_d
|
||||
except:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
# return "Response(%s - %s)" % (self.status_code, self.msg)
|
||||
@@ -536,7 +511,7 @@ class Response(object):
|
||||
attributes (e.g. HTTPOnly) are indicated by a Null value.
|
||||
"""
|
||||
ret = []
|
||||
for header in self.headers.get_all("set-cookie"):
|
||||
for header in self.headers.get_all(b"set-cookie"):
|
||||
v = cookies.parse_set_cookie_header(header)
|
||||
if v:
|
||||
name, value, attrs = v
|
||||
@@ -559,34 +534,4 @@ class Response(object):
|
||||
i[1][1]
|
||||
)
|
||||
)
|
||||
self.headers.set_all("Set-Cookie", values)
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self._body
|
||||
|
||||
@body.setter
|
||||
def body(self, body):
|
||||
self._body = body
|
||||
if isinstance(body, bytes):
|
||||
self.headers["Content-Length"] = str(len(body)).encode()
|
||||
|
||||
@property
|
||||
def content(self): # pragma: no cover
|
||||
# TODO: remove deprecated getter
|
||||
return self.body
|
||||
|
||||
@content.setter
|
||||
def content(self, content): # pragma: no cover
|
||||
# TODO: remove deprecated setter
|
||||
self.body = content
|
||||
|
||||
@property
|
||||
def code(self): # pragma: no cover
|
||||
# TODO: remove deprecated getter
|
||||
return self.status_code
|
||||
|
||||
@code.setter
|
||||
def code(self, code): # pragma: no cover
|
||||
# TODO: remove deprecated setter
|
||||
self.status_code = code
|
||||
self.headers.set_all(b"Set-Cookie", values)
|
||||
|
||||
@@ -105,7 +105,7 @@ def treq(**kwargs):
|
||||
host=b"address",
|
||||
port=22,
|
||||
path=b"/path",
|
||||
httpversion=b"HTTP/1.1",
|
||||
http_version=b"HTTP/1.1",
|
||||
headers=Headers(header=b"qvalue"),
|
||||
body=b"content"
|
||||
)
|
||||
@@ -119,7 +119,7 @@ def tresp(**kwargs):
|
||||
netlib.http.Response
|
||||
"""
|
||||
default = dict(
|
||||
httpversion=b"HTTP/1.1",
|
||||
http_version=b"HTTP/1.1",
|
||||
status_code=200,
|
||||
msg=b"OK",
|
||||
headers=Headers(header_response=b"svalue"),
|
||||
|
||||
@@ -17,11 +17,6 @@ def isascii(bytes):
|
||||
return True
|
||||
|
||||
|
||||
# best way to do it in python 2.x
|
||||
def bytes_to_int(i):
|
||||
return int(i.encode('hex'), 16)
|
||||
|
||||
|
||||
def clean_bin(s, keep_spacing=True):
|
||||
"""
|
||||
Cleans binary data to make it safe to display.
|
||||
@@ -51,21 +46,15 @@ def clean_bin(s, keep_spacing=True):
|
||||
|
||||
def hexdump(s):
|
||||
"""
|
||||
Returns a set of tuples:
|
||||
(offset, hex, str)
|
||||
Returns:
|
||||
A generator of (offset, hex, str) tuples
|
||||
"""
|
||||
parts = []
|
||||
for i in range(0, len(s), 16):
|
||||
o = "%.10x" % i
|
||||
offset = b"%.10x" % i
|
||||
part = s[i:i + 16]
|
||||
x = " ".join("%.2x" % ord(i) for i in part)
|
||||
if len(part) < 16:
|
||||
x += " "
|
||||
x += " ".join(" " for i in range(16 - len(part)))
|
||||
parts.append(
|
||||
(o, x, clean_bin(part, False))
|
||||
)
|
||||
return parts
|
||||
x = b" ".join(b"%.2x" % i for i in six.iterbytes(part))
|
||||
x = x.ljust(47) # 16*2 + 15
|
||||
yield (offset, x, clean_bin(part, False))
|
||||
|
||||
|
||||
def setbit(byte, offset, value):
|
||||
@@ -80,8 +69,7 @@ def setbit(byte, offset, value):
|
||||
|
||||
def getbit(byte, offset):
|
||||
mask = 1 << offset
|
||||
if byte & mask:
|
||||
return True
|
||||
return bool(byte & mask)
|
||||
|
||||
|
||||
class BiDi(object):
|
||||
@@ -159,7 +147,7 @@ def is_valid_host(host):
|
||||
return False
|
||||
if len(host) > 255:
|
||||
return False
|
||||
if host[-1] == ".":
|
||||
if host[-1] == b".":
|
||||
host = host[:-1]
|
||||
return all(_label_valid.match(x) for x in host.split(b"."))
|
||||
|
||||
@@ -248,7 +236,7 @@ def hostport(scheme, host, port):
|
||||
"""
|
||||
Returns the host component, with a port specifcation if needed.
|
||||
"""
|
||||
if (port, scheme) in [(80, "http"), (443, "https")]:
|
||||
if (port, scheme) in [(80, b"http"), (443, b"https")]:
|
||||
return host
|
||||
else:
|
||||
return b"%s:%d" % (host, port)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
import os
|
||||
import struct
|
||||
import io
|
||||
import six
|
||||
|
||||
from .protocol import Masker
|
||||
from netlib import tcp
|
||||
@@ -127,8 +128,8 @@ class FrameHeader(object):
|
||||
"""
|
||||
read a websockets frame header
|
||||
"""
|
||||
first_byte = utils.bytes_to_int(fp.safe_read(1))
|
||||
second_byte = utils.bytes_to_int(fp.safe_read(1))
|
||||
first_byte = six.byte2int(fp.safe_read(1))
|
||||
second_byte = six.byte2int(fp.safe_read(1))
|
||||
|
||||
fin = utils.getbit(first_byte, 7)
|
||||
rsv1 = utils.getbit(first_byte, 6)
|
||||
@@ -145,9 +146,9 @@ class FrameHeader(object):
|
||||
if length_code <= 125:
|
||||
payload_length = length_code
|
||||
elif length_code == 126:
|
||||
payload_length = utils.bytes_to_int(fp.safe_read(2))
|
||||
payload_length, = struct.unpack("!H", fp.safe_read(2))
|
||||
elif length_code == 127:
|
||||
payload_length = utils.bytes_to_int(fp.safe_read(8))
|
||||
payload_length, = struct.unpack("!Q", fp.safe_read(8))
|
||||
|
||||
# masking key only present if mask bit set
|
||||
if mask_bit == 1:
|
||||
|
||||
@@ -17,6 +17,7 @@ from __future__ import absolute_import
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import six
|
||||
from ..http import Headers
|
||||
from .. import utils
|
||||
|
||||
@@ -40,7 +41,7 @@ class Masker(object):
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.masks = [utils.bytes_to_int(byte) for byte in key]
|
||||
self.masks = [six.byte2int(byte) for byte in key]
|
||||
self.offset = 0
|
||||
|
||||
def mask(self, offset, data):
|
||||
|
||||
@@ -413,7 +413,7 @@ class TestReadResponse(tservers.ServerTestBase):
|
||||
|
||||
resp = protocol.read_response(NotImplemented, stream_id=42)
|
||||
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.http_version == (2, 0)
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']]
|
||||
@@ -440,7 +440,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
|
||||
resp = protocol.read_response(NotImplemented, stream_id=42)
|
||||
|
||||
assert resp.stream_id == 42
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.http_version == (2, 0)
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']]
|
||||
|
||||
@@ -103,11 +103,11 @@ def test_get_header_tokens():
|
||||
headers = Headers()
|
||||
assert utils.get_header_tokens(headers, "foo") == []
|
||||
headers["foo"] = "bar"
|
||||
assert utils.get_header_tokens(headers, "foo") == ["bar"]
|
||||
assert utils.get_header_tokens(headers, "foo") == [b"bar"]
|
||||
headers["foo"] = "bar, voing"
|
||||
assert utils.get_header_tokens(headers, "foo") == ["bar", "voing"]
|
||||
assert utils.get_header_tokens(headers, "foo") == [b"bar", b"voing"]
|
||||
headers.set_all("foo", ["bar, voing", "oink"])
|
||||
assert utils.get_header_tokens(headers, "foo") == ["bar", "voing", "oink"]
|
||||
assert utils.get_header_tokens(headers, "foo") == [b"bar", b"voing", b"oink"]
|
||||
|
||||
|
||||
def test_multipartdecode():
|
||||
@@ -134,8 +134,8 @@ def test_multipartdecode():
|
||||
|
||||
def test_parse_content_type():
|
||||
p = utils.parse_content_type
|
||||
assert p("text/html") == ("text", "html", {})
|
||||
assert p("text") is None
|
||||
assert p(b"text/html") == (b"text", b"html", {})
|
||||
assert p(b"text") is None
|
||||
|
||||
v = p("text/html; charset=UTF-8")
|
||||
assert v == ('text', 'html', {'charset': 'UTF-8'})
|
||||
v = p(b"text/html; charset=UTF-8")
|
||||
assert v == (b'text', b'html', {b'charset': b'UTF-8'})
|
||||
|
||||
Reference in New Issue
Block a user