# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import dataclasses
from collections.abc import Iterable
import itertools
from typing import Any, Iterable, Union, Mapping, Optional
from typing_extensions import TypedDict

import google.ai.generativelanguage as glm
from google.generativeai import protos

from google.generativeai.client import (
    get_default_generative_client,
    get_default_generative_async_client,
)
from google.generativeai.types import model_types
from google.generativeai.types import helper_types
from google.generativeai.types import safety_types
from google.generativeai.types import content_types
from google.generativeai.types import retriever_types
from google.generativeai.types.retriever_types import MetadataFilter

DEFAULT_ANSWER_MODEL = "models/aqa"

AnswerStyle = protos.GenerateAnswerRequest.AnswerStyle

AnswerStyleOptions = Union[int, str, AnswerStyle]

_ANSWER_STYLES: dict[AnswerStyleOptions, AnswerStyle] = {
    AnswerStyle.ANSWER_STYLE_UNSPECIFIED: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
    0: AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
    "answer_style_unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
    "unspecified": AnswerStyle.ANSWER_STYLE_UNSPECIFIED,
    AnswerStyle.ABSTRACTIVE: AnswerStyle.ABSTRACTIVE,
    1: AnswerStyle.ABSTRACTIVE,
    "answer_style_abstractive": AnswerStyle.ABSTRACTIVE,
    "abstractive": AnswerStyle.ABSTRACTIVE,
    AnswerStyle.EXTRACTIVE: AnswerStyle.EXTRACTIVE,
    2: AnswerStyle.EXTRACTIVE,
    "answer_style_extractive": AnswerStyle.EXTRACTIVE,
    "extractive": AnswerStyle.EXTRACTIVE,
    AnswerStyle.VERBOSE: AnswerStyle.VERBOSE,
    3: AnswerStyle.VERBOSE,
    "answer_style_verbose": AnswerStyle.VERBOSE,
    "verbose": AnswerStyle.VERBOSE,
}


def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle:
    if isinstance(x, str):
        x = x.lower()
    return _ANSWER_STYLES[x]


GroundingPassageOptions = (
    Union[
        protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType
    ],
)

GroundingPassagesOptions = Union[
    protos.GroundingPassages,
    Iterable[GroundingPassageOptions],
    Mapping[str, content_types.ContentType],
]


def _make_grounding_passages(source: GroundingPassagesOptions) -> protos.GroundingPassages:
    """
    Converts the `source` into a `protos.GroundingPassage`. A `GroundingPassages` contains a list of
    `protos.GroundingPassage` objects, which each contain a `protos.Content` and a string `id`.

    Args:
        source: `Content` or a `GroundingPassagesOptions` that will be converted to protos.GroundingPassages.

    Return:
        `protos.GroundingPassages` to be passed into `protos.GenerateAnswer`.
    """
    if isinstance(source, protos.GroundingPassages):
        return source

    if not isinstance(source, Iterable):
        raise TypeError(
            f"Invalid input: The 'source' argument must be an instance of 'GroundingPassagesOptions'. Received a '{type(source).__name__}' object instead."
        )

    passages = []
    if isinstance(source, Mapping):
        source = source.items()

    for n, data in enumerate(source):
        if isinstance(data, protos.GroundingPassage):
            passages.append(data)
        elif isinstance(data, tuple):
            id, content = data  # tuple must have exactly 2 items.
            passages.append({"id": id, "content": content_types.to_content(content)})
        else:
            passages.append({"id": str(n), "content": content_types.to_content(data)})

    return protos.GroundingPassages(passages=passages)


SourceNameType = Union[
    str, retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document
]


class SemanticRetrieverConfigDict(TypedDict):
    source: SourceNameType
    query: content_types.ContentsType
    metadata_filter: Optional[Iterable[MetadataFilter]]
    max_chunks_count: Optional[int]
    minimum_relevance_score: Optional[float]


SemanticRetrieverConfigOptions = Union[
    SourceNameType,
    SemanticRetrieverConfigDict,
    protos.SemanticRetrieverConfig,
]


def _maybe_get_source_name(source) -> str | None:
    if isinstance(source, str):
        return source
    elif isinstance(
        source, (retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document)
    ):
        return source.name
    else:
        return None


def _make_semantic_retriever_config(
    source: SemanticRetrieverConfigOptions,
    query: content_types.ContentsType,
) -> protos.SemanticRetrieverConfig:
    if isinstance(source, protos.SemanticRetrieverConfig):
        return source

    name = _maybe_get_source_name(source)
    if name is not None:
        source = {"source": name}
    elif isinstance(source, dict):
        source["source"] = _maybe_get_source_name(source["source"])
    else:
        raise TypeError(
            f"Invalid input: Failed to create a 'protos.SemanticRetrieverConfig' from the provided source. "
            f"Received type: {type(source).__name__}, "
            f"Received value: {source}"
        )

    if source["query"] is None:
        source["query"] = query
    elif isinstance(source["query"], str):
        source["query"] = content_types.to_content(source["query"])

    return protos.SemanticRetrieverConfig(source)


