diff options
Diffstat (limited to 'libmproxy/proxy.py')
-rw-r--r-- | libmproxy/proxy.py | 153 |
1 files changed, 62 insertions, 91 deletions
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 328f195e..5ac40e92 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -2,10 +2,12 @@ import sys, os, string, socket, time import shutil, tempfile, threading import SocketServer from OpenSSL import SSL -from netlib import odict, tcp, http, wsgi, certutils, http_status, http_auth +from netlib import odict, tcp, http, certutils, http_status, http_auth import utils, flow, version, platform, controller, protocol +TRANSPARENT_SSL_PORTS = [443, 8443] + KILL = 0 @@ -17,11 +19,6 @@ class ProxyError(Exception): return "ProxyError(%s, %s)"%(self.code, self.msg) -class Log: - def __init__(self, msg): - self.msg = msg - - class ProxyConfig: def __init__(self, certfile = None, cacert = None, clientcerts = None, no_upstream_cert=False, body_size_limit = None, reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None): self.certfile = certfile @@ -97,34 +94,10 @@ class RequestReplayThread(threading.Thread): self.flow.request, httpversion, code, msg, headers, content, server.cert, server.rfile.first_byte_timestamp ) - self.channel.ask(response) + self.channel.ask("response", response) except (ProxyError, http.HttpError, tcp.NetLibError), v: err = flow.Error(self.flow.request, str(v)) - self.channel.ask(err) - - -class HandleSNI: - def __init__(self, handler, cert, key): - self.handler = handler - self.cert, self.key = cert, key - - def __call__(self, connection): - try: - sn = connection.get_servername() - if sn: - self.handler.sni = sn.decode("utf8").encode("idna") - self.handler.establish_server_connection() - self.handler.handle_ssl() - new_context = SSL.Context(SSL.TLSv1_METHOD) - new_context.use_privatekey_file(self.key) - new_context.use_certificate(self.cert.x509) - connection.set_context(new_context) - # FIXME: How does that work? - # An unhandled exception in this method will core dump PyOpenSSL, so - # make dang sure it doesn't happen. - except Exception, e: # pragma: no cover - pass - + self.channel.ask("error", err) class ConnectionHandler: def __init__(self, config, client_connection, client_address, server, channel, server_version): @@ -145,46 +118,44 @@ class ConnectionHandler: def del_server_connection(self): if self.server_conn: self.server_conn.terminate() + self.channel.tell("serverdisconnect", self) self.server_conn = None + self.sni = None def handle(self): - cc = flow.ClientConnect(self.client_address) - self.log(cc, "connect") - self.channel.ask(cc) + self.log("connect") + self.channel.ask("clientconnect", self) # Can we already identify the target server and connect to it? if self.config.forward_proxy: - self.server_address = self.config.forward_proxy + self.server_address = self.config.forward_proxy[1:] else: if self.config.reverse_proxy: - self.server_address = self.config.reverse_proxy + self.server_address = self.config.reverse_proxy[1:] elif self.config.transparent_proxy: self.server_address = self.config.transparent_proxy["resolver"].original_addr(self.connection) if not self.server_address: raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") - self.log(cc, "transparent to %s:%s"%self.server_address) + self.log("transparent to %s:%s"%self.server_address) if self.server_address: self.establish_server_connection() self.handle_ssl() - self.determine_conntype(self.mode) + self.determine_conntype(self.mode, *self.server_address) - while not cc.close: - protocol.handle_messages(self.conntype, self) + while not self.close: + try: + protocol.handle_messages(self.conntype, self) + except protocol.ConnectionTypeChange: + continue - cc.close = True self.del_server_connection() - cd = flow.ClientDisconnect(cc) - self.log( - cc, "disconnect", - [ - "handled %s requests"%cc.requestcount] - ) - self.channel.tell(cd) + self.log("disconnect") + self.channel.tell("clientdisconnect", self) - def determine_conntype(self, mode): + def determine_conntype(self, mode, host, port): #TODO: Add ruleset to select correct protocol depending on mode/target port etc. self.conntype = "http" @@ -195,14 +166,15 @@ class ConnectionHandler: """ self.del_server_connection() self.server_conn = ServerConnection(self.config, *self.server_address, self.sni) + self.channel.tell("serverconnect", self) def handle_ssl(self): if self.config.transparent_proxy: client_ssl, server_ssl = (self.server_address[1] in self.config.transparent_proxy["sslports"]) elif self.config.reverse_proxy: - client_ssl, server_ssl = self.config.reverse_proxy[0] == "https" - # FIXME: Make protocol generic (as with transparent proxies) - # FIXME: Add SSL-terminating capatbility (SSL -> mitmproxy -> plain and vice versa) + client_ssl, server_ssl = (self.config.reverse_proxy[0] == "https") + # TODO: Make protocol generic (as with transparent proxies) + # TODO: Add SSL-terminating capatbility (SSL -> mitmproxy -> plain and vice versa) else: client_ssl, server_ssl = True # In regular mode, this function will only be called on HTTP CONNECT @@ -211,37 +183,58 @@ class ConnectionHandler: if server_ssl and not self.server_conn.ssl_established: self.server_conn.establish_ssl() if client_ssl and not self.client_conn.ssl_established: - dummycert = self.find_cert(self.client_conn, *self.server_address) - sni = HandleSNI( - self, dummycert, self.config.certfile or self.config.cacert - ) + dummycert = self.find_cert() + self.client_conn.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=self.handle_sni) def log(self, msg, subs=()): msg = [ - "%s:%s: "%self.client_address + msg + "%s:%s: "%(self.client_address, msg) ] for i in subs: msg.append(" -> "+i) msg = "\n".join(msg) - l = Log(msg) - self.channel.tell(l) + self.channel.tell("log", msg) - def find_cert(self, cc, host, port, sni=None): + def find_cert(self): if self.config.certfile: with open(self.config.certfile, "rb") as f: return certutils.SSLCert.from_pem(f.read()) else: + host = self.server_address[0] sans = [] - if not self.config.no_upstream_cert: - conn = self.get_server_connection(cc, "https", host, port, sni) - sans = conn.cert.altnames - if conn.cert.cn: - host = conn.cert.cn.decode("utf8").encode("idna") + if not self.config.no_upstream_cert or not self.server_conn.ssl_established: + upstream_cert = self.server_conn.cert + if upstream_cert.cn: + host = upstream_cert.cn.decode("utf8").encode("idna") + sans = upstream_cert.altnames + ret = self.config.certstore.get_cert(host, sans, self.config.cacert) if not ret: raise ProxyError(502, "Unable to generate dummy cert.") return ret + def handle_sni(self, connection): + """ + This callback gets called during the SSL handshake with the client. + The client has just sent the Sever Name Indication (SNI). We now connect upstream to + figure out which certificate needs to be served. + """ + try: + sn = connection.get_servername() + if sn and sn != self.sni: + self.sni = sn.decode("utf8").encode("idna") + self.establish_server_connection() # reconnect to upstream server with SNI + self.handle_ssl() # establish SSL with upstream + # Now, change client context to reflect changed certificate: + new_context = SSL.Context(SSL.TLSv1_METHOD) + new_context.use_privatekey_file(self.config.certfile or self.config.cacert) + dummycert = self.find_cert() + new_context.use_certificate(dummycert.x509) + connection.set_context(new_context) + # An unhandled exception in this method will core dump PyOpenSSL, so + # make dang sure it doesn't happen. + except Exception, e: # pragma: no cover + pass class ProxyServerError(Exception): pass @@ -260,7 +253,6 @@ class ProxyServer(tcp.TCPServer): except socket.error, v: raise ProxyServerError('Error starting proxy server: ' + v.strerror) self.channel = None - self.apps = AppRegistry() def start_slave(self, klass, channel): slave = klass(channel, self) @@ -275,28 +267,6 @@ class ProxyServer(tcp.TCPServer): h.finish() -class AppRegistry: - def __init__(self): - self.apps = {} - - def add(self, app, domain, port): - """ - Add a WSGI app to the registry, to be served for requests to the - specified domain, on the specified port. - """ - self.apps[(domain, port)] = wsgi.WSGIAdaptor(app, domain, port, version.NAMEVERSION) - - def get(self, request): - """ - Returns an WSGIAdaptor instance if request matches an app, or None. - """ - if (request.host, request.port) in self.apps: - return self.apps[(request.host, request.port)] - if "host" in request.headers: - host = request.headers["host"][0] - return self.apps.get((host, request.port), None) - - class DummyServer: bound = False def __init__(self, config): @@ -324,7 +294,6 @@ def certificate_option_group(parser): ) -TRANSPARENT_SSL_PORTS = [443, 8443] def process_proxy_options(parser, options): if options.cert: @@ -367,7 +336,9 @@ def process_proxy_options(parser, options): if options.clientcerts: options.clientcerts = os.path.expanduser(options.clientcerts) if not os.path.exists(options.clientcerts) or not os.path.isdir(options.clientcerts): - return parser.error("Client certificate directory does not exist or is not a directory: %s"%options.clientcerts) + return parser.error( + "Client certificate directory does not exist or is not a directory: %s"%options.clientcerts + ) if (options.auth_nonanonymous or options.auth_singleuser or options.auth_htpasswd): if options.auth_singleuser: |