/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning;

import java.util.HashMap;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.config.AdaBelief;

public class AdaBeliefUpdater
implements GradientUpdater<AdaBelief> {
    public static final String M_STATE = "M";
    public static final String S_STATE = "S";
    private AdaBelief config;
    private INDArray m;
    private INDArray s;
    private char gradientReshapeOrder;

    public AdaBeliefUpdater(AdaBelief config) {
        this.config = config;
    }

    @Override
    public void setState(@NonNull Map<String, INDArray> stateMap, boolean initialize) {
        if (stateMap == null) {
            throw new NullPointerException("stateMap is marked non-null but is null");
        }
        if (!stateMap.containsKey(M_STATE) || !stateMap.containsKey(S_STATE) || stateMap.size() != 2) {
            throw new IllegalStateException("State map should contain only keys [M,S] but has keys " + stateMap.keySet());
        }
        this.m = stateMap.get(M_STATE);
        this.s = stateMap.get(S_STATE);
    }

    @Override
    public Map<String, INDArray> getState() {
        HashMap<String, INDArray> r = new HashMap<String, INDArray>();
        r.put(M_STATE, this.m);
        r.put(S_STATE, this.s);
        return r;
    }

    @Override
    public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) {
        if (!viewArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (initialize) {
            viewArray.assign(0);
        }
        long length = viewArray.length();
        this.m = viewArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0L, length / 2L));
        this.s = viewArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(length / 2L, length));
        this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f');
        this.s = Shape.newShapeNoCopy(this.s, gradientShape, gradientOrder == 'f');
        if (this.m == null || this.s == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view arrays");
        }
        this.gradientReshapeOrder = gradientOrder;
    }

    @Override
    public void applyUpdater(INDArray gradient, int iteration, int epoch) {
        if (this.m == null || this.s == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        double beta1 = this.config.getBeta1();
        double beta2 = this.config.getBeta2();
        double learningRate = this.config.getLearningRate(iteration, epoch);
        double epsilon = this.config.getEpsilon();
        Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaBeliefUpdater(gradient.reshape(this.s.shape()), this.s, this.m, learningRate, beta1, beta2, epsilon, iteration));
    }

    @Override
    public AdaBelief getConfig() {
        return this.config;
    }

    public INDArray getM() {
        return this.m;
    }

    public INDArray getS() {
        return this.s;
    }

    public char getGradientReshapeOrder() {
        return this.gradientReshapeOrder;
    }

    public void setConfig(AdaBelief config) {
        this.config = config;
    }

    public void setM(INDArray m) {
        this.m = m;
    }

    public void setS(INDArray s) {
        this.s = s;
    }

    public void setGradientReshapeOrder(char gradientReshapeOrder) {
        this.gradientReshapeOrder = gradientReshapeOrder;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AdaBeliefUpdater)) {
            return false;
        }
        AdaBeliefUpdater other = (AdaBeliefUpdater)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getGradientReshapeOrder() != other.getGradientReshapeOrder()) {
            return false;
        }
        AdaBelief this$config = this.getConfig();
        AdaBelief other$config = other.getConfig();
        if (this$config == null ? other$config != null : !((Object)this$config).equals(other$config)) {
            return false;
        }
        INDArray this$m = this.getM();
        INDArray other$m = other.getM();
        if (this$m == null ? other$m != null : !this$m.equals(other$m)) {
            return false;
        }
        INDArray this$s = this.getS();
        INDArray other$s = other.getS();
        return !(this$s == null ? other$s != null : !this$s.equals(other$s));
    }

    protected boolean canEqual(Object other) {
        return other instanceof AdaBeliefUpdater;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getGradientReshapeOrder();
        AdaBelief $config = this.getConfig();
        result = result * 59 + ($config == null ? 43 : ((Object)$config).hashCode());
        INDArray $m = this.getM();
        result = result * 59 + ($m == null ? 43 : $m.hashCode());
        INDArray $s = this.getS();
        result = result * 59 + ($s == null ? 43 : $s.hashCode());
        return result;
    }

    public String toString() {
        return "AdaBeliefUpdater(config=" + this.getConfig() + ", m=" + this.getM() + ", s=" + this.getS() + ", gradientReshapeOrder=" + this.getGradientReshapeOrder() + ")";
    }
}

