package ai.djl.translate;

import ai.djl.Application;
import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.modality.cv.translator.ImageServingTranslator;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Type;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/translate/ServingTranslatorFactory.class */
public class ServingTranslatorFactory implements TranslatorFactory {
    private static final Logger logger = LoggerFactory.getLogger(ServingTranslatorFactory.class);

    @Override // ai.djl.translate.TranslatorFactory
    public Set<Pair<Type, Type>> getSupportedTypes() {
        return Collections.singleton(new Pair(Input.class, Output.class));
    }

    @Override // ai.djl.translate.TranslatorFactory
    public <I, O> Translator<I, O> newInstance(Class<I> cls, Class<O> cls2, Model model, Map<String, ?> map) throws TranslateException {
        TranslatorFactory loadTranslatorFactory;
        if (!isSupported(cls, cls2)) {
            throw new IllegalArgumentException("Unsupported input/output types.");
        }
        Path modelPath = model.getModelPath();
        String stringValue = ArgumentsUtil.stringValue(map, "translatorFactory");
        if (stringValue != null && !stringValue.isEmpty() && (loadTranslatorFactory = loadTranslatorFactory(stringValue)) != null && loadTranslatorFactory.isSupported(cls, cls2)) {
            logger.info("Using TranslatorFactory: {}", loadTranslatorFactory.getClass().getName());
            return loadTranslatorFactory.newInstance(cls, cls2, model, map);
        }
        String str = (String) map.get("translator");
        Path resolve = modelPath.resolve("libs");
        if (!Files.isDirectory(resolve, new LinkOption[0])) {
            resolve = modelPath.resolve("lib");
            if (!Files.isDirectory(resolve, new LinkOption[0]) && str == null) {
                return (Translator<I, O>) loadDefaultTranslator(map);
            }
        }
        ServingTranslator findTranslator = findTranslator(resolve, str);
        if (findTranslator != null) {
            findTranslator.setArguments(map);
            logger.info("Using translator: {}", findTranslator.getClass().getName());
            return findTranslator;
        }
        if (str != null) {
            throw new TranslateException("Failed to load translator: " + str);
        }
        return (Translator<I, O>) loadDefaultTranslator(map);
    }

