"""Session management for PentestGPT - persistence and state tracking."""

import json
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any


class SessionStatus(Enum):
    """Session lifecycle status."""

    RUNNING = "running"
    PAUSED = "paused"
    COMPLETED = "completed"
    ERROR = "error"


@dataclass
class SessionInfo:
    """Session state - framework agnostic."""

    session_id: str
    target: str
    created_at: datetime
    status: SessionStatus = SessionStatus.RUNNING
    backend_session_id: str | None = None  # Backend-specific ID (e.g., Claude session)
    updated_at: datetime | None = None
    task: str = ""
    user_instructions: list[str] = field(default_factory=list)
    flags_found: list[dict[str, str]] = field(default_factory=list)
    total_cost_usd: float = 0.0
    model: str = ""
    last_error: str | None = None

    def to_dict(self) -> dict[str, Any]:
        """Serialize session to dictionary for JSON storage."""
        return {
            "session_id": self.session_id,
            "target": self.target,
            "created_at": self.created_at.isoformat(),
            "status": self.status.value,
            "backend_session_id": self.backend_session_id,
            "updated_at": self.updated_at.isoformat() if self.updated_at else None,
            "task": self.task,
            "user_instructions": self.user_instructions,
            "flags_found": self.flags_found,
            "total_cost_usd": self.total_cost_usd,
            "model": self.model,
            "last_error": self.last_error,
        }

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "SessionInfo":
        """Deserialize session from dictionary."""
        return cls(
            session_id=data["session_id"],
            target=data["target"],
            created_at=datetime.fromisoformat(data["created_at"]),
            status=SessionStatus(data["status"]),
            backend_session_id=data.get("backend_session_id"),
            updated_at=(
                datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None
            ),
            task=data.get("task", ""),
            user_instructions=data.get("user_instructions", []),
            flags_found=data.get("flags_found", []),
            total_cost_usd=data.get("total_cost_usd", 0.0),
            model=data.get("model", ""),
            last_error=data.get("last_error"),
        )


class SessionStore:
    """Simple file-based session persistence."""

    SESSIONS_DIR = Path.home() / ".pentestgpt" / "sessions"

    def __init__(self, sessions_dir: Path | None = None):
        """Initialize session store.

        Args:
            sessions_dir: Optional custom sessions directory
        """
        self._sessions_dir = sessions_dir or self.SESSIONS_DIR
        self._sessions_dir.mkdir(parents=True, exist_ok=True)
        self._current: SessionInfo | None = None

    def create(self, target: str, task: str, model: str) -> SessionInfo:
        """Create a new session.

        Args:
            target: Target URL/IP/domain
            task: Task description
            model: Model name

        Returns:
            New SessionInfo instance
        """
        session = SessionInfo(
            session_id=str(uuid.uuid4())[:8],
            target=target,
            created_at=datetime.now(),
            task=task,
            model=model,
        )
        self._current = session
        self.save()
        return session

    @property
    def current(self) -> SessionInfo | None:
        """Get current active session."""
        return self._current

    def save(self) -> None:
        """Save current session to disk."""
        if not self._current:
            return
        self._current.updated_at = datetime.now()
        path = self._sessions_dir / f"{self._current.session_id}.json"
        path.write_text(json.dumps(self._current.to_dict(), indent=2))

    def load(self, session_id: str) -> SessionInfo | None:
        """Load a session by ID.

        Args:
            session_id: Session ID to load

        Returns:
            SessionInfo if found, None otherwise
        """
        path = self._sessions_dir / f"{session_id}.json"
        if not path.exists():
            return None
        try:
            self._current = SessionInfo.from_dict(json.loads(path.read_text()))
            return self._current
        except (json.JSONDecodeError, KeyError, ValueError):
            return None

    def list_sessions(self, target: str | None = None) -> list[SessionInfo]:
        """List all sessions, optionally filtered by target.

        Args:
            target: Optional target filter

        Returns:
            List of SessionInfo, sorted by creation date (newest first)
        """
        sessions = []
        for path in self._sessions_dir.glob("*.json"):
            try:
                session = SessionInfo.from_dict(json.loads(path.read_text()))
                if target is None or session.target == target:
                    sessions.append(session)
            except (json.JSONDecodeError, KeyError, ValueError):
                continue
        return sorted(sessions, key=lambda s: s.created_at, reverse=True)

    def get_latest(self, target: str | None = None) -> SessionInfo | None:
        """Get the most recent session.

        Args:
            target: Optional target filter

        Returns:
            Most recent SessionInfo if any exist
        """
        sessions = self.list_sessions(target)
        return sessions[0] if sessions else None

    def delete(self, session_id: str) -> bool:
        """Delete a session by ID.

        Args:
            session_id: Session ID to delete

        Returns:
            True if deleted, False if not found
        """
        path = self._sessions_dir / f"{session_id}.json"
        if path.exists():
            path.unlink()
            if self._current and self._current.session_id == session_id:
                self._current = None
            return True
        return False

    # Convenience methods for updating current session

    def update_status(self, status: SessionStatus) -> None:
        """Update current session status."""
        if self._current:
            self._current.status = status
            self.save()

    def add_instruction(self, instruction: str) -> None:
        """Add a user instruction to current session."""
        if self._current:
            self._current.user_instructions.append(instruction)
            self.save()

    def add_flag(self, flag: str, context: str) -> None:
        """Add a found flag to current session."""
        if self._current:
            self._current.flags_found.append({"flag": flag, "context": context})
            self.save()

    def set_backend_session_id(self, backend_id: str) -> None:
        """Set the backend-specific session ID."""
        if self._current:
            self._current.backend_session_id = backend_id
            self.save()

    def add_cost(self, cost: float) -> None:
        """Add to total cost."""
        if self._current:
            self._current.total_cost_usd += cost
            self.save()

    def set_error(self, error: str) -> None:
        """Set last error."""
        if self._current:
            self._current.last_error = error
            self.save()
