package org.kie.kogito.explainability;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.model.ModelIdentifier;
import org.kie.kogito.explainability.model.PredictInput;
import org.kie.kogito.explainability.model.PredictOutput;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;

/* loaded from: input_file:org/kie/kogito/explainability/SpringBootExplainableResourceTest.class */
class SpringBootExplainableResourceTest {
    SpringBootExplainableResource resource = new SpringBootExplainableResource(new ApplicationMock());

    SpringBootExplainableResourceTest() {
    }

    @Test
    void explainServiceTest() {
        List list = (List) this.resource.predict(Collections.singletonList(createInput(40))).getBody();
        Assertions.assertNotNull(list);
        Assertions.assertEquals(1, list.size());
        PredictOutput predictOutput = (PredictOutput) list.get(0);
        Assertions.assertNotNull(predictOutput.getResult());
        Assertions.assertNotNull(predictOutput.getModelIdentifier());
        Map result = predictOutput.getResult();
        Assertions.assertTrue(result.containsKey("Should the driver be suspended?"));
        Assertions.assertEquals("Yes", result.get("Should the driver be suspended?"));
        Assertions.assertTrue(result.containsKey("Fine"));
        HashMap hashMap = new HashMap();
        hashMap.put("Points", BigDecimal.valueOf(7L));
        hashMap.put("Amount", BigDecimal.valueOf(1000L));
        Assertions.assertEquals(hashMap.get("Points"), ((Map) result.get("Fine")).get("Points"));
        Assertions.assertEquals(hashMap.get("Amount"), ((Map) result.get("Fine")).get("Amount"));
    }

    @Test
    void explainServiceTestMultipleInputs() {
        List list = (List) this.resource.predict(Arrays.asList(createInput(40), createInput(120))).getBody();
        Assertions.assertNotNull(list);
        Assertions.assertEquals(2, list.size());
        PredictOutput predictOutput = (PredictOutput) list.get(1);
        Assertions.assertNotNull(predictOutput);
        Assertions.assertNotNull(predictOutput.getResult());
        Assertions.assertNotNull(predictOutput.getModelIdentifier());
        Map result = predictOutput.getResult();
        Assertions.assertTrue(result.containsKey("Should the driver be suspended?"));
        Assertions.assertEquals("No", result.get("Should the driver be suspended?"));
        Assertions.assertNull(result.get("Fine"));
    }

    @Test
    void explainServiceTestNoInputs() {
        List list = (List) this.resource.predict(Collections.emptyList()).getBody();
        Assertions.assertNotNull(list);
        Assertions.assertEquals(0, list.size());
    }

    @Test
    void explainServiceFail() {
        PredictInput createInput = createInput(10);
        createInput.getModelIdentifier().setResourceId("unknown:model");
        ResponseEntity predict = this.resource.predict(Collections.singletonList(createInput));
        Assertions.assertEquals(HttpStatus.BAD_REQUEST, predict.getStatusCode());
        Assertions.assertEquals("Model unknown:model not found.", predict.getBody());
    }

    private PredictInput createInput(int i) {
        String format = String.format("%s:%s", Constants.MODEL_NAMESPACE, Constants.MODEL_NAME);
        HashMap hashMap = new HashMap();
        hashMap.put("Age", 25);
        hashMap.put("Points", 100);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("Type", "speed");
        hashMap2.put("Actual Speed", 120);
        hashMap2.put("Speed Limit", Integer.valueOf(i));
        HashMap hashMap3 = new HashMap();
        hashMap3.put("Driver", hashMap);
        hashMap3.put("Violation", hashMap2);
        return new PredictInput(new ModelIdentifier("dmn", format), hashMap3);
    }
}