    private ServingTranslator findTranslator(Path path, String str) {
        try {
            Path resolve = path.resolve("classes");
            compileJavaClass(resolve);
            ArrayList arrayList = new ArrayList();
            if (Files.isDirectory(path, new LinkOption[0])) {
                Stream<Path> list = Files.list(path);
                try {
                    list.forEach(path2 -> {
                        if (path2.toString().endsWith(".jar")) {
                            arrayList.add(path2);
                        }
                    });
                    if (list != null) {
                        list.close();
                    }
                } finally {
                }
            }
            ArrayList arrayList2 = new ArrayList(arrayList.size() + 1);
            arrayList2.add(resolve.toUri().toURL());
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                arrayList2.add(((Path) it.next()).toUri().toURL());
            }
            URLClassLoader uRLClassLoader = new URLClassLoader((URL[]) arrayList2.toArray(new URL[0]), ClassLoaderUtils.getContextClassLoader());
            if (str != null && !str.isEmpty()) {
                logger.info("Trying to loading specified Translator: {}", str);
                return initTranslator(uRLClassLoader, str);
            }
            ServingTranslator scanDirectory = scanDirectory(uRLClassLoader, resolve);
            if (scanDirectory != null) {
                return scanDirectory;
            }
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                ServingTranslator scanJarFile = scanJarFile(uRLClassLoader, (Path) it2.next());
                if (scanJarFile != null) {
                    return scanJarFile;
                }
            }
            return null;
        } catch (IOException e) {
            logger.debug("Failed to find Translator", e);
            return null;
        }
    }

    private ServingTranslator scanDirectory(ClassLoader classLoader, Path path) throws IOException {
        if (!Files.isDirectory(path, new LinkOption[0])) {
            logger.debug("Directory not exists: {}", path);
            return null;
        }
        Stream<Path> walk = Files.walk(path, new FileVisitOption[0]);
        try {
            Collection collection = (Collection) walk.filter(path2 -> {
                return Files.isRegularFile(path2, new LinkOption[0]) && path2.toString().endsWith(".class");
            }).collect(Collectors.toList());
            if (walk != null) {
                walk.close();
            }
            Iterator it = collection.iterator();
            while (it.hasNext()) {
                String path3 = path.relativize((Path) it.next()).toString();
                String replace = path3.substring(0, path3.lastIndexOf(46)).replace(File.separatorChar, '.');
                ServingTranslator initTranslator = initTranslator(classLoader, replace);
                if (initTranslator != null) {
                    logger.info("Found translator in model directory: {}", replace);
                    return initTranslator;
                }
            }
            return null;
        } catch (Throwable th) {
            if (walk != null) {
                try {
                    walk.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private ServingTranslator scanJarFile(ClassLoader classLoader, Path path) throws IOException {
        String replace;
        ServingTranslator initTranslator;
        JarFile jarFile = new JarFile(path.toFile());
        try {
            Enumeration<JarEntry> entries = jarFile.entries();
            while (entries.hasMoreElements()) {
                String name = entries.nextElement().getName();
                if (name.endsWith(".class") && (initTranslator = initTranslator(classLoader, (replace = name.substring(0, name.lastIndexOf(46)).replace('/', '.')))) != null) {
                    logger.info("Found translator {} in jar {}", replace, path);
                    jarFile.close();
                    return initTranslator;
                }
            }
            jarFile.close();
            return null;
        } catch (Throwable th) {
            try {
                jarFile.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private TranslatorFactory loadTranslatorFactory(String str) {
        try {
            return (TranslatorFactory) Class.forName(str).asSubclass(TranslatorFactory.class).getConstructor(new Class[0]).newInstance(new Object[0]);
        } catch (Throwable th) {
            logger.trace("Not able to load TranslatorFactory: " + str, th);
            return null;
        }
    }

    private ServingTranslator initTranslator(ClassLoader classLoader, String str) {
        try {
            return (ServingTranslator) Class.forName(str, true, classLoader).asSubclass(ServingTranslator.class).getConstructor(new Class[0]).newInstance(new Object[0]);
        } catch (Throwable th) {
            logger.trace("Not able to load Translator: " + str, th);
            return null;
        }
    }

    private Translator<Input, Output> loadDefaultTranslator(Map<String, ?> map) {
        String stringValue = ArgumentsUtil.stringValue(map, "application");
        return (stringValue == null || Application.of(stringValue) != Application.CV.IMAGE_CLASSIFICATION) ? new NoopServingTranslatorFactory().newInstance(Input.class, Output.class, null, map) : getImageClassificationTranslator(map);
    }

    private Translator<Input, Output> getImageClassificationTranslator(Map<String, ?> map) {
        return new ImageServingTranslator(ImageClassificationTranslator.builder(map).build());
    }

    private void compileJavaClass(Path path) {
        try {
            if (!Files.isDirectory(path, new LinkOption[0])) {
                logger.debug("Directory not exists: {}", path);
                return;
            }
            Stream<Path> walk = Files.walk(path, new FileVisitOption[0]);
            try {
                String[] strArr = (String[]) walk.filter(path2 -> {
                    return Files.isRegularFile(path2, new LinkOption[0]) && path2.toString().endsWith(".java");
                }).map(path3 -> {
                    return path3.toAbsolutePath().toString();
                }).toArray(i -> {
                    return new String[i];
                });
                if (walk != null) {
                    walk.close();
                }
                JavaCompiler systemJavaCompiler = ToolProvider.getSystemJavaCompiler();
                if (strArr.length > 0) {
                    systemJavaCompiler.run((InputStream) null, (OutputStream) null, (OutputStream) null, strArr);
                }
            } finally {
            }
        } catch (Throwable th) {
            logger.warn("Failed to compile bundled java file", th);
        }
    }
}
