package ai.djl.nn.core;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;

/* loaded from: input_file:ai/djl/nn/core/Prelu.class */
public class Prelu extends AbstractBlock {
    private static final byte VERSION = 2;
    private Parameter alpha;

    public Prelu() {
        super((byte) 2);
        this.alpha = addParameter(Parameter.builder().setName("alpha").setType(Parameter.Type.WEIGHT).optShape(new Shape(new long[0])).build());
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        return prelu(singletonOrThrow, parameterStore.getValue(this.alpha, singletonOrThrow.getDevice(), z));
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return new Shape[]{shapeArr[0]};
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    public void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b == this.version) {
            readInputShapes(dataInputStream);
        } else if (b != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b));
        }
    }

    public static NDList prelu(NDArray nDArray, NDArray nDArray2) {
        return nDArray.getNDArrayInternal().prelu(nDArray, nDArray2);
    }
}
