package ai.djl.nn.norm;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;

/* loaded from: input_file:ai/djl/nn/norm/Dropout.class */
public class Dropout extends AbstractBlock {
    private static final byte VERSION = 2;
    private float rate;

    /* loaded from: input_file:ai/djl/nn/norm/Dropout$Builder.class */
    public static final class Builder {
        private float rate = 0.5f;

        Builder() {
        }

        public Builder optRate(float f) {
            this.rate = f;
            return this;
        }

        public Dropout build() {
            return new Dropout(this);
        }
    }

    Dropout(Builder builder) {
        super((byte) 2);
        this.rate = builder.rate;
    }

    @Override // ai.djl.nn.AbstractBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        return dropout(nDList.singletonOrThrow(), this.rate, z);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[]{shapeArr[0]};
    }

    @Override // ai.djl.nn.AbstractBlock
    public void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b == VERSION) {
            readInputShapes(dataInputStream);
        } else if (b != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b));
        }
    }

    @Override // ai.djl.nn.AbstractBlock
    public String toString() {
        return "Dropout()";
    }

    public static NDList dropout(NDArray nDArray) {
        return nDArray.getNDArrayInternal().dropout(nDArray, 0.5f, true);
    }

    public static NDList dropout(NDArray nDArray, float f) {
        return nDArray.getNDArrayInternal().dropout(nDArray, f, true);
    }

    public static NDList dropout(NDArray nDArray, float f, boolean z) {
        return nDArray.getNDArrayInternal().dropout(nDArray, f, z);
    }

    public static Builder builder() {
        return new Builder();
    }
}
