/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.engine.rpc;

import ai.djl.Model;
import ai.djl.engine.rpc.RpcClient;
import ai.djl.engine.rpc.RpcTranslator;
import ai.djl.engine.rpc.TypeConverter;
import ai.djl.inference.streaming.ChunkedBytesSupplier;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;
import java.io.IOException;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

public class RpcTranslatorFactory
implements TranslatorFactory {
    private TypeConverter<?, ?> converter;
    private Set<Pair<Type, Type>> supportedTypes;

    public RpcTranslatorFactory() {
        this.supportedTypes = Collections.emptySet();
    }

    public RpcTranslatorFactory(TypeConverter<?, ?> converter) {
        this.converter = converter;
        this.supportedTypes = Collections.singleton(converter.getSupportedType());
    }

    @Override
    public <I, O> Translator<I, O> newInstance(Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) throws TranslateException {
        try {
            if (!this.isSupported(input, output)) {
                throw new IllegalArgumentException("Unsupported input/output types.");
            }
            RpcClient client = RpcClient.getClient(arguments);
            if (this.converter != null) {
                return new RpcTranslator(client, this.converter);
            }
            return new RpcTranslator<I, O>(client, new DefaultTypeConverter<I, O>(input, output));
        }
        catch (IOException e) {
            throw new TranslateException(e);
        }
    }

    @Override
    public Set<Pair<Type, Type>> getSupportedTypes() {
        return this.supportedTypes;
    }

    @Override
    public boolean isSupported(Class<?> input, Class<?> output) {
        if (this.converter == null) {
            return true;
        }
        return TranslatorFactory.super.isSupported(input, output);
    }

    private static final class DefaultTypeConverter<I, O>
    implements TypeConverter<I, O> {
        private Class<I> input;
        private Class<O> output;
        private Method fromJson;
        private Method fromJsonStream;

        DefaultTypeConverter(Class<I> input, Class<O> output) {
            this.input = input;
            this.output = output;
            try {
                this.fromJsonStream = output.getDeclaredMethod("fromJson", Iterator.class);
            }
            catch (ReflectiveOperationException reflectiveOperationException) {
                // empty catch block
            }
            try {
                this.fromJson = output.getDeclaredMethod("fromJson", String.class);
            }
            catch (ReflectiveOperationException reflectiveOperationException) {
                // empty catch block
            }
        }

        @Override
        public Pair<Type, Type> getSupportedType() {
            return new Pair<Type, Type>(this.input, this.output);
        }

        @Override
        public Input toInput(I in) {
            if (in instanceof Input) {
                return (Input)in;
            }
            Input converted = new Input();
            if (in instanceof String) {
                converted.add((String)in);
            } else if (in instanceof byte[]) {
                converted.add((byte[])in);
            } else {
                converted.add(BytesSupplier.wrapAsJson(in));
            }
            return converted;
        }

        @Override
        public O fromOutput(Output out) throws TranslateException {
            if (this.output == Output.class) {
                return (O)out;
            }
            int code = out.getCode();
            BytesSupplier data = out.getData();
            if (code != 200) {
                String error = data == null ? out.getMessage() : out.getMessage() + " " + data.getAsString();
                throw new TranslateException(error);
            }
            if (this.output == String.class) {
                return (O)data.getAsString();
            }
            try {
                if (data instanceof ChunkedBytesSupplier && this.fromJsonStream != null) {
                    ChunkIterator it = new ChunkIterator((ChunkedBytesSupplier)data);
                    return (O)this.fromJsonStream.invoke(null, it);
                }
                if (this.fromJson != null) {
                    return (O)this.fromJson.invoke(null, data.getAsString());
                }
            }
            catch (ReflectiveOperationException e) {
                throw new TranslateException("Failed convert from json", e);
            }
            return (O)JsonUtils.GSON.fromJson(data.getAsString(), this.output);
        }
    }

    private static final class ChunkIterator
    implements Iterator<String> {
        private ChunkedBytesSupplier cbs;
        private boolean error;

        ChunkIterator(ChunkedBytesSupplier cbs) {
            this.cbs = cbs;
        }

        @Override
        public boolean hasNext() {
            if (this.error) {
                return false;
            }
            return this.cbs.hasNext();
        }

        @Override
        public String next() {
            try {
                return new String(this.cbs.nextChunk(20L, TimeUnit.SECONDS), StandardCharsets.UTF_8);
            }
            catch (InterruptedException e) {
                this.error = true;
                return null;
            }
        }
    }
}

