# SPDX-FileCopyrightText: 2015 Sebastian Wagner
#
# SPDX-License-Identifier: AGPL-3.0-or-later

# -*- encoding: utf-8 -*-
"""
Testing the utility functions of intelmq.

Decoding and Encoding, Logging functionality (file and stream), and log
parsing.
base64 de-/encoding is not tested yet, as we fully rely on the module.
"""
import contextlib
import datetime
import io
import json
import os
import pprint
import tempfile
import unittest
import unittest.mock

import cerberus
import dns.resolver
import pkg_resources
import requests
import termstyle
from ruamel.yaml.scanner import ScannerError

import intelmq.lib.utils as utils
from intelmq.lib.test import skip_internet
from intelmq.tests.test_conf import CerberusTests

try:
    from importlib.metadata import EntryPoint
except ImportError:
    from importlib_metadata import EntryPoint


LINES = {'spare': ['Lorem', 'ipsum', 'dolor'],
         'short': ['{}: Lorem', '{}: ipsum',
                   '{}: dolor'],
         'long': [r'\A[-0-9]{{10}} [0-9:]{{8}},\d{{3}} - {} - INFO - Lorem\Z',
                  r'\A[-0-9]{{10}} [0-9:]{{8}},\d{{3}} - {} - ERROR - ipsum\Z',
                  r'\A[-0-9]{{10}} [0-9:]{{8}},\d{{3}} - {} - CRITICAL - dolor\Z'],
         }
SAMPLES = {'normal': [b'Lorem ipsum dolor sit amet',
                      'Lorem ipsum dolor sit amet'],
           'unicode': [b'\xc2\xa9\xc2\xab\xc2\xbb \xc2\xa4\xc2\xbc',
                       '©«» ¤¼']}


def new_get_runtime() -> dict:
    runtime_conf = utils.load_configuration(pkg_resources.resource_filename('intelmq', 'etc/runtime.yaml'))
    if 'global' not in runtime_conf:
        runtime_conf['global'] = {}
    runtime_conf['global']['http_proxy'] = 'http://localhost:8080'
    runtime_conf['global']['https_proxy'] = 'http://localhost:8080'
    runtime_conf['cymru-whois-expert']['parameters']['http_proxy'] = 'http://localhost:8081'
    return runtime_conf


