# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parses an LLM command line."""
from __future__ import annotations

import argparse
import shlex
import sys
from typing import AbstractSet, Any, Callable, MutableMapping, Sequence

from google.generativeai.notebook import argument_parser
from google.generativeai.notebook import flag_def
from google.generativeai.notebook import input_utils
from google.generativeai.notebook import model_registry
from google.generativeai.notebook import output_utils
from google.generativeai.notebook import parsed_args_lib
from google.generativeai.notebook import post_process_utils
from google.generativeai.notebook import py_utils
from google.generativeai.notebook import sheets_utils
from google.generativeai.notebook.lib import llm_function
from google.generativeai.notebook.lib import llmfn_inputs_source
from google.generativeai.notebook.lib import llmfn_outputs
from google.generativeai.notebook.lib import model as model_lib


_MIN_CANDIDATE_COUNT = 1
_MAX_CANDIDATE_COUNT = 8


def _validate_input_source_against_placeholders(
    source: llmfn_inputs_source.LLMFnInputsSource,
    placeholders: AbstractSet[str],
) -> None:
    for inputs in source.to_normalized_inputs():
        for keyword in placeholders:
            if keyword not in inputs:
                raise ValueError('Placeholder "{}" not found in input'.format(keyword))


def _get_resolve_input_from_py_var_fn(
    placeholders: AbstractSet[str] | None,
) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]:
    def _fn(var_name: str) -> llmfn_inputs_source.LLMFnInputsSource:
        source = input_utils.get_inputs_source_from_py_var(var_name)
        if placeholders:
            _validate_input_source_against_placeholders(source, placeholders)
        return source

    return _fn


def _resolve_compare_fn_var(
    name: str,
) -> tuple[str, parsed_args_lib.TextResultCompareFn]:
    """Resolves a value passed into --compare_fn."""
    fn = py_utils.get_py_var(name)
    if not isinstance(fn, Callable):
        raise ValueError('Variable "{}" does not contain a Callable object'.format(name))

    return name, fn


def _resolve_ground_truth_var(name: str) -> Sequence[str]:
    """Resolves a value passed into --ground_truth."""
    value = py_utils.get_py_var(name)

    # "str" and "bytes" are also Sequences but we want an actual Sequence of
    # strings, like a list.
    if not isinstance(value, Sequence) or isinstance(value, str) or isinstance(value, bytes):
        raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
    for x in value:
        if not isinstance(x, str):
            raise ValueError('Variable "{}" does not contain a Sequence of strings'.format(name))
    return value


def _get_resolve_sheets_inputs_fn(
    placeholders: AbstractSet[str] | None,
) -> Callable[[str], llmfn_inputs_source.LLMFnInputsSource]:
    def _fn(value: str) -> llmfn_inputs_source.LLMFnInputsSource:
        sheets_id = sheets_utils.get_sheets_id_from_str(value)
        source = sheets_utils.SheetsInputs(sheets_id)
        if placeholders:
            _validate_input_source_against_placeholders(source, placeholders)
        return source

    return _fn


def _resolve_sheets_outputs(value: str) -> llmfn_outputs.LLMFnOutputsSink:
    sheets_id = sheets_utils.get_sheets_id_from_str(value)
    return sheets_utils.SheetsOutputs(sheets_id)


def _add_model_flags(
    parser: argparse.ArgumentParser,
) -> None:
    """Adds flags that are related to model selection and config."""
    flag_def.EnumFlagDef(
        name="model_type",
        short_name="mt",
        enum_type=model_registry.ModelName,
        default_value=model_registry.ModelRegistry.DEFAULT_MODEL,
        help_msg="The type of model to use.",
    ).add_argument_to_parser(parser)

    def _check_is_greater_than_or_equal_to_zero(x: float) -> float:
        if x < 0:
            raise ValueError("Value should be greater than or equal to zero, got {}".format(x))
        return x

    flag_def.SingleValueFlagDef(
        name="temperature",
        short_name="t",
        parse_type=float,
        # Use None for default value to indicate that this will use the default
        # value in Text service.
        default_value=None,
        parse_to_dest_type_fn=_check_is_greater_than_or_equal_to_zero,
        help_msg=(
            "Controls the randomness of the output. Must be positive. Typical"
            " values are in the range: [0.0, 1.0]. Higher values produce a more"
            " random and varied response. A temperature of zero will be"
            " deterministic."
        ),
    ).add_argument_to_parser(parser)

    flag_def.SingleValueFlagDef(
        name="model",
        short_name="m",
        default_value=None,
        help_msg=(
            "The name of the model to use. If not provided, a default model will" " be used."
        ),
    ).add_argument_to_parser(parser)

    def _check_candidate_count_range(x: Any) -> int:
        if x < _MIN_CANDIDATE_COUNT or x > _MAX_CANDIDATE_COUNT:
            raise ValueError(
                "Value should be in the range [{}, {}], got {}".format(
                    _MIN_CANDIDATE_COUNT, _MAX_CANDIDATE_COUNT, x
                )
            )
        return int(x)

    flag_def.SingleValueFlagDef(
        name="candidate_count",
        short_name="cc",
        parse_type=int,
        # Use None for default value to indicate that this will use the default
        # value in Text service.
        default_value=None,
        parse_to_dest_type_fn=_check_candidate_count_range,
        help_msg="The number of candidates to produce.",
    ).add_argument_to_parser(parser)

    flag_def.BooleanFlagDef(
        name="unique",
        help_msg="Whether to dedupe candidates returned by the model.",
    ).add_argument_to_parser(parser)


