/*
 * Decompiled with CFR 0.152.
 */
package org.wildfly.extension.ai.chat;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import org.jboss.as.controller.AttributeDefinition;
import org.jboss.as.controller.ExpressionResolver;
import org.jboss.as.controller.OperationContext;
import org.jboss.as.controller.OperationFailedException;
import org.jboss.dmr.ModelNode;
import org.wildfly.extension.ai.AIAttributeDefinitions;
import org.wildfly.extension.ai.chat.AbstractChatModelProviderServiceConfigurator;
import org.wildfly.extension.ai.chat.GeminiChatLanguageModelProviderRegistrar;
import org.wildfly.extension.ai.injection.chat.WildFlyChatModelConfig;
import org.wildfly.extension.ai.injection.chat.WildFlyGeminiChatModelConfig;
import org.wildfly.service.capture.ValueRegistry;
import org.wildfly.subsystem.service.ResourceServiceInstaller;

public class GeminiChatModelProviderServiceConfigurator
extends AbstractChatModelProviderServiceConfigurator {
    private static final Map<String, String> HARM_CATEGORIES = Map.of(GeminiChatLanguageModelProviderRegistrar.HATE_SPEECH.getName(), "HARM_CATEGORY_HATE_SPEECH", GeminiChatLanguageModelProviderRegistrar.SEXUALLY_EXPLICIT.getName(), "HARM_CATEGORY_SEXUALLY_EXPLICIT", GeminiChatLanguageModelProviderRegistrar.DANGEROUS_CONTENT.getName(), "HARM_CATEGORY_DANGEROUS_CONTENT", GeminiChatLanguageModelProviderRegistrar.HARASSMENT.getName(), "HARM_CATEGORY_HARASSMENT", GeminiChatLanguageModelProviderRegistrar.CIVIC_INTEGRITY.getName(), "HARM_CATEGORY_CIVIC_INTEGRITY");

    public GeminiChatModelProviderServiceConfigurator(ValueRegistry<String, WildFlyChatModelConfig> registry) {
        super(registry);
    }

    public ResourceServiceInstaller configure(OperationContext context, ModelNode model) throws OperationFailedException {
        final Long connectTimeOut = AIAttributeDefinitions.CONNECT_TIMEOUT.resolveModelAttribute(context, model).asLong();
        final String key = AIAttributeDefinitions.API_KEY.resolveModelAttribute(context, model).asString();
        final Boolean allowCodeExecution = GeminiChatLanguageModelProviderRegistrar.ALLOWED_CODE_EXECUTION.resolveModelAttribute(context, model).asBooleanOrNull();
        final Boolean enableEnhancedCivicAnswers = GeminiChatLanguageModelProviderRegistrar.ENABLE_ENHANCED_CIVIC_ANSWERS.resolveModelAttribute(context, model).asBooleanOrNull();
        final Double frequencyPenalty = AIAttributeDefinitions.FREQUENCY_PENALTY.resolveModelAttribute(context, model).asDoubleOrNull();
        final Boolean includeCodeExecutionOutput = GeminiChatLanguageModelProviderRegistrar.INCLUDE_CODE_EXECUTION_OUTPUT.resolveModelAttribute(context, model).asBooleanOrNull();
        final Boolean includeThoughts = GeminiChatLanguageModelProviderRegistrar.INCLUDE_THOUGHTS.resolveModelAttribute(context, model).asBooleanOrNull();
        final Integer logProbs = GeminiChatLanguageModelProviderRegistrar.LOG_PROBS.resolveModelAttribute(context, model).asIntOrNull();
        final Boolean logRequests = AIAttributeDefinitions.LOG_REQUESTS.resolveModelAttribute(context, model).asBooleanOrNull();
        final Boolean logResponses = AIAttributeDefinitions.LOG_RESPONSES.resolveModelAttribute(context, model).asBooleanOrNull();
        final Integer maxOutputTokens = GeminiChatLanguageModelProviderRegistrar.MAX_OUTPUT_TOKEN.resolveModelAttribute(context, model).asIntOrNull();
        final String modelName = AIAttributeDefinitions.MODEL_NAME.resolveModelAttribute(context, model).asString();
        final Double presencePenalty = AIAttributeDefinitions.PRESENCE_PENALTY.resolveModelAttribute(context, model).asDoubleOrNull();
        final boolean isJson = AIAttributeDefinitions.ResponseFormat.isJson(AIAttributeDefinitions.RESPONSE_FORMAT.resolveModelAttribute(context, model).asStringOrNull());
        final Boolean responseLogprobs = GeminiChatLanguageModelProviderRegistrar.RESPONSE_LOG_PROBS.resolveModelAttribute(context, model).asBooleanOrNull();
        final Boolean returnThinking = GeminiChatLanguageModelProviderRegistrar.RETURN_THINKING.resolveModelAttribute(context, model).asBooleanOrNull();
        final Integer seed = AIAttributeDefinitions.SEED.resolveModelAttribute(context, model).asIntOrNull();
        final List stopSequences = AIAttributeDefinitions.STOP_SEQUENCES.unwrap((ExpressionResolver)context, model);
        final Boolean streaming = AIAttributeDefinitions.STREAMING.resolveModelAttribute(context, model).asBooleanOrNull();
        final Double temperature = AIAttributeDefinitions.TEMPERATURE.resolveModelAttribute(context, model).asDoubleOrNull();
        final Integer thinkingBudget = GeminiChatLanguageModelProviderRegistrar.THINKING_BUDGET.resolveModelAttribute(context, model).asIntOrNull();
        final Integer topK = GeminiChatLanguageModelProviderRegistrar.TOP_K.resolveModelAttribute(context, model).asIntOrNull();
        final Double topP = AIAttributeDefinitions.TOP_P.resolveModelAttribute(context, model).asDoubleOrNull();
        final boolean isObservable = context.getCapabilityServiceSupport().hasCapability("org.wildfly.extension.opentelemetry");
        final Map<String, String> safetySettingsConfig = this.safetySettingConfig(context, model);
        Supplier<WildFlyChatModelConfig> factory = new Supplier<WildFlyChatModelConfig>(this){
            final /* synthetic */ GeminiChatModelProviderServiceConfigurator this$0;
            {
                this.this$0 = this$0;
            }

            @Override
            public WildFlyChatModelConfig get() {
                return new WildFlyGeminiChatModelConfig().allowCodeExecution(allowCodeExecution).apiKey(key).enableEnhancedCivicAnswers(enableEnhancedCivicAnswers).frequencyPenalty(frequencyPenalty).includeCodeExecutionOutput(includeCodeExecutionOutput).includeThoughts(includeThoughts).logprobs(logProbs).logRequests(logRequests).logResponses(logResponses).maxOutputTokens(maxOutputTokens).modelName(modelName).presencePenalty(presencePenalty).responseLogprobs(responseLogprobs).returnThinking(returnThinking).seed(seed).setJson(isJson).setObservable(isObservable).safetySettings(safetySettingsConfig).setStreaming(streaming.booleanValue()).stopSequences(stopSequences).thinkingBudget(thinkingBudget).temperature(temperature).timeout(connectTimeOut.longValue()).topK(topK).topP(topP);
            }
        };
        return this.installService(context.getCurrentAddressValue(), factory);
    }

    private Map<String, String> safetySettingConfig(OperationContext context, ModelNode model) throws OperationFailedException {
        HashMap<String, String> safetySettings = new HashMap<String, String>();
        this.setSafetySettingConfig(safetySettings, (AttributeDefinition)GeminiChatLanguageModelProviderRegistrar.HATE_SPEECH, context, model);
        this.setSafetySettingConfig(safetySettings, (AttributeDefinition)GeminiChatLanguageModelProviderRegistrar.SEXUALLY_EXPLICIT, context, model);
        this.setSafetySettingConfig(safetySettings, (AttributeDefinition)GeminiChatLanguageModelProviderRegistrar.DANGEROUS_CONTENT, context, model);
        this.setSafetySettingConfig(safetySettings, (AttributeDefinition)GeminiChatLanguageModelProviderRegistrar.HARASSMENT, context, model);
        this.setSafetySettingConfig(safetySettings, (AttributeDefinition)GeminiChatLanguageModelProviderRegistrar.CIVIC_INTEGRITY, context, model);
        return safetySettings;
    }

    private void setSafetySettingConfig(Map<String, String> safetySettings, AttributeDefinition att, OperationContext context, ModelNode model) throws OperationFailedException {
        String value = att.resolveModelAttribute(context, model).asStringOrNull();
        if (value != null) {
            safetySettings.put(HARM_CATEGORIES.get(att.getName()), value);
        }
    }
}

