package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/djl/nn/ParallelBlock.class */
public class ParallelBlock extends AbstractBlock {
    private static final byte VERSION = 2;
    private Function<List<NDList>, NDList> function;

    public ParallelBlock(Function<List<NDList>, NDList> function) {
        this(function, Collections.emptyList());
    }

    public ParallelBlock(Function<List<NDList>, NDList> function, List<Block> list) {
        super((byte) 2);
        this.function = function;
        addAll(list);
    }

    public final ParallelBlock addAll(Block... blockArr) {
        return addAll(Arrays.asList(blockArr));
    }

    public final ParallelBlock addAll(Collection<Block> collection) {
        collection.forEach(this::add);
        return this;
    }

    public final ParallelBlock add(Block block) {
        if (block != null) {
            addChildBlock(block.getClass().getSimpleName(), block);
        }
        return this;
    }

    public final ParallelBlock add(Function<NDList, NDList> function) {
        return add(new LambdaBlock(function));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ai.djl.nn.AbstractBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        return (NDList) this.function.apply(this.children.values().stream().map(block -> {
            return block.forward(parameterStore, nDList, z, (PairList<String, Object>) pairList);
        }).collect(Collectors.toList()));
    }

    @Override // ai.djl.nn.AbstractBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        Iterator<Block> it = getChildren().values().iterator();
        while (it.hasNext()) {
            it.next().initialize(nDManager, dataType, shapeArr);
        }
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        Preconditions.checkArgument(!this.children.isEmpty(), "The parallel block is empty");
        NDManager newSubManager = nDManager.newSubManager();
        Throwable th = null;
        try {
            ArrayList arrayList = new ArrayList();
            Iterator<Block> it = this.children.values().iterator();
            while (it.hasNext()) {
                Shape[] outputShapes = it.next().getOutputShapes(nDManager, shapeArr);
                NDList nDList = new NDList(outputShapes.length);
                for (Shape shape : outputShapes) {
                    nDList.add(newSubManager.create(shape));
                }
                arrayList.add(nDList);
            }
            NDList apply = this.function.apply(arrayList);
            Shape[] shapeArr2 = new Shape[apply.size()];
            for (int i = 0; i < apply.size(); i++) {
                shapeArr2[i] = apply.get(i).getShape();
            }
            return shapeArr2;
        } finally {
            if (newSubManager != null) {
                if (0 != 0) {
                    try {
                        newSubManager.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    newSubManager.close();
                }
            }
        }
    }

    @Override // ai.djl.nn.AbstractBlock
    public void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b == VERSION) {
            readInputShapes(dataInputStream);
        } else if (b != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b));
        }
    }

    @Override // ai.djl.nn.AbstractBlock
    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Parallel(\n");
        Iterator<Block> it = this.children.values().iterator();
        while (it.hasNext()) {
            sb.append(it.next().toString().replaceAll("(?m)^", "\t")).append('\n');
        }
        sb.append(')');
        return sb.toString();
    }
}
