diff options
Diffstat (limited to 'libmproxy/protocol2/layer.py')
-rw-r--r-- | libmproxy/protocol2/layer.py | 47 |
1 files changed, 20 insertions, 27 deletions
diff --git a/libmproxy/protocol2/layer.py b/libmproxy/protocol2/layer.py index c1648a62..67f3d549 100644 --- a/libmproxy/protocol2/layer.py +++ b/libmproxy/protocol2/layer.py @@ -44,8 +44,8 @@ class _LayerCodeCompletion(object): Dummy class that provides type hinting in PyCharm, which simplifies development a lot. """ - def __init__(self): - super(_LayerCodeCompletion, self).__init__() + def __init__(self, *args, **kwargs): + super(_LayerCodeCompletion, self).__init__(*args, **kwargs) if True: return self.config = None @@ -57,12 +57,12 @@ class _LayerCodeCompletion(object): class Layer(_LayerCodeCompletion): - def __init__(self, ctx): + def __init__(self, ctx, *args, **kwargs): """ Args: ctx: The (read-only) higher layer. """ - super(Layer, self).__init__() + super(Layer, self).__init__(*args, **kwargs) self.ctx = ctx def __call__(self): @@ -103,10 +103,9 @@ class ServerConnectionMixin(object): Mixin that provides a layer with the capabilities to manage a server connection. """ - def __init__(self): + def __init__(self, server_address=None): super(ServerConnectionMixin, self).__init__() - self._server_address = None - self.server_conn = None + self.server_conn = ServerConnection(server_address) def _handle_server_message(self, message): if message == Reconnect: @@ -116,44 +115,38 @@ class ServerConnectionMixin(object): elif message == Connect: self._connect() return True - elif message == SetServer and message.depth == 1: - if self.server_conn: - self._disconnect() - self.server_address = message.address - return True + elif message == SetServer: + if message.depth == 1: + if self.server_conn: + self._disconnect() + self.log("Set new server address: " + repr(message.address), "debug") + self.server_conn.address = message.address + return True + else: + message.depth -= 1 elif message == Kill: self._disconnect() return False - @property - def server_address(self): - return self._server_address - - @server_address.setter - def server_address(self, address): - self._server_address = tcp.Address.wrap(address) - self.log("Set new server address: " + repr(self.server_address), "debug") - def _disconnect(self): """ Deletes (and closes) an existing server connection. """ - self.log("serverdisconnect", "debug", [repr(self.server_address)]) + self.log("serverdisconnect", "debug", [repr(self.server_conn.address)]) self.server_conn.finish() self.server_conn.close() # self.channel.tell("serverdisconnect", self) - self.server_conn = None + self.server_conn = ServerConnection(None) def _connect(self): - if not self.server_address: + if not self.server_conn.address: raise ProtocolException("Cannot connect to server, no server address given.") - self.log("serverconnect", "debug", [repr(self.server_address)]) - self.server_conn = ServerConnection(self.server_address) + self.log("serverconnect", "debug", [repr(self.server_conn.address)]) try: self.server_conn.connect() except tcp.NetLibError as e: - raise ProtocolException("Server connection to '%s' failed: %s" % (self.server_address, e), e) + raise ProtocolException("Server connection to '%s' failed: %s" % (self.server_conn.address, e), e) def yield_from_callback(fun): |