from __future__ import annotations

import functools
import logging
from collections.abc import AsyncIterator, Mapping, Sequence
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Optional,
    TypeVar,
    Union,
)

from typing_extensions import Self, TypedDict

from langsmith import client as ls_client
from langsmith import run_helpers
from langsmith._internal._orjson import dumps as _dumps
from langsmith.schemas import InputTokenDetails, UsageMetadata

if TYPE_CHECKING:
    import httpx
    from anthropic import Anthropic, AsyncAnthropic
    from anthropic.lib.streaming import AsyncMessageStream, MessageStream
    from anthropic.types import Completion, Message, MessageStreamEvent

C = TypeVar("C", bound=Union["Anthropic", "AsyncAnthropic", Any])
logger = logging.getLogger(__name__)


@functools.lru_cache
def _get_not_given() -> Optional[tuple[type, ...]]:
    try:
        from anthropic._types import NotGiven, Omit

        return (NotGiven, Omit)
    except ImportError:
        return None


def _strip_not_given(d: dict) -> dict:
    try:
        if not_given := _get_not_given():
            d = {
                k: v
                for k, v in d.items()
                if not any(isinstance(v, t) for t in not_given)
            }
    except Exception as e:
        logger.error(f"Error stripping NotGiven: {e}")

    if "system" in d:
        d["messages"] = [{"role": "system", "content": d["system"]}] + d.get(
            "messages", []
        )
        d.pop("system")
    return {k: v for k, v in d.items() if v is not None}


def _infer_ls_params(prepopulated_invocation_params: dict, kwargs: dict):
    stripped = _strip_not_given(kwargs)

    stop = stripped.get("stop")
    if stop and isinstance(stop, str):
        stop = [stop]

    # Allowlist of safe invocation parameters to include
    # Only include known, non-sensitive parameters
    allowed_invocation_keys = {
        "mcp_servers",
        "service_tier",
        "tool_choice",
        "top_k",
        "top_p",
        "stream",
        "thinking",
    }

    # Only include allowlisted parameters
    invocation_params = {
        k: v for k, v in stripped.items() if k in allowed_invocation_keys
    }

    return {
        "ls_provider": "anthropic",
        "ls_model_type": "chat",
        "ls_model_name": stripped.get("model", None),
        "ls_temperature": stripped.get("temperature", None),
        "ls_max_tokens": stripped.get("max_tokens", None),
        "ls_stop": stop,
        "ls_invocation_params": {
            **prepopulated_invocation_params,
            **invocation_params,
        },
    }


@functools.lru_cache
def _get_sdk_accumulate_event() -> Optional[Callable]:
    try:
        from anthropic.lib.streaming._messages import accumulate_event

        return accumulate_event
    except ImportError:
        return None


def _create_usage_metadata(anthropic_token_usage: dict) -> UsageMetadata:
    input_tokens = anthropic_token_usage.get("input_tokens") or 0
    output_tokens = anthropic_token_usage.get("output_tokens") or 0

    input_token_details: dict = {}
    cache_read = anthropic_token_usage.get("cache_read_input_tokens") or 0
    if cache_read:
        input_token_details["cache_read"] = cache_read

    cache_creation_obj = anthropic_token_usage.get("cache_creation") or {}
    if cache_creation_obj:
        ephemeral_5m = cache_creation_obj.get("ephemeral_5m_input_tokens") or 0
        ephemeral_1h = cache_creation_obj.get("ephemeral_1h_input_tokens") or 0
        if ephemeral_5m:
            input_token_details["ephemeral_5m_input_tokens"] = ephemeral_5m
        if ephemeral_1h:
            input_token_details["ephemeral_1h_input_tokens"] = ephemeral_1h
    else:
        cache_creation = anthropic_token_usage.get("cache_creation_input_tokens") or 0
        if cache_creation:
            input_token_details["cache_creation"] = cache_creation

    # Anthropic cache tokens are ADDITIVE (not subsets of input_tokens like OpenAI).
    # Sum them into input_tokens so the backend cost calculation is correct.
    cache_token_sum = sum(input_token_details.values())
    adjusted_input = input_tokens + cache_token_sum
    adjusted_total = adjusted_input + output_tokens

    result = UsageMetadata(
        input_tokens=adjusted_input,
        output_tokens=output_tokens,
        total_tokens=adjusted_total,
    )
    if input_token_details:
        result["input_token_details"] = InputTokenDetails(**input_token_details)
    return result


