diff options
Diffstat (limited to 'libpathod/language.py')
-rw-r--r-- | libpathod/language.py | 154 |
1 files changed, 120 insertions, 34 deletions
diff --git a/libpathod/language.py b/libpathod/language.py index 29d2ade8..28976a29 100644 --- a/libpathod/language.py +++ b/libpathod/language.py @@ -15,6 +15,20 @@ BLOCKSIZE = 1024 TRUNCATE = 1024 +class Settings: + def __init__( + self, + staticdir = None, + unconstrained_file_access = False, + request_host = None, + websocket_key = None + ): + self.staticdir = staticdir + self.unconstrained_file_access = unconstrained_file_access + self.request_host = request_host + self.websocket_key = websocket_key + + def quote(s): quotechar = s[0] s = s[1:-1] @@ -22,7 +36,11 @@ def quote(s): return quotechar + s + quotechar -class FileAccessDenied(Exception): +class RenderError(Exception): + pass + + +class FileAccessDenied(RenderError): pass @@ -97,7 +115,7 @@ def write_values(fp, vals, actions, sofar=0, blocksize=BLOCKSIZE): return True -def serve(msg, fp, settings, **kwargs): +def serve(msg, fp, settings): """ fp: The file pointer to write to. @@ -107,7 +125,7 @@ def serve(msg, fp, settings, **kwargs): Calling this function may modify the object. """ - msg = msg.resolve(settings, **kwargs) + msg = msg.resolve(settings) started = time.time() vals = msg.values(settings) @@ -351,15 +369,16 @@ class ValueFile(_Token): return self def get_generator(self, settings): - uf = settings.get("unconstrained_file_access") - sd = settings.get("staticdir") - if not sd: + if not settings.staticdir: raise FileAccessDenied("File access disabled.") - sd = os.path.normpath(os.path.abspath(sd)) + sd = os.path.normpath(os.path.abspath(settings.staticdir)) s = os.path.expanduser(self.path) - s = os.path.normpath(os.path.abspath(os.path.join(sd, s))) - if not uf and not s.startswith(sd): + s = os.path.normpath( + os.path.abspath(os.path.join(settings.staticdir, s)) + ) + uf = settings.unconstrained_file_access + if not uf and not s.startswith(settings.staticdir): raise FileAccessDenied( "File access outside of configured directory" ) @@ -594,9 +613,28 @@ class Path(_Component): return Path(self.value.freeze(settings)) +class WS(_Component): + def __init__(self, value): + self.value = value + + @classmethod + def expr(klass): + spec = pp.Literal("ws") + spec = spec.setParseAction(lambda x: klass(*x)) + return spec + + def values(self, settings): + return "ws" + + def spec(self): + return "ws" + + def freeze(self, settings): + return self + + class Method(_Component): methods = [ - "ws", "get", "head", "post", @@ -797,29 +835,35 @@ class _Message(object): def __init__(self, tokens): self.tokens = tokens - def _get_tokens(self, klass): + def toks(self, klass): + """ + Fetch all tokens that are instances of klass + """ return [i for i in self.tokens if isinstance(i, klass)] - def _get_token(self, klass): - l = self._get_tokens(klass) + def tok(self, klass): + """ + Fetch first token that is an instance of klass + """ + l = self.toks(klass) if l: return l[0] @property def raw(self): - return bool(self._get_token(Raw)) + return bool(self.tok(Raw)) @property def actions(self): - return self._get_tokens(_Action) + return self.toks(_Action) @property def body(self): - return self._get_token(Body) + return self.tok(Body) @property def headers(self): - return self._get_tokens(_Header) + return self.toks(_Header) def length(self, settings): """ @@ -883,8 +927,8 @@ class _Message(object): vals.append(self.body.value.get_generator(settings)) return vals - def freeze(self, settings, **kwargs): - r = self.resolve(settings, **kwargs) + def freeze(self, settings): + r = self.resolve(settings) return self.__class__([i.freeze(settings) for i in r.tokens]) def __repr__(self): @@ -909,16 +953,25 @@ class Response(_Message): logattrs = ["code", "reason", "version", "body"] @property + def ws(self): + return self.tok(WS) + + @property def code(self): - return self._get_token(Code) + return self.tok(Code) @property def reason(self): - return self._get_token(Reason) + return self.tok(Reason) def preamble(self, settings): l = [self.version, " "] - l.extend(self.code.values(settings)) + if self.code: + l.extend(self.code.values(settings)) + code = int(self.code.code) + elif self.ws: + l.extend(Code(101).values(settings)) + code = 101 l.append(" ") if self.reason: l.extend(self.reason.values(settings)) @@ -926,7 +979,7 @@ class Response(_Message): l.append( LiteralGenerator( http_status.RESPONSES.get( - int(self.code.code), + code, "Unknown code" ) ) @@ -935,6 +988,22 @@ class Response(_Message): def resolve(self, settings): tokens = self.tokens[:] + if self.ws: + if not settings.websocket_key: + raise RenderError( + "No websocket key - have we seen a client handshake?" + ) + if not self.code: + tokens.insert( + 1, + Code(101) + ) + hdrs = websockets.server_handshake_headers(settings.websocket_key) + for i in hdrs.lst: + if not utils.get_header(i[0], self.headers): + tokens.append( + Header(ValueLiteral(i[0]), ValueLiteral(i[1])) + ) if not self.raw: if not utils.get_header("Content-Length", self.headers): if not self.body: @@ -958,7 +1027,12 @@ class Response(_Message): atom = pp.MatchFirst(parts) resp = pp.And( [ - Code.expr(), + pp.MatchFirst( + [ + WS.expr() + pp.Optional(Sep + Code.expr()), + Code.expr(), + ] + ), pp.ZeroOrMore(Sep + atom) ] ) @@ -983,16 +1057,20 @@ class Request(_Message): logattrs = ["method", "path", "body"] @property + def ws(self): + return self.tok(WS) + + @property def method(self): - return self._get_token(Method) + return self.tok(Method) @property def path(self): - return self._get_token(Path) + return self.tok(Path) @property def pathodspec(self): - return self._get_token(PathodSpec) + return self.tok(PathodSpec) def preamble(self, settings): v = self.method.values(settings) @@ -1004,10 +1082,14 @@ class Request(_Message): v.append(self.version) return v - def resolve(self, settings, **kwargs): + def resolve(self, settings): tokens = self.tokens[:] - if self.method.string().lower() == "ws": - tokens[0] = Method("get") + if self.ws: + if not self.method: + tokens.insert( + 1, + Method("get") + ) for i in websockets.client_handshake_headers().lst: if not utils.get_header(i[0], self.headers): tokens.append( @@ -1023,13 +1105,12 @@ class Request(_Message): ValueLiteral(str(length)), ) ) - request_host = kwargs.get("request_host") - if request_host: + if settings.request_host: if not utils.get_header("Host", self.headers): tokens.append( Header( ValueLiteral("Host"), - ValueLiteral(request_host) + ValueLiteral(settings.request_host) ) ) intermediate = self.__class__(tokens) @@ -1043,7 +1124,12 @@ class Request(_Message): atom = pp.MatchFirst(parts) resp = pp.And( [ - Method.expr(), + pp.MatchFirst( + [ + WS.expr() + pp.Optional(Sep + Method.expr()), + Method.expr(), + ] + ), Sep, Path.expr(), pp.ZeroOrMore(Sep + atom) |