diff options
Diffstat (limited to 'libmproxy/proxy.py')
-rw-r--r-- | libmproxy/proxy.py | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 3098aff4..609ffb62 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -133,18 +133,25 @@ class ProxyHandler(tcp.BaseHandler): self.server_conn = None tcp.BaseHandler.__init__(self, connection, client_address, server) - def get_server_connection(self, cc, scheme, host, port, sni): + def get_server_connection(self, cc, scheme, host, port, sni, request=None): """ When SNI is in play, this means we have an SSL-encrypted connection, which means that the entire handler is dedicated to a single server connection - no multiplexing. If this assumption ever breaks, we'll have to do something different with the SNI host variable on the handler object. + + `conn_info` holds the initial connection's parameters, as the + hook might change them. Also, the hook might require an initial + request to figure out connection settings; in this case it can + set require_request, which will cause the connection to be + re-opened after the client's request arrives. """ sc = self.server_conn if not sni: sni = host - if sc and (scheme, host, port, sni) != (sc.scheme, sc.host, sc.port, sc.sni): + conn_info = (scheme, host, port, sni) + if sc and (conn_info != sc.conn_info or (request and sc.require_request)): sc.terminate() self.server_conn = None self.log( @@ -159,6 +166,13 @@ class ProxyHandler(tcp.BaseHandler): if not self.server_conn: try: self.server_conn = ServerConnection(self.config, scheme, host, port, sni) + + # Additional attributes, used if the server_connect hook + # needs to change parameters + self.server_conn.request = request + self.server_conn.require_request = False + + self.server_conn.conn_info = conn_info self.channel.ask(self.server_conn) self.server_conn.connect() except tcp.NetLibError, v: @@ -223,7 +237,7 @@ class ProxyHandler(tcp.BaseHandler): # the case, we want to reconnect without sending an error # to the client. while 1: - sc = self.get_server_connection(cc, scheme, host, port, self.sni) + sc = self.get_server_connection(cc, scheme, host, port, self.sni, request=request) sc.send(request) if sc.requestcount == 1: # add timestamps only for first request (others are not directly affected) request.tcp_setup_timestamp = sc.tcp_setup_timestamp |