/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.huggingface.translator;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.translator.ZeroShotClassificationInput;
import ai.djl.modality.nlp.translator.ZeroShotClassificationOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.Map;
import java.util.UUID;

public class ZeroShotClassificationTranslator
implements NoBatchifyTranslator<ZeroShotClassificationInput, ZeroShotClassificationOutput> {
    private HuggingFaceTokenizer tokenizer;
    private boolean int32;
    private Predictor<NDList, NDList> predictor;

    ZeroShotClassificationTranslator(HuggingFaceTokenizer tokenizer, boolean int32) {
        this.tokenizer = tokenizer;
        this.int32 = int32;
    }

    public void prepare(TranslatorContext ctx) throws IOException, ModelException {
        Model model = ctx.getModel();
        this.predictor = model.newPredictor((Translator)new NoopTranslator(null));
        ctx.getPredictorManager().attachInternal(UUID.randomUUID().toString(), new AutoCloseable[]{this.predictor});
    }

    public NDList processInput(TranslatorContext ctx, ZeroShotClassificationInput input) {
        ctx.setAttachment("input", (Object)input);
        return new NDList();
    }

    public ZeroShotClassificationOutput processOutput(TranslatorContext ctx, NDList list) throws TranslateException {
        ZeroShotClassificationInput input = (ZeroShotClassificationInput)ctx.getAttachment("input");
        String template = input.getHypothesisTemplate();
        String[] candidates = input.getCandidates();
        if (candidates == null || candidates.length == 0) {
            throw new TranslateException("Missing candidates in input");
        }
        NDManager manager = ctx.getNDManager();
        NDList output = new NDList(candidates.length);
        for (String candidate : candidates) {
            String hypothesis = this.applyTemplate(template, candidate);
            Encoding encoding = this.tokenizer.encode(input.getText(), hypothesis);
            NDList in = encoding.toNDList(manager, false, this.int32);
            NDList batch = Batchifier.STACK.batchify(new NDList[]{in});
            output.add((Object)((NDArray)((NDList)this.predictor.predict((Object)batch)).get(0)));
        }
        NDArray logits = NDArrays.concat((NDList)output);
        if (input.isMultiLabel()) {
            logits = logits.get(":, -1", new Object[0]);
            logits = logits.softmax(-1);
        } else {
            logits = logits.get(new NDIndex(":, {}", new Object[]{manager.create(new int[]{0, 2})}));
            logits = logits.softmax(1);
            logits = logits.get(":, -1", new Object[0]);
        }
        long[] indices = logits.argSort(-1, false).toLongArray();
        float[] probabilities = logits.toFloatArray();
        String[] labels = new String[candidates.length];
        double[] scores = new double[candidates.length];
        for (int i = 0; i < labels.length; ++i) {
            int index = (int)indices[i];
            labels[i] = candidates[index];
            scores[i] = probabilities[index];
        }
        return new ZeroShotClassificationOutput(input.getText(), labels, scores);
    }

    private String applyTemplate(String template, String arg) {
        int pos = template.indexOf("{}");
        if (pos == -1) {
            return template + arg;
        }
        int len = template.length();
        return template.substring(0, pos) + arg + template.substring(pos + 2, len);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer) {
        return new Builder(tokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
        Builder builder = ZeroShotClassificationTranslator.builder(tokenizer);
        builder.configure(arguments);
        return builder;
    }

    public static final class Builder {
        private HuggingFaceTokenizer tokenizer;
        private boolean int32;

        Builder(HuggingFaceTokenizer tokenizer) {
            this.tokenizer = tokenizer;
        }

        public Builder optInt32(boolean int32) {
            this.int32 = int32;
            return this;
        }

        public void configure(Map<String, ?> arguments) {
            this.optInt32(ArgumentsUtil.booleanValue(arguments, (String)"int32"));
        }

        public ZeroShotClassificationTranslator build() {
            return new ZeroShotClassificationTranslator(this.tokenizer, this.int32);
        }
    }
}

