/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.VisionLanguageInput;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.BaseImagePreProcessor;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;

public class ZeroShotObjectDetectionTranslator
implements NoBatchifyTranslator<VisionLanguageInput, DetectedObjects> {
    private HuggingFaceTokenizer tokenizer;
    private BaseImageTranslator<?> imageProcessor;
    private boolean int32;
    private float threshold;

    ZeroShotObjectDetectionTranslator(HuggingFaceTokenizer tokenizer, BaseImageTranslator<?> imageProcessor, boolean int32, float threshold) {
        this.tokenizer = tokenizer;
        this.imageProcessor = imageProcessor;
        this.int32 = int32;
        this.threshold = threshold;
    }

    public NDList processInput(TranslatorContext ctx, VisionLanguageInput input) throws TranslateException {
        NDManager manager = ctx.getNDManager();
        String[] candidates = input.getCandidates();
        if (candidates == null || candidates.length == 0) {
            throw new TranslateException("Missing candidates in input");
        }
        Encoding[] encodings = this.tokenizer.batchEncode(candidates);
        NDList list = Encoding.toNDList(encodings, manager, false, this.int32);
        Image img = input.getImage();
        NDList imageFeatures = this.imageProcessor.processInput(ctx, img);
        NDArray array = ((NDArray)imageFeatures.get(0)).expandDims(0);
        list.add((Object)array);
        ctx.setAttachment("candidates", (Object)candidates);
        return list;
    }

    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) throws TranslateException {
        NDArray logits = list.get("logits");
        NDArray boxes = list.get("pred_boxes");
        NDArray labels = logits.argMax(-1);
        NDArray scores = logits.max(new int[]{-1}).getNDArrayInternal().sigmoid();
        NDArray selected = scores.gt((Number)Float.valueOf(this.threshold));
        scores = scores.get(selected);
        labels = labels.get(selected);
        boxes = boxes.get(selected);
        float[] prob = scores.toFloatArray();
        long[] labelsIndex = labels.toLongArray();
        float[] box = boxes.toFloatArray();
        String[] candidates = (String[])ctx.getAttachment("candidates");
        ArrayList<String> classes = new ArrayList<String>(labelsIndex.length);
        ArrayList<Double> probabilities = new ArrayList<Double>(labelsIndex.length);
        ArrayList<Rectangle> boundingBoxes = new ArrayList<Rectangle>(labelsIndex.length);
        int width = (Integer)ctx.getAttachment("width");
        int height = (Integer)ctx.getAttachment("height");
        for (int i = 0; i < labelsIndex.length; ++i) {
            classes.add(candidates[(int)labelsIndex[i]]);
            int pos = i * 4;
            float x = box[pos];
            float y = box[pos + 1];
            float w = box[pos + 2];
            float h = box[pos + 3];
            x -= w / 2.0f;
            y -= h / 2.0f;
            if (width > height) {
                y = y * (float)width / (float)height;
                h = h * (float)width / (float)height;
            } else if (width < height) {
                x = x * (float)height / (float)width;
                w = w * (float)height / (float)width;
            }
            Rectangle bbox = new Rectangle((double)x, (double)y, (double)w, (double)h);
            boundingBoxes.add(bbox);
            probabilities.add(Double.valueOf(prob[i]));
        }
        return new DetectedObjects(classes, probabilities, boundingBoxes);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer) {
        return new Builder(tokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
        Builder builder = ZeroShotObjectDetectionTranslator.builder(tokenizer);
        builder.configure(arguments);
        return builder;
    }

    public static final class Builder
    extends BaseImageTranslator.BaseBuilder<Builder> {
        private HuggingFaceTokenizer tokenizer;
        private boolean int32;
        private float threshold = 0.2f;

        Builder(HuggingFaceTokenizer tokenizer) {
            this.tokenizer = tokenizer;
        }

        protected Builder self() {
            return this;
        }

        public Builder optThreshold(float threshold) {
            this.threshold = threshold;
            return this;
        }

        public Builder optInt32(boolean int32) {
            this.int32 = int32;
            return this;
        }

        public void configure(Map<String, ?> arguments) {
            this.configPreProcess(arguments);
            this.optInt32(ArgumentsUtil.booleanValue(arguments, (String)"int32"));
            this.optThreshold(ArgumentsUtil.floatValue(arguments, (String)"threshold", (float)0.2f));
        }

        public ZeroShotObjectDetectionTranslator build() throws IOException {
            BaseImagePreProcessor imageProcessor = new BaseImagePreProcessor((BaseImageTranslator.BaseBuilder)this);
            return new ZeroShotObjectDetectionTranslator(this.tokenizer, (BaseImageTranslator<?>)imageProcessor, this.int32, this.threshold);
        }
    }
}

