diff options
Diffstat (limited to 'libmproxy/proxy.py')
-rw-r--r-- | libmproxy/proxy.py | 848 |
1 files changed, 393 insertions, 455 deletions
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 0d53aef8..b6480822 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -1,13 +1,26 @@ -import os, socket, time -import threading +import os, socket, time, threading, copy from OpenSSL import SSL -from netlib import tcp, http, certutils, http_status, http_auth -import utils, flow, version, platform, controller - +from netlib import tcp, http, certutils, http_auth +import utils, version, platform, controller, stateobject TRANSPARENT_SSL_PORTS = [443, 8443] -KILL = 0 + +class AddressPriority(object): + """ + Enum that signifies the priority of the given address when choosing the destination host. + Higher is better (None < i) + """ + FORCE = 5 + """forward mode""" + MANUALLY_CHANGED = 4 + """user changed the target address in the ui""" + FROM_SETTINGS = 3 + """reverse proxy mode""" + FROM_CONNECTION = 2 + """derived from transparent resolver""" + FROM_PROTOCOL = 1 + """derived from protocol (e.g. absolute-form http requests)""" class ProxyError(Exception): @@ -15,7 +28,7 @@ class ProxyError(Exception): self.code, self.msg, self.headers = code, msg, headers def __str__(self): - return "ProxyError(%s, %s)"%(self.code, self.msg) + return "ProxyError(%s, %s)" % (self.code, self.msg) class Log: @@ -24,7 +37,8 @@ class Log: class ProxyConfig: - def __init__(self, certfile = None, cacert = None, clientcerts = None, no_upstream_cert=False, body_size_limit = None, reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None): + def __init__(self, certfile=None, cacert=None, clientcerts=None, no_upstream_cert=False, body_size_limit=None, + reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None): self.certfile = certfile self.cacert = cacert self.clientcerts = clientcerts @@ -37,49 +51,146 @@ class ProxyConfig: self.certstore = certutils.CertStore() -class ServerConnection(tcp.TCPClient): - def __init__(self, config, scheme, host, port, sni): - tcp.TCPClient.__init__(self, host, port) - self.config = config - self.scheme, self.sni = scheme, sni - self.requestcount = 0 - self.tcp_setup_timestamp = None - self.ssl_setup_timestamp = None +class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): + def __init__(self, client_connection, address, server): + if client_connection: # Eventually, this object is restored from state. We don't have a connection then. + tcp.BaseHandler.__init__(self, client_connection, address, server) + else: + self.connection = None + self.server = None + self.wfile = None + self.rfile = None + self.address = None + self.clientcert = None + + self.timestamp_start = utils.timestamp() + self.timestamp_end = None + self.timestamp_ssl_setup = None + + _stateobject_attributes = dict( + timestamp_start=float, + timestamp_end=float, + timestamp_ssl_setup=float + ) + + def _get_state(self): + d = super(ClientConnection, self)._get_state() + d.update( + address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, + clientcert=self.cert.to_pem() if self.clientcert else None + ) + return d + + def _load_state(self, state): + super(ClientConnection, self)._load_state(state) + self.address = tcp.Address(**state["address"]) if state["address"] else None + self.clientcert = certutils.SSLCert.from_pem(state["clientcert"]) if state["clientcert"] else None + + def copy(self): + return copy.copy(self) + + def send(self, message): + self.wfile.write(message) + self.wfile.flush() + + @classmethod + def _from_state(cls, state): + f = cls(None, tuple(), None) + f._load_state(state) + return f + + def convert_to_ssl(self, *args, **kwargs): + tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs) + self.timestamp_ssl_setup = utils.timestamp() + + def finish(self): + tcp.BaseHandler.finish(self) + self.timestamp_end = utils.timestamp() + + +class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): + def __init__(self, address, priority): + tcp.TCPClient.__init__(self, address) + self.priority = priority + + self.peername = None + self.timestamp_start = None + self.timestamp_end = None + self.timestamp_tcp_setup = None + self.timestamp_ssl_setup = None + + _stateobject_attributes = dict( + peername=tuple, + timestamp_start=float, + timestamp_end=float, + timestamp_tcp_setup=float, + timestamp_ssl_setup=float, + address=tcp.Address, + source_address=tcp.Address, + cert=certutils.SSLCert, + ssl_established=bool, + sni=str + ) + + def _get_state(self): + d = super(ServerConnection, self)._get_state() + d.update( + address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, + source_address= {"address": self.source_address(), + "use_ipv6": self.source_address.use_ipv6} if self.source_address else None, + cert=self.cert.to_pem() if self.cert else None + ) + return d + + def _load_state(self, state): + super(ServerConnection, self)._load_state(state) + + self.address = tcp.Address(**state["address"]) if state["address"] else None + self.source_address = tcp.Address(**state["source_address"]) if state["source_address"] else None + self.cert = certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None + + @classmethod + def _from_state(cls, state): + f = cls(tuple(), None) + f._load_state(state) + return f + + def copy(self): + return copy.copy(self) def connect(self): + self.timestamp_start = utils.timestamp() tcp.TCPClient.connect(self) - self.tcp_setup_timestamp = time.time() - if self.scheme == "https": - clientcert = None - if self.config.clientcerts: - path = os.path.join(self.config.clientcerts, self.host.encode("idna")) + ".pem" - if os.path.exists(path): - clientcert = path - try: - self.convert_to_ssl(cert=clientcert, sni=self.sni) - self.ssl_setup_timestamp = time.time() - except tcp.NetLibError, v: - raise ProxyError(400, str(v)) - - def send(self, request): - self.requestcount += 1 - d = request._assemble() - if not d: - raise ProxyError(502, "Cannot transmit an incomplete request.") - self.wfile.write(d) + self.peername = self.connection.getpeername() + self.timestamp_tcp_setup = utils.timestamp() + + def send(self, message): + self.wfile.write(message) self.wfile.flush() - def terminate(self): - if self.connection: - try: - self.wfile.flush() - except tcp.NetLibDisconnect: # pragma: no cover - pass - self.connection.close() + def establish_ssl(self, clientcerts, sni): + clientcert = None + if clientcerts: + path = os.path.join(clientcerts, self.address.host.encode("idna")) + ".pem" + if os.path.exists(path): + clientcert = path + try: + self.convert_to_ssl(cert=clientcert, sni=sni) + self.timestamp_ssl_setup = utils.timestamp() + except tcp.NetLibError, v: + raise ProxyError(400, str(v)) + + def finish(self): + tcp.TCPClient.finish(self) + self.timestamp_end = utils.timestamp() +from . import protocol +from .protocol.http import HTTPResponse class RequestReplayThread(threading.Thread): + name="RequestReplayThread" + def __init__(self, config, flow, masterq): self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) threading.Thread.__init__(self) @@ -87,448 +198,275 @@ class RequestReplayThread(threading.Thread): def run(self): try: r = self.flow.request - server = ServerConnection(self.config, r.scheme, r.host, r.port, r.host) + server = ServerConnection(self.flow.server_conn.address(), None) server.connect() - server.send(r) - httpversion, code, msg, headers, content = http.read_response( - server.rfile, r.method, self.config.body_size_limit - ) - response = flow.Response( - self.flow.request, httpversion, code, msg, headers, content, server.cert, - server.rfile.first_byte_timestamp - ) - self.channel.ask("response", response) + if self.flow.server_conn.ssl_established: + server.establish_ssl(self.config.clientcerts, + self.flow.server_conn.sni) + server.send(r._assemble()) + self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, body_size_limit=self.config.body_size_limit) + self.channel.ask("response", self.flow.response) except (ProxyError, http.HttpError, tcp.NetLibError), v: - err = flow.Error(self.flow.request, str(v)) - self.channel.ask("error", err) + self.flow.error = protocol.primitives.Error(str(v)) + self.channel.ask("error", self.flow.error) + + +class ConnectionHandler: + def __init__(self, config, client_connection, client_address, server, channel, server_version): + self.config = config + self.client_conn = ClientConnection(client_connection, client_address, server) + self.server_conn = None + self.channel, self.server_version = channel, server_version + + self.close = False + self.conntype = None + self.sni = None + + self.mode = "regular" + if self.config.reverse_proxy: + self.mode = "reverse" + if self.config.transparent_proxy: + self.mode = "transparent" + def handle(self): + self.log("clientconnect") + self.channel.ask("clientconnect", self) -class HandleSNI: - def __init__(self, handler, client_conn, host, port, key): - self.handler, self.client_conn, self.host, self.port = handler, client_conn, host, port - self.key = key + self.determine_conntype() - def __call__(self, client_connection): try: - sn = client_connection.get_servername() - if sn: - self.handler.get_server_connection(self.client_conn, "https", self.host, self.port, sn) - dummycert = self.handler.find_cert(self.client_conn, self.host, self.port, sn) - new_context = SSL.Context(SSL.TLSv1_METHOD) - new_context.use_privatekey_file(self.key) - new_context.use_certificate(dummycert.x509) - client_connection.set_context(new_context) - self.handler.sni = sn.decode("utf8").encode("idna") - # An unhandled exception in this method will core dump PyOpenSSL, so - # make dang sure it doesn't happen. - except Exception: # pragma: no cover - pass + try: + # Can we already identify the target server and connect to it? + server_address = None + address_priority = None + if self.config.forward_proxy: + server_address = self.config.forward_proxy[1:] + address_priority = AddressPriority.FORCE + elif self.config.reverse_proxy: + server_address = self.config.reverse_proxy[1:] + address_priority = AddressPriority.FROM_SETTINGS + elif self.config.transparent_proxy: + server_address = self.config.transparent_proxy["resolver"].original_addr( + self.client_conn.connection) + if not server_address: + raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") + address_priority = AddressPriority.FROM_CONNECTION + self.log("transparent to %s:%s" % server_address) + + if server_address: + self.set_server_address(server_address, address_priority) + self._handle_ssl() + + while not self.close: + try: + protocol.handle_messages(self.conntype, self) + except protocol.ConnectionTypeChange: + self.log("Connection Type Changed: %s" % self.conntype) + continue + + # FIXME: Do we want to persist errors? + except (ProxyError, tcp.NetLibError), e: + protocol.handle_error(self.conntype, self, e) + except Exception, e: + self.log(e.__class__) + import traceback + self.log(traceback.format_exc()) + self.log(str(e)) + self.del_server_connection() + self.log("clientdisconnect") + self.channel.tell("clientdisconnect", self) -class ProxyHandler(tcp.BaseHandler): - def __init__(self, config, connection, client_address, server, channel, server_version): - self.channel, self.server_version = channel, server_version - self.config = config - self.proxy_connect_state = None - self.sni = None - self.server_conn = None - tcp.BaseHandler.__init__(self, connection, client_address, server) + def _handle_ssl(self): + """ + Helper function of .handle() + Check if we can already identify SSL connections. + If so, connect to the server and establish an SSL connection + """ + client_ssl = False + server_ssl = False - def get_server_connection(self, cc, scheme, host, port, sni, request=None): + if self.config.transparent_proxy: + client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"]) + elif self.config.reverse_proxy: + client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https") + # TODO: Make protocol generic (as with transparent proxies) + # TODO: Add SSL-terminating capatbility (SSL -> mitmproxy -> plain and vice versa) + if client_ssl or server_ssl: + self.establish_server_connection() + self.establish_ssl(client=client_ssl, server=server_ssl) + + def del_server_connection(self): """ - When SNI is in play, this means we have an SSL-encrypted - connection, which means that the entire handler is dedicated to a - single server connection - no multiplexing. If this assumption ever - breaks, we'll have to do something different with the SNI host - variable on the handler object. - - `conn_info` holds the initial connection's parameters, as the - hook might change them. Also, the hook might require an initial - request to figure out connection settings; in this case it can - set require_request, which will cause the connection to be - re-opened after the client's request arrives. + Deletes an existing server connection. """ - sc = self.server_conn - if not sni: - sni = host - conn_info = (scheme, host, port, sni) - if sc and (conn_info != sc.conn_info or (request and sc.require_request)): - sc.terminate() - self.server_conn = None - self.log( - cc, - "switching connection", [ - "%s://%s:%s (sni=%s) -> %s://%s:%s (sni=%s)"%( - scheme, host, port, sni, - sc.scheme, sc.host, sc.port, sc.sni - ) - ] - ) - if not self.server_conn: - try: - self.server_conn = ServerConnection(self.config, scheme, host, port, sni) + if self.server_conn and self.server_conn.connection: + self.server_conn.finish() + self.log("serverdisconnect", ["%s:%s" % (self.server_conn.address.host, self.server_conn.address.port)]) + self.channel.tell("serverdisconnect", self) + self.server_conn = None + self.sni = None - # Additional attributes, used if the server_connect hook - # needs to change parameters - self.server_conn.request = request - self.server_conn.require_request = False + def determine_conntype(self): + #TODO: Add ruleset to select correct protocol depending on mode/target port etc. + self.conntype = "http" - self.server_conn.conn_info = conn_info - self.channel.ask("serverconnect", self.server_conn) - self.server_conn.connect() - except tcp.NetLibError, v: - raise ProxyError(502, v) - return self.server_conn + def set_server_address(self, address, priority): + """ + Sets a new server address with the given priority. + Does not re-establish either connection or SSL handshake. + @type priority: AddressPriority + """ + address = tcp.Address.wrap(address) - def del_server_connection(self): if self.server_conn: - self.server_conn.terminate() - self.server_conn = None + if self.server_conn.priority > priority: + self.log("Attempt to change server address, " + "but priority is too low (is: %s, got: %s)" % (self.server_conn.priority, priority)) + return + if self.server_conn.address == address: + self.server_conn.priority = priority # Possibly increase priority + return - def handle(self): - cc = flow.ClientConnect(self.client_address) - self.log(cc, "connect") - self.channel.ask("clientconnect", cc) - while self.handle_request(cc) and not cc.close: - pass - cc.close = True - self.del_server_connection() + self.del_server_connection() - cd = flow.ClientDisconnect(cc) - self.log( - cc, "disconnect", - [ - "handled %s requests"%cc.requestcount] - ) - self.channel.tell("clientdisconnect", cd) + self.log("Set new server address: %s:%s" % (address.host, address.port)) + self.server_conn = ServerConnection(address, priority) - def handle_request(self, cc): + def establish_server_connection(self): + """ + Establishes a new server connection. + If there is already an existing server connection, the function returns immediately. + """ + if self.server_conn.connection: + return + self.log("serverconnect", ["%s:%s" % self.server_conn.address()[:2]]) + self.channel.tell("serverconnect", self) try: - request, err = None, None - request = self.read_request(cc) - if request is None: - return - cc.requestcount += 1 + self.server_conn.connect() + except tcp.NetLibError, v: + raise ProxyError(502, v) - request_reply = self.channel.ask("request", request) - if request_reply is None or request_reply == KILL: - return - elif isinstance(request_reply, flow.Response): - request = False - response = request_reply - response_reply = self.channel.ask("response", response) - else: - request = request_reply - if self.config.reverse_proxy: - scheme, host, port = self.config.reverse_proxy - elif self.config.forward_proxy: - scheme, host, port = self.config.forward_proxy - else: - scheme, host, port = request.scheme, request.host, request.port - - # If we've already pumped a request over this connection, - # it's possible that the server has timed out. If this is - # the case, we want to reconnect without sending an error - # to the client. - while 1: - sc = self.get_server_connection(cc, scheme, host, port, self.sni, request=request) - sc.send(request) - if sc.requestcount == 1: # add timestamps only for first request (others are not directly affected) - request.tcp_setup_timestamp = sc.tcp_setup_timestamp - request.ssl_setup_timestamp = sc.ssl_setup_timestamp - sc.rfile.reset_timestamps() - try: - peername = sc.connection.getpeername() - if peername: - request.ip = peername[0] - httpversion, code, msg, headers, content = http.read_response( - sc.rfile, - request.method, - self.config.body_size_limit - ) - except http.HttpErrorConnClosed: - self.del_server_connection() - if sc.requestcount > 1: - continue - else: - raise - except http.HttpError: - raise ProxyError(502, "Invalid server response.") - else: - break - - response = flow.Response( - request, httpversion, code, msg, headers, content, sc.cert, - sc.rfile.first_byte_timestamp - ) - response_reply = self.channel.ask("response", response) - # Not replying to the server invalidates the server - # connection, so we terminate. - if response_reply == KILL: - sc.terminate() - - if response_reply == KILL: - return - else: - response = response_reply - self.send_response(response) - if request and http.connection_close(request.httpversion, request.headers): - return - # We could keep the client connection when the server - # connection needs to go away. However, we want to mimic - # behaviour as closely as possible to the client, so we - # disconnect. - if http.connection_close(response.httpversion, response.headers): - return - except (IOError, ProxyError, http.HttpError, tcp.NetLibError), e: - if hasattr(e, "code"): - cc.error = "%s: %s"%(e.code, e.msg) - else: - cc.error = str(e) - - if request: - err = flow.Error(request, cc.error) - self.channel.ask("error", err) - self.log( - cc, cc.error, - ["url: %s"%request.get_url()] - ) - else: - self.log(cc, cc.error) - if isinstance(e, ProxyError): - self.send_error(e.code, e.msg, e.headers) - else: - return True + def establish_ssl(self, client=False, server=False): + """ + Establishes SSL on the existing connection(s) to the server or the client, + as specified by the parameters. If the target server is on the pass-through list, + the conntype attribute will be changed and the SSL connection won't be wrapped. + A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening + """ + # TODO: Implement SSL pass-through handling and change conntype + passthrough = [ + # "echo.websocket.org", + # "174.129.224.73" # echo.websocket.org, transparent mode + ] + if self.server_conn.address.host in passthrough or self.sni in passthrough: + self.conntype = "tcp" + return + + # Logging + if client or server: + subs = [] + if client: + subs.append("with client") + if server: + subs.append("with server (sni: %s)" % self.sni) + self.log("Establish SSL", subs) + + if server: + if self.server_conn.ssl_established: + raise ProxyError(502, "SSL to Server already established.") + self.establish_server_connection() # make sure there is a server connection. + self.server_conn.establish_ssl(self.config.clientcerts, self.sni) + if client: + if self.client_conn.ssl_established: + raise ProxyError(502, "SSL to Client already established.") + dummycert = self.find_cert() + self.client_conn.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, + handle_sni=self.handle_sni) + + def server_reconnect(self, no_ssl=False): + address = self.server_conn.address + had_ssl = self.server_conn.ssl_established + priority = self.server_conn.priority + sni = self.sni + self.log("(server reconnect follows)") + self.del_server_connection() + self.set_server_address(address, priority) + self.establish_server_connection() + if had_ssl and not no_ssl: + self.sni = sni + self.establish_ssl(server=True) + + def finish(self): + self.client_conn.finish() - def log(self, cc, msg, subs=()): + def log(self, msg, subs=()): msg = [ - "%s:%s: "%cc.address + msg + "%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg) ] for i in subs: - msg.append(" -> "+i) + msg.append(" -> " + i) msg = "\n".join(msg) - l = Log(msg) - self.channel.tell("log", l) + self.channel.tell("log", Log(msg)) - def find_cert(self, cc, host, port, sni): + def find_cert(self): if self.config.certfile: with open(self.config.certfile, "rb") as f: return certutils.SSLCert.from_pem(f.read()) else: + host = self.server_conn.address.host sans = [] - if not self.config.no_upstream_cert: - conn = self.get_server_connection(cc, "https", host, port, sni) - sans = conn.cert.altnames - if conn.cert.cn: - host = conn.cert.cn.decode("utf8").encode("idna") + if not self.config.no_upstream_cert or not self.server_conn.ssl_established: + upstream_cert = self.server_conn.cert + if upstream_cert.cn: + host = upstream_cert.cn.decode("utf8").encode("idna") + sans = upstream_cert.altnames + ret = self.config.certstore.get_cert(host, sans, self.config.cacert) if not ret: raise ProxyError(502, "Unable to generate dummy cert.") return ret - def establish_ssl(self, client_conn, host, port): - dummycert = self.find_cert(client_conn, host, port, host) - sni = HandleSNI( - self, client_conn, host, port, self.config.certfile or self.config.cacert - ) - try: - self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni) - except tcp.NetLibError, v: - raise ProxyError(400, str(v)) - - def get_line(self, fp): + def handle_sni(self, connection): """ - Get a line, possibly preceded by a blank. + This callback gets called during the SSL handshake with the client. + The client has just sent the Sever Name Indication (SNI). We now connect upstream to + figure out which certificate needs to be served. """ - line = fp.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message - line = fp.readline() - return line - - def read_request(self, client_conn): - self.rfile.reset_timestamps() - if self.config.transparent_proxy: - return self.read_request_transparent(client_conn) - elif self.config.reverse_proxy: - return self.read_request_reverse(client_conn) - else: - return self.read_request_proxy(client_conn) - - def read_request_transparent(self, client_conn): - orig = self.config.transparent_proxy["resolver"].original_addr(self.connection) - if not orig: - raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") - self.log(client_conn, "transparent to %s:%s"%orig) - - host, port = orig - if port in self.config.transparent_proxy["sslports"]: - scheme = "https" - else: - scheme = "http" - - return self._read_request_origin_form(client_conn, scheme, host, port) - - def read_request_reverse(self, client_conn): - scheme, host, port = self.config.reverse_proxy - return self._read_request_origin_form(client_conn, scheme, host, port) - - def read_request_proxy(self, client_conn): - # Check for a CONNECT command. - if not self.proxy_connect_state: - line = self.get_line(self.rfile) - if line == "": - return None - self.proxy_connect_state = self._read_request_authority_form(line) - - # Check for an actual request - if self.proxy_connect_state: - host, port, _ = self.proxy_connect_state - return self._read_request_origin_form(client_conn, "https", host, port) - else: - # noinspection PyUnboundLocalVariable - return self._read_request_absolute_form(client_conn, line) - - def _read_request_authority_form(self, line): - """ - The authority-form of request-target is only used for CONNECT requests. - The CONNECT method is used to request a tunnel to the destination server. - This function sends a "200 Connection established" response to the client - and returns the host information that can be used to process further requests in origin-form. - An example authority-form request line would be: - CONNECT www.example.com:80 HTTP/1.1 - """ - connparts = http.parse_init_connect(line) - if connparts: - self.read_headers(authenticate=True) - # respond according to http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.2 - self.wfile.write( - 'HTTP/1.1 200 Connection established\r\n' + - ('Proxy-agent: %s\r\n'%self.server_version) + - '\r\n' - ) - self.wfile.flush() - return connparts - - def _read_request_absolute_form(self, client_conn, line): - """ - When making a request to a proxy (other than CONNECT or OPTIONS), - a client must send the target uri in absolute-form. - An example absolute-form request line would be: - GET http://www.example.com/foo.html HTTP/1.1 - """ - r = http.parse_init_proxy(line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) - method, scheme, host, port, path, httpversion = r - headers = self.read_headers(authenticate=True) - self.handle_expect_header(headers, httpversion) - content = http.read_http_body( - self.rfile, headers, self.config.body_size_limit, True - ) - r = flow.Request( - client_conn, httpversion, host, port, scheme, method, path, headers, content, - self.rfile.first_byte_timestamp, utils.timestamp() - ) - r.set_live(self.rfile, self.wfile) - return r - - def _read_request_origin_form(self, client_conn, scheme, host, port): - """ - Read a HTTP request with regular (origin-form) request line. - An example origin-form request line would be: - GET /foo.html HTTP/1.1 - - The request destination is already known from one of the following sources: - 1) transparent proxy: destination provided by platform resolver - 2) reverse proxy: fixed destination - 3) regular proxy: known from CONNECT command. - """ - if scheme.lower() == "https" and not self.ssl_established: - self.establish_ssl(client_conn, host, port) - - line = self.get_line(self.rfile) - if line == "": - return None - - r = http.parse_init_http(line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) - method, path, httpversion = r - headers = self.read_headers(authenticate=False) - self.handle_expect_header(headers, httpversion) - content = http.read_http_body( - self.rfile, headers, self.config.body_size_limit, True - ) - r = flow.Request( - client_conn, httpversion, host, port, scheme, method, path, headers, content, - self.rfile.first_byte_timestamp, utils.timestamp() - ) - r.set_live(self.rfile, self.wfile) - return r - - def handle_expect_header(self, headers, httpversion): - if "expect" in headers: - if "100-continue" in headers['expect'] and httpversion >= (1, 1): - #FIXME: Check if content-length is over limit - self.wfile.write('HTTP/1.1 100 Continue\r\n' - '\r\n') - del headers['expect'] - - def read_headers(self, authenticate=False): - headers = http.read_headers(self.rfile) - if headers is None: - raise ProxyError(400, "Invalid headers") - if authenticate and self.config.authenticator: - if self.config.authenticator.authenticate(headers): - self.config.authenticator.clean(headers) - else: - raise ProxyError( - 407, - "Proxy Authentication Required", - self.config.authenticator.auth_challenge_headers() - ) - return headers - - def send_response(self, response): - d = response._assemble() - if not d: - raise ProxyError(502, "Cannot transmit an incomplete response.") - self.wfile.write(d) - self.wfile.flush() - - def send_error(self, code, body, headers): try: - response = http_status.RESPONSES.get(code, "Unknown") - html_content = '<html><head>\n<title>%d %s</title>\n</head>\n<body>\n%s\n</body>\n</html>'%(code, response, body) - self.wfile.write("HTTP/1.1 %s %s\r\n" % (code, response)) - self.wfile.write("Server: %s\r\n"%self.server_version) - self.wfile.write("Content-type: text/html\r\n") - self.wfile.write("Content-Length: %d\r\n"%len(html_content)) - if headers: - for key, value in headers.items(): - self.wfile.write("%s: %s\r\n"%(key, value)) - self.wfile.write("Connection: close\r\n") - self.wfile.write("\r\n") - self.wfile.write(html_content) - self.wfile.flush() - except: + sn = connection.get_servername() + if sn and sn != self.sni: + self.sni = sn.decode("utf8").encode("idna") + self.log("SNI received: %s" % self.sni) + self.server_reconnect() # reconnect to upstream server with SNI + # Now, change client context to reflect changed certificate: + new_context = SSL.Context(SSL.TLSv1_METHOD) + new_context.use_privatekey_file(self.config.certfile or self.config.cacert) + dummycert = self.find_cert() + new_context.use_certificate(dummycert.x509) + connection.set_context(new_context) + # An unhandled exception in this method will core dump PyOpenSSL, so + # make dang sure it doesn't happen. + except Exception, e: # pragma: no cover pass -class ProxyServerError(Exception): pass +class ProxyServerError(Exception): + pass class ProxyServer(tcp.TCPServer): allow_reuse_address = True bound = True - def __init__(self, config, port, address='', server_version=version.NAMEVERSION): + + def __init__(self, config, port, host='', server_version=version.NAMEVERSION): """ Raises ProxyServerError if there's a startup problem. """ - self.config, self.port, self.address = config, port, address + self.config = config self.server_version = server_version try: - tcp.TCPServer.__init__(self, (address, port)) + tcp.TCPServer.__init__(self, (host, port)) except socket.error, v: raise ProxyServerError('Error starting proxy server: ' + v.strerror) self.channel = None @@ -540,14 +478,15 @@ class ProxyServer(tcp.TCPServer): def set_channel(self, channel): self.channel = channel - def handle_connection(self, request, client_address): - h = ProxyHandler(self.config, request, client_address, self, self.channel, self.server_version) + def handle_client_connection(self, conn, client_address): + h = ConnectionHandler(self.config, conn, client_address, self, self.channel, self.server_version) h.handle() h.finish() class DummyServer: bound = False + def __init__(self, config): self.config = config @@ -563,22 +502,21 @@ def certificate_option_group(parser): group = parser.add_argument_group("SSL") group.add_argument( "--cert", action="store", - type = str, dest="cert", default=None, - help = "User-created SSL certificate file." + type=str, dest="cert", default=None, + help="User-created SSL certificate file." ) group.add_argument( "--client-certs", action="store", - type = str, dest = "clientcerts", default=None, - help = "Client certificate directory." + type=str, dest="clientcerts", default=None, + help="Client certificate directory." ) - def process_proxy_options(parser, options): if options.cert: options.cert = os.path.expanduser(options.cert) if not os.path.exists(options.cert): - return parser.error("Manually created certificate does not exist: %s"%options.cert) + return parser.error("Manually created certificate does not exist: %s" % options.cert) cacert = os.path.join(options.confdir, "mitmproxy-ca.pem") cacert = os.path.expanduser(cacert) @@ -592,8 +530,8 @@ def process_proxy_options(parser, options): if not platform.resolver: return parser.error("Transparent mode not supported on this platform.") trans = dict( - resolver = platform.resolver(), - sslports = TRANSPARENT_SSL_PORTS + resolver=platform.resolver(), + sslports=TRANSPARENT_SSL_PORTS ) else: trans = None @@ -601,14 +539,14 @@ def process_proxy_options(parser, options): if options.reverse_proxy: rp = utils.parse_proxy_spec(options.reverse_proxy) if not rp: - return parser.error("Invalid reverse proxy specification: %s"%options.reverse_proxy) + return parser.error("Invalid reverse proxy specification: %s" % options.reverse_proxy) else: rp = None if options.forward_proxy: fp = utils.parse_proxy_spec(options.forward_proxy) if not fp: - return parser.error("Invalid forward proxy specification: %s"%options.forward_proxy) + return parser.error("Invalid forward proxy specification: %s" % options.forward_proxy) else: fp = None @@ -616,8 +554,8 @@ def process_proxy_options(parser, options): options.clientcerts = os.path.expanduser(options.clientcerts) if not os.path.exists(options.clientcerts) or not os.path.isdir(options.clientcerts): return parser.error( - "Client certificate directory does not exist or is not a directory: %s"%options.clientcerts - ) + "Client certificate directory does not exist or is not a directory: %s" % options.clientcerts + ) if (options.auth_nonanonymous or options.auth_singleuser or options.auth_htpasswd): if options.auth_singleuser: @@ -637,13 +575,13 @@ def process_proxy_options(parser, options): authenticator = http_auth.NullProxyAuth(None) return ProxyConfig( - certfile = options.cert, - cacert = cacert, - clientcerts = options.clientcerts, - body_size_limit = body_size_limit, - no_upstream_cert = options.no_upstream_cert, - reverse_proxy = rp, - forward_proxy = fp, - transparent_proxy = trans, - authenticator = authenticator + certfile=options.cert, + cacert=cacert, + clientcerts=options.clientcerts, + body_size_limit=body_size_limit, + no_upstream_cert=options.no_upstream_cert, + reverse_proxy=rp, + forward_proxy=fp, + transparent_proxy=trans, + authenticator=authenticator ) |