aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/__init__.py0
-rw-r--r--netlib/odict.py160
-rw-r--r--netlib/protocol.py218
-rw-r--r--netlib/tcp.py182
4 files changed, 560 insertions, 0 deletions
diff --git a/netlib/__init__.py b/netlib/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/netlib/__init__.py
diff --git a/netlib/odict.py b/netlib/odict.py
new file mode 100644
index 00000000..afc33caa
--- /dev/null
+++ b/netlib/odict.py
@@ -0,0 +1,160 @@
+import re, copy
+
+def safe_subn(pattern, repl, target, *args, **kwargs):
+ """
+ There are Unicode conversion problems with re.subn. We try to smooth
+ that over by casting the pattern and replacement to strings. We really
+ need a better solution that is aware of the actual content ecoding.
+ """
+ return re.subn(str(pattern), str(repl), target, *args, **kwargs)
+
+
+class ODict:
+ """
+ A dictionary-like object for managing ordered (key, value) data.
+ """
+ def __init__(self, lst=None):
+ self.lst = lst or []
+
+ def _kconv(self, s):
+ return s
+
+ def __eq__(self, other):
+ return self.lst == other.lst
+
+ def __getitem__(self, k):
+ """
+ Returns a list of values matching key.
+ """
+ ret = []
+ k = self._kconv(k)
+ for i in self.lst:
+ if self._kconv(i[0]) == k:
+ ret.append(i[1])
+ return ret
+
+ def _filter_lst(self, k, lst):
+ k = self._kconv(k)
+ new = []
+ for i in lst:
+ if self._kconv(i[0]) != k:
+ new.append(i)
+ return new
+
+ def __len__(self):
+ """
+ Total number of (key, value) pairs.
+ """
+ return len(self.lst)
+
+ def __setitem__(self, k, valuelist):
+ """
+ Sets the values for key k. If there are existing values for this
+ key, they are cleared.
+ """
+ if isinstance(valuelist, basestring):
+ raise ValueError("ODict valuelist should be lists.")
+ new = self._filter_lst(k, self.lst)
+ for i in valuelist:
+ new.append([k, i])
+ self.lst = new
+
+ def __delitem__(self, k):
+ """
+ Delete all items matching k.
+ """
+ self.lst = self._filter_lst(k, self.lst)
+
+ def __contains__(self, k):
+ for i in self.lst:
+ if self._kconv(i[0]) == self._kconv(k):
+ return True
+ return False
+
+ def add(self, key, value):
+ self.lst.append([key, str(value)])
+
+ def get(self, k, d=None):
+ if k in self:
+ return self[k]
+ else:
+ return d
+
+ def items(self):
+ return self.lst[:]
+
+ def _get_state(self):
+ return [tuple(i) for i in self.lst]
+
+ @classmethod
+ def _from_state(klass, state):
+ return klass([list(i) for i in state])
+
+ def copy(self):
+ """
+ Returns a copy of this object.
+ """
+ lst = copy.deepcopy(self.lst)
+ return self.__class__(lst)
+
+ def __repr__(self):
+ elements = []
+ for itm in self.lst:
+ elements.append(itm[0] + ": " + itm[1])
+ elements.append("")
+ return "\r\n".join(elements)
+
+ def in_any(self, key, value, caseless=False):
+ """
+ Do any of the values matching key contain value?
+
+ If caseless is true, value comparison is case-insensitive.
+ """
+ if caseless:
+ value = value.lower()
+ for i in self[key]:
+ if caseless:
+ i = i.lower()
+ if value in i:
+ return True
+ return False
+
+ def match_re(self, expr):
+ """
+ Match the regular expression against each (key, value) pair. For
+ each pair a string of the following format is matched against:
+
+ "key: value"
+ """
+ for k, v in self.lst:
+ s = "%s: %s"%(k, v)
+ if re.search(expr, s):
+ return True
+ return False
+
+ def replace(self, pattern, repl, *args, **kwargs):
+ """
+ Replaces a regular expression pattern with repl in both keys and
+ values. Encoded content will be decoded before replacement, and
+ re-encoded afterwards.
+
+ Returns the number of replacements made.
+ """
+ nlst, count = [], 0
+ for i in self.lst:
+ k, c = safe_subn(pattern, repl, i[0], *args, **kwargs)
+ count += c
+ v, c = safe_subn(pattern, repl, i[1], *args, **kwargs)
+ count += c
+ nlst.append([k, v])
+ self.lst = nlst
+ return count
+
+
+class ODictCaseless(ODict):
+ """
+ A variant of ODict with "caseless" keys. This version _preserves_ key
+ case, but does not consider case when setting or getting items.
+ """
+ def _kconv(self, s):
+ return s.lower()
diff --git a/netlib/protocol.py b/netlib/protocol.py
new file mode 100644
index 00000000..55bcf440
--- /dev/null
+++ b/netlib/protocol.py
@@ -0,0 +1,218 @@
+import string, urlparse
+
+class ProtocolError(Exception):
+ def __init__(self, code, msg):
+ self.code, self.msg = code, msg
+
+ def __str__(self):
+ return "ProtocolError(%s, %s)"%(self.code, self.msg)
+
+
+def parse_url(url):
+ """
+ Returns a (scheme, host, port, path) tuple, or None on error.
+ """
+ scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
+ if not scheme:
+ return None
+ if ':' in netloc:
+ host, port = string.rsplit(netloc, ':', maxsplit=1)
+ try:
+ port = int(port)
+ except ValueError:
+ return None
+ else:
+ host = netloc
+ if scheme == "https":
+ port = 443
+ else:
+ port = 80
+ path = urlparse.urlunparse(('', '', path, params, query, fragment))
+ if not path.startswith("/"):
+ path = "/" + path
+ return scheme, host, port, path
+
+
+def read_headers(fp):
+ """
+ Read a set of headers from a file pointer. Stop once a blank line
+ is reached. Return a ODictCaseless object.
+ """
+ ret = []
+ name = ''
+ while 1:
+ line = fp.readline()
+ if not line or line == '\r\n' or line == '\n':
+ break
+ if line[0] in ' \t':
+ # continued header
+ ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
+ else:
+ i = line.find(':')
+ # We're being liberal in what we accept, here.
+ if i > 0:
+ name = line[:i]
+ value = line[i+1:].strip()
+ ret.append([name, value])
+ return ret
+
+
+def read_chunked(fp, limit):
+ content = ""
+ total = 0
+ while 1:
+ line = fp.readline(128)
+ if line == "":
+ raise IOError("Connection closed")
+ if line == '\r\n' or line == '\n':
+ continue
+ try:
+ length = int(line,16)
+ except ValueError:
+ # FIXME: Not strictly correct - this could be from the server, in which
+ # case we should send a 502.
+ raise ProtocolError(400, "Invalid chunked encoding length: %s"%line)
+ if not length:
+ break
+ total += length
+ if limit is not None and total > limit:
+ msg = "HTTP Body too large."\
+ " Limit is %s, chunked content length was at least %s"%(limit, total)
+ raise ProtocolError(509, msg)
+ content += fp.read(length)
+ line = fp.readline(5)
+ if line != '\r\n':
+ raise IOError("Malformed chunked body")
+ while 1:
+ line = fp.readline()
+ if line == "":
+ raise IOError("Connection closed")
+ if line == '\r\n' or line == '\n':
+ break
+ return content
+
+
+def has_chunked_encoding(headers):
+ for i in headers["transfer-encoding"]:
+ for j in i.split(","):
+ if j.lower() == "chunked":
+ return True
+ return False
+
+
+def read_http_body(rfile, headers, all, limit):
+ if has_chunked_encoding(headers):
+ content = read_chunked(rfile, limit)
+ elif "content-length" in headers:
+ try:
+ l = int(headers["content-length"][0])
+ except ValueError:
+ # FIXME: Not strictly correct - this could be from the server, in which
+ # case we should send a 502.
+ raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"])
+ if limit is not None and l > limit:
+ raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l))
+ content = rfile.read(l)
+ elif all:
+ content = rfile.read(limit if limit else None)
+ else:
+ content = ""
+ return content
+
+
+def parse_http_protocol(s):
+ if not s.startswith("HTTP/"):
+ return None
+ major, minor = s.split('/')[1].split('.')
+ major = int(major)
+ minor = int(minor)
+ return major, minor
+
+
+def parse_init_connect(line):
+ try:
+ method, url, protocol = string.split(line)
+ except ValueError:
+ return None
+ if method != 'CONNECT':
+ return None
+ try:
+ host, port = url.split(":")
+ except ValueError:
+ return None
+ port = int(port)
+ httpversion = parse_http_protocol(protocol)
+ if not httpversion:
+ return None
+ return host, port, httpversion
+
+
+def parse_init_proxy(line):
+ try:
+ method, url, protocol = string.split(line)
+ except ValueError:
+ return None
+ parts = parse_url(url)
+ if not parts:
+ return None
+ scheme, host, port, path = parts
+ httpversion = parse_http_protocol(protocol)
+ if not httpversion:
+ return None
+ return method, scheme, host, port, path, httpversion
+
+
+def parse_init_http(line):
+ """
+ Returns (method, url, httpversion)
+ """
+ try:
+ method, url, protocol = string.split(line)
+ except ValueError:
+ return None
+ if not (url.startswith("/") or url == "*"):
+ return None
+ httpversion = parse_http_protocol(protocol)
+ if not httpversion:
+ return None
+ return method, url, httpversion
+
+
+def request_connection_close(httpversion, headers):
+ """
+ Checks the request to see if the client connection should be closed.
+ """
+ if "connection" in headers:
+ for value in ",".join(headers['connection']).split(","):
+ value = value.strip()
+ if value == "close":
+ return True
+ elif value == "keep-alive":
+ return False
+ # HTTP 1.1 connections are assumed to be persistent
+ if httpversion == (1, 1):
+ return False
+ return True
+
+
+def response_connection_close(httpversion, headers):
+ """
+ Checks the response to see if the client connection should be closed.
+ """
+ if request_connection_close(httpversion, headers):
+ return True
+ elif not has_chunked_encoding(headers) and "content-length" in headers:
+ return True
+ return False
+
+
+def read_http_body_request(rfile, wfile, headers, httpversion, limit):
+ if "expect" in headers:
+ # FIXME: Should be forwarded upstream
+ expect = ",".join(headers['expect'])
+ if expect == "100-continue" and httpversion >= (1, 1):
+ wfile.write('HTTP/1.1 100 Continue\r\n')
+ wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
+ wfile.write('\r\n')
+ del headers['expect']
+ return read_http_body(rfile, headers, False, limit)
diff --git a/netlib/tcp.py b/netlib/tcp.py
new file mode 100644
index 00000000..08ccba09
--- /dev/null
+++ b/netlib/tcp.py
@@ -0,0 +1,182 @@
+import select, socket, threading, traceback, sys
+from OpenSSL import SSL
+
+
+class NetLibError(Exception): pass
+
+
+class FileLike:
+ def __init__(self, o):
+ self.o = o
+
+ def __getattr__(self, attr):
+ return getattr(self.o, attr)
+
+ def flush(self):
+ pass
+
+ def read(self, length):
+ result = ''
+ while len(result) < length:
+ try:
+ data = self.o.read(length)
+ except SSL.ZeroReturnError:
+ break
+ if not data:
+ break
+ result += data
+ return result
+
+ def write(self, v):
+ self.o.sendall(v)
+
+ def readline(self, size = None):
+ result = ''
+ bytes_read = 0
+ while True:
+ if size is not None and bytes_read >= size:
+ break
+ ch = self.read(1)
+ bytes_read += 1
+ if not ch:
+ break
+ else:
+ result += ch
+ if ch == '\n':
+ break
+ return result
+
+
+class TCPClient:
+ def __init__(self, ssl, host, port, clientcert):
+ self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert
+ self.connection, self.rfile, self.wfile = None, None, None
+ self.cert = None
+ self.connect()
+
+ def connect(self):
+ try:
+ addr = socket.gethostbyname(self.host)
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if self.ssl:
+ context = SSL.Context(SSL.SSLv23_METHOD)
+ if self.clientcert:
+ context.use_certificate_file(self.clientcert)
+ server = SSL.Connection(context, server)
+ server.connect((addr, self.port))
+ if self.ssl:
+ self.cert = server.get_peer_certificate()
+ self.rfile, self.wfile = FileLike(server), FileLike(server)
+ else:
+ self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
+ except socket.error, err:
+ raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
+ self.connection = server
+
+
+class BaseHandler:
+ rbufsize = -1
+ wbufsize = 0
+ def __init__(self, connection, client_address, server):
+ self.connection = connection
+ self.rfile = self.connection.makefile('rb', self.rbufsize)
+ self.wfile = self.connection.makefile('wb', self.wbufsize)
+
+ self.client_address = client_address
+ self.server = server
+ self.handle()
+ self.finish()
+
+ def convert_to_ssl(self, cert, key):
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
+ ctx.use_privatekey_file(key)
+ ctx.use_certificate_file(cert)
+ self.connection = SSL.Connection(ctx, self.connection)
+ self.connection.set_accept_state()
+ self.rfile = FileLike(self.connection)
+ self.wfile = FileLike(self.connection)
+
+ def finish(self):
+ try:
+ if not getattr(self.wfile, "closed", False):
+ self.wfile.flush()
+ self.connection.close()
+ self.wfile.close()
+ self.rfile.close()
+ except IOError: # pragma: no cover
+ pass
+
+ def handle(self): # pragma: no cover
+ raise NotImplementedError
+
+
+class TCPServer:
+ request_queue_size = 20
+ def __init__(self, server_address):
+ self.server_address = server_address
+ self.__is_shut_down = threading.Event()
+ self.__shutdown_request = False
+ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ self.socket.bind(self.server_address)
+ self.server_address = self.socket.getsockname()
+ self.socket.listen(self.request_queue_size)
+ self.port = self.socket.getsockname()[1]
+
+ def request_thread(self, request, client_address):
+ try:
+ self.handle_connection(request, client_address)
+ request.close()
+ except:
+ self.handle_error(request, client_address)
+ request.close()
+
+ def serve_forever(self, poll_interval=0.5):
+ self.__is_shut_down.clear()
+ try:
+ while not self.__shutdown_request:
+ r, w, e = select.select([self.socket], [], [], poll_interval)
+ if self.socket in r:
+ try:
+ request, client_address = self.socket.accept()
+ except socket.error:
+ return
+ try:
+ t = threading.Thread(
+ target = self.request_thread,
+ args = (request, client_address)
+ )
+ t.setDaemon(1)
+ t.start()
+ except:
+ self.handle_error(request, client_address)
+ request.close()
+ finally:
+ self.__shutdown_request = False
+ self.__is_shut_down.set()
+
+ def shutdown(self):
+ self.__shutdown_request = True
+ self.__is_shut_down.wait()
+ self.handle_shutdown()
+
+ def handle_error(self, request, client_address, fp=sys.stderr):
+ """
+ Called when handle_connection raises an exception.
+ """
+ print >> fp, '-'*40
+ print >> fp, "Error processing of request from %s:%s"%client_address
+ print >> fp, traceback.format_exc()
+ print >> fp, '-'*40
+
+ def handle_connection(self, request, client_address): # pragma: no cover
+ """
+ Called after client connection.
+ """
+ raise NotImplementedError
+
+ def handle_shutdown(self):
+ """
+ Called after server shutdown.
+ """
+ pass