aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--libmproxy/proxy.py37
-rw-r--r--libmproxy/tcpserver.py88
2 files changed, 99 insertions, 26 deletions
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py
index bcbc8ea5..89493e79 100644
--- a/libmproxy/proxy.py
+++ b/libmproxy/proxy.py
@@ -21,7 +21,7 @@
import sys, os, string, socket, time
import shutil, tempfile, threading
import optparse, SocketServer
-import utils, flow, certutils, version, wsgi
+import utils, flow, certutils, version, wsgi, tcpserver
from OpenSSL import SSL
@@ -50,7 +50,7 @@ class ProxyConfig:
def read_headers(fp):
"""
Read a set of headers from a file pointer. Stop once a blank line
- is reached. Return a ODict object.
+ is reached. Return a ODictCaseless object.
"""
ret = []
name = ''
@@ -374,13 +374,13 @@ class ServerConnection:
pass
-class ProxyHandler(SocketServer.StreamRequestHandler):
- def __init__(self, config, request, client_address, server, q):
+class ProxyHandler(tcpserver.BaseHandler):
+ def __init__(self, config, connection, client_address, server, q):
self.mqueue = q
self.config = config
self.server_conn = None
self.proxy_connect_state = None
- SocketServer.StreamRequestHandler.__init__(self, request, client_address, server)
+ tcpserver.BaseHandler.__init__(self, connection, client_address, server)
def handle(self):
cc = flow.ClientConnect(self.client_address)
@@ -390,7 +390,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
cc.close = True
cd = flow.ClientDisconnect(cc)
cd._send(self.mqueue)
- self.finish()
def server_connect(self, scheme, host, port):
sc = self.server_conn
@@ -554,18 +553,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
self.wfile.write(d)
self.wfile.flush()
- def terminate(self, connection, wfile, rfile):
- self.request.close()
- try:
- if not getattr(wfile, "closed", False):
- wfile.flush()
- connection.close()
- except IOError:
- pass
-
- def finish(self):
- self.terminate(self.connection, self.wfile, self.rfile)
-
def send_error(self, code, body):
try:
import BaseHTTPServer
@@ -584,10 +571,8 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
class ProxyServerError(Exception): pass
-ServerBase = SocketServer.ThreadingTCPServer
-ServerBase.daemon_threads = True # Terminate workers when main thread terminates
-class ProxyServer(ServerBase):
- request_queue_size = 20
+
+class ProxyServer(tcpserver.TCPServer):
allow_reuse_address = True
bound = True
def __init__(self, config, port, address=''):
@@ -596,7 +581,7 @@ class ProxyServer(ServerBase):
"""
self.config, self.port, self.address = config, port, address
try:
- ServerBase.__init__(self, (address, port), ProxyHandler)
+ tcpserver.TCPServer.__init__(self, (address, port))
except socket.error, v:
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
self.masterq = None
@@ -611,11 +596,11 @@ class ProxyServer(ServerBase):
def set_mqueue(self, q):
self.masterq = q
- def finish_request(self, request, client_address):
- self.RequestHandlerClass(self.config, request, client_address, self, self.masterq)
+ def handle_connection(self, request, client_address):
+ ProxyHandler(self.config, request, client_address, self, self.masterq)
def shutdown(self):
- ServerBase.shutdown(self)
+ tcpserver.TCPServer.shutdown(self)
try:
shutil.rmtree(self.certdir)
except OSError:
diff --git a/libmproxy/tcpserver.py b/libmproxy/tcpserver.py
new file mode 100644
index 00000000..bf7ed0b4
--- /dev/null
+++ b/libmproxy/tcpserver.py
@@ -0,0 +1,88 @@
+import select, socket, threading
+
+class BaseHandler:
+ rbufsize = -1
+ wbufsize = 0
+ def __init__(self, connection, client_address, server):
+ self.connection = connection
+ self.rfile = self.connection.makefile('rb', self.rbufsize)
+ self.wfile = self.connection.makefile('wb', self.wbufsize)
+
+ self.client_address = client_address
+ self.server = server
+ self.handle()
+ self.finish()
+
+ def finish(self):
+ try:
+ if not getattr(self.wfile, "closed", False):
+ self.wfile.flush()
+ self.connection.close()
+ self.wfile.close()
+ self.rfile.close()
+ except IOError:
+ pass
+
+ def handle(self):
+ raise NotImplementedError
+
+
+class TCPServer:
+ request_queue_size = 20
+ def __init__(self, server_address):
+ self.server_address = server_address
+ self.__is_shut_down = threading.Event()
+ self.__shutdown_request = False
+ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ self.socket.bind(self.server_address)
+ self.server_address = self.socket.getsockname()
+ self.socket.listen(self.request_queue_size)
+
+ def fileno(self):
+ return self.socket.fileno()
+
+ def request_thread(self, request, client_address):
+ try:
+ self.handle_connection(request, client_address)
+ request.close()
+ except:
+ self.handle_error(request, client_address)
+ request.close()
+
+ def serve_forever(self, poll_interval=0.5):
+ self.__is_shut_down.clear()
+ try:
+ while not self.__shutdown_request:
+ r, w, e = select.select([self], [], [], poll_interval)
+ if self in r:
+ try:
+ request, client_address = self.socket.accept()
+ except socket.error:
+ return
+ try:
+ t = threading.Thread(target = self.request_thread,
+ args = (request, client_address))
+ t.setDaemon (1)
+ t.start()
+ except:
+ self.handle_error(request, client_address)
+ request.close()
+ finally:
+ self.__shutdown_request = False
+ self.__is_shut_down.set()
+
+ def shutdown(self):
+ self.__shutdown_request = True
+ self.__is_shut_down.wait()
+
+ def handle_error(self, request, client_address):
+ print '-'*40
+ print 'Exception happened during processing of request from',
+ print client_address
+ import traceback
+ traceback.print_exc() # XXX But this goes to stderr!
+ print '-'*40
+
+ def handle_connection(self, request, client_address):
+ raise NotImplementedError