package io.github.lnyocly.ai4j.platform.openai.chat;

import com.alibaba.fastjson2.JSON;
import io.github.lnyocly.ai4j.config.OpenAiConfig;
import io.github.lnyocly.ai4j.constant.Constants;
import io.github.lnyocly.ai4j.listener.SseListener;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletion;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatCompletionResponse;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.ChatMessage;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.Choice;
import io.github.lnyocly.ai4j.platform.openai.chat.entity.StreamOptions;
import io.github.lnyocly.ai4j.platform.openai.tool.Tool;
import io.github.lnyocly.ai4j.platform.openai.tool.ToolCall;
import io.github.lnyocly.ai4j.platform.openai.usage.Usage;
import io.github.lnyocly.ai4j.service.Configuration;
import io.github.lnyocly.ai4j.service.IChatService;
import io.github.lnyocly.ai4j.utils.ToolUtil;
import io.github.lnyocly.ai4j.utils.ValidateUtil;
import java.util.ArrayList;
import java.util.List;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/github/lnyocly/ai4j/platform/openai/chat/OpenAiChatService.class */
public class OpenAiChatService implements IChatService {
    private static final Logger log = LoggerFactory.getLogger(OpenAiChatService.class);
    private final OpenAiConfig openAiConfig;
    private final OkHttpClient okHttpClient;
    private final EventSource.Factory factory;

    public OpenAiChatService(Configuration configuration) {
        this.openAiConfig = configuration.getOpenAiConfig();
        this.okHttpClient = configuration.getOkHttpClient();
        this.factory = configuration.createRequestFactory();
    }

    @Override // io.github.lnyocly.ai4j.service.IChatService
    public ChatCompletionResponse chatCompletion(String str, String str2, ChatCompletion chatCompletion) throws Exception {
        if (str == null || "".equals(str)) {
            str = this.openAiConfig.getApiHost();
        }
        if (str2 == null || "".equals(str2)) {
            str2 = this.openAiConfig.getApiKey();
        }
        chatCompletion.setStream(false);
        chatCompletion.setStreamOptions(null);
        if (chatCompletion.getFunctions() == null || chatCompletion.getFunctions().isEmpty()) {
            chatCompletion.setParallelToolCalls(null);
        } else {
            List<Tool> allFunctionTools = ToolUtil.getAllFunctionTools(chatCompletion.getFunctions());
            chatCompletion.setTools(allFunctionTools);
            if (allFunctionTools == null) {
                chatCompletion.setParallelToolCalls(null);
            }
        }
        Usage usage = new Usage();
        String str3 = "first";
        while (true) {
            if (!"first".equals(str3) && !"tool_calls".equals(str3)) {
                return null;
            }
            Response execute = this.okHttpClient.newCall(new Request.Builder().header("Authorization", "Bearer " + str2).url(ValidateUtil.concatUrl(str, this.openAiConfig.getChatCompletionUrl())).post(RequestBody.create(MediaType.parse(Constants.JSON_CONTENT_TYPE), JSON.toJSONString(chatCompletion))).build()).execute();
            if (!execute.isSuccessful() || execute.body() == null) {
                return null;
            }
            ChatCompletionResponse chatCompletionResponse = (ChatCompletionResponse) JSON.parseObject(execute.body().string(), ChatCompletionResponse.class);
            Choice choice = chatCompletionResponse.getChoices().get(0);
            str3 = choice.getFinishReason();
            Usage usage2 = chatCompletionResponse.getUsage();
            usage.setCompletionTokens(usage.getCompletionTokens() + usage2.getCompletionTokens());
            usage.setTotalTokens(usage.getTotalTokens() + usage2.getTotalTokens());
            usage.setPromptTokens(usage.getPromptTokens() + usage2.getPromptTokens());
            if (!"tool_calls".equals(str3)) {
                chatCompletionResponse.setUsage(usage);
                return chatCompletionResponse;
            }
            ChatMessage message = choice.getMessage();
            List<ToolCall> toolCalls = message.getToolCalls();
            ArrayList arrayList = new ArrayList(chatCompletion.getMessages());
            arrayList.add(message);
            for (ToolCall toolCall : toolCalls) {
                arrayList.add(ChatMessage.withTool(ToolUtil.invoke(toolCall.getFunction().getName(), toolCall.getFunction().getArguments()), toolCall.getId()));
            }
            chatCompletion.setMessages(arrayList);
        }
    }

    @Override // io.github.lnyocly.ai4j.service.IChatService
    public ChatCompletionResponse chatCompletion(ChatCompletion chatCompletion) throws Exception {
        return chatCompletion(null, null, chatCompletion);
    }

    @Override // io.github.lnyocly.ai4j.service.IChatService
    public void chatCompletionStream(String str, String str2, ChatCompletion chatCompletion, SseListener sseListener) throws Exception {
        if (str == null || "".equals(str)) {
            str = this.openAiConfig.getApiHost();
        }
        if (str2 == null || "".equals(str2)) {
            str2 = this.openAiConfig.getApiKey();
        }
        chatCompletion.setStream(true);
        if (chatCompletion.getStreamOptions() == null) {
            chatCompletion.setStreamOptions(new StreamOptions(true));
        }
        if (chatCompletion.getFunctions() == null || chatCompletion.getFunctions().isEmpty()) {
            chatCompletion.setParallelToolCalls(null);
        } else {
            List<Tool> allFunctionTools = ToolUtil.getAllFunctionTools(chatCompletion.getFunctions());
            chatCompletion.setTools(allFunctionTools);
            if (allFunctionTools == null) {
                chatCompletion.setParallelToolCalls(null);
            }
        }
        String str3 = "first";
        while (true) {
            if (!"first".equals(str3) && !"tool_calls".equals(str3)) {
                return;
            }
            this.factory.newEventSource(new Request.Builder().header("Authorization", "Bearer " + str2).url(ValidateUtil.concatUrl(str, this.openAiConfig.getChatCompletionUrl())).post(RequestBody.create(MediaType.parse(Constants.APPLICATION_JSON), JSON.toJSONString(chatCompletion))).build(), sseListener);
            sseListener.getCountDownLatch().await();
            str3 = sseListener.getFinishReason();
            List<ToolCall> toolCalls = sseListener.getToolCalls();
            if ("tool_calls".equals(str3) && !toolCalls.isEmpty()) {
                ChatMessage withAssistant = ChatMessage.withAssistant(sseListener.getToolCalls());
                ArrayList arrayList = new ArrayList(chatCompletion.getMessages());
                arrayList.add(withAssistant);
                for (ToolCall toolCall : toolCalls) {
                    arrayList.add(ChatMessage.withTool(ToolUtil.invoke(toolCall.getFunction().getName(), toolCall.getFunction().getArguments()), toolCall.getId()));
                }
                sseListener.setToolCalls(new ArrayList());
                sseListener.setToolCall(null);
                chatCompletion.setMessages(arrayList);
            }
        }
    }

    @Override // io.github.lnyocly.ai4j.service.IChatService
    public void chatCompletionStream(ChatCompletion chatCompletion, SseListener sseListener) throws Exception {
        chatCompletionStream(null, null, chatCompletion, sseListener);
    }
}
