from __future__ import absolute_import
import os
import struct
import io

import six

from netlib import tcp
from netlib import strutils
from netlib import utils
from netlib import human
from .masker import Masker


MAX_16_BIT_INT = (1 << 16)
MAX_64_BIT_INT = (1 << 64)

DEFAULT = object()

# RFC 6455, Section 5.2 - Base Framing Protocol
OPCODE = utils.BiDi(
    CONTINUE=0x00,
    TEXT=0x01,
    BINARY=0x02,
    CLOSE=0x08,
    PING=0x09,
    PONG=0x0a
)

# RFC 6455, Section 7.4.1 - Defined Status Codes
CLOSE_REASON = utils.BiDi(
    NORMAL_CLOSURE=1000,
    GOING_AWAY=1001,
    PROTOCOL_ERROR=1002,
    UNSUPPORTED_DATA=1003,
    RESERVED=1004,
    RESERVED_NO_STATUS=1005,
    RESERVED_ABNORMAL_CLOSURE=1006,
    INVALID_PAYLOAD_DATA=1007,
    POLICY_VIOLATION=1008,
    MESSAGE_TOO_BIG=1009,
    MANDATORY_EXTENSION=1010,
    INTERNAL_ERROR=1011,
    RESERVED_TLS_HANDHSAKE_FAILED=1015,
)


class FrameHeader(object):

    def __init__(
        self,
        opcode=OPCODE.TEXT,
        payload_length=0,
        fin=False,
        rsv1=False,
        rsv2=False,
        rsv3=False,
        masking_key=DEFAULT,
        mask=DEFAULT,
        length_code=DEFAULT
    ):
        if not 0 <= opcode < 2 ** 4:
            raise ValueError("opcode must be 0-16")
        self.opcode = opcode
        self.payload_length = payload_length
        self.fin = fin
        self.rsv1 = rsv1
        self.rsv2 = rsv2
        self.rsv3 = rsv3

        if length_code is DEFAULT:
            self.length_code = self._make_length_code(self.payload_length)
        else:
            self.length_code = length_code

        if mask is DEFAULT and masking_key is DEFAULT:
            self.mask = False
            self.masking_key = b""
        elif mask is DEFAULT:
            self.mask = 1
            self.masking_key = masking_key
        elif masking_key is DEFAULT:
            self.mask = mask
            self.masking_key = os.urandom(4)
        else:
            self.mask = mask
            self.masking_key = masking_key

        if self.masking_key and len(self.masking_key) != 4:
            raise ValueError("Masking key must be 4 bytes.")

    @classmethod
    def _make_length_code(self, length):
        """
         A websockets frame contains an initial length_code, and an optional
         extended length code to represent the actual length if length code is
         larger than 125
        """
        if length <= 125:
            return length
        elif length >= 126 and length <= 65535:
            return 126
        else:
            return 127

    def __repr__(self):
        vals = [
            "ws frame:",
            OPCODE.get_name(self.opcode, hex(self.opcode)).lower()
        ]
        flags = []
        for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]:
            if getattr(self, i):
                flags.append(i)
        if flags:
            vals.extend([":", "|".join(flags)])
        if self.masking_key:
            vals.append(":key=%s" % repr(self.masking_key))
        if self.payload_length:
            vals.append(" %s" % human.pretty_size(self.payload_length))
        return "".join(vals)

    def __bytes__(self):
        first_byte = utils.setbit(0, 7, self.fin)
        first_byte = utils.setbit(first_byte, 6, self.rsv1)
        first_byte = utils.setbit(first_byte, 5, self.rsv2)
        first_byte = utils.setbit(first_byte, 4, self.rsv3)
        first_byte = first_byte | self.opcode

        second_byte = utils.setbit(self.length_code, 7, self.mask)

        b = six.int2byte(first_byte) + six.int2byte(second_byte)

        if self.payload_length < 126:
            pass
        elif self.payload_length < MAX_16_BIT_INT:
            # '!H' pack as 16 bit unsigned short
            # add 2 byte extended payload length
            b += struct.pack('!H', self.payload_length)
        elif self.payload_length < MAX_64_BIT_INT:
            # '!Q' = pack as 64 bit unsigned long long
            # add 8 bytes extended payload length
            b += struct.pack('!Q', self.payload_length)
        else:
            raise ValueError("Payload length exceeds 64bit integer")

        if self.masking_key:
            b += self.masking_key
        return b

    if six.PY2:
        __str__ = __bytes__

    @classmethod
    def from_file(cls, fp):
        """
          read a websockets frame header
        """
        first_byte = six.byte2int(fp.safe_read(1))
        second_byte = six.byte2int(fp.safe_read(1))

        fin = utils.getbit(first_byte, 7)
        rsv1 = utils.getbit(first_byte, 6)
        rsv2 = utils.getbit(first_byte, 5)
        rsv3 = utils.getbit(first_byte, 4)
        opcode = first_byte & 0xF
        mask_bit = utils.getbit(second_byte, 7)
        length_code = second_byte & 0x7F

        # payload_length > 125 indicates you need to read more bytes
        # to get the actual payload length
        if length_code <= 125:
            payload_length = length_code
        elif length_code == 126:
            payload_length, = struct.unpack("!H", fp.safe_read(2))
        else:  # length_code == 127:
            payload_length, = struct.unpack("!Q", fp.safe_read(8))

        # masking key only present if mask bit set
        if mask_bit == 1:
            masking_key = fp.safe_read(4)
        else:
            masking_key = None

        return cls(
            fin=fin,
            rsv1=rsv1,
            rsv2=rsv2,
            rsv3=rsv3,
            opcode=opcode,
            mask=mask_bit,
            length_code=length_code,
            payload_length=payload_length,
            masking_key=masking_key,
        )

    def __eq__(self, other):
        if isinstance(other, FrameHeader):
            return bytes(self) == bytes(other)
        return False


