diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/certutils.py | 45 | ||||
-rw-r--r-- | netlib/tcp.py | 3 | ||||
-rw-r--r-- | netlib/test.py | 7 |
3 files changed, 20 insertions, 35 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py index 4c06eb8f..7dcb5450 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -73,7 +73,7 @@ def dummy_ca(path): return True -def dummy_cert(fp, ca, commonname, sans): +def dummy_cert(ca, commonname, sans): """ Generates and writes a certificate to fp. @@ -111,27 +111,15 @@ def dummy_cert(fp, ca, commonname, sans): cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert.set_pubkey(req.get_pubkey()) cert.sign(key, "sha1") - - fp.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)) - fp.close() + return SSLCert(cert) class CertStore: """ - Implements an on-disk certificate store. + Implements an in-memory certificate store. """ - def __init__(self, certdir=None): - """ - certdir: The certificate store directory. If None, a temporary - directory will be created, and destroyed when the .cleanup() method - is called. - """ - if certdir: - self.remove = False - self.certdir = certdir - else: - self.remove = True - self.certdir = tempfile.mkdtemp(prefix="certstore") + def __init__(self): + self.certs = {} def check_domain(self, commonname): try: @@ -145,33 +133,26 @@ class CertStore: return False return True - def get_cert(self, commonname, sans, cacert=False): + def get_cert(self, commonname, sans, cacert): """ - Returns the path to a certificate. + Returns an SSLCert object. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. sans: A list of Subject Alternate Names. - cacert: An optional path to a CA certificate. If specified, the - cert is created if it does not exist, else return None. + cacert: The path to a CA certificate. Return None if the certificate could not be found or generated. """ if not self.check_domain(commonname): return None - certpath = os.path.join(self.certdir, commonname + ".pem") - if os.path.exists(certpath): - return certpath - elif cacert: - f = open(certpath, "wb") - dummy_cert(f, cacert, commonname, sans) - return certpath - - def cleanup(self): - if self.remove: - shutil.rmtree(self.certdir) + if commonname in self.certs: + return self.certs[commonname] + c = dummy_cert(cacert, commonname, sans) + self.certs[commonname] = c + return c class _GeneralName(univ.Choice): diff --git a/netlib/tcp.py b/netlib/tcp.py index f4a8acf9..31e9a398 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -268,6 +268,7 @@ class BaseHandler: def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False): """ + cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: @@ -297,7 +298,7 @@ class BaseHandler: # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(handle_sni) ctx.use_privatekey_file(key) - ctx.use_certificate_file(cert) + ctx.use_certificate(cert.x509) if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) diff --git a/netlib/test.py b/netlib/test.py index deaef64e..661395c5 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,5 +1,5 @@ import threading, Queue, cStringIO -import tcp +import tcp, certutils class ServerThread(threading.Thread): def __init__(self, server): @@ -51,6 +51,9 @@ class TServer(tcp.TCPServer): h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl: + cert = certutils.SSLCert.from_pem( + file(self.ssl["cert"], "r").read() + ) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 @@ -58,7 +61,7 @@ class TServer(tcp.TCPServer): method = tcp.SSLv23_METHOD options = None h.convert_to_ssl( - self.ssl["cert"], + cert, self.ssl["key"], method = method, options = options, |