aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
authorThomas Kriechbaumer <Kriechi@users.noreply.github.com>2016-02-08 09:52:29 +0100
committerThomas Kriechbaumer <Kriechi@users.noreply.github.com>2016-02-08 09:52:29 +0100
commit4ee1ad88fc440164985c8efc50c0be133a0053bc (patch)
tree210635d4aa964873f8054c80fcbeb2e9f94ce6ab /netlib
parent4873547de3c65ba7c14cace4bca7b17368b2900d (diff)
parent655b521749efd5a600d342a1d95b67d32da280a8 (diff)
downloadmitmproxy-4ee1ad88fc440164985c8efc50c0be133a0053bc.tar.gz
mitmproxy-4ee1ad88fc440164985c8efc50c0be133a0053bc.tar.bz2
mitmproxy-4ee1ad88fc440164985c8efc50c0be133a0053bc.zip
Merge pull request #120 from mitmproxy/model-cleanup
Model Cleanup
Diffstat (limited to 'netlib')
-rw-r--r--netlib/certutils.py23
-rw-r--r--netlib/http/headers.py7
-rw-r--r--netlib/http/message.py32
-rw-r--r--netlib/http/request.py5
-rw-r--r--netlib/http/response.py5
-rw-r--r--netlib/odict.py10
-rw-r--r--netlib/tcp.py18
-rw-r--r--netlib/utils.py33
8 files changed, 108 insertions, 25 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py
index a0111381..616a778e 100644
--- a/netlib/certutils.py
+++ b/netlib/certutils.py
@@ -12,7 +12,10 @@ from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error
import OpenSSL
+from .utils import Serializable
+
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
+
DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3
# Generated with "openssl dhparam". It's too slow to generate this on startup.
DEFAULT_DHPARAM = b"""
@@ -361,7 +364,7 @@ class _GeneralNames(univ.SequenceOf):
constraint.ValueSizeConstraint(1, 1024)
-class SSLCert(object):
+class SSLCert(Serializable):
def __init__(self, cert):
"""
@@ -375,15 +378,25 @@ class SSLCert(object):
def __ne__(self, other):
return not self.__eq__(other)
+ def get_state(self):
+ return self.to_pem()
+
+ def set_state(self, state):
+ self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state)
+
+ @classmethod
+ def from_state(cls, state):
+ cls.from_pem(state)
+
@classmethod
- def from_pem(klass, txt):
+ def from_pem(cls, txt):
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
- return klass(x509)
+ return cls(x509)
@classmethod
- def from_der(klass, der):
+ def from_der(cls, der):
pem = ssl.DER_cert_to_PEM_cert(der)
- return klass.from_pem(pem)
+ return cls.from_pem(pem)
def to_pem(self):
return OpenSSL.crypto.dump_certificate(
diff --git a/netlib/http/headers.py b/netlib/http/headers.py
index 6eb9db92..78404796 100644
--- a/netlib/http/headers.py
+++ b/netlib/http/headers.py
@@ -14,7 +14,7 @@ except ImportError: # pragma: nocover
import six
-from netlib.utils import always_byte_args, always_bytes
+from netlib.utils import always_byte_args, always_bytes, Serializable
if six.PY2: # pragma: nocover
_native = lambda x: x
@@ -27,7 +27,7 @@ else:
_always_byte_args = always_byte_args("utf-8", "surrogateescape")
-class Headers(MutableMapping):
+class Headers(MutableMapping, Serializable):
"""
Header class which allows both convenient access to individual headers as well as
direct access to the underlying raw data. Provides a full dictionary interface.
@@ -193,11 +193,10 @@ class Headers(MutableMapping):
def copy(self):
return Headers(copy.copy(self.fields))
- # Implement the StateObject protocol from mitmproxy
def get_state(self):
return tuple(tuple(field) for field in self.fields)
- def load_state(self, state):
+ def set_state(self, state):
self.fields = [list(field) for field in state]
@classmethod
diff --git a/netlib/http/message.py b/netlib/http/message.py
index 28f55fa2..e3d8ce37 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -4,9 +4,9 @@ import warnings
import six
+from .headers import Headers
from .. import encoding, utils
-
CONTENT_MISSING = 0
if six.PY2: # pragma: nocover
@@ -18,7 +18,7 @@ else:
_always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape")
-class MessageData(object):
+class MessageData(utils.Serializable):
def __eq__(self, other):
if isinstance(other, MessageData):
return self.__dict__ == other.__dict__
@@ -27,8 +27,24 @@ class MessageData(object):
def __ne__(self, other):
return not self.__eq__(other)
+ def set_state(self, state):
+ for k, v in state.items():
+ if k == "headers":
+ v = Headers.from_state(v)
+ setattr(self, k, v)
+
+ def get_state(self):
+ state = vars(self).copy()
+ state["headers"] = state["headers"].get_state()
+ return state
+
+ @classmethod
+ def from_state(cls, state):
+ state["headers"] = Headers.from_state(state["headers"])
+ return cls(**state)
+
-class Message(object):
+class Message(utils.Serializable):
def __init__(self, data):
self.data = data
@@ -40,6 +56,16 @@ class Message(object):
def __ne__(self, other):
return not self.__eq__(other)
+ def get_state(self):
+ return self.data.get_state()
+
+ def set_state(self, state):
+ self.data.set_state(state)
+
+ @classmethod
+ def from_state(cls, state):
+ return cls(**state)
+
@property
def headers(self):
"""
diff --git a/netlib/http/request.py b/netlib/http/request.py
index 6dabb189..0e0f88ce 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -16,9 +16,8 @@ from .message import Message, _native, _always_bytes, MessageData
class RequestData(MessageData):
def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None,
timestamp_start=None, timestamp_end=None):
- if not headers:
- headers = Headers()
- assert isinstance(headers, Headers)
+ if not isinstance(headers, Headers):
+ headers = Headers(headers)
self.first_line_format = first_line_format
self.method = method
diff --git a/netlib/http/response.py b/netlib/http/response.py
index 66e5ded6..8f4d6215 100644
--- a/netlib/http/response.py
+++ b/netlib/http/response.py
@@ -12,9 +12,8 @@ from ..odict import ODict
class ResponseData(MessageData):
def __init__(self, http_version, status_code, reason=None, headers=None, content=None,
timestamp_start=None, timestamp_end=None):
- if not headers:
- headers = Headers()
- assert isinstance(headers, Headers)
+ if not isinstance(headers, Headers):
+ headers = Headers(headers)
self.http_version = http_version
self.status_code = status_code
diff --git a/netlib/odict.py b/netlib/odict.py
index 90317e5e..1e6e381a 100644
--- a/netlib/odict.py
+++ b/netlib/odict.py
@@ -3,6 +3,8 @@ import re
import copy
import six
+from .utils import Serializable
+
def safe_subn(pattern, repl, target, *args, **kwargs):
"""
@@ -13,7 +15,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
-class ODict(object):
+class ODict(Serializable):
"""
A dictionary-like object for managing ordered (key, value) data. Think
@@ -172,12 +174,12 @@ class ODict(object):
def get_state(self):
return [tuple(i) for i in self.lst]
- def load_state(self, state):
+ def set_state(self, state):
self.lst = [list(i) for i in state]
@classmethod
- def from_state(klass, state):
- return klass([list(i) for i in state])
+ def from_state(cls, state):
+ return cls([list(i) for i in state])
class ODictCaseless(ODict):
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 85b4b0e2..c8548aea 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -16,7 +16,7 @@ import six
import OpenSSL
from OpenSSL import SSL
-from . import certutils, version_check
+from . import certutils, version_check, utils
# This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed.
@@ -298,7 +298,7 @@ class Reader(_FileLike):
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
-class Address(object):
+class Address(utils.Serializable):
"""
This class wraps an IPv4/IPv6 tuple to provide named attributes and
@@ -309,6 +309,20 @@ class Address(object):
self.address = tuple(address)
self.use_ipv6 = use_ipv6
+ def get_state(self):
+ return {
+ "address": self.address,
+ "use_ipv6": self.use_ipv6
+ }
+
+ def set_state(self, state):
+ self.address = state["address"]
+ self.use_ipv6 = state["use_ipv6"]
+
+ @classmethod
+ def from_state(cls, state):
+ return Address(**state)
+
@classmethod
def wrap(cls, t):
if isinstance(t, cls):
diff --git a/netlib/utils.py b/netlib/utils.py
index 1c1b617a..d2fc7195 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -1,14 +1,45 @@
from __future__ import absolute_import, print_function, division
import os.path
import re
-import string
import codecs
import unicodedata
+from abc import ABCMeta, abstractmethod
+
import six
from six.moves import urllib
import hyperframe
+
+@six.add_metaclass(ABCMeta)
+class Serializable(object):
+ """
+ Abstract Base Class that defines an API to save an object's state and restore it later on.
+ """
+
+ @classmethod
+ @abstractmethod
+ def from_state(cls, state):
+ """
+ Create a new object from the given state.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def get_state(self):
+ """
+ Retrieve object state.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def set_state(self, state):
+ """
+ Set object state to the given state.
+ """
+ raise NotImplementedError()
+
+
def always_bytes(unicode_or_bytes, *encode_args):
if isinstance(unicode_or_bytes, six.text_type):
return unicode_or_bytes.encode(*encode_args)