from __future__ import print_function import copy import logging import os import sys import threading from netlib import tcp from netlib import certutils from netlib import websockets from netlib import version import urllib from netlib.exceptions import HttpException, HttpReadDisconnect, TcpTimeout, TcpDisconnect, \ TlsException from . import language, utils, log, protocols DEFAULT_CERT_DOMAIN = b"pathod.net" CONFDIR = "~/.mitmproxy" CERTSTORE_BASENAME = "mitmproxy" CA_CERT_NAME = "mitmproxy-ca.pem" DEFAULT_CRAFT_ANCHOR = "/p/" logger = logging.getLogger('pathod') class PathodError(Exception): pass class SSLOptions(object): def __init__( self, confdir=CONFDIR, cn=None, sans=(), not_after_connect=None, request_client_cert=False, ssl_version=tcp.SSL_DEFAULT_METHOD, ssl_options=tcp.SSL_DEFAULT_OPTIONS, ciphers=None, certs=None, alpn_select=b'h2', ): self.confdir = confdir self.cn = cn self.sans = sans self.not_after_connect = not_after_connect self.request_client_cert = request_client_cert self.ssl_version = ssl_version self.ssl_options = ssl_options self.ciphers = ciphers self.alpn_select = alpn_select self.certstore = certutils.CertStore.from_store( os.path.expanduser(confdir), CERTSTORE_BASENAME ) for i in certs or []: self.certstore.add_cert_file(*i) def get_cert(self, name): if self.cn: name = self.cn elif not name: name = DEFAULT_CERT_DOMAIN return self.certstore.get_cert(name, self.sans) class PathodHandler(tcp.BaseHandler): wbufsize = 0 sni = None def __init__( self, connection, address, server, logfp, settings, http2_framedump=False ): tcp.BaseHandler.__init__(self, connection, address, server) self.logfp = logfp self.settings = copy.copy(settings) self.protocol = None self.use_http2 = False self.http2_framedump = http2_framedump def handle_sni(self, connection): sni = connection.get_servername() if sni: sni = sni.decode("idna") self.sni = sni def http_serve_crafted(self, crafted, logctx): error, crafted = self.server.check_policy( crafted, self.settings ) if error: err = self.make_http_error_response(error) language.serve(err, self.wfile, self.settings) return None, dict( type="error", msg=error ) if self.server.explain and not hasattr(crafted, 'is_error_response'): crafted = crafted.freeze(self.settings) logctx(">> Spec: %s" % crafted.spec()) response_log = language.serve( crafted, self.wfile, self.settings ) if response_log["disconnect"]: return None, response_log return self.handle_http_request, response_log def handle_http_request(self, logger): """ Returns a (handler, log) tuple. handler: Handler for the next request, or None to disconnect log: A dictionary, or None """ with logger.ctx() as lg: try: req = self.protocol.read_request(self.rfile) except HttpReadDisconnect: return None, None except HttpException as s: s = str(s) lg(s) return None, dict(type="error", msg=s) if req.method == 'CONNECT': return self.protocol.handle_http_connect([req.host, req.port, req.http_version], lg) method = req.method path = req.path http_version = req.http_version headers = req.headers clientcert = None if self.clientcert: clientcert = dict( cn=self.clientcert.cn, subject=self.clientcert.subject, serial=self.clientcert.serial, notbefore=self.clientcert.notbefore.isoformat(), notafter=self.clientcert.notafter.isoformat(), keyinfo=self.clientcert.keyinfo, ) retlog = dict( type="crafted", protocol="http", request=dict( path=path, method=method, headers=headers.fields, http_version=http_version, sni=self.sni, remote_address=self.address(), clientcert=clientcert, ), cipher=None, ) if self.ssl_established: retlog["cipher"] = self.get_current_cipher() m = utils.MemBool() valid_websocket_handshake = websockets.check_handshake(headers) self.settings.websocket_key = websockets.get_client_key(headers) # If this is a websocket initiation, we respond with a proper # server response, unless over-ridden. if valid_websocket_handshake: anchor_gen = language.parse_pathod("ws") else: anchor_gen = None for regex, spec in self.server.anchors: if regex.match(path): anchor_gen = language.parse_pathod(spec, self.use_http2) break else: if m(path.startswith(self.server.craftanchor)): spec = urllib.parse.unquote(path)[len(self.server.craftanchor):] if spec: try: anchor_gen = language.parse_pathod(spec, self.use_http2) except language.ParseException as v: lg("Parse error: %s" % v.msg) anchor_gen = iter([self.make_http_error_response( "Parse Error", "Error parsing response spec: %s\n" % ( v.msg + v.marked() ) )]) else: if self.use_http2: anchor_gen = iter([self.make_http_error_response( "Spec Error", "HTTP/2 only supports request/response with the craft anchor point: %s" % self.server.craftanchor )]) if not anchor_gen: anchor_gen = iter([self.make_http_error_response( "Not found", "No valid craft request found" )]) spec = next(anchor_gen) if self.use_http2 and isinstance(spec, language.http2.Response): spec.stream_id = req.stream_id lg("crafting spec: %s" % spec) nexthandler, retlog["response"] = self.http_serve_crafted( spec, lg ) if nexthandler and valid_websocket_handshake: self.protocol = protocols.websockets.WebsocketsProtocol(self) return self.protocol.handle_websocket, retlog else: return nexthandler, retlog def make_http_error_response(self, reason, body=None): resp = self.protocol.make_error_response(reason, body) resp.is_error_response = True return resp def handle(self): self.settimeout(self.server.timeout) if self.server.ssl: try: cert, key, _ = self.server.ssloptions.get_cert(None) self.convert_to_ssl( cert, key, handle_sni=self.handle_sni, request_client_cert=self.server.ssloptions.request_client_cert, cipher_list=self.server.ssloptions.ciphers, method=self.server.ssloptions.ssl_version, options=self.server.ssloptions.ssl_options, alpn_select=self.server.ssloptions.alpn_select, ) except TlsException as v: s = str(v) self.server.add_log( dict( type="error", msg=s ) ) log.write_raw(self.logfp, s) return alp = self.get_alpn_proto_negotiated() if alp == b'h2': self.protocol = protocols.http2.HTTP2Protocol(self) self.use_http2 = True if not self.protocol: self.protocol = protocols.http.HTTPProtocol(self) lr = self.rfile if self.server.logreq else None lw = self.wfile if self.server.logresp else None logger = log.ConnectionLogger(self.logfp, self.server.hexdump, True, lr, lw) self.settings.protocol = self.protocol handler = self.handle_http_request while not self.finished: handler, l = handler(logger) if l: self.addlog(l) if not handler: return def addlog(self, log): if self.server.logreq: log["request_bytes"] = self.rfile.get_log() if self.server.logresp: log["response_bytes"] = self.wfile.get_log() self.server.add_log(log) class Pathod(tcp.TCPServer): LOGBUF = 500 def __init__( self, addr, ssl=False, ssloptions=None, craftanchor=DEFAULT_CRAFT_ANCHOR, staticdir=None, anchors=(), sizelimit=None, nocraft=False, nohang=False, timeout=None, logreq=False, logresp=False, explain=False, hexdump=False, http2_framedump=False, webdebug=False, logfp=sys.stdout, ): """ addr: (address, port) tuple. If port is 0, a free port will be automatically chosen. ssloptions: an SSLOptions object. craftanchor: URL prefix specifying the path under which to anchor response generation. staticdir: path to a directory of static resources, or None. anchors: List of (regex object, language.Request object) tuples, or None. sizelimit: Limit size of served data. nocraft: Disable response crafting. nohang: Disable pauses. """ tcp.TCPServer.__init__(self, addr) self.ssl = ssl self.ssloptions = ssloptions or SSLOptions() self.staticdir = staticdir self.craftanchor = craftanchor self.sizelimit = sizelimit self.nocraft = nocraft self.nohang = nohang self.timeout, self.logreq = timeout, logreq self.logresp, self.hexdump = logresp, hexdump self.http2_framedump = http2_framedump self.explain = explain self.logfp = logfp self.log = [] self.logid = 0 self.anchors = anchors self.settings = language.Settings( staticdir=self.staticdir ) self.loglock = threading.Lock() def check_policy(self, req, settings): """ A policy check that verifies the request size is within limits. """ if self.nocraft: return "Crafting disabled.", None try: req = req.resolve(settings) l = req.maximum_length(settings) except language.FileAccessDenied: return "File access denied.", None if self.sizelimit and l > self.sizelimit: return "Response too large.", None pauses = [isinstance(i, language.actions.PauseAt) for i in req.actions] if self.nohang and any(pauses): return "Pauses have been disabled.", None return None, req def handle_client_connection(self, request, client_address): h = PathodHandler( request, client_address, self, self.logfp, self.settings, self.http2_framedump, ) try: h.handle() h.finish() except TcpDisconnect: # pragma: no cover log.write_raw(self.logfp, "Disconnect") self.add_log( dict( type="error", msg="Disconnect" ) ) return except TcpTimeout: log.write_raw(self.logfp, "Timeout") self.add_log( dict( type="timeout", ) ) return def add_log(self, d): with self.loglock: d["id"] = self.logid self.log.insert(0, d) if len(self.log) > self.LOGBUF: self.log.pop() self.logid += 1 return d["id"] def clear_log(self): with self.loglock: self.log = [] def log_by_id(self, identifier): with self.loglock: for i in self.log: if i["id"] == identifier: return i def get_log(self): with self.loglock: return self.log def main(args): # pragma: no cover ssloptions = SSLOptions( cn=args.cn, confdir=args.confdir, not_after_connect=args.ssl_not_after_connect, ciphers=args.ciphers, ssl_version=args.ssl_version, ssl_options=args.ssl_options, certs=args.ssl_certs, sans=args.sans, ) root = logging.getLogger() if root.handlers: for handler in root.handlers: root.removeHandler(handler) log = logging.getLogger('pathod') log.setLevel(logging.DEBUG) fmt = logging.Formatter( '%(asctime)s: %(message)s', datefmt='%d-%m-%y %H:%M:%S', ) if args.logfile: fh = logging.handlers.WatchedFileHandler(args.logfile) fh.setFormatter(fmt) log.addHandler(fh) if not args.daemonize: sh = logging.StreamHandler() sh.setFormatter(fmt) log.addHandler(sh) try: pd = Pathod( (args.address, args.port), craftanchor=args.craftanchor, ssl=args.ssl, ssloptions=ssloptions, staticdir=args.staticdir, anchors=args.anchors, sizelimit=args.sizelimit, nocraft=args.nocraft, nohang=args.nohang, timeout=args.timeout, logreq=args.logreq, logresp=args.logresp, hexdump=args.hexdump, http2_framedump=args.http2_framedump, explain=args.explain, webdebug=args.webdebug ) except PathodError as v: print("Error: %s" % v, file=sys.stderr) sys.exit(1) except language.FileAccessDenied as v: print("Error: %s" % v, file=sys.stderr) if args.daemonize: utils.daemonize() try: print("%s listening on %s" % ( version.PATHOD, repr(pd.address) )) pd.serve_forever() except KeyboardInterrupt: pass