aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/tcp.py8
-rw-r--r--pathod/pathoc.py3
-rw-r--r--test/pathod/tutils.py36
3 files changed, 28 insertions, 19 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index bb0c93a9..61209d64 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -6,6 +6,7 @@ import sys
import threading
import time
import traceback
+import contextlib
import binascii
from six.moves import range
@@ -577,6 +578,12 @@ class _Connection(object):
return context
+@contextlib.contextmanager
+def _closer(client):
+ yield
+ client.close()
+
+
class TCPClient(_Connection):
def __init__(self, address, source_address=None):
@@ -708,6 +715,7 @@ class TCPClient(_Connection):
self.connection = connection
self.ip_address = Address(connection.getpeername())
self._makefile()
+ return _closer(self)
def settimeout(self, n):
self.connection.settimeout(n)
diff --git a/pathod/pathoc.py b/pathod/pathoc.py
index 2b7d053c..5cfb4591 100644
--- a/pathod/pathoc.py
+++ b/pathod/pathoc.py
@@ -286,7 +286,7 @@ class Pathoc(tcp.TCPClient):
if self.use_http2 and not self.ssl:
raise NotImplementedError("HTTP2 without SSL is not supported.")
- tcp.TCPClient.connect(self)
+ ret = tcp.TCPClient.connect(self)
if connect_to:
self.http_connect(connect_to)
@@ -324,6 +324,7 @@ class Pathoc(tcp.TCPClient):
if self.timeout:
self.settimeout(self.timeout)
+ return ret
def stop(self):
if self.ws_framereader:
diff --git a/test/pathod/tutils.py b/test/pathod/tutils.py
index e674812b..b9f38d86 100644
--- a/test/pathod/tutils.py
+++ b/test/pathod/tutils.py
@@ -88,11 +88,11 @@ class DaemonTests(object):
ssl=self.ssl,
fp=logfp,
)
- c.connect()
- if params:
- path = path + "?" + urllib.urlencode(params)
- resp = c.request("get:%s" % path)
- return resp
+ with c.connect():
+ if params:
+ path = path + "?" + urllib.urlencode(params)
+ resp = c.request("get:%s" % path)
+ return resp
def get(self, spec):
logfp = StringIO()
@@ -101,9 +101,9 @@ class DaemonTests(object):
ssl=self.ssl,
fp=logfp,
)
- c.connect()
- resp = c.request("get:/p/%s" % urllib.quote(spec).encode("string_escape"))
- return resp
+ with c.connect():
+ resp = c.request("get:/p/%s" % urllib.quote(spec).encode("string_escape"))
+ return resp
def pathoc(
self,
@@ -128,16 +128,16 @@ class DaemonTests(object):
fp=logfp,
use_http2=use_http2,
)
- c.connect(connect_to)
- ret = []
- for i in specs:
- resp = c.request(i)
- if resp:
- ret.append(resp)
- for frm in c.wait():
- ret.append(frm)
- c.stop()
- return ret, logfp.getvalue()
+ with c.connect(connect_to):
+ ret = []
+ for i in specs:
+ resp = c.request(i)
+ if resp:
+ ret.append(resp)
+ for frm in c.wait():
+ ret.append(frm)
+ c.stop()
+ return ret, logfp.getvalue()
tmpdir = tutils.tmpdir