package org.kie.kogito.trusty.service.common.api;

import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.mockito.InjectMock;
import io.restassured.RestAssured;
import io.restassured.filter.log.RequestLoggingFilter;
import io.restassured.filter.log.ResponseLoggingFilter;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.tracing.typedvalue.TypedValue;
import org.kie.kogito.trusty.service.common.TrustyService;
import org.kie.kogito.trusty.service.common.TrustyServiceTestUtils;
import org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse;
import org.kie.kogito.trusty.service.common.responses.CounterfactualResultsResponse;
import org.kie.kogito.trusty.service.common.responses.SalienciesResponse;
import org.kie.kogito.trusty.storage.api.model.BaseExplainabilityResult;
import org.kie.kogito.trusty.storage.api.model.CounterfactualDomainCategorical;
import org.kie.kogito.trusty.storage.api.model.CounterfactualDomainRange;
import org.kie.kogito.trusty.storage.api.model.CounterfactualExplainabilityRequest;
import org.kie.kogito.trusty.storage.api.model.CounterfactualExplainabilityResult;
import org.kie.kogito.trusty.storage.api.model.CounterfactualSearchDomain;
import org.kie.kogito.trusty.storage.api.model.ExplainabilityStatus;
import org.kie.kogito.trusty.storage.api.model.FeatureImportanceModel;
import org.kie.kogito.trusty.storage.api.model.LIMEExplainabilityResult;
import org.kie.kogito.trusty.storage.api.model.SaliencyModel;
import org.kie.kogito.trusty.storage.api.model.TypedVariableWithValue;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.testcontainers.shaded.org.apache.commons.lang.builder.CompareToBuilder;

@QuarkusTest
/* loaded from: input_file:org/kie/kogito/trusty/service/common/api/ExplainabilityApiV1IT.class */
class ExplainabilityApiV1IT {
    private static final String TEST_EXECUTION_ID = "executionId";
    private static final String TEST_COUNTERFACTUAL_ID = "counterfactualId";
    private static final CounterfactualExplainabilityResult SOLUTION1 = new CounterfactualExplainabilityResult(TEST_EXECUTION_ID, TEST_COUNTERFACTUAL_ID, "solution1", 0L, ExplainabilityStatus.SUCCEEDED, "", true, CounterfactualExplainabilityResult.Stage.INTERMEDIATE, Collections.emptyList(), Collections.emptyList());
    private static final CounterfactualExplainabilityResult SOLUTION2 = new CounterfactualExplainabilityResult(TEST_EXECUTION_ID, TEST_COUNTERFACTUAL_ID, "solution2", 1L, ExplainabilityStatus.SUCCEEDED, "", true, CounterfactualExplainabilityResult.Stage.FINAL, Collections.emptyList(), Collections.emptyList());

    @InjectMock
    TrustyService executionService;

    ExplainabilityApiV1IT() {
    }

    private static BaseExplainabilityResult buildValidExplainabilityResult() {
        return new LIMEExplainabilityResult(TEST_EXECUTION_ID, ExplainabilityStatus.SUCCEEDED, (String) null, List.of(new SaliencyModel("O1", "Output1", List.of(new FeatureImportanceModel("Feature1", Double.valueOf(0.49384d)), new FeatureImportanceModel("Feature2", Double.valueOf(-0.1084d)))), new SaliencyModel("O2", "Output2", List.of(new FeatureImportanceModel("Feature1", Double.valueOf(0.0d)), new FeatureImportanceModel("Feature2", Double.valueOf(0.70293d))))));
    }

    private static CounterfactualExplainabilityRequest buildValidCounterfactual() {
        return new CounterfactualExplainabilityRequest(TEST_EXECUTION_ID, TEST_COUNTERFACTUAL_ID);
    }

    private static List<CounterfactualExplainabilityResult> buildValidCounterfactualResults() {
        return List.of(SOLUTION1, SOLUTION2);
    }

