/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.modeling;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.ModelParametersProvider;
import ai.h2o.automl.ModelSelectionStrategies;
import ai.h2o.automl.ModelSelectionStrategy;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import ai.h2o.automl.ModelingStepsProvider;
import ai.h2o.automl.Models;
import ai.h2o.automl.events.EventLogEntry;
import hex.Model;
import hex.tree.SharedTreeModel;
import hex.tree.gbm.GBMModel;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import water.Job;
import water.Key;

public class GBMStepsProvider
implements ModelingStepsProvider<GBMSteps>,
ModelParametersProvider<GBMModel.GBMParameters> {
    @Override
    public String getName() {
        return GBMSteps.NAME;
    }

    @Override
    public GBMSteps newInstance(AutoML aml) {
        return new GBMSteps(aml);
    }

    @Override
    public GBMModel.GBMParameters newDefaultParameters() {
        return new GBMModel.GBMParameters();
    }

    public static class GBMSteps
    extends ModelingSteps {
        static final String NAME = Algo.GBM.name();
        private final ModelingStep[] defaults = new GBMModelStep[]{new GBMModelStep("def_1", this.aml()){

            @Override
            public GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters params = super.prepareModelParameters();
                params._max_depth = 6;
                params._min_rows = 1.0;
                return params;
            }
        }, new GBMModelStep("def_2", this.aml()){

            @Override
            public GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters params = super.prepareModelParameters();
                params._max_depth = 7;
                params._min_rows = 10.0;
                return params;
            }
        }, new GBMModelStep("def_3", this.aml()){

            @Override
            public GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters params = super.prepareModelParameters();
                params._max_depth = 8;
                params._min_rows = 10.0;
                return params;
            }
        }, new GBMModelStep("def_4", this.aml()){

            @Override
            public GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters params = super.prepareModelParameters();
                params._max_depth = 10;
                params._min_rows = 10.0;
                return params;
            }
        }, new GBMModelStep("def_5", this.aml()){

            @Override
            public GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters params = super.prepareModelParameters();
                params._max_depth = 15;
                params._min_rows = 100.0;
                return params;
            }
        }};
        private final ModelingStep[] grids = new GBMGridStep[]{new DefaultGBMGridStep("grid_1", this.aml())};
        private final ModelingStep[] exploitation = new ModelingStep[]{new GBMExploitationStep("lr_annealing", this.aml()){
            Key<Models> resultKey;
            {
                super(id, autoML);
                this.resultKey = null;
            }

            @Override
            protected Job<Models> startTraining(Key result, double maxRuntimeSecs) {
                this.resultKey = result;
                GBMModel bestGBM = this.getBestGBM();
                this.aml().eventLog().info(EventLogEntry.Stage.ModelSelection, "Retraining best GBM with learning rate annealing: " + bestGBM._key);
                GBMModel.GBMParameters params = (GBMModel.GBMParameters)((GBMModel.GBMParameters)bestGBM._input_parms).clone();
                params._max_runtime_secs = 0.0;
                params._learn_rate_annealing = 0.99;
                this.initTimeConstraints((Model.Parameters)params, maxRuntimeSecs);
                this.setStoppingCriteria((Model.Parameters)params, (Model.Parameters)new GBMModel.GBMParameters());
                return this.asModelsJob(this.startModel(Key.make((String)(result + "_model")), params), (Key<Models>)result);
            }

            @Override
            protected ModelSelectionStrategy getSelectionStrategy() {
                return (originalModels, newModels) -> new ModelSelectionStrategies.KeepBestN(1, () -> this.makeTmpLeaderboard(Objects.toString(this.resultKey, this._provider + "_" + this._id))).select(new Key[]{this.getBestGBM()._key}, newModels);
            }
        }};

        static GBMModel.GBMParameters prepareModelParameters() {
            GBMModel.GBMParameters params = new GBMModel.GBMParameters();
            params._score_tree_interval = 5;
            params._histogram_type = SharedTreeModel.SharedTreeParameters.HistogramType.AUTO;
            return params;
        }

        public GBMSteps(AutoML autoML) {
            super(autoML);
        }

        @Override
        public String getProvider() {
            return NAME;
        }

        @Override
        protected ModelingStep[] getDefaultModels() {
            return this.defaults;
        }

        @Override
        protected ModelingStep[] getGrids() {
            return this.grids;
        }

        @Override
        protected ModelingStep[] getOptionals() {
            return this.exploitation;
        }

        static class DefaultGBMGridStep
        extends GBMGridStep {
            public DefaultGBMGridStep(String id, AutoML autoML) {
                super(id, autoML);
            }

            @Override
            public Map<String, Object[]> prepareSearchParameters() {
                HashMap<String, Object[]> searchParams = new HashMap<String, Object[]>();
                searchParams.put("_max_depth", new Integer[]{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
                searchParams.put("_min_rows", new Integer[]{1, 5, 10, 15, 30, 100});
                searchParams.put("_sample_rate", new Double[]{0.5, 0.6, 0.7, 0.8, 0.9, 1.0});
                searchParams.put("_col_sample_rate", new Double[]{0.4, 0.7, 1.0});
                searchParams.put("_col_sample_rate_per_tree", new Double[]{0.4, 0.7, 1.0});
                searchParams.put("_min_split_improvement", new Double[]{1.0E-4, 1.0E-5});
                return searchParams;
            }
        }

        static abstract class GBMExploitationStep
        extends ModelingStep.SelectionStep<GBMModel> {
            protected GBMModel getBestGBM() {
                for (Model model : this.getTrainedModels()) {
                    if (!(model instanceof GBMModel)) continue;
                    return (GBMModel)model;
                }
                return null;
            }

            @Override
            public boolean canRun() {
                return super.canRun() && this.getBestGBM() != null;
            }

            public GBMExploitationStep(String id, AutoML autoML) {
                super(NAME, Algo.GBM, id, autoML);
                if (autoML.getBuildSpec().build_models.exploitation_ratio > 0.0) {
                    this._ignoredConstraints = new AutoML.Constraint[]{AutoML.Constraint.MODEL_COUNT};
                }
            }
        }

        static abstract class GBMGridStep
        extends ModelingStep.GridStep<GBMModel> {
            public GBMGridStep(String id, AutoML autoML) {
                super(NAME, Algo.GBM, id, autoML);
            }

            public GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters params = GBMSteps.prepareModelParameters();
                params._ntrees = 10000;
                return params;
            }
        }

        static abstract class GBMModelStep
        extends ModelingStep.ModelStep<GBMModel> {
            GBMModelStep(String id, AutoML autoML) {
                super(NAME, Algo.GBM, id, autoML);
            }

            public GBMModel.GBMParameters prepareModelParameters() {
                GBMModel.GBMParameters params = GBMSteps.prepareModelParameters();
                params._ntrees = 10000;
                params._sample_rate = 0.8;
                params._col_sample_rate = 0.8;
                params._col_sample_rate_per_tree = 0.8;
                return params;
            }
        }
    }
}