def _add_input_flags(
    parser: argparse.ArgumentParser,
    placeholders: AbstractSet[str] | None,
) -> None:
    """Adds flags to read inputs from a Python variable or Sheets."""
    flag_def.MultiValuesFlagDef(
        name="inputs",
        short_name="i",
        dest_type=llmfn_inputs_source.LLMFnInputsSource,
        parse_to_dest_type_fn=_get_resolve_input_from_py_var_fn(placeholders),
        help_msg=(
            "Optional names of Python variables containing inputs to use to"
            " instantiate a prompt. The variable must be either: a dictionary"
            " {'key1': ['val1', 'val2'] ...}, or an instance of LLMFnInputsSource"
            " such as SheetsInput."
        ),
    ).add_argument_to_parser(parser)

    flag_def.MultiValuesFlagDef(
        name="sheets_input_names",
        short_name="si",
        dest_type=llmfn_inputs_source.LLMFnInputsSource,
        parse_to_dest_type_fn=_get_resolve_sheets_inputs_fn(placeholders),
        help_msg=(
            "Optional names of Google Sheets to read inputs from. This is"
            " equivalent to using --inputs with the names of variables that are"
            " instances of SheetsInputs, just more convenient to use."
        ),
    ).add_argument_to_parser(parser)


def _add_output_flags(
    parser: argparse.ArgumentParser,
) -> None:
    """Adds flags to write outputs to a Python variable."""
    flag_def.MultiValuesFlagDef(
        name="outputs",
        short_name="o",
        dest_type=llmfn_outputs.LLMFnOutputsSink,
        parse_to_dest_type_fn=output_utils.get_outputs_sink_from_py_var,
        help_msg=(
            "Optional names of Python variables to output to. If the Python"
            " variable has not already been defined, it will be created. If the"
            " variable is defined and is an instance of LLMFnOutputsSink, the"
            " outputs will be written through the sink's write_outputs() method."
        ),
    ).add_argument_to_parser(parser)

    flag_def.MultiValuesFlagDef(
        name="sheets_output_names",
        short_name="so",
        dest_type=llmfn_outputs.LLMFnOutputsSink,
        parse_to_dest_type_fn=_resolve_sheets_outputs,
        help_msg=(
            "Optional names of Google Sheets to write inputs to. This is"
            " equivalent to using --outputs with the names of variables that are"
            " instances of SheetsOutputs, just more convenient to use."
        ),
    ).add_argument_to_parser(parser)


def _add_compare_flags(
    parser: argparse.ArgumentParser,
) -> None:
    flag_def.MultiValuesFlagDef(
        name="compare_fn",
        dest_type=tuple,
        parse_to_dest_type_fn=_resolve_compare_fn_var,
        help_msg=(
            "An optional function that takes two inputs: (lhs_result, rhs_result)"
            " which are the results of the left- and right-hand side functions. "
            "Multiple comparison functions can be provided."
        ),
    ).add_argument_to_parser(parser)


def _add_eval_flags(
    parser: argparse.ArgumentParser,
) -> None:
    flag_def.SingleValueFlagDef(
        name="ground_truth",
        required=True,
        dest_type=Sequence,
        parse_to_dest_type_fn=_resolve_ground_truth_var,
        help_msg=(
            "A variable containing a Sequence of strings representing the ground"
            " truth that the output of this cell will be compared against. It"
            " should have the same number of entries as inputs."
        ),
    ).add_argument_to_parser(parser)


def _create_run_parser(
    parser: argparse.ArgumentParser,
    placeholders: AbstractSet[str] | None,
) -> None:
    """Adds flags for the `run` command.

    `run` sends one or more prompts to a model.

    Args:
      parser: The parser to which flags will be added.
      placeholders: Placeholders from prompts in the cell contents.
    """
    _add_model_flags(parser)
    _add_input_flags(parser, placeholders)
    _add_output_flags(parser)