    @Test
    void testSalienciesWithExplainabilityResult() {
        mockServiceWithExplainabilityResult();
        SalienciesResponse salienciesResponse = (SalienciesResponse) RestAssured.given().filter(new ResponseLoggingFilter()).when().get("/executions/decisions/executionId/explanations/saliencies", new Object[0]).as(SalienciesResponse.class);
        Assertions.assertNotNull(salienciesResponse);
        Assertions.assertNotNull(salienciesResponse.getSaliencies());
        Assertions.assertSame(2, Integer.valueOf(salienciesResponse.getSaliencies().size()));
        List list = (List) salienciesResponse.getSaliencies().stream().sorted((saliencyModel, saliencyModel2) -> {
            return new CompareToBuilder().append(saliencyModel.getOutcomeName(), saliencyModel2.getOutcomeName()).toComparison();
        }).collect(Collectors.toList());
        Assertions.assertNotNull(list.get(0));
        Assertions.assertEquals("Output1", ((SaliencyModel) list.get(0)).getOutcomeName());
        Assertions.assertNotNull(((SaliencyModel) list.get(0)).getFeatureImportance());
        Assertions.assertSame(2, Integer.valueOf(((SaliencyModel) list.get(0)).getFeatureImportance().size()));
        Assertions.assertEquals("Feature1", ((FeatureImportanceModel) ((SaliencyModel) list.get(0)).getFeatureImportance().get(0)).getFeatureName());
        Assertions.assertEquals(0.49384d, ((FeatureImportanceModel) ((SaliencyModel) list.get(0)).getFeatureImportance().get(0)).getFeatureScore());
        Assertions.assertEquals("Feature2", ((FeatureImportanceModel) ((SaliencyModel) list.get(0)).getFeatureImportance().get(1)).getFeatureName());
        Assertions.assertEquals(-0.1084d, ((FeatureImportanceModel) ((SaliencyModel) list.get(0)).getFeatureImportance().get(1)).getFeatureScore());
        Assertions.assertNotNull(list.get(1));
        Assertions.assertEquals("Output2", ((SaliencyModel) list.get(1)).getOutcomeName());
        Assertions.assertNotNull(((SaliencyModel) list.get(1)).getFeatureImportance());
        Assertions.assertSame(2, Integer.valueOf(((SaliencyModel) list.get(1)).getFeatureImportance().size()));
        Assertions.assertEquals("Feature1", ((FeatureImportanceModel) ((SaliencyModel) list.get(1)).getFeatureImportance().get(0)).getFeatureName());
        Assertions.assertEquals(0.0d, ((FeatureImportanceModel) ((SaliencyModel) list.get(1)).getFeatureImportance().get(0)).getFeatureScore());
        Assertions.assertEquals("Feature2", ((FeatureImportanceModel) ((SaliencyModel) list.get(1)).getFeatureImportance().get(1)).getFeatureName());
        Assertions.assertEquals(0.70293d, ((FeatureImportanceModel) ((SaliencyModel) list.get(1)).getFeatureImportance().get(1)).getFeatureScore());
    }

    @Test
    void testSalienciesWithNullExplainabilityResult() {
        mockServiceWithNullExplainabilityResult();
        RestAssured.given().filter(new ResponseLoggingFilter()).when().get("/executions/decisions/executionId/explanations/saliencies", new Object[0]).then().statusCode(400);
    }

    @Test
    void testSalienciesWithoutExplainabilityResult() {
        mockServiceWithoutExplainabilityResult();
        RestAssured.given().filter(new ResponseLoggingFilter()).when().get("/executions/decisions/executionId/explanations/saliencies", new Object[0]).then().statusCode(400);
    }

