/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.kie.kogito.Application;
import org.kie.kogito.explainability.ApplicationMock;
import org.kie.kogito.explainability.SpringBootExplainableResource;
import org.kie.kogito.explainability.model.ModelIdentifier;
import org.kie.kogito.explainability.model.PredictInput;
import org.kie.kogito.explainability.model.PredictOutput;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;

class SpringBootExplainableResourceTest {
    SpringBootExplainableResource resource = new SpringBootExplainableResource((Application)new ApplicationMock());

    SpringBootExplainableResourceTest() {
    }

    @Test
    void explainServiceTest() {
        List<PredictInput> inputs = Collections.singletonList(this.createInput(40));
        List outputs = (List)this.resource.predict(inputs).getBody();
        Assertions.assertNotNull((Object)outputs);
        Assertions.assertEquals((int)1, (int)outputs.size());
        PredictOutput output = (PredictOutput)outputs.get(0);
        Assertions.assertNotNull((Object)output.getResult());
        Assertions.assertNotNull((Object)output.getModelIdentifier());
        Map result = output.getResult();
        Assertions.assertTrue((boolean)result.containsKey("Should the driver be suspended?"));
        Assertions.assertEquals((Object)"Yes", result.get("Should the driver be suspended?"));
        Assertions.assertTrue((boolean)result.containsKey("Fine"));
        HashMap<String, BigDecimal> expectedFine = new HashMap<String, BigDecimal>();
        expectedFine.put("Points", BigDecimal.valueOf(7L));
        expectedFine.put("Amount", BigDecimal.valueOf(1000L));
        Assertions.assertEquals(expectedFine.get("Points"), ((Map)result.get("Fine")).get("Points"));
        Assertions.assertEquals(expectedFine.get("Amount"), ((Map)result.get("Fine")).get("Amount"));
    }

    @Test
    void explainServiceTestMultipleInputs() {
        List<PredictInput> inputs = Arrays.asList(this.createInput(40), this.createInput(120));
        List outputs = (List)this.resource.predict(inputs).getBody();
        Assertions.assertNotNull((Object)outputs);
        Assertions.assertEquals((int)2, (int)outputs.size());
        PredictOutput output = (PredictOutput)outputs.get(1);
        Assertions.assertNotNull((Object)output);
        Assertions.assertNotNull((Object)output.getResult());
        Assertions.assertNotNull((Object)output.getModelIdentifier());
        Map result = output.getResult();
        Assertions.assertTrue((boolean)result.containsKey("Should the driver be suspended?"));
        Assertions.assertEquals((Object)"No", result.get("Should the driver be suspended?"));
        Assertions.assertNull(result.get("Fine"));
    }

    @Test
    void explainServiceTestNoInputs() {
        List outputs = (List)this.resource.predict(Collections.emptyList()).getBody();
        Assertions.assertNotNull((Object)outputs);
        Assertions.assertEquals((int)0, (int)outputs.size());
    }

    @Test
    void explainServiceFail() {
        String unknownwResourceId = "unknown:model";
        PredictInput input = this.createInput(10);
        input.getModelIdentifier().setResourceId(unknownwResourceId);
        ResponseEntity responseEntity = this.resource.predict(Collections.singletonList(input));
        Assertions.assertEquals((int)HttpStatus.BAD_REQUEST.value(), (int)responseEntity.getStatusCodeValue());
        Assertions.assertEquals((Object)("Model " + unknownwResourceId + " not found."), (Object)responseEntity.getBody());
    }

    private PredictInput createInput(int speedLimit) {
        String resourceId = String.format("%s:%s", "https://github.com/kiegroup/drools/kie-dmn/_A4BCA8B8-CF08-433F-93B2-A2598F19ECFF", "Traffic Violation");
        HashMap<String, Integer> driver = new HashMap<String, Integer>();
        driver.put("Age", 25);
        driver.put("Points", 100);
        HashMap<String, Object> violation = new HashMap<String, Object>();
        violation.put("Type", "speed");
        violation.put("Actual Speed", 120);
        violation.put("Speed Limit", speedLimit);
        HashMap<String, HashMap<String, Object>> payload = new HashMap<String, HashMap<String, Object>>();
        payload.put("Driver", driver);
        payload.put("Violation", violation);
        ModelIdentifier modelIdentifier = new ModelIdentifier("dmn", resourceId);
        return new PredictInput(modelIdentifier, payload);
    }
}

