package ai.djl.modality.nlp;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;

/* loaded from: input_file:ai/djl/modality/nlp/Encoder.class */
public abstract class Encoder extends AbstractBlock {
    protected Block block;

    public Encoder(byte b, Block block) {
        super(b);
        this.block = addChildBlock("Block", block);
    }

    public abstract NDList getStates(NDList nDList);

    @Override // ai.djl.nn.AbstractBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        return this.block.forward(parameterStore, nDList, z, pairList);
    }

    @Override // ai.djl.nn.AbstractBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        this.block.initialize(nDManager, dataType, shapeArr);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return this.block.getOutputShapes(nDManager, shapeArr);
    }

    @Override // ai.djl.nn.AbstractBlock, ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        this.block.saveParameters(dataOutputStream);
    }

    @Override // ai.djl.nn.AbstractBlock, ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        this.block.loadParameters(nDManager, dataInputStream);
    }
}