class TestUtils(unittest.TestCase):

    def test_decode_byte(self):
        """Tests if the decode can handle bytes."""
        self.assertEqual(SAMPLES['normal'][1],
                         utils.decode(SAMPLES['normal'][0]))

    def test_decode_bytes_unicode(self):
        """Tests if the decode can handle bytes."""
        self.assertEqual(SAMPLES['unicode'][1],
                         utils.decode(SAMPLES['unicode'][0]))

    def test_encode_byte(self):
        """Tests if the decode can handle bytes."""
        self.assertEqual(SAMPLES['normal'][0],
                         utils.encode(SAMPLES['normal'][1]))

    def test_encode_unicode(self):
        """Tests if the decode can handle bytes."""
        self.assertEqual(SAMPLES['unicode'][0],
                         utils.encode(SAMPLES['unicode'][1]))

    def test_decode_ascii(self):
        """ Test ASCII decoding enforcement. """
        self.assertEqual('fobar',
                         utils.decode(b'fo\xe4bar', encodings=('ascii', ),
                                      force=True))

    def test_decode_unicode(self):
        """ Test decoding with unicode string. """
        self.assertEqual('foobar', utils.decode('foobar'))

    def test_encode_bytes(self):
        """ Test encoding with bytes string. """
        self.assertEqual('foobar', utils.decode(b'foobar'))

    def test_encode_force(self):
        """ Test ASCII encoding enforcement. """
        self.assertEqual(b'fobar',
                         utils.encode('fo\xe4bar', encodings=('ascii', ),
                                      force=True))

    def test_file_logger(self):
        """Tests if a logger for a file can be generated with log()."""

        with tempfile.NamedTemporaryFile(suffix=".log", mode='w+') as handle:
            filename = handle.name
            name = os.path.splitext(os.path.split(filename)[-1])[0]
            logger = utils.log(name, log_path=tempfile.tempdir,
                               stream=io.StringIO())

            logger.info(termstyle.green(LINES['spare'][0]))
            logger.error(LINES['spare'][1])
            logger.critical(LINES['spare'][2])
            handle.seek(0)
            file_lines = handle.readlines()

            line_format = [line.format(name) for line in LINES['long']]
            for ind, line in enumerate(file_lines):
                self.assertRegex(line.strip(), line_format[ind])

    def test_stream_logger_given(self):
        """
        Tests if a logger for a stream can be generated with log()
        if the stream is explicitly given.
        """

        stream = io.StringIO()
        with tempfile.NamedTemporaryFile() as handle:
            filename = handle.name
            name = os.path.split(filename)[-1]
            logger = utils.log(name, log_path=tempfile.tempdir, stream=stream)

            logger.info(LINES['spare'][0])
            logger.error(LINES['spare'][1])
            logger.critical(LINES['spare'][2])

            stream_lines = stream.getvalue().splitlines()

            line_format = [line.format(name) for line in LINES['short']]
            self.assertSequenceEqual(line_format, stream_lines)

    def test_stream_logger(self):
        stdout = io.StringIO()
        stderr = io.StringIO()
        with contextlib.redirect_stdout(stdout):
            with contextlib.redirect_stderr(stderr):
                logger = utils.log('test-bot', log_path=None)
                logger.info(LINES['spare'][0])
                logger.error(LINES['spare'][1])
                logger.critical(LINES['spare'][2])
        line_format = [line.format('test-bot') for line in LINES['short']]
        self.assertEqual(stdout.getvalue(), line_format[0] + '\n')
        self.assertEqual(stderr.getvalue(),
                         '\n'.join((termstyle.red(line_format[1]),
                                    termstyle.red(line_format[2]))) + '\n')

    def test_parse_logline(self):
        """Tests if the parse_logline() function works as expected"""
        line = ("2015-05-29 21:00:24,379 - malware-domain-list-collector - "
                "ERROR - Something went wrong")
        thread = ("2015-05-29 21:00:24,379 - malware-domain-list-collector.4 - "
                  "ERROR - Something went wrong")

        fields = utils.parse_logline(line)
        self.assertDictEqual({'date': '2015-05-29T21:00:24.379000',
                              'bot_id': 'malware-domain-list-collector',
                              'thread_id': None,
                              'log_level': 'ERROR',
                              'message': 'Something went wrong'},
                             fields)
        fields = utils.parse_logline(thread)
        self.assertDictEqual({'date': '2015-05-29T21:00:24.379000',
                              'bot_id': 'malware-domain-list-collector',
                              'thread_id': 4,
                              'log_level': 'ERROR',
                              'message': 'Something went wrong'},
                             fields)

    def test_parse_logline_invalid(self):
        """Tests if the parse_logline() function returns the line. """
        line = ("    report = self.receive_message()\n  File"
                " \"/usr/local/lib/python3.4/dist-packages/intelmq-1.0.0"
                "-py3.4.egg/intelmq/lib/bot.py\", line 259, in"
                " receive_message")

        actual = utils.parse_logline(line)
        self.assertEqual(line, actual)

    def test_parse_logline_syslog(self):
        """Tests if the parse_logline() function parses syslog correctly. """
        line = ("Feb 22 10:17:10 host malware-domain-list-collector: ERROR "
                "Something went wrong")
        thread = ("Feb 22 10:17:10 host malware-domain-list-collector.4: ERROR "
                  "Something went wrong")

        actual = utils.parse_logline(line, regex=utils.SYSLOG_REGEX)
        self.assertEqual({'bot_id': 'malware-domain-list-collector',
                          'date': '%d-02-22T10:17:10' % datetime.datetime.now().year,
                          'thread_id': None,
                          'log_level': 'ERROR',
                          'message': 'Something went wrong'}, actual)
        actual = utils.parse_logline(thread, regex=utils.SYSLOG_REGEX)
        self.assertEqual({'bot_id': 'malware-domain-list-collector',
                          'date': '%d-02-22T10:17:10' % datetime.datetime.now().year,
                          'thread_id': 4,
                          'log_level': 'ERROR',
                          'message': 'Something went wrong'}, actual)

    def test_error_message_from_exc(self):
        """Tests if error_message_from_exc correctly returns the error message."""
        exc = IndexError('This is a test')
        self.assertEqual(utils.error_message_from_exc(exc), 'This is a test')

    def test_parse_relative(self):
        """Tests if parse_relative returns the correct timespan."""
        self.assertEqual(utils.parse_relative('1 hour'), 60)
        self.assertEqual(utils.parse_relative('2\tyears'), 1051200)
        self.assertEqual(utils.parse_relative('5 minutes'), 5)
        self.assertEqual(utils.parse_relative('10 seconds'), 1 / 60 * 10)

    def test_parse_relative_raises(self):
        """Tests if parse_relative correctly raises ValueError."""
        with self.assertRaises(ValueError):
            utils.parse_relative('1 hou')
        with self.assertRaises(ValueError):
            utils.parse_relative('1 µs')

    def test_seconds_to_human(self):
        """ Test seconds_to_human """
        self.assertEqual(utils.seconds_to_human(60), '1m')
        self.assertEqual(utils.seconds_to_human(3600), '1h')
        self.assertEqual(utils.seconds_to_human(86401), '1d 1s')
        self.assertEqual(utils.seconds_to_human(64.2), '1m 4s')
        self.assertEqual(utils.seconds_to_human(64.2, precision=1),
                         '1.0m 4.2s')

    def test_version_smaller(self):
        """ Test version_smaller """
        self.assertTrue(utils.version_smaller((1, 0, 0), (1, 1, 0)))
        self.assertTrue(utils.version_smaller((1, 0, 0), (1, 0, 1, 'alpha')))
        self.assertFalse(utils.version_smaller((1, 0, 0, 'beta', 3), (1, 0, 0, 'alpha', 0)))
        self.assertFalse(utils.version_smaller((1, 0, 0), (1, 0, 0, 'alpha', 99)))
        self.assertFalse(utils.version_smaller((1, 0, 0), (1, 0, 0, 'beta')))

    def test_unzip_tar_gz(self):
        """ Test the unzip function with a tar gz file. """
        filename = os.path.join(os.path.dirname(__file__), '../assets/two_files.tar.gz')
        with open(filename, 'rb') as fh:
            result = utils.unzip(fh.read(), extract_files=True)
        self.assertEqual(tuple(result), (b'bar text\n', b'foo text\n'))

    def test_unzip_tar_gz_return_names(self):
        """ Test the unzip function with a tar gz file and return_names. """
        filename = os.path.join(os.path.dirname(__file__), '../assets/two_files.tar.gz')
        with open(filename, 'rb') as fh:
            result = utils.unzip(fh.read(), extract_files=True, return_names=True)
        self.assertEqual(tuple(result), (('bar', b'bar text\n'),
                                         ('foo', b'foo text\n')))

    def test_unzip_tar_gz_with_subdir(self):
        """ Test the unzip function with a tar gz file containing a subdirectory and return_names. Test that the directories themselves are ignored. """
        filename = os.path.join(os.path.dirname(__file__), '../assets/subdir.tar.gz')
        with open(filename, 'rb') as fh:
            result = utils.unzip(fh.read(), extract_files=True, return_names=True)
        self.assertEqual(tuple(result), (('subdir/foo', b'foo text\n'),
                                         ('subdir/bar', b'bar text\n')))

    def test_unzip_gz(self):
        """ Test the unzip function with a gz file. """
        filename = os.path.join(os.path.dirname(__file__), '../assets/foobar.gz')
        with open(filename, 'rb') as fh:
            result = utils.unzip(fh.read(), extract_files=True)
        self.assertEqual(result, (b'bar text\n', ))

    def test_unzip_gz_name(self):
        """ Test the unzip function with a gz file. """
        filename = os.path.join(os.path.dirname(__file__), '../assets/foobar.gz')
        with open(filename, 'rb') as fh:
            result = utils.unzip(fh.read(), extract_files=True, return_names=True)
        self.assertEqual(result, ((None, b'bar text\n'), ))

    def test_unzip_zip(self):
        """ Test the unzip function with a zip file. """
        filename = os.path.join(os.path.dirname(__file__), '../assets/two_files.zip')
        with open(filename, 'rb') as fh:
            result = utils.unzip(fh.read(), extract_files=True)
        self.assertEqual(tuple(result), (b'bar text\n', b'foo text\n'))

    def test_unzip_zip_return_names(self):
        """ Test the unzip function with a zip file and return_names. """
        filename = os.path.join(os.path.dirname(__file__), '../assets/two_files.zip')
        with open(filename, 'rb') as fh:
            result = utils.unzip(fh.read(), extract_files=True, return_names=True)
        self.assertEqual(tuple(result), (('bar', b'bar text\n'),
                                         ('foo', b'foo text\n')))

    def test_unzip_zip_with_subdir(self):
        """ Test the unzip function with a zip containing a subdirectory and returning names. Test that the directories themselves are ignored."""
        filename = os.path.join(os.path.dirname(__file__), '../assets/subdir.zip')
        with open(filename, 'rb') as fh:
            result = utils.unzip(fh.read(), extract_files=True, return_names=True)
        self.assertEqual(tuple(result), (('subdir/bar', b'bar text\n'),
                                         ('subdir/foo', b'foo text\n')))

    def test_file_name_from_response(self):
        """ test file_name_from_response """
        response = requests.Response()
        response.headers['Content-Disposition'] = 'attachment; filename=2019-09-09-drone_brute_force-austria-geo.csv'
        self.assertEqual(utils.file_name_from_response(response),
                         '2019-09-09-drone_brute_force-austria-geo.csv')

    def test_list_all_bots(self):
        """ test list_all_bots """
        bots_list = utils.list_all_bots()
        test = CerberusTests()
        with open(os.path.join(os.path.dirname(__file__), '../assets/bots.schema.json')) as handle:
            schema = json.loads(test.convert_cerberus_schema(handle.read()))

        v = cerberus.Validator(schema)

        self.assertTrue(v.validate(bots_list),
                        msg='Invalid BOTS list:\n%s' % pprint.pformat(v.errors))

    def test_list_all_bots_ignores_bots_with_syntax_error(self):
        original_import = utils.importlib.import_module
        effects = [SyntaxError, original_import, SyntaxError]

        def _mock_importing(module):
            if len(effects) == 1:
                return effects[0](module)
            return effects.pop()(module)

        with unittest.mock.patch.object(utils.importlib, "import_module") as import_mock:
            import_mock.side_effect = _mock_importing
            bots = utils.list_all_bots()

        bot_count = sum([len(val) for val in bots.values()])
        self.assertEqual(1, bot_count)

    def test_list_all_bots_filters_entrypoints(self):
        entries = [
            EntryPoint("intelmq.bots.collector.api.collector_api",
                       "intelmq.bots.collector.api.collector_api:BOT.run", group="console_scripts"),
            EntryPoint("intelmq.bots.collector.awesome.my_bot",
                       "awesome.extension.package.collector:BOT.run", group="console_scripts"),
            EntryPoint("not.a.bot", "not.a.bot:run", group="console_scripts")
        ]

        with unittest.mock.patch.object(utils, "_get_console_entry_points", return_value=entries):
            with unittest.mock.patch.object(utils.importlib, "import_module") as import_mock:
                import_mock.side_effect = SyntaxError()  # stop processing after import try
                utils.list_all_bots()

        import_mock.assert_has_calls(
            [
                unittest.mock.call("intelmq.bots.collector.api.collector_api"),
                unittest.mock.call("awesome.extension.package.collector"),
            ]
        )
        self.assertEqual(2, import_mock.call_count)

    def test_get_bot_module_name_builtin_bot(self):
        found_name = utils.get_bot_module_name("intelmq.bots.collectors.api.collector_api")
        self.assertEqual("intelmq.bots.collectors.api.collector_api", found_name)

        self.assertIsNone(utils.get_bot_module_name("intelmq.not-existing-bot"))

    def test_get_bots_settings(self):
        with unittest.mock.patch.object(utils, "get_runtime", new_get_runtime):
            runtime = utils.get_bots_settings()
        self.assertEqual(runtime['cymru-whois-expert']['parameters']['http_proxy'], 'http://localhost:8081')
        self.assertEqual(runtime['deduplicator-expert']['parameters']['http_proxy'], 'http://localhost:8080')

        with unittest.mock.patch.object(utils, "get_runtime", new_get_runtime):
            cymru = utils.get_bots_settings('cymru-whois-expert')
        self.assertEqual(cymru['parameters']['http_proxy'], 'http://localhost:8081')

        with unittest.mock.patch.object(utils, "get_runtime", new_get_runtime):
            deduplicator = utils.get_bots_settings('deduplicator-expert')
        self.assertEqual(deduplicator['parameters']['http_proxy'], 'http://localhost:8080')

    def test_get_global_settings(self):
        with unittest.mock.patch.object(utils, "get_runtime", new_get_runtime):
            defaults = utils.get_global_settings()
        self.assertEqual(defaults['http_proxy'], 'http://localhost:8080')
        self.assertEqual(defaults['https_proxy'], 'http://localhost:8080')

    def test_load_configuration_json(self):
        """ Test load_configuration with a JSON file containing space whitespace """
        filename = os.path.join(os.path.dirname(__file__), '../assets/foobar.json')
        self.assertEqual(utils.load_configuration(filename), {'foo': 'bar'})

    def test_load_configuration_json_tabs(self):
        """ Test load_configuration with a JSON file containing tab whitespace """
        filename = os.path.join(os.path.dirname(__file__), '../assets/tab-whitespace.json')
        self.assertEqual(utils.load_configuration(filename), {'foo': 'bar'})

    def test_load_configuration_yaml(self):
        """ Test load_configuration with a YAML file """
        filename = os.path.join(os.path.dirname(__file__), '../assets/example.yaml')
        self.assertEqual(utils.load_configuration(filename),
                         {
            'some_string': 'Hello World!',
            'other_string': 'with a : in it',
                            'now more': ['values', 'in', 'a', 'list'],
                            'types': -4,
                            'other': True,
                            'final': 0.5,
        }
        )

    def test_load_configuration_yaml_invalid(self):
        """ Test load_configuration with an invalid YAML file """
        filename = os.path.join(os.path.dirname(__file__), '../assets/example-invalid.yaml')
        with self.assertRaises(ScannerError):
            utils.load_configuration(filename)

    @skip_internet()
    def test_resolve_dns_returns_answer(self):
        answer = utils.resolve_dns("example.com")
        self.assertIsInstance(answer, dns.resolver.Answer)


if __name__ == '__main__':  # pragma: no cover
    unittest.main()
