package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.core.Linear;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Arrays;

/* loaded from: input_file:ai/djl/nn/transformer/BertNextSentenceBlock.class */
public class BertNextSentenceBlock extends AbstractBlock {
    private static final byte VERSION = 1;
    private final Linear binaryClassifier;

    public BertNextSentenceBlock() {
        super((byte) 1);
        this.binaryClassifier = (Linear) addChildBlock("binaryClassifier", Linear.builder().setUnits(2L).optBias(true).build());
    }

    @Override // ai.djl.nn.AbstractBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        this.inputNames = Arrays.asList("pooledOutput");
        this.binaryClassifier.initialize(nDManager, dataType, shapeArr);
    }

    @Override // ai.djl.nn.AbstractBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        return forward(parameterStore, nDList, z);
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z) {
        return new NDList(forward(parameterStore, nDList.singletonOrThrow(), z));
    }

    public NDArray forward(ParameterStore parameterStore, NDArray nDArray, boolean z) {
        return this.binaryClassifier.forward(parameterStore, new NDList(nDArray), z).head().logSoftmax(VERSION);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[]{new Shape(shapeArr[0].get(0), 2)};
    }
}
