package org.kie.pmml.models.regression.compiler.factories;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.body.VariableDeclarator;
import com.github.javaparser.ast.expr.CastExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.MethodReferenceExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.Statement;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
import com.github.javaparser.ast.type.Type;
import java.lang.invoke.SerializedLambda;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.regression.RegressionModel;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.exceptions.KiePMMLInternalException;
import org.kie.pmml.api.iinterfaces.SerializableFunction;
import org.kie.pmml.commons.Constants;
import org.kie.pmml.compiler.commons.utils.CommonCodegenUtils;
import org.kie.pmml.compiler.commons.utils.JavaParserUtils;
import org.kie.pmml.models.regression.compiler.dto.RegressionCompilationDTO;
import org.kie.pmml.models.regression.model.KiePMMLClassificationTable;
import org.kie.pmml.models.regression.model.KiePMMLRegressionTable;
import org.kie.pmml.models.regression.model.enums.REGRESSION_NORMALIZATION_METHOD;
import org.kie.pmml.models.regression.model.tuples.KiePMMLTableSourceCategory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:BOOT-INF/lib/kie-pmml-models-regression-compiler-8.38.1-SNAPSHOT.jar:org/kie/pmml/models/regression/compiler/factories/KiePMMLClassificationTableFactory.class */
public class KiePMMLClassificationTableFactory {
    static final String GETKIEPMML_TABLE = "getKiePMMLTable";
    static final String CATEGORICAL_TABLE_MAP = "categoryTableMap";
    public static final List<RegressionModel.NormalizationMethod> SUPPORTED_NORMALIZATION_METHODS;
    public static final List<RegressionModel.NormalizationMethod> UNSUPPORTED_NORMALIZATION_METHODS;
    private static AtomicInteger classArity;
    private static final Logger logger = LoggerFactory.getLogger(KiePMMLClassificationTableFactory.class.getName());
    static final String KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE_JAVA = "KiePMMLClassificationTableTemplate.tmpl";
    static final String KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE = "KiePMMLClassificationTableTemplate";
    static final ClassOrInterfaceDeclaration CLASSIFICATION_TABLE_TEMPLATE = JavaParserUtils.getFromFileName(KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE_JAVA).getClassByName(KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE).orElseThrow(() -> {
        return new KiePMMLException("Main class not found: KiePMMLClassificationTableTemplate");
    });

