package org.kie.kogito.explainability.global.pdp;

import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.kie.kogito.explainability.global.GlobalExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.FeatureDistribution;
import org.kie.kogito.explainability.model.PartialDependenceGraph;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.PredictionProviderMetadata;
import org.kie.kogito.explainability.utils.DataUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/global/pdp/PartialDependencePlotExplainer.class */
public class PartialDependencePlotExplainer implements GlobalExplainer<Collection<PartialDependenceGraph>> {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) PartialDependencePlotExplainer.class);
    private static final int DEFAULT_SERIES_LENGTH = 100;
    private final int seriesLength;

    public PartialDependencePlotExplainer(int i) {
        this.seriesLength = i;
    }

    public PartialDependencePlotExplainer() {
        this(100);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.kie.kogito.explainability.global.GlobalExplainer
    public Collection<PartialDependenceGraph> explain(PredictionProvider predictionProvider, PredictionProviderMetadata predictionProviderMetadata) {
        long currentTimeMillis = System.currentTimeMillis();
        LinkedList linkedList = new LinkedList();
        DataDistribution dataDistribution = predictionProviderMetadata.getDataDistribution();
        int size = predictionProviderMetadata.getInputShape().getFeatures().size();
        List<FeatureDistribution> featureDistributions = dataDistribution.getFeatureDistributions();
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < predictionProviderMetadata.getOutputShape().getOutputs().size(); i2++) {
                double[] generateSamples = DataUtils.generateSamples(featureDistributions.get(i).getMin(), featureDistributions.get(i).getMax(), this.seriesLength);
                double[][] dArr = new double[size][this.seriesLength];
                for (int i3 = 0; i3 < size; i3++) {
                    dArr[i3] = DataUtils.generateData(featureDistributions.get(i3).getMean(), featureDistributions.get(i3).getStdDev(), this.seriesLength);
                }
                double[] dArr2 = new double[generateSamples.length];
                for (int i4 = 0; i4 < generateSamples.length; i4++) {
                    LinkedList linkedList2 = new LinkedList();
                    double d = generateSamples[i4];
                    double[] dArr3 = new double[size];
                    dArr3[i] = d;
                    for (int i5 = 0; i5 < this.seriesLength; i5++) {
                        for (int i6 = 0; i6 < size; i6++) {
                            if (i6 != i) {
                                dArr3[i6] = dArr[i6][i5];
                            }
                        }
                        linkedList2.add(new PredictionInput(DataUtils.doublesToFeatures(dArr3)));
                    }
                    Iterator<PredictionOutput> it = predictionProvider.predict(linkedList2).iterator();
                    while (it.hasNext()) {
                        int i7 = i4;
                        dArr2[i7] = dArr2[i7] + (it.next().getOutputs().get(i2).getScore() / this.seriesLength);
                    }
                }
                linkedList.add(new PartialDependenceGraph(predictionProviderMetadata.getInputShape().getFeatures().get(i), generateSamples, dArr2));
            }
        }
        LOGGER.debug("explanation time: {}ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
        return linkedList;
    }
}
