from __future__ import annotations

from enum import IntEnum
from typing import TYPE_CHECKING, Any
from uuid import UUID

from dissect.util.sid import read_sid

from dissect.etl.c_etl import c_etl
from dissect.etl.exceptions import ExtendedDataItemException
from dissect.etl.headers.headers import Header

if TYPE_CHECKING:
    from dissect.cstruct import CharArray


def read_uuid(data: bytes) -> UUID:
    uuid_data = c_etl.char[16](data)
    return UUID(bytes_le=uuid_data)


def read_instance_info(data: bytes) -> dict[str, Any]:
    instance_info = c_etl.EVENT_HEADER_EXT_TYPE_ITEM_INSTANCE(data)
    output_dict = instance_info.__values__
    parent_guid = read_uuid(output_dict.get("ParentGuid"))
    output_dict["ParentGuid"] = f"{parent_guid}"
    return output_dict


def read_stack_trace(data: bytes) -> dict[str, Any]:
    return c_etl.EVENT_HEADER_EXT_TYPE_STACK_TRACE32(data).__values__


def read_stack_trace64(data: bytes) -> dict[str, Any]:
    return c_etl.EVENT_HEADER_EXT_TYPE_STACK_TRACE64(data).__values__


def read_provider_traits(data: bytes) -> dict[str, Any]:
    provider_traits = c_etl.EVENT_HEADER_EXT_TYPE_PROVIDER_TRAIT(data)
    output_dict = provider_traits.__values__
    trait_offset = len(provider_traits)
    traits = []
    while trait_offset < provider_traits.TraitSize:
        trait = c_etl.TRAIT(data[trait_offset:])
        traits.append(trait.__values__)
        trait_offset += trait.TraitSize
    return {**output_dict, "Traits": traits}


class EventDescriptor:
    """An representation of the Event data in a event header."""

    __slots__ = [
        "channel",
        "id",
        "keywords",
        "level",
        "opcode",
        "task",
        "version",
    ]

    def __init__(self, header: Header):
        self.id = header.Id
        self.version = header.Version
        self.channel = header.Channel
        self.level = header.Level
        self.opcode = header.OpCode
        self.task = header.Task
        self.keywords = header.Keywords


class ExtType(IntEnum):
    RELATED_ACTIVITY_ID = 0x1
    SID = 0x2
    TS_ID = 0x3
    INSTANCE_INFO = 0x4
    STACK_TRACE32 = 0x5
    STACK_TRACE64 = 0x6
    PEBS_INDEX = 0x7
    PMC_COUNTERS = 0x8
    PSM_KEY = 0x9
    EVENT_KEY = 0xA
    EVENT_SCHEMA_TL = 0xB
    PROV_TRAITS = 0xC
    PROCESS_START_KEY = 0xD
    TYPE_MAX = 0xE
    UNKNOWN = 0x0