    @Test
    void testCounterfactualRequest() {
        ArgumentCaptor forClass = ArgumentCaptor.forClass(List.class);
        ArgumentCaptor forClass2 = ArgumentCaptor.forClass(List.class);
        mockServiceWithCounterfactualRequest();
        CounterfactualRequestResponse counterfactualRequestResponse = (CounterfactualRequestResponse) RestAssured.given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).contentType("application/json").body(TrustyServiceTestUtils.getCounterfactualJsonRequest()).when().post("/executions/decisions/executionId/explanations/counterfactuals", new Object[0]).as(CounterfactualRequestResponse.class);
        Assertions.assertNotNull(counterfactualRequestResponse);
        Assertions.assertNotNull(counterfactualRequestResponse.getExecutionId());
        Assertions.assertNotNull(counterfactualRequestResponse.getCounterfactualId());
        Assertions.assertEquals(counterfactualRequestResponse.getExecutionId(), TEST_EXECUTION_ID);
        Assertions.assertEquals(counterfactualRequestResponse.getCounterfactualId(), TEST_COUNTERFACTUAL_ID);
        ((TrustyService) Mockito.verify(this.executionService)).requestCounterfactuals((String) ArgumentMatchers.eq(TEST_EXECUTION_ID), (List) forClass.capture(), (List) forClass2.capture());
        List list = (List) forClass.getValue();
        Assertions.assertNotNull(list);
        Assertions.assertEquals(2, list.size());
        TypedVariableWithValue typedVariableWithValue = (TypedVariableWithValue) list.get(0);
        Assertions.assertEquals(TypedValue.Kind.UNIT, typedVariableWithValue.getKind());
        Assertions.assertEquals("deposit", typedVariableWithValue.getName());
        Assertions.assertEquals("number", typedVariableWithValue.getTypeRef());
        Assertions.assertEquals(5000, typedVariableWithValue.getValue().asInt());
        TypedVariableWithValue typedVariableWithValue2 = (TypedVariableWithValue) list.get(1);
        Assertions.assertEquals(TypedValue.Kind.UNIT, typedVariableWithValue2.getKind());
        Assertions.assertEquals("approved", typedVariableWithValue2.getName());
        Assertions.assertEquals("boolean", typedVariableWithValue2.getTypeRef());
        Assertions.assertEquals(Boolean.TRUE, Boolean.valueOf(typedVariableWithValue2.getValue().asBoolean()));
        List list2 = (List) forClass2.getValue();
        Assertions.assertNotNull(list2);
        Assertions.assertEquals(3, list2.size());
        CounterfactualSearchDomain counterfactualSearchDomain = (CounterfactualSearchDomain) list2.get(0);
        Assertions.assertTrue(counterfactualSearchDomain.isFixed().booleanValue());
        Assertions.assertEquals(TypedValue.Kind.UNIT, counterfactualSearchDomain.getKind());
        Assertions.assertEquals("age", counterfactualSearchDomain.getName());
        Assertions.assertEquals("number", counterfactualSearchDomain.getTypeRef());
        Assertions.assertNull(counterfactualSearchDomain.getDomain());
        CounterfactualSearchDomain counterfactualSearchDomain2 = (CounterfactualSearchDomain) list2.get(1);
        Assertions.assertFalse(counterfactualSearchDomain2.isFixed().booleanValue());
        Assertions.assertEquals(TypedValue.Kind.UNIT, counterfactualSearchDomain2.getKind());
        Assertions.assertEquals("income", counterfactualSearchDomain2.getName());
        Assertions.assertEquals("number", counterfactualSearchDomain2.getTypeRef());
        Assertions.assertNotNull(counterfactualSearchDomain2.getDomain());
        Assertions.assertTrue(counterfactualSearchDomain2.getDomain() instanceof CounterfactualDomainRange);
        CounterfactualDomainRange domain = counterfactualSearchDomain2.getDomain();
        Assertions.assertEquals(0, domain.getLowerBound().asInt());
        Assertions.assertEquals(1000, domain.getUpperBound().asInt());
        CounterfactualSearchDomain counterfactualSearchDomain3 = (CounterfactualSearchDomain) list2.get(2);
        Assertions.assertFalse(counterfactualSearchDomain3.isFixed().booleanValue());
        Assertions.assertEquals(TypedValue.Kind.UNIT, counterfactualSearchDomain3.getKind());
        Assertions.assertEquals("taxCode", counterfactualSearchDomain3.getName());
        Assertions.assertEquals("string", counterfactualSearchDomain3.getTypeRef());
        Assertions.assertNotNull(counterfactualSearchDomain3.getDomain());
        Assertions.assertTrue(counterfactualSearchDomain3.getDomain() instanceof CounterfactualDomainCategorical);
        CounterfactualDomainCategorical domain2 = counterfactualSearchDomain3.getDomain();
        Assertions.assertEquals(3, domain2.getCategories().size());
        Assertions.assertTrue(((List) domain2.getCategories().stream().map((v0) -> {
            return v0.asText();
        }).collect(Collectors.toList())).containsAll(Arrays.asList("A", "B", "C")));
    }

    @Test
    void testCounterfactualRequestWithStructuredModel() {
        ArgumentCaptor forClass = ArgumentCaptor.forClass(List.class);
        ArgumentCaptor forClass2 = ArgumentCaptor.forClass(List.class);
        mockServiceWithCounterfactualRequest();
        CounterfactualRequestResponse counterfactualRequestResponse = (CounterfactualRequestResponse) RestAssured.given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).contentType("application/json").body(TrustyServiceTestUtils.getCounterfactualWithStructuredModelJsonRequest()).when().post("/executions/decisions/executionId/explanations/counterfactuals", new Object[0]).as(CounterfactualRequestResponse.class);
        Assertions.assertNotNull(counterfactualRequestResponse);
        Assertions.assertNotNull(counterfactualRequestResponse.getExecutionId());
        Assertions.assertNotNull(counterfactualRequestResponse.getCounterfactualId());
        Assertions.assertEquals(counterfactualRequestResponse.getExecutionId(), TEST_EXECUTION_ID);
        Assertions.assertEquals(counterfactualRequestResponse.getCounterfactualId(), TEST_COUNTERFACTUAL_ID);
        ((TrustyService) Mockito.verify(this.executionService)).requestCounterfactuals((String) ArgumentMatchers.eq(TEST_EXECUTION_ID), (List) forClass.capture(), (List) forClass2.capture());
        List list = (List) forClass.getValue();
        Assertions.assertNotNull(list);
        Assertions.assertEquals(1, list.size());
        TypedVariableWithValue typedVariableWithValue = (TypedVariableWithValue) list.get(0);
        Assertions.assertEquals(TypedValue.Kind.STRUCTURE, typedVariableWithValue.getKind());
        Assertions.assertEquals("Fine", typedVariableWithValue.getName());
        Assertions.assertEquals("tFine", typedVariableWithValue.getTypeRef());
        Assertions.assertEquals(2, typedVariableWithValue.getComponents().size());
        Iterator it = typedVariableWithValue.getComponents().iterator();
        TypedVariableWithValue typedVariableWithValue2 = (TypedVariableWithValue) it.next();
        TypedVariableWithValue typedVariableWithValue3 = (TypedVariableWithValue) it.next();
        Assertions.assertEquals(TypedValue.Kind.UNIT, typedVariableWithValue2.getKind());
        Assertions.assertEquals("Amount", typedVariableWithValue2.getName());
        Assertions.assertEquals("number", typedVariableWithValue2.getTypeRef());
        Assertions.assertEquals(100, typedVariableWithValue2.getValue().asInt());
        Assertions.assertNull(typedVariableWithValue2.getComponents());
        Assertions.assertEquals(TypedValue.Kind.UNIT, typedVariableWithValue3.getKind());
        Assertions.assertEquals("Points", typedVariableWithValue3.getName());
        Assertions.assertEquals("number", typedVariableWithValue3.getTypeRef());
        Assertions.assertEquals(0, typedVariableWithValue3.getValue().asInt());
        Assertions.assertNull(typedVariableWithValue3.getComponents());
        List list2 = (List) forClass2.getValue();
        Assertions.assertNotNull(list2);
        Assertions.assertEquals(1, list2.size());
        CounterfactualSearchDomain counterfactualSearchDomain = (CounterfactualSearchDomain) list2.get(0);
        Assertions.assertFalse(counterfactualSearchDomain.isFixed().booleanValue());
        Assertions.assertEquals(TypedValue.Kind.STRUCTURE, counterfactualSearchDomain.getKind());
        Assertions.assertEquals("Violation", counterfactualSearchDomain.getName());
        Assertions.assertEquals("tViolation", counterfactualSearchDomain.getTypeRef());
        Assertions.assertNull(counterfactualSearchDomain.getDomain());
        Assertions.assertEquals(3, counterfactualSearchDomain.getComponents().size());
        Iterator it2 = counterfactualSearchDomain.getComponents().iterator();
        CounterfactualSearchDomain counterfactualSearchDomain2 = (CounterfactualSearchDomain) it2.next();
        CounterfactualSearchDomain counterfactualSearchDomain3 = (CounterfactualSearchDomain) it2.next();
        CounterfactualSearchDomain counterfactualSearchDomain4 = (CounterfactualSearchDomain) it2.next();
        Assertions.assertFalse(counterfactualSearchDomain2.isFixed().booleanValue());
        Assertions.assertEquals(TypedValue.Kind.UNIT, counterfactualSearchDomain2.getKind());
        Assertions.assertEquals("Type", counterfactualSearchDomain2.getName());
        Assertions.assertEquals("string", counterfactualSearchDomain2.getTypeRef());
        Assertions.assertNotNull(counterfactualSearchDomain2.getDomain());
        Assertions.assertTrue(counterfactualSearchDomain2.getDomain() instanceof CounterfactualDomainCategorical);
        CounterfactualDomainCategorical domain = counterfactualSearchDomain2.getDomain();
        Assertions.assertEquals(2, domain.getCategories().size());
        Assertions.assertTrue(((List) domain.getCategories().stream().map((v0) -> {
            return v0.asText();
        }).collect(Collectors.toList())).containsAll(Arrays.asList("speed", "driving under the influence")));
        Assertions.assertFalse(counterfactualSearchDomain3.isFixed().booleanValue());
        Assertions.assertEquals(TypedValue.Kind.UNIT, counterfactualSearchDomain3.getKind());
        Assertions.assertEquals("Actual Speed", counterfactualSearchDomain3.getName());
        Assertions.assertEquals("number", counterfactualSearchDomain3.getTypeRef());
        Assertions.assertNotNull(counterfactualSearchDomain3.getDomain());
        Assertions.assertTrue(counterfactualSearchDomain3.getDomain() instanceof CounterfactualDomainRange);
        CounterfactualDomainRange domain2 = counterfactualSearchDomain3.getDomain();
        Assertions.assertEquals(0, domain2.getLowerBound().asInt());
        Assertions.assertEquals(100, domain2.getUpperBound().asInt());
        Assertions.assertFalse(counterfactualSearchDomain4.isFixed().booleanValue());
        Assertions.assertEquals(TypedValue.Kind.UNIT, counterfactualSearchDomain4.getKind());
        Assertions.assertEquals("Speed Limit", counterfactualSearchDomain4.getName());
        Assertions.assertEquals("number", counterfactualSearchDomain4.getTypeRef());
        Assertions.assertNotNull(counterfactualSearchDomain4.getDomain());
        Assertions.assertTrue(counterfactualSearchDomain4.getDomain() instanceof CounterfactualDomainRange);
        CounterfactualDomainRange domain3 = counterfactualSearchDomain4.getDomain();
        Assertions.assertEquals(0, domain3.getLowerBound().asInt());
        Assertions.assertEquals(100, domain3.getUpperBound().asInt());
    }

    @Test
    void testCounterfactualResultsWithRequest() {
        mockServiceWithCounterfactualRequest();
        mockServiceWithCounterfactualResults();
        RestAssured.given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).contentType("application/json").body(TrustyServiceTestUtils.getCounterfactualJsonRequest()).when().post("/executions/decisions/executionId/explanations/counterfactuals", new Object[0]).as(CounterfactualRequestResponse.class);
        CounterfactualResultsResponse counterfactualResultsResponse = (CounterfactualResultsResponse) RestAssured.given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).when().get("/executions/decisions/executionId/explanations/counterfactuals/counterfactualId", new Object[0]).as(CounterfactualResultsResponse.class);
        Assertions.assertNotNull(counterfactualResultsResponse);
        Assertions.assertNotNull(counterfactualResultsResponse.getExecutionId());
        Assertions.assertNotNull(counterfactualResultsResponse.getCounterfactualId());
        Assertions.assertEquals(counterfactualResultsResponse.getExecutionId(), TEST_EXECUTION_ID);
        Assertions.assertEquals(counterfactualResultsResponse.getCounterfactualId(), TEST_COUNTERFACTUAL_ID);
        Assertions.assertEquals(2, counterfactualResultsResponse.getSolutions().size());
    }

    @Test
    void testCounterfactualResultsWithRequestWithoutResults() {
        mockServiceWithCounterfactualRequest();
        RestAssured.given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).contentType("application/json").body(TrustyServiceTestUtils.getCounterfactualJsonRequest()).when().post("/executions/decisions/executionId/explanations/counterfactuals", new Object[0]).as(CounterfactualRequestResponse.class);
        CounterfactualResultsResponse counterfactualResultsResponse = (CounterfactualResultsResponse) RestAssured.given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).when().get("/executions/decisions/executionId/explanations/counterfactuals/counterfactualId", new Object[0]).as(CounterfactualResultsResponse.class);
        Assertions.assertNotNull(counterfactualResultsResponse);
        Assertions.assertNotNull(counterfactualResultsResponse.getExecutionId());
        Assertions.assertNotNull(counterfactualResultsResponse.getCounterfactualId());
        Assertions.assertEquals(counterfactualResultsResponse.getExecutionId(), TEST_EXECUTION_ID);
        Assertions.assertEquals(counterfactualResultsResponse.getCounterfactualId(), TEST_COUNTERFACTUAL_ID);
        Assertions.assertTrue(counterfactualResultsResponse.getSolutions().isEmpty());
    }

    @Test
    void testCounterfactualResultsWithoutRequest() {
        RestAssured.given().filter(new ResponseLoggingFilter()).when().get("/executions/decisions/executionId/explanations/counterfactuals/counterfactualId", new Object[0]).then().statusCode(400);
    }

    private void mockServiceWithExplainabilityResult() {
        Mockito.when(this.executionService.getExplainabilityResultById((String) ArgumentMatchers.eq(TEST_EXECUTION_ID), (Class) ArgumentMatchers.any())).thenReturn(buildValidExplainabilityResult());
    }

    private void mockServiceWithNullExplainabilityResult() {
        Mockito.when(this.executionService.getExplainabilityResultById(ArgumentMatchers.anyString(), (Class) ArgumentMatchers.any())).thenReturn((Object) null);
    }

    private void mockServiceWithoutExplainabilityResult() {
        Mockito.when(this.executionService.getExplainabilityResultById(ArgumentMatchers.anyString(), (Class) ArgumentMatchers.any())).thenThrow(new Throwable[]{new IllegalArgumentException("Explainability result does not exist.")});
    }

    private void mockServiceWithCounterfactualRequest() {
        Mockito.when(this.executionService.requestCounterfactuals((String) ArgumentMatchers.eq(TEST_EXECUTION_ID), (List) ArgumentMatchers.any(), (List) ArgumentMatchers.any())).thenReturn(buildValidCounterfactual());
        Mockito.when(this.executionService.getCounterfactualRequest((String) ArgumentMatchers.eq(TEST_EXECUTION_ID), (String) ArgumentMatchers.eq(TEST_COUNTERFACTUAL_ID))).thenReturn(buildValidCounterfactual());
    }

    private void mockServiceWithCounterfactualResults() {
        Mockito.when(this.executionService.getCounterfactualResults((String) ArgumentMatchers.eq(TEST_EXECUTION_ID), (String) ArgumentMatchers.eq(TEST_COUNTERFACTUAL_ID))).thenReturn(buildValidCounterfactualResults());
    }
}
