/*
 * Decompiled with CFR 0.152.
 */
package com.didalgo.gpt3;

import com.didalgo.gpt3.ByteSequence;
import com.didalgo.gpt3.Encoding;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public class GPT3Tokenizer {
    private final Map<ByteSequence, Integer> encoder;
    private final Map<Integer, ByteSequence> decoder;
    private final Map<String, Integer> specialTokensEncoder;
    private final Map<Integer, String> specialTokensDecoder;
    private final Pattern pattern;
    private final Pattern specialPattern;

    public GPT3Tokenizer(Encoding encoding) {
        this.encoder = encoding.mergeableRanks();
        this.decoder = this.encoder.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
        this.specialTokensEncoder = encoding.specialTokens();
        this.specialTokensDecoder = this.specialTokensEncoder.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
        this.pattern = encoding.pattern();
        this.specialPattern = this.createSpecialRegex(encoding.specialTokens());
    }

    protected Pattern createSpecialRegex(Map<String, ?> specialTokensEncoder) {
        String joinedPattern = specialTokensEncoder.keySet().stream().map(Pattern::quote).collect(Collectors.joining("|"));
        return Pattern.compile(joinedPattern);
    }

    public String decode(List<Integer> tokens) {
        return this.decodeImpl(tokens);
    }

    protected String decodeImpl(List<Integer> tokens) {
        ByteArrayOutputStream result = new ByteArrayOutputStream();
        for (Integer token : tokens) {
            ByteSequence bytes = this.decoder.get(token);
            if (bytes != null) {
                result.writeBytes(bytes.toByteArray());
                continue;
            }
            result.writeBytes(this.specialTokensDecoder.get(token).getBytes(StandardCharsets.ISO_8859_1));
        }
        return result.toString(StandardCharsets.UTF_8);
    }

    protected Pattern getTlSpecialRegex() {
        return this.specialPattern;
    }

    protected Pattern getTlRegex() {
        return this.pattern;
    }

    public List<Integer> encode(String text) {
        return this.encode(text, false);
    }

    public List<Integer> encode(String text, boolean allowedSpecial) {
        return this.encode(text, allowedSpecial ? this.specialTokensEncoder.keySet() : Set.of());
    }

    public List<Integer> encode(String text, Set<String> allowedSpecial) {
        return this.encodeImpl(text, allowedSpecial);
    }

    protected List<Integer> encodeImpl(String text, Set<String> allowedSpecial) {
        Pattern specialRegex = this.getTlSpecialRegex();
        Pattern regex = this.getTlRegex();
        ArrayList<Integer> ret = new ArrayList<Integer>(text.length() / 4);
        int start = 0;
        int lastPieceTokenLen = 0;
        while (true) {
            Integer token;
            Object piece;
            Matcher nextSpecial;
            block5: {
                int startFind = start;
                while ((nextSpecial = specialRegex.matcher(text.substring(startFind))).find()) {
                    int startMatch = start + nextSpecial.start();
                    if (!allowedSpecial.contains(text.substring(startMatch, startMatch + nextSpecial.group().length()))) {
                        startFind = startMatch + 1;
                        continue;
                    }
                    break block5;
                }
                nextSpecial = null;
            }
            int end = nextSpecial != null ? start + nextSpecial.start() : text.length();
            Matcher matcher = regex.matcher(text.substring(start, end));
            while (matcher.find()) {
                piece = ByteSequence.from(matcher.group());
                token = this.encoder.get(piece);
                if (token != null) {
                    lastPieceTokenLen = 1;
                    ret.add(token);
                    continue;
                }
                lastPieceTokenLen = this.bytePairMerge((ByteSequence)piece, ret);
            }
            if (nextSpecial == null) break;
            piece = nextSpecial.group();
            token = this.specialTokensEncoder.get(piece);
            ret.add(token);
            start += nextSpecial.end();
            lastPieceTokenLen = 0;
        }
        return ret;
    }

    protected int getRank(ByteSequence piece, List<IntPair> partsList, int startIdx) {
        if (startIdx + 2 < partsList.size()) {
            ByteSequence bytes = piece.subSequence(partsList.get((int)startIdx).start, partsList.get((int)(startIdx + 2)).start);
            Integer rank = this.encoder.get(bytes);
            return rank != null ? rank : Integer.MAX_VALUE;
        }
        return Integer.MAX_VALUE;
    }

    protected int bytePairMerge(ByteSequence piece, Collection<Integer> result) {
        int i;
        ArrayList<IntPair> parts = new ArrayList<IntPair>(piece.length() + 1);
        for (i = 0; i <= piece.length(); ++i) {
            parts.add(new IntPair(i, Integer.MAX_VALUE));
        }
        for (i = 0; i < parts.size() - 2; ++i) {
            int rank = this.getRank(piece, parts, i);
            if (rank == Integer.MAX_VALUE) continue;
            ((IntPair)parts.get((int)i)).end = rank;
        }
        while (parts.size() > 1) {
            int minRank = Integer.MAX_VALUE;
            int minIndex = -1;
            for (int i2 = 0; i2 < parts.size() - 1; ++i2) {
                int rank = ((IntPair)parts.get((int)i2)).end;
                if (rank >= minRank) continue;
                minRank = rank;
                minIndex = i2;
            }
            if (minRank == Integer.MAX_VALUE) break;
            parts.remove(minIndex + 1);
            ((IntPair)parts.get((int)minIndex)).end = this.getRank(piece, parts, minIndex);
            if (minIndex <= 0) continue;
            ((IntPair)parts.get((int)(minIndex - 1))).end = this.getRank(piece, parts, minIndex - 1);
        }
        int resultCount = 0;
        for (int i3 = 0; i3 < parts.size() - 1; ++i3) {
            IntPair range = new IntPair(((IntPair)parts.get((int)i3)).start, ((IntPair)parts.get((int)(i3 + 1))).start);
            result.add(this.encoder.get(piece.subSequence(range.start, range.end)));
            ++resultCount;
        }
        return resultCount;
    }

    private static class IntPair {
        int start;
        int end;

        IntPair(int start, int end) {
            this.start = start;
            this.end = end;
        }
    }
}

