package org.kie.kogito.explainability.handlers;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;
import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.kie.kogito.explainability.ConversionUtils;
import org.kie.kogito.explainability.PredictionProviderFactory;
import org.kie.kogito.explainability.api.BaseExplainabilityRequest;
import org.kie.kogito.explainability.api.BaseExplainabilityResult;
import org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest;
import org.kie.kogito.explainability.api.CounterfactualExplainabilityResult;
import org.kie.kogito.explainability.api.CounterfactualSearchDomain;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue;
import org.kie.kogito.explainability.api.HasNameValue;
import org.kie.kogito.explainability.api.NamedTypedValue;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualResult;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.tracing.typedvalue.CollectionValue;
import org.kie.kogito.tracing.typedvalue.StructureValue;
import org.kie.kogito.tracing.typedvalue.TypedValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ApplicationScoped
/* loaded from: input_file:org/kie/kogito/explainability/handlers/CounterfactualExplainerServiceHandler.class */
public class CounterfactualExplainerServiceHandler implements LocalExplainerServiceHandler<CounterfactualResult, CounterfactualExplainabilityRequest> {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) CounterfactualExplainerServiceHandler.class);
    private final Long kafkaMaxRecordAgeSeconds;
    private final CounterfactualExplainer explainer;
    private final PredictionProviderFactory predictionProviderFactory;

    public CounterfactualExplainerServiceHandler() {
    }

    @Inject
    public CounterfactualExplainerServiceHandler(CounterfactualExplainer counterfactualExplainer, PredictionProviderFactory predictionProviderFactory, @ConfigProperty(name = "mp.messaging.incoming.trusty-explainability-request.throttled.unprocessed-record-max-age.ms", defaultValue = "60000") Long l) {
        this.explainer = counterfactualExplainer;
        this.predictionProviderFactory = predictionProviderFactory;
        this.kafkaMaxRecordAgeSeconds = Long.valueOf(Math.floorDiv(l.longValue(), 1000));
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public <T extends BaseExplainabilityRequest> boolean supports(Class<T> cls) {
        return CounterfactualExplainabilityRequest.class.isAssignableFrom(cls);
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public PredictionProvider getPredictionProvider(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest) {
        return this.predictionProviderFactory.createPredictionProvider(counterfactualExplainabilityRequest.getServiceUrl(), counterfactualExplainabilityRequest.getModelIdentifier(), counterfactualExplainabilityRequest.getGoals());
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public Prediction getPrediction(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest) {
        List<NamedTypedValue> mapBasedSorting = toMapBasedSorting(counterfactualExplainabilityRequest.getGoals());
        Collection<CounterfactualSearchDomain> searchDomains = counterfactualExplainabilityRequest.getSearchDomains();
        Collection<NamedTypedValue> originalInputs = counterfactualExplainabilityRequest.getOriginalInputs();
        Long maxRunningTimeSeconds = counterfactualExplainabilityRequest.getMaxRunningTimeSeconds();
        if (Objects.nonNull(maxRunningTimeSeconds) && maxRunningTimeSeconds.longValue() > this.kafkaMaxRecordAgeSeconds.longValue()) {
            LOGGER.info(String.format("Maximum Running Timeout set to '%d's since the provided value '%d's exceeded the Messaging sub-system configuration '%d's.", this.kafkaMaxRecordAgeSeconds, maxRunningTimeSeconds, this.kafkaMaxRecordAgeSeconds));
            maxRunningTimeSeconds = this.kafkaMaxRecordAgeSeconds;
        }
        if (isUnsupportedModel(originalInputs, mapBasedSorting, searchDomains)) {
            throw new IllegalArgumentException("Counterfactual explanations only support flat models.");
        }
        return new CounterfactualPrediction(new PredictionInput(ConversionUtils.toFeatureList(originalInputs, searchDomains)), new PredictionOutput(ConversionUtils.toOutputList(mapBasedSorting)), null, UUID.fromString(counterfactualExplainabilityRequest.getExecutionId()), maxRunningTimeSeconds);
    }

    private boolean isUnsupportedModel(Collection<NamedTypedValue> collection, Collection<NamedTypedValue> collection2, Collection<CounterfactualSearchDomain> collection3) {
        return isUnsupportedTypedValue(collection) || isUnsupportedTypedValue(collection2) || isUnsupportedCounterfactualSearchDomain(collection3);
    }

    private boolean isUnsupportedTypedValue(Collection<? extends HasNameValue<?>> collection) {
        return collection.stream().map((v0) -> {
            return v0.getValue();
        }).anyMatch(baseTypedValue -> {
            return (baseTypedValue instanceof StructureValue) || (baseTypedValue instanceof CollectionValue);
        });
    }

    private boolean isUnsupportedCounterfactualSearchDomain(Collection<CounterfactualSearchDomain> collection) {
        return collection.stream().map((v0) -> {
            return v0.getValue();
        }).anyMatch(counterfactualSearchDomainValue -> {
            return (counterfactualSearchDomainValue instanceof CounterfactualSearchDomainStructureValue) || (counterfactualSearchDomainValue instanceof CounterfactualSearchDomainCollectionValue);
        });
    }

    private List<NamedTypedValue> toMapBasedSorting(Collection<NamedTypedValue> collection) {
        return (List) (collection != null ? (Map) collection.stream().collect(HashMap::new, (hashMap, namedTypedValue) -> {
            hashMap.put(namedTypedValue.getName(), namedTypedValue.getValue());
        }, (v0, v1) -> {
            v0.putAll(v1);
        }) : Collections.emptyMap()).entrySet().stream().map(entry -> {
            return new NamedTypedValue((String) entry.getKey(), (TypedValue) entry.getValue());
        }).collect(Collectors.toList());
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public BaseExplainabilityResult createSucceededResult(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest, CounterfactualResult counterfactualResult) {
        return buildResultFromExplanation(counterfactualExplainabilityRequest, counterfactualResult, CounterfactualExplainabilityResult.Stage.FINAL);
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public BaseExplainabilityResult createFailedResult(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest, Throwable th) {
        return CounterfactualExplainabilityResult.buildFailed(counterfactualExplainabilityRequest.getExecutionId(), counterfactualExplainabilityRequest.getCounterfactualId(), th.getMessage());
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public BaseExplainabilityResult createIntermediateResult(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest, CounterfactualResult counterfactualResult) {
        return buildResultFromExplanation(counterfactualExplainabilityRequest, counterfactualResult, CounterfactualExplainabilityResult.Stage.INTERMEDIATE);
    }

    private CounterfactualExplainabilityResult buildResultFromExplanation(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest, CounterfactualResult counterfactualResult, CounterfactualExplainabilityResult.Stage stage) {
        List list = (List) counterfactualResult.getEntities().stream().map((v0) -> {
            return v0.asFeature();
        }).collect(Collectors.toList());
        List<PredictionOutput> output = counterfactualResult.getOutput();
        if (Objects.isNull(output)) {
            throw new NullPointerException(String.format("Null Outputs produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", counterfactualExplainabilityRequest.getExecutionId(), counterfactualExplainabilityRequest.getCounterfactualId()));
        }
        if (output.isEmpty()) {
            throw new IllegalStateException(String.format("No Outputs produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", counterfactualExplainabilityRequest.getExecutionId(), counterfactualExplainabilityRequest.getCounterfactualId()));
        }
        if (output.size() > 1) {
            throw new IllegalStateException(String.format("Multiple Output sets produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", counterfactualExplainabilityRequest.getExecutionId(), counterfactualExplainabilityRequest.getCounterfactualId()));
        }
        return CounterfactualExplainabilityResult.buildSucceeded(counterfactualExplainabilityRequest.getExecutionId(), counterfactualExplainabilityRequest.getCounterfactualId(), counterfactualResult.getSolutionId().toString(), Long.valueOf(counterfactualResult.getSequenceId()), Boolean.valueOf(counterfactualResult.isValid()), stage, ConversionUtils.fromFeatureList(list), ConversionUtils.fromOutputs(output.get(0).getOutputs()));
    }

    @Override // org.kie.kogito.explainability.local.LocalExplainer
    public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider predictionProvider, Consumer<CounterfactualResult> consumer) {
        return this.explainer.explainAsync(prediction, predictionProvider, consumer);
    }
}
