aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2015-08-06 12:13:23 +0200
committerMaximilian Hils <git@maximilianhils.com>2015-08-11 20:32:12 +0200
commitaac0ab23ebb0e4d88306b12efee1dd31338f7664 (patch)
tree686c95d9c36fc28e69f10e11f3034562a6565fdc
parentc46e3f90bbc38080a41a278340aaad27d8881fd9 (diff)
downloadmitmproxy-aac0ab23ebb0e4d88306b12efee1dd31338f7664.tar.gz
mitmproxy-aac0ab23ebb0e4d88306b12efee1dd31338f7664.tar.bz2
mitmproxy-aac0ab23ebb0e4d88306b12efee1dd31338f7664.zip
simplify layer code, add yield_from_callback decorator
-rw-r--r--libmproxy/protocol2/layer.py58
-rw-r--r--libmproxy/protocol2/ssl.py51
2 files changed, 69 insertions, 40 deletions
diff --git a/libmproxy/protocol2/layer.py b/libmproxy/protocol2/layer.py
index 30aed350..aaa51baf 100644
--- a/libmproxy/protocol2/layer.py
+++ b/libmproxy/protocol2/layer.py
@@ -32,6 +32,8 @@ Further goals:
inline scripts shall have a chance to handle everything locally.
"""
from __future__ import (absolute_import, print_function, division)
+import Queue
+import threading
from netlib import tcp
from ..proxy import ProxyError2, Log
from ..proxy.connection import ServerConnection
@@ -131,7 +133,6 @@ class ServerConnectionMixin(object):
self._server_address = tcp.Address.wrap(address)
self.log("Set new server address: " + repr(self.server_address), "debug")
-
def _disconnect(self):
"""
Deletes (and closes) an existing server connection.
@@ -149,3 +150,58 @@ class ServerConnectionMixin(object):
self.server_conn.connect()
except tcp.NetLibError as e:
raise ProxyError2("Server connection to '%s' failed: %s" % (self.server_address, e), e)
+
+
+def yield_from_callback(fun):
+ """
+ Decorator which makes it possible to yield from callbacks in the original thread.
+ As a use case, take the pyOpenSSL handle_sni callback: If we receive a new SNI from the client,
+ we need to reconnect to the server with the new SNI. Reconnecting would normally be done using "yield Reconnect()",
+ but we're in a pyOpenSSL callback here, outside of the main program flow. With this decorator, it looks as follows:
+
+ def handle_sni(self):
+ # ...
+ self.yield_from_callback(Reconnect())
+
+ @yield_from_callback
+ def establish_ssl_with_client():
+ self.client_conn.convert_to_ssl(...)
+
+ for message in self.establish_ssl_with_client(): # will yield Reconnect at some point
+ yield message
+
+
+ Limitations:
+ - You cannot yield True.
+ """
+ yield_queue = Queue.Queue()
+
+ def do_yield(self, msg):
+ yield_queue.put(msg)
+ yield_queue.get()
+
+ def wrapper(self, *args, **kwargs):
+ self.yield_from_callback = do_yield
+
+ def run():
+ try:
+ fun(self, *args, **kwargs)
+ yield_queue.put(True)
+ except Exception as e:
+ yield_queue.put(e)
+
+ threading.Thread(target=run, name="YieldFromCallbackThread").start()
+ while True:
+ e = yield_queue.get()
+ if e is True:
+ break
+ elif isinstance(e, Exception):
+ # TODO: Include func name?
+ raise ProxyError2("Error from callback: " + repr(e), e)
+ else:
+ yield e
+ yield_queue.put(None)
+
+ self.yield_from_callback = None
+
+ return wrapper
diff --git a/libmproxy/protocol2/ssl.py b/libmproxy/protocol2/ssl.py
index c21956b7..32798e72 100644
--- a/libmproxy/protocol2/ssl.py
+++ b/libmproxy/protocol2/ssl.py
@@ -1,20 +1,13 @@
from __future__ import (absolute_import, print_function, division)
-import Queue
-import threading
import traceback
from netlib import tcp
from ..proxy import ProxyError2
-from .layer import Layer
+from .layer import Layer, yield_from_callback
from .messages import Connect, Reconnect, ChangeServer
from .auto import AutoLayer
-class ReconnectRequest(object):
- def __init__(self):
- self.done = threading.Event()
-
-
class SslLayer(Layer):
def __init__(self, ctx, client_ssl, server_ssl):
super(SslLayer, self).__init__(ctx)
@@ -59,7 +52,8 @@ class SslLayer(Layer):
for m in self._establish_ssl_with_client_and_server():
yield m
elif self.client_ssl:
- self._establish_ssl_with_client()
+ for m in self._establish_ssl_with_client():
+ yield m
layer = AutoLayer(self)
for message in layer():
@@ -96,32 +90,12 @@ class SslLayer(Layer):
except ProxyError2 as e:
server_err = e
- # The part below is a bit ugly as we cannot yield from the handle_sni callback.
- # The workaround is to do that in a separate thread and yield from the main thread.
-
- # client_ssl_queue may contain the following elements
- # - True, if ssl was successfully established
- # - An Exception thrown by self._establish_ssl_with_client()
- # - A threading.Event, which singnifies a request for a reconnect from the sni callback
- self.__client_ssl_queue = Queue.Queue()
-
- def establish_client_ssl():
- try:
- self._establish_ssl_with_client()
- self.__client_ssl_queue.put(True)
- except Exception as client_err:
- self.__client_ssl_queue.put(client_err)
-
- threading.Thread(target=establish_client_ssl, name="ClientSSLThread").start()
- e = self.__client_ssl_queue.get()
- if isinstance(e, ReconnectRequest):
- yield Reconnect()
- self._establish_ssl_with_server()
- e.done.set()
- e = self.__client_ssl_queue.get()
- if e is not True:
- raise ProxyError2("Error when establish client SSL: " + repr(e), e)
- self.__client_ssl_queue = None
+ for message in self._establish_ssl_with_client():
+ if message == Reconnect:
+ yield message
+ self._establish_ssl_with_server()
+ else:
+ raise RuntimeError("Unexpected Message: %s" % message)
if server_err and not self._sni_from_handshake:
raise server_err
@@ -142,9 +116,7 @@ class SslLayer(Layer):
if old_upstream_sni != self.sni_for_upstream_connection:
# Perform reconnect
if self.server_ssl:
- reconnect = ReconnectRequest()
- self.__client_ssl_queue.put(reconnect)
- reconnect.done.wait()
+ self.yield_from_callback(Reconnect())
if self._sni_from_handshake:
# Now, change client context to reflect possibly changed certificate:
@@ -163,6 +135,7 @@ class SslLayer(Layer):
except: # pragma: no cover
self.log("Error in handle_sni:\r\n" + traceback.format_exc(), "error")
+ @yield_from_callback
def _establish_ssl_with_client(self):
self.log("Establish SSL with client", "debug")
cert, key, chain_file = self.find_cert()
@@ -227,4 +200,4 @@ class SslLayer(Layer):
if self._sni_from_server_change:
sans.add(self._sni_from_server_change)
- return self.config.certstore.get_cert(host, list(sans)) \ No newline at end of file
+ return self.config.certstore.get_cert(host, list(sans))