/*
 * Decompiled with CFR 0.152.
 */
package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ensemble.Metalearner;
import hex.ensemble.Metalearners;
import hex.ensemble.StackedEnsembleModel;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Objects;
import java.util.stream.Stream;
import jsr166y.CountedCompleter;
import water.DKV;
import water.Job;
import water.Key;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

public class StackedEnsemble
extends ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput> {
    StackedEnsembleDriver _driver;
    protected StackedEnsembleModel _model;

    public StackedEnsemble(StackedEnsembleModel.StackedEnsembleParameters parms) {
        super(parms);
        this.init(false);
    }

    public StackedEnsemble(boolean startup_once) {
        super(new StackedEnsembleModel.StackedEnsembleParameters(), startup_once);
    }

    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    @Override
    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    @Override
    public boolean isSupervised() {
        return true;
    }

    @Override
    protected void ignoreBadColumns(int npredictors, boolean expensive) {
        final HashSet<String> usedColumns = new HashSet<String>();
        for (Key<Model> k : ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models) {
            Model model = (Model)DKV.getGet(k);
            usedColumns.add(((Model.Parameters)model._parms)._response_column);
            usedColumns.addAll(Arrays.asList(((Model.Parameters)model._parms).getNonPredictors()));
            if (((Model.Output)model._output)._origNames != null) {
                usedColumns.addAll(Arrays.asList(((Model.Output)model._output)._origNames));
                continue;
            }
            usedColumns.addAll(Arrays.asList(((Model.Output)model._output)._names));
        }
        usedColumns.addAll(Arrays.asList(((StackedEnsembleModel.StackedEnsembleParameters)this._parms).getNonPredictors()));
        new ModelBuilder.FilterCols(0){

            @Override
            protected boolean filter(Vec v, String name) {
                return !usedColumns.contains(name);
            }
        }.doIt(this._train, "Dropping unused columns: ", expensive);
    }

    @Override
    protected StackedEnsembleDriver trainModelImpl() {
        this._driver = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._blending == null ? new StackedEnsembleCVStackingDriver() : new StackedEnsembleBlendingDriver();
        return this._driver;
    }

    @Override
    public boolean haveMojo() {
        return true;
    }

    @Override
    public int nclasses() {
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters != null) {
            DistributionFamily distribution = ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_parameters.getDistributionFamily();
            if (Arrays.asList(DistributionFamily.multinomial, DistributionFamily.ordinal, DistributionFamily.AUTO).contains((Object)distribution)) {
                return this._nclass;
            }
            if (Arrays.asList(DistributionFamily.bernoulli, DistributionFamily.quasibinomial, DistributionFamily.fractionalbinomial).contains((Object)distribution)) {
                return 2;
            }
            return 1;
        }
        return super.nclasses();
    }

    @Override
    public void init(boolean expensive) {
        this.expandBaseModels();
        super.init(expensive);
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._distribution != DistributionFamily.AUTO) {
            throw new H2OIllegalArgumentException("Setting \"distribution\" to StackedEnsemble is unsupported. Please set it in \"metalearner_parameters\".");
        }
        StackedEnsemble.checkColumnPresent("fold", ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._metalearner_fold_column, this.train(), this.valid(), ((StackedEnsembleModel.StackedEnsembleParameters)this._parms).blending());
        StackedEnsemble.checkColumnPresent("weights", ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._weights_column, this.train(), this.valid(), ((StackedEnsembleModel.StackedEnsembleParameters)this._parms).blending());
        StackedEnsemble.checkColumnPresent("offset", ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._offset_column, this.train(), this.valid(), ((StackedEnsembleModel.StackedEnsembleParameters)this._parms).blending());
        this.validateBaseModels();
    }

    private void expandBaseModels() {
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models == null) {
            return;
        }
        ArrayList<Key<Model>> baseModels = new ArrayList<Key<Model>>();
        for (Key<Model> baseModelKey : ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models) {
            Object retrievedObject = DKV.getGet(baseModelKey);
            if (retrievedObject instanceof Model) {
                baseModels.add(baseModelKey);
                continue;
            }
            if (retrievedObject instanceof Grid) {
                Grid grid = (Grid)retrievedObject;
                Collections.addAll(baseModels, grid.getModelKeys());
                continue;
            }
            if (retrievedObject == null) {
                throw new IllegalArgumentException(String.format("Specified id \"%s\" does not exist.", baseModelKey));
            }
            throw new IllegalArgumentException(String.format("Unsupported type \"%s\" as a base model.", retrievedObject.getClass().toString()));
        }
        ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models = baseModels.toArray(new Key[0]);
    }

    private void validateBaseModels() {
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models == null) {
            return;
        }
        boolean warnSameWeightsColumns = true;
        String referenceWeightsColumn = null;
        for (int i = 0; i < ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models.length; ++i) {
            Model baseModel = (Model)DKV.getGet(((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models[i]);
            if (i == 0) {
                if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._offset_column == null) {
                    ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._offset_column = ((Model.Parameters)baseModel._parms)._offset_column;
                }
                boolean bl = warnSameWeightsColumns = (referenceWeightsColumn = ((Model.Parameters)baseModel._parms)._weights_column) != null;
            }
            if (!Objects.equals(referenceWeightsColumn, ((Model.Parameters)baseModel._parms)._weights_column)) {
                warnSameWeightsColumns = false;
            }
            if (Objects.equals(((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._offset_column, ((Model.Parameters)baseModel._parms)._offset_column)) continue;
            throw new IllegalArgumentException("All base models must have the same offset_column!");
        }
        if (((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._weights_column == null && warnSameWeightsColumns && ((StackedEnsembleModel.StackedEnsembleParameters)this._parms)._base_models.length > 0) {
            this.warn("_weights_column", "All base models use weights_column=\"" + referenceWeightsColumn + "\" but Stacked Ensemble does not. If you want to use the same weights_column for the meta learner, please specify it as an argument in the h2o.stackedEnsemble call.");
        }
    }

    private static void checkColumnPresent(String columnName, String columnId, Frame ... frames) {
        if (columnId == null) {
            return;
        }
        for (Frame frame : frames) {
            if (frame == null || frame.vec(columnId) != null) continue;
            throw new IllegalArgumentException(String.format("Specified %s column '%s' not found in one of the supplied data frames. Available column names are: %s", columnName, columnId, Arrays.toString(frame.names())));
        }
    }

    static void addModelPredictionsToLevelOneFrame(Model aModel, Frame aModelsPredictions, Frame levelOneFrame) {
        if (((Model.Output)aModel._output).isBinomialClassifier()) {
            Vec preds = aModelsPredictions.vec(2);
            levelOneFrame.add(aModel._key.toString(), preds);
        } else if (((Model.Output)aModel._output).isMultinomialClassifier()) {
            Frame probabilities = aModelsPredictions.subframe(ArrayUtils.remove(aModelsPredictions.names(), "predict"));
            probabilities.setNames((String[])Stream.of(probabilities.names()).map(name -> aModel._key.toString().concat("/").concat((String)name)).toArray(String[]::new));
            levelOneFrame.add(probabilities);
        } else {
            if (((Model.Output)aModel._output).isAutoencoder()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack autoencoders: " + aModel._key);
            }
            if (!((Model.Output)aModel._output).isSupervised()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack unsupervised models: " + aModel._key);
            }
            Vec preds = aModelsPredictions.vec("predict");
            levelOneFrame.add(aModel._key.toString(), preds);
        }
    }

    static void addNonPredictorsToLevelOneFrame(StackedEnsembleModel.StackedEnsembleParameters parms, Frame fr, Frame levelOneFrame, boolean training) {
        if (training && parms._metalearner_fold_column != null) {
            levelOneFrame.add(parms._metalearner_fold_column, fr.vec(parms._metalearner_fold_column));
        }
        if (parms._weights_column != null) {
            levelOneFrame.add(parms._weights_column, fr.vec(parms._weights_column));
        }
        if (parms._offset_column != null) {
            levelOneFrame.add(parms._offset_column, fr.vec(parms._offset_column));
        }
        levelOneFrame.add(parms._response_column, fr.vec(parms._response_column));
    }

    private class StackedEnsembleBlendingDriver
    extends StackedEnsembleDriver {
        private StackedEnsembleBlendingDriver() {
        }

        @Override
        protected StackedEnsembleModel.StackingStrategy strategy() {
            return StackedEnsembleModel.StackingStrategy.blending;
        }

        @Override
        protected Frame getActualTrainingFrame() {
            return ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).blending();
        }

        @Override
        protected Frame getPredictionsForBaseModel(Model model, Frame actualsFrame, boolean isTrainingFrame) {
            if (StackedEnsemble.this.stop_requested() && isTrainingFrame) {
                throw new Job.JobCancelledException();
            }
            return this.buildPredictionsForBaseModel(model, actualsFrame);
        }
    }

    private class StackedEnsembleCVStackingDriver
    extends StackedEnsembleDriver {
        private StackedEnsembleCVStackingDriver() {
        }

        @Override
        protected StackedEnsembleModel.StackingStrategy strategy() {
            return StackedEnsembleModel.StackingStrategy.cross_validation;
        }

        @Override
        protected Frame getActualTrainingFrame() {
            return ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).train();
        }

        @Override
        protected Frame getPredictionsForBaseModel(Model model, Frame actualsFrame, boolean isTraining) {
            Frame fr;
            if (isTraining) {
                if (null == ((Model.Output)model._output)._cross_validation_holdout_predictions_frame_id) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame id. . .  Looks like keep_cross_validation_predictions wasn't set when building the models.");
                }
                fr = (Frame)DKV.getGet(((Model.Output)model._output)._cross_validation_holdout_predictions_frame_id);
                if (null == fr) {
                    throw new H2OIllegalArgumentException("Failed to find the xval predictions frame. . .  Looks like keep_cross_validation_predictions wasn't set when building the models, or the frame was deleted.");
                }
            } else {
                fr = this.buildPredictionsForBaseModel(model, actualsFrame);
            }
            return fr;
        }
    }

    private abstract class StackedEnsembleDriver
    extends ModelBuilder.Driver {
        private StackedEnsembleDriver() {
            super(StackedEnsemble.this);
        }

        private Frame prepareLevelOneFrame(String levelOneKey, Model[] baseModels, Frame[] baseModelPredictions, Frame actuals) {
            Frame old;
            StackedEnsembleModel.StackedEnsembleParameters.MetalearnerTransform transform;
            if (null == baseModels) {
                throw new H2OIllegalArgumentException("Base models array is null.");
            }
            if (null == baseModelPredictions) {
                throw new H2OIllegalArgumentException("Base model predictions array is null.");
            }
            if (baseModels.length == 0) {
                throw new H2OIllegalArgumentException("Base models array is empty.");
            }
            if (baseModelPredictions.length == 0) {
                throw new H2OIllegalArgumentException("Base model predictions array is empty.");
            }
            if (baseModels.length != baseModelPredictions.length) {
                throw new H2OIllegalArgumentException("Base models and prediction arrays are different lengths.");
            }
            if (((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._metalearner_transform != null && ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._metalearner_transform != StackedEnsembleModel.StackedEnsembleParameters.MetalearnerTransform.NONE) {
                if (!((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output).isBinomialClassifier() && !((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output).isMultinomialClassifier()) {
                    throw new H2OIllegalArgumentException("Metalearner transform is supported only for classification!");
                }
                transform = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._metalearner_transform;
            } else {
                transform = null;
            }
            if (null == levelOneKey) {
                levelOneKey = "levelone_" + StackedEnsemble.this._model._key.toString() + "_" + ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._metalearner_transform.toString();
            }
            if ((old = (Frame)DKV.getGet(levelOneKey)) != null && old instanceof Frame) {
                Frame oldFrame = old;
                oldFrame.write_lock(StackedEnsemble.this._job);
                oldFrame.removeAll();
                oldFrame.update(StackedEnsemble.this._job);
                oldFrame.unlock(StackedEnsemble.this._job);
            }
            Frame levelOneFrame = transform == null ? new Frame(Key.make(levelOneKey)) : new Frame(new Vec[0]);
            for (int i = 0; i < baseModels.length; ++i) {
                Model baseModel = baseModels[i];
                Frame baseModelPreds = baseModelPredictions[i];
                if (null == baseModel) {
                    Log.warn("Failed to find base model; skipping: " + baseModels[i]);
                    continue;
                }
                if (null == baseModelPreds) {
                    Log.warn("Failed to find base model " + baseModel + " predictions; skipping: " + baseModelPreds._key);
                    continue;
                }
                StackedEnsemble.addModelPredictionsToLevelOneFrame(baseModel, baseModelPreds, levelOneFrame);
                Scope.untrack(baseModelPredictions);
            }
            if (transform != null) {
                levelOneFrame = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._metalearner_transform.transform(StackedEnsemble.this._model, levelOneFrame, Key.make(levelOneKey));
            }
            StackedEnsemble.addNonPredictorsToLevelOneFrame((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms, actuals, levelOneFrame, true);
            Log.info("Finished creating \"level one\" frame for stacking: " + levelOneFrame.toString());
            DKV.put(levelOneFrame);
            return levelOneFrame;
        }

        private Frame prepareLevelOneFrame(String levelOneKey, Key<Model>[] baseModelKeys, Frame actuals, boolean isTraining) {
            ArrayList<Model> baseModels = new ArrayList<Model>();
            ArrayList<Frame> baseModelPredictions = new ArrayList<Frame>();
            for (Key<Model> k : baseModelKeys) {
                if (((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._metalearner != null && !StackedEnsemble.this._model.isUsefulBaseModel(k)) continue;
                Model aModel = (Model)DKV.getGet(k);
                if (null == aModel) {
                    throw new H2OIllegalArgumentException("Failed to find base model: " + k);
                }
                Frame predictions = this.getPredictionsForBaseModel(aModel, actuals, isTraining);
                baseModels.add(aModel);
                baseModelPredictions.add(predictions);
            }
            boolean keepLevelOneFrame = isTraining && ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._keep_levelone_frame;
            Frame levelOneFrame = this.prepareLevelOneFrame(levelOneKey, baseModels.toArray(new Model[0]), baseModelPredictions.toArray(new Frame[0]), actuals);
            if (keepLevelOneFrame) {
                levelOneFrame = levelOneFrame.deepCopy(levelOneFrame._key.toString());
                levelOneFrame.write_lock(StackedEnsemble.this._job);
                levelOneFrame.update(StackedEnsemble.this._job);
                levelOneFrame.unlock(StackedEnsemble.this._job);
                Scope.untrack(levelOneFrame.keysList());
            }
            return levelOneFrame;
        }

        @Override
        public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
            if (StackedEnsemble.this._model != null) {
                StackedEnsemble.this._model.delete();
            }
            return super.onExceptionalCompletion(ex, caller);
        }

        protected Frame buildPredictionsForBaseModel(Model model, Frame frame) {
            Key<Frame> predsKey = this.buildPredsKey(model, frame);
            Frame preds = (Frame)DKV.getGet(predsKey);
            if (preds == null) {
                preds = model.score(frame, predsKey.toString(), null, false);
                Scope.untrack(preds.keysList());
            }
            if (((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys == null) {
                ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys = new Key[0];
            }
            if (!ArrayUtils.contains(((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys, predsKey)) {
                ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys = ArrayUtils.append(((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._base_model_predictions_keys, predsKey);
            }
            return preds;
        }

        protected abstract StackedEnsembleModel.StackingStrategy strategy();

        protected abstract Frame getActualTrainingFrame();

        protected abstract Frame getPredictionsForBaseModel(Model var1, Frame var2, boolean var3);

        private Key<Frame> buildPredsKey(Key model_key, long model_checksum, Key frame_key, long frame_checksum) {
            return Key.make("preds_" + model_checksum + "_on_" + frame_checksum);
        }

        protected Key<Frame> buildPredsKey(Model model, Frame frame) {
            return frame == null || model == null ? null : this.buildPredsKey(model._key, model.checksum(), frame._key, frame.checksum());
        }

        @Override
        public void computeImpl() {
            Metalearner.Algorithm metalearnerAlgoSpec;
            Metalearner.Algorithm metalearnerAlgoImpl;
            StackedEnsemble.this.init(true);
            if (StackedEnsemble.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(StackedEnsemble.this);
            }
            StackedEnsemble.this._model = new StackedEnsembleModel(StackedEnsemble.this.dest(), (StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms, new StackedEnsembleModel.StackedEnsembleOutput(StackedEnsemble.this));
            ((StackedEnsembleModel.StackedEnsembleOutput)StackedEnsemble.this._model._output)._stacking_strategy = this.strategy();
            try {
                StackedEnsemble.this._model.delete_and_lock(StackedEnsemble.this._job);
                StackedEnsemble.this._model.checkAndInheritModelProperties();
                StackedEnsemble.this._model.update(StackedEnsemble.this._job);
            }
            finally {
                StackedEnsemble.this._model.unlock(StackedEnsemble.this._job);
            }
            String levelOneTrainKey = "levelone_training_" + StackedEnsemble.this._model._key.toString();
            Frame levelOneTrainingFrame = this.prepareLevelOneFrame(levelOneTrainKey, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._base_models, this.getActualTrainingFrame(), true);
            Frame levelOneValidationFrame = null;
            if (((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).valid() != null) {
                String levelOneValidKey = "levelone_validation_" + StackedEnsemble.this._model._key.toString();
                levelOneValidationFrame = this.prepareLevelOneFrame(levelOneValidKey, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._base_models, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms).valid(), false);
            }
            if ((metalearnerAlgoImpl = Metalearners.getActualMetalearnerAlgo(metalearnerAlgoSpec = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_algorithm)) == null) {
                throw new H2OIllegalArgumentException("Invalid `metalearner_algorithm`. Passed in " + (Object)((Object)metalearnerAlgoSpec) + " but must be one of " + Arrays.toString((Object[])Metalearner.Algorithm.values()));
            }
            Key<Model> metalearnerKey = Key.make("metalearner_" + (Object)((Object)metalearnerAlgoSpec) + "_" + StackedEnsemble.this._model._key);
            Job metalearnerJob = new Job(metalearnerKey, ModelBuilder.javaName(metalearnerAlgoImpl.toString()), "StackingEnsemble metalearner (" + (Object)((Object)metalearnerAlgoSpec) + ")");
            boolean hasMetaLearnerParams = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_parameters != null;
            long metalearnerSeed = ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._seed;
            Metalearner metalearner = Metalearners.createInstance(metalearnerAlgoSpec.name());
            metalearner.init(levelOneTrainingFrame, levelOneValidationFrame, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_parameters, StackedEnsemble.this._model, StackedEnsemble.this._job, metalearnerKey, metalearnerJob, (StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms, hasMetaLearnerParams, metalearnerSeed, ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._parms)._max_runtime_secs == 0.0 ? 0L : Math.max(StackedEnsemble.this.remainingTimeSecs(), 1L));
            metalearner.compute();
            if (StackedEnsemble.this._model.evalAutoParamsEnabled && ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_algorithm == Metalearner.Algorithm.AUTO) {
                ((StackedEnsembleModel.StackedEnsembleParameters)StackedEnsemble.this._model._parms)._metalearner_algorithm = metalearnerAlgoImpl;
            }
        }
    }
}

