/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.modeling;

import ai.h2o.automl.AutoML;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import ai.h2o.automl.ModelingStepsProvider;
import hex.Model;
import hex.grid.Grid;
import hex.grid.HyperSpaceSearchCriteria;
import hex.leaderboard.Leaderboard;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import water.Job;
import water.Key;

public class CompletionStepsProvider
implements ModelingStepsProvider<CompletionSteps> {
    @Override
    public String getName() {
        return "completion";
    }

    @Override
    public CompletionSteps newInstance(AutoML aml) {
        return new CompletionSteps(aml);
    }

    public static class CompletionSteps
    extends ModelingSteps {
        static final String NAME = "completion";
        private final ModelingStep[] optionals = new ModelingStep[]{new ResumeBestNGridsStep("resume_best_grids", 2, this.aml())};

        public CompletionSteps(AutoML autoML) {
            super(autoML);
        }

        @Override
        public String getProvider() {
            return NAME;
        }

        @Override
        protected ModelingStep[] getOptionals() {
            return this.optionals;
        }

        static class ResumeBestNGridsStep
        extends ModelingStep.DynamicStep<Model> {
            private final int _nGrids;

            public ResumeBestNGridsStep(String id, int nGrids, AutoML autoML) {
                super(CompletionSteps.NAME, id, autoML);
                this._nGrids = nGrids;
            }

            private List<ModelingStep> sortModelingStepByPerf() {
                HashMap scoresBySource = new HashMap();
                Model[] models = this.getTrainedModels();
                double[] metrics = this.aml().leaderboard().getSortMetricValues();
                if (metrics == null) {
                    return Collections.emptyList();
                }
                for (int i = 0; i < models.length; ++i) {
                    ModelingStep source = this.aml().session().getModelingStep(models[i]._key);
                    if (!scoresBySource.containsKey((Object)source)) {
                        scoresBySource.put(source, new ArrayList());
                    }
                    ((List)scoresBySource.get((Object)source)).add(metrics[i]);
                }
                Comparator metricsComparator = Map.Entry.comparingByValue();
                if (!Leaderboard.isLossFunction((String)this.aml().leaderboard().getSortMetric())) {
                    metricsComparator = metricsComparator.reversed();
                }
                return scoresBySource.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((List)e.getValue()).stream().mapToDouble(Double::doubleValue).average().orElse(-1.0))).entrySet().stream().sorted(metricsComparator).filter(e -> (Double)e.getValue() >= 0.0).map(Map.Entry::getKey).collect(Collectors.toList());
            }

            @Override
            protected Collection<ModelingStep> prepareModelingSteps() {
                List<ModelingStep> bestStep = this.sortModelingStepByPerf();
                return bestStep.stream().filter(ModelingStep::isResumable).filter(ModelingStep.GridStep.class::isInstance).limit(this._nGrids).map(s -> new ResumingGridStep((ModelingStep.GridStep)((Object)s), this._priorityGroup, this._weight / this._nGrids, this.aml())).collect(Collectors.toList());
            }
        }

        static class ResumingGridStep
        extends ModelingStep.GridStep {
            private transient ModelingStep.GridStep _step;

            public ResumingGridStep(ModelingStep.GridStep step, int priorityGroup, int weight, AutoML aml) {
                super(CompletionSteps.NAME, step.getAlgo(), step.getProvider() + "_" + step.getId(), priorityGroup, weight, aml);
                this._work = this.makeWork();
                this._step = step;
            }

            @Override
            public boolean canRun() {
                return this._step != null && this._weight > 0;
            }

            @Override
            public Model.Parameters prepareModelParameters() {
                return this._step.prepareModelParameters();
            }

            @Override
            public Map<String, Object[]> prepareSearchParameters() {
                return this._step.prepareSearchParameters();
            }

            @Override
            protected void setSearchCriteria(HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria searchCriteria, Model.Parameters baseParms) {
                super.setSearchCriteria(searchCriteria, baseParms);
                searchCriteria.set_stopping_rounds(0);
            }

            @Override
            protected Job<Grid> startJob() {
                Key[] resumedGrid = this.aml().session().getResumableKeys(this._step.getProvider(), this._step.getId());
                if (resumedGrid.length == 0) {
                    return null;
                }
                return this.hyperparameterSearch((Key<Grid>)resumedGrid[0], this.prepareModelParameters(), this.prepareSearchParameters());
            }
        }
    }
}

