aboutsummaryrefslogtreecommitdiffstats
path: root/libpathod/language.py
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2015-04-22 15:49:17 +1200
committerAldo Cortesi <aldo@nullcube.com>2015-04-22 15:49:17 +1200
commit99cb0808abfa3bd5bbc8d19c10756641c032dc48 (patch)
treef93b7a76f878eb17d5ef97787c01202559d70dc2 /libpathod/language.py
parent65f04bf4d103ea7ab9307946ccc11a17de51d93f (diff)
downloadmitmproxy-99cb0808abfa3bd5bbc8d19c10756641c032dc48.tar.gz
mitmproxy-99cb0808abfa3bd5bbc8d19c10756641c032dc48.tar.bz2
mitmproxy-99cb0808abfa3bd5bbc8d19c10756641c032dc48.zip
websockets: server handshake scheme
Also refactor settings and resolution interfaces
Diffstat (limited to 'libpathod/language.py')
-rw-r--r--libpathod/language.py154
1 files changed, 120 insertions, 34 deletions
diff --git a/libpathod/language.py b/libpathod/language.py
index 29d2ade8..28976a29 100644
--- a/libpathod/language.py
+++ b/libpathod/language.py
@@ -15,6 +15,20 @@ BLOCKSIZE = 1024
TRUNCATE = 1024
+class Settings:
+ def __init__(
+ self,
+ staticdir = None,
+ unconstrained_file_access = False,
+ request_host = None,
+ websocket_key = None
+ ):
+ self.staticdir = staticdir
+ self.unconstrained_file_access = unconstrained_file_access
+ self.request_host = request_host
+ self.websocket_key = websocket_key
+
+
def quote(s):
quotechar = s[0]
s = s[1:-1]
@@ -22,7 +36,11 @@ def quote(s):
return quotechar + s + quotechar
-class FileAccessDenied(Exception):
+class RenderError(Exception):
+ pass
+
+
+class FileAccessDenied(RenderError):
pass
@@ -97,7 +115,7 @@ def write_values(fp, vals, actions, sofar=0, blocksize=BLOCKSIZE):
return True
-def serve(msg, fp, settings, **kwargs):
+def serve(msg, fp, settings):
"""
fp: The file pointer to write to.
@@ -107,7 +125,7 @@ def serve(msg, fp, settings, **kwargs):
Calling this function may modify the object.
"""
- msg = msg.resolve(settings, **kwargs)
+ msg = msg.resolve(settings)
started = time.time()
vals = msg.values(settings)
@@ -351,15 +369,16 @@ class ValueFile(_Token):
return self
def get_generator(self, settings):
- uf = settings.get("unconstrained_file_access")
- sd = settings.get("staticdir")
- if not sd:
+ if not settings.staticdir:
raise FileAccessDenied("File access disabled.")
- sd = os.path.normpath(os.path.abspath(sd))
+ sd = os.path.normpath(os.path.abspath(settings.staticdir))
s = os.path.expanduser(self.path)
- s = os.path.normpath(os.path.abspath(os.path.join(sd, s)))
- if not uf and not s.startswith(sd):
+ s = os.path.normpath(
+ os.path.abspath(os.path.join(settings.staticdir, s))
+ )
+ uf = settings.unconstrained_file_access
+ if not uf and not s.startswith(settings.staticdir):
raise FileAccessDenied(
"File access outside of configured directory"
)
@@ -594,9 +613,28 @@ class Path(_Component):
return Path(self.value.freeze(settings))
+class WS(_Component):
+ def __init__(self, value):
+ self.value = value
+
+ @classmethod
+ def expr(klass):
+ spec = pp.Literal("ws")
+ spec = spec.setParseAction(lambda x: klass(*x))
+ return spec
+
+ def values(self, settings):
+ return "ws"
+
+ def spec(self):
+ return "ws"
+
+ def freeze(self, settings):
+ return self
+
+
class Method(_Component):
methods = [
- "ws",
"get",
"head",
"post",
@@ -797,29 +835,35 @@ class _Message(object):
def __init__(self, tokens):
self.tokens = tokens
- def _get_tokens(self, klass):
+ def toks(self, klass):
+ """
+ Fetch all tokens that are instances of klass
+ """
return [i for i in self.tokens if isinstance(i, klass)]
- def _get_token(self, klass):
- l = self._get_tokens(klass)
+ def tok(self, klass):
+ """
+ Fetch first token that is an instance of klass
+ """
+ l = self.toks(klass)
if l:
return l[0]
@property
def raw(self):
- return bool(self._get_token(Raw))
+ return bool(self.tok(Raw))
@property
def actions(self):
- return self._get_tokens(_Action)
+ return self.toks(_Action)
@property
def body(self):
- return self._get_token(Body)
+ return self.tok(Body)
@property
def headers(self):
- return self._get_tokens(_Header)
+ return self.toks(_Header)
def length(self, settings):
"""
@@ -883,8 +927,8 @@ class _Message(object):
vals.append(self.body.value.get_generator(settings))
return vals
- def freeze(self, settings, **kwargs):
- r = self.resolve(settings, **kwargs)
+ def freeze(self, settings):
+ r = self.resolve(settings)
return self.__class__([i.freeze(settings) for i in r.tokens])
def __repr__(self):
@@ -909,16 +953,25 @@ class Response(_Message):
logattrs = ["code", "reason", "version", "body"]
@property
+ def ws(self):
+ return self.tok(WS)
+
+ @property
def code(self):
- return self._get_token(Code)
+ return self.tok(Code)
@property
def reason(self):
- return self._get_token(Reason)
+ return self.tok(Reason)
def preamble(self, settings):
l = [self.version, " "]
- l.extend(self.code.values(settings))
+ if self.code:
+ l.extend(self.code.values(settings))
+ code = int(self.code.code)
+ elif self.ws:
+ l.extend(Code(101).values(settings))
+ code = 101
l.append(" ")
if self.reason:
l.extend(self.reason.values(settings))
@@ -926,7 +979,7 @@ class Response(_Message):
l.append(
LiteralGenerator(
http_status.RESPONSES.get(
- int(self.code.code),
+ code,
"Unknown code"
)
)
@@ -935,6 +988,22 @@ class Response(_Message):
def resolve(self, settings):
tokens = self.tokens[:]
+ if self.ws:
+ if not settings.websocket_key:
+ raise RenderError(
+ "No websocket key - have we seen a client handshake?"
+ )
+ if not self.code:
+ tokens.insert(
+ 1,
+ Code(101)
+ )
+ hdrs = websockets.server_handshake_headers(settings.websocket_key)
+ for i in hdrs.lst:
+ if not utils.get_header(i[0], self.headers):
+ tokens.append(
+ Header(ValueLiteral(i[0]), ValueLiteral(i[1]))
+ )
if not self.raw:
if not utils.get_header("Content-Length", self.headers):
if not self.body:
@@ -958,7 +1027,12 @@ class Response(_Message):
atom = pp.MatchFirst(parts)
resp = pp.And(
[
- Code.expr(),
+ pp.MatchFirst(
+ [
+ WS.expr() + pp.Optional(Sep + Code.expr()),
+ Code.expr(),
+ ]
+ ),
pp.ZeroOrMore(Sep + atom)
]
)
@@ -983,16 +1057,20 @@ class Request(_Message):
logattrs = ["method", "path", "body"]
@property
+ def ws(self):
+ return self.tok(WS)
+
+ @property
def method(self):
- return self._get_token(Method)
+ return self.tok(Method)
@property
def path(self):
- return self._get_token(Path)
+ return self.tok(Path)
@property
def pathodspec(self):
- return self._get_token(PathodSpec)
+ return self.tok(PathodSpec)
def preamble(self, settings):
v = self.method.values(settings)
@@ -1004,10 +1082,14 @@ class Request(_Message):
v.append(self.version)
return v
- def resolve(self, settings, **kwargs):
+ def resolve(self, settings):
tokens = self.tokens[:]
- if self.method.string().lower() == "ws":
- tokens[0] = Method("get")
+ if self.ws:
+ if not self.method:
+ tokens.insert(
+ 1,
+ Method("get")
+ )
for i in websockets.client_handshake_headers().lst:
if not utils.get_header(i[0], self.headers):
tokens.append(
@@ -1023,13 +1105,12 @@ class Request(_Message):
ValueLiteral(str(length)),
)
)
- request_host = kwargs.get("request_host")
- if request_host:
+ if settings.request_host:
if not utils.get_header("Host", self.headers):
tokens.append(
Header(
ValueLiteral("Host"),
- ValueLiteral(request_host)
+ ValueLiteral(settings.request_host)
)
)
intermediate = self.__class__(tokens)
@@ -1043,7 +1124,12 @@ class Request(_Message):
atom = pp.MatchFirst(parts)
resp = pp.And(
[
- Method.expr(),
+ pp.MatchFirst(
+ [
+ WS.expr() + pp.Optional(Sep + Method.expr()),
+ Method.expr(),
+ ]
+ ),
Sep,
Path.expr(),
pp.ZeroOrMore(Sep + atom)