package org.kie.kogito.explainability.global.pdp;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.global.GlobalExplainer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureDistribution;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PartialDependenceGraph;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionInputsDataDistribution;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.PredictionProviderMetadata;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/kie/kogito/explainability/global/pdp/PartialDependencePlotExplainer.class */
public class PartialDependencePlotExplainer implements GlobalExplainer<List<PartialDependenceGraph>> {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) PartialDependencePlotExplainer.class);
    private final PartialDependencePlotConfig config;

    public PartialDependencePlotExplainer(PartialDependencePlotConfig partialDependencePlotConfig) {
        this.config = partialDependencePlotConfig;
    }

    public PartialDependencePlotExplainer() {
        this(new PartialDependencePlotConfig());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.kie.kogito.explainability.global.GlobalExplainer
    public List<PartialDependenceGraph> explainFromMetadata(PredictionProvider predictionProvider, PredictionProviderMetadata predictionProviderMetadata) throws InterruptedException, ExecutionException, TimeoutException {
        return explainFromDataDistribution(predictionProvider, predictionProviderMetadata.getOutputShape().getOutputs().size(), predictionProviderMetadata.getDataDistribution());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.kie.kogito.explainability.global.GlobalExplainer
    public List<PartialDependenceGraph> explainFromPredictions(PredictionProvider predictionProvider, Collection<Prediction> collection) throws InterruptedException, ExecutionException, TimeoutException {
        return explainFromDataDistribution(predictionProvider, collection.isEmpty() ? 0 : ((Integer) collection.stream().findAny().map(prediction -> {
            return Integer.valueOf(prediction.getOutput().getOutputs().size());
        }).orElse(0)).intValue(), new PredictionInputsDataDistribution((List) collection.stream().map((v0) -> {
            return v0.getInput();
        }).collect(Collectors.toList())));
    }

    private List<PartialDependenceGraph> explainFromDataDistribution(PredictionProvider predictionProvider, int i, DataDistribution dataDistribution) throws InterruptedException, ExecutionException, TimeoutException {
        long currentTimeMillis = System.currentTimeMillis();
        ArrayList arrayList = new ArrayList();
        List<FeatureDistribution> asFeatureDistributions = dataDistribution.asFeatureDistributions();
        asFeatureDistributions.size();
        List<PredictionInput> sample = dataDistribution.sample(this.config.getSeriesLength());
        for (FeatureDistribution featureDistribution : asFeatureDistributions) {
            List<Value<?>> list = (List) featureDistribution.sample(this.config.getSeriesLength()).stream().sorted(Comparator.comparing((v0) -> {
                return v0.asString();
            })).sorted((value, value2) -> {
                return Comparator.comparingDouble((v0) -> {
                    return v0.asNumber();
                }).compare(value, value2);
            }).distinct().collect(Collectors.toList());
            List<Feature> list2 = (List) list.stream().map(value3 -> {
                return FeatureFactory.copyOf(featureDistribution.getFeature(), value3);
            }).collect(Collectors.toList());
            for (int i2 = 0; i2 < i; i2++) {
                arrayList.add(getPartialDependenceGraph(predictionProvider, sample, list, list2, i2));
            }
        }
        LOGGER.debug("explanation time: {}ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
        return arrayList;
    }

    private PartialDependenceGraph getPartialDependenceGraph(PredictionProvider predictionProvider, List<PredictionInput> list, List<Value<?>> list2, List<Feature> list3, int i) throws InterruptedException, ExecutionException, TimeoutException {
        Output output = null;
        Feature feature = null;
        ArrayList arrayList = new ArrayList(list3.size());
        for (int i2 = 0; i2 < list3.size(); i2++) {
            if (feature == null) {
                feature = FeatureFactory.copyOf(list3.get(i2), new Value(null));
            }
            Iterator<PredictionOutput> it = getOutputs(predictionProvider, prepareInputs(list3.get(i2), list)).iterator();
            while (it.hasNext()) {
                Output output2 = it.next().getOutputs().get(i);
                if (output == null) {
                    output = new Output(output2.getName(), output2.getType());
                }
                updateValueCounts(arrayList, i2, output2);
            }
        }
        if (output == null) {
            throw new RuntimeException("cannot produce PDP for null decision");
        }
        return new PartialDependenceGraph(feature, output, list2, collapseMarginalImpacts(arrayList, output.getType()));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v45, types: [java.util.List] */
    private List<Value<?>> collapseMarginalImpacts(List<Map<Value<?>, Long>> list, Type type) {
        ArrayList arrayList = new ArrayList();
        if (Type.NUMBER.equals(type)) {
            arrayList = (List) ((List) list.stream().map(map -> {
                return Double.valueOf(map.entrySet().stream().map(entry -> {
                    return Double.valueOf((((Value) entry.getKey()).asNumber() * ((Long) entry.getValue()).longValue()) / this.config.getSeriesLength());
                }).mapToDouble(d -> {
                    return d.doubleValue();
                }).sum());
            }).collect(Collectors.toList())).stream().map((v1) -> {
                return new Value(v1);
            }).collect(Collectors.toList());
        } else {
            Iterator<Map<Value<?>, Long>> it = list.iterator();
            while (it.hasNext()) {
                long j = 0;
                String str = null;
                for (Map.Entry<Value<?>, Long> entry : it.next().entrySet()) {
                    if (entry.getValue().longValue() > j) {
                        j = entry.getValue().longValue();
                        str = entry.getKey().asString();
                    }
                }
                arrayList.add(new Value(str));
            }
        }
        return arrayList;
    }

    private void updateValueCounts(List<Map<Value<?>, Long>> list, int i, Output output) {
        Value<?> value = output.getValue();
        if (list.size() <= i) {
            HashMap hashMap = new HashMap();
            hashMap.put(value, 1L);
            list.add(hashMap);
        } else {
            Map<Value<?>, Long> map = list.get(i);
            if (map.containsKey(value)) {
                map.put(value, Long.valueOf(map.get(value).longValue() + 1));
            } else {
                map.put(value, 1L);
            }
            list.set(i, map);
        }
    }

    private List<PredictionOutput> getOutputs(PredictionProvider predictionProvider, List<PredictionInput> list) throws InterruptedException, ExecutionException, TimeoutException {
        try {
            return predictionProvider.predictAsync(list).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
            LOGGER.error("Impossible to obtain prediction {}", e.getMessage());
            throw e;
        }
    }

    private List<PredictionInput> prepareInputs(Feature feature, List<PredictionInput> list) {
        ArrayList arrayList = new ArrayList(this.config.getSeriesLength());
        Iterator<PredictionInput> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new PredictionInput(replaceFeatures(feature, it.next().getFeatures())));
        }
        return arrayList;
    }

    private List<Feature> replaceFeatures(Feature feature, List<Feature> list) {
        Feature copyOf;
        ArrayList arrayList = new ArrayList();
        for (Feature feature2 : list) {
            if (feature2.getName().equals(feature.getName())) {
                copyOf = FeatureFactory.copyOf(feature2, feature.getValue());
            } else if (Type.COMPOSITE == feature2.getType()) {
                copyOf = FeatureFactory.newCompositeFeature(feature2.getName(), replaceFeatures(feature, (List) feature2.getValue().getUnderlyingObject()));
            } else {
                copyOf = FeatureFactory.copyOf(feature2, feature2.getValue());
            }
            arrayList.add(copyOf);
        }
        return arrayList;
    }

    @Override // org.kie.kogito.explainability.global.GlobalExplainer
    public /* bridge */ /* synthetic */ List<PartialDependenceGraph> explainFromPredictions(PredictionProvider predictionProvider, Collection collection) throws InterruptedException, ExecutionException, TimeoutException {
        return explainFromPredictions(predictionProvider, (Collection<Prediction>) collection);
    }
}
