/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.tracker;

import ai.djl.training.tracker.Tracker;
import ai.djl.util.Preconditions;

public class CyclicalTracker
implements Tracker {
    private float baseValue;
    private float maxValue;
    private int stepSizeUp;
    private int stepSizeDown;
    private int totalSize;
    private float stepRatio;
    private ScaleFunction scaleFunction;
    private boolean scaleModeCycle;

    public CyclicalTracker(Builder builder) {
        this.baseValue = builder.baseValue;
        this.maxValue = builder.maxValue;
        this.stepSizeUp = builder.stepSizeUp;
        this.stepSizeDown = builder.stepSizeDown > 0 ? builder.stepSizeDown : builder.stepSizeUp;
        this.totalSize = this.stepSizeUp + this.stepSizeDown;
        this.stepRatio = (float)this.stepSizeUp / (float)this.totalSize;
        if (builder.scaleFunction != null) {
            this.scaleFunction = builder.scaleFunction;
            this.scaleModeCycle = builder.scaleModeCycle;
        } else {
            switch (builder.mode) {
                case TRIANGULAR: {
                    this.scaleFunction = new TriangularScaleFunction();
                    this.scaleModeCycle = true;
                    break;
                }
                case TRIANGULAR2: {
                    this.scaleFunction = new Triangular2ScaleFunction();
                    this.scaleModeCycle = true;
                    break;
                }
                case EXP_RANGE: {
                    this.scaleFunction = new ExpRangeScaleFunction(builder.gamma);
                    this.scaleModeCycle = false;
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported Cyclical mode.");
                }
            }
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public float getNewValue(int numUpdate) {
        int cycle = (int)Math.floor(1.0f + (float)numUpdate / (float)this.totalSize);
        float x = 1.0f + (float)numUpdate / (float)this.totalSize - (float)cycle;
        float scaleFactor = x < this.stepRatio ? x / this.stepRatio : (x - 1.0f) / (this.stepRatio - 1.0f);
        float baseHeight = (this.maxValue - this.baseValue) * scaleFactor;
        float res = this.scaleModeCycle ? this.baseValue + baseHeight * this.scaleFunction.func(cycle) : this.baseValue + baseHeight * this.scaleFunction.func(numUpdate);
        return res;
    }

    public static final class Builder {
        private float baseValue = 0.001f;
        private float maxValue = 0.006f;
        private int stepSizeUp = 2000;
        private int stepSizeDown;
        private CyclicalMode mode = CyclicalMode.TRIANGULAR;
        private ScaleFunction scaleFunction;
        private boolean scaleModeCycle = true;
        private float gamma = 1.0f;

        private Builder() {
        }

        public Builder optBaseValue(float baseValue) {
            this.baseValue = baseValue;
            return this;
        }

        public Builder optMaxValue(float maxValue) {
            this.maxValue = maxValue;
            return this;
        }

        public Builder optStepSizeUp(int stepSizeUp) {
            this.stepSizeUp = stepSizeUp;
            return this;
        }

        public Builder optStepSizeDown(int stepSizeDown) {
            this.stepSizeDown = stepSizeDown;
            return this;
        }

        public Builder optMode(CyclicalMode mode) {
            this.mode = mode;
            return this;
        }

        public Builder optGamma(float gamma) {
            this.gamma = gamma;
            return this;
        }

        public Builder optScaleFunction(ScaleFunction scaleFunction) {
            this.scaleFunction = scaleFunction;
            return this;
        }

        public Builder optScaleModeCycle(boolean scaleModeCycle) {
            this.scaleModeCycle = scaleModeCycle;
            return this;
        }

        public CyclicalTracker build() {
            Preconditions.checkArgument(this.baseValue > 0.0f, "baseValue has to be positive!");
            Preconditions.checkArgument(this.maxValue > 0.0f, "maxValue has to be positive!");
            Preconditions.checkArgument(this.baseValue <= this.maxValue, "baseValue has to lower than maxValue!");
            Preconditions.checkArgument(this.stepSizeUp >= 1, "stepSizeUp has to be positive!");
            Preconditions.checkArgument(this.stepSizeDown >= 0, "stepSizeUp cannot be negative!");
            Preconditions.checkArgument(this.gamma >= 0.0f && this.gamma <= 1.0f, "gamma has to be between 0 and 1!");
            return new CyclicalTracker(this);
        }
    }

    public static interface ScaleFunction {
        public float func(int var1);
    }

    public static enum CyclicalMode {
        TRIANGULAR,
        TRIANGULAR2,
        EXP_RANGE;

    }

    private static final class TriangularScaleFunction
    implements ScaleFunction {
        private TriangularScaleFunction() {
        }

        @Override
        public float func(int steps) {
            return 1.0f;
        }
    }

    private static final class Triangular2ScaleFunction
    implements ScaleFunction {
        private Triangular2ScaleFunction() {
        }

        @Override
        public float func(int steps) {
            return (float)(1.0 / Math.pow(2.0, steps - 1));
        }
    }

    private static final class ExpRangeScaleFunction
    implements ScaleFunction {
        float gamma;

        ExpRangeScaleFunction(float gamma) {
            this.gamma = gamma;
        }

        @Override
        public float func(int steps) {
            return (float)Math.pow(this.gamma, steps);
        }
    }
}