def _message_to_outputs(message: Any) -> dict:
    """Convert an Anthropic Message to a flat outputs dict with usage_metadata."""
    outputs = message.model_dump()
    anthropic_token_usage = outputs.pop("usage", None)
    if anthropic_token_usage:
        outputs["usage_metadata"] = _create_usage_metadata(anthropic_token_usage)
    outputs.pop("type", None)

    content = outputs.get("content") or []
    tool_use_blocks = [
        b for b in content if isinstance(b, dict) and b.get("type") == "tool_use"
    ]
    if tool_use_blocks:
        text_parts = [
            b.get("text", "")
            for b in content
            if isinstance(b, dict) and b.get("type") == "text"
        ]
        outputs["content"] = "".join(text_parts) or None
        outputs["tool_calls"] = [
            {
                "id": block.get("id", f"call_{i}"),
                "type": "function",
                "index": i,
                "function": {
                    "name": block.get("name", ""),
                    "arguments": _dumps(block.get("input", {})).decode(),
                },
            }
            for i, block in enumerate(tool_use_blocks)
        ]
    return outputs


def _reduce_chat_chunks(all_chunks: Sequence) -> dict:
    accumulate = _get_sdk_accumulate_event()
    if accumulate is None:
        return {"output": all_chunks}
    full_message = None
    for chunk in all_chunks:
        try:
            full_message = accumulate(
                event=chunk,
                current_snapshot=full_message,
            )
        except RuntimeError as e:
            logger.debug(f"Error accumulating event in Anthropic Wrapper: {e}")
            return {"output": all_chunks}
    if full_message is None:
        return {"output": all_chunks}
    return _message_to_outputs(full_message)


def _reduce_completions(all_chunks: list[Completion]) -> dict:
    all_content = []
    for chunk in all_chunks:
        content = chunk.completion
        if content is not None:
            all_content.append(content)
    content = "".join(all_content)
    if all_chunks:
        d = all_chunks[-1].model_dump()
        d["choices"] = [{"text": content}]
    else:
        d = {"choices": [{"text": content}]}

    return d


def _process_chat_completion(outputs: Any):
    try:
        # Check if outputs is a LegacyAPIResponse wrapper (from with_raw_response).
        # The Anthropic SDK's LegacyAPIResponse wraps the actual response object.
        # Call .parse() to extract the Message for tracing.
        # See: anthropics/anthropic-sdk-python _legacy_response.py#L102
        if hasattr(outputs, "parse") and callable(outputs.parse):
            try:
                outputs = outputs.parse()
            except Exception:
                pass
        return _message_to_outputs(outputs)
    except BaseException as e:
        logger.debug(f"Error processing chat completion: {e}")
        return {"output": outputs}


def _get_wrapper(
    original_create: Callable,
    name: str,
    reduce_fn: Callable,
    prepopulated_invocation_params: dict,
    tracing_extra: TracingExtra,
) -> Callable:
    @functools.wraps(original_create)
    def create(*args, **kwargs):
        stream = kwargs.get("stream")
        decorator = run_helpers.traceable(
            name=name,
            run_type="llm",
            reduce_fn=reduce_fn if stream else None,
            process_inputs=_strip_not_given,
            process_outputs=_process_chat_completion,
            _invocation_params_fn=functools.partial(
                _infer_ls_params, prepopulated_invocation_params
            ),
            **tracing_extra,
        )

        result = decorator(original_create)(*args, **kwargs)
        return result

    @functools.wraps(original_create)
    async def acreate(*args, **kwargs):
        stream = kwargs.get("stream")
        decorator = run_helpers.traceable(
            name=name,
            run_type="llm",
            reduce_fn=reduce_fn if stream else None,
            process_inputs=_strip_not_given,
            process_outputs=_process_chat_completion,
            _invocation_params_fn=functools.partial(
                _infer_ls_params, prepopulated_invocation_params
            ),
            **tracing_extra,
        )
        result = await decorator(original_create)(*args, **kwargs)
        return result

    return acreate if run_helpers.is_async(original_create) else create


