"""任务管理器"""
import uuid
from datetime import datetime, timezone
from typing import Optional

from src.models.task import Task
from src.models.base import get_session_maker


class TaskManager:
    """任务管理器 - 管理任务生命周期"""

    async def create_task(
        self,
        query: str,
        complexity_score: float = 0.0,
        user_id: Optional[str] = None,
        conversation_id: Optional[str] = None,
    ) -> Task:
        """创建新任务"""
        task = Task(
            id=f"task_{uuid.uuid4().hex[:12]}",
            status=Task.STATUS_PENDING,
            query=query,
            complexity_score=complexity_score,
            user_id=user_id,
            conversation_id=conversation_id,
        )

        session_maker = get_session_maker()
        async with session_maker() as session:
            session.add(task)
            await session.flush()
            await session.refresh(task)
            await session.commit()

        return task

    async def get_task(self, task_id: str) -> Optional[Task]:
        """获取任务"""
        session_maker = get_session_maker()
        async with session_maker() as session:
            result = await session.get(Task, task_id)
            return result

    async def start_task(self, task_id: str) -> None:
        """标记任务开始"""
        session_maker = get_session_maker()
        async with session_maker() as session:
            task = await session.get(Task, task_id)
            if task:
                task.status = Task.STATUS_RUNNING
                task.started_at = datetime.now(timezone.utc)
                await session.commit()

    async def complete_task(self, task_id: str, skills_invoked: list = None) -> None:
        """标记任务完成"""
        session_maker = get_session_maker()
        async with session_maker() as session:
            task = await session.get(Task, task_id)
            if task:
                task.status = Task.STATUS_COMPLETED
                task.completed_at = datetime.now(timezone.utc)
                if skills_invoked:
                    task.skills_invoked = skills_invoked
                await session.commit()

    async def fail_task(self, task_id: str, error: dict) -> None:
        """标记任务失败"""
        session_maker = get_session_maker()
        async with session_maker() as session:
            task = await session.get(Task, task_id)
            if task:
                task.status = Task.STATUS_FAILED
                task.completed_at = datetime.now(timezone.utc)
                task.error = error
                await session.commit()

    async def cancel_task(self, task_id: str) -> None:
        """取消任务"""
        session_maker = get_session_maker()
        async with session_maker() as session:
            task = await session.get(Task, task_id)
            if task and task.status in [Task.STATUS_PENDING, Task.STATUS_RUNNING]:
                task.status = Task.STATUS_CANCELLED
                task.completed_at = datetime.now(timezone.utc)
                await session.commit()
