from __future__ import print_function, unicode_literals

import os
import random
import re
import select
import socket
import ssl
import string
import subprocess
import sys
import time
from itertools import count
from threading import Event, Lock, Thread

import paho.mqtt.client as mqtt
import ttfw_idf
from common_test_methods import get_host_ip4_by_dest_ip

DEFAULT_MSG_SIZE = 16


def _path(f):
    return os.path.join(os.path.dirname(os.path.realpath(__file__)),f)


def set_server_cert_cn(ip):
    arg_list = [
        ['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'],
        ['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'),
         '-CAkey', _path('ca.key'), '-CAcreateserial', '-out', _path('srv.crt'), '-days', '360']]
    for args in arg_list:
        if subprocess.check_call(args) != 0:
            raise RuntimeError('openssl command {} failed'.format(args))


# Publisher class creating a python client to send/receive published data from esp-mqtt client
class MqttPublisher:

    def __init__(self, dut, transport, qos, repeat, published, queue, publish_cfg, log_details=False):
        # instance variables used as parameters of the publish test
        self.event_stop_client = Event()
        self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE))
        self.client = None
        self.dut = dut
        self.log_details = log_details
        self.repeat = repeat
        self.publish_cfg = publish_cfg
        self.publish_cfg['qos'] = qos
        self.publish_cfg['queue'] = queue
        self.publish_cfg['transport'] = transport
        self.lock = Lock()
        # static variables used to pass options to and from static callbacks of paho-mqtt client
        MqttPublisher.event_client_connected = Event()
        MqttPublisher.event_client_got_all = Event()
        MqttPublisher.published = published
        MqttPublisher.event_client_connected.clear()
        MqttPublisher.event_client_got_all.clear()
        MqttPublisher.expected_data = self.sample_string * self.repeat

    def print_details(self, text):
        if self.log_details:
            print(text)

    def mqtt_client_task(self, client, lock):
        while not self.event_stop_client.is_set():
            with lock:
                client.loop()
            time.sleep(0.001)   # yield to other threads

    # The callback for when the client receives a CONNACK response from the server (needs to be static)
    @staticmethod
    def on_connect(_client, _userdata, _flags, _rc):
        MqttPublisher.event_client_connected.set()

    # The callback for when a PUBLISH message is received from the server (needs to be static)
    @staticmethod
    def on_message(client, userdata, msg):
        payload = msg.payload.decode()
        if payload == MqttPublisher.expected_data:
            userdata += 1
            client.user_data_set(userdata)
            if userdata == MqttPublisher.published:
                MqttPublisher.event_client_got_all.set()

    def __enter__(self):

        qos = self.publish_cfg['qos']
        queue = self.publish_cfg['queue']
        transport = self.publish_cfg['transport']
        broker_host = self.publish_cfg['broker_host_' + transport]
        broker_port = self.publish_cfg['broker_port_' + transport]

        # Start the test
        self.print_details("PUBLISH TEST: transport:{}, qos:{}, sequence:{}, enqueue:{}, sample msg:'{}'"
                           .format(transport, qos, MqttPublisher.published, queue, MqttPublisher.expected_data))

        try:
            if transport in ['ws', 'wss']:
                self.client = mqtt.Client(transport='websockets')
            else:
                self.client = mqtt.Client()
            self.client.on_connect = MqttPublisher.on_connect
            self.client.on_message = MqttPublisher.on_message
            self.client.user_data_set(0)

            if transport in ['ssl', 'wss']:
                self.client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
                self.client.tls_insecure_set(True)
            self.print_details('Connecting...')
            self.client.connect(broker_host, broker_port, 60)
        except Exception:
            self.print_details('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}'.format(broker_host))
            raise
        # Starting a py-client in a separate thread
        thread1 = Thread(target=self.mqtt_client_task, args=(self.client, self.lock))
        thread1.start()
        self.print_details('Connecting py-client to broker {}:{}...'.format(broker_host, broker_port))
        if not MqttPublisher.event_client_connected.wait(timeout=30):
            raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_host))
        with self.lock:
            self.client.subscribe(self.publish_cfg['subscribe_topic'], qos)
        self.dut.write(' '.join(str(x) for x in (transport, self.sample_string, self.repeat, MqttPublisher.published, qos, queue)), eol='\n')
        try:
            # waiting till subscribed to defined topic
            self.dut.expect(re.compile(r'MQTT_EVENT_SUBSCRIBED'), timeout=30)
            for _ in range(MqttPublisher.published):
                with self.lock:
                    self.client.publish(self.publish_cfg['publish_topic'], self.sample_string * self.repeat, qos)
                self.print_details('Publishing...')
            self.print_details('Checking esp-client received msg published from py-client...')
            self.dut.expect(re.compile(r'Correct pattern received exactly x times'), timeout=60)
            if not MqttPublisher.event_client_got_all.wait(timeout=60):
                raise ValueError('Not all data received from ESP32')
            print(' - all data received from ESP32')
        finally:
            self.event_stop_client.set()
            thread1.join()

    def __exit__(self, exc_type, exc_value, traceback):
        self.client.disconnect()
        self.event_stop_client.clear()


