/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp;

import ai.djl.modality.nlp.Vocabulary;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public class SimpleVocabulary
implements Vocabulary {
    private Map<String, TokenInfo> tokens = new ConcurrentHashMap<String, TokenInfo>();
    private List<String> indexToToken = new ArrayList<String>();
    private Set<String> reservedTokens;
    private int minFrequency;
    private String unknownToken;

    public SimpleVocabulary(VocabularyBuilder builder) {
        this.reservedTokens = builder.reservedTokens;
        this.minFrequency = builder.minFrequency;
        this.unknownToken = builder.unknownToken;
        this.reservedTokens.add(this.unknownToken);
        this.addTokens(this.reservedTokens);
        for (List<String> sentence : builder.sentences) {
            for (String word : sentence) {
                this.addWord(word);
            }
        }
    }

    public SimpleVocabulary(List<String> tokens) {
        this.reservedTokens = new HashSet<String>();
        this.minFrequency = 10;
        this.unknownToken = "<unk>";
        this.reservedTokens.add(this.unknownToken);
        this.addTokens(this.reservedTokens);
        this.addTokens(tokens);
    }

    private void addWord(String token) {
        if (this.reservedTokens.contains(token)) {
            return;
        }
        TokenInfo tokenInfo = this.tokens.getOrDefault(token, new TokenInfo());
        if (++tokenInfo.frequency == this.minFrequency) {
            tokenInfo.index = this.tokens.size();
            this.indexToToken.add(token);
        }
        this.tokens.put(token, tokenInfo);
    }

    private void addTokens(Collection<String> tokens) {
        for (String token : tokens) {
            TokenInfo tokenInfo = new TokenInfo();
            tokenInfo.frequency = Integer.MAX_VALUE;
            tokenInfo.index = this.indexToToken.size();
            this.indexToToken.add(token);
            this.tokens.put(token, tokenInfo);
        }
    }

    @Override
    public boolean contains(String token) {
        return this.tokens.containsKey(token);
    }

    @Override
    public String getToken(long index) {
        if (index < 0L || index >= (long)this.indexToToken.size()) {
            return this.unknownToken;
        }
        return this.indexToToken.get((int)index);
    }

    @Override
    public long getIndex(String token) {
        if (this.tokens.containsKey(token)) {
            TokenInfo tokenInfo = this.tokens.get(token);
            if (tokenInfo.frequency >= this.minFrequency) {
                return tokenInfo.index;
            }
        }
        return this.tokens.get((Object)this.unknownToken).index;
    }

    @Override
    public long size() {
        return this.tokens.size();
    }

    private static final class TokenInfo {
        int frequency;
        long index = -1L;
    }

    public static class VocabularyBuilder {
        protected List<List<String>> sentences = new LinkedList<List<String>>();
        protected Set<String> reservedTokens = new HashSet<String>();
        protected int minFrequency = 10;
        protected String unknownToken = "<unk>";

        public VocabularyBuilder optMinFrequency(int minFrequency) {
            this.minFrequency = minFrequency;
            return this;
        }

        public VocabularyBuilder optUnknownToken(String unknownToken) {
            this.unknownToken = unknownToken;
            return this;
        }

        public VocabularyBuilder optReservedTokens(Collection<String> reservedTokens) {
            this.reservedTokens.addAll(reservedTokens);
            return this;
        }

        public VocabularyBuilder add(List<String> sentence) {
            this.sentences.add(sentence);
            return this;
        }

        public VocabularyBuilder addAll(List<List<String>> sentences) {
            this.sentences.addAll(sentences);
            return this;
        }

        public SimpleVocabulary build() {
            return new SimpleVocabulary(this);
        }
    }
}

