/*
 * Decompiled with CFR 0.152.
 */
package hex.modelselection;

import hex.DataInfo;
import hex.Model;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.modelselection.ModelSelectionModel;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import water.DKV;
import water.Key;
import water.Lockable;
import water.fvec.Frame;

public class ModelSelectionUtils {
    public static Frame[] generateTrainingFrames(ModelSelectionModel.ModelSelectionParameters parms, int predNum, String[] predNames, int numModels, String foldColumn) {
        int maxPredNum = predNames.length;
        Frame[] trainFrames = new Frame[numModels];
        int[] predIndices = IntStream.range(0, predNum).toArray();
        int zeroBound = maxPredNum - predNum;
        int[] bounds = IntStream.range(zeroBound, maxPredNum).toArray();
        for (int frameCount = 0; frameCount < numModels; ++frameCount) {
            trainFrames[frameCount] = ModelSelectionUtils.generateOneFrame(predIndices, parms, predNames, foldColumn);
            DKV.put(trainFrames[frameCount]);
            ModelSelectionUtils.updatePredIndices(predIndices, bounds);
        }
        return trainFrames;
    }

    public static void updatePredIndices(int[] currentPredIndices, int[] indicesBounds) {
        int lastPredInd;
        for (int index = lastPredInd = currentPredIndices.length - 1; index >= 0; --index) {
            if (currentPredIndices[index] >= indicesBounds[index]) continue;
            int n = index;
            currentPredIndices[n] = currentPredIndices[n] + 1;
            ModelSelectionUtils.updateLaterIndices(currentPredIndices, index, lastPredInd);
            break;
        }
    }

    public static void updateLaterIndices(int[] currentPredIndices, int indexUpdated, int lastPredInd) {
        for (int index = indexUpdated; index < lastPredInd; ++index) {
            currentPredIndices[index + 1] = currentPredIndices[index] + 1;
        }
    }

    public static Frame generateOneFrame(int[] predIndices, ModelSelectionModel.ModelSelectionParameters parms, String[] predNames, String foldColumn) {
        Frame predVecs = new Frame(Key.make());
        Frame train = parms.train();
        for (int predVecNum : predIndices) {
            predVecs.add(predNames[predVecNum], train.vec(predNames[predVecNum]));
        }
        if (parms._weights_column != null) {
            predVecs.add(parms._weights_column, train.vec(parms._weights_column));
        }
        if (parms._offset_column != null) {
            predVecs.add(parms._offset_column, train.vec(parms._offset_column));
        }
        if (foldColumn != null) {
            predVecs.add(foldColumn, train.vec(foldColumn));
        }
        predVecs.add(parms._response_column, train.vec(parms._response_column));
        return predVecs;
    }

    public static BitSet setBitSet(int[] currIndices, int totalPredSize) {
        BitSet predSet = new BitSet(totalPredSize);
        ModelSelectionUtils.setBitSet(predSet, currIndices);
        return predSet;
    }

    public static void setBitSet(BitSet predBitSet, int[] currIndices) {
        for (int predIndex : currIndices) {
            predBitSet.set(predIndex);
        }
    }

