/*
 * Decompiled with CFR 0.152.
 */
package org.hibernate.search.backend.lucene.lowlevel.query.impl;

import java.io.IOException;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilterWeight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.VectorUtil;
import org.hibernate.search.util.common.AssertionFailure;

public class VectorSimilarityFilterQuery
extends Query {
    private final Query query;
    private final float similarityAsScore;

    public static VectorSimilarityFilterQuery create(KnnByteVectorQuery query, float similarityLimit, int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
        return new VectorSimilarityFilterQuery((Query)query, VectorSimilarityFilterQuery.byteSimilarityDistanceToScore(similarityLimit, dimension, vectorSimilarityFunction));
    }

    public static VectorSimilarityFilterQuery create(KnnFloatVectorQuery query, float similarityLimit, VectorSimilarityFunction vectorSimilarityFunction) {
        return new VectorSimilarityFilterQuery((Query)query, VectorSimilarityFilterQuery.floatSimilarityDistanceToScore(similarityLimit, vectorSimilarityFunction));
    }

    private VectorSimilarityFilterQuery(Query query, float similarityAsScore) {
        this.query = query;
        this.similarityAsScore = similarityAsScore;
    }

    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        Query rewritten = this.query.rewrite(indexSearcher);
        if (rewritten == this.query) {
            return this;
        }
        return new VectorSimilarityFilterQuery(rewritten, this.similarityAsScore);
    }

    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
        return new SimilarityWeight(this.query.createWeight(searcher, scoreMode, boost), this.similarityAsScore * boost);
    }

    public void visit(QueryVisitor visitor) {
        visitor.visitLeaf((Query)this);
    }

    public String toString(String field) {
        return ((Object)((Object)this)).getClass().getName() + "{query=" + this.query + ", similarityLimit=" + this.similarityAsScore + "}";
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || ((Object)((Object)this)).getClass() != o.getClass()) {
            return false;
        }
        VectorSimilarityFilterQuery that = (VectorSimilarityFilterQuery)((Object)o);
        return Float.compare(this.similarityAsScore, that.similarityAsScore) == 0 && Objects.equals(this.query, that.query);
    }

    public int hashCode() {
        return Objects.hash(this.query, Float.valueOf(this.similarityAsScore));
    }

    private static float floatSimilarityDistanceToScore(float distance, VectorSimilarityFunction similarityFunction) {
        switch (similarityFunction) {
            case EUCLIDEAN: {
                return 1.0f / (1.0f + distance * distance);
            }
            case DOT_PRODUCT: 
            case COSINE: {
                return (1.0f + distance) / 2.0f;
            }
            case MAXIMUM_INNER_PRODUCT: {
                return VectorUtil.scaleMaxInnerProductScore((float)distance);
            }
        }
        throw new AssertionFailure("Unknown similarity function: " + similarityFunction);
    }

    private static float byteSimilarityDistanceToScore(float distance, int dimension, VectorSimilarityFunction similarityFunction) {
        switch (similarityFunction) {
            case EUCLIDEAN: {
                return 1.0f / (1.0f + distance * distance);
            }
            case DOT_PRODUCT: {
                return 0.5f + distance / (float)(dimension * 32768);
            }
            case COSINE: {
                return (1.0f + distance) / 2.0f;
            }
            case MAXIMUM_INNER_PRODUCT: {
                return VectorUtil.scaleMaxInnerProductScore((float)distance);
            }
        }
        throw new AssertionFailure("Unknown similarity function: " + similarityFunction);
    }

    private static class SimilarityWeight
    extends FilterWeight {
        private final float similarityAsScore;

        protected SimilarityWeight(Weight weight, float similarityAsScore) {
            super(weight);
            this.similarityAsScore = similarityAsScore;
        }

        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
            Explanation explanation = super.explain(context, doc);
            if (explanation.isMatch() && this.similarityAsScore > explanation.getValue().floatValue()) {
                return Explanation.noMatch((String)"Similarity limit is greater than the vector similarity.", (Explanation[])new Explanation[]{explanation});
            }
            return explanation;
        }

        public Scorer scorer(LeafReaderContext context) throws IOException {
            Scorer scorer = super.scorer(context);
            if (scorer == null) {
                return null;
            }
            return new MinScoreScorer((Weight)this, scorer, this.similarityAsScore);
        }
    }

    private static class MinScoreScorer
    extends Scorer {
        private final Scorer in;
        private final float minScore;
        private float curScore;

        MinScoreScorer(Weight weight, Scorer scorer, float minScore) {
            super(weight);
            this.in = scorer;
            this.minScore = minScore;
        }

        public int docID() {
            return this.in.docID();
        }

        public float score() {
            return this.curScore;
        }

        public int advanceShallow(int target) throws IOException {
            return this.in.advanceShallow(target);
        }

        public float getMaxScore(int upTo) throws IOException {
            return this.in.getMaxScore(upTo);
        }

        public DocIdSetIterator iterator() {
            return TwoPhaseIterator.asDocIdSetIterator((TwoPhaseIterator)this.twoPhaseIterator());
        }

        public TwoPhaseIterator twoPhaseIterator() {
            final TwoPhaseIterator inTwoPhase = this.in.twoPhaseIterator();
            DocIdSetIterator approximation = inTwoPhase == null ? this.in.iterator() : inTwoPhase.approximation();
            return new TwoPhaseIterator(approximation){

                public boolean matches() throws IOException {
                    if (inTwoPhase != null && !inTwoPhase.matches()) {
                        return false;
                    }
                    curScore = in.score();
                    return curScore >= minScore;
                }

                public float matchCost() {
                    return 1000.0f + (inTwoPhase == null ? 0.0f : inTwoPhase.matchCost());
                }
            };
        }
    }
}

