package org.kie.kogito.explainability;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
import java.util.stream.Stream;
import javax.enterprise.inject.Instance;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.ThrowingSupplier;
import org.kie.kogito.explainability.api.BaseExplainabilityResultDto;
import org.kie.kogito.explainability.api.CounterfactualExplainabilityResultDto;
import org.kie.kogito.explainability.api.ExplainabilityStatus;
import org.kie.kogito.explainability.api.FeatureImportanceDto;
import org.kie.kogito.explainability.api.LIMEExplainabilityResultDto;
import org.kie.kogito.explainability.api.SaliencyDto;
import org.kie.kogito.explainability.handlers.CounterfactualExplainerServiceHandler;
import org.kie.kogito.explainability.handlers.LimeExplainerServiceHandler;
import org.kie.kogito.explainability.handlers.LocalExplainerServiceHandler;
import org.kie.kogito.explainability.handlers.LocalExplainerServiceHandlerRegistry;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.models.ModelIdentifier;
import org.kie.kogito.tracing.typedvalue.TypedValue;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

/* loaded from: input_file:org/kie/kogito/explainability/ExplanationServiceImplTest.class */
class ExplanationServiceImplTest {
    private static final Long MAX_RUNNING_TIME_SECONDS = 60L;
    Instance instance;
    ExplanationServiceImpl explanationService;
    LimeExplainer limeExplainerMock;
    LimeExplainerServiceHandler limeExplainerServiceHandlerMock;
    CounterfactualExplainer cfExplainerMock;
    CounterfactualExplainerServiceHandler cfExplainerServiceHandlerMock;
    LocalExplainerServiceHandlerRegistry explainerServiceHandlerRegistryMock;
    PredictionProvider predictionProviderMock;
    Consumer<BaseExplainabilityResultDto> callbackMock;

    ExplanationServiceImplTest() {
    }

    @BeforeEach
    void init() {
        this.instance = (Instance) Mockito.mock(Instance.class);
        this.limeExplainerMock = (LimeExplainer) Mockito.mock(LimeExplainer.class);
        this.cfExplainerMock = (CounterfactualExplainer) Mockito.mock(CounterfactualExplainer.class);
        PredictionProviderFactory predictionProviderFactory = (PredictionProviderFactory) Mockito.mock(PredictionProviderFactory.class);
        this.explainerServiceHandlerRegistryMock = new LocalExplainerServiceHandlerRegistry(this.instance);
        this.limeExplainerServiceHandlerMock = (LimeExplainerServiceHandler) Mockito.spy(new LimeExplainerServiceHandler(this.limeExplainerMock, predictionProviderFactory));
        this.cfExplainerServiceHandlerMock = (CounterfactualExplainerServiceHandler) Mockito.spy(new CounterfactualExplainerServiceHandler(this.cfExplainerMock, predictionProviderFactory, MAX_RUNNING_TIME_SECONDS));
        this.predictionProviderMock = (PredictionProvider) Mockito.mock(PredictionProvider.class);
        this.callbackMock = (Consumer) Mockito.mock(Consumer.class);
        this.explanationService = new ExplanationServiceImpl(this.explainerServiceHandlerRegistryMock);
        Mockito.when(predictionProviderFactory.createPredictionProvider((String) ArgumentMatchers.any(), (ModelIdentifier) ArgumentMatchers.any(), (Map) ArgumentMatchers.any())).thenReturn(this.predictionProviderMock);
    }

