/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class CreateView
extends DynamicCustomOp {
    public static int POINT_TYPE = 0;
    public static int INTERVAL_TYPE = 1;
    public static int ALL_TYPE = 2;
    public static int NEW_AXIS = 3;
    public static int DEFAULT_INCLUSIVE = 1;

    public CreateView() {
    }

    public CreateView(INDArray[] inputs) {
        super(inputs, null);
    }

    public CreateView(SameDiff sameDiff, SDVariable[] args) {
        super(sameDiff, args);
    }

    public CreateView(SameDiff sd, SDVariable input, SDVariable[] indices) {
        this(sd, (SDVariable[])ArrayUtil.combine((Object[][])new SDVariable[][]{{input}, indices}));
    }

    public CreateView(INDArray input, INDArray[] indices) {
        this((INDArray[])ArrayUtil.combine((Object[][])new INDArray[][]{{input}, indices}));
    }

    public static SDVariable createInterval(SameDiff sameDiff, SDVariable intervalInputBegin, SDVariable intervalInputEnd, SDVariable intervalStrideInput, SDVariable inclusive) {
        return CreateView.createInterval(sameDiff, null, intervalInputBegin, intervalInputEnd, intervalStrideInput, inclusive);
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        return Collections.singletonList(dataTypes.get(0));
    }

    @Override
    public int getNumOutputs() {
        return 1;
    }

    @Override
    public String opName() {
        return "create_view";
    }

    public static SDVariable createPoint(SameDiff sameDiff, long offset) {
        return CreateView.createPoint(sameDiff, null, offset);
    }

    public static SDVariable createPoint(SameDiff sameDiff, SDVariable offset) {
        return CreateView.createPoint(sameDiff, null, offset);
    }

    public static SDVariable createPoint(SameDiff sameDiff, String name, long offset) {
        INDArray arr = Nd4j.createFromArray(new long[]{POINT_TYPE, 1L, 1L, offset, DEFAULT_INCLUSIVE});
        return sameDiff.var(name, arr);
    }

    public static SDVariable createPoint(SameDiff sameDiff, String name, SDVariable offset) {
        return sameDiff.concat(name, 0, sameDiff.constant(POINT_TYPE).reshape(1).castTo(DataType.INT64), sameDiff.constant(1).reshape(1).castTo(DataType.INT64), sameDiff.constant(1).reshape(1).castTo(DataType.INT64), offset.reshape(1).castTo(DataType.INT64), sameDiff.constant(DEFAULT_INCLUSIVE).reshape(1).castTo(DataType.INT64));
    }

    public static SDVariable createAll(SameDiff sameDiff) {
        return CreateView.createAll(sameDiff, null);
    }

    public static SDVariable createAll(SameDiff sameDiff, String name) {
        INDArray arr = Nd4j.createFromArray(new long[]{ALL_TYPE, 0L, 1L, DEFAULT_INCLUSIVE});
        return sameDiff.var(name, arr);
    }

    public static SDVariable createNewAxis(SameDiff sameDiff, String name) {
        INDArray arr = Nd4j.createFromArray(new long[]{NEW_AXIS, 1L, 10L, DEFAULT_INCLUSIVE});
        return sameDiff.var(name, arr);
    }

    public static SDVariable createNewAxis(SameDiff sameDiff) {
        return CreateView.createNewAxis(sameDiff, null);
    }

    public static SDVariable createInterval(SameDiff sameDiff, String name, long start, long end, long stride, long inclusive) {
        INDArray arr = Nd4j.createFromArray(new long[]{INTERVAL_TYPE, 2L, 1L, start, end, stride, inclusive});
        return sameDiff.var(name, arr);
    }

    public static SDVariable createInterval(SameDiff sameDiff, String name, SDVariable start, SDVariable end, SDVariable stride, SDVariable inclusive) {
        if (stride == null) {
            stride = sameDiff.constant(1).castTo(DataType.INT64).reshape(1);
        }
        if (inclusive == null) {
            inclusive = sameDiff.constant(0).castTo(DataType.INT64).reshape(1);
        }
        return sameDiff.concat(name, 0, sameDiff.constant(INTERVAL_TYPE).reshape(1).castTo(DataType.INT64), sameDiff.constant(2).reshape(1).castTo(DataType.INT64), sameDiff.constant(1).reshape(1).castTo(DataType.INT64), start.reshape(1).castTo(DataType.INT64), end.reshape(1).castTo(DataType.INT64), stride.reshape(1).castTo(DataType.INT64), inclusive.castTo(DataType.INT64).reshape(1));
    }

    public static SDVariable createInterval(SameDiff sameDiff, long start, long end, long stride, long inclusive) {
        return CreateView.createInterval(sameDiff, null, start, end, stride, inclusive);
    }

    public static INDArray createFrom(INDArray input, INDArray ... indices) {
        return input.get(CreateView.indices(indices));
    }

    public static INDArrayIndex[] indices(INDArray ... indexArrs) {
        return Arrays.stream(indexArrs).map(CreateView::fromIndexArr).collect(Collectors.toList()).toArray(new INDArrayIndex[indexArrs.length]);
    }

    public static INDArrayIndex fromIndexArr(INDArray index) {
        int idx = index.getInt(0);
        if (idx == POINT_TYPE) {
            int getPoint = index.getInt(3);
            return NDArrayIndex.point(getPoint);
        }
        if (idx == INTERVAL_TYPE) {
            int start = index.getInt(3);
            int end = index.getInt(4);
            int stride = index.getInt(5);
            boolean inclusive = index.getInt(6) > 0;
            return NDArrayIndex.interval(start, stride, end, inclusive);
        }
        if (idx == NEW_AXIS) {
            return NDArrayIndex.newAxis();
        }
        if (idx == ALL_TYPE) {
            return NDArrayIndex.all();
        }
        throw new IllegalArgumentException("Invalid type. Must be 1 of: " + POINT_TYPE + " (point type) " + INTERVAL_TYPE + " (interval type)" + NEW_AXIS + " (new axis) " + ALL_TYPE + " (all) ");
    }
}

