aboutsummaryrefslogtreecommitdiffstats
path: root/mitmproxy/command.py
blob: 1c943cefece225d21a30f5760495d0c8ca2b8383 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import inspect
import typing
import shlex
import textwrap

from mitmproxy.utils import typecheck
from mitmproxy import exceptions
from mitmproxy import flow


def typename(t: type, ret: bool) -> str:
    """
        Translates a type to an explanatory string. Ifl ret is True, we're
        looking at a return type, else we're looking at a parameter type.
    """
    if t in (str, int, bool):
        return t.__name__
    if t == typing.Sequence[flow.Flow]:
        return "[flow]" if ret else "flowspec"
    else:  # pragma: no cover
        raise NotImplementedError(t)


class Command:
    def __init__(self, manager, path, func) -> None:
        self.path = path
        self.manager = manager
        self.func = func
        sig = inspect.signature(self.func)
        self.help = None
        if func.__doc__:
            txt = func.__doc__.strip()
            self.help = "\n".join(textwrap.wrap(txt))
        self.paramtypes = [v.annotation for v in sig.parameters.values()]
        self.returntype = sig.return_annotation

    def paramnames(self) -> typing.Sequence[str]:
        return [typename(i, False) for i in self.paramtypes]

    def retname(self) -> str:
        return typename(self.returntype, True) if self.returntype else ""

    def signature_help(self) -> str:
        params = " ".join(self.paramnames())
        ret = self.retname()
        if ret:
            ret = " -> " + ret
        return "%s %s%s" % (self.path, params, ret)

    def call(self, args: typing.Sequence[str]):
        """
            Call the command with a set of arguments. At this point, all argumets are strings.
        """
        if len(self.paramtypes) != len(args):
            raise exceptions.CommandError("Usage: %s" % self.signature_help())

        pargs = []
        for i in range(len(args)):
            pargs.append(parsearg(self.manager, args[i], self.paramtypes[i]))

        with self.manager.master.handlecontext():
            ret = self.func(*pargs)

        if not typecheck.check_command_return_type(ret, self.returntype):
            raise exceptions.CommandError("Command returned unexpected data")

        return ret


class CommandManager:
    def __init__(self, master):
        self.master = master
        self.commands = {}

    def add(self, path: str, func: typing.Callable):
        self.commands[path] = Command(self, path, func)

    def call_args(self, path, args):
        """
            Call a command using a list of string arguments. May raise CommandError.
        """
        if path not in self.commands:
            raise exceptions.CommandError("Unknown command: %s" % path)
        return self.commands[path].call(args)

    def call(self, cmdstr: str):
        """
            Call a command using a string. May raise CommandError.
        """
        parts = shlex.split(cmdstr)
        if not len(parts) >= 1:
            raise exceptions.CommandError("Invalid command: %s" % cmdstr)
        return self.call_args(parts[0], parts[1:])


def parsearg(manager: CommandManager, spec: str, argtype: type) -> typing.Any:
    """
        Convert a string to a argument to the appropriate type.
    """
    if argtype == str:
        return spec
    elif argtype == typing.Sequence[flow.Flow]:
        return manager.call_args("console.resolve", [spec])
    else:
        raise exceptions.CommandError("Unsupported argument type: %s" % argtype)