def _create_compile_parser(
    parser: argparse.ArgumentParser,
) -> None:
    """Adds flags for the compile command.

    `compile` "compiles" a prompt and model call into a callable function.

    Args:
      parser: The parser to which flags will be added.
    """

    # Add a positional argument for "compile_save_name".
    def _compile_save_name_fn(var_name: str) -> str:
        try:
            py_utils.validate_var_name(var_name)
        except ValueError as e:
            # Re-raise as ArgumentError to preserve the original error message.
            raise argparse.ArgumentError(None, "{}".format(e)) from e
        return var_name

    save_name_help = "The name of a Python variable to save the compiled function to."
    parser.add_argument("compile_save_name", help=save_name_help, type=_compile_save_name_fn)
    _add_model_flags(parser)


def _create_compare_parser(
    parser: argparse.ArgumentParser,
    placeholders: AbstractSet[str] | None,
) -> None:
    """Adds flags for the compare command.

    Args:
      parser: The parser to which flags will be added.
      placeholders: Placeholders from prompts in the compiled functions.
    """

    # Add positional arguments.
    def _resolve_llm_function_fn(
        var_name: str,
    ) -> tuple[str, llm_function.LLMFunction]:
        try:
            py_utils.validate_var_name(var_name)
        except ValueError as e:
            # Re-raise as ArgumentError to preserve the original error message.
            raise argparse.ArgumentError(None, "{}".format(e)) from e

        fn = py_utils.get_py_var(var_name)
        if not isinstance(fn, llm_function.LLMFunction):
            raise argparse.ArgumentError(
                None,
                '{} is not a function created with the "compile" command'.format(var_name),
            )
        return var_name, fn

    name_help = (
        "The name of a Python variable containing a function previously created"
        ' with the "compile" command.'
    )
    parser.add_argument("lhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)
    parser.add_argument("rhs_name_and_fn", help=name_help, type=_resolve_llm_function_fn)

    _add_input_flags(parser, placeholders)
    _add_output_flags(parser)
    _add_compare_flags(parser)


def _create_eval_parser(
    parser: argparse.ArgumentParser,
    placeholders: AbstractSet[str] | None,
) -> None:
    """Adds flags for the eval command.

    Args:
      parser: The parser to which flags will be added.
      placeholders: Placeholders from prompts in the cell contents.
    """
    _add_model_flags(parser)
    _add_input_flags(parser, placeholders)
    _add_output_flags(parser)
    _add_compare_flags(parser)
    _add_eval_flags(parser)


def _create_parser(
    placeholders: AbstractSet[str] | None,
) -> argparse.ArgumentParser:
    """Create the full parser."""
    system_name = "llm"
    description = "A system for interacting with LLMs."
    epilog = ""

    # Commands
    parser = argument_parser.ArgumentParser(
        prog=system_name,
        description=description,
        epilog=epilog,
    )
    subparsers = parser.add_subparsers(dest="cmd")
    _create_run_parser(
        subparsers.add_parser(parsed_args_lib.CommandName.RUN_CMD.value),
        placeholders,
    )
    _create_compile_parser(subparsers.add_parser(parsed_args_lib.CommandName.COMPILE_CMD.value))
    _create_compare_parser(
        subparsers.add_parser(parsed_args_lib.CommandName.COMPARE_CMD.value),
        placeholders,
    )
    _create_eval_parser(
        subparsers.add_parser(parsed_args_lib.CommandName.EVAL_CMD.value),
        placeholders,
    )
    return parser


def _validate_parsed_args(parsed_args: parsed_args_lib.ParsedArgs) -> None:
    # If candidate_count is not set (i.e. is None), assuming the default value
    # is 1.
    if parsed_args.unique and (
        parsed_args.model_args.candidate_count is None
        or parsed_args.model_args.candidate_count == 1
    ):
        print(
            '"--unique" works across candidates only: it should be used with'
            " --candidate_count set to a value greater-than one."
        )


class CmdLineParser:
    """Implementation of Magics command line parser."""

    # Commands
    DEFAULT_CMD = parsed_args_lib.CommandName.RUN_CMD

    # Post-processing operator.
    PIPE_OP = "|"

    @classmethod
    def _split_post_processing_tokens(
        cls,
        tokens: Sequence[str],
    ) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]:
        """Splits inputs into the command and post processing tokens.

        The command is represented as a sequence of tokens.
        See comments on the PostProcessingTokens type alias.

        E.g. Given: "run --temperature 0.5 | add_score | to_lower_case"
        The command will be: ["run", "--temperature", "0.5"].
        The post processing tokens will be: [["add_score"], ["to_lower_case"]]

        Args:
          tokens: The command line tokens.

        Returns:
          A tuple of (command line, post processing tokens).
        """
        split_tokens = []
        start_idx: int | None = None
        for token_num, token in enumerate(tokens):
            if start_idx is None:
                start_idx = token_num
            if token == CmdLineParser.PIPE_OP:
                split_tokens.append(tokens[start_idx:token_num] if start_idx is not None else [])
                start_idx = None

        # Add the remaining tokens after the last PIPE_OP.
        split_tokens.append(tokens[start_idx:] if start_idx is not None else [])

        return split_tokens[0], split_tokens[1:]

    @classmethod
    def _tokenize_line(
        cls, line: str
    ) -> tuple[Sequence[str], parsed_args_lib.PostProcessingTokens]:
        """Parses `line` and returns command line and post processing tokens."""
        # Check to make sure there is a command at the start. If not, add the
        # default command to the list of tokens.
        tokens = shlex.split(line)
        if not tokens:
            tokens = [CmdLineParser.DEFAULT_CMD.value]
        first_token = tokens[0]
        # Add default command if the first token is not the help token.
        if not first_token[0].isalpha() and first_token not in ["-h", "--help"]:
            tokens = [CmdLineParser.DEFAULT_CMD.value] + tokens
        # Split line into tokens and post-processing
        return CmdLineParser._split_post_processing_tokens(tokens)

    @classmethod
    def _get_model_args(
        cls, parsed_results: MutableMapping[str, Any]
    ) -> tuple[MutableMapping[str, Any], model_lib.ModelArguments]:
        """Extracts fields for model args from `parsed_results`.

        Keys specific to model arguments will be removed from `parsed_results`.

        Args:
          parsed_results: A dictionary of parsed arguments (from ArgumentParser). It
            will be modified in place.

        Returns:
          A tuple of (updated parsed_results, model arguments).
        """
        model = parsed_results.pop("model", None)
        temperature = parsed_results.pop("temperature", None)
        candidate_count = parsed_results.pop("candidate_count", None)

        model_args = model_lib.ModelArguments(
            model=model,
            temperature=temperature,
            candidate_count=candidate_count,
        )
        return parsed_results, model_args

    def parse_line(
        self,
        line: str,
        placeholders: AbstractSet[str] | None = None,
    ) -> tuple[parsed_args_lib.ParsedArgs, parsed_args_lib.PostProcessingTokens]:
        """Parses the commandline and returns ParsedArgs and post-processing tokens.

        Args:
          line: The line to parse (usually contents from cell Magics).
          placeholders: Placeholders from prompts in the cell contents.

        Returns:
          A tuple of (parsed_args, post_processing_tokens).
        """
        tokens, post_processing_tokens = CmdLineParser._tokenize_line(line)

        parsed_args = self._get_parsed_args_from_cmd_line_tokens(
            tokens=tokens, placeholders=placeholders
        )

        # Special-case for "compare" command: because the prompts are compiled into
        # the left- and right-hand side functions rather than in the cell body, we
        # cannot examine the cell body to get the placeholders.
        #
        # Instead we parse the command line twice: once to get the left- and right-
        # functions, then we query the functions for their placeholders, then
        # parse the commandline again to validate the inputs.
        if parsed_args.cmd == parsed_args_lib.CommandName.COMPARE_CMD:
            assert parsed_args.lhs_name_and_fn is not None
            assert parsed_args.rhs_name_and_fn is not None
            _, lhs_fn = parsed_args.lhs_name_and_fn
            _, rhs_fn = parsed_args.rhs_name_and_fn
            parsed_args = self._get_parsed_args_from_cmd_line_tokens(
                tokens=tokens,
                placeholders=frozenset(lhs_fn.get_placeholders()).union(rhs_fn.get_placeholders()),
            )

        _validate_parsed_args(parsed_args)

        for expr in post_processing_tokens:
            post_process_utils.validate_one_post_processing_expression(expr)

        return parsed_args, post_processing_tokens

    def _get_parsed_args_from_cmd_line_tokens(
        self,
        tokens: Sequence[str],
        placeholders: AbstractSet[str] | None,
    ) -> parsed_args_lib.ParsedArgs:
        """Returns ParsedArgs from a tokenized command line."""
        # Create a new parser to avoid reusing the temporary argparse.Namespace
        # object.
        results = _create_parser(placeholders).parse_args(tokens)

        results_dict = vars(results)
        results_dict["cmd"] = parsed_args_lib.CommandName(results_dict["cmd"])

        results_dict, model_args = CmdLineParser._get_model_args(results_dict)
        results_dict["model_args"] = model_args

        return parsed_args_lib.ParsedArgs(**results_dict)
