package org.kie.kogito.explainability.handlers;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;
import org.kie.kogito.explainability.ConversionUtils;
import org.kie.kogito.explainability.PredictionProviderFactory;
import org.kie.kogito.explainability.api.BaseExplainabilityRequestDto;
import org.kie.kogito.explainability.api.BaseExplainabilityResultDto;
import org.kie.kogito.explainability.api.CounterfactualExplainabilityRequestDto;
import org.kie.kogito.explainability.api.CounterfactualExplainabilityResultDto;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionDto;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainDto;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureDto;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualResult;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionFeatureDomain;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.models.BaseExplainabilityRequest;
import org.kie.kogito.explainability.models.CounterfactualExplainabilityRequest;
import org.kie.kogito.explainability.models.ModelIdentifier;
import org.kie.kogito.tracing.typedvalue.CollectionValue;
import org.kie.kogito.tracing.typedvalue.StructureValue;
import org.kie.kogito.tracing.typedvalue.TypedValue;

@ApplicationScoped
/* loaded from: input_file:org/kie/kogito/explainability/handlers/CounterfactualExplainerServiceHandler.class */
public class CounterfactualExplainerServiceHandler implements LocalExplainerServiceHandler<CounterfactualResult, CounterfactualExplainabilityRequest, CounterfactualExplainabilityRequestDto> {
    private final CounterfactualExplainer explainer;
    private final PredictionProviderFactory predictionProviderFactory;

