diff options
Diffstat (limited to 'libmproxy/stateobject.py')
-rw-r--r-- | libmproxy/stateobject.py | 55 |
1 files changed, 27 insertions, 28 deletions
diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py index 52a8347f..9600ab09 100644 --- a/libmproxy/stateobject.py +++ b/libmproxy/stateobject.py @@ -1,52 +1,51 @@ from __future__ import absolute_import +from netlib.utils import Serializable -class StateObject(object): - +class StateObject(Serializable): """ - An object with serializable state. + An object with serializable state. - State attributes can either be serializable types(str, tuple, bool, ...) - or StateObject instances themselves. + State attributes can either be serializable types(str, tuple, bool, ...) + or StateObject instances themselves. """ - # An attribute-name -> class-or-type dict containing all attributes that - # should be serialized. If the attribute is a class, it must implement the - # StateObject protocol. - _stateobject_attributes = None - # A set() of attributes that should be ignored for short state - _stateobject_long_attributes = frozenset([]) - def from_state(self, state): - raise NotImplementedError() + _stateobject_attributes = None + """ + An attribute-name -> class-or-type dict containing all attributes that + should be serialized. If the attribute is a class, it must implement the + Serializable protocol. + """ - def get_state(self, short=False): + def get_state(self): """ - Retrieve object state. If short is true, return an abbreviated - format with long data elided. + Retrieve object state. """ state = {} for attr, cls in self._stateobject_attributes.iteritems(): - if short and attr in self._stateobject_long_attributes: - continue val = getattr(self, attr) if hasattr(val, "get_state"): - state[attr] = val.get_state(short) + state[attr] = val.get_state() else: state[attr] = val return state - def load_state(self, state): + def set_state(self, state): """ - Load object state from data returned by a get_state call. + Load object state from data returned by a get_state call. """ + state = state.copy() for attr, cls in self._stateobject_attributes.iteritems(): - if state.get(attr, None) is None: - setattr(self, attr, None) + if state.get(attr) is None: + setattr(self, attr, state.pop(attr)) else: curr = getattr(self, attr) - if hasattr(curr, "load_state"): - curr.load_state(state[attr]) + if hasattr(curr, "set_state"): + curr.set_state(state.pop(attr)) elif hasattr(cls, "from_state"): - setattr(self, attr, cls.from_state(state[attr])) - else: - setattr(self, attr, cls(state[attr])) + obj = cls.from_state(state.pop(attr)) + setattr(self, attr, obj) + else: # primitive types such as int, str, ... + setattr(self, attr, cls(state.pop(attr))) + if state: + raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state)) |