# Simple server for mqtt over TLS connection
class TlsServer:

    def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False):
        self.port = port
        self.socket = socket.socket()
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.settimeout(10.0)
        self.shutdown = Event()
        self.client_cert = client_cert
        self.refuse_connection = refuse_connection
        self.ssl_error = None
        self.use_alpn = use_alpn
        self.negotiated_protocol = None

    def __enter__(self):
        try:
            self.socket.bind(('', self.port))
        except socket.error as e:
            print('Bind failed:{}'.format(e))
            raise

        self.socket.listen(1)
        self.server_thread = Thread(target=self.run_server)
        self.server_thread.start()

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.shutdown.set()
        self.server_thread.join()
        self.socket.close()
        if (self.conn is not None):
            self.conn.close()

    def get_last_ssl_error(self):
        return self.ssl_error

    def get_negotiated_protocol(self):
        return self.negotiated_protocol

    def run_server(self):
        context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
        if self.client_cert:
            context.verify_mode = ssl.CERT_REQUIRED
            context.load_verify_locations(cafile=_path('ca.crt'))
        context.load_cert_chain(certfile=_path('srv.crt'), keyfile=_path('server.key'))
        if self.use_alpn:
            context.set_alpn_protocols(['mymqtt', 'http/1.1'])
        self.socket = context.wrap_socket(self.socket, server_side=True)
        try:
            self.conn, address = self.socket.accept()  # accept new connection
            self.socket.settimeout(10.0)
            print(' - connection from: {}'.format(address))
            if self.use_alpn:
                self.negotiated_protocol = self.conn.selected_alpn_protocol()
                print(' - negotiated_protocol: {}'.format(self.negotiated_protocol))
            self.handle_conn()
        except ssl.SSLError as e:
            self.conn = None
            self.ssl_error = str(e)
            print(' - SSLError: {}'.format(str(e)))

    def handle_conn(self):
        while not self.shutdown.is_set():
            r,w,e = select.select([self.conn], [], [], 1)
            try:
                if self.conn in r:
                    self.process_mqtt_connect()

            except socket.error as err:
                print(' - error: {}'.format(err))
                raise

    def process_mqtt_connect(self):
        try:
            data = bytearray(self.conn.recv(1024))
            message = ''.join(format(x, '02x') for x in data)
            if message[0:16] == '101800044d515454':
                if self.refuse_connection is False:
                    print(' - received mqtt connect, sending ACK')
                    self.conn.send(bytearray.fromhex('20020000'))
                else:
                    # injecting connection not authorized error
                    print(' - received mqtt connect, sending NAK')
                    self.conn.send(bytearray.fromhex('20020005'))
            else:
                raise Exception(' - error process_mqtt_connect unexpected connect received: {}'.format(message))
        finally:
            # stop the server after the connect message in happy flow, or if any exception occur
            self.shutdown.set()


def connection_tests(dut, cases, dut_ip):
    ip = get_host_ip4_by_dest_ip(dut_ip)
    set_server_cert_cn(ip)
    server_port = 2222

    def teardown_connection_suite():
        dut.write('conn teardown 0 0')

    def start_connection_case(case, desc):
        print('Starting {}: {}'.format(case, desc))
        case_id = cases[case]
        dut.write('conn {} {} {}'.format(ip, server_port, case_id))
        dut.expect('Test case:{} started'.format(case_id))
        return case_id

    for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']:
        # All these cases connect to the server with no server verification or with server only verification
        with TlsServer(server_port):
            test_nr = start_connection_case(case, 'default server - expect to connect normally')
            dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
        with TlsServer(server_port, refuse_connection=True):
            test_nr = start_connection_case(case, 'ssl shall connect, but mqtt sends connect refusal')
            dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
            dut.expect('MQTT ERROR: 0x5')  # expecting 0x5 ... connection not authorized error
        with TlsServer(server_port, client_cert=True) as s:
            test_nr = start_connection_case(case, 'server with client verification - handshake error since client presents no client certificate')
            dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
            dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED')  # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
            if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error():
                raise RuntimeError('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))

    for case in ['CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
        # These cases connect to server with both server and client verification (client key might be password protected)
        with TlsServer(server_port, client_cert=True):
            test_nr = start_connection_case(case, 'server with client verification - expect to connect normally')
            dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)

    case = 'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT'
    with TlsServer(server_port) as s:
        test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error')
        dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
        dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED')  # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
        if 'alert unknown ca' not in s.get_last_ssl_error():
            raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))

    case = 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
    with TlsServer(server_port, client_cert=True) as s:
        test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error')
        dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30)
        dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED')  # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
        if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error():
            raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error()))

    for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
        with TlsServer(server_port, use_alpn=True) as s:
            test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol')
            dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30)
            if case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None:
                print(' - client with alpn off, no negotiated protocol: OK')
            elif case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt':
                print(' - client with alpn on, negotiated protocol resolved: OK')
            else:
                raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol()))

    teardown_connection_suite()


