/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.stabilityai;

import java.util.List;
import java.util.stream.Collectors;
import org.springframework.ai.image.Image;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImageGenerationMetadata;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.stabilityai.StabilityAiImageGenerationMetadata;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.util.Assert;

public class StabilityAiImageModel
implements ImageModel {
    private final StabilityAiImageOptions defaultOptions;
    private final StabilityAiApi stabilityAiApi;

    public StabilityAiImageModel(StabilityAiApi stabilityAiApi) {
        this(stabilityAiApi, StabilityAiImageOptions.builder().build());
    }

    public StabilityAiImageModel(StabilityAiApi stabilityAiApi, StabilityAiImageOptions defaultOptions) {
        Assert.notNull((Object)stabilityAiApi, (String)"StabilityAiApi must not be null");
        Assert.notNull((Object)defaultOptions, (String)"StabilityAiImageOptions must not be null");
        this.stabilityAiApi = stabilityAiApi;
        this.defaultOptions = defaultOptions;
    }

    private static StabilityAiApi.GenerateImageRequest getGenerateImageRequest(ImagePrompt stabilityAiImagePrompt, StabilityAiImageOptions optionsToUse) {
        return new StabilityAiApi.GenerateImageRequest.Builder().textPrompts(stabilityAiImagePrompt.getInstructions().stream().map(message -> new StabilityAiApi.GenerateImageRequest.TextPrompts(message.getText(), message.getWeight())).collect(Collectors.toList())).height(optionsToUse.getHeight()).width(optionsToUse.getWidth()).cfgScale(optionsToUse.getCfgScale()).clipGuidancePreset(optionsToUse.getClipGuidancePreset()).sampler(optionsToUse.getSampler()).samples(optionsToUse.getN()).seed(optionsToUse.getSeed()).steps(optionsToUse.getSteps()).stylePreset(optionsToUse.getStylePreset()).build();
    }

    public StabilityAiImageOptions getOptions() {
        return this.defaultOptions;
    }

    public ImageResponse call(ImagePrompt imagePrompt) {
        StabilityAiImageOptions requestImageOptions = this.mergeOptions(imagePrompt.getOptions(), this.defaultOptions);
        StabilityAiApi.GenerateImageRequest generateImageRequest = StabilityAiImageModel.getGenerateImageRequest(imagePrompt, requestImageOptions);
        StabilityAiApi.GenerateImageResponse generateImageResponse = this.stabilityAiApi.generateImage(generateImageRequest);
        return this.convertResponse(generateImageResponse);
    }

    private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse generateImageResponse) {
        List<ImageGeneration> imageGenerationList = generateImageResponse.artifacts().stream().map(entry -> new ImageGeneration(new Image(null, entry.base64()), (ImageGenerationMetadata)new StabilityAiImageGenerationMetadata(entry.finishReason(), entry.seed()))).toList();
        return new ImageResponse(imageGenerationList, new ImageResponseMetadata());
    }

    StabilityAiImageOptions mergeOptions(ImageOptions runtimeOptions, StabilityAiImageOptions defaultOptions) {
        if (runtimeOptions == null) {
            return defaultOptions;
        }
        StabilityAiImageOptions.Builder builder = StabilityAiImageOptions.builder().model((String)ModelOptionsUtils.mergeOption((Object)runtimeOptions.getModel(), (Object)defaultOptions.getModel())).N((Integer)ModelOptionsUtils.mergeOption((Object)runtimeOptions.getN(), (Object)defaultOptions.getN())).responseFormat((String)ModelOptionsUtils.mergeOption((Object)runtimeOptions.getResponseFormat(), (Object)defaultOptions.getResponseFormat())).width((Integer)ModelOptionsUtils.mergeOption((Object)runtimeOptions.getWidth(), (Object)defaultOptions.getWidth())).height((Integer)ModelOptionsUtils.mergeOption((Object)runtimeOptions.getHeight(), (Object)defaultOptions.getHeight())).stylePreset((String)ModelOptionsUtils.mergeOption((Object)runtimeOptions.getStyle(), (Object)defaultOptions.getStyle())).cfgScale(defaultOptions.getCfgScale()).clipGuidancePreset(defaultOptions.getClipGuidancePreset()).sampler(defaultOptions.getSampler()).seed(defaultOptions.getSeed()).steps(defaultOptions.getSteps()).stylePreset(defaultOptions.getStylePreset());
        if (runtimeOptions instanceof StabilityAiImageOptions) {
            StabilityAiImageOptions stabilityOptions = (StabilityAiImageOptions)runtimeOptions;
            builder.cfgScale((Float)ModelOptionsUtils.mergeOption((Object)stabilityOptions.getCfgScale(), (Object)defaultOptions.getCfgScale())).clipGuidancePreset((String)ModelOptionsUtils.mergeOption((Object)stabilityOptions.getClipGuidancePreset(), (Object)defaultOptions.getClipGuidancePreset())).sampler((String)ModelOptionsUtils.mergeOption((Object)stabilityOptions.getSampler(), (Object)defaultOptions.getSampler())).seed((Long)ModelOptionsUtils.mergeOption((Object)stabilityOptions.getSeed(), (Object)defaultOptions.getSeed())).steps((Integer)ModelOptionsUtils.mergeOption((Object)stabilityOptions.getSteps(), (Object)defaultOptions.getSteps())).stylePreset((String)ModelOptionsUtils.mergeOption((Object)stabilityOptions.getStylePreset(), (Object)defaultOptions.getStylePreset()));
        }
        return builder.build();
    }
}

