diff options
author | Maximilian Hils <git@maximilianhils.com> | 2015-08-06 12:13:23 +0200 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2015-08-11 20:32:12 +0200 |
commit | aac0ab23ebb0e4d88306b12efee1dd31338f7664 (patch) | |
tree | 686c95d9c36fc28e69f10e11f3034562a6565fdc | |
parent | c46e3f90bbc38080a41a278340aaad27d8881fd9 (diff) | |
download | mitmproxy-aac0ab23ebb0e4d88306b12efee1dd31338f7664.tar.gz mitmproxy-aac0ab23ebb0e4d88306b12efee1dd31338f7664.tar.bz2 mitmproxy-aac0ab23ebb0e4d88306b12efee1dd31338f7664.zip |
simplify layer code, add yield_from_callback decorator
-rw-r--r-- | libmproxy/protocol2/layer.py | 58 | ||||
-rw-r--r-- | libmproxy/protocol2/ssl.py | 51 |
2 files changed, 69 insertions, 40 deletions
diff --git a/libmproxy/protocol2/layer.py b/libmproxy/protocol2/layer.py index 30aed350..aaa51baf 100644 --- a/libmproxy/protocol2/layer.py +++ b/libmproxy/protocol2/layer.py @@ -32,6 +32,8 @@ Further goals: inline scripts shall have a chance to handle everything locally. """ from __future__ import (absolute_import, print_function, division) +import Queue +import threading from netlib import tcp from ..proxy import ProxyError2, Log from ..proxy.connection import ServerConnection @@ -131,7 +133,6 @@ class ServerConnectionMixin(object): self._server_address = tcp.Address.wrap(address) self.log("Set new server address: " + repr(self.server_address), "debug") - def _disconnect(self): """ Deletes (and closes) an existing server connection. @@ -149,3 +150,58 @@ class ServerConnectionMixin(object): self.server_conn.connect() except tcp.NetLibError as e: raise ProxyError2("Server connection to '%s' failed: %s" % (self.server_address, e), e) + + +def yield_from_callback(fun): + """ + Decorator which makes it possible to yield from callbacks in the original thread. + As a use case, take the pyOpenSSL handle_sni callback: If we receive a new SNI from the client, + we need to reconnect to the server with the new SNI. Reconnecting would normally be done using "yield Reconnect()", + but we're in a pyOpenSSL callback here, outside of the main program flow. With this decorator, it looks as follows: + + def handle_sni(self): + # ... + self.yield_from_callback(Reconnect()) + + @yield_from_callback + def establish_ssl_with_client(): + self.client_conn.convert_to_ssl(...) + + for message in self.establish_ssl_with_client(): # will yield Reconnect at some point + yield message + + + Limitations: + - You cannot yield True. + """ + yield_queue = Queue.Queue() + + def do_yield(self, msg): + yield_queue.put(msg) + yield_queue.get() + + def wrapper(self, *args, **kwargs): + self.yield_from_callback = do_yield + + def run(): + try: + fun(self, *args, **kwargs) + yield_queue.put(True) + except Exception as e: + yield_queue.put(e) + + threading.Thread(target=run, name="YieldFromCallbackThread").start() + while True: + e = yield_queue.get() + if e is True: + break + elif isinstance(e, Exception): + # TODO: Include func name? + raise ProxyError2("Error from callback: " + repr(e), e) + else: + yield e + yield_queue.put(None) + + self.yield_from_callback = None + + return wrapper diff --git a/libmproxy/protocol2/ssl.py b/libmproxy/protocol2/ssl.py index c21956b7..32798e72 100644 --- a/libmproxy/protocol2/ssl.py +++ b/libmproxy/protocol2/ssl.py @@ -1,20 +1,13 @@ from __future__ import (absolute_import, print_function, division) -import Queue -import threading import traceback from netlib import tcp from ..proxy import ProxyError2 -from .layer import Layer +from .layer import Layer, yield_from_callback from .messages import Connect, Reconnect, ChangeServer from .auto import AutoLayer -class ReconnectRequest(object): - def __init__(self): - self.done = threading.Event() - - class SslLayer(Layer): def __init__(self, ctx, client_ssl, server_ssl): super(SslLayer, self).__init__(ctx) @@ -59,7 +52,8 @@ class SslLayer(Layer): for m in self._establish_ssl_with_client_and_server(): yield m elif self.client_ssl: - self._establish_ssl_with_client() + for m in self._establish_ssl_with_client(): + yield m layer = AutoLayer(self) for message in layer(): @@ -96,32 +90,12 @@ class SslLayer(Layer): except ProxyError2 as e: server_err = e - # The part below is a bit ugly as we cannot yield from the handle_sni callback. - # The workaround is to do that in a separate thread and yield from the main thread. - - # client_ssl_queue may contain the following elements - # - True, if ssl was successfully established - # - An Exception thrown by self._establish_ssl_with_client() - # - A threading.Event, which singnifies a request for a reconnect from the sni callback - self.__client_ssl_queue = Queue.Queue() - - def establish_client_ssl(): - try: - self._establish_ssl_with_client() - self.__client_ssl_queue.put(True) - except Exception as client_err: - self.__client_ssl_queue.put(client_err) - - threading.Thread(target=establish_client_ssl, name="ClientSSLThread").start() - e = self.__client_ssl_queue.get() - if isinstance(e, ReconnectRequest): - yield Reconnect() - self._establish_ssl_with_server() - e.done.set() - e = self.__client_ssl_queue.get() - if e is not True: - raise ProxyError2("Error when establish client SSL: " + repr(e), e) - self.__client_ssl_queue = None + for message in self._establish_ssl_with_client(): + if message == Reconnect: + yield message + self._establish_ssl_with_server() + else: + raise RuntimeError("Unexpected Message: %s" % message) if server_err and not self._sni_from_handshake: raise server_err @@ -142,9 +116,7 @@ class SslLayer(Layer): if old_upstream_sni != self.sni_for_upstream_connection: # Perform reconnect if self.server_ssl: - reconnect = ReconnectRequest() - self.__client_ssl_queue.put(reconnect) - reconnect.done.wait() + self.yield_from_callback(Reconnect()) if self._sni_from_handshake: # Now, change client context to reflect possibly changed certificate: @@ -163,6 +135,7 @@ class SslLayer(Layer): except: # pragma: no cover self.log("Error in handle_sni:\r\n" + traceback.format_exc(), "error") + @yield_from_callback def _establish_ssl_with_client(self): self.log("Establish SSL with client", "debug") cert, key, chain_file = self.find_cert() @@ -227,4 +200,4 @@ class SslLayer(Layer): if self._sni_from_server_change: sans.add(self._sni_from_server_change) - return self.config.certstore.get_cert(host, list(sans))
\ No newline at end of file + return self.config.certstore.get_cert(host, list(sans)) |