/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.advisor;

import com.alibaba.cloud.ai.document.DocumentWithScore;
import com.alibaba.cloud.ai.model.RerankModel;
import com.alibaba.cloud.ai.model.RerankRequest;
import com.alibaba.cloud.ai.model.RerankResponse;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

public class RetrievalRerankAdvisor
implements CallAroundAdvisor,
StreamAroundAdvisor {
    private static final Logger logger = LoggerFactory.getLogger(RetrievalRerankAdvisor.class);
    private static final String DEFAULT_USER_TEXT_ADVISE = "Context information is below.\n---------------------\n{question_answer_context}\n---------------------\nGiven the context and provided history information and not prior knowledge,\nreply to the user comment. If the answer is not in the context, inform\nthe user that you can't answer the question.\n";
    private static final Double DEFAULT_MIN_SCORE = 0.1;
    private static final int DEFAULT_ORDER = 0;
    private final VectorStore vectorStore;
    private final RerankModel rerankModel;
    private final String userTextAdvise;
    private final SearchRequest searchRequest;
    private final Double minScore;
    private final boolean protectFromBlocking;
    private final int order;
    public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";
    public static final String FILTER_EXPRESSION = "qa_filter_expression";
    public static final String RERANK_SCORE = "rerank_score";

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel) {
        this(vectorStore, rerankModel, SearchRequest.builder().build(), DEFAULT_USER_TEXT_ADVISE, DEFAULT_MIN_SCORE);
    }

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, Double score) {
        this(vectorStore, rerankModel, SearchRequest.builder().build(), DEFAULT_USER_TEXT_ADVISE, score);
    }

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, SearchRequest searchRequest) {
        this(vectorStore, rerankModel, searchRequest, DEFAULT_USER_TEXT_ADVISE, DEFAULT_MIN_SCORE);
    }

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, SearchRequest searchRequest, String userTextAdvise, Double minScore) {
        this(vectorStore, rerankModel, searchRequest, userTextAdvise, minScore, true);
    }

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, SearchRequest searchRequest, String userTextAdvise, Double minScore, boolean protectFromBlocking) {
        this(vectorStore, rerankModel, searchRequest, userTextAdvise, minScore, protectFromBlocking, 0);
    }

    public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, SearchRequest searchRequest, String userTextAdvise, Double minScore, boolean protectFromBlocking, int order) {
        Assert.notNull((Object)vectorStore, (String)"The vectorStore must not be null!");
        Assert.notNull((Object)rerankModel, (String)"The rerankModel must not be null!");
        Assert.notNull((Object)searchRequest, (String)"The searchRequest must not be null!");
        Assert.hasText((String)userTextAdvise, (String)"The userTextAdvise must not be empty!");
        this.vectorStore = vectorStore;
        this.rerankModel = rerankModel;
        this.userTextAdvise = userTextAdvise;
        this.searchRequest = searchRequest;
        this.minScore = minScore;
        this.protectFromBlocking = protectFromBlocking;
        this.order = order;
    }

    public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
        advisedRequest = this.before(advisedRequest);
        AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
        return this.after(advisedResponse);
    }

    public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
        Flux advisedResponses = this.protectFromBlocking ? Mono.just((Object)advisedRequest).publishOn(Schedulers.boundedElastic()).map(this::before).flatMapMany(request -> chain.nextAroundStream(request)) : chain.nextAroundStream(this.before(advisedRequest));
        return advisedResponses.map(ar -> {
            if (this.onFinishReason().test((AdvisedResponse)ar)) {
                ar = this.after((AdvisedResponse)ar);
            }
            return ar;
        });
    }

    public String getName() {
        return this.getClass().getSimpleName();
    }

    public int getOrder() {
        return this.order;
    }

    protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {
        if (!context.containsKey(FILTER_EXPRESSION) || !StringUtils.hasText((String)context.get(FILTER_EXPRESSION).toString())) {
            return this.searchRequest.getFilterExpression();
        }
        return new FilterExpressionTextParser().parse(context.get(FILTER_EXPRESSION).toString());
    }

    protected List<Document> doRerank(AdvisedRequest request, List<Document> documents) {
        if (CollectionUtils.isEmpty(documents)) {
            return documents;
        }
        RerankRequest rerankRequest = new RerankRequest(request.userText(), documents);
        RerankResponse response = this.rerankModel.call(rerankRequest);
        logger.debug("reranked documents: {}", (Object)response);
        if (response == null || response.getResults() == null) {
            return documents;
        }
        return response.getResults().stream().filter(doc -> doc != null && doc.getScore() >= this.minScore).sorted(Comparator.comparingDouble(DocumentWithScore::getScore).reversed()).map(DocumentWithScore::getOutput).collect(Collectors.toList());
    }

    private AdvisedRequest before(AdvisedRequest request) {
        HashMap<String, Object> context = new HashMap<String, Object>(request.adviseContext());
        String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;
        SearchRequest searchRequestToUse = SearchRequest.from((SearchRequest)this.searchRequest).query(request.userText()).filterExpression(this.doGetFilterExpression(context)).build();
        logger.debug("searchRequestToUse: {}", (Object)searchRequestToUse);
        List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);
        logger.debug("retrieved documents: {}", (Object)documents);
        documents = this.doRerank(request, documents);
        context.put(RETRIEVED_DOCUMENTS, documents);
        String documentContext = documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
        HashMap<String, String> advisedUserParams = new HashMap<String, String>(request.userParams());
        advisedUserParams.put("question_answer_context", documentContext);
        return AdvisedRequest.from((AdvisedRequest)request).userText(advisedUserText).userParams(advisedUserParams).adviseContext(context).build();
    }

    private AdvisedResponse after(AdvisedResponse advisedResponse) {
        ChatResponseMetadata.Builder metadataBuilder = ChatResponseMetadata.builder();
        metadataBuilder.keyValue(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS));
        ChatResponseMetadata metadata = advisedResponse.response().getMetadata();
        if (metadata != null) {
            metadataBuilder.id(metadata.getId());
            metadataBuilder.model(metadata.getModel());
            metadataBuilder.usage(metadata.getUsage());
            metadataBuilder.promptMetadata(metadata.getPromptMetadata());
            metadataBuilder.rateLimit(metadata.getRateLimit());
            Set entries = metadata.entrySet();
            for (Map.Entry entry : entries) {
                metadataBuilder.keyValue((String)entry.getKey(), entry.getValue());
            }
        }
        ChatResponse chatResponse = new ChatResponse(advisedResponse.response().getResults(), metadataBuilder.build());
        return new AdvisedResponse(chatResponse, advisedResponse.adviseContext());
    }

    private Predicate<AdvisedResponse> onFinishReason() {
        return advisedResponse -> advisedResponse.response().getResults().stream().anyMatch(result -> result != null && result.getMetadata() != null && StringUtils.hasText((String)result.getMetadata().getFinishReason()));
    }
}

