/*
 * Decompiled with CFR 0.152.
 */
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.ImmutableByteArray;
import com.knuddels.jtokkit.TokenEncoder;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.GptBytePairEncodingParams;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

final class GptBytePairEncoding
implements Encoding {
    private final String name;
    private final Pattern pattern;
    private final TokenEncoder<ImmutableByteArray, Integer> encoder;
    private final TokenEncoder<String, Integer> specialTokensEncoder;

    GptBytePairEncoding(GptBytePairEncodingParams params) {
        this.name = params.getName();
        this.pattern = params.getPattern();
        this.encoder = new TokenEncoder<ImmutableByteArray, Integer>(params.getEncoder(), ImmutableByteArray::from);
        this.specialTokensEncoder = new TokenEncoder<String, Integer>(params.getSpecialTokensEncoder());
    }

    @Override
    public List<Integer> encode(String text) {
        for (String specialToken : this.specialTokensEncoder.getDecodedTokens()) {
            if (!text.contains(specialToken)) continue;
            throw new UnsupportedOperationException("Encoding special tokens is not supported yet.");
        }
        ArrayList<Integer> out = new ArrayList<Integer>();
        Matcher matcher = this.pattern.matcher(text);
        while (matcher.find()) {
            ImmutableByteArray match = ImmutableByteArray.from(matcher.group());
            if (this.encoder.containsDecodedToken(match)) {
                out.add(this.encoder.encode(match));
                continue;
            }
            out.addAll(this.bytePairMerge(match));
        }
        return out;
    }

    @Override
    public int countTokens(String text) {
        return this.encode(text).size();
    }

    @Override
    public String decode(List<Integer> tokens) {
        return new String(this.decodeBytes(tokens), StandardCharsets.UTF_8);
    }

    @Override
    public byte[] decodeBytes(List<Integer> tokens) {
        ArrayList<Byte> out = new ArrayList<Byte>();
        for (int token : tokens) {
            byte[] decodedToken;
            for (byte b : decodedToken = this.decodeToken(token)) {
                out.add(b);
            }
        }
        byte[] outArray = new byte[out.size()];
        for (int i = 0; i < out.size(); ++i) {
            outArray[i] = (Byte)out.get(i);
        }
        return outArray;
    }

    @Override
    public String getName() {
        return this.name;
    }

    private List<Integer> bytePairMerge(ImmutableByteArray piece) {
        int i;
        ArrayList<PieceIndexToRank> parts = new ArrayList<PieceIndexToRank>();
        for (i = 0; i < piece.length() + 1; ++i) {
            parts.add(new PieceIndexToRank(i, Integer.MAX_VALUE));
        }
        for (i = 0; i < parts.size() - 2; ++i) {
            Optional<Integer> rank = this.getRank(piece, parts, i, 0);
            if (!rank.isPresent()) continue;
            ((PieceIndexToRank)parts.get(i)).rank = rank.get();
        }
        while (parts.size() > 1) {
            int minRankIndex = 0;
            int minRank = Integer.MAX_VALUE;
            for (int i2 = 0; i2 < parts.size() - 1; ++i2) {
                int rank = ((PieceIndexToRank)parts.get(i2)).rank;
                if (rank >= minRank) continue;
                minRank = rank;
                minRankIndex = i2;
            }
            if (minRank == Integer.MAX_VALUE) break;
            ((PieceIndexToRank)parts.get(minRankIndex)).rank = this.getRank(piece, parts, minRankIndex, 1).orElse(Integer.MAX_VALUE);
            if (minRankIndex > 0) {
                ((PieceIndexToRank)parts.get(minRankIndex - 1)).rank = this.getRank(piece, parts, minRankIndex - 1, 1).orElse(Integer.MAX_VALUE);
            }
            parts.remove(minRankIndex + 1);
        }
        ArrayList<Integer> out = new ArrayList<Integer>();
        for (int i3 = 0; i3 < parts.size() - 1; ++i3) {
            out.add(this.encoder.encode(piece.getBytesBetween(((PieceIndexToRank)parts.get(i3)).index, ((PieceIndexToRank)parts.get(i3 + 1)).index)));
        }
        return out;
    }

    private Optional<Integer> getRank(ImmutableByteArray piece, List<PieceIndexToRank> parts, int startIndex, int skip) {
        if (startIndex + skip + 2 >= parts.size()) {
            return Optional.empty();
        }
        int pieceStartIndex = parts.get(startIndex).index;
        int pieceEndIndex = parts.get(startIndex + skip + 2).index;
        ImmutableByteArray encoderIndex = piece.getBytesBetween(pieceStartIndex, pieceEndIndex);
        return this.encoder.encodeIfPresent(encoderIndex);
    }

    private byte[] decodeToken(int token) {
        Optional<ImmutableByteArray> decodedToken = this.encoder.decodeIfPresent(token);
        if (decodedToken.isPresent()) {
            return decodedToken.get().getRawArray();
        }
        Optional<String> decodedSpecialToken = this.specialTokensEncoder.decodeIfPresent(token);
        if (decodedSpecialToken.isPresent()) {
            return decodedSpecialToken.get().getBytes(StandardCharsets.UTF_8);
        }
        throw new IllegalArgumentException("Unknown token for decoding: " + token);
    }

    private static class PieceIndexToRank {
        private final int index;
        private int rank;

        public PieceIndexToRank(int index, int rank) {
            this.index = index;
            this.rank = rank;
        }
    }
}

