/*
 * Decompiled with CFR 0.152.
 */
package hex.psvm;

import hex.DataInfo;
import hex.FrameTask;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.psvm.PSVMModel;
import hex.psvm.SupportVector;
import hex.psvm.psvm.IncompleteCholeskyFactorization;
import hex.psvm.psvm.Kernel;
import hex.psvm.psvm.PrimalDualIPM;
import java.util.ArrayList;
import java.util.Arrays;
import water.AutoBuffer;
import water.DKV;
import water.H2O;
import water.Key;
import water.Lockable;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.TwoDimTable;
import water.util.VecUtils;

public class PSVM
extends ModelBuilder<PSVMModel, PSVMModel.PSVMParameters, PSVMModel.PSVMModelOutput> {
    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Binomial};
    }

    @Override
    public boolean isSupervised() {
        return true;
    }

    @Override
    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public PSVM(boolean startup_once) {
        super(new PSVMModel.PSVMParameters(), startup_once);
    }

    public PSVM(PSVMModel.PSVMParameters parms) {
        super(parms);
        this.init(false);
    }

    @Override
    public void init(boolean expensive) {
        super.init(expensive);
        if (expensive) {
            // empty if block
        }
    }

    @Override
    public void checkDistributions() {
        if (this._response.isCategorical()) {
            if (this._response.cardinality() != 2) {
                this.error("_response", "Expected a binary categorical response, but instead got response with " + this._response.cardinality() + " categories.");
            }
        } else if (this._response.min() != -1.0 || this._response.max() != 1.0 || !this._response.isInt() || this._response.nzCnt() != this._response.length()) {
            this.error("_response", "Non-categorical response provided, please make sure the response is either binary categorical response or uses only values -1/+1 in case of numerical response.");
        }
    }

    @Override
    protected boolean computePriorClassDistribution() {
        return false;
    }

    @Override
    protected int init_getNClass() {
        return 2;
    }

    @Override
    protected ModelBuilder.Driver trainModelImpl() {
        return new SVMDriver();
    }

    private int getRankICF(double rankRatio, long numRows) {
        if (rankRatio == -1.0) {
            return (int)Math.sqrt(numRows);
        }
        return (int)(rankRatio * (double)numRows);
    }

    private static SupportVector toSupportVector(double alpha, DataInfo.Row row) {
        if (row.isSparse()) {
            throw new UnsupportedOperationException("Sparse rows are not yet supported");
        }
        return new SupportVector().fill(alpha, row.numVals, row.binIds);
    }

    private static TwoDimTable createModelSummaryTable(PSVMModel.PSVMModelOutput output, IPMInfo ipmInfo) {
        ArrayList<String> colHeaders = new ArrayList<String>();
        ArrayList<String> colTypes = new ArrayList<String>();
        ArrayList<String> colFormat = new ArrayList<String>();
        colHeaders.add("Number of Support Vectors");
        colTypes.add("long");
        colFormat.add("%d");
        colHeaders.add("Number of Bounded Support Vectors");
        colTypes.add("long");
        colFormat.add("%d");
        colHeaders.add("Raw Model Size in Bytes");
        colTypes.add("long");
        colFormat.add("%d");
        colHeaders.add("rho");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("Number of Iterations");
        colTypes.add("long");
        colFormat.add("%d");
        colHeaders.add("Surrogate Gap");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("Primal Residual");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("Dual Residual");
        colTypes.add("double");
        colFormat.add("%.5f");
        boolean rows = true;
        TwoDimTable table = new TwoDimTable("Model Summary", null, new String[1], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), "");
        int row = 0;
        int col = 0;
        table.set(row, col++, output._svs_count);
        table.set(row, col++, output._bsv_count);
        table.set(row, col++, output._compressed_svs != null ? output._compressed_svs.length : -1);
        table.set(row, col++, output._rho);
        table.set(row, col++, ipmInfo._iter);
        table.set(row, col++, ipmInfo._sgap);
        table.set(row, col++, ipmInfo._resp);
        table.set(row, col++, ipmInfo._resd);
        assert (col == colHeaders.size());
        return table;
    }

    private static class IPMInfo
    implements PrimalDualIPM.ProgressObserver {
        int _iter;
        double _sgap;
        double _resp;
        double _resd;
        boolean _converged;

        private IPMInfo() {
        }

        @Override
        public void reportProgress(int iter2, double sgap, double resp, double resd, boolean converged) {
            this._iter = iter2;
            this._sgap = sgap;
            this._resp = resp;
            this._resd = resd;
            this._converged = converged;
        }
    }

    private static class RegulateAlphaTask
    extends MRTask<RegulateAlphaTask> {
        private double _c_pos;
        private double _c_neg;
        private double _sv_threshold;
        long _svs_count;
        long _bsv_count;

        private RegulateAlphaTask(double c_pos, double c_neg, double sv_threshold) {
            this._c_pos = c_pos;
            this._c_neg = c_neg;
            this._sv_threshold = sv_threshold;
        }

        @Override
        public void map(Chunk alpha, Chunk label, NewChunk nc) {
            for (int i = 0; i < alpha._len; ++i) {
                double out_x;
                double c;
                double x = alpha.atd(i);
                if (x <= this._sv_threshold) {
                    alpha.setNA(i);
                    continue;
                }
                ++this._svs_count;
                nc.addNum(alpha.start() + (long)i);
                double d = c = label.atd(i) > 0.0 ? this._c_pos : this._c_neg;
                if (c - x <= this._sv_threshold) {
                    out_x = c;
                    ++this._bsv_count;
                } else {
                    out_x = x;
                }
                alpha.set(i, out_x * label.atd(i));
            }
        }

        @Override
        public void reduce(RegulateAlphaTask mrt) {
            this._svs_count += mrt._svs_count;
            this._bsv_count += mrt._bsv_count;
        }

        private Vec updateStats(PSVMModel.PSVMModelOutput o) {
            o._svs_count = this._svs_count;
            o._bsv_count = this._bsv_count;
            return this.outputFrame().vec(0);
        }
    }

    private static class CollectSupportVecSamplesTask
    extends FrameTask<CollectSupportVecSamplesTask> {
        private Vec _svs;
        private int _num_selected;
        DataInfo.Row[][] _selected;
        private transient long[] _local_selected_idxs;

        @Override
        protected void setupLocal() {
            super.setupLocal();
            this._selected = new DataInfo.Row[H2O.CLOUD.size()][];
            int[] cids = VecUtils.getLocalChunkIds(this._svs);
            long local_svs = 0L;
            for (int cidx : cids) {
                local_svs += (long)this._svs.chunkLen(cidx);
            }
            int local_contribution = (int)((long)this._num_selected * local_svs / this._svs.length());
            DataInfo.Row[] local_selected = new DataInfo.Row[local_contribution];
            this._local_selected_idxs = new long[local_contribution];
            this._selected[H2O.SELF.index()] = local_selected;
            int v = 0;
            block1: for (int cidx : cids) {
                Chunk svIndices = this._svs.chunkForChunkIdx(cidx);
                for (int i = 0; i < svIndices._len; ++i) {
                    this._local_selected_idxs[v++] = svIndices.at8(i);
                    if (v == local_contribution) break block1;
                }
            }
            Arrays.sort(this._local_selected_idxs);
        }

        CollectSupportVecSamplesTask(DataInfo dinfo, Vec svs, int num_selected) {
            super(null, dinfo);
            this._svs = svs;
            this._num_selected = num_selected;
        }

        @Override
        protected boolean skipRow(long gid) {
            return Arrays.binarySearch(this._local_selected_idxs, gid) < 0;
        }

        @Override
        protected void processRow(long gid, DataInfo.Row r) {
            int idx = Arrays.binarySearch(this._local_selected_idxs, gid);
            this._selected[H2O.SELF.index()][idx] = r.deepClone();
        }

        @Override
        public void reduce(CollectSupportVecSamplesTask mrt) {
            for (int i = 0; i < H2O.CLOUD.size(); ++i) {
                if (mrt._selected[i] == null) continue;
                this._selected[i] = mrt._selected[i];
            }
        }

        DataInfo.Row[] getSelected() {
            return ArrayUtils.flat(this._selected);
        }
    }

    private static class CalculateRhoTask
    extends FrameTask<CalculateRhoTask> {
        DataInfo.Row[] _selected;
        Vec _alpha;
        Kernel _kernel;
        double[] _rhos;
        transient long _offset;
        transient Chunk _alphaChunk;

        public CalculateRhoTask(DataInfo dinfo, DataInfo.Row[] selected, Vec alpha, Kernel kernel) {
            super(null, dinfo);
            this._selected = selected;
            this._alpha = alpha;
            this._kernel = kernel;
        }

        @Override
        public void map(Chunk[] chunks, NewChunk[] outputs) {
            this._alphaChunk = this._alpha.chunkForChunkIdx(chunks[0].cidx());
            this._offset = this._alphaChunk.start();
            this._rhos = new double[this._selected.length];
            super.map(chunks, outputs);
        }

        @Override
        protected boolean skipRow(long gid) {
            return this._alphaChunk.isNA((int)(gid - this._offset));
        }

        @Override
        protected void processRow(long gid, DataInfo.Row r) {
            for (int i = 0; i < this._selected.length; ++i) {
                int n = i;
                this._rhos[n] = this._rhos[n] + this._alphaChunk.atd((int)(gid - this._offset)) * this._kernel.calcKernel(r, this._selected[i]);
            }
        }

        @Override
        public void reduce(CalculateRhoTask mrt) {
            this._rhos = ArrayUtils.add(this._rhos, mrt._rhos);
        }

        double getRho() {
            double rho = 0.0;
            for (int i = 0; i < this._selected.length; ++i) {
                rho += this._selected[i].response[0] - this._rhos[i];
            }
            return rho / (double)this._selected.length;
        }
    }

    private static class CompressVectorsTask
    extends MRTask<CompressVectorsTask> {
        private final DataInfo _dinfo;
        private byte[] _csvs;

        CompressVectorsTask(DataInfo dinfo) {
            this._dinfo = dinfo;
        }

        @Override
        public void map(Chunk[] acs) {
            AutoBuffer ab = new AutoBuffer();
            Chunk alpha = acs[acs.length - 1];
            Chunk[] cs = Arrays.copyOf(acs, acs.length - 1);
            DataInfo.Row row = this._dinfo.newDenseRow();
            SupportVector sv = new SupportVector();
            for (int i = 0; i < alpha._len; ++i) {
                if (alpha.isNA(i)) continue;
                this._dinfo.extractDenseRow(cs, i, row);
                sv.fill(alpha.atd(i), row.numVals, row.binIds);
                sv.compress(ab);
            }
            this._csvs = ab.buf();
        }

        @Override
        public void reduce(CompressVectorsTask mrt) {
            this._csvs = ArrayUtils.append(this._csvs, mrt._csvs);
        }
    }

    private class SVMDriver
    extends ModelBuilder.Driver {
        private SVMDriver() {
            super(PSVM.this);
        }

        DataInfo adaptTrain() {
            Frame adapted = new Frame(PSVM.this.train());
            adapted.remove(((PSVMModel.PSVMParameters)PSVM.this._parms)._response_column);
            if (PSVM.this.response().naCnt() > 0L) {
                throw new IllegalStateException("NA values in response column are currently not supported.");
            }
            Vec numericResp = PSVM.this.response().domain() == null ? PSVM.this.response() : ((MRTask)new MRTask(){

                @Override
                public void map(Chunk c, NewChunk nc) {
                    for (int i = 0; i < c._len; ++i) {
                        if (c.at8(i) == 0L) {
                            nc.addNum(-1.0);
                            continue;
                        }
                        nc.addNum(1.0);
                    }
                }
            }.doAll((byte)3, PSVM.this.response())).outputFrame().vec(0);
            adapted.add(((PSVMModel.PSVMParameters)PSVM.this._parms)._response_column, numericResp);
            adapted.add("two_norm_sq", Scope.track(adapted.anyVec().makeZero()));
            return new DataInfo(adapted, null, 2, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, false, false, false, null).disableIntercept();
        }

        Frame prototypeFrame(DataInfo di) {
            Frame f = new Frame(di._adaptedFrame);
            f.remove("two_norm_sq");
            return f;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void computeImpl() {
            Lockable model = null;
            try {
                boolean tooBig;
                PSVM.this.init(true);
                PSVM.this._job.update(0L, "Initializing model training");
                DataInfo di = this.adaptTrain();
                Scope.track_generic(di);
                DKV.put(di);
                if (((PSVMModel.PSVMParameters)PSVM.this._parms)._gamma == -1.0) {
                    ((PSVMModel.PSVMParameters)PSVM.this._parms)._gamma = 1.0 / (double)di.fullN();
                    Log.info("Set gamma = " + ((PSVMModel.PSVMParameters)PSVM.this._parms)._gamma);
                }
                Vec response = di._adaptedFrame.vec(((PSVMModel.PSVMParameters)PSVM.this._parms)._response_column);
                PSVMModel.PSVMModelOutput output = new PSVMModel.PSVMModelOutput(PSVM.this, this.prototypeFrame(di), PSVM.this.response().domain());
                model = new PSVMModel(PSVM.this._result, (PSVMModel.PSVMParameters)PSVM.this._parms, output);
                model.delete_and_lock(PSVM.this._job);
                int rank = PSVM.this.getRankICF(((PSVMModel.PSVMParameters)PSVM.this._parms)._rank_ratio, di._adaptedFrame.numRows());
                Log.info("Desired rank of ICF matrix = " + rank);
                PSVM.this._job.update(0L, "Running Incomplete Cholesky Factorization");
                Frame icf = IncompleteCholeskyFactorization.icf(di, ((PSVMModel.PSVMParameters)PSVM.this._parms).kernel(), rank, ((PSVMModel.PSVMParameters)PSVM.this._parms)._fact_threshold);
                Scope.track(icf);
                PSVM.this._job.update(0L, "Running IPM");
                IPMInfo ipmInfo = new IPMInfo();
                Vec alpha = PrimalDualIPM.solve(icf, response, ((PSVMModel.PSVMParameters)PSVM.this._parms).ipmParms(), ipmInfo);
                icf.remove();
                Log.info("IPM finished");
                Vec svs = ((RegulateAlphaTask)new RegulateAlphaTask(((PSVMModel.PSVMParameters)PSVM.this._parms).c_pos(), ((PSVMModel.PSVMParameters)PSVM.this._parms).c_neg(), ((PSVMModel.PSVMParameters)PSVM.this._parms)._sv_threshold).doAll((byte)3, alpha, response)).updateStats((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output);
                assert (svs.length() == ((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output)._svs_count);
                Scope.track(svs);
                Frame alphaFr = new Frame(Key.make(((PSVMModel)model)._key + "_alpha"));
                alphaFr.add("alpha", alpha);
                DKV.put(alphaFr);
                ((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output)._alpha_key = alphaFr._key;
                int sampleSize = (int)Math.min(svs.length(), 1000L);
                DataInfo.Row[] sampleSVs = ((CollectSupportVecSamplesTask)new CollectSupportVecSamplesTask(di, svs, sampleSize).doAll(di._adaptedFrame)).getSelected();
                ((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output)._rho = ((CalculateRhoTask)new CalculateRhoTask(di, sampleSVs, alpha, ((PSVMModel.PSVMParameters)PSVM.this._parms).kernel()).doAll(di._adaptedFrame)).getRho();
                long estimatedSize = 0L;
                for (DataInfo.Row sv : sampleSVs) {
                    estimatedSize += (long)PSVM.toSupportVector(0.0, sv).estimateSize();
                }
                if (svs.length() > (long)sampleSize) {
                    estimatedSize = estimatedSize * svs.length() / (long)sampleSize;
                }
                boolean bl = tooBig = estimatedSize >= Integer.MAX_VALUE;
                if (!tooBig) {
                    Frame fr = new Frame(di._adaptedFrame);
                    fr.add("__alpha", alpha);
                    ((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output)._compressed_svs = ((CompressVectorsTask)new CompressVectorsTask(di).doAll(fr))._csvs;
                    assert (svs.length() > (long)sampleSize || (long)((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output)._compressed_svs.length == estimatedSize);
                } else {
                    Log.err("Estimated model size (" + estimatedSize + "B) exceeds limits of DKV. Support vectors will not be stored.");
                    ((Model)model).addWarning("Model too big (size " + estimatedSize + "B) exceeds maximum model size. Support vectors will not be stored as a part of the model. You can still inspect what vectors were chosen and what are their alpha coefficients (see Frame alpha in model output).");
                    ((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output)._compressed_svs = new byte[0];
                }
                Log.info("Total #support vectors: " + ((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output)._svs_count + " (size in memory " + estimatedSize + "B)");
                ((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output)._model_summary = PSVM.createModelSummaryTable((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output, ipmInfo);
                model.update(PSVM.this._job);
                if (!tooBig) {
                    if (((PSVMModel.PSVMParameters)PSVM.this._parms)._disable_training_metrics) {
                        String noMetricsWarning = "Not creating training metrics: scoring disabled (use disable_training_metrics = false to override)";
                        Log.warn(noMetricsWarning);
                        ((Model)model).addWarning(noMetricsWarning);
                    } else {
                        PSVM.this._job.update(0L, "Scoring training frame");
                        Frame scoringTrain = new Frame(PSVM.this.train());
                        ((Model)model).adaptTestForTrain(scoringTrain, true, true);
                        ((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output)._training_metrics = ((PSVMModel)model).makeModelMetrics(PSVM.this.train(), scoringTrain, "Training metrics");
                    }
                    if (PSVM.this.valid() != null) {
                        PSVM.this._job.update(0L, "Scoring validation frame");
                        Frame scoringValid = new Frame(PSVM.this.valid());
                        ((Model)model).adaptTestForTrain(scoringValid, true, true);
                        ((PSVMModel.PSVMModelOutput)((PSVMModel)model)._output)._validation_metrics = ((PSVMModel)model).makeModelMetrics(PSVM.this.valid(), scoringValid, "Validation metrics");
                    }
                }
                Scope.untrack(alpha._key);
            }
            finally {
                if (model != null) {
                    model.unlock(PSVM.this._job);
                }
            }
        }
    }
}

