diff options
Diffstat (limited to 'libmproxy/protocol/http.py')
-rw-r--r-- | libmproxy/protocol/http.py | 304 |
1 files changed, 167 insertions, 137 deletions
diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 49310ec3..852ce393 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -6,15 +6,16 @@ import time import copy from email.utils import parsedate_tz, formatdate, mktime_tz import threading -from netlib import http, tcp, http_status +from netlib import http, tcp, http_status, http_cookies import netlib.utils -from netlib.odict import ODict, ODictCaseless +from netlib import odict from .tcp import TCPHandler from .primitives import KILL, ProtocolHandler, Flow, Error from ..proxy.connection import ServerConnection from .. import encoding, utils, controller, stateobject, proxy HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 @@ -22,19 +23,6 @@ class KillSignal(Exception): pass -def get_line(fp): - """ - Get a line, possibly preceded by a blank. - """ - line = fp.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = fp.readline() - if line == "": - raise tcp.NetLibDisconnect() - return line - - def send_connect_request(conn, host, port, update_state=True): upstream_request = HTTPRequest( "authority", @@ -44,7 +32,7 @@ def send_connect_request(conn, host, port, update_state=True): port, None, (1, 1), - ODictCaseless(), + odict.ODictCaseless(), "" ) conn.send(upstream_request.assemble()) @@ -99,7 +87,7 @@ class HTTPMessage(stateobject.StateObject): timestamp_end=None): self.httpversion = httpversion self.headers = headers - """@type: ODictCaseless""" + """@type: odict.ODictCaseless""" self.content = content self.timestamp_start = timestamp_start @@ -107,7 +95,7 @@ class HTTPMessage(stateobject.StateObject): _stateobject_attributes = dict( httpversion=tuple, - headers=ODictCaseless, + headers=odict.ODictCaseless, content=str, timestamp_start=float, timestamp_end=float @@ -119,6 +107,8 @@ class HTTPMessage(stateobject.StateObject): if short: if self.content: ret["contentLength"] = len(self.content) + elif self.content == CONTENT_MISSING: + ret["contentLength"] = None else: ret["contentLength"] = 0 return ret @@ -239,7 +229,7 @@ class HTTPRequest(HTTPMessage): httpversion: HTTP version tuple, e.g. (1,1) - headers: ODictCaseless object + headers: odict.ODictCaseless object content: Content of the request, None, or CONTENT_MISSING if there is content associated, but not present. CONTENT_MISSING evaluates @@ -277,7 +267,7 @@ class HTTPRequest(HTTPMessage): timestamp_end=None, form_out=None ): - assert isinstance(headers, ODictCaseless) or not headers + assert isinstance(headers, odict.ODictCaseless) or not headers HTTPMessage.__init__( self, httpversion, @@ -325,78 +315,46 @@ class HTTPRequest(HTTPMessage): ) @classmethod - def from_stream(cls, rfile, include_body=True, body_size_limit=None): + def from_stream(cls, rfile, include_body=True, body_size_limit=None, wfile=None): """ Parse an HTTP request from a file stream + + Args: + rfile (file): Input file to read from + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled automatically. + by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + HTTPRequest: The HTTP request + + Raises: + HttpError: If the input is invalid. """ - httpversion, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end = ( - None, None, None, None, None, None, None, None, None, None) + timestamp_start, timestamp_end = None, None timestamp_start = utils.timestamp() - if hasattr(rfile, "reset_timestamps"): rfile.reset_timestamps() - request_line = get_line(rfile) - - if hasattr(rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = rfile.first_byte_timestamp - - request_line_parts = http.parse_init(request_line) - if not request_line_parts: - raise http.HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not netlib.utils.isascii(path): - raise http.HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = http.parse_init_connect(request_line) - if not r: - raise http.HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, _ = r - path = None - else: - form_in = "absolute" - r = http.parse_init_proxy(request_line) - if not r: - raise http.HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = http.read_headers(rfile) - if headers is None: - raise http.HttpError(400, "Invalid headers") - - if include_body: - content = http.read_http_body(rfile, headers, body_size_limit, - method, None, True) - timestamp_end = utils.timestamp() - + req = http.read_request( + rfile, + include_body = include_body, + body_size_limit = body_size_limit, + wfile = wfile + ) + timestamp_end = utils.timestamp() return HTTPRequest( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content, + req.form_in, + req.method, + req.scheme, + req.host, + req.port, + req.path, + req.httpversion, + req.headers, + req.content, timestamp_start, timestamp_end ) @@ -444,7 +402,7 @@ class HTTPRequest(HTTPMessage): if self.content or self.content == "": headers["Content-Length"] = [str(len(self.content))] - return str(headers) + return headers.format() def _assemble_head(self, form=None): return "%s\r\n%s\r\n" % ( @@ -507,6 +465,19 @@ class HTTPRequest(HTTPMessage): """ self.headers["Host"] = [self.host] + def get_form(self): + """ + Retrieves the URL-encoded or multipart form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.content: + if self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): + return self.get_form_urlencoded() + elif self.headers.in_any("content-type", HDR_FORM_MULTIPART, True): + return self.get_form_multipart() + return odict.ODict([]) + def get_form_urlencoded(self): """ Retrieves the URL-encoded form data, returning an ODict object. @@ -514,8 +485,13 @@ class HTTPRequest(HTTPMessage): indicates non-form data. """ if self.content and self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): - return ODict(utils.urldecode(self.content)) - return ODict([]) + return odict.ODict(utils.urldecode(self.content)) + return odict.ODict([]) + + def get_form_multipart(self): + if self.content and self.headers.in_any("content-type", HDR_FORM_MULTIPART, True): + return odict.ODict(utils.multipartdecode(self.headers, self.content)) + return odict.ODict([]) def set_form_urlencoded(self, odict): """ @@ -556,8 +532,8 @@ class HTTPRequest(HTTPMessage): """ _, _, _, _, query, _ = urlparse.urlparse(self.url) if query: - return ODict(utils.urldecode(query)) - return ODict([]) + return odict.ODict(utils.urldecode(query)) + return odict.ODict([]) def set_query(self, odict): """ @@ -588,8 +564,10 @@ class HTTPRequest(HTTPMessage): host = self.headers.get_first("host") if not host: host = self.host - host = host.encode("idna") - return host + if host: + return host.encode("idna") + else: + return None def pretty_url(self, hostheader): if self.form_out == "authority": # upstream proxy mode @@ -625,15 +603,22 @@ class HTTPRequest(HTTPMessage): self.scheme, self.host, self.port, self.path = parts def get_cookies(self): - cookie_headers = self.headers.get("cookie") - if not cookie_headers: - return None + """ - cookies = [] - for header in cookie_headers: - pairs = [pair.partition("=") for pair in header.split(';')] - cookies.extend((pair[0], (pair[2], {})) for pair in pairs) - return dict(cookies) + Returns a possibly empty netlib.odict.ODict object. + """ + ret = odict.ODict() + for i in self.headers["cookie"]: + ret.extend(http_cookies.parse_cookie_header(i)) + return ret + + def set_cookies(self, odict): + """ + Takes an netlib.odict.ODict object. Over-writes any existing Cookie + headers. + """ + v = http_cookies.format_cookie_header(odict) + self.headers["Cookie"] = [v] def replace(self, pattern, repl, *args, **kwargs): """ @@ -676,7 +661,7 @@ class HTTPResponse(HTTPMessage): def __init__(self, httpversion, code, msg, headers, content, timestamp_start=None, timestamp_end=None): - assert isinstance(headers, ODictCaseless) or headers is None + assert isinstance(headers, odict.ODictCaseless) or headers is None HTTPMessage.__init__( self, httpversion, @@ -771,7 +756,7 @@ class HTTPResponse(HTTPMessage): if self.content or self.content == "": headers["Content-Length"] = [str(len(self.content))] - return str(headers) + return headers.format() def _assemble_head(self, preserve_transfer_encoding=False): return '%s\r\n%s\r\n' % ( @@ -850,20 +835,39 @@ class HTTPResponse(HTTPMessage): self.headers["set-cookie"] = c def get_cookies(self): - cookie_headers = self.headers.get("set-cookie") - if not cookie_headers: - return None + """ + Get the contents of all Set-Cookie headers. + + Returns a possibly empty ODict, where keys are cookie name strings, + and values are [value, attr] lists. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary + attributes (e.g. HTTPOnly) are indicated by a Null value. + """ + ret = [] + for header in self.headers["set-cookie"]: + v = http_cookies.parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append([name, [value, attrs]]) + return odict.ODict(ret) + + def set_cookies(self, odict): + """ + Set the Set-Cookie headers on this response, over-writing existing + headers. - cookies = [] - for header in cookie_headers: - pairs = [pair.partition("=") for pair in header.split(';')] - cookie_name = pairs[0][0] # the key of the first key/value pairs - cookie_value = pairs[0][2] # the value of the first key/value pairs - cookie_parameters = { - key.strip().lower(): value.strip() for key, sep, value in pairs[1:] - } - cookies.append((cookie_name, (cookie_value, cookie_parameters))) - return dict(cookies) + Accepts an ODict of the same format as that returned by get_cookies. + """ + values = [] + for i in odict.lst: + values.append( + http_cookies.format_set_cookie_header( + i[0], + i[1][0], + i[1][1] + ) + ) + self.headers["Set-Cookie"] = values class HTTPFlow(Flow): @@ -1041,7 +1045,8 @@ class HTTPHandler(ProtocolHandler): try: req = HTTPRequest.from_stream( self.c.client_conn.rfile, - body_size_limit=self.c.config.body_size_limit + body_size_limit=self.c.config.body_size_limit, + wfile=self.c.client_conn.wfile ) except tcp.NetLibError: # don't throw an error for disconnects that happen @@ -1304,7 +1309,8 @@ class HTTPHandler(ProtocolHandler): ) if needs_server_change: - # force create new connection to the proxy server to reset state + # force create new connection to the proxy server to reset + # state self.live.change_server(self.c.server_conn.address, force=True) if ssl: send_connect_request( @@ -1314,8 +1320,9 @@ class HTTPHandler(ProtocolHandler): ) self.c.establish_ssl(server=True) else: - # If we're not in upstream mode, we just want to update the host and - # possibly establish TLS. This is a no op if the addresses match. + # If we're not in upstream mode, we just want to update the host + # and possibly establish TLS. This is a no op if the addresses + # match. self.live.change_server(address, ssl=ssl) flow.server_conn = self.c.server_conn @@ -1323,8 +1330,8 @@ class HTTPHandler(ProtocolHandler): def send_response_to_client(self, flow): if not flow.response.stream: # no streaming: - # we already received the full response from the server and can send - # it to the client straight away. + # we already received the full response from the server and can + # send it to the client straight away. self.c.client_conn.send(flow.response.assemble()) else: # streaming: @@ -1362,8 +1369,10 @@ class HTTPHandler(ProtocolHandler): flow.response.code) == -1) if close_connection: if flow.request.form_in == "authority" and flow.response.code == 200: - # Workaround for https://github.com/mitmproxy/mitmproxy/issues/313: - # Some proxies (e.g. Charles) send a CONNECT response with HTTP/1.0 and no Content-Length header + # Workaround for + # https://github.com/mitmproxy/mitmproxy/issues/313: Some + # proxies (e.g. Charles) send a CONNECT response with HTTP/1.0 + # and no Content-Length header pass else: return True @@ -1385,14 +1394,16 @@ class HTTPHandler(ProtocolHandler): self.expected_form_out = "relative" self.skip_authentication = True - # In practice, nobody issues a CONNECT request to send unencrypted HTTP requests afterwards. - # If we don't delegate to TCP mode, we should always negotiate a SSL connection. + # In practice, nobody issues a CONNECT request to send unencrypted + # HTTP requests afterwards. If we don't delegate to TCP mode, we + # should always negotiate a SSL connection. # - # FIXME: - # Turns out the previous statement isn't entirely true. Chrome on Windows CONNECTs to :80 - # if an explicit proxy is configured and a websocket connection should be established. - # We don't support websocket at the moment, so it fails anyway, but we should come up with - # a better solution to this if we start to support WebSockets. + # FIXME: Turns out the previous statement isn't entirely true. + # Chrome on Windows CONNECTs to :80 if an explicit proxy is + # configured and a websocket connection should be established. We + # don't support websocket at the moment, so it fails anyway, but we + # should come up with a better solution to this if we start to + # support WebSockets. should_establish_ssl = ( address.port in self.c.config.ssl_ports or @@ -1400,12 +1411,18 @@ class HTTPHandler(ProtocolHandler): ) if should_establish_ssl: - self.c.log("Received CONNECT request to SSL port. Upgrading to SSL...", "debug") + self.c.log( + "Received CONNECT request to SSL port. " + "Upgrading to SSL...", "debug" + ) self.c.establish_ssl(server=True, client=True) self.c.log("Upgrade to SSL completed.", "debug") if self.c.config.check_tcp(address): - self.c.log("Generic TCP mode for host: %s:%s" % address(), "info") + self.c.log( + "Generic TCP mode for host: %s:%s" % address(), + "info" + ) TCPHandler(self.c).handle_messages() return False @@ -1426,7 +1443,8 @@ class RequestReplayThread(threading.Thread): def __init__(self, config, flow, masterq, should_exit): """ - masterqueue can be a queue or None, if no scripthooks should be processed. + masterqueue can be a queue or None, if no scripthooks should be + processed. """ self.config, self.flow = config, flow if masterq: @@ -1452,12 +1470,17 @@ class RequestReplayThread(threading.Thread): if not self.flow.response: # In all modes, we directly connect to the server displayed if self.config.mode == "upstream": - server_address = self.config.mode.get_upstream_server(self.flow.client_conn)[2:] + server_address = self.config.mode.get_upstream_server( + self.flow.client_conn + )[2:] server = ServerConnection(server_address) server.connect() if r.scheme == "https": send_connect_request(server, r.host, r.port) - server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni) + server.establish_ssl( + self.config.clientcerts, + sni=self.flow.server_conn.sni + ) r.form_out = "relative" else: r.form_out = "absolute" @@ -1466,12 +1489,18 @@ class RequestReplayThread(threading.Thread): server = ServerConnection(server_address) server.connect() if r.scheme == "https": - server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni) + server.establish_ssl( + self.config.clientcerts, + sni=self.flow.server_conn.sni + ) r.form_out = "relative" server.send(r.assemble()) self.flow.server_conn = server - self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, - body_size_limit=self.config.body_size_limit) + self.flow.response = HTTPResponse.from_stream( + server.rfile, + r.method, + body_size_limit=self.config.body_size_limit + ) if self.channel: response_reply = self.channel.ask("response", self.flow) if response_reply is None or response_reply == KILL: @@ -1481,7 +1510,8 @@ class RequestReplayThread(threading.Thread): if self.channel: self.channel.ask("error", self.flow) except KillSignal: - # KillSignal should only be raised if there's a channel in the first place. + # KillSignal should only be raised if there's a channel in the + # first place. self.channel.tell("log", proxy.Log("Connection killed", "info")) finally: r.form_out = form_out_backup |