import pytest
from src.core.query_router import QueryRouter, Mode


@pytest.fixture
def router():
    return QueryRouter()


def test_simple_query_is_sync(router):
    """测试简单查询使用同步模式"""
    mode, score = router.assess_complexity(
        "你好",
        {"required_sources": ["single"], "user_id": "user_1"}
    )
    assert mode == Mode.SYNC
    assert score < 0.3


def test_complex_query_is_async(router):
    """测试复杂查询使用异步模式"""
    mode, score = router.assess_complexity(
        "对比技术部和产品部的员工数量，分析人员配置建议",
        {"required_sources": ["hr", "finance"], "user_id": "user_1"}
    )
    assert mode == Mode.ASYNC
    assert score >= 0.3


def test_score_data_sources(router):
    """测试数据源评分"""
    assert router._score_data_sources({"required_sources": ["single"]}) == 0.0
    assert router._score_data_sources({"required_sources": ["a", "b"]}) == 0.5
    assert router._score_data_sources({"required_sources": ["a", "b", "c"]}) == 0.8
