/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.dashscope.rag;

import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

public class DashScopeDocumentRetrievalAdvisor
implements CallAroundAdvisor,
StreamAroundAdvisor {
    private static final Pattern RAG_REFERENCE_PATTERN = Pattern.compile("<ref>(.*?)</ref>");
    private static final Pattern RAG_REFERENCE_INNER_PATTERN = Pattern.compile("\\[([0-9]+)(?:[,\uff0c]?([0-9]+))*]");
    private static final String DEFAULT_USER_TEXT_ADVISE = "# \u77e5\u8bc6\u5e93\n\u8bf7\u8bb0\u4f4f\u4ee5\u4e0b\u6750\u6599\uff0c\u4ed6\u4eec\u53ef\u80fd\u5bf9\u56de\u7b54\u95ee\u9898\u6709\u5e2e\u52a9\u3002\n\u6307\u4ee4\uff1a\u60a8\u9700\u8981\u4ec5\u4f7f\u7528\u63d0\u4f9b\u7684\u641c\u7d22\u6587\u6863\u4e3a\u7ed9\u5b9a\u95ee\u9898\u5199\u51fa\u9ad8\u8d28\u91cf\u7684\u7b54\u6848\uff0c\u5e76\u6b63\u786e\u5f15\u7528\u5b83\u4eec\u3002 \u5f15\u7528\u591a\u4e2a\u641c\u7d22\u7ed3\u679c\u65f6\uff0c\u8bf7\u4f7f\u7528<ref>[\u7f16\u53f7]</ref>\u683c\u5f0f\uff0c\u6ce8\u610f\u786e\u4fdd\u8fd9\u4e9b\u5f15\u7528\u76f4\u63a5\u6709\u52a9\u4e8e\u89e3\u7b54\u95ee\u9898\uff0c\u7f16\u53f7\u9700\u4e0e\u6750\u6599\u539f\u59cb\u7f16\u53f7\u4e00\u81f4\u4e14\u552f\u4e00\u3002\u8bf7\u6ce8\u610f\uff0c\u6bcf\u4e2a\u53e5\u5b50\u4e2d\u5fc5\u987b\u81f3\u5c11\u5f15\u7528\u4e00\u4e2a\u6587\u6863\u3002\u6362\u53e5\u8bdd\u8bf4\uff0c\u4f60\u7981\u6b62\u5728\u6ca1\u6709\u5f15\u7528\u4efb\u4f55\u6587\u732e\u7684\u60c5\u51b5\u4e0b\u5199\u53e5\u5b50\u3002\u6b64\u5916\uff0c\u60a8\u5e94\u8be5\u5728\u6bcf\u4e2a\u53e5\u5b50\u4e2d\u6dfb\u52a0\u5f15\u7528\u7b26\u53f7\uff0c\u6ce8\u610f\u5728\u53e5\u53f7\u4e4b\u524d\u3002\n\n\u5bf9\u4e8e\u6bcf\u4e2a\u95ee\u9898\u6309\u7167\u4e0b\u9762\u7684\u63a8\u7406\u6b65\u9aa4\u5f97\u5230\u5e26\u5f15\u7528\u7684\u7b54\u6848\uff1a\n\n\u6b65\u9aa41\uff1a\u6211\u5224\u65ad\u6587\u68631\u548c\u6587\u68632\u4e0e\u95ee\u9898\u76f8\u5173\u3002\n\n\u6b65\u9aa42\uff1a\u6839\u636e\u6587\u68631\uff0c\u6211\u5199\u4e86\u4e00\u4e2a\u56de\u7b54\u9648\u8ff0\u5e76\u5f15\u7528\u4e86\u8be5\u6587\u6863\u3002\n\n\u6b65\u9aa43\uff1a\u6839\u636e\u6587\u68632\uff0c\u6211\u5199\u4e00\u4e2a\u7b54\u6848\u58f0\u660e\u5e76\u5f15\u7528\u8be5\u6587\u6863\u3002\n\n\u6b65\u9aa44\uff1a\u6211\u5c06\u4ee5\u4e0a\u4e24\u4e2a\u7b54\u6848\u8bed\u53e5\u8fdb\u884c\u5408\u5e76\u3001\u6392\u5e8f\u548c\u8fde\u63a5\uff0c\u4ee5\u83b7\u5f97\u6d41\u7545\u8fde\u8d2f\u7684\u7b54\u6848\u3002\n\n$$\u6750\u6599\uff1a\n[1] \u3010\u6587\u6863\u540d\u3011\u690d\u7269\u4e2d\u7684\u5149\u5408\u4f5c\u7528.pdf\n\u3010\u6807\u9898\u3011\u5149\u5408\u4f5c\u7528\u4f4d\u7f6e\n\u3010\u6b63\u6587\u3011\u5149\u5408\u4f5c\u7528\u4e3b\u8981\u5728\u53f6\u7eff\u4f53\u4e2d\u8fdb\u884c\uff0c\u6d89\u53ca\u5149\u80fd\u5230\u5316\u5b66\u80fd\u7684\u8f6c\u5316\u3002\n[2] \u3010\u6587\u6863\u540d\u3011\u5149\u5408\u4f5c\u7528.pdf\n\u3010\u6807\u9898\u3011\u5149\u5408\u4f5c\u7528\u8f6c\u5316\n\u3010\u6b63\u6587\u3011\u5149\u5408\u4f5c\u7528\u662f\u5229\u7528\u9633\u5149\u5c06CO2\u548cH2O\u8f6c\u5316\u4e3a\u6c27\u6c14\u548c\u8461\u8404\u7cd6\u7684\u8fc7\u7a0b\u3002\n\n$$\u6750\u6599:\n{question_answer_context}\n";
    private static final int DEFAULT_ORDER = 0;
    public static String RETRIEVED_DOCUMENTS = "question_answer_context";
    private final DocumentRetriever retriever;
    private final String userTextAdvise;
    private final boolean enableReference;
    private final boolean protectFromBlocking;
    private final int order;

    public DashScopeDocumentRetrievalAdvisor(DocumentRetriever retriever, boolean enableReference) {
        this(retriever, DEFAULT_USER_TEXT_ADVISE, enableReference);
    }

    public DashScopeDocumentRetrievalAdvisor(DocumentRetriever retriever, String userTextAdvise, boolean enableReference) {
        this(retriever, userTextAdvise, enableReference, true);
    }

    public DashScopeDocumentRetrievalAdvisor(DocumentRetriever retriever, String userTextAdvise, boolean enableReference, boolean protectFromBlocking) {
        this(retriever, userTextAdvise, enableReference, protectFromBlocking, 0);
    }

    public DashScopeDocumentRetrievalAdvisor(DocumentRetriever retriever, String userTextAdvise, boolean enableReference, boolean protectFromBlocking, int order) {
        this.retriever = retriever;
        this.userTextAdvise = userTextAdvise;
        this.enableReference = enableReference;
        this.protectFromBlocking = protectFromBlocking;
        this.order = order;
    }

    public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
        advisedRequest = this.before(advisedRequest);
        AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
        return this.after(advisedResponse);
    }

    public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
        Flux advisedResponses = this.protectFromBlocking ? Mono.just((Object)advisedRequest).publishOn(Schedulers.boundedElastic()).map(this::before).flatMapMany(arg_0 -> ((StreamAroundAdvisorChain)chain).nextAroundStream(arg_0)) : chain.nextAroundStream(this.before(advisedRequest));
        return advisedResponses.map(ar -> {
            if (this.onFinishReason().test((AdvisedResponse)ar)) {
                ar = this.after((AdvisedResponse)ar);
            }
            return ar;
        });
    }

    public String getName() {
        return this.getClass().getSimpleName();
    }

    public int getOrder() {
        return this.order;
    }

    private AdvisedRequest before(AdvisedRequest request) {
        HashMap context = new HashMap(request.adviseContext());
        List documents = this.retriever.retrieve(new Query(request.userText()));
        HashMap<String, Document> documentMap = new HashMap<String, Document>();
        StringBuffer documentContext = new StringBuffer();
        for (int i = 0; i < documents.size(); ++i) {
            Document document = (Document)documents.get(i);
            String indexId = String.format("[%d]", i + 1);
            String docInfo = String.format("%s \u3010\u6587\u6863\u540d\u3011%s\n\u3010\u6807\u9898\u3011%s\n\u3010\u6b63\u6587\u3011%s\n", indexId, document.getMetadata().get("doc_name"), document.getMetadata().get("title"), document.getText());
            documentContext.append(docInfo);
            documentContext.append(System.lineSeparator());
            document.getMetadata().put("index_id", i);
            documentMap.put(indexId, document);
        }
        context.put(RETRIEVED_DOCUMENTS, documentMap);
        HashMap<String, StringBuffer> advisedUserParams = new HashMap<String, StringBuffer>(request.userParams());
        advisedUserParams.put(RETRIEVED_DOCUMENTS, documentContext);
        return AdvisedRequest.from((AdvisedRequest)request).userText(request.userText() + System.lineSeparator() + this.userTextAdvise).userParams(advisedUserParams).adviseContext(context).build();
    }

    private AdvisedResponse after(AdvisedResponse advisedResponse) {
        if (!this.enableReference) {
            return advisedResponse;
        }
        ChatResponse response = advisedResponse.response();
        Map context = advisedResponse.adviseContext();
        DashScopeApi.ChatCompletionFinishReason finishReason = DashScopeApi.ChatCompletionFinishReason.valueOf(response.getResult().getMetadata().getFinishReason());
        if (finishReason == DashScopeApi.ChatCompletionFinishReason.NULL) {
            String fullContent = context.getOrDefault("full_content", "").toString() + response.getResult().getOutput().getText();
            context.put("full_content", fullContent);
            return advisedResponse;
        }
        String content = context.getOrDefault("full_content", "").toString();
        if ("".equalsIgnoreCase(content)) {
            content = response.getResult().getOutput().getText();
        }
        Map documentMap = (Map)context.get(RETRIEVED_DOCUMENTS);
        ArrayList<Document> referencedDocuments = new ArrayList<Document>();
        Matcher refMatcher = RAG_REFERENCE_PATTERN.matcher(content);
        while (refMatcher.find()) {
            String refContent = refMatcher.group();
            Matcher numberMatcher = RAG_REFERENCE_INNER_PATTERN.matcher(refContent);
            while (numberMatcher.find()) {
                for (int i = 1; i <= numberMatcher.groupCount(); ++i) {
                    if (numberMatcher.group(i) == null) continue;
                    String index = numberMatcher.group(i - 1);
                    Document document = (Document)documentMap.get(index);
                    referencedDocuments.add(document);
                }
            }
        }
        ChatResponseMetadata.Builder metadataBuilder = ChatResponseMetadata.builder();
        metadataBuilder.keyValue(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS));
        ChatResponseMetadata metadata = advisedResponse.response().getMetadata();
        if (metadata != null) {
            metadataBuilder.id(metadata.getId());
            metadataBuilder.model(metadata.getModel());
            metadataBuilder.usage(metadata.getUsage());
            metadataBuilder.promptMetadata(metadata.getPromptMetadata());
            metadataBuilder.rateLimit(metadata.getRateLimit());
            Set entries = metadata.entrySet();
            for (Map.Entry entry : entries) {
                metadataBuilder.keyValue((String)entry.getKey(), entry.getValue());
            }
        }
        ChatResponse chatResponse = new ChatResponse(advisedResponse.response().getResults(), metadataBuilder.build());
        return new AdvisedResponse(chatResponse, context);
    }

    private Predicate<AdvisedResponse> onFinishReason() {
        return advisedResponse -> advisedResponse.response().getResults().stream().anyMatch(result -> result != null && result.getMetadata() != null && StringUtils.hasText((String)result.getMetadata().getFinishReason()));
    }
}

