from __future__ import (absolute_import, print_function, division)

import struct
import threading
import time
import Queue

from netlib.tcp import ssl_read_select
from netlib.exceptions import HttpException
from netlib.http import Headers

import h2
from h2.connection import H2Connection
from h2.events import *

from .base import Layer
from .http import _HttpTransmissionLayer, HttpLayer
from .. import utils
from ..models import HTTPRequest, HTTPResponse


class SafeH2Connection(H2Connection):
    def __init__(self, conn, *args, **kwargs):
        super(SafeH2Connection, self).__init__(*args, **kwargs)
        self.conn = conn
        self.lock = threading.RLock()

    def safe_close_connection(self, error_code):
        with self.lock:
            self.close_connection(error_code)
            self.conn.send(self.data_to_send())

    def safe_increment_flow_control(self, stream_id, length):
        if length == 0:
            return

        with self.lock:
            self.increment_flow_control_window(length)
            self.conn.send(self.data_to_send())
        with self.lock:
            if stream_id in self.streams and not self.streams[stream_id].closed:
                self.increment_flow_control_window(length, stream_id=stream_id)
                self.conn.send(self.data_to_send())

    def safe_reset_stream(self, stream_id, error_code):
        with self.lock:
            try:
                self.reset_stream(stream_id, error_code)
            except h2.exceptions.StreamClosedError:
                # stream is already closed - good
                pass
            self.conn.send(self.data_to_send())

    def safe_update_settings(self, new_settings):
        with self.lock:
            self.update_settings(new_settings)
            self.conn.send(self.data_to_send())

    def safe_send_headers(self, is_zombie, stream_id, headers):
        with self.lock:
            if is_zombie(self, stream_id):
                return
            self.send_headers(stream_id, headers)
            self.conn.send(self.data_to_send())

    def safe_send_body(self, is_zombie, stream_id, chunks):
        for chunk in chunks:
            position = 0
            while position < len(chunk):
                self.lock.acquire()
                if is_zombie(self, stream_id):
                    return
                max_outbound_frame_size = self.max_outbound_frame_size
                frame_chunk = chunk[position:position+max_outbound_frame_size]
                if self.local_flow_control_window(stream_id) < len(frame_chunk):
                    self.lock.release()
                    time.sleep(0)
                    continue
                self.send_data(stream_id, frame_chunk)
                self.conn.send(self.data_to_send())
                self.lock.release()
                position += max_outbound_frame_size
        with self.lock:
            if is_zombie(self, stream_id):
                return
            self.end_stream(stream_id)
            self.conn.send(self.data_to_send())


class Http2Layer(Layer):
    def __init__(self, ctx, mode):
        super(Http2Layer, self).__init__(ctx)
        self.mode = mode
        self.streams = dict()
        self.server_to_client_stream_ids = dict([(0, 0)])
        self.client_conn.h2 = SafeH2Connection(self.client_conn, client_side=False)

        # make sure that we only pass actual SSL.Connection objects in here,
        # because otherwise ssl_read_select fails!
        self.active_conns = [self.client_conn.connection]

    def _initiate_server_conn(self):
        self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True)
        self.server_conn.h2.initiate_connection()
        self.server_conn.send(self.server_conn.h2.data_to_send())
        self.active_conns.append(self.server_conn.connection)

    def connect(self):
        self.ctx.connect()
        self.server_conn.connect()
        self._initiate_server_conn()

    def set_server(self):
        raise NotImplementedError("Cannot change server for HTTP2 connections.")

    def disconnect(self):
        raise NotImplementedError("Cannot dis- or reconnect in HTTP2 connections.")

    def next_layer(self):
        # WebSockets over HTTP/2?
        # CONNECT for proxying?
        raise NotImplementedError()

    def _handle_event(self, event, source_conn, other_conn, is_server):
        if hasattr(event, 'stream_id'):
            if is_server and event.stream_id % 2 == 1:
                eid = self.server_to_client_stream_ids[event.stream_id]
            else:
                eid = event.stream_id

        if isinstance(event, RequestReceived):
            headers = Headers([[str(k), str(v)] for k, v in event.headers])
            self.streams[eid] = Http2SingleStreamLayer(self, eid, headers)
            self.streams[eid].timestamp_start = time.time()
            self.streams[eid].start()
        elif isinstance(event, ResponseReceived):
            headers = Headers([[str(k), str(v)] for k, v in event.headers])
            self.streams[eid].queued_data_length = 0
            self.streams[eid].timestamp_start = time.time()
            self.streams[eid].response_headers = headers
            self.streams[eid].response_arrived.set()
        elif isinstance(event, DataReceived):
            if self.config.body_size_limit and self.streams[eid].queued_data_length > self.config.body_size_limit:
                raise HttpException("HTTP body too large. Limit is {}.".format(self.config.body_size_limit))
            self.streams[eid].data_queue.put(event.data)
            self.streams[eid].queued_data_length += len(event.data)
            source_conn.h2.safe_increment_flow_control(event.stream_id, len(event.data))
        elif isinstance(event, StreamEnded):
            self.streams[eid].timestamp_end = time.time()
            self.streams[eid].data_finished.set()
        elif isinstance(event, StreamReset):
            self.streams[eid].zombie = time.time()
            if eid in self.streams and event.error_code == 0x8:
                if is_server:
                    other_stream_id = self.streams[eid].client_stream_id
                else:
                    other_stream_id = self.streams[eid].server_stream_id
                other_conn.h2.safe_reset_stream(other_stream_id, event.error_code)
        elif isinstance(event, RemoteSettingsChanged):
            new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()])
            other_conn.h2.safe_update_settings(new_settings)
        elif isinstance(event, ConnectionTerminated):
            other_conn.h2.safe_close_connection(event.error_code)
            return False
        elif isinstance(event, PushedStreamReceived):
            # pushed stream ids should be uniq and not dependent on race conditions
            # only the parent stream id must be looked up first
            parent_eid = self.server_to_client_stream_ids[event.parent_stream_id]
            with self.client_conn.h2.lock:
                self.client_conn.h2.push_stream(parent_eid, event.pushed_stream_id, event.headers)

            headers = Headers([[str(k), str(v)] for k, v in event.headers])
            headers['x-mitmproxy-pushed'] = 'true'
            self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, event.pushed_stream_id, headers)
            self.streams[event.pushed_stream_id].timestamp_start = time.time()
            self.streams[event.pushed_stream_id].pushed = True
            self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
            self.streams[event.pushed_stream_id].timestamp_end = time.time()
            self.streams[event.pushed_stream_id].data_finished.set()
            self.streams[event.pushed_stream_id].start()
        elif isinstance(event, TrailersReceived):
            raise NotImplementedError()

        return True

    def _cleanup_streams(self):
        death_time = time.time() - 10
        for stream_id in self.streams.keys():
            zombie = self.streams[stream_id].zombie
            if zombie and zombie <= death_time:
                self.streams.pop(stream_id, None)

    def __call__(self):
        if self.server_conn:
            self._initiate_server_conn()

        preamble = self.client_conn.rfile.read(24)
        self.client_conn.h2.initiate_connection()
        self.client_conn.h2.receive_data(preamble)
        self.client_conn.send(self.client_conn.h2.data_to_send())

        while True:
            r = ssl_read_select(self.active_conns, 1)
            for conn in r:
                source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
                other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
                is_server = (conn == self.server_conn.connection)

                with source_conn.h2.lock:
                    events = source_conn.h2.receive_data(utils.http2_read_frame(source_conn.rfile))
                    source_conn.send(source_conn.h2.data_to_send())

                    for event in events:
                        if not self._handle_event(event, source_conn, other_conn, is_server):
                            return

            self._cleanup_streams()


