# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import functools
from typing import Iterator

from google.generativeai import protos

from google.generativeai import client as client_lib
from google.generativeai.types import model_types
from google.api_core import operation as operation_lib

import tqdm.auto as tqdm


def list_operations(*, client=None) -> Iterator[CreateTunedModelOperation]:
    """Calls the API to list all operations"""

    if client is None:
        client = client_lib.get_default_operations_client()

    # The client returns an iterator of Operation protos (`Iterator[google.longrunning.operations_pb2.Operation]`)
    # not a gapic Operation object (`google.api_core.operation.Operation`)
    operations = (
        CreateTunedModelOperation.from_proto(op, client)
        for op in client.list_operations(name="", filter_="")
    )

    return operations


def get_operation(name: str, *, client=None) -> CreateTunedModelOperation:
    """Calls the API to get a specific operation"""
    if client is None:
        client = client_lib.get_default_operations_client()

    op = client.get_operation(name=name)
    return CreateTunedModelOperation.from_proto(op, client)


def delete_operation(name: str, *, client=None):
    """Calls the API to delete a specific operation"""

    # Raises:google.api_core.exceptions.MethodNotImplemented: Not implemented.
    if client is None:
        client = client_lib.get_default_operations_client()

    return client.delete_operation(name=name)


class CreateTunedModelOperation(operation_lib.Operation):
    @classmethod
    def from_proto(cls, proto, client):
        """
        result = getattr(proto, 'result', None)
        if result is not None:
            if result.value == b'':
                del proto.result
        """

        return from_gapic(
            cls=CreateTunedModelOperation,
            operation=proto,
            operations_client=client,
            result_type=protos.TunedModel,
            metadata_type=protos.CreateTunedModelMetadata,
        )

    @classmethod
    def from_core_operation(
        cls,
        operation: operation_lib.Operation,
    ):
        polling = getattr(operation, "_polling", None)
        retry = getattr(operation, "_retry", None)
        if polling is not None:
            # google.api_core v 2.11
            kwargs = {"polling": polling}
        elif retry is not None:
            # google.api_core v 2.10
            kwargs = {"retry": retry}
        else:
            kwargs = {}
        return cls(
            operation=operation._operation,
            refresh=operation._refresh,
            cancel=operation._cancel,
            result_type=operation._result_type,
            metadata_type=operation._metadata_type,
            **kwargs,
        )

    @property
    def name(self) -> str:
        return self._operation.name

    def update(self):
        """Refresh the current statuses in metadata/result/error"""
        self._refresh_and_update()

    def wait_bar(self, **kwargs) -> Iterator[protos.CreateTunedModelMetadata]:
        """A tqdm wait bar, yields `Operation` statuses until complete.

        Args:
            **kwargs: passed through to `tqdm.auto.tqdm(..., **kwargs)`

        Yields:
            Operation statuses as `protos.CreateTunedModelMetadata` objects.
        """
        bar = tqdm.tqdm(total=self.metadata.total_steps, initial=0, **kwargs)

        # done() includes a `_refresh_and_update`
        while not self.done():
            metadata = self.metadata
            bar.update(self.metadata.completed_steps - bar.n)
            yield metadata
        metadata = self.metadata
        bar.update(self.metadata.completed_steps - bar.n)
        return self.result()

    def set_result(self, result: protos.TunedModel):
        result = model_types.decode_tuned_model(result)
        super().set_result(result)


def from_gapic(
    cls,
    *,
    operation,
    operations_client,
    result_type,
    metadata_type,
    grpc_metadata=None,
    **kwargs,
):
    """`google.api_core.operation.from_gapic`, patched to allow subclasses."""
    refresh = functools.partial(
        operations_client.get_operation, operation.name, metadata=grpc_metadata
    )
    cancel = functools.partial(
        operations_client.cancel_operation,
        operation.name,
        metadata=grpc_metadata,
    )
    return cls(operation, refresh, cancel, result_type, metadata_type, **kwargs)
