package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.initializer.XavierInitializer;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Objects;
import java.util.UUID;

/* loaded from: input_file:ai/djl/nn/Parameter.class */
public class Parameter implements AutoCloseable {
    private static final byte VERSION = 1;
    private String id = UUID.randomUUID().toString();
    private String name;
    private Shape shape;
    private Type type;
    private Initializer initializer;
    private NDArray array;
    private boolean requiresGrad;

    /* loaded from: input_file:ai/djl/nn/Parameter$Builder.class */
    public static final class Builder {
        String name;
        Shape shape;
        Type type;
        Initializer initializer;
        NDArray array;
        boolean requiresGrad = true;

        public Builder setName(String str) {
            this.name = str;
            return this;
        }

        public Builder setType(Type type) {
            this.type = type;
            return this;
        }

        public Builder optShape(Shape shape) {
            this.shape = shape;
            return this;
        }

        public Builder optInitializer(Initializer initializer) {
            this.initializer = initializer;
            return this;
        }

        public Builder optArray(NDArray nDArray) {
            this.array = nDArray;
            return this;
        }

        public Builder optRequiresGrad(boolean z) {
            this.requiresGrad = z;
            return this;
        }

        public Parameter build() {
            return new Parameter(this);
        }
    }

    /* loaded from: input_file:ai/djl/nn/Parameter$Type.class */
    public enum Type {
        WEIGHT(new XavierInitializer(XavierInitializer.RandomType.GAUSSIAN, XavierInitializer.FactorType.IN, 2.0f)),
        BIAS(Initializer.ZEROS),
        GAMMA(Initializer.ONES),
        BETA(Initializer.ZEROS),
        RUNNING_MEAN(Initializer.ZEROS),
        RUNNING_VAR(Initializer.ONES),
        OTHER(null);

        private final transient Initializer initializer;

        Type(Initializer initializer) {
            this.initializer = initializer;
        }

        public Initializer getInitializer() {
            return this.initializer;
        }
    }

    Parameter(Builder builder) {
        this.name = builder.name;
        this.shape = builder.shape;
        this.type = builder.type;
        this.array = builder.array;
        this.requiresGrad = builder.requiresGrad;
        this.initializer = builder.initializer != null ? builder.initializer : this.type.getInitializer();
    }

    public String getId() {
        return this.id;
    }

    public String getName() {
        return this.name == null ? "" : this.name;
    }

    public Type getType() {
        return this.type;
    }

    public void setArray(NDArray nDArray) {
        if (this.shape != null) {
            throw new IllegalStateException("array has been set! Use either setArray or setShape");
        }
        this.array = nDArray;
        this.shape = nDArray.getShape();
        nDArray.setName(this.name);
    }

    public void setShape(Shape shape) {
        if (this.array != null) {
            throw new IllegalStateException("array has been set! Use either setArray or setShape");
        }
        this.shape = shape;
    }

    public Shape getShape() {
        return this.shape;
    }

    public NDArray getArray() {
        if (isInitialized()) {
            return this.array;
        }
        throw new UninitializedParameterException("The array for parameter \"" + getName() + "\" has not been initialized");
    }

    public boolean requiresGradient() {
        return this.requiresGrad;
    }

    public void freeze(boolean z) {
        this.requiresGrad = !z;
        if (this.array != null) {
            this.array.setRequiresGradient(this.requiresGrad);
        }
    }

    public boolean isInitialized() {
        return this.array != null;
    }

    public void setInitializer(Initializer initializer) {
        this.initializer = initializer;
    }

    public Initializer getInitializer() {
        return this.initializer;
    }

    public void initialize(NDManager nDManager, DataType dataType) {
        if (!isInitialized()) {
            Objects.requireNonNull(this.initializer, "No initializer has been set");
            Objects.requireNonNull(this.shape, "No parameter shape has been set");
            this.array = this.initializer.initialize(nDManager, this.shape, dataType);
            this.array.setName(this.name);
        }
        if (requiresGradient()) {
            this.array.setRequiresGradient(true);
        }
    }

    public void save(DataOutputStream dataOutputStream) throws IOException {
        if (!isInitialized()) {
            dataOutputStream.writeChar(78);
            return;
        }
        dataOutputStream.writeChar(80);
        dataOutputStream.writeByte(VERSION);
        dataOutputStream.writeUTF(getName());
        dataOutputStream.write(this.array.encode());
    }

    public void load(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        char readChar = dataInputStream.readChar();
        if (readChar == 'N') {
            return;
        }
        if (readChar != 'P') {
            throw new MalformedModelException("Invalid input data.");
        }
        byte readByte = dataInputStream.readByte();
        if (readByte != VERSION) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
        String readUTF = dataInputStream.readUTF();
        if (!readUTF.equals(getName())) {
            throw new MalformedModelException("Unexpected parameter name: " + readUTF + ", expected: " + this.name);
        }
        this.array = nDManager.decode(dataInputStream);
        this.shape = this.array.getShape();
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.array != null) {
            this.array.close();
            this.array = null;
        }
    }

    public static Builder builder() {
        return new Builder();
    }
}
