package ai.djl.training;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.training.dataset.Batch;
import ai.djl.training.listener.TrainingListener;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

/* loaded from: input_file:ai/djl/training/ParallelTrain.class */
public final class ParallelTrain {
    private final ExecutorService executor;

    public ParallelTrain(Device[] deviceArr) {
        this.executor = Executors.newFixedThreadPool(deviceArr.length);
    }

    public void trainBatch(Trainer trainer, Batch batch) {
        if (trainer.getManager().getEngine() != batch.getManager().getEngine()) {
            throw new IllegalArgumentException("The data must be on the same engine as the trainer. You may need to change one of your NDManagers.");
        }
        Batch[] split = batch.split(trainer.getDevices(), false);
        TrainingListener.BatchData batchData = new TrainingListener.BatchData(batch, new ConcurrentHashMap(), new ConcurrentHashMap());
        ArrayList arrayList = new ArrayList(split.length);
        for (Batch batch2 : split) {
            arrayList.add(this.executor.submit(() -> {
                GradientCollector newGradientCollector = trainer.newGradientCollector();
                Throwable th = null;
                try {
                    try {
                        NDList data = trainer.getDataManager().getData(batch2);
                        NDList labels = trainer.getDataManager().getLabels(batch2);
                        NDList forward = trainer.forward(data);
                        long nanoTime = System.nanoTime();
                        newGradientCollector.backward(trainer.getLoss().evaluate(labels, forward));
                        trainer.addMetric("backward", nanoTime);
                        long nanoTime2 = System.nanoTime();
                        batchData.getLabels().put(labels.get(0).getDevice(), labels);
                        batchData.getPredictions().put(forward.get(0).getDevice(), forward);
                        trainer.addMetric("training-metrics", nanoTime2);
                        if (newGradientCollector != null) {
                            if (0 != 0) {
                                try {
                                    newGradientCollector.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                newGradientCollector.close();
                            }
                        }
                        return true;
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (newGradientCollector != null) {
                        if (th != null) {
                            try {
                                newGradientCollector.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            newGradientCollector.close();
                        }
                    }
                    throw th3;
                }
            }));
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            try {
                ((Future) it.next()).get();
            } catch (InterruptedException e) {
                e.printStackTrace();
            } catch (ExecutionException e2) {
                e2.printStackTrace();
            }
        }
        trainer.notifyListeners(trainingListener -> {
            trainingListener.onTrainingBatch(trainer, batchData);
        });
    }
}