    @Inject
    public CounterfactualExplainerServiceHandler(CounterfactualExplainer counterfactualExplainer, PredictionProviderFactory predictionProviderFactory) {
        this.explainer = counterfactualExplainer;
        this.predictionProviderFactory = predictionProviderFactory;
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public <T extends BaseExplainabilityRequest> boolean supports(Class<T> cls) {
        return CounterfactualExplainabilityRequest.class.isAssignableFrom(cls);
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public <T extends BaseExplainabilityRequestDto> boolean supportsDto(Class<T> cls) {
        return CounterfactualExplainabilityRequestDto.class.isAssignableFrom(cls);
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public CounterfactualExplainabilityRequest explainabilityRequestFrom(CounterfactualExplainabilityRequestDto counterfactualExplainabilityRequestDto) {
        return new CounterfactualExplainabilityRequest(counterfactualExplainabilityRequestDto.getExecutionId(), counterfactualExplainabilityRequestDto.getCounterfactualId(), counterfactualExplainabilityRequestDto.getServiceUrl(), ModelIdentifier.from(counterfactualExplainabilityRequestDto.getModelIdentifier()), counterfactualExplainabilityRequestDto.getOriginalInputs(), counterfactualExplainabilityRequestDto.getGoals(), counterfactualExplainabilityRequestDto.getSearchDomains());
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public PredictionProvider getPredictionProvider(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest) {
        return this.predictionProviderFactory.createPredictionProvider(counterfactualExplainabilityRequest.getServiceUrl(), counterfactualExplainabilityRequest.getModelIdentifier(), counterfactualExplainabilityRequest.getGoals());
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public Prediction getPrediction(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest) {
        Map<String, TypedValue> originalInputs = counterfactualExplainabilityRequest.getOriginalInputs();
        Map<String, TypedValue> goals = counterfactualExplainabilityRequest.getGoals();
        Map<String, CounterfactualSearchDomainDto> searchDomains = counterfactualExplainabilityRequest.getSearchDomains();
        if (isUnsupportedModel(originalInputs, goals, searchDomains)) {
            throw new IllegalArgumentException("Counterfactual explanations only support flat models.");
        }
        return new CounterfactualPrediction(new PredictionInput(ConversionUtils.toFeatureList(originalInputs)), new PredictionOutput(ConversionUtils.toOutputList(goals)), new PredictionFeatureDomain(ConversionUtils.toFeatureDomainList(searchDomains)), ConversionUtils.toFeatureConstraintList(searchDomains), (DataDistribution) null, UUID.fromString(counterfactualExplainabilityRequest.getExecutionId()));
    }

    private boolean isUnsupportedModel(Map<String, TypedValue> map, Map<String, TypedValue> map2, Map<String, CounterfactualSearchDomainDto> map3) {
        return isUnsupportedTypedValue(map.values()) || isUnsupportedTypedValue(map2.values()) || isUnsupportedCounterfactualSearchDomain(map3.values());
    }

    private boolean isUnsupportedTypedValue(Collection<TypedValue> collection) {
        return collection.stream().anyMatch(typedValue -> {
            return (typedValue instanceof StructureValue) || (typedValue instanceof CollectionValue);
        });
    }

    private boolean isUnsupportedCounterfactualSearchDomain(Collection<CounterfactualSearchDomainDto> collection) {
        return collection.stream().anyMatch(counterfactualSearchDomainDto -> {
            return (counterfactualSearchDomainDto instanceof CounterfactualSearchDomainStructureDto) || (counterfactualSearchDomainDto instanceof CounterfactualSearchDomainCollectionDto);
        });
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public BaseExplainabilityResultDto createSucceededResultDto(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest, CounterfactualResult counterfactualResult) {
        return buildResultDtoFromExplanation(counterfactualExplainabilityRequest, counterfactualResult, CounterfactualExplainabilityResultDto.Stage.FINAL);
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public BaseExplainabilityResultDto createFailedResultDto(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest, Throwable th) {
        return CounterfactualExplainabilityResultDto.buildFailed(counterfactualExplainabilityRequest.getExecutionId(), counterfactualExplainabilityRequest.getCounterfactualId(), th.getMessage());
    }

    @Override // org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler
    public BaseExplainabilityResultDto createIntermediateResultDto(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest, CounterfactualResult counterfactualResult) {
        return buildResultDtoFromExplanation(counterfactualExplainabilityRequest, counterfactualResult, CounterfactualExplainabilityResultDto.Stage.INTERMEDIATE);
    }

    private CounterfactualExplainabilityResultDto buildResultDtoFromExplanation(CounterfactualExplainabilityRequest counterfactualExplainabilityRequest, CounterfactualResult counterfactualResult, CounterfactualExplainabilityResultDto.Stage stage) {
        List list = (List) counterfactualResult.getEntities().stream().map((v0) -> {
            return v0.asFeature();
        }).collect(Collectors.toList());
        List output = counterfactualResult.getOutput();
        if (Objects.isNull(output)) {
            throw new NullPointerException(String.format("Null Outputs produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", counterfactualExplainabilityRequest.getExecutionId(), counterfactualExplainabilityRequest.getCounterfactualId()));
        }
        if (output.isEmpty()) {
            throw new IllegalStateException(String.format("No Outputs produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", counterfactualExplainabilityRequest.getExecutionId(), counterfactualExplainabilityRequest.getCounterfactualId()));
        }
        if (output.size() > 1) {
            throw new IllegalStateException(String.format("Multiple Output sets produced for Explanation with ExecutionId '%s' and CounterfactualId '%s'", counterfactualExplainabilityRequest.getExecutionId(), counterfactualExplainabilityRequest.getCounterfactualId()));
        }
        return CounterfactualExplainabilityResultDto.buildSucceeded(counterfactualExplainabilityRequest.getExecutionId(), counterfactualExplainabilityRequest.getCounterfactualId(), counterfactualResult.getSolutionId().toString(), Long.valueOf(counterfactualResult.getSequenceId()), Boolean.valueOf(counterfactualResult.isValid()), stage, ConversionUtils.fromFeatureList(list), ConversionUtils.fromOutputs(((PredictionOutput) output.get(0)).getOutputs()));
    }

    public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider predictionProvider, Consumer<CounterfactualResult> consumer) {
        return this.explainer.explainAsync(prediction, predictionProvider, consumer);
    }
}
