package org.kie.kogito.explainability.handlers;

import com.fasterxml.jackson.databind.node.IntNode;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.api.BaseExplainabilityRequestDto;
import org.kie.kogito.explainability.api.ExplainabilityStatus;
import org.kie.kogito.explainability.api.FeatureImportanceDto;
import org.kie.kogito.explainability.api.LIMEExplainabilityRequestDto;
import org.kie.kogito.explainability.api.LIMEExplainabilityResultDto;
import org.kie.kogito.explainability.api.ModelIdentifierDto;
import org.kie.kogito.explainability.api.SaliencyDto;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureImportance;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.models.BaseExplainabilityRequest;
import org.kie.kogito.explainability.models.LIMEExplainabilityRequest;
import org.kie.kogito.explainability.models.ModelIdentifier;
import org.kie.kogito.tracing.typedvalue.CollectionValue;
import org.kie.kogito.tracing.typedvalue.StructureValue;
import org.kie.kogito.tracing.typedvalue.UnitValue;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

/* loaded from: input_file:org/kie/kogito/explainability/handlers/LimeExplainerServiceHandlerTest.class */
public class LimeExplainerServiceHandlerTest {
    private static final String EXECUTION_ID = "executionId";
    private static final String SERVICE_URL = "serviceURL";
    private static final ModelIdentifier MODEL_IDENTIFIER = new ModelIdentifier("resourceType", "resourceId");
    private static final ModelIdentifierDto MODEL_IDENTIFIER_DTO = new ModelIdentifierDto("resourceType", "resourceId");
    private LimeExplainer explainer;
    private LimeExplainerServiceHandler handler;

    @BeforeEach
    public void setup() {
        this.explainer = (LimeExplainer) Mockito.mock(LimeExplainer.class);
        this.handler = new LimeExplainerServiceHandler(this.explainer);
    }

    @Test
    public void testSupports() {
        Assertions.assertTrue(this.handler.supports(LIMEExplainabilityRequest.class));
        Assertions.assertFalse(this.handler.supports(BaseExplainabilityRequest.class));
    }

    @Test
    public void testSupportsDo() {
        Assertions.assertTrue(this.handler.supportsDto(LIMEExplainabilityRequestDto.class));
        Assertions.assertFalse(this.handler.supportsDto(BaseExplainabilityRequestDto.class));
    }

    @Test
    public void testExplainabilityRequestFrom() {
        LIMEExplainabilityRequestDto lIMEExplainabilityRequestDto = new LIMEExplainabilityRequestDto("executionId", SERVICE_URL, MODEL_IDENTIFIER_DTO, Collections.emptyMap(), Collections.emptyMap());
        LIMEExplainabilityRequest explainabilityRequestFrom = this.handler.explainabilityRequestFrom(lIMEExplainabilityRequestDto);
        Assertions.assertEquals(lIMEExplainabilityRequestDto.getExecutionId(), explainabilityRequestFrom.getExecutionId());
        Assertions.assertEquals(lIMEExplainabilityRequestDto.getServiceUrl(), explainabilityRequestFrom.getServiceUrl());
        Assertions.assertEquals(lIMEExplainabilityRequestDto.getModelIdentifier().getResourceId(), explainabilityRequestFrom.getModelIdentifier().getResourceId());
        Assertions.assertEquals(lIMEExplainabilityRequestDto.getModelIdentifier().getResourceType(), explainabilityRequestFrom.getModelIdentifier().getResourceType());
        Assertions.assertEquals(lIMEExplainabilityRequestDto.getInputs(), explainabilityRequestFrom.getInputs());
        Assertions.assertEquals(lIMEExplainabilityRequestDto.getOutputs(), explainabilityRequestFrom.getOutputs());
    }

    @Test
    public void testGetPredictionWithEmptyDefinition() {
        Prediction prediction = this.handler.getPrediction(new LIMEExplainabilityRequest("executionId", SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyMap(), Collections.emptyMap()));
        Assertions.assertTrue(prediction.getInput().getFeatures().isEmpty());
        Assertions.assertTrue(prediction.getOutput().getOutputs().isEmpty());
    }

