package org.kie.kogito.explainability;

import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.Assertions;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.ValidationUtils;
import org.mockito.Mockito;

/* loaded from: input_file:org/kie/kogito/explainability/TestUtils.class */
public class TestUtils {
    public static PredictionProvider getFeaturePassModel(int i) {
        return list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    Feature feature = (Feature) ((PredictionInput) it.next()).getFeatures().get(i);
                    linkedList.add(new PredictionOutput(List.of(new Output("feature-" + i, feature.getType(), feature.getValue(), 1.0d))));
                }
                return linkedList;
            });
        };
    }

    public static PredictionProvider getSumSkipModel(int i) {
        return list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    List features = ((PredictionInput) it.next()).getFeatures();
                    double d = 0.0d;
                    for (int i2 = 0; i2 < features.size(); i2++) {
                        if (i != i2) {
                            d += ((Feature) features.get(i2)).getValue().asNumber();
                        }
                    }
                    linkedList.add(new PredictionOutput(List.of(new Output("sum-but" + i, Type.NUMBER, new Value(Double.valueOf(d)), 1.0d))));
                }
                return linkedList;
            });
        };
    }

    public static PredictionProvider getEvenFeatureModel(int i) {
        return list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    linkedList.add(new PredictionOutput(List.of(new Output("feature-" + i, Type.BOOLEAN, new Value(Boolean.valueOf(((Feature) ((PredictionInput) it.next()).getFeatures().get(i)).getValue().asNumber() % 2.0d == 0.0d)), 1.0d))));
                }
                return linkedList;
            });
        };
    }

    public static PredictionProvider getEvenSumModel(int i) {
        return list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    List features = ((PredictionInput) it.next()).getFeatures();
                    double d = 0.0d;
                    for (int i2 = 0; i2 < features.size(); i2++) {
                        if (i != i2) {
                            d += ((Feature) features.get(i2)).getValue().asNumber();
                        }
                    }
                    linkedList.add(new PredictionOutput(List.of(new Output("sum-even-but" + i, Type.BOOLEAN, new Value(Boolean.valueOf(((int) d) % 2 == 0)), 1.0d))));
                }
                return linkedList;
            });
        };
    }

    public static PredictionProvider getSumThresholdModel(double d, double d2) {
        return list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    List features = ((PredictionInput) it.next()).getFeatures();
                    double d3 = 0.0d;
                    for (int i = 0; i < features.size(); i++) {
                        d3 += ((Feature) features.get(i)).getValue().asNumber();
                    }
                    linkedList.add(new PredictionOutput(List.of(new Output("inside", Type.BOOLEAN, new Value(Boolean.valueOf(d3 >= d - d2 && d3 <= d + d2)), 1.0d - Math.abs(d3 - d)))));
                }
                return linkedList;
            });
        };
    }

    public static PredictionProvider getDummyTextClassifier() {
        List asList = Arrays.asList("money", "$", "£", "bitcoin");
        return list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    boolean z = false;
                    for (Feature feature : ((PredictionInput) it.next()).getFeatures()) {
                        if (!z) {
                            String[] split = feature.getValue().asString().split(" ");
                            int length = split.length;
                            int i = 0;
                            while (true) {
                                if (i >= length) {
                                    break;
                                }
                                if (asList.contains(split[i])) {
                                    z = true;
                                    break;
                                }
                                i++;
                            }
                        }
                    }
                    linkedList.add(new PredictionOutput(List.of(new Output("spam", Type.BOOLEAN, new Value(Boolean.valueOf(z)), 1.0d))));
                }
                return linkedList;
            });
        };
    }

    public static PredictionProvider getSymbolicArithmeticModel() {
        return list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    List<Feature> features = ((PredictionInput) it.next()).getFeatures();
                    Optional findFirst = features.stream().filter(feature -> {
                        return "operand".equals(feature.getName());
                    }).map(feature2 -> {
                        return feature2.getValue().asString();
                    }).findFirst();
                    if (!findFirst.isPresent()) {
                        throw new IllegalArgumentException("No valid operand found in features");
                    }
                    String str = (String) findFirst.get();
                    double d = 0.0d;
                    for (Feature feature3 : features) {
                        if (!"operand".equals(feature3.getName())) {
                            boolean z = -1;
                            switch (str.hashCode()) {
                                case 42:
                                    if (str.equals("*")) {
                                        z = 2;
                                        break;
                                    }
                                    break;
                                case 43:
                                    if (str.equals("+")) {
                                        z = false;
                                        break;
                                    }
                                    break;
                                case 45:
                                    if (str.equals("-")) {
                                        z = true;
                                        break;
                                    }
                                    break;
                                case 47:
                                    if (str.equals("/")) {
                                        z = 3;
                                        break;
                                    }
                                    break;
                            }
                        }
                    }
                    linkedList.add(new PredictionOutput(List.of(new Output("result", Type.NUMBER, new Value(Double.valueOf(d)), 1.0d))));
                }
                return linkedList;
            });
        };
    }

    public static PredictionProvider getFixedOutputClassifier() {
        return list -> {
            return CompletableFuture.supplyAsync(() -> {
                LinkedList linkedList = new LinkedList();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    linkedList.add(new PredictionOutput(List.of(new Output("class", Type.BOOLEAN, new Value(false), 1.0d))));
                }
                return linkedList;
            });
        };
    }

    public static Feature getMockedNumericFeature() {
        return getMockedNumericFeature(1.0d);
    }

    public static Feature getMockedFeature(Type type, Value value) {
        Feature feature = (Feature) Mockito.mock(Feature.class);
        Mockito.when(feature.getType()).thenReturn(type);
        Mockito.when(feature.getName()).thenReturn("f-" + type.name());
        Mockito.when(feature.getValue()).thenReturn(value);
        return feature;
    }

    public static Feature getMockedTextFeature(String str) {
        Feature feature = (Feature) Mockito.mock(Feature.class);
        Mockito.when(feature.getType()).thenReturn(Type.TEXT);
        Mockito.when(feature.getName()).thenReturn("f-text");
        Value value = (Value) Mockito.mock(Value.class);
        Mockito.when(value.getUnderlyingObject()).thenReturn(str);
        Mockito.when(Double.valueOf(value.asNumber())).thenReturn(Double.valueOf(Double.NaN));
        Mockito.when(value.asString()).thenReturn(str);
        Mockito.when(feature.getValue()).thenReturn(value);
        return feature;
    }

    public static Feature getMockedNumericFeature(double d) {
        Feature feature = (Feature) Mockito.mock(Feature.class);
        Mockito.when(feature.getType()).thenReturn(Type.NUMBER);
        Mockito.when(feature.getName()).thenReturn("f-num");
        Value value = (Value) Mockito.mock(Value.class);
        Mockito.when(value.getUnderlyingObject()).thenReturn(Double.valueOf(d));
        Mockito.when(Double.valueOf(value.asNumber())).thenReturn(Double.valueOf(d));
        Mockito.when(value.asString()).thenReturn(String.valueOf(d));
        Mockito.when(feature.getValue()).thenReturn(value);
        return feature;
    }

    public static void assertLimeStability(PredictionProvider predictionProvider, Prediction prediction, LimeExplainer limeExplainer, int i, double d, double d2) {
        Assertions.assertDoesNotThrow(() -> {
            ValidationUtils.validateLocalSaliencyStability(predictionProvider, prediction, limeExplainer, i, d, d2);
        });
    }

    public static void fillBalancedDataForFiltering(int i, List<Pair<double[], Double>> list, double[] dArr) {
        for (int i2 = 0; i2 < i; i2++) {
            double[] dArr2 = new double[2];
            for (int i3 = 0; i3 < 2; i3++) {
                dArr2[i3] = (i2 + i3) % 2 == 0 ? 0.0d : 1.0d;
            }
            list.add(Pair.of(dArr2, Double.valueOf(i2 % 3 == 0 ? 0.0d : 1.0d)));
            dArr[i2] = i2 % 2 == 0 ? 0.2d : 0.8d;
        }
    }
}
