package ai.djl.ndarray.index.full;

import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexFixed;
import ai.djl.ndarray.index.dim.NDIndexSlice;
import ai.djl.ndarray.types.Shape;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:ai/djl/ndarray/index/full/NDIndexFullSlice.class */
public final class NDIndexFullSlice {
    private long[] min;
    private long[] max;
    private long[] step;
    private int[] toSqueeze;
    private Shape shape;
    private Shape squeezedShape;

    private NDIndexFullSlice(long[] jArr, long[] jArr2, long[] jArr3, int[] iArr, Shape shape, Shape shape2) {
        this.min = jArr;
        this.max = jArr2;
        this.step = jArr3;
        this.toSqueeze = iArr;
        this.shape = shape;
        this.squeezedShape = shape2;
    }

    public static Optional<NDIndexFullSlice> fromIndex(NDIndex nDIndex, Shape shape) {
        if (!nDIndex.stream().allMatch(nDIndexElement -> {
            return (nDIndexElement instanceof NDIndexAll) || (nDIndexElement instanceof NDIndexFixed) || (nDIndexElement instanceof NDIndexSlice);
        })) {
            return Optional.empty();
        }
        int ellipsisIndex = nDIndex.getEllipsisIndex();
        int rank = nDIndex.getRank();
        int dimension = shape.dimension();
        if (rank > shape.dimension()) {
            throw new IllegalArgumentException("The index has too many dimensions - " + rank + " dimensions for array with " + dimension + " dimensions");
        }
        long[] jArr = new long[dimension];
        long[] jArr2 = new long[dimension];
        long[] jArr3 = new long[dimension];
        ArrayList arrayList = new ArrayList(dimension);
        long[] jArr4 = new long[dimension];
        ArrayList arrayList2 = new ArrayList(dimension);
        if (ellipsisIndex == -1 || ellipsisIndex == rank) {
            for (int i = 0; i < rank; i++) {
                addSliceInfo(nDIndex.get(i), i, shape, jArr, jArr2, jArr3, arrayList, jArr4, arrayList2);
            }
            for (int i2 = rank; i2 < shape.dimension(); i2++) {
                padIndexAll(i2, shape, jArr, jArr2, jArr3, jArr4, arrayList2);
            }
        } else if (ellipsisIndex == 0) {
            int i3 = dimension - rank;
            int i4 = 0;
            while (i4 < i3) {
                padIndexAll(i4, shape, jArr, jArr2, jArr3, jArr4, arrayList2);
                i4++;
            }
            while (i4 < dimension) {
                addSliceInfo(nDIndex.get(i4 - i3), i4, shape, jArr, jArr2, jArr3, arrayList, jArr4, arrayList2);
                i4++;
            }
        } else {
            int i5 = dimension - rank;
            int i6 = 0;
            while (i6 < ellipsisIndex) {
                addSliceInfo(nDIndex.get(i6), i6, shape, jArr, jArr2, jArr3, arrayList, jArr4, arrayList2);
                i6++;
            }
            while (i6 < i5 + ellipsisIndex) {
                padIndexAll(i6, shape, jArr, jArr2, jArr3, jArr4, arrayList2);
                i6++;
            }
            while (i6 < dimension) {
                addSliceInfo(nDIndex.get(i6 - i5), i6, shape, jArr, jArr2, jArr3, arrayList, jArr4, arrayList2);
                i6++;
            }
        }
        return Optional.of(new NDIndexFullSlice(jArr, jArr2, jArr3, arrayList.stream().mapToInt(num -> {
            return num.intValue();
        }).toArray(), new Shape(jArr4), new Shape(arrayList2)));
    }

    private static void addSliceInfo(NDIndexElement nDIndexElement, int i, Shape shape, long[] jArr, long[] jArr2, long[] jArr3, List<Integer> list, long[] jArr4, List<Long> list2) {
        if (nDIndexElement instanceof NDIndexFixed) {
            long index = ((NDIndexFixed) nDIndexElement).getIndex();
            jArr[i] = index < 0 ? Math.floorMod(index, shape.get(i)) : index;
            jArr2[i] = jArr[i] + 1;
            jArr3[i] = 1;
            list.add(Integer.valueOf(i));
            jArr4[i] = 1;
            return;
        }
        if (!(nDIndexElement instanceof NDIndexSlice)) {
            if (nDIndexElement instanceof NDIndexAll) {
                padIndexAll(i, shape, jArr, jArr2, jArr3, jArr4, list2);
                return;
            }
            return;
        }
        NDIndexSlice nDIndexSlice = (NDIndexSlice) nDIndexElement;
        long longValue = ((Long) Optional.ofNullable(nDIndexSlice.getMin()).orElse(0L)).longValue();
        jArr[i] = longValue < 0 ? Math.floorMod(longValue, shape.get(i)) : longValue;
        long longValue2 = ((Long) Optional.ofNullable(nDIndexSlice.getMax()).orElse(Long.valueOf(shape.size(i)))).longValue();
        jArr2[i] = longValue2 < 0 ? Math.floorMod(longValue2, shape.get(i)) : longValue2;
        jArr3[i] = ((Long) Optional.ofNullable(nDIndexSlice.getStep()).orElse(1L)).longValue();
        jArr4[i] = (jArr2[i] - jArr[i]) / jArr3[i];
        list2.add(Long.valueOf(jArr4[i]));
    }

    private static void padIndexAll(int i, Shape shape, long[] jArr, long[] jArr2, long[] jArr3, long[] jArr4, List<Long> list) {
        jArr[i] = 0;
        jArr2[i] = shape.size(i);
        jArr3[i] = 1;
        jArr4[i] = shape.size(i);
        list.add(Long.valueOf(shape.size(i)));
    }

    public long[] getMin() {
        return this.min;
    }

    public long[] getMax() {
        return this.max;
    }

    public long[] getStep() {
        return this.step;
    }

    public int[] getToSqueeze() {
        return this.toSqueeze;
    }

    public Shape getShape() {
        return this.shape;
    }

    public Shape getSqueezedShape() {
        return this.squeezedShape;
    }
}