@ttfw_idf.idf_custom_test(env_tag='ethernet_router', group='test-apps')
def test_app_protocol_mqtt_publish_connect(env, extra_data):
    """
    steps:
      1. join AP
      2. connect to uri specified in the config
      3. send and receive data
    """
    dut1 = env.get_dut('mqtt_publish_connect_test', 'tools/test_apps/protocols/mqtt/publish_connect_test')
    # check and log bin size
    binary_file = os.path.join(dut1.app.binary_path, 'mqtt_publish_connect_test.bin')
    bin_size = os.path.getsize(binary_file)
    ttfw_idf.log_performance('mqtt_publish_connect_test_bin_size', '{}KB'.format(bin_size // 1024))

    # Look for test case symbolic names and publish configs
    cases = {}
    publish_cfg = {}
    try:

        # Get connection test cases configuration: symbolic names for test cases
        for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT',
                     'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT',
                     'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
                     'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
                     'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
                     'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
                     'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
                     'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']:
            cases[case] = dut1.app.get_sdkconfig()[case]
    except Exception:
        print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
        raise

    dut1.start_app()
    esp_ip = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0]
    print('Got IP={}'.format(esp_ip))

    if not os.getenv('MQTT_SKIP_CONNECT_TEST'):
        connection_tests(dut1,cases,esp_ip)

    #
    # start publish tests only if enabled in the environment (for weekend tests only)
    if not os.getenv('MQTT_PUBLISH_TEST'):
        return

    # Get publish test configuration
    try:
        def get_host_port_from_dut(dut1, config_option):
            value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()[config_option])
            if value is None:
                return None, None
            return value.group(1), int(value.group(2))

        publish_cfg['publish_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_SUBSCRIBE_TOPIC'].replace('"','')
        publish_cfg['subscribe_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_PUBLISH_TOPIC'].replace('"','')
        publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_SSL_URI')
        publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_TCP_URI')
        publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WS_URI')
        publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WSS_URI')

    except Exception:
        print('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
        raise

    def start_publish_case(transport, qos, repeat, published, queue):
        print('Starting Publish test: transport:{}, qos:{}, nr_of_msgs:{}, msg_size:{}, enqueue:{}'
              .format(transport, qos, published, repeat * DEFAULT_MSG_SIZE, queue))
        with MqttPublisher(dut1, transport, qos, repeat, published, queue, publish_cfg):
            pass

    # Initialize message sizes and repeat counts (if defined in the environment)
    messages = []
    for i in count(0):
        # Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
        env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
        if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
            messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']})
            continue
        break
    if not messages:    # No message sizes present in the env - set defaults
        messages = [{'len':0,   'repeat':5},    # zero-sized messages
                    {'len':2,   'repeat':10},   # short messages
                    {'len':200, 'repeat':3},    # long messages
                    {'len':20,  'repeat':50}    # many medium sized
                    ]

    # Iterate over all publish message properties
    for qos in [0, 1, 2]:
        for transport in ['tcp', 'ssl', 'ws', 'wss']:
            for q in [0, 1]:
                if publish_cfg['broker_host_' + transport] is None:
                    print('Skipping transport: {}...'.format(transport))
                    continue
                for msg in messages:
                    start_publish_case(transport, qos, msg['len'], msg['repeat'], q)


if __name__ == '__main__':
    test_app_protocol_mqtt_publish_connect(dut=ttfw_idf.ESP32QEMUDUT if sys.argv[1:] == ['qemu'] else ttfw_idf.ESP32DUT)
