# tests/integration/test_tasks_api.py
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient

import sys
sys.path.insert(0, '/home/jang/Projects/ai-researcher')

from src.main import app

client = TestClient(app)


@pytest.fixture
def mock_session_maker():
    """创建mock session maker用于集成测试"""
    mock_session = AsyncMock()
    mock_session.add = MagicMock()
    mock_session.flush = AsyncMock()
    mock_session.refresh = AsyncMock()
    mock_session.commit = AsyncMock()

    maker = MagicMock()
    maker.return_value.__aenter__ = AsyncMock(return_value=mock_session)
    maker.return_value.__aexit__ = AsyncMock(return_value=None)
    return maker, mock_session


def test_get_task_status(mock_session_maker):
    maker, mock_session = mock_session_maker
    # Use a message that triggers async mode (contains "对比" keyword)
    with patch('src.core.task_manager.get_session_maker', return_value=maker):
        response = client.post("/api/v1/chat", json={"message": "对比技术部和产品部员工数量", "context": {}})
        assert response.status_code == 202
        task_id = response.json()["task_id"]

        # Setup mock task for get_task_status endpoint
        mock_task = MagicMock()
        mock_task.id = task_id
        mock_task.status = "pending"
        mock_task.created_at = datetime.now(timezone.utc)
        mock_task.updated_at = datetime.now(timezone.utc)
        mock_task.completed_at = None
        mock_session.get.return_value = mock_task

        # Get task status
        response = client.get(f"/api/v1/tasks/{task_id}")
        assert response.status_code == 200
        data = response.json()
        assert data["task_id"] == task_id
        assert "status" in data


def test_get_task_stream(mock_session_maker):
    maker, _ = mock_session_maker
    with patch('src.core.task_manager.get_session_maker', return_value=maker):
        response = client.post("/api/v1/chat", json={"message": "对比技术部和产品部员工数量", "context": {}})
        assert response.status_code == 202
        task_id = response.json()["task_id"]

        # Test SSE endpoint
        response = client.get(f"/api/v1/tasks/{task_id}/stream", headers={"Accept": "text/event-stream"})
        assert response.status_code == 200
