/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.commons.model.expressions;

import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.text.similarity.LevenshteinDistance;
import org.kie.pmml.api.enums.COUNT_HITS;
import org.kie.pmml.api.enums.LOCAL_TERM_WEIGHTS;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.commons.model.KiePMMLExtension;
import org.kie.pmml.commons.model.ProcessingDTO;
import org.kie.pmml.commons.model.abstracts.AbstractKiePMMLComponent;
import org.kie.pmml.commons.model.expressions.ExpressionsUtils;
import org.kie.pmml.commons.model.expressions.KiePMMLExpression;
import org.kie.pmml.commons.model.expressions.KiePMMLTextIndexNormalization;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KiePMMLTextIndex
extends AbstractKiePMMLComponent
implements KiePMMLExpression {
    private static final long serialVersionUID = -1946996874918753317L;
    private static final Logger logger = LoggerFactory.getLogger(KiePMMLTextIndex.class);
    public static final String DEFAULT_TOKENIZER = "\\s+";
    private final KiePMMLExpression expression;
    private LOCAL_TERM_WEIGHTS localTermWeights = LOCAL_TERM_WEIGHTS.TERM_FREQUENCY;
    private boolean isCaseSensitive = false;
    private int maxLevenshteinDistance = 0;
    private COUNT_HITS countHits = COUNT_HITS.ALL_HITS;
    private String wordSeparatorCharacterRE = "\\s+";
    private boolean tokenize = true;
    private LevenshteinDistance levenshteinDistance;
    private List<KiePMMLTextIndexNormalization> textIndexNormalizations;

    private KiePMMLTextIndex(String name, List<KiePMMLExtension> extensions, KiePMMLExpression expression) {
        super(name, extensions);
        this.expression = expression;
        this.levenshteinDistance = new LevenshteinDistance(this.maxLevenshteinDistance);
    }

    public static Builder builder(String name, List<KiePMMLExtension> extensions, KiePMMLExpression expression) {
        return new Builder(name, extensions, expression);
    }

    static double evaluateRaw(boolean isCaseSensitive, boolean tokenize, String term, String text, String wordSeparatorCharacterRE, LOCAL_TERM_WEIGHTS localTermWeights, COUNT_HITS countHits, LevenshteinDistance levenshteinDistance) {
        int calculatedLevenshteinDistance;
        if (!isCaseSensitive) {
            term = term.toLowerCase();
            text = text.toLowerCase();
        }
        Pattern pattern = tokenize ? Pattern.compile(wordSeparatorCharacterRE) : Pattern.compile(DEFAULT_TOKENIZER);
        List<String> terms = KiePMMLTextIndex.splitText(term, pattern);
        List<String> texts = KiePMMLTextIndex.splitText(text, pattern);
        switch (countHits) {
            case ALL_HITS: {
                calculatedLevenshteinDistance = KiePMMLTextIndex.evaluateLevenshteinDistanceAllHits(levenshteinDistance, terms, texts);
                break;
            }
            case BEST_HITS: {
                calculatedLevenshteinDistance = KiePMMLTextIndex.evaluateLevenshteinDistanceBestHits(levenshteinDistance, terms, texts);
                break;
            }
            default: {
                throw new IllegalArgumentException("Unknown COUNT_HITS " + countHits);
            }
        }
        switch (localTermWeights) {
            case TERM_FREQUENCY: {
                return calculatedLevenshteinDistance;
            }
            case BINARY: {
                return KiePMMLTextIndex.evaluateBinary(calculatedLevenshteinDistance);
            }
            case LOGARITHMIC: {
                return KiePMMLTextIndex.evaluateLogarithmic(calculatedLevenshteinDistance);
            }
            case AUGMENTED_NORMALIZED_TERM_FREQUENCY: {
                return KiePMMLTextIndex.evaluateAugmentedNormalizedTermFrequency(calculatedLevenshteinDistance, texts);
            }
        }
        throw new IllegalArgumentException("Unknown LOCAL_TERM_WEIGHTS " + localTermWeights);
    }

    static int evaluateBinary(int calculatedLevenshteinDistance) {
        return calculatedLevenshteinDistance >= 0 ? 1 : 0;
    }

    static double evaluateLogarithmic(int calculatedLevenshteinDistance) {
        return Math.log10(1.0 + (double)calculatedLevenshteinDistance);
    }

    static double evaluateAugmentedNormalizedTermFrequency(int calculatedLevenshteinDistance, List<String> texts) {
        Map wordFrequencies = texts.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
        int maxFrequency = wordFrequencies.values().stream().max(Comparator.comparingLong(f -> f)).map(Long::intValue).orElseThrow(() -> new KiePMMLException("Failed to find most frequent word!"));
        int binaryEvaluation = KiePMMLTextIndex.evaluateBinary(calculatedLevenshteinDistance);
        return 0.5 * ((double)binaryEvaluation + (double)calculatedLevenshteinDistance / (double)maxFrequency);
    }

    static int evaluateLevenshteinDistanceAllHits(LevenshteinDistance levenshteinDistance, List<String> terms, List<String> texts) {
        logger.debug("evaluateLevenshteinDistanceAllHits {} {}", (Object)terms, (Object)texts);
        int batchSize = terms.size();
        int limit = texts.size() - batchSize + 1;
        String toSearch = String.join((CharSequence)" ", terms);
        int toReturn = 0;
        for (int i = 0; i < limit; ++i) {
            String subText = String.join((CharSequence)" ", texts.subList(i, i + batchSize));
            int distance = KiePMMLTextIndex.evaluateLevenshteinDistance(levenshteinDistance, toSearch, subText);
            if (distance <= -1) continue;
            ++toReturn;
        }
        return toReturn;
    }

    static int evaluateLevenshteinDistanceBestHits(LevenshteinDistance levenshteinDistance, List<String> terms, List<String> texts) {
        logger.debug("evaluateLevenshteinDistanceBestHits {} {}", (Object)terms, (Object)texts);
        int batchSize = terms.size();
        int limit = texts.size() - batchSize + 1;
        String toSearch = String.join((CharSequence)" ", terms);
        TreeMap<Integer, AtomicInteger> distancesMap = new TreeMap<Integer, AtomicInteger>();
        for (int i = 0; i < limit; ++i) {
            String subText = String.join((CharSequence)" ", texts.subList(i, i + batchSize));
            int distance = KiePMMLTextIndex.evaluateLevenshteinDistance(levenshteinDistance, toSearch, subText);
            if (distance <= -1) continue;
            if (distancesMap.containsKey(distance)) {
                ((AtomicInteger)distancesMap.get(distance)).addAndGet(1);
                continue;
            }
            distancesMap.put(distance, new AtomicInteger(1));
        }
        return ((AtomicInteger)distancesMap.get(distancesMap.firstKey())).get();
    }

    static int evaluateLevenshteinDistance(LevenshteinDistance levenshteinDistance, String term, String text) {
        logger.debug("evaluateLevenshteinDistance {} {}", (Object)term, (Object)text);
        return levenshteinDistance.apply(term, text);
    }

    static List<String> splitText(String toSplit, Pattern pattern) {
        return pattern.splitAsStream(toSplit).map(trm -> trm.replaceAll("[^a-zA-Z0-9 ]", "")).filter(trm -> !trm.isEmpty()).collect(Collectors.toList());
    }

    public KiePMMLExpression getExpression() {
        return this.expression;
    }

    public LOCAL_TERM_WEIGHTS getLocalTermWeights() {
        return this.localTermWeights;
    }

    public boolean isCaseSensitive() {
        return this.isCaseSensitive;
    }

    public int getMaxLevenshteinDistance() {
        return this.maxLevenshteinDistance;
    }

    public COUNT_HITS getCountHits() {
        return this.countHits;
    }

    public String getWordSeparatorCharacterRE() {
        return this.wordSeparatorCharacterRE;
    }

    public boolean isTokenize() {
        return this.tokenize;
    }

    public LevenshteinDistance getLevenshteinDistance() {
        return this.levenshteinDistance;
    }

    public List<KiePMMLTextIndexNormalization> getTextIndexNormalizations() {
        return Collections.unmodifiableList(this.textIndexNormalizations);
    }

    @Override
    public Object evaluate(ProcessingDTO processingDTO) {
        String term = (String)this.expression.evaluate(processingDTO);
        String text = (String)ExpressionsUtils.getFromPossibleSources(this.name, processingDTO).orElseThrow(() -> new KiePMMLException("No text to scan in " + this));
        if (this.textIndexNormalizations != null) {
            for (KiePMMLTextIndexNormalization textIndexNormalization : this.textIndexNormalizations) {
                text = textIndexNormalization.replace(text, this.isCaseSensitive, this.maxLevenshteinDistance, false, DEFAULT_TOKENIZER);
            }
        }
        return KiePMMLTextIndex.evaluateRaw(this.isCaseSensitive, this.tokenize, term, text, this.wordSeparatorCharacterRE, this.localTermWeights, this.countHits, this.levenshteinDistance);
    }

    public String toString() {
        return new StringJoiner(", ", KiePMMLTextIndex.class.getSimpleName() + "[", "]").add("name='" + this.name + "'").add("localTermWeights=" + this.localTermWeights).add("isCaseSensitive=" + this.isCaseSensitive).add("maxLevenshteinDistance=" + this.maxLevenshteinDistance).add("countHits=" + this.countHits).add("wordSeparatorCharacterRE='" + this.wordSeparatorCharacterRE + "'").add("tokenize=" + this.tokenize).toString();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        KiePMMLTextIndex that = (KiePMMLTextIndex)o;
        return this.isCaseSensitive == that.isCaseSensitive && this.maxLevenshteinDistance == that.maxLevenshteinDistance && this.tokenize == that.tokenize && this.localTermWeights == that.localTermWeights && this.countHits == that.countHits && this.wordSeparatorCharacterRE.equals(that.wordSeparatorCharacterRE);
    }

    public int hashCode() {
        return Objects.hash(this.localTermWeights, this.isCaseSensitive, this.maxLevenshteinDistance, this.countHits, this.wordSeparatorCharacterRE, this.tokenize);
    }

    public static class Builder
    extends AbstractKiePMMLComponent.Builder<KiePMMLTextIndex> {
        private Builder(String name, List<KiePMMLExtension> extensions, KiePMMLExpression expression) {
            super("TextIndex-", () -> new KiePMMLTextIndex(name, extensions, expression));
        }

        public Builder withLocalTermWeights(LOCAL_TERM_WEIGHTS localTermWeights) {
            if (localTermWeights != null) {
                ((KiePMMLTextIndex)this.toBuild).localTermWeights = localTermWeights;
            }
            return this;
        }

        public Builder withIsCaseSensitive(boolean isCaseSensitive) {
            ((KiePMMLTextIndex)this.toBuild).isCaseSensitive = isCaseSensitive;
            return this;
        }

        public Builder withMaxLevenshteinDistance(int maxLevenshteinDistance) {
            ((KiePMMLTextIndex)this.toBuild).maxLevenshteinDistance = maxLevenshteinDistance;
            ((KiePMMLTextIndex)this.toBuild).levenshteinDistance = new LevenshteinDistance(maxLevenshteinDistance);
            return this;
        }

        public Builder withCountHits(COUNT_HITS countHits) {
            if (countHits != null) {
                ((KiePMMLTextIndex)this.toBuild).countHits = countHits;
            }
            return this;
        }

        public Builder withWordSeparatorCharacterRE(String wordSeparatorCharacterRE) {
            if (wordSeparatorCharacterRE != null) {
                ((KiePMMLTextIndex)this.toBuild).wordSeparatorCharacterRE = wordSeparatorCharacterRE;
            }
            return this;
        }

        public Builder withTokenize(boolean tokenize) {
            ((KiePMMLTextIndex)this.toBuild).tokenize = tokenize;
            return this;
        }

        public Builder withTextIndexNormalizations(List<KiePMMLTextIndexNormalization> textIndexNormalizations) {
            if (textIndexNormalizations != null) {
                ((KiePMMLTextIndex)this.toBuild).textIndexNormalizations = textIndexNormalizations;
            }
            return this;
        }
    }
}

