package org.springframework.ai.qianfan;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.qianfan.api.QianFanApi;
import org.springframework.ai.qianfan.api.QianFanConstants;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/* loaded from: input_file:org/springframework/ai/qianfan/QianFanChatModel.class */
public class QianFanChatModel implements ChatModel, StreamingChatModel {
    private static final Logger logger = LoggerFactory.getLogger(QianFanChatModel.class);
    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
    public final RetryTemplate retryTemplate;
    private final QianFanChatOptions defaultOptions;
    private final QianFanApi qianFanApi;
    private final ObservationRegistry observationRegistry;
    private ChatModelObservationConvention observationConvention;

    public QianFanChatModel(QianFanApi qianFanApi) {
        this(qianFanApi, QianFanChatOptions.builder().model(QianFanApi.DEFAULT_CHAT_MODEL).temperature(Double.valueOf(0.7d)).build());
    }

    public QianFanChatModel(QianFanApi qianFanApi, QianFanChatOptions qianFanChatOptions) {
        this(qianFanApi, qianFanChatOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    public QianFanChatModel(QianFanApi qianFanApi, QianFanChatOptions qianFanChatOptions, RetryTemplate retryTemplate) {
        this(qianFanApi, qianFanChatOptions, retryTemplate, ObservationRegistry.NOOP);
    }

    public QianFanChatModel(QianFanApi qianFanApi, QianFanChatOptions qianFanChatOptions, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
        this.observationConvention = DEFAULT_OBSERVATION_CONVENTION;
        Assert.notNull(qianFanApi, "QianFanApi must not be null");
        Assert.notNull(qianFanChatOptions, "Options must not be null");
        Assert.notNull(retryTemplate, "RetryTemplate must not be null");
        Assert.notNull(observationRegistry, "ObservationRegistry must not be null");
        this.qianFanApi = qianFanApi;
        this.defaultOptions = qianFanChatOptions;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
    }

    public ChatResponse call(Prompt prompt) {
        QianFanApi.ChatCompletionRequest createRequest = createRequest(prompt, false);
        ChatModelObservationContext build = ChatModelObservationContext.builder().prompt(prompt).provider(QianFanConstants.PROVIDER_NAME).requestOptions(buildRequestOptions(createRequest)).build();
        return (ChatResponse) ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
            return build;
        }, this.observationRegistry).observe(() -> {
            QianFanApi.ChatCompletion chatCompletion = (QianFanApi.ChatCompletion) ((ResponseEntity) this.retryTemplate.execute(retryContext -> {
                return this.qianFanApi.chatCompletionEntity(createRequest);
            })).getBody();
            if (chatCompletion == null) {
                logger.warn("No chat completion returned for prompt: {}", prompt);
                return new ChatResponse(List.of());
            }
            ChatResponse chatResponse = new ChatResponse(Collections.singletonList(new Generation(new AssistantMessage(chatCompletion.result(), Map.of("id", chatCompletion.id(), "role", QianFanApi.ChatCompletionMessage.Role.ASSISTANT)))), from(chatCompletion, createRequest.model()));
            build.setResponse(chatResponse);
            return chatResponse;
        });
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        return Flux.deferContextual(contextView -> {
            QianFanApi.ChatCompletionRequest createRequest = createRequest(prompt, true);
            Flux<QianFanApi.ChatCompletionChunk> chatCompletionStream = this.qianFanApi.chatCompletionStream(createRequest);
            ChatModelObservationContext build = ChatModelObservationContext.builder().prompt(prompt).provider(QianFanConstants.PROVIDER_NAME).requestOptions(buildRequestOptions(createRequest)).build();
            Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
                return build;
            }, this.observationRegistry);
            observation.parentObservation((Observation) contextView.getOrDefault("micrometer.observation", (Object) null)).start();
            Flux switchMap = chatCompletionStream.map(this::toChatCompletion).switchMap(chatCompletion -> {
                return Mono.just(chatCompletion).map(chatCompletion -> {
                    return new ChatResponse(Collections.singletonList(new Generation(new AssistantMessage(chatCompletion.result(), Map.of("id", chatCompletion.id(), "role", QianFanApi.ChatCompletionMessage.Role.ASSISTANT)))), from(chatCompletion, createRequest.model()));
                });
            });
            Objects.requireNonNull(observation);
            Flux contextWrite = switchMap.doOnError(observation::error).doFinally(signalType -> {
                observation.stop();
            }).contextWrite(context -> {
                return context.put("micrometer.observation", observation);
            });
            MessageAggregator messageAggregator = new MessageAggregator();
            Objects.requireNonNull(build);
            return messageAggregator.aggregate(contextWrite, (v1) -> {
                r2.setResponse(v1);
            });
        });
    }

    private QianFanApi.ChatCompletion toChatCompletion(QianFanApi.ChatCompletionChunk chatCompletionChunk) {
        return new QianFanApi.ChatCompletion(chatCompletionChunk.id(), chatCompletionChunk.object(), chatCompletionChunk.created(), chatCompletionChunk.result(), chatCompletionChunk.finishReason(), chatCompletionChunk.usage());
    }

    public QianFanApi.ChatCompletionRequest createRequest(Prompt prompt, boolean z) {
        List list = prompt.getInstructions().stream().map(message -> {
            return new QianFanApi.ChatCompletionMessage(message.getText(), QianFanApi.ChatCompletionMessage.Role.valueOf(message.getMessageType().name()));
        }).toList();
        List list2 = list.stream().filter(chatCompletionMessage -> {
            return chatCompletionMessage.role() == QianFanApi.ChatCompletionMessage.Role.SYSTEM;
        }).toList();
        List list3 = list.stream().filter(chatCompletionMessage2 -> {
            return chatCompletionMessage2.role() != QianFanApi.ChatCompletionMessage.Role.SYSTEM;
        }).toList();
        if (list2.size() > 1) {
            throw new IllegalArgumentException("Only one system message is allowed in the prompt");
        }
        QianFanApi.ChatCompletionRequest chatCompletionRequest = new QianFanApi.ChatCompletionRequest(list3, list2.isEmpty() ? null : ((QianFanApi.ChatCompletionMessage) list2.get(0)).content(), Boolean.valueOf(z));
        if (this.defaultOptions != null) {
            chatCompletionRequest = (QianFanApi.ChatCompletionRequest) ModelOptionsUtils.merge(this.defaultOptions, chatCompletionRequest, QianFanApi.ChatCompletionRequest.class);
        }
        if (prompt.getOptions() != null) {
            chatCompletionRequest = (QianFanApi.ChatCompletionRequest) ModelOptionsUtils.merge((QianFanChatOptions) ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, QianFanChatOptions.class), chatCompletionRequest, QianFanApi.ChatCompletionRequest.class);
        }
        return chatCompletionRequest;
    }

    public ChatOptions getDefaultOptions() {
        return QianFanChatOptions.fromOptions(this.defaultOptions);
    }

    private ChatOptions buildRequestOptions(QianFanApi.ChatCompletionRequest chatCompletionRequest) {
        return ChatOptions.builder().model(chatCompletionRequest.model()).frequencyPenalty(chatCompletionRequest.frequencyPenalty()).maxTokens(chatCompletionRequest.maxTokens()).presencePenalty(chatCompletionRequest.presencePenalty()).stopSequences(chatCompletionRequest.stop()).temperature(chatCompletionRequest.temperature()).topP(chatCompletionRequest.topP()).build();
    }

    private ChatResponseMetadata from(QianFanApi.ChatCompletion chatCompletion, String str) {
        Assert.notNull(chatCompletion, "QianFan ChatCompletionResult must not be null");
        return ChatResponseMetadata.builder().id(chatCompletion.id() != null ? chatCompletion.id() : "").usage(chatCompletion.usage() != null ? getDefaultUsage(chatCompletion.usage()) : new EmptyUsage()).model(str).keyValue("created", Long.valueOf(chatCompletion.created() != null ? chatCompletion.created().longValue() : 0L)).build();
    }

    private DefaultUsage getDefaultUsage(QianFanApi.Usage usage) {
        return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage);
    }

    public void setObservationConvention(ChatModelObservationConvention chatModelObservationConvention) {
        this.observationConvention = chatModelObservationConvention;
    }
}
