diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/tcp.py | 26 | ||||
-rw-r--r-- | netlib/test.py | 3 |
2 files changed, 16 insertions, 13 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py index c6e0075e..7f98b4f9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -39,6 +39,9 @@ class SSLKeyLogger(object): if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: with self.lock: if not self.f: + d = os.path.dirname(self.filename) + if not os.path.isdir(d): + os.makedirs(d) self.f = open(self.filename, "ab") self.f.write("\r\n") client_random = connection.client_random().encode("hex") @@ -51,11 +54,13 @@ class SSLKeyLogger(object): if self.f: self.f.close() -_logfile = os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE") -if _logfile: - log_ssl_key = SSLKeyLogger(_logfile) -else: - log_ssl_key = False + @staticmethod + def create_logfun(filename): + if filename: + return SSLKeyLogger(filename) + return False + +log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) class _FileLike: @@ -161,9 +166,9 @@ class Reader(_FileLike): except SSL.SysCallError as e: if e.args == (-1, 'Unexpected EOF'): break - raise NetLibDisconnect - except SSL.Error, v: - raise NetLibSSLError(v.message) + raise NetLibSSLError(e.message) + except SSL.Error as e: + raise NetLibSSLError(e.message) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break @@ -179,10 +184,7 @@ class Reader(_FileLike): while True: if size is not None and bytes_read >= size: break - try: - ch = self.read(1) - except NetLibDisconnect: - break + ch = self.read(1) bytes_read += 1 if not ch: break diff --git a/netlib/test.py b/netlib/test.py index fb468907..3a23ba8f 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -75,7 +75,8 @@ class TServer(tcp.TCPServer): handle_sni = getattr(h, "handle_sni", None), request_client_cert = self.ssl["request_client_cert"], cipher_list = self.ssl.get("cipher_list", None), - dhparams = self.ssl.get("dhparams", None) + dhparams = self.ssl.get("dhparams", None), + chain_file = self.ssl.get("chain_file", None) ) h.handle() h.finish() |