class Frame(object):
    """
    Represents a single WebSockets frame.
    Constructor takes human readable forms of the frame components.
    from_bytes() reads from a file-like object to create a new Frame.

    WebSockets Frame as defined in RFC6455

       0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
      +-+-+-+-+-------+-+-------------+-------------------------------+
      |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
      |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
      |N|V|V|V|       |S|             |   (if payload len==126/127)   |
      | |1|2|3|       |K|             |                               |
      +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
      |     Extended payload length continued, if payload len == 127  |
      + - - - - - - - - - - - - - - - +-------------------------------+
      |                               |Masking-key, if MASK set to 1  |
      +-------------------------------+-------------------------------+
      | Masking-key (continued)       |          Payload Data         |
      +-------------------------------- - - - - - - - - - - - - - - - +
      :                     Payload Data continued ...                :
      + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
      |                     Payload Data continued ...                |
      +---------------------------------------------------------------+
    """

    def __init__(self, payload=b"", **kwargs):
        self.payload = payload
        kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
        self.header = FrameHeader(**kwargs)

    @classmethod
    def from_bytes(cls, bytestring):
        """
          Construct a websocket frame from an in-memory bytestring
          to construct a frame from a stream of bytes, use from_file() directly
        """
        return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))

    def __repr__(self):
        ret = repr(self.header)
        if self.payload:
            ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload)
        return ret

    def __bytes__(self):
        """
            Serialize the frame to wire format. Returns a string.
        """
        b = bytes(self.header)
        if self.header.masking_key:
            b += Masker(self.header.masking_key)(self.payload)
        else:
            b += self.payload
        return b

    if six.PY2:
        __str__ = __bytes__

    @classmethod
    def from_file(cls, fp):
        """
          read a websockets frame sent by a server or client

          fp is a "file like" object that could be backed by a network
          stream or a disk or an in memory stream reader
        """
        header = FrameHeader.from_file(fp)
        payload = fp.safe_read(header.payload_length)

        if header.mask == 1 and header.masking_key:
            payload = Masker(header.masking_key)(payload)

        frame = cls(payload)
        frame.header = header
        return frame

    def __eq__(self, other):
        if isinstance(other, Frame):
            return bytes(self) == bytes(other)
        return False