EXTENDED_DATA_READERS = {
    ExtType.RELATED_ACTIVITY_ID: lambda x: {"Guid": f"{read_uuid(x)}"},
    ExtType.SID: lambda x: {"Sid": read_sid(x)},
    ExtType.TS_ID: lambda x: {"SessionId": c_etl.uint32(x)},
    ExtType.INSTANCE_INFO: read_instance_info,
    ExtType.STACK_TRACE32: read_stack_trace,
    ExtType.STACK_TRACE64: read_stack_trace64,
    ExtType.PEBS_INDEX: lambda x: {"PebsIndex": c_etl.uint32(x)},
    ExtType.PMC_COUNTERS: lambda x: {"PmcCounters": c_etl.uint64[len(x) // 8](x)},
    ExtType.PSM_KEY: lambda x: {"PsmKey": c_etl.uint64(x)},
    ExtType.EVENT_KEY: lambda x: {"EventKey": c_etl.uint64(x)},
    ExtType.EVENT_SCHEMA_TL: lambda x: {"EventSchema": c_etl.char[len(x)](x)},
    ExtType.PROV_TRAITS: read_provider_traits,
    ExtType.PROCESS_START_KEY: lambda x: {"ProcessStartKey": c_etl.uint64(x)},
    ExtType.TYPE_MAX: lambda x: {"Max": c_etl.char[len(x)](x)},
}


class EventHeaderExtendedDataItem:
    """Loads an extended data item from payload."""

    __slots__ = [
        "data",
        "data_size",
        "ext_type",
        "linkage",
        "raw_data",
        "reserved1",
        "reserved2",
        "size",
    ]

    def __init__(self, payload: bytes):
        header = c_etl.EventHeaderExtendedDataItemHeader(payload)
        self.size = header.Size
        self.ext_type = self._extension_type(header.ExtType)
        self.reserved1 = header.Reserved1
        self.data_size = header.DataSize
        self.data = self._read_extension_type(self.ext_type, header.Data)
        self.raw_data = header.Data
        self.linkage = 0
        self.reserved2 = 0

        self.validate_header()

    def validate_header(self) -> None:
        if self.size < 8:
            raise ExtendedDataItemException("DataItem size smaller than 8 bytes.")

        if self.size % 8:
            raise ExtendedDataItemException("DataItem size not aligned with 8 bytes.")

        if self.data_size > self.size - 8:
            raise ExtendedDataItemException("Data size larger than DataItem size.")

    def _extension_type(self, item_type: int) -> ExtType:
        try:
            return ExtType(item_type)
        except ValueError:
            return ExtType.UNKNOWN

    def _read_extension_type(self, ext_type: ExtType, data: CharArray) -> dict[str, Any]:
        reader = EXTENDED_DATA_READERS.get(ext_type)
        return reader(data) if reader else {}

    def __getattr__(self, name: str) -> Any:
        try:
            return object.__getattribute__(self, name)
        except AttributeError:
            pass
        return self.data.get(name)

    def __repr__(self) -> str:
        return (
            f"<EventHeaderExtendedDataItem Size={self.size} Reserved1={self.reserved1} ExtType={self.ext_type} "
            f"Linkage={self.linkage} Reserved2={self.reserved2} DataSize={self.data_size}>"
        )


class EventHeader(Header):
    @property
    def descriptor(self) -> EventDescriptor:
        """Event descriptor of the header."""
        return EventDescriptor(self.header)

    @property
    def header_extensions(self) -> list[EventHeaderExtendedDataItem]:
        """A list with all the extended data items for this Event."""
        return self._read_extensions()

    @property
    def minimal_size(self) -> int:
        """Minimum header size."""
        return 0x50

    @property
    def _header_type(self) -> c_etl.EventHeader:
        """Type of header that will get parsed."""
        return c_etl.EventHeader

    def _read_extensions(self) -> list[EventHeaderExtendedDataItem]:
        """Read header extensions from the payload"""
        count = 0
        items = []
        payload_pos = 0
        payload_size = len(self.payload)

        while True:
            if count >= 13:
                break

            if payload_pos + 8 > payload_size:
                break

            try:
                header_extension = EventHeaderExtendedDataItem(self.payload[payload_pos:])
            except (EOFError, ExtendedDataItemException):
                break

            items.append(header_extension)

            count += 1
            payload_pos += header_extension.size

        return items

    @property
    def provider_id(self) -> UUID:
        """Provider that generated this event."""
        return UUID(bytes=self.header.ProviderId)

    @property
    def activity_id(self) -> UUID:
        """The ID associated with the activity in the event.

        At least, that is my assumption.
        """
        return UUID(bytes_le=self.header.ActivityId)

    @property
    def opcode(self) -> int:
        """The opcode used in this event."""
        return self.header.OpCode

    @property
    def thread_id(self) -> int:
        """The thread id that created this event."""
        return self.header.ThreadId

    @property
    def process_id(self) -> int:
        """The process id that created this event."""
        return self.header.ProcessId

    def additional_header_fields(self) -> dict[str, Any]:
        basic_information = {
            "ThreadId": self.thread_id,
            "ProcessId": self.process_id,
            "ActivityId": f"{self.activity_id}",
        }
        extensions = []
        for extension in self.header_extensions:
            extension_dict = {
                "ExtType": extension.ext_type,
            }
            extension_dict.update(extension.data)
            extensions.append(extension_dict)
        basic_information.update({"Extensions": extensions})
        return basic_information
