/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.listener;

import ai.djl.Device;
import ai.djl.metric.Dimension;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.metric.Unit;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.listener.TrainingListenerAdapter;
import ai.djl.util.cuda.CudaUtils;
import java.io.BufferedWriter;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryMXBean;
import java.lang.management.MemoryUsage;
import java.lang.management.RuntimeMXBean;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MemoryTrainingListener
extends TrainingListenerAdapter {
    private static final Logger logger = LoggerFactory.getLogger(MemoryTrainingListener.class);
    private String outputDir;

    public MemoryTrainingListener() {
    }

    public MemoryTrainingListener(String outputDir) {
        this.outputDir = outputDir;
    }

    @Override
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        Metrics metrics = trainer.getMetrics();
        MemoryTrainingListener.collectMemoryInfo(metrics);
    }

    @Override
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        Metrics metrics = trainer.getMetrics();
        MemoryTrainingListener.collectMemoryInfo(metrics);
    }

    @Override
    public void onTrainingEnd(Trainer trainer) {
        Metrics metrics = trainer.getMetrics();
        MemoryTrainingListener.dumpMemoryInfo(metrics, this.outputDir);
    }

    public static void collectMemoryInfo(Metrics metrics) {
        if (metrics != null && Boolean.getBoolean("collect-memory")) {
            MemoryMXBean memBean = ManagementFactory.getMemoryMXBean();
            MemoryUsage heap = memBean.getHeapMemoryUsage();
            MemoryUsage nonHeap = memBean.getNonHeapMemoryUsage();
            long heapUsed = heap.getUsed();
            long nonHeapUsed = nonHeap.getUsed();
            MemoryTrainingListener.getProcessInfo(metrics);
            metrics.addMetric("Heap", heapUsed, Unit.BYTES, new Dimension[0]);
            metrics.addMetric("NonHeap", nonHeapUsed, Unit.BYTES, new Dimension[0]);
            int gpuCount = CudaUtils.getGpuCount();
            for (int i = 0; i < gpuCount; ++i) {
                Device device = Device.gpu(i);
                MemoryUsage mem = CudaUtils.getGpuMemory(device);
                metrics.addMetric("GPU-" + i, mem.getCommitted(), Unit.BYTES, new Dimension[0]);
            }
        }
    }

    public static void dumpMemoryInfo(Metrics metrics, String logDir) {
        if (metrics == null || logDir == null) {
            return;
        }
        try {
            Path dir = Paths.get(logDir, new String[0]);
            Files.createDirectories(dir, new FileAttribute[0]);
            Path file = dir.resolve("memory.log");
            try (BufferedWriter writer = Files.newBufferedWriter(file, StandardOpenOption.CREATE, StandardOpenOption.APPEND);){
                ArrayList<Metric> list = new ArrayList<Metric>();
                list.addAll(metrics.getMetric("Heap"));
                list.addAll(metrics.getMetric("NonHeap"));
                list.addAll(metrics.getMetric("cpu"));
                list.addAll(metrics.getMetric("rss"));
                int gpuCount = CudaUtils.getGpuCount();
                for (int i = 0; i < gpuCount; ++i) {
                    list.addAll(metrics.getMetric("GPU-" + i));
                }
                for (Metric metric : list) {
                    writer.append(metric.toString());
                    writer.newLine();
                }
            }
        }
        catch (IOException e) {
            logger.error("Failed dump memory log", (Throwable)e);
        }
    }

    private static void getProcessInfo(Metrics metrics) {
        if (System.getProperty("os.name").startsWith("Linux") || System.getProperty("os.name").startsWith("Mac")) {
            RuntimeMXBean mxBean = ManagementFactory.getRuntimeMXBean();
            String pid = mxBean.getName().split("@")[0];
            String cmd = "ps -o %cpu= -o rss= -p " + pid;
            try {
                Process process = Runtime.getRuntime().exec(cmd);
                try (InputStream is = process.getInputStream();){
                    String line = new String(MemoryTrainingListener.readAll(is), StandardCharsets.UTF_8).trim();
                    String[] tokens = line.split("\\s+");
                    if (tokens.length != 2) {
                        logger.error("Invalid ps output: {}", (Object)line);
                        return;
                    }
                    float cpu = Float.parseFloat(tokens[0]);
                    long rss = Long.parseLong(tokens[1]) * 1024L;
                    metrics.addMetric("cpu", Float.valueOf(cpu), Unit.PERCENT, new Dimension[0]);
                    metrics.addMetric("rss", rss, Unit.BYTES, new Dimension[0]);
                }
            }
            catch (IOException e) {
                logger.error("Failed execute cmd: " + cmd, (Throwable)e);
            }
        }
    }

    private static byte[] readAll(InputStream is) throws IOException {
        try (ByteArrayOutputStream bos = new ByteArrayOutputStream();){
            int read;
            byte[] buf = new byte[8192];
            while ((read = is.read(buf)) != -1) {
                bos.write(buf, 0, read);
            }
            byte[] byArray = bos.toByteArray();
            return byArray;
        }
    }
}

