diff options
Diffstat (limited to 'libmproxy/protocol/__init__.py')
-rw-r--r-- | libmproxy/protocol/__init__.py | 46 |
1 files changed, 40 insertions, 6 deletions
diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py index 78930e05..123c31e0 100644 --- a/libmproxy/protocol/__init__.py +++ b/libmproxy/protocol/__init__.py @@ -1,5 +1,6 @@ -KILL = 0 # const for killed requests +from ..proxy import ServerConnection, AddressPriority +KILL = 0 # const for killed requests class ConnectionTypeChange(Exception): """ @@ -12,7 +13,7 @@ class ConnectionTypeChange(Exception): class ProtocolHandler(object): def __init__(self, c): self.c = c - """@type : libmproxy.proxy.ConnectionHandler""" + """@type: libmproxy.proxy.ConnectionHandler""" def handle_messages(self): """ @@ -36,13 +37,46 @@ class TemporaryServerChangeMixin(object): """ def change_server(self, address, ssl): - self._backup_server = True - raise NotImplementedError("You must not change host port port.") + if address == self.c.server_conn.address(): + return + priority = AddressPriority.MANUALLY_CHANGED + + if self.c.server_conn.priority > priority: + self.log("Attempt to change server address, " + "but priority is too low (is: %s, got: %s)" % (self.server_conn.priority, priority)) + return + + self.log("Temporarily change server connection: %s:%s -> %s:%s" % ( + self.c.server_conn.address.host, + self.c.server_conn.address.port, + address.host, + address.port + )) + + if not hasattr(self, "_backup_server_conn"): + self._backup_server_conn = self.c.server_conn + self.c.server_conn = None + else: # This is at least the second temporary change. We can kill the current connection. + self.c.del_server_connection() + + self.c.set_server_address(address, priority) + if ssl: + self.establish_ssl(server=True) def restore_server(self): - if not hasattr(self,"_backup_server"): + if not hasattr(self, "_backup_server_conn"): return - raise NotImplementedError + + self.log("Restore original server connection: %s:%s -> %s:%s" % ( + self.c.server_conn.address.host, + self.c.server_conn.address.port, + self._backup_server_conn.host, + self._backup_server_conn.port + )) + + self.c.del_server_connection() + self.c.server_conn = self._backup_server_conn + del self._backup_server_conn from . import http, tcp |