package ai.djl.nn.transformer;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.core.Linear;
import ai.djl.training.ParameterStore;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;

/* loaded from: input_file:ai/djl/nn/transformer/PointwiseFeedForwardBlock.class */
public class PointwiseFeedForwardBlock extends AbstractBlock {
    private static final byte VERSION = 1;
    private Shape outputShape;

    public PointwiseFeedForwardBlock(List<Integer> list, int i, Function<NDList, NDList> function) {
        super((byte) 1);
        int i2 = 0;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            addChildBlock("linear_" + i2, Linear.builder().optBias(true).setUnits(it.next().intValue()).build());
            addChildBlock("activation_" + i2, new LambdaBlock(function));
            i2 += VERSION;
        }
        addChildBlock("output_layer", Linear.builder().optBias(true).setUnits(i).build());
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[]{this.outputShape};
    }

    @Override // ai.djl.nn.AbstractBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        this.inputNames = Collections.singletonList("input");
        if (shapeArr.length != VERSION) {
            throw new IllegalArgumentException("Pointwise feed forward blocks can only have one input.");
        }
        Shape shape = shapeArr[0];
        if (shape.dimension() < 2) {
            throw new IllegalArgumentException("Pointwise feed forward blocks need an input of at least dimension 2.");
        }
        Shape shape2 = shape;
        Iterator<Block> it = this.children.values().iterator();
        while (it.hasNext()) {
            shape2 = it.next().initialize(nDManager, dataType, shape2)[0];
        }
        this.outputShape = shape2;
    }

    @Override // ai.djl.nn.AbstractBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList nDList2 = nDList;
        Iterator<Pair<String, Block>> it = getChildren().iterator();
        while (it.hasNext()) {
            nDList2 = it.next().getValue().forward(parameterStore, nDList2, z);
        }
        return nDList2;
    }
}