def _get_stream_wrapper(
    original_stream: Callable,
    name: str,
    prepopulated_invocation_params: dict,
    tracing_extra: TracingExtra,
) -> Callable:
    """Create a wrapper for Anthropic's streaming context manager."""
    is_async = "async" in str(original_stream).lower()
    configured_traceable = run_helpers.traceable(
        name=name,
        reduce_fn=_reduce_chat_chunks,
        run_type="llm",
        process_inputs=_strip_not_given,
        _invocation_params_fn=functools.partial(
            _infer_ls_params, prepopulated_invocation_params
        ),
        **tracing_extra,
    )
    configured_traceable_text = run_helpers.traceable(
        name=name,
        run_type="llm",
        process_inputs=_strip_not_given,
        process_outputs=_process_chat_completion,
        _invocation_params_fn=functools.partial(
            _infer_ls_params, prepopulated_invocation_params
        ),
        **tracing_extra,
    )

    if is_async:

        class AsyncMessageStreamWrapper:
            def __init__(
                self,
                wrapped: AsyncMessageStream,
                **kwargs,
            ) -> None:
                self._wrapped = wrapped
                self._kwargs = kwargs

            @property
            def text_stream(self):
                @configured_traceable_text
                async def _text_stream(**_):
                    async for chunk in self._wrapped.text_stream:
                        yield chunk
                    run_tree = run_helpers.get_current_run_tree()
                    final_message = await self._wrapped.get_final_message()
                    outputs = _message_to_outputs(final_message)
                    run_tree.outputs = outputs
                    if usage := outputs.get("usage_metadata"):
                        run_tree.metadata["usage_metadata"] = usage

                return _text_stream(**self._kwargs)

            @property
            def response(self) -> httpx.Response:
                return self._wrapped.response

            @property
            def request_id(self) -> str | None:
                return self._wrapped.request_id

            async def __anext__(self) -> MessageStreamEvent:
                aiter = self.__aiter__()
                return await aiter.__anext__()

            async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]:
                @configured_traceable
                def traced_iter(**_):
                    return self._wrapped.__aiter__()

                async for chunk in traced_iter(**self._kwargs):
                    yield chunk

            async def __aenter__(self) -> Self:
                await self._wrapped.__aenter__()
                return self

            async def __aexit__(self, *exc) -> None:
                await self._wrapped.__aexit__(*exc)

            async def close(self) -> None:
                await self._wrapped.close()

            async def get_final_message(self) -> Message:
                return await self._wrapped.get_final_message()

            async def get_final_text(self) -> str:
                return await self._wrapped.get_final_text()

            async def until_done(self) -> None:
                await self._wrapped.until_done()

            @property
            def current_message_snapshot(self) -> Message:
                return self._wrapped.current_message_snapshot

        class AsyncMessagesStreamManagerWrapper:
            def __init__(self, **kwargs):
                self._kwargs = kwargs

            async def __aenter__(self):
                self._manager = original_stream(**self._kwargs)
                stream = await self._manager.__aenter__()
                return AsyncMessageStreamWrapper(stream, **self._kwargs)

            async def __aexit__(self, *exc):
                await self._manager.__aexit__(*exc)

        return AsyncMessagesStreamManagerWrapper
    else:

        class MessageStreamWrapper:
            def __init__(
                self,
                wrapped: MessageStream,
                **kwargs,
            ) -> None:
                self._wrapped = wrapped
                self._kwargs = kwargs

            @property
            def response(self) -> Any:
                return self._wrapped.response

            @property
            def request_id(self) -> str | None:
                return self._wrapped.request_id  # type: ignore[no-any-return]

            @property
            def text_stream(self):
                @configured_traceable_text
                def _text_stream(**_):
                    yield from self._wrapped.text_stream
                    run_tree = run_helpers.get_current_run_tree()
                    final_message = self._wrapped.get_final_message()
                    outputs = _message_to_outputs(final_message)
                    run_tree.outputs = outputs
                    if usage := outputs.get("usage_metadata"):
                        run_tree.metadata["usage_metadata"] = usage

                return _text_stream(**self._kwargs)

            def __next__(self) -> MessageStreamEvent:
                return self.__iter__().__next__()

            def __iter__(self):
                @configured_traceable
                def traced_iter(**_):
                    return self._wrapped.__iter__()

                return traced_iter(**self._kwargs)

            def __enter__(self) -> Self:
                self._wrapped.__enter__()
                return self

            def __exit__(self, *exc) -> None:
                self._wrapped.__exit__(*exc)

            def close(self) -> None:
                self._wrapped.close()

            def get_final_message(self) -> Message:
                return self._wrapped.get_final_message()

            def get_final_text(self) -> str:
                return self._wrapped.get_final_text()

            def until_done(self) -> None:
                return self._wrapped.until_done()

            @property
            def current_message_snapshot(self) -> Message:
                return self._wrapped.current_message_snapshot

        class MessagesStreamManagerWrapper:
            def __init__(self, **kwargs):
                self._kwargs = kwargs

            def __enter__(self):
                self._manager = original_stream(**self._kwargs)
                return MessageStreamWrapper(self._manager.__enter__(), **self._kwargs)

            def __exit__(self, *exc):
                self._manager.__exit__(*exc)

        return MessagesStreamManagerWrapper


