aboutsummaryrefslogtreecommitdiffstats
path: root/libmproxy/protocol.py
blob: c6223164480bea08f5afa93911f2e053038dffda (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
from libmproxy import flow
import libmproxy.utils
from netlib import http, tcp
import netlib.utils
from netlib.odict import ODictCaseless

KILL = 0 # FIXME: Remove duplication with proxy module
LEGACY = True

class ProtocolError(Exception):
    def __init__(self, code, msg, headers=None):
        self.code, self.msg, self.headers = code, msg, headers

    def __str__(self):
        return "ProtocolError(%s, %s)"%(self.code, self.msg)


def _handle(msg, conntype, connection_handler, *args, **kwargs):
    handler = None
    if conntype == "http":
        handler = HTTPHandler(connection_handler)
    else:
        raise NotImplementedError

    f = getattr(handler, "handle_" + msg)
    return f(*args, **kwargs)


def handle_messages(conntype, connection_handler):
    _handle("messages", conntype, connection_handler)


class ConnectionTypeChange(Exception):
    pass


class ProtocolHandler(object):
    def __init__(self, c):
        self.c = c


class Flow(object):
    def __init__(self, client_conn, server_conn, timestamp_start, timestamp_end):
        self.client_conn, self.server_conn = client_conn, server_conn
        self.timestamp_start, self.timestamp_end = timestamp_start, timestamp_end


class HTTPFlow(Flow):
    def __init__(self, client_conn, server_conn, timestamp_start, timestamp_end, request, response):
        Flow.__init__(self, client_conn, server_conn,
                      timestamp_start, timestamp_end)
        self.request, self.response = request, response


class HTTPMessage(object):
    def _assemble_headers(self):
        headers = self.headers.copy()
        libmproxy.utils.del_all(headers,
                                ["proxy-connection",
                                 "transfer-encoding"])
        if self.content:
            headers["Content-Length"] = [str(len(self.content))]
        elif 'Transfer-Encoding' in self.headers:  # content-length for e.g. chuncked transfer-encoding with no content
            headers["Content-Length"] = ["0"]

        return str(headers)

class HTTPResponse(HTTPMessage):
    def __init__(self, http_version, code, msg, headers, content, timestamp_start, timestamp_end):
        self.http_version = http_version
        self.code = code
        self.msg = msg
        self.headers = headers
        self.content = content
        self.timestamp_start = timestamp_start
        self.timestamp_end = timestamp_end

        assert isinstance(headers, ODictCaseless)

    #FIXME: Legacy
    @property
    def request(self):
        return False

    def _assemble(self):
        response_line = 'HTTP/%s.%s %s %s'%(self.http_version[0], self.http_version[1], self.code, self.msg)
        return '%s\r\n%s\r\n%s' % (response_line, self._assemble_headers(), self.content)

    @classmethod
    def from_stream(cls, rfile, request_method, include_content=True, body_size_limit=None):
        """
        Parse an HTTP response from a file stream
        """
        if not include_content:
            raise NotImplementedError

        timestamp_start = libmproxy.utils.timestamp()
        http_version, code, msg, headers, content = http.read_response(
            rfile,
            request_method,
            body_size_limit)
        timestamp_end = libmproxy.utils.timestamp()
        return HTTPResponse(http_version, code, msg, headers, content, timestamp_start, timestamp_end)


class HTTPRequest(HTTPMessage):
    def __init__(self, form_in, method, scheme, host, port, path, http_version, headers, content,
                 timestamp_start, timestamp_end, form_out=None, ip=None):
        self.form_in = form_in
        self.method = method
        self.scheme = scheme
        self.host = host
        self.port = port
        self.path = path
        self.http_version = http_version
        self.headers = headers
        self.content = content
        self.timestamp_start = timestamp_start
        self.timestamp_end = timestamp_end

        self.form_out = form_out or self.form_in
        self.ip = ip  # resolved ip address
        assert isinstance(headers, ODictCaseless)

    #FIXME: Remove, legacy
    def is_live(self):
        return True

    def _assemble(self):
        request_line = None
        if self.form_out == "asterisk" or self.form_out == "origin":
            request_line = '%s %s HTTP/%s.%s' % (self.method, self.path, self.http_version[0], self.http_version[1])
        elif self.form_out == "authority":
            request_line = '%s %s:%s HTTP/%s.%s' % (self.method, self.host, self.port,
                                                    self.http_version[0], self.http_version[1])
        elif self.form_out == "absolute":
            request_line = '%s %s://%s:%s%s HTTP/%s.%s' % \
                           (self.method, self.scheme, self.host, self.port, self.path,
                            self.http_version[0], self.http_version[1])
        else:
            raise http.HttpError(400, "Invalid request form")

        return '%s\r\n%s\r\n%s' % (request_line, self._assemble_headers(), self.content)

    @classmethod
    def from_stream(cls, rfile, include_content=True, body_size_limit=None):
        """
        Parse an HTTP request from a file stream
        """
        http_version, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end \
            = None, None, None, None, None, None, None, None, None, None

        timestamp_start = libmproxy.utils.timestamp()
        request_line = HTTPHandler.get_line(rfile)

        request_line_parts = http.parse_init(request_line)
        if not request_line_parts:
            raise ProtocolError(400, "Bad HTTP request line: %s"%repr(request_line))
        method, path, http_version = request_line_parts

        if path == '*':
            form_in = "asterisk"
        elif path.startswith("/"):
            form_in = "origin"
            if not netlib.utils.isascii(path):
                raise ProtocolError(400, "Bad HTTP request line: %s"%repr(request_line))
        elif method.upper() == 'CONNECT':
            form_in = "authority"
            r = http.parse_init_connect(request_line)
            if not r:
                raise ProtocolError(400, "Bad HTTP request line: %s"%repr(request_line))
            host, port, _ = r
        else:
            form_in = "absolute"
            r = http.parse_init_proxy(request_line)
            if not r:
                raise ProtocolError(400, "Bad HTTP request line: %s"%repr(request_line))
            _, scheme, host, port, path, _ = r

        headers = http.read_headers(rfile)
        if headers is None:
            raise ProtocolError(400, "Invalid headers")

        if include_content:
            content = http.read_http_body(rfile, headers, body_size_limit, True)
            timestamp_end = libmproxy.utils.timestamp()

        return HTTPRequest(form_in, method, scheme, host, port, path, http_version, headers, content,
                           timestamp_start, timestamp_end)


class HTTPHandler(ProtocolHandler):

    def handle_messages(self):
        while self.handle_request():
            pass
        self.c.close = True

    def handle_error(self, e):
        raise e  # FIXME: Proper error handling

    def handle_request(self):
        try:
            flow = HTTPFlow(self.c.client_conn, self.c.server_conn, libmproxy.utils.timestamp(), None, None, None)
            flow.request = self.read_request()

            request_reply = self.c.channel.ask("request" if LEGACY else "httprequest", flow.request if LEGACY else flow)
            if request_reply is None or request_reply == KILL:
                return False

            if isinstance(request_reply, HTTPResponse):
                flow.response = request_reply
            else:
                raw = flow.request._assemble()
                self.c.server_conn.wfile.write(raw)
                self.c.server_conn.wfile.flush()
                flow.response = self.read_response(flow)

            response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse",
                                                flow.response if LEGACY else flow)
            if response_reply is None or response_reply == KILL:
                return False

            raw = flow.response._assemble()
            self.c.client_conn.wfile.write(raw)
            self.c.client_conn.wfile.flush()
            flow.timestamp_end = libmproxy.utils.timestamp()

            if (http.connection_close(flow.request.http_version, flow.request.headers) or
                    http.connection_close(flow.response.http_version, flow.response.headers)):
                return False

            if flow.request.form_in == "authority":
                self.ssl_upgrade()
            return flow
        except ProtocolError, http.HttpError:
            raise NotImplementedError
            # FIXME: Implement error handling
            return False

    def ssl_upgrade(self):
        self.c.mode = "transparent"
        self.c.determine_conntype()
        self.c.establish_ssl(server=True, client=True)
        raise ConnectionTypeChange

    def read_request(self):
        request = HTTPRequest.from_stream(self.c.client_conn.rfile, body_size_limit=self.c.config.body_size_limit)

        if self.c.mode == "regular":
            self.authenticate(request)
        if request.form_in == "authority" and self.c.client_conn.ssl_established:
            raise ProtocolError(502, "Must not CONNECT on already encrypted connection")

        # If we have a CONNECT request, we might need to intercept
        if request.form_in == "authority":
            directly_addressed_at_mitmproxy = (self.c.mode == "regular") and not self.c.config.forward_proxy
            if directly_addressed_at_mitmproxy:
                self.c.establish_server_connection(request.host, request.port)
                self.c.client_conn.wfile.write(
                    'HTTP/1.1 200 Connection established\r\n' +
                    ('Proxy-agent: %s\r\n'%self.c.server_version) +
                    '\r\n'
                )
                self.c.client_conn.wfile.flush()
                self.ssl_upgrade()

        if self.c.mode == "regular":
            if request.form_in == "authority":
                pass
            elif request.form_in == "absolute":
                if not self.c.config.forward_proxy:
                        request.form_out = "origin"
                        if ((not self.c.server_conn) or
                                (self.c.server_conn.address != (request.host, request.port))):
                            self.c.establish_server_connection(request.host, request.port)
            elif request.form_in == "asterisk":
                raise ProtocolError(501, "Not Implemented")
            else:
                raise ProtocolError(400, "Invalid Request")
        return request

    def read_response(self, flow):
        return HTTPResponse.from_stream(self.c.server_conn.rfile, flow.request.method,
                                        body_size_limit=self.c.config.body_size_limit)

    def authenticate(self, request):
        if self.c.config.authenticator:
            if self.c.config.authenticator.authenticate(request.headers):
                self.c.config.authenticator.clean(request.headers)
            else:
                raise ProtocolError(
                    407,
                    "Proxy Authentication Required",
                    self.c.config.authenticator.auth_challenge_headers()
                )
        return request.headers

    @staticmethod
    def get_line(fp):
        """
            Get a line, possibly preceded by a blank.
        """
        line = fp.readline()
        if line == "\r\n" or line == "\n":  # Possible leftover from previous message
            line = fp.readline()
        if line == "":
            raise tcp.NetLibDisconnect
        return line