diff options
Diffstat (limited to 'libmproxy/proxy.py')
-rw-r--r-- | libmproxy/proxy.py | 57 |
1 files changed, 39 insertions, 18 deletions
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 7c229064..c9ceb8de 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -80,8 +80,7 @@ class ServerConnection(tcp.TCPClient): def terminate(self): try: - if not self.wfile.closed: - self.wfile.flush() + self.wfile.flush() self.connection.close() except IOError: pass @@ -110,6 +109,27 @@ class RequestReplayThread(threading.Thread): self.channel.ask(err) +class HandleSNI: + def __init__(self, handler, client_conn, host, port, cert, key): + self.handler, self.client_conn, self.host, self.port = handler, client_conn, host, port + self.cert, self.key = cert, key + + def __call__(self, connection): + try: + sn = connection.get_servername() + if sn: + self.handler.get_server_connection(self.client_conn, "https", self.host, self.port, sn) + new_context = SSL.Context(SSL.TLSv1_METHOD) + new_context.use_privatekey_file(self.key) + new_context.use_certificate_file(self.cert) + connection.set_context(new_context) + self.handler.sni = sn.decode("utf8").encode("idna") + # An unhandled exception in this method will core dump PyOpenSSL, so + # make dang sure it doesn't happen. + except Exception, e: + pass + + class ProxyHandler(tcp.BaseHandler): def __init__(self, config, connection, client_address, server, channel, server_version): self.channel, self.server_version = channel, server_version @@ -266,18 +286,15 @@ class ProxyHandler(tcp.BaseHandler): l = Log(msg) self.channel.tell(l) - def find_cert(self, host, port, sni): + def find_cert(self, cc, host, port, sni): if self.config.certfile: return self.config.certfile else: sans = [] if not self.config.no_upstream_cert: - try: - cert = certutils.get_remote_cert(host, port, sni) - except tcp.NetLibError, v: - raise ProxyError(502, "Unable to get remote cert: %s"%str(v)) - sans = cert.altnames - host = cert.cn.decode("utf8").encode("idna") + conn = self.get_server_connection(cc, "https", host, port, sni) + sans = conn.cert.altnames + host = conn.cert.cn.decode("utf8").encode("idna") ret = self.config.certstore.get_cert(host, sans, self.config.cacert) if not ret: raise ProxyError(502, "mitmproxy: Unable to generate dummy cert.") @@ -292,11 +309,6 @@ class ProxyHandler(tcp.BaseHandler): line = fp.readline() return line - def handle_sni(self, conn): - sn = conn.get_servername() - if sn: - self.sni = sn.decode("utf8").encode("idna") - def read_request_transparent(self, client_conn): orig = self.config.transparent_proxy["resolver"].original_addr(self.connection) if not orig: @@ -304,9 +316,13 @@ class ProxyHandler(tcp.BaseHandler): host, port = orig if not self.ssl_established and (port in self.config.transparent_proxy["sslports"]): scheme = "https" - certfile = self.find_cert(host, port, None) + dummycert = self.find_cert(client_conn, host, port, host) try: - self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert) + sni = HandleSNI( + self, client_conn, host, port, + dummycert, self.config.certfile or self.config.cacert + ) + self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni) except tcp.NetLibError, v: raise ProxyError(400, str(v)) else: @@ -346,9 +362,14 @@ class ProxyHandler(tcp.BaseHandler): '\r\n' ) self.wfile.flush() - certfile = self.find_cert(host, port, None) + certfile = self.find_cert(client_conn, host, port, host) + + sni = HandleSNI( + self, client_conn, host, port, + dummycert, self.config.certfile or self.config.cacert + ) try: - self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert) + self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert, handle_sni=sni) except tcp.NetLibError, v: raise ProxyError(400, str(v)) self.proxy_connect_state = (host, port, httpversion) |