/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.internal.memory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.function.Predicate;
import org.nd4j.autodiff.samediff.internal.IDependeeGroup;
import org.nd4j.autodiff.samediff.internal.IDependencyMap;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;

public class DependencyMap<K extends IDependeeGroup<INDArray>, V>
implements IDependencyMap<K, V> {
    private HashMap<Long, HashSet<Pair<Long, V>>> map = new HashMap();

    @Override
    public void clear() {
        this.map.clear();
    }

    @Override
    public void add(K dependeeGroup, V element) {
        long id = dependeeGroup.getId();
        Collection g = dependeeGroup.getCollection();
        for (INDArray arr : g) {
            if (arr == null) continue;
            HashSet<Pair<Long, V>> v = this.map.get(arr.getId());
            if (v != null) {
                v.add(Pair.create((Object)id, element));
                continue;
            }
            HashSet<Pair> newH = new HashSet<Pair>();
            newH.add(Pair.create((Object)id, element));
            this.map.put(arr.getId(), newH);
        }
    }

    @Override
    public boolean isEmpty() {
        return this.map.isEmpty();
    }

    @Override
    public Iterable<V> getDependantsForEach(K dependeeGroup) {
        HashSet<Object> combination = new HashSet<Object>();
        Collection g = dependeeGroup.getCollection();
        for (INDArray arr : g) {
            HashSet<Pair<Long, V>> hashSet;
            if (arr == null || (hashSet = this.map.get(arr.getId())) == null) continue;
            for (Pair<Long, V> vPair : hashSet) {
                combination.add(vPair.getSecond());
            }
        }
        return combination;
    }

    @Override
    public Iterable<V> getDependantsForGroup(K dependeeGroup) {
        HashSet<Object> combination = new HashSet<Object>();
        Collection g = dependeeGroup.getCollection();
        for (INDArray arr : g) {
            HashSet<Pair<Long, V>> hashSet;
            if (arr == null || (hashSet = this.map.get(arr.getId())) == null) continue;
            for (Pair<Long, V> vPair : hashSet) {
                if (((Long)vPair.getFirst()).longValue() != dependeeGroup.getId()) continue;
                combination.add(vPair.getSecond());
            }
        }
        return combination;
    }

    @Override
    public boolean containsAnyForGroup(K dependeeGroup) {
        Collection g = dependeeGroup.getCollection();
        for (INDArray arr : g) {
            HashSet<Pair<Long, V>> hashSet;
            if (arr == null || (hashSet = this.map.get(arr.getId())) == null) continue;
            for (Pair<Long, V> vPair : hashSet) {
                if (((Long)vPair.getFirst()).longValue() != dependeeGroup.getId()) continue;
                return true;
            }
        }
        return false;
    }

    @Override
    public void removeGroup(K dependeeGroup) {
        Collection g = dependeeGroup.getCollection();
        for (INDArray arr : g) {
            HashSet<Pair<Long, V>> hashSet;
            if (arr == null || (hashSet = this.map.get(arr.getId())) == null) continue;
            long hashSize = hashSet.size();
            ArrayList<Pair<Long, V>> removeList = new ArrayList<Pair<Long, V>>();
            for (Pair<Long, V> vPair : hashSet) {
                if (((Long)vPair.getFirst()).longValue() != dependeeGroup.getId()) continue;
                removeList.add(vPair);
            }
            if (removeList.size() <= 0) continue;
            hashSet.removeAll(removeList);
            if (hashSize != (long)removeList.size()) continue;
            this.map.remove(arr.getId());
        }
    }

    @Override
    public Iterable<V> removeGroupReturn(K dependeeGroup) {
        HashSet<Object> combination = new HashSet<Object>();
        Collection g = dependeeGroup.getCollection();
        for (INDArray arr : g) {
            HashSet<Pair<Long, V>> hashSet;
            if (arr == null || (hashSet = this.map.get(arr.getId())) == null) continue;
            long hashSize = hashSet.size();
            ArrayList<Pair<Long, V>> removeList = new ArrayList<Pair<Long, V>>();
            for (Pair<Long, V> vPair : hashSet) {
                if (((Long)vPair.getFirst()).longValue() != dependeeGroup.getId()) continue;
                removeList.add(vPair);
                combination.add(vPair.getSecond());
            }
            if (removeList.size() <= 0) continue;
            hashSet.removeAll(removeList);
            if (hashSize != (long)removeList.size()) continue;
            this.map.remove(arr.getId());
        }
        return combination;
    }

    @Override
    public void removeForEach(K dependeeGroup) {
        Collection g = dependeeGroup.getCollection();
        for (INDArray arr : g) {
            if (arr == null) continue;
            this.map.remove(arr.getId());
        }
    }

    @Override
    public Iterable<V> removeForEachResult(K dependeeGroup) {
        HashSet<Object> combination = new HashSet<Object>();
        Collection g = dependeeGroup.getCollection();
        for (INDArray arr : g) {
            HashSet<Pair<Long, V>> hashSet;
            if (arr == null || (hashSet = this.map.remove(arr.getId())) == null) continue;
            for (Pair<Long, V> vPair : hashSet) {
                combination.add(vPair.getSecond());
            }
            this.map.remove(arr.getId());
        }
        return combination;
    }

    @Override
    public boolean containsAny(K dependeeGroup) {
        Collection g = dependeeGroup.getCollection();
        for (INDArray arr : g) {
            if (arr == null || !this.map.containsKey(arr.getId())) continue;
            return true;
        }
        return false;
    }

    @Override
    public Iterable<V> removeGroupReturn(K dependeeGroup, Predicate<V> predicate) {
        HashSet<Object> combination = new HashSet<Object>();
        Collection g = dependeeGroup.getCollection();
        for (INDArray arr : g) {
            long id;
            HashSet<Pair<Long, V>> hashSet;
            if (arr == null || (hashSet = this.map.get(id = arr.getId())) == null) continue;
            long hashSize = hashSet.size();
            ArrayList<Pair<Long, V>> removeList = new ArrayList<Pair<Long, V>>();
            for (Pair<Long, V> vPair : hashSet) {
                if (((Long)vPair.getFirst()).longValue() != dependeeGroup.getId() || !predicate.test(vPair.getSecond())) continue;
                removeList.add(vPair);
                combination.add(vPair.getSecond());
            }
            if (removeList.size() <= 0) continue;
            hashSet.removeAll(removeList);
            if (hashSize != (long)removeList.size()) continue;
            this.map.remove(id);
        }
        return combination;
    }
}

