"""查询路由器 - 评估查询复杂度并决定处理模式"""
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Tuple

from src.config import get_settings


class Mode(Enum):
    """处理模式"""
    SYNC = "sync"
    ASYNC = "async"


@dataclass
class QueryRouterResult:
    """路由结果"""
    mode: Mode
    score: float
    estimated_seconds: int


class QueryRouter:
    """查询路由器"""

    def __init__(self):
        settings = get_settings()
        self.threshold = settings.QUERY_ROUTER_COMPLEXITY_THRESHOLD
        self.weights = settings.QUERY_ROUTER_WEIGHTS

    def assess_complexity(self, query: str, context: dict) -> Tuple[Mode, float]:
        """
        评估查询复杂度

        Args:
            query: 用户查询文本
            context: 上下文信息

        Returns:
            (处理模式, 复杂度评分)
        """
        scores = {
            'data_sources': self._score_data_sources(context),
            'api_calls': self._estimate_api_calls(query),
            'reasoning_depth': self._estimate_reasoning(query),
            'historical_time': self._get_historical_avg(query),
        }

        score = (
            scores['data_sources'] * self.weights['data_sources'] +
            scores['api_calls'] * self.weights['api_calls'] +
            scores['reasoning_depth'] * self.weights['reasoning_depth'] +
            scores['historical_time'] * self.weights['historical_time']
        )

        mode = Mode.ASYNC if score >= self.threshold else Mode.SYNC
        return mode, score

    def _score_data_sources(self, context: dict) -> float:
        """根据数据源数量评分"""
        sources = context.get('required_sources', [])
        count = len(sources)
        if count == 1:
            return 0.0
        elif count == 2:
            return 0.5
        else:
            return 0.8

    def _estimate_api_calls(self, query: str) -> float:
        """估计API调用次数"""
        indicators = ['对比', '汇总', '多个', '所有', '分别', '各自']
        matches = sum(1 for ind in indicators if ind in query)
        if matches == 0:
            return 0.0
        elif matches <= 2:
            return 0.5
        else:
            return 0.8

    def _estimate_reasoning(self, query: str) -> float:
        """估计推理复杂度"""
        reasoning_keywords = ['为什么', '建议', '分析', '预测', '原因', '趋势']
        if any(kw in query for kw in reasoning_keywords):
            return 0.6
        elif '?' in query or '？' in query:
            return 0.3
        return 0.0

    def _get_historical_avg(self, query: str) -> float:
        """获取历史平均处理时间评分"""
        # TODO: 从Redis获取实际统计数据
        # 默认返回中等复杂度
        return 0.3