    private KiePMMLClassificationTableFactory() {
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static KiePMMLClassificationTable getClassificationTable(RegressionCompilationDTO regressionCompilationDTO) {
        logger.trace("getClassificationTable {}", regressionCompilationDTO);
        LinkedHashMap<String, KiePMMLRegressionTable> regressionTables = KiePMMLRegressionTableFactory.getRegressionTables(regressionCompilationDTO);
        boolean isBinary = regressionCompilationDTO.isBinary(regressionTables.size());
        return (KiePMMLClassificationTable) KiePMMLClassificationTable.builder(UUID.randomUUID().toString(), Collections.emptyList()).withRegressionNormalizationMethod(regressionCompilationDTO.getDefaultREGRESSION_NORMALIZATION_METHOD()).withOpType(regressionCompilationDTO.getOP_TYPE()).withCategoryTableMap(regressionTables).withProbabilityMapFunction(getProbabilityMapFunction(regressionCompilationDTO.getModelNormalizationMethod(), isBinary)).withIsBinary(Boolean.valueOf(isBinary)).withTargetField(regressionCompilationDTO.getTargetFieldName()).build();
    }

    public static Map<String, KiePMMLTableSourceCategory> getClassificationTableBuilders(RegressionCompilationDTO regressionCompilationDTO) {
        logger.trace("getRegressionTables {}", regressionCompilationDTO.getRegressionTables());
        LinkedHashMap<String, KiePMMLTableSourceCategory> regressionTableBuilders = KiePMMLRegressionTableFactory.getRegressionTableBuilders(regressionCompilationDTO);
        Map.Entry<String, String> classificationTableBuilder = getClassificationTableBuilder(regressionCompilationDTO, regressionTableBuilders);
        regressionTableBuilders.put(classificationTableBuilder.getKey(), new KiePMMLTableSourceCategory(classificationTableBuilder.getValue(), ""));
        return regressionTableBuilders;
    }

    public static Map.Entry<String, String> getClassificationTableBuilder(RegressionCompilationDTO regressionCompilationDTO, LinkedHashMap<String, KiePMMLTableSourceCategory> linkedHashMap) {
        logger.trace("getRegressionTableBuilder {}", linkedHashMap);
        String str = "KiePMMLClassificationTable" + classArity.addAndGet(1);
        CompilationUnit kiePMMLModelCompilationUnit = JavaParserUtils.getKiePMMLModelCompilationUnit(str, regressionCompilationDTO.getPackageName(), KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE_JAVA, KIE_PMML_CLASSIFICATION_TABLE_TEMPLATE);
        setStaticGetter(regressionCompilationDTO, linkedHashMap, kiePMMLModelCompilationUnit.getClassByName(str).orElseThrow(() -> {
            return new KiePMMLException("Main class not found: " + str);
        }).getMethodsByName(GETKIEPMML_TABLE).get(0), str.toLowerCase());
        return new AbstractMap.SimpleEntry(JavaParserUtils.getFullClassName(kiePMMLModelCompilationUnit), kiePMMLModelCompilationUnit.toString());
    }

    static SerializableFunction<LinkedHashMap<String, Double>, LinkedHashMap<String, Double>> getProbabilityMapFunction(RegressionModel.NormalizationMethod normalizationMethod, boolean z) {
        if (UNSUPPORTED_NORMALIZATION_METHODS.contains(normalizationMethod)) {
            throw new KiePMMLInternalException(String.format("Unsupported NormalizationMethod %s", normalizationMethod));
        }
        return getProbabilityMapFunctionSupported(normalizationMethod, z);
    }

    static SerializableFunction<LinkedHashMap<String, Double>, LinkedHashMap<String, Double>> getProbabilityMapFunctionSupported(RegressionModel.NormalizationMethod normalizationMethod, boolean z) {
        switch (normalizationMethod) {
            case SOFTMAX:
                return KiePMMLClassificationTable::getSOFTMAXProbabilityMap;
            case SIMPLEMAX:
                return KiePMMLClassificationTable::getSIMPLEMAXProbabilityMap;
            case NONE:
                return z ? KiePMMLClassificationTable::getNONEBinaryProbabilityMap : KiePMMLClassificationTable::getNONEProbabilityMap;
            case LOGIT:
                return KiePMMLClassificationTable::getLOGITProbabilityMap;
            case PROBIT:
                return KiePMMLClassificationTable::getPROBITProbabilityMap;
            case CLOGLOG:
                return KiePMMLClassificationTable::getCLOGLOGProbabilityMap;
            case CAUCHIT:
                return KiePMMLClassificationTable::getCAUCHITProbabilityMap;
            default:
                throw new KiePMMLException("Unexpected NormalizationMethod " + normalizationMethod);
        }
    }

    static void setStaticGetter(RegressionCompilationDTO regressionCompilationDTO, LinkedHashMap<String, KiePMMLTableSourceCategory> linkedHashMap, MethodDeclaration methodDeclaration, String str) {
        BlockStmt orElseThrow = methodDeclaration.getBody().orElseThrow(() -> {
            return new KiePMMLException(String.format("Missing body in %s", methodDeclaration));
        });
        VariableDeclarator orElseThrow2 = CommonCodegenUtils.getVariableDeclarator(orElseThrow, Constants.TO_RETURN).orElseThrow(() -> {
            return new KiePMMLException(String.format("Missing expected variable '%s' in body %s", Constants.TO_RETURN, orElseThrow));
        });
        BlockStmt blockStmt = new BlockStmt();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        linkedHashMap.forEach((str2, kiePMMLTableSourceCategory) -> {
            MethodCallExpr methodCallExpr = new MethodCallExpr();
            methodCallExpr.setScope((Expression) new NameExpr(str2));
            methodCallExpr.setName(GETKIEPMML_TABLE);
            linkedHashMap2.put(kiePMMLTableSourceCategory.getCategory(), methodCallExpr);
        });
        String format = String.format("%s_%s", CATEGORICAL_TABLE_MAP, str);
        CommonCodegenUtils.createPopulatedLinkedHashMap(blockStmt, format, Arrays.asList(String.class.getSimpleName(), KiePMMLRegressionTable.class.getName()), linkedHashMap2);
        MethodCallExpr asMethodCallExpr = orElseThrow2.getInitializer().orElseThrow(() -> {
            return new KiePMMLException(String.format("Missing '%s' initializer in %s", Constants.TO_RETURN, orElseThrow));
        }).asMethodCallExpr();
        CommonCodegenUtils.getChainedMethodCallExprFrom("builder", asMethodCallExpr).setArgument(0, new StringLiteralExpr(str));
        REGRESSION_NORMALIZATION_METHOD defaultREGRESSION_NORMALIZATION_METHOD = regressionCompilationDTO.getDefaultREGRESSION_NORMALIZATION_METHOD();
        CommonCodegenUtils.getChainedMethodCallExprFrom("withRegressionNormalizationMethod", asMethodCallExpr).setArgument(0, new NameExpr(defaultREGRESSION_NORMALIZATION_METHOD.getClass().getSimpleName() + "." + defaultREGRESSION_NORMALIZATION_METHOD.name()));
        OP_TYPE op_type = regressionCompilationDTO.getOP_TYPE();
        CommonCodegenUtils.getChainedMethodCallExprFrom("withOpType", asMethodCallExpr).setArgument(0, new NameExpr(op_type.getClass().getSimpleName() + "." + op_type.name()));
        CommonCodegenUtils.getChainedMethodCallExprFrom("withCategoryTableMap", asMethodCallExpr).setArgument(0, new NameExpr(format));
        boolean isBinary = regressionCompilationDTO.isBinary(linkedHashMap.size());
        CommonCodegenUtils.getChainedMethodCallExprFrom("withProbabilityMapFunction", asMethodCallExpr).setArgument(0, getProbabilityMapFunctionExpression(regressionCompilationDTO.getModelNormalizationMethod(), isBinary));
        CommonCodegenUtils.getChainedMethodCallExprFrom("withIsBinary", asMethodCallExpr).setArgument(0, CommonCodegenUtils.getExpressionForObject(Boolean.valueOf(isBinary)));
        CommonCodegenUtils.getChainedMethodCallExprFrom("withTargetField", asMethodCallExpr).setArgument(0, CommonCodegenUtils.getExpressionForObject(regressionCompilationDTO.getTargetFieldName()));
        CommonCodegenUtils.getChainedMethodCallExprFrom("withTargetCategory", asMethodCallExpr).setArgument(0, CommonCodegenUtils.getExpressionForObject(null));
        NodeList<Statement> statements = orElseThrow.getStatements();
        Objects.requireNonNull(blockStmt);
        statements.forEach(blockStmt::addStatement);
        methodDeclaration.setBody(blockStmt);
    }

    static Expression getProbabilityMapFunctionExpression(RegressionModel.NormalizationMethod normalizationMethod, boolean z) {
        if (UNSUPPORTED_NORMALIZATION_METHODS.contains(normalizationMethod)) {
            throw new KiePMMLInternalException(String.format("Unsupported NormalizationMethod %s", normalizationMethod));
        }
        return getProbabilityMapFunctionSupportedExpression(normalizationMethod, z);
    }

    static MethodReferenceExpr getProbabilityMapFunctionSupportedExpression(RegressionModel.NormalizationMethod normalizationMethod, boolean z) {
        String name = normalizationMethod.name();
        if (RegressionModel.NormalizationMethod.NONE.equals(normalizationMethod) && z) {
            name = name + "Binary";
        }
        String format = String.format("get%sProbabilityMap", name);
        CastExpr castExpr = new CastExpr();
        ClassOrInterfaceType typedClassOrInterfaceTypeByTypeNames = CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypeNames(LinkedHashMap.class.getCanonicalName(), Arrays.asList(String.class.getSimpleName(), Double.class.getSimpleName()));
        castExpr.setType((Type) CommonCodegenUtils.getTypedClassOrInterfaceTypeByTypes(SerializableFunction.class.getCanonicalName(), Arrays.asList(typedClassOrInterfaceTypeByTypeNames, typedClassOrInterfaceTypeByTypeNames)));
        castExpr.setExpression("KiePMMLClassificationTable");
        MethodReferenceExpr methodReferenceExpr = new MethodReferenceExpr();
        methodReferenceExpr.setScope(castExpr);
        methodReferenceExpr.setIdentifier(format);
        return methodReferenceExpr;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -556373002:
                if (implMethodName.equals("getNONEBinaryProbabilityMap")) {
                    z = true;
                    break;
                }
                break;
            case 206507776:
                if (implMethodName.equals("getLOGITProbabilityMap")) {
                    z = 7;
                    break;
                }
                break;
            case 244178872:
                if (implMethodName.equals("getCAUCHITProbabilityMap")) {
                    z = 4;
                    break;
                }
                break;
            case 1022975939:
                if (implMethodName.equals("getSIMPLEMAXProbabilityMap")) {
                    z = 5;
                    break;
                }
                break;
            case 1052350900:
                if (implMethodName.equals("getCLOGLOGProbabilityMap")) {
                    z = 6;
                    break;
                }
                break;
            case 1108964509:
                if (implMethodName.equals("getPROBITProbabilityMap")) {
                    z = 3;
                    break;
                }
                break;
            case 1910196939:
                if (implMethodName.equals("getSOFTMAXProbabilityMap")) {
                    z = false;
                    break;
                }
                break;
            case 2056373269:
                if (implMethodName.equals("getNONEProbabilityMap")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/kie/pmml/api/iinterfaces/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/kie/pmml/models/regression/model/KiePMMLClassificationTable") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedHashMap;)Ljava/util/LinkedHashMap;")) {
                    return KiePMMLClassificationTable::getSOFTMAXProbabilityMap;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/kie/pmml/api/iinterfaces/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/kie/pmml/models/regression/model/KiePMMLClassificationTable") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedHashMap;)Ljava/util/LinkedHashMap;")) {
                    return KiePMMLClassificationTable::getNONEBinaryProbabilityMap;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/kie/pmml/api/iinterfaces/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/kie/pmml/models/regression/model/KiePMMLClassificationTable") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedHashMap;)Ljava/util/LinkedHashMap;")) {
                    return KiePMMLClassificationTable::getNONEProbabilityMap;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/kie/pmml/api/iinterfaces/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/kie/pmml/models/regression/model/KiePMMLClassificationTable") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedHashMap;)Ljava/util/LinkedHashMap;")) {
                    return KiePMMLClassificationTable::getPROBITProbabilityMap;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/kie/pmml/api/iinterfaces/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/kie/pmml/models/regression/model/KiePMMLClassificationTable") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedHashMap;)Ljava/util/LinkedHashMap;")) {
                    return KiePMMLClassificationTable::getCAUCHITProbabilityMap;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/kie/pmml/api/iinterfaces/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/kie/pmml/models/regression/model/KiePMMLClassificationTable") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedHashMap;)Ljava/util/LinkedHashMap;")) {
                    return KiePMMLClassificationTable::getSIMPLEMAXProbabilityMap;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/kie/pmml/api/iinterfaces/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/kie/pmml/models/regression/model/KiePMMLClassificationTable") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedHashMap;)Ljava/util/LinkedHashMap;")) {
                    return KiePMMLClassificationTable::getCLOGLOGProbabilityMap;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/kie/pmml/api/iinterfaces/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/kie/pmml/models/regression/model/KiePMMLClassificationTable") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedHashMap;)Ljava/util/LinkedHashMap;")) {
                    return KiePMMLClassificationTable::getLOGITProbabilityMap;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        CLASSIFICATION_TABLE_TEMPLATE.getMethodsByName(GETKIEPMML_TABLE).get(0).mo535clone();
        SUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.SOFTMAX, RegressionModel.NormalizationMethod.SIMPLEMAX, RegressionModel.NormalizationMethod.NONE, RegressionModel.NormalizationMethod.LOGIT, RegressionModel.NormalizationMethod.PROBIT, RegressionModel.NormalizationMethod.CLOGLOG, RegressionModel.NormalizationMethod.CAUCHIT);
        UNSUPPORTED_NORMALIZATION_METHODS = Arrays.asList(RegressionModel.NormalizationMethod.EXP, RegressionModel.NormalizationMethod.LOGLOG);
        classArity = new AtomicInteger(0);
    }
}
