diff options
-rw-r--r-- | libmproxy/protocol.py | 100 | ||||
-rw-r--r-- | libmproxy/proxy.py | 10 |
2 files changed, 69 insertions, 41 deletions
diff --git a/libmproxy/protocol.py b/libmproxy/protocol.py index 402caef5..c6223164 100644 --- a/libmproxy/protocol.py +++ b/libmproxy/protocol.py @@ -1,12 +1,12 @@ from libmproxy import flow -from libmproxy.utils import timestamp -from netlib import http, utils, tcp +import libmproxy.utils +from netlib import http, tcp +import netlib.utils from netlib.odict import ODictCaseless KILL = 0 # FIXME: Remove duplication with proxy module LEGACY = True -#FIXME: Combine with ProxyError? class ProtocolError(Exception): def __init__(self, code, msg, headers=None): self.code, self.msg, self.headers = code, msg, headers @@ -30,10 +30,6 @@ def handle_messages(conntype, connection_handler): _handle("messages", conntype, connection_handler) -def handle_error(conntype, connection_handler, e): - _handle("error", conntype, connection_handler, e) - - class ConnectionTypeChange(Exception): pass @@ -56,7 +52,20 @@ class HTTPFlow(Flow): self.request, self.response = request, response -class HTTPResponse(object): +class HTTPMessage(object): + def _assemble_headers(self): + headers = self.headers.copy() + libmproxy.utils.del_all(headers, + ["proxy-connection", + "transfer-encoding"]) + if self.content: + headers["Content-Length"] = [str(len(self.content))] + elif 'Transfer-Encoding' in self.headers: # content-length for e.g. chuncked transfer-encoding with no content + headers["Content-Length"] = ["0"] + + return str(headers) + +class HTTPResponse(HTTPMessage): def __init__(self, http_version, code, msg, headers, content, timestamp_start, timestamp_end): self.http_version = http_version self.code = code @@ -75,7 +84,7 @@ class HTTPResponse(object): def _assemble(self): response_line = 'HTTP/%s.%s %s %s'%(self.http_version[0], self.http_version[1], self.code, self.msg) - return '%s\r\n%s\r\n%s' % (response_line, str(self.headers), self.content) + return '%s\r\n%s\r\n%s' % (response_line, self._assemble_headers(), self.content) @classmethod def from_stream(cls, rfile, request_method, include_content=True, body_size_limit=None): @@ -85,15 +94,16 @@ class HTTPResponse(object): if not include_content: raise NotImplementedError - timestamp_start = timestamp() + timestamp_start = libmproxy.utils.timestamp() http_version, code, msg, headers, content = http.read_response( rfile, request_method, body_size_limit) - timestamp_end = timestamp() + timestamp_end = libmproxy.utils.timestamp() return HTTPResponse(http_version, code, msg, headers, content, timestamp_start, timestamp_end) -class HTTPRequest(object): + +class HTTPRequest(HTTPMessage): def __init__(self, form_in, method, scheme, host, port, path, http_version, headers, content, timestamp_start, timestamp_end, form_out=None, ip=None): self.form_in = form_in @@ -109,7 +119,7 @@ class HTTPRequest(object): self.timestamp_end = timestamp_end self.form_out = form_out or self.form_in - self.ip = ip # resolved ip address + self.ip = ip # resolved ip address assert isinstance(headers, ODictCaseless) #FIXME: Remove, legacy @@ -123,9 +133,14 @@ class HTTPRequest(object): elif self.form_out == "authority": request_line = '%s %s:%s HTTP/%s.%s' % (self.method, self.host, self.port, self.http_version[0], self.http_version[1]) + elif self.form_out == "absolute": + request_line = '%s %s://%s:%s%s HTTP/%s.%s' % \ + (self.method, self.scheme, self.host, self.port, self.path, + self.http_version[0], self.http_version[1]) else: - raise NotImplementedError - return '%s\r\n%s\r\n%s' % (request_line, str(self.headers), self.content) + raise http.HttpError(400, "Invalid request form") + + return '%s\r\n%s\r\n%s' % (request_line, self._assemble_headers(), self.content) @classmethod def from_stream(cls, rfile, include_content=True, body_size_limit=None): @@ -135,7 +150,7 @@ class HTTPRequest(object): http_version, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end \ = None, None, None, None, None, None, None, None, None, None - timestamp_start = timestamp() + timestamp_start = libmproxy.utils.timestamp() request_line = HTTPHandler.get_line(rfile) request_line_parts = http.parse_init(request_line) @@ -147,7 +162,7 @@ class HTTPRequest(object): form_in = "asterisk" elif path.startswith("/"): form_in = "origin" - if not utils.isascii(path): + if not netlib.utils.isascii(path): raise ProtocolError(400, "Bad HTTP request line: %s"%repr(request_line)) elif method.upper() == 'CONNECT': form_in = "authority" @@ -168,7 +183,7 @@ class HTTPRequest(object): if include_content: content = http.read_http_body(rfile, headers, body_size_limit, True) - timestamp_end = timestamp() + timestamp_end = libmproxy.utils.timestamp() return HTTPRequest(form_in, method, scheme, host, port, path, http_version, headers, content, timestamp_start, timestamp_end) @@ -182,48 +197,60 @@ class HTTPHandler(ProtocolHandler): self.c.close = True def handle_error(self, e): - raise e # FIXME: Proper error handling + raise e # FIXME: Proper error handling def handle_request(self): try: - flow = HTTPFlow(self.c.client_conn, self.c.server_conn, timestamp(), None, None, None) + flow = HTTPFlow(self.c.client_conn, self.c.server_conn, libmproxy.utils.timestamp(), None, None, None) flow.request = self.read_request() - request_reply = self.c.channel.ask("request" if LEGACY else "httprequest", flow.request) + request_reply = self.c.channel.ask("request" if LEGACY else "httprequest", flow.request if LEGACY else flow) if request_reply is None or request_reply == KILL: - return False + return False + if isinstance(request_reply, HTTPResponse): flow.response = request_reply else: - flow.request = request_reply raw = flow.request._assemble() self.c.server_conn.wfile.write(raw) self.c.server_conn.wfile.flush() flow.response = self.read_response(flow) - response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse", flow.response) + + response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse", + flow.response if LEGACY else flow) if response_reply is None or response_reply == KILL: return False - else: - raw = flow.response._assemble() - self.c.client_conn.wfile.write(raw) - self.c.client_conn.wfile.flush() + + raw = flow.response._assemble() + self.c.client_conn.wfile.write(raw) + self.c.client_conn.wfile.flush() + flow.timestamp_end = libmproxy.utils.timestamp() if (http.connection_close(flow.request.http_version, flow.request.headers) or http.connection_close(flow.response.http_version, flow.response.headers)): return False - flow.timestamp_end = timestamp() + if flow.request.form_in == "authority": + self.ssl_upgrade() return flow - except tcp.NetLibDisconnect, e: + except ProtocolError, http.HttpError: + raise NotImplementedError + # FIXME: Implement error handling return False + def ssl_upgrade(self): + self.c.mode = "transparent" + self.c.determine_conntype() + self.c.establish_ssl(server=True, client=True) + raise ConnectionTypeChange + def read_request(self): request = HTTPRequest.from_stream(self.c.client_conn.rfile, body_size_limit=self.c.config.body_size_limit) if self.c.mode == "regular": self.authenticate(request) if request.form_in == "authority" and self.c.client_conn.ssl_established: - raise ProtocolError(502, "Must not CONNECT on SSL connection") + raise ProtocolError(502, "Must not CONNECT on already encrypted connection") # If we have a CONNECT request, we might need to intercept if request.form_in == "authority": @@ -236,11 +263,7 @@ class HTTPHandler(ProtocolHandler): '\r\n' ) self.c.client_conn.wfile.flush() - - self.c.establish_ssl(server=True, client=True) - self.c.mode = "transparent" - self.c.determine_conntype() - raise ConnectionTypeChange + self.ssl_upgrade() if self.c.mode == "regular": if request.form_in == "authority": @@ -258,7 +281,8 @@ class HTTPHandler(ProtocolHandler): return request def read_response(self, flow): - return HTTPResponse.from_stream(self.c.server_conn.rfile, flow.request.method, body_size_limit=self.c.config.body_size_limit) + return HTTPResponse.from_stream(self.c.server_conn.rfile, flow.request.method, + body_size_limit=self.c.config.body_size_limit) def authenticate(self, request): if self.c.config.authenticator: @@ -278,7 +302,7 @@ class HTTPHandler(ProtocolHandler): Get a line, possibly preceded by a blank. """ line = fp.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message + if line == "\r\n" or line == "\n": # Possible leftover from previous message line = fp.readline() if line == "": raise tcp.NetLibDisconnect diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 4ce14491..6a2da7d2 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -184,9 +184,8 @@ class ConnectionHandler: continue self.del_server_connection() - except (ProxyError, protocol.ProtocolError), e: + except ProxyError, e: self.log(str(e)) - protocol.handle_error(self.conntype, self, e) # FIXME: We need to persist errors self.log("disconnect") @@ -223,8 +222,13 @@ class ConnectionHandler: self.channel.tell("serverconnect", self) def establish_ssl(self, client, server): + """ + Establishes SSL on the existing connection(s) to the server or the client, + as specified by the parameters. If the target server is on the pass-through list, + the conntype attribute will be changed and no the SSL connection won't be wrapped. + A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening + """ # TODO: Implement SSL pass-through handling and change conntype - if self.server_conn.host == "ycombinator.com": self.conntype = "tcp" |