/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.CustomMetric;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsRegression;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import org.apache.commons.lang.ArrayUtils;
import water.DKV;
import water.Key;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Merge;

public class ModelMetricsRegressionCoxPH
extends ModelMetricsRegression {
    private double _concordance;
    private long _concordant;
    private long _discordant;
    private long _tied_y;

    public double concordance() {
        return this._concordance;
    }

    public long concordant() {
        return this._concordant;
    }

    public long discordant() {
        return this._discordant;
    }

    public long tiedY() {
        return this._tied_y;
    }

    public ModelMetricsRegressionCoxPH(Model model, Frame frame, long nobs, double mse, double sigma, double mae, double rmsle, double meanResidualDeviance, CustomMetric customMetric, double concordance, long concordant, long discordant, long tied_y) {
        super(model, frame, nobs, mse, sigma, mae, rmsle, meanResidualDeviance, customMetric);
        this._concordance = concordance;
        this._concordant = concordant;
        this._discordant = discordant;
        this._tied_y = tied_y;
    }

    public static ModelMetricsRegressionCoxPH getFromDKV(Model model, Frame frame) {
        ModelMetrics mm = ModelMetrics.getFromDKV(model, frame);
        if (!(mm instanceof ModelMetricsRegressionCoxPH)) {
            throw new H2OIllegalArgumentException("Expected to find a Regression ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsRegression for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + mm.getClass());
        }
        return (ModelMetricsRegressionCoxPH)mm;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        if (!Double.isNaN(this._concordance)) {
            sb.append(" concordance: " + (float)this._concordance + "\n");
        } else {
            sb.append(" concordance: N/A\n");
        }
        sb.append(" concordant: " + this._concordant + "\n");
        sb.append(" discordant: " + this._discordant + "\n");
        sb.append(" tied.y: " + this._tied_y + "\n");
        return sb.toString();
    }

    static class StatTree {
        final double[] values;
        final long[] counts;

        StatTree(double[] possibleValues) {
            assert (null != possibleValues);
            assert (StatTree.sortedAscending(possibleValues));
            this.values = new double[possibleValues.length];
            int filled = this.fillTree(possibleValues, 0, possibleValues.length, 0);
            this.addMissingValues(possibleValues, filled);
            this.counts = new long[possibleValues.length];
            assert (StatTree.containsAll(possibleValues, this.values));
            assert (StatTree.isSearchTree(this.values));
            assert (StatTree.allZeroes(this.counts));
        }

        private void addMissingValues(double[] possibleValues, int filled) {
            int missing = possibleValues.length - filled;
            for (int i = 0; i < missing; ++i) {
                this.values[filled + i] = possibleValues[i * 2];
            }
        }

        private int fillTree(double[] inputValues, int start, int stop, int rootIndex) {
            int len = stop - start;
            if (0 >= len) {
                return 0;
            }
            int lastFullRow = 32 - Integer.numberOfLeadingZeros(len + 1) - 1;
            int fillable = (1 << lastFullRow) - 1;
            int totalOverflow = len - fillable;
            int leftOverflow = Math.min(totalOverflow, 1 << lastFullRow - 1);
            int leftTreeSize = (1 << lastFullRow - 1) - 1 + leftOverflow;
            this.values[rootIndex] = inputValues[start + leftTreeSize];
            this.fillTree(inputValues, start, start + leftTreeSize, StatTree.leftChild(rootIndex));
            this.fillTree(inputValues, start + leftTreeSize + 1, stop, StatTree.rightChild(rootIndex));
            return fillable;
        }

        private static boolean sortedAscending(double[] a) {
            for (int i = 1; i < a.length; ++i) {
                if (!(a[i - 1] > a[i])) continue;
                return false;
            }
            return true;
        }

        private static boolean containsAll(double[] a, double[] b) {
            for (int i = 0; i < b.length; ++i) {
                if (ArrayUtils.contains(a, b[i])) continue;
                return false;
            }
            return true;
        }

        private static boolean isSearchTree(double[] a) {
            for (int i = 0; i < a.length; ++i) {
                int leftChild = StatTree.leftChild(i);
                if (leftChild < a.length && a[i] < a[leftChild]) {
                    return false;
                }
                int rightChild = StatTree.rightChild(i);
                if (rightChild >= a.length || !(a[i] > a[rightChild])) continue;
                return false;
            }
            return true;
        }

        private static boolean allZeroes(long[] a) {
            for (int i = 0; i < a.length; ++i) {
                if (0L == a[i]) continue;
                return false;
            }
            return true;
        }

        void insert(double value) {
            int i = 0;
            long n = this.values.length;
            while ((long)i < n) {
                double cur = this.values[i];
                int n2 = i;
                this.counts[n2] = this.counts[n2] + 1L;
                if (value < cur) {
                    i = StatTree.leftChild(i);
                    continue;
                }
                if (value > cur) {
                    i = StatTree.rightChild(i);
                    continue;
                }
                return;
            }
            throw new IllegalArgumentException("Value " + value + " not contained in tree. Tree counts now in illegal state;");
        }

        public int size() {
            return this.values.length;
        }

        public long len() {
            return this.counts[0];
        }

        RankAndCount rankAndCount(double value) {
            int i = 0;
            int rank = 0;
            long count = 0L;
            while (i < this.values.length) {
                double cur = this.values[i];
                if (value < cur) {
                    i = StatTree.leftChild(i);
                    continue;
                }
                if (value > cur) {
                    rank = (int)((long)rank + this.counts[i]);
                    int nexti = StatTree.rightChild(i);
                    if (nexti < this.values.length) {
                        rank = (int)((long)rank - this.counts[nexti]);
                        i = nexti;
                        continue;
                    }
                    return new RankAndCount(rank, count);
                }
                count = this.counts[i];
                int lefti = StatTree.leftChild(i);
                if (lefti < this.values.length) {
                    long nleft = this.counts[lefti];
                    count -= nleft;
                    rank = (int)((long)rank + nleft);
                    int righti = StatTree.rightChild(i);
                    if (righti < this.values.length) {
                        count -= this.counts[righti];
                    }
                }
                return new RankAndCount(rank, count);
            }
            return new RankAndCount(rank, count);
        }

        public String toString() {
            return this.toString(new StringBuilder()).toString();
        }

        private StringBuilder toString(StringBuilder strBuilder) {
            int i = 0;
            int to = 2;
            while (true) {
                if (i < to - 1) {
                    if (i >= this.values.length) {
                        return strBuilder;
                    }
                    strBuilder.append(this.values[i]).append('(').append(this.counts[i]).append(')').append(" ");
                    ++i;
                    continue;
                }
                strBuilder.append("\n");
                to *= 2;
            }
        }

        private static int leftChild(int i) {
            return 2 * i + 1;
        }

        private static int rightChild(int i) {
            return 2 * i + 2;
        }

        static class RankAndCount {
            final long rank;
            final long count;

            public RankAndCount(long rank, long count) {
                this.rank = rank;
                this.count = count;
            }

            public String toString() {
                return "RankAndCount{rank=" + this.rank + ", count=" + this.count + '}';
            }
        }
    }

    public static class MetricBuilderRegressionCoxPH<T extends MetricBuilderRegressionCoxPH<T>>
    extends ModelMetricsRegression.MetricBuilderRegression<T> {
        private final String startVecName;
        private final String stopVecName;
        private final boolean isStratified;
        private final String[] stratifyBy;

        public MetricBuilderRegressionCoxPH(String startVecName, String stopVecName, boolean isStratified, String[] stratifyByName) {
            this.startVecName = startVecName;
            this.stopVecName = stopVecName;
            this.isStratified = isStratified;
            this.stratifyBy = stratifyByName;
        }

        @Override
        public ModelMetricsRegressionCoxPH makeModelMetrics(Model m, Frame f, Frame adaptedFrame, Frame preds) {
            ModelMetricsRegression modelMetricsRegression = super.computeModelMetrics(m, f, adaptedFrame, preds);
            Stats stats = this.concordance(m, f, adaptedFrame, preds);
            ModelMetricsRegressionCoxPH mm = new ModelMetricsRegressionCoxPH(m, f, this._count, modelMetricsRegression.mse(), this.weightedSigma(), modelMetricsRegression.mae(), modelMetricsRegression.rmsle(), modelMetricsRegression.mean_residual_deviance(), this._customMetric, stats.c(), stats.nconcordant, stats.discordant(), stats.nties);
            if (m != null) {
                m.addModelMetrics(mm);
            }
            return mm;
        }

        private Stats concordance(Model m, Frame fr, Frame adaptFrm, Frame scored) {
            Vec startVec = adaptFrm.vec(this.startVecName);
            Vec stopVec = adaptFrm.vec(this.stopVecName);
            Vec statusVec = adaptFrm.lastVec();
            Vec estimateVec = scored.lastVec();
            List<Vec> strataVecs = this.isStratified ? Arrays.asList(this.stratifyBy).stream().map(s -> fr.vec((String)s)).collect(Collectors.toList()) : Collections.emptyList();
            return MetricBuilderRegressionCoxPH.concordance(startVec, stopVec, statusVec, strataVecs, estimateVec);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        static Stats concordance(Vec startVec, Vec stopVec, Vec eventVec, List<Vec> strataVecs, Vec estimateVec) {
            try {
                Scope.enter();
                Vec durations = MetricBuilderRegressionCoxPH.durations(startVec, stopVec);
                Frame fr = MetricBuilderRegressionCoxPH.prepareFrameForConcordanceComputation(eventVec, strataVecs, estimateVec, durations);
                Stats stats = MetricBuilderRegressionCoxPH.concordanceStats(fr);
                return stats;
            }
            finally {
                Scope.exit(new Key[0]);
            }
        }

        private static Frame prepareFrameForConcordanceComputation(Vec eventVec, List<Vec> strataVecs, Vec estimateVec, Vec durations) {
            Frame fr = new Frame(new Vec[0]);
            fr.add("duration", durations);
            fr.add("event", eventVec);
            fr.add("estimate", estimateVec);
            for (int i = 0; i < strataVecs.size(); ++i) {
                fr.add("strata_" + i, strataVecs.get(i));
            }
            return fr;
        }

        private static Vec durations(Vec startVec, Vec stopVec) {
            if (null == startVec) {
                return stopVec;
            }
            Frame frame = ((MRTask)new MRTask(){

                @Override
                public void map(Chunk c0, Chunk c1, NewChunk nc) {
                    for (int i = 0; i < c0._len; ++i) {
                        nc.addNum(c1.atd(i) - c0.atd(i));
                    }
                }
            }.doAll((byte)3, startVec, stopVec)).outputFrame(new String[]{"durations"}, null);
            Vec result = frame.vec(0);
            DKV.put(result);
            Scope.track(result);
            return result;
        }

        private static Stats concordanceStats(Frame fr) {
            Frame withoutNas = MetricBuilderRegressionCoxPH.removeNAs(fr);
            int[] stratasAndDuration = new int[withoutNas.numCols() - 2];
            int[] strataIndexes = new int[withoutNas.numCols() - 3];
            for (int i2 = 0; i2 < strataIndexes.length; ++i2) {
                stratasAndDuration[i2] = i2 + 3;
                strataIndexes[i2] = i2 + 3;
            }
            stratasAndDuration[withoutNas.numCols() - 3] = 0;
            if (0L == withoutNas.numRows()) {
                return new Stats();
            }
            Frame sorted = withoutNas.sort(stratasAndDuration);
            Scope.track(sorted);
            List strataCols = Arrays.stream(strataIndexes).boxed().map(i -> new Vec.Reader(sorted.vec((int)i))).collect(Collectors.toList());
            long lastStart = 0L;
            ArrayList lastRow = new ArrayList(sorted.numCols() - 3);
            Stats statsAcc = new Stats();
            for (long i3 = 0L; i3 < sorted.numRows(); ++i3) {
                ArrayList<Double> row = new ArrayList<Double>(sorted.numCols() - 3);
                for (Vec.Reader strataCol : strataCols) {
                    row.add(strataCol.at(i3));
                }
                if (lastRow.equals(row)) continue;
                lastRow = row;
                Stats stats = MetricBuilderRegressionCoxPH.statsForAStrata(new Vec.Reader(sorted.vec("duration")), new Vec.Reader(sorted.vec("event")), new Vec.Reader(sorted.vec("estimate")), lastStart, i3);
                lastStart = i3;
                statsAcc = statsAcc.plus(stats);
            }
            Stats stats = MetricBuilderRegressionCoxPH.statsForAStrata(new Vec.Reader(sorted.vec("duration")), new Vec.Reader(sorted.vec("event")), new Vec.Reader(sorted.vec("estimate")), lastStart, sorted.numRows());
            return statsAcc.plus(stats);
        }

        private static Frame removeNAs(Frame fr) {
            int[] iDontWantNAsInThisCols = new int[]{0, 2};
            Frame withoutNas = ((Merge.RemoveNAsTask)new Merge.RemoveNAsTask(iDontWantNAsInThisCols).doAll(fr.types(), fr)).outputFrame(fr.names(), fr.domains());
            Scope.track(withoutNas);
            Arrays.stream(withoutNas.vecs()).forEach(Scope::track);
            withoutNas.replace(1, withoutNas.vec("event"));
            return withoutNas;
        }

        private static Stats statsForAStrata(Vec.Reader duration, Vec.Reader eventVec, Vec.Reader estimateVec, long firstIndex, long lastIndex) {
            boolean hasMoreDead;
            boolean hasMoreCensored;
            if (lastIndex == firstIndex) {
                return new Stats();
            }
            int countOfCensored = 0;
            int countOfDead = 0;
            for (long i2 = firstIndex; i2 < lastIndex; ++i2) {
                if (0.0 == eventVec.at(i2)) {
                    ++countOfCensored;
                    continue;
                }
                ++countOfDead;
            }
            long[] indexesOfDead = new long[countOfDead];
            long[] indexesOfCensored = new long[countOfCensored];
            countOfCensored = 0;
            countOfDead = 0;
            for (long i3 = firstIndex; i3 < lastIndex; ++i3) {
                if (0.0 == eventVec.at(i3)) {
                    indexesOfCensored[countOfCensored++] = i3;
                    continue;
                }
                indexesOfDead[countOfDead++] = i3;
            }
            assert ((long)(indexesOfCensored.length + indexesOfDead.length) == lastIndex - firstIndex);
            int diedIndex = 0;
            int censoredIndex = 0;
            DoubleStream estimatesOfDead = Arrays.stream(indexesOfDead).mapToDouble(i -> MetricBuilderRegressionCoxPH.estimateTime(estimateVec, i));
            StatTree timesToCompare = new StatTree(estimatesOfDead.distinct().sorted().toArray());
            long nTotals = 0L;
            long nConcordant = 0L;
            long nTied = 0L;
            while (true) {
                PairStats pairStats;
                hasMoreCensored = censoredIndex < indexesOfCensored.length;
                boolean bl = hasMoreDead = diedIndex < indexesOfDead.length;
                if (hasMoreCensored && (!hasMoreDead || MetricBuilderRegressionCoxPH.deadTime(duration, indexesOfDead[diedIndex]) > MetricBuilderRegressionCoxPH.deadTime(duration, indexesOfCensored[censoredIndex]))) {
                    pairStats = MetricBuilderRegressionCoxPH.handlePairs(indexesOfCensored, estimateVec, censoredIndex, timesToCompare);
                    nTotals += pairStats.pairs;
                    nConcordant += pairStats.concordant;
                    nTied += pairStats.tied;
                    censoredIndex = pairStats.next_ix;
                    continue;
                }
                if (!hasMoreDead || hasMoreCensored && !(MetricBuilderRegressionCoxPH.deadTime(duration, indexesOfDead[diedIndex]) <= MetricBuilderRegressionCoxPH.deadTime(duration, indexesOfCensored[censoredIndex]))) break;
                pairStats = MetricBuilderRegressionCoxPH.handlePairs(indexesOfDead, estimateVec, diedIndex, timesToCompare);
                for (int i4 = diedIndex; i4 < pairStats.next_ix; ++i4) {
                    double pred = MetricBuilderRegressionCoxPH.estimateTime(estimateVec, indexesOfDead[i4]);
                    timesToCompare.insert(pred);
                }
                nTotals += pairStats.pairs;
                nConcordant += pairStats.concordant;
                nTied += pairStats.tied;
                diedIndex = pairStats.next_ix;
            }
            assert (!hasMoreDead && !hasMoreCensored);
            return new Stats(nTotals, nConcordant, nTied);
        }

        private static double deadTime(Vec.Reader duration, long i) {
            return duration.at(i);
        }

        private static double estimateTime(Vec.Reader estimateVec, long i) {
            return -estimateVec.at(i);
        }

        static PairStats handlePairs(long[] truth, Vec.Reader estimateVec, int first_ix, StatTree statTree) {
            int next_ix;
            for (next_ix = first_ix; next_ix < truth.length && truth[next_ix] == truth[first_ix]; ++next_ix) {
            }
            long pairs = statTree.len() * (long)(next_ix - first_ix);
            long correct = 0L;
            long tied = 0L;
            for (int i = first_ix; i < next_ix; ++i) {
                double estimateTime = MetricBuilderRegressionCoxPH.estimateTime(estimateVec, truth[i]);
                StatTree.RankAndCount rankAndCount = statTree.rankAndCount(estimateTime);
                correct += rankAndCount.rank;
                tied += rankAndCount.count;
            }
            PairStats pairStats = new PairStats(pairs, correct, tied, next_ix);
            return pairStats;
        }

        static class PairStats {
            final long pairs;
            final long concordant;
            final long tied;
            final int next_ix;

            public PairStats(long pairs, long concordant, long tied, int next_ix) {
                this.pairs = pairs;
                this.concordant = concordant;
                this.tied = tied;
                this.next_ix = next_ix;
            }

            public String toString() {
                return "PairStats{pairs=" + this.pairs + ", concordant=" + this.concordant + ", tied=" + this.tied + ", next_ix=" + this.next_ix + '}';
            }
        }

        static class Stats {
            final long ntotals;
            final long nconcordant;
            final long nties;

            Stats() {
                this(0L, 0L, 0L);
            }

            Stats(long ntotals, long nconcordant, long nties) {
                this.ntotals = ntotals;
                this.nconcordant = nconcordant;
                this.nties = nties;
            }

            double c() {
                return ((double)this.nconcordant + 0.5 * (double)this.nties) / (double)this.ntotals;
            }

            long discordant() {
                return this.ntotals - this.nconcordant - this.nties;
            }

            public String toString() {
                return "Stats{ntotals=" + this.ntotals + ", nconcordant=" + this.nconcordant + ", ndiscordant=" + this.discordant() + ", nties=" + this.nties + '}';
            }

            Stats plus(Stats s2) {
                return new Stats(this.ntotals + s2.ntotals, this.nconcordant + s2.nconcordant, this.nties + s2.nties);
            }
        }
    }
}