    public static Frame[] generateMaxRTrainingFrames(ModelSelectionModel.ModelSelectionParameters parms, String[] predictorNames, String foldColumn, List<Integer> currSubsetIndices, int newPredPos, List<Integer> validSubsets, Set<BitSet> usedCombo) {
        ArrayList<Frame> trainFramesList = new ArrayList<Frame>();
        ArrayList<Integer> changedSubset = new ArrayList<Integer>(currSubsetIndices);
        changedSubset.add(newPredPos, -1);
        int[] predIndices = changedSubset.stream().mapToInt(Integer::intValue).toArray();
        int predNum = predictorNames.length;
        BitSet tempIndices = new BitSet(predNum);
        int predSizes = changedSubset.size();
        boolean emptyUsedCombo = usedCombo != null && usedCombo.size() == 0;
        Iterator<Integer> iterator = validSubsets.iterator();
        while (iterator.hasNext()) {
            Frame trainFrame;
            int predIndex;
            predIndices[newPredPos] = predIndex = iterator.next().intValue();
            if (emptyUsedCombo && predSizes > 1) {
                tempIndices.clear();
                ModelSelectionUtils.setBitSet(tempIndices, predIndices);
                usedCombo.add((BitSet)tempIndices.clone());
                trainFrame = ModelSelectionUtils.generateOneFrame(predIndices, parms, predictorNames, foldColumn);
                DKV.put(trainFrame);
                trainFramesList.add(trainFrame);
                continue;
            }
            if (usedCombo != null && predSizes > 1) {
                tempIndices.clear();
                ModelSelectionUtils.setBitSet(tempIndices, predIndices);
                if (!usedCombo.add((BitSet)tempIndices.clone())) continue;
                trainFrame = ModelSelectionUtils.generateOneFrame(predIndices, parms, predictorNames, foldColumn);
                DKV.put(trainFrame);
                trainFramesList.add(trainFrame);
                continue;
            }
            trainFrame = ModelSelectionUtils.generateOneFrame(predIndices, parms, predictorNames, foldColumn);
            DKV.put(trainFrame);
            trainFramesList.add(trainFrame);
        }
        return (Frame[])trainFramesList.stream().toArray(Frame[]::new);
    }

    public static String[][] shrinkStringArray(String[][] array, int numModels) {
        int arrLen = array.length - 1;
        int offset = numModels - 1;
        String[][] newArray = new String[numModels][];
        for (int index = 0; index < numModels; ++index) {
            newArray[offset - index] = (String[])array[arrLen - index].clone();
        }
        return newArray;
    }

    public static double[][] shrinkDoubleArray(double[][] array, int numModels) {
        int arrLen = array.length - 1;
        int offset = numModels - 1;
        double[][] newArray = new double[numModels][];
        for (int index = 0; index < numModels; ++index) {
            newArray[offset - index] = (double[])array[arrLen - index].clone();
        }
        return newArray;
    }

    public static Key[] shrinkKeyArray(Key[] array, int numModels) {
        int arrLen = array.length;
        Key[] newArray = new Key[numModels];
        System.arraycopy(array, arrLen - numModels, newArray, 0, numModels);
        return newArray;
    }

    public static String joinDouble(double[] val) {
        int arrLen = val.length;
        CharSequence[] strVal = new String[arrLen];
        for (int index = 0; index < arrLen; ++index) {
            strVal[index] = Double.toString(val[index]);
        }
        return String.join((CharSequence)", ", strVal);
    }

    public static int findBestR2Model(double lastBestR2, GLMModel[] bestR2Models) {
        int numModel = bestR2Models.length;
        int bestIndex = 0;
        double currBestR2 = lastBestR2;
        for (int index = 0; index < numModel; ++index) {
            if (bestR2Models[index] == null) continue;
            double bestR2 = bestR2Models[index].r2();
            if (bestR2 > currBestR2) {
                bestR2Models[bestIndex].delete();
                bestIndex = index;
                currBestR2 = bestR2;
                continue;
            }
            bestR2Models[index].delete();
        }
        return currBestR2 > lastBestR2 ? bestIndex : -1;
    }

    public static GLMModel.GLMParameters[] generateGLMParameters(Frame[] trainingFrames, ModelSelectionModel.ModelSelectionParameters parms, int nfolds, String foldColumn, Model.Parameters.FoldAssignmentScheme foldAssignment) {
        int numModels = trainingFrames.length;
        GLMModel.GLMParameters[] params = new GLMModel.GLMParameters[numModels];
        Field[] field1 = ModelSelectionModel.ModelSelectionParameters.class.getDeclaredFields();
        Field[] field2 = Model.Parameters.class.getDeclaredFields();
        for (int index = 0; index < numModels; ++index) {
            params[index] = new GLMModel.GLMParameters();
            ModelSelectionUtils.setParamField(parms, params[index], false, field1, Collections.emptyList());
            ModelSelectionUtils.setParamField(parms, params[index], true, field2, Collections.emptyList());
            params[index]._train = trainingFrames[index]._key;
            params[index]._nfolds = nfolds;
            params[index]._fold_column = foldColumn;
            params[index]._fold_assignment = foldAssignment;
        }
        return params;
    }

