aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/certutils.py24
-rw-r--r--netlib/tcp.py7
-rw-r--r--netlib/test.py11
3 files changed, 31 insertions, 11 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py
index d544cfa6..19148382 100644
--- a/netlib/certutils.py
+++ b/netlib/certutils.py
@@ -115,11 +115,23 @@ class CertStore:
"""
Implements an in-memory certificate store.
"""
- def __init__(self, privkey, cacert):
+ def __init__(self, privkey, cacert, dhparams=None):
self.privkey, self.cacert = privkey, cacert
+ self.dhparams = dhparams
self.certs = DNTree()
@classmethod
+ def load_dhparam(klass, path):
+ bio = OpenSSL.SSL._lib.BIO_new_file(path, b"r")
+ if bio != OpenSSL.SSL._ffi.NULL:
+ bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free)
+ dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams(
+ bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL
+ )
+ dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free)
+ return dh
+
+ @classmethod
def from_store(klass, path, basename):
p = os.path.join(path, basename + "-ca.pem")
if not os.path.exists(p):
@@ -129,7 +141,9 @@ class CertStore:
raw = file(p, "rb").read()
ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
- return klass(key, ca)
+ dhp = os.path.join(path, basename + "-dhparam.pem")
+ dh = klass.load_dhparam(dhp)
+ return klass(key, ca, dh)
@classmethod
def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP):
@@ -147,17 +161,17 @@ class CertStore:
f.close()
# Dump the certificate in PEM format
- f = open(os.path.join(path, basename + "-cert.pem"), "wb")
+ f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb")
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Create a .cer file with the same contents for Android
- f = open(os.path.join(path, basename + "-cert.cer"), "wb")
+ f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb")
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Dump the certificate in PKCS12 format for Windows devices
- f = open(os.path.join(path, basename + "-cert.p12"), "wb")
+ f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb")
p12 = OpenSSL.crypto.PKCS12()
p12.set_certificate(ca)
p12.set_privatekey(key)
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 83059bc2..078ac497 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -339,7 +339,10 @@ class BaseHandler(_Connection):
self.ssl_established = False
self.clientcert = None
- def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None):
+ def convert_to_ssl(self, cert, key,
+ method=SSLv23_METHOD, options=None, handle_sni=None,
+ request_client_cert=False, cipher_list=None, dhparams=None
+ ):
"""
cert: A certutils.SSLCert object.
method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD
@@ -377,6 +380,8 @@ class BaseHandler(_Connection):
ctx.set_tlsext_servername_callback(handle_sni)
ctx.use_privatekey(key)
ctx.use_certificate(cert.x509)
+ if dhparams:
+ SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams)
if request_client_cert:
def ver(*args):
self.clientcert = certutils.SSLCert(args[1])
diff --git a/netlib/test.py b/netlib/test.py
index b88b3586..bb0012ad 100644
--- a/netlib/test.py
+++ b/netlib/test.py
@@ -18,7 +18,6 @@ class ServerTestBase:
ssl = None
handler = None
addr = ("localhost", 0)
-
@classmethod
def setupAll(cls):
cls.q = Queue.Queue()
@@ -43,15 +42,16 @@ class ServerTestBase:
class TServer(tcp.TCPServer):
def __init__(self, ssl, q, handler_klass, addr):
"""
- ssl: A {cert, key, v3_only} dict.
+ ssl: A dictionary of SSL parameters:
+
+ cert, key, request_client_cert, cipher_list,
+ dhparams, v3_only
"""
tcp.TCPServer.__init__(self, addr)
self.ssl, self.q = ssl, q
self.handler_klass = handler_klass
self.last_handler = None
-
-
def handle_client_connection(self, request, client_address):
h = self.handler_klass(request, client_address, self)
self.last_handler = h
@@ -73,7 +73,8 @@ class TServer(tcp.TCPServer):
options = options,
handle_sni = getattr(h, "handle_sni", None),
request_client_cert = self.ssl["request_client_cert"],
- cipher_list = self.ssl.get("cipher_list", None)
+ cipher_list = self.ssl.get("cipher_list", None),
+ dhparams = self.ssl.get("dhparams", None)
)
h.handle()
h.finish()