/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.ObjectDetectionTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.PriorityQueue;

public class YoloV5Translator
extends ObjectDetectionTranslator {
    private YoloOutputType yoloOutputLayerType;
    private float nmsThreshold;

    protected YoloV5Translator(Builder builder) {
        super(builder);
        this.yoloOutputLayerType = builder.outputType;
        this.nmsThreshold = builder.nmsThreshold;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);
        return builder;
    }

    protected double boxIntersection(Rectangle a, Rectangle b) {
        double w = this.overlap((a.getX() * 2.0 + a.getWidth()) / 2.0, a.getWidth(), (b.getX() * 2.0 + b.getWidth()) / 2.0, b.getWidth());
        double h = this.overlap((a.getY() * 2.0 + a.getHeight()) / 2.0, a.getHeight(), (b.getY() * 2.0 + b.getHeight()) / 2.0, b.getHeight());
        if (w < 0.0 || h < 0.0) {
            return 0.0;
        }
        return w * h;
    }

    protected double boxIou(Rectangle a, Rectangle b) {
        return this.boxIntersection(a, b) / this.boxUnion(a, b);
    }

    protected double boxUnion(Rectangle a, Rectangle b) {
        double i = this.boxIntersection(a, b);
        return a.getWidth() * a.getHeight() + b.getWidth() * b.getHeight() - i;
    }

    protected DetectedObjects nms(List<IntermediateResult> list) {
        ArrayList<String> retClasses = new ArrayList<String>();
        ArrayList<Double> retProbs = new ArrayList<Double>();
        ArrayList<BoundingBox> retBB = new ArrayList<BoundingBox>();
        for (int k = 0; k < this.classes.size(); ++k) {
            PriorityQueue<IntermediateResult> pq = new PriorityQueue<IntermediateResult>(50, (lhs, rhs) -> Double.compare(rhs.getConfidence(), lhs.getConfidence()));
            for (IntermediateResult intermediateResult : list) {
                if (intermediateResult.getDetectedClass() != k) continue;
                pq.add(intermediateResult);
            }
            while (pq.size() > 0) {
                IntermediateResult[] a = new IntermediateResult[pq.size()];
                IntermediateResult[] detections = pq.toArray(a);
                Rectangle rec = detections[0].getLocation();
                retClasses.add(detections[0].id);
                retProbs.add(detections[0].confidence);
                if (this.applyRatio) {
                    retBB.add(new Rectangle(rec.getX() / this.imageWidth, rec.getY() / this.imageHeight, rec.getWidth() / this.imageWidth, rec.getHeight() / this.imageHeight));
                } else {
                    retBB.add(new Rectangle(rec.getX(), rec.getY(), rec.getWidth(), rec.getHeight()));
                }
                pq.clear();
                for (int j = 1; j < detections.length; ++j) {
                    IntermediateResult detection = detections[j];
                    Rectangle location = detection.getLocation();
                    if (!(this.boxIou(rec, location) < (double)this.nmsThreshold)) continue;
                    pq.add(detection);
                }
            }
        }
        return new DetectedObjects(retClasses, retProbs, retBB);
    }

    protected double overlap(double x1, double w1, double x2, double w2) {
        double l1 = x1 - w1 / 2.0;
        double l2 = x2 - w2 / 2.0;
        double left = Math.max(l1, l2);
        double r1 = x1 + w1 / 2.0;
        double r2 = x2 + w2 / 2.0;
        double right = Math.min(r1, r2);
        return right - left;
    }

    protected DetectedObjects processFromBoxOutput(NDList list) {
        float[] flattened = ((NDArray)list.get(0)).toFloatArray();
        ArrayList<IntermediateResult> intermediateResults = new ArrayList<IntermediateResult>();
        int sizeClasses = this.classes.size();
        int stride = 5 + sizeClasses;
        int size = flattened.length / stride;
        for (int i = 0; i < size; ++i) {
            int indexBase = i * stride;
            float maxClass = 0.0f;
            int maxIndex = 0;
            for (int c = 0; c < sizeClasses; ++c) {
                if (!(flattened[indexBase + c + 5] > maxClass)) continue;
                maxClass = flattened[indexBase + c + 5];
                maxIndex = c;
            }
            float score = maxClass * flattened[indexBase + 4];
            if (!(score > this.threshold)) continue;
            float xPos = flattened[indexBase];
            float yPos = flattened[indexBase + 1];
            float w = flattened[indexBase + 2];
            float h = flattened[indexBase + 3];
            Rectangle rect = new Rectangle(Math.max(0.0f, xPos - w / 2.0f), Math.max(0.0f, yPos - h / 2.0f), w, h);
            intermediateResults.add(new IntermediateResult((String)this.classes.get(maxIndex), score, maxIndex, rect));
        }
        return this.nms(intermediateResults);
    }

    private DetectedObjects processFromDetectOutput() {
        throw new UnsupportedOperationException("detect layer output is not supported yet, check correct YoloV5 export format");
    }

    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        switch (this.yoloOutputLayerType) {
            case DETECT: {
                return this.processFromDetectOutput();
            }
            case AUTO: {
                if (((NDArray)list.get(0)).getShape().dimension() > 2) {
                    return this.processFromDetectOutput();
                }
                return this.processFromBoxOutput(list);
            }
        }
        return this.processFromBoxOutput(list);
    }

    protected static final class IntermediateResult {
        private double confidence;
        private int detectedClass;
        private String id;
        private Rectangle location;

        IntermediateResult(String id, double confidence, int detectedClass, Rectangle location) {
            this.confidence = confidence;
            this.id = id;
            this.detectedClass = detectedClass;
            this.location = location;
        }

        public double getConfidence() {
            return this.confidence;
        }

        public int getDetectedClass() {
            return this.detectedClass;
        }

        public String getId() {
            return this.id;
        }

        public Rectangle getLocation() {
            return new Rectangle(this.location.getX(), this.location.getY(), this.location.getWidth(), this.location.getHeight());
        }
    }

    public static class Builder
    extends ObjectDetectionTranslator.ObjectDetectionBuilder<Builder> {
        YoloOutputType outputType = YoloOutputType.AUTO;
        float nmsThreshold = 0.4f;

        public Builder optOutputType(YoloOutputType outputType) {
            this.outputType = outputType;
            return this;
        }

        public Builder optNmsThreshold(float nmsThreshold) {
            this.nmsThreshold = nmsThreshold;
            return this;
        }

        @Override
        protected Builder self() {
            return this;
        }

        @Override
        protected void configPostProcess(Map<String, ?> arguments) {
            super.configPostProcess(arguments);
            String type = ArgumentsUtil.stringValue(arguments, "outputType", "AUTO");
            this.outputType = YoloOutputType.valueOf(type.toUpperCase(Locale.ENGLISH));
            this.nmsThreshold = ArgumentsUtil.floatValue(arguments, "nmsThreshold", 0.4f);
        }

        public YoloV5Translator build() {
            if (this.pipeline == null) {
                this.addTransform(array -> array.transpose(2, 0, 1).toType(DataType.FLOAT32, false).div(255));
            }
            this.validate();
            return new YoloV5Translator(this);
        }
    }

    public static enum YoloOutputType {
        BOX,
        DETECT,
        AUTO;

    }
}