    @Test
    public void testGetPredictionWithNonEmptyDefinition() {
        Prediction prediction = this.handler.getPrediction(new LIMEExplainabilityRequest("executionId", SERVICE_URL, MODEL_IDENTIFIER, Map.of("input1", new UnitValue("number", new IntNode(20)), "input2", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55)))), "input3", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100))))), Map.of("output1", new UnitValue("number", new IntNode(20)), "output2", new StructureValue("number", Map.of("output2b", new UnitValue("number", new IntNode(55)))), "output3", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))));
        Assertions.assertEquals(3, prediction.getInput().getFeatures().size());
        Optional findFirst = prediction.getInput().getFeatures().stream().filter(feature -> {
            return feature.getName().equals("input1");
        }).findFirst();
        Assertions.assertTrue(findFirst.isPresent());
        Feature feature2 = (Feature) findFirst.get();
        Assertions.assertEquals(Type.NUMBER, feature2.getType());
        Assertions.assertEquals(20.0d, feature2.getValue().asNumber());
        Optional findFirst2 = prediction.getInput().getFeatures().stream().filter(feature3 -> {
            return feature3.getName().equals("input2");
        }).findFirst();
        Assertions.assertTrue(findFirst2.isPresent());
        Feature feature4 = (Feature) findFirst2.get();
        Assertions.assertEquals(Type.COMPOSITE, feature4.getType());
        Assertions.assertTrue(feature4.getValue().getUnderlyingObject() instanceof List);
        List list = (List) feature4.getValue().getUnderlyingObject();
        Assertions.assertEquals(1, list.size());
        Optional findFirst3 = list.stream().filter(feature5 -> {
            return feature5.getName().equals("input2b");
        }).findFirst();
        Assertions.assertTrue(findFirst3.isPresent());
        Feature feature6 = (Feature) findFirst3.get();
        Assertions.assertEquals(Type.NUMBER, feature6.getType());
        Assertions.assertEquals(55.0d, feature6.getValue().asNumber());
        Optional findFirst4 = prediction.getInput().getFeatures().stream().filter(feature7 -> {
            return feature7.getName().equals("input3");
        }).findFirst();
        Assertions.assertTrue(findFirst4.isPresent());
        Feature feature8 = (Feature) findFirst4.get();
        Assertions.assertEquals(Type.COMPOSITE, feature8.getType());
        Assertions.assertTrue(feature8.getValue().getUnderlyingObject() instanceof List);
        List list2 = (List) feature8.getValue().getUnderlyingObject();
        Assertions.assertEquals(1, list2.size());
        Feature feature9 = (Feature) list2.get(0);
        Assertions.assertEquals(Type.NUMBER, feature9.getType());
        Assertions.assertEquals(100.0d, feature9.getValue().asNumber());
        Assertions.assertEquals(3, prediction.getOutput().getOutputs().size());
        Optional findFirst5 = prediction.getOutput().getOutputs().stream().filter(output -> {
            return output.getName().equals("output1");
        }).findFirst();
        Assertions.assertTrue(findFirst5.isPresent());
        Output output2 = (Output) findFirst5.get();
        Assertions.assertEquals(Type.NUMBER, output2.getType());
        Assertions.assertEquals(20.0d, output2.getValue().asNumber());
        Optional findFirst6 = prediction.getOutput().getOutputs().stream().filter(output3 -> {
            return output3.getName().equals("output2");
        }).findFirst();
        Assertions.assertTrue(findFirst6.isPresent());
        Output output4 = (Output) findFirst6.get();
        Assertions.assertEquals(Type.COMPOSITE, feature4.getType());
        Assertions.assertTrue(output4.getValue().getUnderlyingObject() instanceof List);
        List list3 = (List) output4.getValue().getUnderlyingObject();
        Assertions.assertEquals(1, list3.size());
        Optional findFirst7 = list3.stream().filter(output5 -> {
            return output5.getName().equals("output2b");
        }).findFirst();
        Assertions.assertTrue(findFirst7.isPresent());
        Output output6 = (Output) findFirst7.get();
        Assertions.assertEquals(Type.NUMBER, output6.getType());
        Assertions.assertEquals(55.0d, output6.getValue().asNumber());
        Optional findFirst8 = prediction.getOutput().getOutputs().stream().filter(output7 -> {
            return output7.getName().equals("output3");
        }).findFirst();
        Assertions.assertTrue(findFirst8.isPresent());
        Output output8 = (Output) findFirst8.get();
        Assertions.assertEquals(Type.COMPOSITE, output8.getType());
        Assertions.assertTrue(output8.getValue().getUnderlyingObject() instanceof List);
        List list4 = (List) output8.getValue().getUnderlyingObject();
        Assertions.assertEquals(1, list4.size());
        Output output9 = (Output) list4.get(0);
        Assertions.assertEquals(Type.NUMBER, output9.getType());
        Assertions.assertEquals(100.0d, output9.getValue().asNumber());
    }

    @Test
    public void testCreateSucceededResultDto() {
        LIMEExplainabilityResultDto createSucceededResultDto = this.handler.createSucceededResultDto(new LIMEExplainabilityRequest("executionId", SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyMap(), Collections.emptyMap()), Map.of("s1", new Saliency(new Output("salary", Type.NUMBER), List.of(new FeatureImportance(new Feature("age", Type.NUMBER, new Value(Double.valueOf(25.0d))), 5.0d), new FeatureImportance(new Feature("dependents", Type.NUMBER, new Value(2)), -11.0d)))));
        Assertions.assertTrue(createSucceededResultDto instanceof LIMEExplainabilityResultDto);
        LIMEExplainabilityResultDto lIMEExplainabilityResultDto = createSucceededResultDto;
        Assertions.assertEquals(ExplainabilityStatus.SUCCEEDED, lIMEExplainabilityResultDto.getStatus());
        Assertions.assertEquals("executionId", lIMEExplainabilityResultDto.getExecutionId());
        Assertions.assertEquals(1, lIMEExplainabilityResultDto.getSaliencies().size());
        Assertions.assertTrue(lIMEExplainabilityResultDto.getSaliencies().containsKey("s1"));
        SaliencyDto saliencyDto = (SaliencyDto) lIMEExplainabilityResultDto.getSaliencies().get("s1");
        Assertions.assertEquals(2, saliencyDto.getFeatureImportance().size());
        Assertions.assertEquals("age", ((FeatureImportanceDto) saliencyDto.getFeatureImportance().get(0)).getFeatureName());
        Assertions.assertEquals(5.0d, ((FeatureImportanceDto) saliencyDto.getFeatureImportance().get(0)).getScore());
        Assertions.assertEquals("dependents", ((FeatureImportanceDto) saliencyDto.getFeatureImportance().get(1)).getFeatureName());
        Assertions.assertEquals(-11.0d, ((FeatureImportanceDto) saliencyDto.getFeatureImportance().get(1)).getScore());
    }

    @Test
    public void testCreateFailedResultDto() {
        LIMEExplainabilityResultDto createFailedResultDto = this.handler.createFailedResultDto(new LIMEExplainabilityRequest("executionId", SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyMap(), Collections.emptyMap()), new NullPointerException("Something went wrong"));
        Assertions.assertTrue(createFailedResultDto instanceof LIMEExplainabilityResultDto);
        LIMEExplainabilityResultDto lIMEExplainabilityResultDto = createFailedResultDto;
        Assertions.assertEquals(ExplainabilityStatus.FAILED, lIMEExplainabilityResultDto.getStatus());
        Assertions.assertEquals("Something went wrong", lIMEExplainabilityResultDto.getStatusDetails());
        Assertions.assertEquals("executionId", lIMEExplainabilityResultDto.getExecutionId());
    }

    @Test
    public void testExplainAsyncDelegation() {
        Prediction prediction = (Prediction) Mockito.mock(Prediction.class);
        PredictionProvider predictionProvider = (PredictionProvider) Mockito.mock(PredictionProvider.class);
        this.handler.explainAsync(prediction, predictionProvider);
        ((LimeExplainer) Mockito.verify(this.explainer)).explainAsync((Prediction) ArgumentMatchers.eq(prediction), (PredictionProvider) ArgumentMatchers.eq(predictionProvider));
    }
}
