/*
 * Decompiled with CFR 0.152.
 */
package hex.rulefit;

import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.ToEigenVec;
import hex.glm.GLMModel;
import hex.rulefit.Rule;
import hex.rulefit.RuleEnsemble;
import hex.rulefit.RuleFit;
import hex.rulefit.RuleFitMojoWriter;
import hex.rulefit.RuleFitUtils;
import hex.util.LinearAlgebraUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.fvec.Frame;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.TwoDimTable;

public class RuleFitModel
extends Model<RuleFitModel, RuleFitParameters, RuleFitOutput> {
    GLMModel glmModel;
    RuleEnsemble ruleEnsemble;

    @Override
    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    public RuleFitModel(Key<RuleFitModel> selfKey, RuleFitParameters parms, RuleFitOutput output, GLMModel glmModel, RuleEnsemble ruleEnsemble) {
        super(selfKey, parms, output);
        this.glmModel = glmModel;
        this.ruleEnsemble = ruleEnsemble;
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        assert (domain == null);
        switch (((RuleFitOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((RuleFitOutput)this._output).nclasses(), domain, ((RuleFitParameters)this._parms)._auc_type);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl("Invalid ModelCategory " + (Object)((Object)((RuleFitOutput)this._output).getModelCategory()));
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("RuleFitModel doesn't support scoring on raw data. Use score() instead.");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Frame score(Frame fr, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) throws IllegalArgumentException {
        Frame adaptFrm = new Frame(fr);
        this.adaptTestForTrain(adaptFrm, true, false);
        Frame linearTest = new Frame(new Vec[0]);
        try {
            if (ModelType.RULES_AND_LINEAR.equals((Object)((RuleFitParameters)this._parms)._model_type) || ModelType.RULES.equals((Object)((RuleFitParameters)this._parms)._model_type)) {
                linearTest.add(this.ruleEnsemble.createGLMTrainFrame(adaptFrm, ((RuleFitParameters)this._parms)._max_rule_length - ((RuleFitParameters)this._parms)._min_rule_length + 1, ((RuleFitParameters)this._parms)._rule_generation_ntrees, ((RuleFitOutput)this._output).classNames(), ((RuleFitParameters)this._parms)._weights_column, false));
            }
            if (ModelType.RULES_AND_LINEAR.equals((Object)((RuleFitParameters)this._parms)._model_type) || ModelType.LINEAR.equals((Object)((RuleFitParameters)this._parms)._model_type)) {
                linearTest.add(RuleFitUtils.getLinearNames(adaptFrm.numCols(), adaptFrm.names()), adaptFrm.vecs());
            } else {
                linearTest.add(RuleFitUtils.getLinearNames(1, new String[]{((RuleFitParameters)this._parms)._response_column})[0], adaptFrm.vec(((RuleFitParameters)this._parms)._response_column));
            }
            Frame scored = this.glmModel.score(linearTest, destination_key, null, true);
            this.updateModelMetrics(this.glmModel, fr);
            Frame frame = scored;
            return frame;
        }
        finally {
            Frame.deleteTempFrameAndItsNonSharedVecs(linearTest, adaptFrm);
        }
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        super.remove_impl(fs, cascade);
        if (cascade) {
            this.glmModel.remove(fs);
        }
        return fs;
    }

    void updateModelMetrics(GLMModel glmModel, Frame fr) {
        for (Key<ModelMetrics> modelMetricsKey : ((GLMModel.GLMOutput)glmModel._output).getModelMetrics()) {
            if (modelMetricsKey.get() == null) continue;
            this.addModelMetrics(modelMetricsKey.get().deepCloneWithDifferentModelAndFrame(this, fr));
        }
    }

    @Override
    public RuleFitMojoWriter getMojo() {
        return new RuleFitMojoWriter(this);
    }

    @Override
    public boolean haveMojo() {
        return true;
    }

    public Frame predictRules(Frame frame, String[] ruleIds) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        List<String> linVarNames = Arrays.asList(this.glmModel.names()).stream().filter(name -> name.startsWith("linear.")).collect(Collectors.toList());
        ArrayList<Rule> rules = new ArrayList<Rule>();
        ArrayList<String> linearRules = new ArrayList<String>();
        for (int i = 0; i < ruleIds.length; ++i) {
            if (ruleIds[i].startsWith("linear.") && this.isLinearVar(ruleIds[i], linVarNames)) {
                linearRules.add(ruleIds[i]);
                continue;
            }
            rules.add(this.ruleEnsemble.getRuleByVarName(RuleFitUtils.readRuleId(ruleIds[i])));
        }
        RuleEnsemble subEnsemble = new RuleEnsemble(rules.toArray(new Rule[0]));
        Frame result = subEnsemble.transform(adaptFrm);
        for (int i = 0; i < linearRules.size(); ++i) {
            result.add((String)linearRules.get(i), Vec.makeOne(frame.numRows()));
        }
        result = new Frame(Key.make(), result.names(), result.vecs());
        DKV.put(result);
        return result;
    }

    private boolean isLinearVar(String potentialLinVarId, List<String> linVarNames) {
        for (String linVarName : linVarNames) {
            if (!potentialLinVarId.startsWith(linVarName)) continue;
            return true;
        }
        return false;
    }

    public static class RuleFitOutput
    extends Model.Output {
        public double[] _intercept;
        String[] _linear_names;
        public TwoDimTable _rule_importance = null;
        Key glmModelKey = null;
        String[] _dataFromRulesCodes;

        public RuleFitOutput(RuleFit b) {
            super(b);
        }
    }

    public static class RuleFitParameters
    extends Model.Parameters {
        public Algorithm _algorithm = Algorithm.AUTO;
        public int _min_rule_length = 3;
        public int _max_rule_length = 3;
        public int _max_num_rules = -1;
        public ModelType _model_type = ModelType.RULES_AND_LINEAR;
        public int _rule_generation_ntrees = 50;
        public boolean _remove_duplicates = true;
        public double[] _lambda;

        @Override
        public String algoName() {
            return "RuleFit";
        }

        @Override
        public String fullName() {
            return "RuleFit";
        }

        @Override
        public String javaName() {
            return RuleFitModel.class.getName();
        }

        @Override
        public long progressUnits() {
            return 1000000L;
        }

        public void validate(RuleFit rfit) {
            if (((RuleFitParameters)rfit._parms)._min_rule_length > ((RuleFitParameters)rfit._parms)._max_rule_length) {
                rfit.error("min_rule_length", "min_rule_length cannot be greater than max_rule_length. Current values:  min_rule_length = " + ((RuleFitParameters)rfit._parms)._min_rule_length + ", max_rule_length = " + ((RuleFitParameters)rfit._parms)._max_rule_length + ".");
            }
        }
    }

    public static enum ModelType {
        RULES,
        RULES_AND_LINEAR,
        LINEAR;

    }

    public static enum Algorithm {
        DRF,
        GBM,
        AUTO;

    }
}