class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
    def __init__(self, ctx, stream_id, request_headers):
        super(Http2SingleStreamLayer, self).__init__(ctx)
        self.zombie = None
        self.client_stream_id = stream_id
        self.server_stream_id = None
        self.request_headers = request_headers
        self.response_headers = None
        self.data_queue = Queue.Queue()
        self.queued_data_length = 0

        self.response_arrived = threading.Event()
        self.data_finished = threading.Event()

    def is_zombie(self, h2_conn, stream_id):
        if self.zombie:
            return True
        return False

    def read_request(self):
        self.data_finished.wait()
        self.data_finished.clear()

        authority = self.request_headers.get(':authority', '')
        method = self.request_headers.get(':method', 'GET')
        scheme = self.request_headers.get(':scheme', 'https')
        path = self.request_headers.get(':path', '/')
        host = None
        port = None

        if path == '*' or path.startswith("/"):
            form_in = "relative"
        elif method == 'CONNECT':
            form_in = "authority"
            if ":" in authority:
                host, port = authority.split(":", 1)
            else:
                host = authority
        else:
            form_in = "absolute"
            # FIXME: verify if path or :host contains what we need
            scheme, host, port, _ = utils.parse_url(path)

        if host is None:
            host = 'localhost'
        if port is None:
            port = 80 if scheme == 'http' else 443
        port = int(port)

        data = []
        while self.data_queue.qsize() > 0:
            data.append(self.data_queue.get())
        data = b"".join(data)

        return HTTPRequest(
            form_in,
            method,
            scheme,
            host,
            port,
            path,
            (2, 0),
            self.request_headers,
            data,
            timestamp_start=self.timestamp_start,
            timestamp_end=self.timestamp_end,
        )

    def send_request(self, message):
        with self.server_conn.h2.lock:
            self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
            self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id

            self.server_conn.h2.safe_send_headers(
                self.is_zombie,
                self.server_stream_id,
                message.headers
            )
        self.server_conn.h2.safe_send_body(
            self.is_zombie,
            self.server_stream_id,
            message.body
        )

    def read_response_headers(self):
        self.response_arrived.wait()

        status_code = int(self.response_headers.get(':status', 502))

        return HTTPResponse(
            http_version=(2, 0),
            status_code=status_code,
            reason='',
            headers=self.response_headers,
            content=None,
            timestamp_start=self.timestamp_start,
            timestamp_end=self.timestamp_end,
        )

    def read_response_body(self, request, response):
        while True:
            try:
                yield self.data_queue.get(timeout=1)
            except Queue.Empty:
                pass
            if self.data_finished.is_set():
                while self.data_queue.qsize() > 0:
                    yield self.data_queue.get()
                return
            if self.zombie:
                return

    def send_response_headers(self, response):
        self.client_conn.h2.safe_send_headers(
            self.is_zombie,
            self.client_stream_id,
            response.headers
        )

    def send_response_body(self, _response, chunks):
        self.client_conn.h2.safe_send_body(
            self.is_zombie,
            self.client_stream_id,
            chunks
        )

    def check_close_connection(self, flow):
        # This layer only handles a single stream.
        # RFC 7540 8.1: An HTTP request/response exchange fully consumes a single stream.
        return True

    def connect(self):
        raise ValueError("CONNECT inside an HTTP2 stream is not supported.")

    def set_server(self, *args, **kwargs):
        # do not mess with the server connection - all streams share it.
        pass

    def run(self):
        layer = HttpLayer(self, self.mode)
        layer()
        self.zombie = time.time()