"""`LLMResult` class."""

from __future__ import annotations

from copy import deepcopy
from typing import Literal

from pydantic import BaseModel

from langchain_core.outputs.chat_generation import ChatGeneration, ChatGenerationChunk
from langchain_core.outputs.generation import Generation, GenerationChunk
from langchain_core.outputs.run_info import RunInfo


class LLMResult(BaseModel):
    """A container for results of an LLM call.

    Both chat models and LLMs generate an `LLMResult` object. This object contains the
    generated outputs and any additional information that the model provider wants to
    return.
    """

    generations: list[
        list[Generation | ChatGeneration | GenerationChunk | ChatGenerationChunk]
    ]
    """Generated outputs.

    The first dimension of the list represents completions for different input prompts.

    The second dimension of the list represents different candidate generations for a
    given prompt.

    - When returned from **an LLM**, the type is `list[list[Generation]]`.
    - When returned from a **chat model**, the type is `list[list[ChatGeneration]]`.

    `ChatGeneration` is a subclass of `Generation` that has a field for a structured
    chat message.
    """

    llm_output: dict | None = None
    """For arbitrary model provider-specific output.

    This dictionary is a free-form dictionary that can contain any information that the
    provider wants to return. It is not standardized and keys may vary by provider and
    over time.

    Users should generally avoid relying on this field and instead rely on accessing
    relevant information from standardized fields present in AIMessage.
    """

    run: list[RunInfo] | None = None
    """List of metadata info for model call for each input.

    See `langchain_core.outputs.run_info.RunInfo` for details.
    """

    type: Literal["LLMResult"] = "LLMResult"
    """Type is used exclusively for serialization purposes."""

    def flatten(self) -> list[LLMResult]:
        """Flatten generations into a single list.

        Unpack `list[list[Generation]] -> list[LLMResult]` where each returned
        `LLMResult` contains only a single `Generation`. If token usage information is
        available, it is kept only for the `LLMResult` corresponding to the top-choice
        `Generation`, to avoid over-counting of token usage downstream.

        Returns:
            List of `LLMResult` objects where each returned `LLMResult` contains a
                single `Generation`.
        """
        llm_results = []
        for i, gen_list in enumerate(self.generations):
            # Avoid double counting tokens in OpenAICallback
            if i == 0:
                llm_results.append(
                    LLMResult(
                        generations=[gen_list],
                        llm_output=self.llm_output,
                    )
                )
            else:
                if self.llm_output is not None:
                    llm_output = deepcopy(self.llm_output)
                    llm_output["token_usage"] = {}
                else:
                    llm_output = None
                llm_results.append(
                    LLMResult(
                        generations=[gen_list],
                        llm_output=llm_output,
                    )
                )
        return llm_results

    def __eq__(self, other: object) -> bool:
        """Check for `LLMResult` equality by ignoring any metadata related to runs.

        Args:
            other: Another `LLMResult` object to compare against.

        Returns:
            `True` if the generations and `llm_output` are equal, `False` otherwise.
        """
        if not isinstance(other, LLMResult):
            return NotImplemented
        return (
            self.generations == other.generations
            and self.llm_output == other.llm_output
        )

    __hash__ = None  # type: ignore[assignment]
