/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.indexing.masking;

import java.util.ArrayList;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
import org.nd4j.linalg.api.ops.impl.shape.Gather;
import org.nd4j.linalg.api.ops.impl.shape.Squeeze;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.primitives.Longs;

public class Masking {
    public static SDVariable applyMask(SameDiff ret, SDVariable input, SDVariable mask, int axis) {
        SDVariable maskShape = mask.shape();
        SDVariable rank = mask.rank();
        SDVariable tensorShape = mask.shape();
        int maskRank = mask.rank().eval().getInt(0);
        SDVariable leadingSize = ret.prod(tensorShape.get(SDIndex.interval(0, mask.rank().eval().getInt(0))), 0);
        input = input.reshape(ret.concat(0, tensorShape.get(SDIndex.interval(0, axis)), leadingSize, tensorShape.get(SDIndex.interval(axis, maskRank))));
        mask = mask.reshape(-1);
        SDVariable indices = ret.squeeze(ret.where(mask), 0);
        SDVariable gathered = ret.gather(input, indices, axis);
        return gathered;
    }

    public static INDArray applyMask(INDArray input, INDArray mask, int axis) {
        int i;
        long[] maskShape = mask.shape();
        long rank = maskShape.length;
        long[] tensorShape = input.shape();
        Preconditions.checkState((maskShape.length > 0 ? 1 : 0) != 0, (String)"Mask shape must not be scalar");
        long leadingSize = 1L;
        int i2 = 0;
        while ((long)i2 < (long)axis + rank) {
            leadingSize *= tensorShape[i2];
            ++i2;
        }
        ArrayList<Long> retShape = new ArrayList<Long>();
        for (i = 0; i < axis; ++i) {
            retShape.add(tensorShape[i]);
        }
        retShape.add(leadingSize);
        i = axis;
        while ((long)i < (long)axis + rank) {
            retShape.add(tensorShape[i]);
            ++i;
        }
        INDArray retTensor = input.reshape(Longs.toArray(retShape));
        mask = mask.reshape(-1L);
        INDArray whereMask = Nd4j.getExecutioner().exec(new Where(mask))[0];
        INDArray indices = Nd4j.getExecutioner().exec(new Squeeze(whereMask, 1))[0];
        INDArray ret = Nd4j.getExecutioner().exec(new Gather(retTensor, indices, axis))[0];
        return ret;
    }
}