class TracingExtra(TypedDict, total=False):
    metadata: Optional[Mapping[str, Any]]
    tags: Optional[list[str]]
    client: Optional[ls_client.Client]


def wrap_anthropic(
    client: C,
    *,
    tracing_extra: Optional[TracingExtra] = None,
    chat_name: str = "ChatAnthropic",
    completions_name: str = "Anthropic",
) -> C:
    """Patch the Anthropic client to make it traceable.

    Args:
        client: The client to patch.
        tracing_extra: Extra tracing information.
        chat_name: The run name for the messages endpoint.
        completions_name: The run name for the completions endpoint.

    Returns:
        The patched client.

    Example:
        ```python
        import anthropic
        from langsmith import wrappers

        client = wrappers.wrap_anthropic(anthropic.Anthropic())

        # Use Anthropic client same as you normally would:
        system = "You are a helpful assistant."
        messages = [
            {
                "role": "user",
                "content": "What physics breakthroughs do you predict will happen by 2300?",
            }
        ]
        completion = client.messages.create(
            model="claude-3-5-sonnet-latest",
            messages=messages,
            max_tokens=1000,
            system=system,
        )
        print(completion.content)

        # With raw response to access headers:
        raw_response = client.messages.with_raw_response.create(
            model="claude-3-5-sonnet-latest",
            messages=messages,
            max_tokens=1000,
            system=system,
        )
        print(raw_response.headers)  # Access HTTP headers
        message = raw_response.parse()  # Get parsed response

        # You can also use the streaming context manager:
        with client.messages.stream(
            model="claude-3-5-sonnet-latest",
            messages=messages,
            max_tokens=1000,
            system=system,
        ) as stream:
            for text in stream.text_stream:
                print(text, end="", flush=True)
            message = stream.get_final_message()
        ```
    """  # noqa: E501
    tracing_extra = tracing_extra or {}

    # Extract ls_invocation_params from metadata
    metadata = dict(tracing_extra.get("metadata") or {})
    prepopulated_invocation_params = metadata.pop("ls_invocation_params", {})

    # Create new tracing_extra without ls_invocation_params in metadata
    tracing_extra_rest: TracingExtra = {  # type: ignore[assignment]
        k: v for k, v in tracing_extra.items() if k != "metadata"
    }
    if metadata:
        tracing_extra_rest["metadata"] = metadata  # type: ignore[typeddict-item]

    client.messages.create = _get_wrapper(  # type: ignore[method-assign]
        client.messages.create,
        chat_name,
        _reduce_chat_chunks,
        prepopulated_invocation_params,
        tracing_extra_rest,
    )

    client.messages.stream = _get_stream_wrapper(  # type: ignore[method-assign]
        client.messages.stream,
        chat_name,
        prepopulated_invocation_params,
        tracing_extra_rest,
    )
    client.completions.create = _get_wrapper(  # type: ignore[method-assign]
        client.completions.create,
        completions_name,
        _reduce_completions,
        prepopulated_invocation_params,
        tracing_extra_rest,
    )

    if (
        hasattr(client, "beta")
        and hasattr(client.beta, "messages")
        and hasattr(client.beta.messages, "create")
    ):
        client.beta.messages.create = _get_wrapper(  # type: ignore[method-assign]
            client.beta.messages.create,  # type: ignore
            chat_name,
            _reduce_chat_chunks,
            prepopulated_invocation_params,
            tracing_extra_rest,
        )
    return client