def _make_generate_answer_request(
    *,
    model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
    contents: content_types.ContentsType,
    inline_passages: GroundingPassagesOptions | None = None,
    semantic_retriever: SemanticRetrieverConfigOptions | None = None,
    answer_style: AnswerStyle | None = None,
    safety_settings: safety_types.SafetySettingOptions | None = None,
    temperature: float | None = None,
) -> protos.GenerateAnswerRequest:
    """
    constructs a protos.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model.

    Args:
        model: Name of the model used to generate the grounded response.
        contents: Content of the current conversation with the model. For single-turn query, this is a
            single question to answer. For multi-turn queries, this is a repeated field that contains
            conversation history and the last `Content` in the list containing the question.
        inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs,
            or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
            one must be set, but not both.
        semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
             `inline_passages`, one must be set, but not both.
        answer_style: Style for grounded answers.
        safety_settings: Safety settings for generated output.
        temperature: The temperature for randomness in the output.

    Returns:
        Call for protos.GenerateAnswerRequest().
    """
    model = model_types.make_model_name(model)

    contents = content_types.to_contents(contents)

    if safety_settings:
        safety_settings = safety_types.normalize_safety_settings(safety_settings)

    if inline_passages is not None and semantic_retriever is not None:
        raise ValueError(
            f"Invalid configuration: Please set either 'inline_passages' or 'semantic_retriever_config', but not both. "
            f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}."
        )
    elif inline_passages is not None:
        inline_passages = _make_grounding_passages(inline_passages)
    elif semantic_retriever is not None:
        semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1])
    else:
        raise TypeError(
            f"Invalid configuration: Either 'inline_passages' or 'semantic_retriever_config' must be provided, but currently both are 'None'. "
            f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}."
        )

    if answer_style:
        answer_style = to_answer_style(answer_style)

    return protos.GenerateAnswerRequest(
        model=model,
        contents=contents,
        inline_passages=inline_passages,
        semantic_retriever=semantic_retriever,
        safety_settings=safety_settings,
        temperature=temperature,
        answer_style=answer_style,
    )


def generate_answer(
    *,
    model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
    contents: content_types.ContentsType,
    inline_passages: GroundingPassagesOptions | None = None,
    semantic_retriever: SemanticRetrieverConfigOptions | None = None,
    answer_style: AnswerStyle | None = None,
    safety_settings: safety_types.SafetySettingOptions | None = None,
    temperature: float | None = None,
    client: glm.GenerativeServiceClient | None = None,
    request_options: helper_types.RequestOptionsType | None = None,
):
    """Calls the GenerateAnswer API and returns a `types.Answer` containing the response.

    You can pass a literal list of text chunks:

    >>> from google.generativeai import answer
    >>> answer.generate_answer(
    ...     content=question,
    ...     inline_passages=splitter.split(document)
    ... )

    Or pass a reference to a retreiver Document or Corpus:

    >>> from google.generativeai import answer
    >>> from google.generativeai import retriever
    >>> my_corpus = retriever.get_corpus('my_corpus')
    >>> genai.generate_answer(
    ...     content=question,
    ...     semantic_retriever=my_corpus
    ... )


    Args:
        model: Which model to call, as a string or a `types.Model`.
        contents: The question to be answered by the model, grounded in the
                provided source.
        inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
            or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
            one must be set, but not both.
        semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
             `inline_passages`, one must be set, but not both.
        answer_style: Style in which the grounded answer should be returned.
        safety_settings: Safety settings for generated output. Defaults to None.
        temperature: Controls the randomness of the output.
        client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead.
        request_options: Options for the request.

    Returns:
        A `types.Answer` containing the model's text answer response.
    """
    if request_options is None:
        request_options = {}

    if client is None:
        client = get_default_generative_client()

    request = _make_generate_answer_request(
        model=model,
        contents=contents,
        inline_passages=inline_passages,
        semantic_retriever=semantic_retriever,
        safety_settings=safety_settings,
        temperature=temperature,
        answer_style=answer_style,
    )

    response = client.generate_answer(request, **request_options)

    return response


async def generate_answer_async(
    *,
    model: model_types.AnyModelNameOptions = DEFAULT_ANSWER_MODEL,
    contents: content_types.ContentsType,
    inline_passages: GroundingPassagesOptions | None = None,
    semantic_retriever: SemanticRetrieverConfigOptions | None = None,
    answer_style: AnswerStyle | None = None,
    safety_settings: safety_types.SafetySettingOptions | None = None,
    temperature: float | None = None,
    client: glm.GenerativeServiceClient | None = None,
    request_options: helper_types.RequestOptionsType | None = None,
):
    """
    Calls the API and returns a `types.Answer` containing the answer.

    Args:
        model: Which model to call, as a string or a `types.Model`.
        contents: The question to be answered by the model, grounded in the
                provided source.
        inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs,
            or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retriever`,
            one must be set, but not both.
        semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with
             `inline_passages`, one must be set, but not both.
        answer_style: Style in which the grounded answer should be returned.
        safety_settings: Safety settings for generated output. Defaults to None.
        temperature: Controls the randomness of the output.
        client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead.

    Returns:
        A `types.Answer` containing the model's text answer response.
    """
    if request_options is None:
        request_options = {}

    if client is None:
        client = get_default_generative_async_client()

    request = _make_generate_answer_request(
        model=model,
        contents=contents,
        inline_passages=inline_passages,
        semantic_retriever=semantic_retriever,
        safety_settings=safety_settings,
        temperature=temperature,
        answer_style=answer_style,
    )

    response = await client.generate_answer(request, **request_options)

    return response
