/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.prims;

import hex.Model;
import java.util.Arrays;
import water.MRTask;
import water.fvec.C0DChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.ArrayUtils;

public class AstPredictedVsActualByVar
extends AstPrimitive<AstPredictedVsActualByVar> {
    @Override
    public String[] args() {
        return new String[]{"model"};
    }

    @Override
    public int nargs() {
        return 5;
    }

    @Override
    public String str() {
        return "predicted.vs.actual.by.var";
    }

    @Override
    public ValFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        String variable;
        Model model = stk.track(asts[1].exec(env)).getModel();
        if (!model.isSupervised()) {
            throw new IllegalArgumentException("Only supervised models are supported for calculating predicted v actual");
        }
        if (((Model.Output)model._output).isMultinomialClassifier()) {
            throw new IllegalArgumentException("Multinomial classification models are not supported by predicted v actual");
        }
        Frame frame = stk.track(asts[2].exec(env)).getFrame();
        if (frame.vec(variable = stk.track(asts[3].exec(env)).getStr()) == null) {
            throw new IllegalArgumentException("Frame doesn't contain column '" + variable + "'.");
        }
        Frame preds = stk.track(asts[4].exec(env)).getFrame();
        if (frame.numRows() != preds.numRows()) {
            throw new IllegalArgumentException("Input frame and frame of predictions need to have same number of columns.");
        }
        Vec predicted = preds.vec(0);
        Vec actual = frame.vec(((Model.Output)model._output).responseName());
        Vec weights = frame.vec(((Model.Output)model._output).weightsName());
        if (actual.domain() != predicted.domain() && !Arrays.equals(actual.domain(), predicted.domain())) {
            throw new IllegalArgumentException("Actual and predicted need to have identical domain.");
        }
        Vec varVec = frame.vec(variable);
        Vec[] vs = new Vec[]{predicted, actual, varVec};
        if (weights != null) {
            vs = ArrayUtils.append(vs, weights);
        }
        PredictedVsActualByVar pva = (PredictedVsActualByVar)new PredictedVsActualByVar(varVec).doAll(vs);
        String[] domainExt = ArrayUtils.append(varVec.domain(), null);
        Vec[] resultVecs = new Vec[]{Vec.makeVec(domainExt, Vec.newKey()), Vec.makeVec(pva._preds, Vec.newKey()), Vec.makeVec(pva._acts, Vec.newKey())};
        Frame result = new Frame(new String[]{variable, preds.name(0), "actual"}, resultVecs);
        return new ValFrame(result);
    }

    static class PredictedVsActualByVar
    extends MRTask<PredictedVsActualByVar> {
        private final int _s;
        private double[] _preds;
        private double[] _acts;
        private double[] _weights;

        public PredictedVsActualByVar(Vec varVec) {
            this._s = varVec.domain().length + 1;
        }

        @Override
        public void map(Chunk[] cs) {
            this._preds = new double[this._s];
            this._acts = new double[this._s];
            this._weights = new double[this._s];
            Chunk predChunk = cs[0];
            Chunk actChunk = cs[1];
            Chunk varChunk = cs[2];
            Chunk weightChunk = cs.length == 4 ? cs[3] : new C0DChunk(1.0, predChunk._len);
            for (int i = 0; i < actChunk._len; ++i) {
                if (actChunk.isNA(i) || weightChunk.atd(i) == 0.0) continue;
                int level = varChunk.isNA(i) ? this._s - 1 : (int)varChunk.atd(i);
                double weight = weightChunk.atd(i);
                int n = level;
                this._preds[n] = this._preds[n] + weight * predChunk.atd(i);
                int n2 = level;
                this._acts[n2] = this._acts[n2] + weight * actChunk.atd(i);
                int n3 = level;
                this._weights[n3] = this._weights[n3] + weight;
            }
        }

        @Override
        public void reduce(PredictedVsActualByVar mrt) {
            this._preds = ArrayUtils.add(this._preds, mrt._preds);
            this._acts = ArrayUtils.add(this._acts, mrt._acts);
            this._weights = ArrayUtils.add(this._weights, mrt._weights);
        }

        @Override
        protected void postGlobal() {
            for (int i = 0; i < this._weights.length; ++i) {
                if (this._weights[i] == 0.0) continue;
                int n = i;
                this._preds[n] = this._preds[n] / this._weights[i];
                int n2 = i;
                this._acts[n2] = this._acts[n2] / this._weights[i];
            }
        }
    }
}