    @Test
    void testLIMEExplainAsyncSucceeded() {
        testLIMEExplainAsyncSuccess(() -> {
            return (BaseExplainabilityResultDto) this.explanationService.explainAsync(TestUtils.LIME_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        });
    }

    @Test
    void testLIMEExplainAsyncSucceededWithoutCallback() {
        testLIMEExplainAsyncSuccess(() -> {
            return (BaseExplainabilityResultDto) this.explanationService.explainAsync(TestUtils.LIME_REQUEST).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        });
    }

    void testLIMEExplainAsyncSuccess(ThrowingSupplier<BaseExplainabilityResultDto> throwingSupplier) {
        Mockito.when(this.instance.stream()).thenReturn(Stream.of(this.limeExplainerServiceHandlerMock));
        Mockito.when(this.limeExplainerMock.explainAsync((Prediction) ArgumentMatchers.any(Prediction.class), (PredictionProvider) ArgumentMatchers.eq(this.predictionProviderMock), (Consumer) ArgumentMatchers.any(Consumer.class))).thenReturn(CompletableFuture.completedFuture(TestUtils.SALIENCY_MAP));
        LIMEExplainabilityResultDto lIMEExplainabilityResultDto = (BaseExplainabilityResultDto) Assertions.assertDoesNotThrow(throwingSupplier);
        Assertions.assertNotNull(lIMEExplainabilityResultDto);
        Assertions.assertTrue(lIMEExplainabilityResultDto instanceof LIMEExplainabilityResultDto);
        LIMEExplainabilityResultDto lIMEExplainabilityResultDto2 = lIMEExplainabilityResultDto;
        Assertions.assertEquals(TestUtils.EXECUTION_ID, lIMEExplainabilityResultDto2.getExecutionId());
        Assertions.assertSame(ExplainabilityStatus.SUCCEEDED, lIMEExplainabilityResultDto2.getStatus());
        Assertions.assertNull(lIMEExplainabilityResultDto2.getStatusDetails());
        Assertions.assertEquals(TestUtils.SALIENCY_MAP.size(), lIMEExplainabilityResultDto2.getSaliencies().size());
        Assertions.assertTrue(lIMEExplainabilityResultDto2.getSaliencies().containsKey("key"));
        SaliencyDto saliencyDto = (SaliencyDto) lIMEExplainabilityResultDto2.getSaliencies().get("key");
        Assertions.assertEquals(TestUtils.SALIENCY.getPerFeatureImportance().size(), saliencyDto.getFeatureImportance().size());
        FeatureImportanceDto featureImportanceDto = (FeatureImportanceDto) saliencyDto.getFeatureImportance().get(0);
        Assertions.assertEquals(TestUtils.FEATURE_IMPORTANCE_1.getFeature().getName(), featureImportanceDto.getFeatureName());
        Assertions.assertEquals(TestUtils.FEATURE_IMPORTANCE_1.getScore(), featureImportanceDto.getScore().doubleValue(), 0.01d);
    }

    @Test
    void testCounterfactualsExplainAsyncSucceeded() {
        testCounterfactualsExplainAsyncSuccess(() -> {
            return (BaseExplainabilityResultDto) this.explanationService.explainAsync(TestUtils.COUNTERFACTUAL_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        });
    }

    @Test
    void testCounterfactualsExplainAsyncSucceededWithoutCallback() {
        testCounterfactualsExplainAsyncSuccess(() -> {
            return (BaseExplainabilityResultDto) this.explanationService.explainAsync(TestUtils.COUNTERFACTUAL_REQUEST).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        });
    }

    void testCounterfactualsExplainAsyncSuccess(ThrowingSupplier<BaseExplainabilityResultDto> throwingSupplier) {
        Mockito.when(this.instance.stream()).thenReturn(Stream.of(this.cfExplainerServiceHandlerMock));
        Mockito.when(this.cfExplainerMock.explainAsync((Prediction) ArgumentMatchers.any(Prediction.class), (PredictionProvider) ArgumentMatchers.eq(this.predictionProviderMock), (Consumer) ArgumentMatchers.any(Consumer.class))).thenReturn(CompletableFuture.completedFuture(TestUtils.COUNTERFACTUAL_RESULT));
        CounterfactualExplainabilityResultDto counterfactualExplainabilityResultDto = (BaseExplainabilityResultDto) Assertions.assertDoesNotThrow(throwingSupplier);
        Assertions.assertNotNull(counterfactualExplainabilityResultDto);
        Assertions.assertTrue(counterfactualExplainabilityResultDto instanceof CounterfactualExplainabilityResultDto);
        CounterfactualExplainabilityResultDto counterfactualExplainabilityResultDto2 = counterfactualExplainabilityResultDto;
        Assertions.assertEquals(TestUtils.EXECUTION_ID, counterfactualExplainabilityResultDto2.getExecutionId());
        Assertions.assertEquals(TestUtils.COUNTERFACTUAL_ID, counterfactualExplainabilityResultDto2.getCounterfactualId());
        Assertions.assertSame(ExplainabilityStatus.SUCCEEDED, counterfactualExplainabilityResultDto2.getStatus());
        Assertions.assertNull(counterfactualExplainabilityResultDto2.getStatusDetails());
        Assertions.assertEquals(TestUtils.COUNTERFACTUAL_RESULT.getEntities().size(), counterfactualExplainabilityResultDto2.getInputs().size());
        Assertions.assertEquals(TestUtils.COUNTERFACTUAL_RESULT.getOutput().size(), counterfactualExplainabilityResultDto2.getOutputs().size());
        Assertions.assertTrue(counterfactualExplainabilityResultDto2.getOutputs().containsKey("output1"));
        TypedValue typedValue = (TypedValue) counterfactualExplainabilityResultDto2.getOutputs().get("output1");
        Assertions.assertTrue(typedValue.isUnit());
        Assertions.assertEquals(Double.class.getSimpleName(), typedValue.toUnit().getType());
        Assertions.assertEquals(555.0d, typedValue.toUnit().getValue().asDouble());
    }

    @Test
    void testServiceCallFailed() {
        RuntimeException runtimeException = new RuntimeException("Something bad happened");
        Mockito.when(this.instance.stream()).thenReturn(Stream.of(this.limeExplainerServiceHandlerMock));
        ((LimeExplainerServiceHandler) Mockito.doThrow(new Throwable[]{runtimeException}).when(this.limeExplainerServiceHandlerMock)).supports((Class) ArgumentMatchers.any());
        Assertions.assertThrows(RuntimeException.class, () -> {
            this.explanationService.explainAsync(TestUtils.LIME_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        });
    }

    @Test
    void testServiceCallFailedNoMatchingServiceHandlers() {
        Mockito.when(this.instance.stream()).thenReturn(Stream.of(new Object[0]));
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            this.explanationService.explainAsync(TestUtils.LIME_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        });
    }

    @Test
    void testLIMEExplainAsyncFailed() {
        RuntimeException runtimeException = new RuntimeException("Something bad happened");
        Mockito.when(this.instance.stream()).thenReturn(Stream.of(this.limeExplainerServiceHandlerMock));
        Mockito.when(this.limeExplainerMock.explainAsync((Prediction) ArgumentMatchers.any(Prediction.class), (PredictionProvider) ArgumentMatchers.eq(this.predictionProviderMock), (Consumer) ArgumentMatchers.any(Consumer.class))).thenThrow(new Throwable[]{runtimeException});
        LIMEExplainabilityResultDto lIMEExplainabilityResultDto = (BaseExplainabilityResultDto) Assertions.assertDoesNotThrow(() -> {
            return (BaseExplainabilityResultDto) this.explanationService.explainAsync(TestUtils.LIME_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        });
        Assertions.assertNotNull(lIMEExplainabilityResultDto);
        Assertions.assertTrue(lIMEExplainabilityResultDto instanceof LIMEExplainabilityResultDto);
        LIMEExplainabilityResultDto lIMEExplainabilityResultDto2 = lIMEExplainabilityResultDto;
        Assertions.assertEquals(TestUtils.EXECUTION_ID, lIMEExplainabilityResultDto2.getExecutionId());
        Assertions.assertSame(ExplainabilityStatus.FAILED, lIMEExplainabilityResultDto2.getStatus());
        Assertions.assertEquals("Something bad happened", lIMEExplainabilityResultDto2.getStatusDetails());
    }

    @Test
    void testCounterfactualsxplainAsyncFailed() {
        RuntimeException runtimeException = new RuntimeException("Something bad happened");
        Mockito.when(this.instance.stream()).thenReturn(Stream.of(this.cfExplainerServiceHandlerMock));
        Mockito.when(this.cfExplainerMock.explainAsync((Prediction) ArgumentMatchers.any(Prediction.class), (PredictionProvider) ArgumentMatchers.eq(this.predictionProviderMock), (Consumer) ArgumentMatchers.any(Consumer.class))).thenThrow(new Throwable[]{runtimeException});
        CounterfactualExplainabilityResultDto counterfactualExplainabilityResultDto = (BaseExplainabilityResultDto) Assertions.assertDoesNotThrow(() -> {
            return (BaseExplainabilityResultDto) this.explanationService.explainAsync(TestUtils.COUNTERFACTUAL_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        });
        Assertions.assertNotNull(counterfactualExplainabilityResultDto);
        Assertions.assertTrue(counterfactualExplainabilityResultDto instanceof CounterfactualExplainabilityResultDto);
        CounterfactualExplainabilityResultDto counterfactualExplainabilityResultDto2 = counterfactualExplainabilityResultDto;
        Assertions.assertEquals(TestUtils.EXECUTION_ID, counterfactualExplainabilityResultDto2.getExecutionId());
        Assertions.assertSame(ExplainabilityStatus.FAILED, counterfactualExplainabilityResultDto2.getStatus());
        Assertions.assertEquals("Something bad happened", counterfactualExplainabilityResultDto2.getStatusDetails());
    }

    @Test
    void testServiceHandlerLookupLIME() {
        Mockito.when(this.instance.stream()).thenReturn(Stream.of((Object[]) new LocalExplainerServiceHandler[]{this.limeExplainerServiceHandlerMock, this.cfExplainerServiceHandlerMock}));
        Mockito.when(this.limeExplainerMock.explainAsync((Prediction) ArgumentMatchers.any(), (PredictionProvider) ArgumentMatchers.any(), (Consumer) ArgumentMatchers.any())).thenReturn(CompletableFuture.completedFuture(TestUtils.SALIENCY_MAP));
        BaseExplainabilityResultDto baseExplainabilityResultDto = (BaseExplainabilityResultDto) Assertions.assertDoesNotThrow(() -> {
            return (BaseExplainabilityResultDto) this.explanationService.explainAsync(TestUtils.LIME_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        });
        Assertions.assertNotNull(baseExplainabilityResultDto);
        Assertions.assertTrue(baseExplainabilityResultDto instanceof LIMEExplainabilityResultDto);
    }

    @Test
    void testServiceHandlerLookupCounterfactuals() {
        Mockito.when(this.instance.stream()).thenReturn(Stream.of((Object[]) new LocalExplainerServiceHandler[]{this.limeExplainerServiceHandlerMock, this.cfExplainerServiceHandlerMock}));
        Mockito.when(this.cfExplainerMock.explainAsync((Prediction) ArgumentMatchers.any(), (PredictionProvider) ArgumentMatchers.any(), (Consumer) ArgumentMatchers.any())).thenReturn(CompletableFuture.completedFuture(TestUtils.COUNTERFACTUAL_RESULT));
        BaseExplainabilityResultDto baseExplainabilityResultDto = (BaseExplainabilityResultDto) Assertions.assertDoesNotThrow(() -> {
            return (BaseExplainabilityResultDto) this.explanationService.explainAsync(TestUtils.COUNTERFACTUAL_REQUEST, this.callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        });
        Assertions.assertNotNull(baseExplainabilityResultDto);
        Assertions.assertTrue(baseExplainabilityResultDto instanceof CounterfactualExplainabilityResultDto);
    }
}