    public static void setParamField(Model.Parameters params, GLMModel.GLMParameters glmParam, boolean superClassParams, Field[] paramFields, List<String> excludeList) {
        boolean emptyExcludeList = excludeList.size() == 0;
        for (Field oneField : paramFields) {
            try {
                if (!emptyExcludeList && excludeList.contains(oneField.getName())) continue;
                Field glmField = superClassParams ? glmParam.getClass().getSuperclass().getDeclaredField(oneField.getName()) : glmParam.getClass().getDeclaredField(oneField.getName());
                glmField.set(glmParam, oneField.get(params));
            }
            catch (IllegalAccessException | NoSuchFieldException reflectiveOperationException) {
                // empty catch block
            }
        }
    }

    public static GLM[] buildGLMBuilders(GLMModel.GLMParameters[] trainingParams) {
        int numModels = trainingParams.length;
        GLM[] builders = new GLM[numModels];
        for (int index = 0; index < numModels; ++index) {
            builders[index] = new GLM(trainingParams[index]);
        }
        return builders;
    }

    public static void removeTrainingFrames(Frame[] trainingFrames) {
        for (Frame oneFrame : trainingFrames) {
            DKV.remove(oneFrame._key);
        }
    }

    public static GLMModel findBestModel(GLM[] glmResults) {
        double bestR2Val = 0.0;
        int numModels = glmResults.length;
        Lockable bestModel = null;
        for (int index = 0; index < numModels; ++index) {
            GLMModel oneModel = (GLMModel)glmResults[index].get();
            double currR2 = oneModel.r2();
            if (((GLMModel.GLMParameters)oneModel._parms)._nfolds > 0) {
                int r2Index = Arrays.asList(((GLMModel.GLMOutput)oneModel._output)._cross_validation_metrics_summary.getRowHeaders()).indexOf("r2");
                Float tempR2 = (Float)((GLMModel.GLMOutput)oneModel._output)._cross_validation_metrics_summary.get(r2Index, 0);
                currR2 = tempR2.doubleValue();
            }
            if (currR2 > bestR2Val) {
                bestR2Val = currR2;
                if (bestModel != null) {
                    bestModel.delete();
                }
                bestModel = oneModel;
                continue;
            }
            oneModel.delete();
        }
        return bestModel;
    }

    public static String[] extractPredictorNames(ModelSelectionModel.ModelSelectionParameters parms, DataInfo dinfo, String foldColumn) {
        String[] nonResponseCols;
        List frameNames = Arrays.stream(dinfo._adaptedFrame.names()).collect(Collectors.toList());
        for (String col : nonResponseCols = parms.getNonPredictors()) {
            frameNames.remove(col);
        }
        if (foldColumn != null && frameNames.contains(foldColumn)) {
            frameNames.remove(foldColumn);
        }
        return (String[])frameNames.stream().toArray(String[]::new);
    }

    public static List<String> extraModelColumnNames(List<String> coefNames, GLMModel bestModel) {
        ArrayList<String> coefUsed = new ArrayList<String>();
        ArrayList<String> modelColumns = new ArrayList<String>(Arrays.asList(((GLMModel.GLMOutput)bestModel._output)._names));
        for (String coefName : modelColumns) {
            if (!coefNames.contains(coefName)) continue;
            coefUsed.add(coefName);
        }
        return coefUsed;
    }

    public static void updateValidSubset(List<Integer> validSubset, List<Integer> originalSubset, List<Integer> currSubsetIndices) {
        ArrayList<Integer> onlyInOriginal = new ArrayList<Integer>(originalSubset);
        onlyInOriginal.removeAll(currSubsetIndices);
        ArrayList<Integer> onlyInCurr = new ArrayList<Integer>(currSubsetIndices);
        onlyInCurr.removeAll(originalSubset);
        validSubset.addAll(onlyInOriginal);
        validSubset.removeAll(onlyInCurr);
    }
}

