package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.util.PairList;

/* loaded from: input_file:ai/djl/nn/transformer/MissingOps.class */
public final class MissingOps {
    private MissingOps() {
    }

    public static NDArray gatherNd(NDArray nDArray, NDArray nDArray2) {
        return nDArray2.getManager().invoke("gather_nd", new NDList(nDArray, nDArray2), null).head();
    }

    public static NDArray oneHot(int i, NDArray nDArray) {
        PairList<String, ?> pairList = new PairList<>();
        pairList.add("depth", Integer.valueOf(i));
        pairList.add("on_value", Float.valueOf(1.0f));
        pairList.add("off_value", Float.valueOf(0.0f));
        pairList.add("dtype", DataType.FLOAT32);
        return nDArray.getManager().invoke("_npx_one_hot", new NDList(nDArray), pairList).head();
    }
}
