/*
 * Decompiled with CFR 0.152.
 */
package org.drools.retediagram;

import guru.nidi.graphviz.engine.Format;
import guru.nidi.graphviz.engine.Graphviz;
import guru.nidi.graphviz.model.MutableGraph;
import guru.nidi.graphviz.parse.Parser;
import java.awt.Desktop;
import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.drools.core.base.ClassObjectType;
import org.drools.core.base.ObjectType;
import org.drools.core.common.BaseNode;
import org.drools.core.reteoo.AccumulateNode;
import org.drools.core.reteoo.AlphaNode;
import org.drools.core.reteoo.BetaNode;
import org.drools.core.reteoo.EntryPointNode;
import org.drools.core.reteoo.JoinNode;
import org.drools.core.reteoo.LeftInputAdapterNode;
import org.drools.core.reteoo.LeftTupleSource;
import org.drools.core.reteoo.NotNode;
import org.drools.core.reteoo.ObjectSink;
import org.drools.core.reteoo.ObjectSource;
import org.drools.core.reteoo.ObjectTypeNode;
import org.drools.core.reteoo.Rete;
import org.drools.core.reteoo.RightInputAdapterNode;
import org.drools.core.reteoo.RuleTerminalNode;
import org.drools.core.reteoo.Sink;
import org.drools.core.rule.constraint.BetaNodeFieldConstraint;
import org.drools.kiesession.rulebase.InternalKnowledgeBase;
import org.kie.api.KieBase;
import org.kie.api.runtime.KieRuntime;
import org.kie.api.runtime.KieSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReteDiagram {
    private static final Logger LOG = LoggerFactory.getLogger(ReteDiagram.class);
    private Layout layout;
    private File outputPath;
    private boolean prefixTimestamp;
    private boolean outputSVG;
    private boolean outputPNG;
    private boolean openSVG;
    private boolean openPNG;
    private boolean printDebugVerticalCluster = false;

    private ReteDiagram() {
    }

    public static ReteDiagram newInstance() {
        File outpath = new File(".");
        try {
            outpath = Files.createTempDirectory("retediagram", new FileAttribute[0]).toFile();
        }
        catch (Exception exception) {
            // empty catch block
        }
        return new ReteDiagram().configLayout(Layout.VLEVEL).configFilenameScheme(outpath, true).configGraphvizRender(true, true).configOpenFile(false, false);
    }

    public ReteDiagram configLayout(Layout layout) {
        this.layout = layout;
        return this;
    }

    public ReteDiagram configPrintDebugVerticalCluster(boolean printDebugVerticalCluster) {
        this.printDebugVerticalCluster = printDebugVerticalCluster;
        return this;
    }

    public ReteDiagram configFilenameScheme(File outputPath, boolean prefixTimestamp) {
        this.outputPath = outputPath;
        this.prefixTimestamp = prefixTimestamp;
        return this;
    }

    public ReteDiagram configGraphvizRender(boolean outputSVG, boolean outputPNG) {
        this.outputSVG = outputSVG;
        this.outputPNG = outputPNG;
        return this;
    }

    public ReteDiagram configOpenFile(boolean openSVG, boolean openPNG) {
        this.openSVG = openSVG;
        this.openPNG = openPNG;
        return this;
    }

    public void diagramRete(KieBase kbase) {
        this.diagramRete((InternalKnowledgeBase)kbase);
    }

    public void diagramRete(KieRuntime session) {
        this.diagramRete((InternalKnowledgeBase)session.getKieBase());
    }

    public void diagramRete(KieSession session) {
        this.diagramRete((InternalKnowledgeBase)session.getKieBase());
    }

    public void diagramRete(InternalKnowledgeBase kBase) {
        this.diagramRete(kBase.getRete());
    }

    public void diagramRete(Rete rete) {
        MutableGraph g;
        String timestampPrefix = new SimpleDateFormat("yyyyMMddHHmmssSSS").format(new Date());
        String fileNameNoExtension = (String)(this.prefixTimestamp ? timestampPrefix + "." : "") + rete.getRuleBase().getId();
        String gvFileName = fileNameNoExtension + ".gv";
        String svgFileName = fileNameNoExtension + ".svg";
        String pngFileName = fileNameNoExtension + ".png";
        File gvFile = new File(this.outputPath, gvFileName);
        File svgFile = new File(this.outputPath, svgFileName);
        File pngFile = new File(this.outputPath, pngFileName);
        try (PrintStream out = new PrintStream(new FileOutputStream(gvFile));){
            out.println("digraph g {\ngraph [fontname = \"Overpass\" fontsize=11];\n node [fontname = \"Overpass\" fontsize=11];\n edge [fontname = \"Overpass\" fontsize=11];");
            HashMap<Class<? extends BaseNode>, Set<BaseNode>> levelMap = new HashMap<Class<? extends BaseNode>, Set<BaseNode>>();
            HashMap<Class<? extends BaseNode>, List<BaseNode>> nodeMap = new HashMap<Class<? extends BaseNode>, List<BaseNode>>();
            ArrayList<Vertex<BaseNode, BaseNode>> vertexes = new ArrayList<Vertex<BaseNode, BaseNode>>();
            HashSet<Integer> visitedNodesIDs = new HashSet<Integer>();
            for (EntryPointNode entryPointNode : rete.getEntryPointNodes().values()) {
                ReteDiagram.visitNodes((BaseNode)entryPointNode, "", visitedNodesIDs, nodeMap, vertexes, levelMap, out);
            }
            out.println("");
            ReteDiagram.printNodeMap(nodeMap, out);
            out.println("");
            ReteDiagram.printVertexes(vertexes, out);
            out.println("");
            this.printLevelMap(levelMap, out, vertexes);
            out.println("");
            if (this.layout == Layout.PARTITION) {
                ReteDiagram.printPartitionMap(nodeMap, out, vertexes);
            }
            out.println("}");
        }
        catch (Exception e) {
            LOG.error("Error building diagram", (Throwable)e);
        }
        LOG.info("Written gvFile: {}", (Object)gvFile);
        if (this.outputSVG) {
            try {
                g = new Parser().read(gvFile);
                Graphviz.fromGraph((MutableGraph)g).render(Format.SVG).toFile(svgFile);
                LOG.info("Written svgFile: {}", (Object)svgFile);
            }
            catch (Exception e) {
                LOG.error("Error building SVG file", (Throwable)e);
            }
        }
        if (this.outputPNG) {
            try {
                g = new Parser().read(gvFile);
                Graphviz.fromGraph((MutableGraph)g).render(Format.PNG).toFile(pngFile);
                LOG.info("Written pngFile: {}", (Object)pngFile);
            }
            catch (Exception e) {
                LOG.error("Error building PNG file", (Throwable)e);
            }
        }
        if (this.outputSVG && this.openSVG) {
            try {
                Desktop.getDesktop().open(svgFile);
            }
            catch (Exception e) {
                LOG.error("Error opening SVG file", (Throwable)e);
            }
        }
        if (this.outputPNG && this.openPNG) {
            try {
                Desktop.getDesktop().open(pngFile);
            }
            catch (Exception e) {
                LOG.error("Error opening PNG file", (Throwable)e);
            }
        }
    }

    private static void printVertexes(List<Vertex<BaseNode, BaseNode>> vertexes, PrintStream out) {
        for (Vertex<BaseNode, BaseNode> v : vertexes) {
            out.println(ReteDiagram.printNodeId((BaseNode)v.from) + " -> " + ReteDiagram.printNodeId((BaseNode)v.to) + " ;");
        }
    }

    private static void printNodeMap(HashMap<Class<? extends BaseNode>, List<BaseNode>> nodeMap, PrintStream out) {
        ReteDiagram.printNodeMapNodes(nodeMap.get(EntryPointNode.class), out);
        ReteDiagram.printNodeMapNodes(nodeMap.get(ObjectTypeNode.class), out);
        ReteDiagram.printNodeMapNodes(nodeMap.getOrDefault(AlphaNode.class, Collections.emptyList()), out);
        List<BaseNode> l3 = nodeMap.entrySet().stream().filter(kv -> LeftInputAdapterNode.class.isAssignableFrom((Class)kv.getKey())).flatMap(kv -> ((List)kv.getValue()).stream()).collect(Collectors.toList());
        ReteDiagram.printNodeMapNodes(l3, out);
        ReteDiagram.printNodeMapNodes(nodeMap.getOrDefault(RightInputAdapterNode.class, Collections.emptyList()), out);
        List<BaseNode> l4 = nodeMap.entrySet().stream().filter(kv -> BetaNode.class.isAssignableFrom((Class)kv.getKey())).flatMap(kv -> ((List)kv.getValue()).stream()).collect(Collectors.toList());
        ReteDiagram.printNodeMapNodes(l4, out);
        ReteDiagram.printNodeMapNodes(nodeMap.get(RuleTerminalNode.class), out);
    }

    public static void printNodeMapNodes(List<BaseNode> nodes, PrintStream out) {
        for (BaseNode node : nodes) {
            out.println(ReteDiagram.printNodeId(node) + " " + ReteDiagram.printNodeAttributes(node) + " ;");
        }
    }

    private static void printPartitionMap(HashMap<Class<? extends BaseNode>, List<BaseNode>> nodeMap, PrintStream out, List<Vertex<BaseNode, BaseNode>> vertexes) {
        Map<Integer, List<BaseNode>> byPartition = nodeMap.entrySet().stream().flatMap(kv -> ((List)kv.getValue()).stream()).collect(Collectors.groupingBy(n -> n.getPartitionId() == null ? 0 : n.getPartitionId().getId()));
        for (Map.Entry<Integer, List<BaseNode>> kv2 : byPartition.entrySet()) {
            ReteDiagram.printClusterMapCluster("P" + kv2.getKey(), new HashSet<BaseNode>((Collection)kv2.getValue()), out);
        }
    }

    private void printLevelMap(HashMap<Class<? extends BaseNode>, Set<BaseNode>> levelMap, PrintStream out, List<Vertex<BaseNode, BaseNode>> vertexes) {
        Set<BaseNode> l1 = levelMap.entrySet().stream().filter(kv -> ObjectTypeNode.class.isAssignableFrom((Class)kv.getKey())).flatMap(kv -> ((Set)kv.getValue()).stream()).collect(Collectors.toSet());
        this.printLevelMapLevel("l1", l1, out);
        Set<BaseNode> l2 = levelMap.entrySet().stream().filter(kv -> AlphaNode.class.isAssignableFrom((Class)kv.getKey())).flatMap(kv -> ((Set)kv.getValue()).stream()).collect(Collectors.toSet());
        this.printLevelMapLevel("l2", l2, out);
        Set<BaseNode> l3 = levelMap.entrySet().stream().filter(kv -> LeftInputAdapterNode.class.isAssignableFrom((Class)kv.getKey())).flatMap(kv -> ((Set)kv.getValue()).stream()).collect(Collectors.toSet());
        this.printLevelMapLevel("l3", l3, out);
        Set<BaseNode> lria = levelMap.entrySet().stream().filter(kv -> RightInputAdapterNode.class.isAssignableFrom((Class)kv.getKey())).flatMap(kv -> ((Set)kv.getValue()).stream()).collect(Collectors.toSet());
        this.printLevelMapLevel("lria", lria, out);
        HashSet<BaseNode> lriaSources = new HashSet<BaseNode>();
        Set<Vertex<BaseNode, BaseNode>> onlyBetas = vertexes.stream().filter(v -> v.from instanceof BetaNode).collect(Collectors.toSet());
        for (BaseNode ria : lria) {
            Set t = onlyBetas.stream().filter(v -> ((BaseNode)v.to).equals(ria)).map(v -> (BaseNode)v.from).collect(Collectors.toSet());
            lriaSources.addAll(t);
        }
        for (BaseNode lriaSource : lriaSources) {
            lriaSources.addAll(ReteDiagram.recurseIncomingVertex(lriaSource, onlyBetas));
        }
        this.printLevelMapLevel("lriaSources", lriaSources, out);
        Set<BaseNode> lsubbeta = levelMap.entrySet().stream().filter(kv -> BetaNode.class.isAssignableFrom((Class)kv.getKey())).flatMap(kv -> ((Set)kv.getValue()).stream()).filter(b -> ((BetaNode)b).getObjectType() == null).collect(Collectors.toSet());
        this.printLevelMapLevel("lsubbeta", lsubbeta, out);
        Set<BaseNode> l4 = levelMap.entrySet().stream().filter(kv -> BetaNode.class.isAssignableFrom((Class)kv.getKey())).flatMap(kv -> ((Set)kv.getValue()).stream()).filter(b -> !lriaSources.contains(b)).filter(b -> !lsubbeta.contains(b)).collect(Collectors.toSet());
        this.printLevelMapLevel("l4", l4, out);
        Set<BaseNode> l5 = levelMap.entrySet().stream().filter(kv -> RuleTerminalNode.class.isAssignableFrom((Class)kv.getKey())).flatMap(kv -> ((Set)kv.getValue()).stream()).collect(Collectors.toSet());
        this.printLevelMapLevel("l5", l5, out);
        out.println((this.printDebugVerticalCluster ? "" : " edge[style=invis];\n") + " l1->l2->l3->lriaSources->lria->lsubbeta->l4->l5;");
    }

    private static Set<BaseNode> recurseIncomingVertex(BaseNode to, Set<Vertex<BaseNode, BaseNode>> vertexes) {
        HashSet<BaseNode> acc = new HashSet<BaseNode>();
        for (Vertex<BaseNode, BaseNode> v : vertexes) {
            if (!((BaseNode)v.to).equals(to)) continue;
            acc.add((BaseNode)v.from);
            acc.addAll(ReteDiagram.recurseIncomingVertex((BaseNode)v.from, vertexes));
        }
        return acc;
    }

    private static void printClusterMapCluster(String levelId, Set<BaseNode> value, PrintStream out) {
        StringBuilder nodeIds = new StringBuilder();
        for (BaseNode n : value) {
            nodeIds.append(ReteDiagram.printNodeId(n) + "; ");
        }
        String level = String.format(" subgraph cluster_%1$s{style=dotted; labelloc=b; label=\"%1$s\"; %2$s}", levelId, nodeIds.toString());
        out.println(level);
    }

    private void printLevelMapLevel(String levelId, Set<BaseNode> value, PrintStream out) {
        String level;
        StringBuilder nodeIds = new StringBuilder();
        for (BaseNode n : value) {
            nodeIds.append(ReteDiagram.printNodeId(n) + "; ");
        }
        if (this.layout == Layout.PARTITION) {
            level = String.format(" subgraph %1$s{%1$s[" + (this.printDebugVerticalCluster ? "shape=point, xlabel=\"%1$s\"" : "shape=none, label=\"\"") + "]; %2$s}", levelId, nodeIds.toString());
            out.println(level);
        } else {
            level = String.format(" {rank=same; %1$s[" + (this.printDebugVerticalCluster ? "shape=point, xlabel=\"%1$s\"" : "shape=none, label=\"\"") + "]; %2$s}", levelId, nodeIds.toString());
            out.println(level);
        }
    }

    private static void visitNodes(BaseNode node, String ident, Set<Integer> visitedNodesIDs, HashMap<Class<? extends BaseNode>, List<BaseNode>> nodeMap, List<Vertex<BaseNode, BaseNode>> vertexes, Map<Class<? extends BaseNode>, Set<BaseNode>> levelMap, PrintStream out) {
        if (!visitedNodesIDs.add(node.getId())) {
            return;
        }
        ReteDiagram.addToNodeMap(node, nodeMap);
        ReteDiagram.addToLevel(node, levelMap);
        Sink[] sinks = ReteDiagram.getSinks(node);
        if (sinks != null) {
            for (Sink sink : sinks) {
                vertexes.add(Vertex.of(node, (BaseNode)sink));
                if (!(sink instanceof BaseNode)) continue;
                ReteDiagram.visitNodes((BaseNode)sink, ident + " ", visitedNodesIDs, nodeMap, vertexes, levelMap, out);
            }
        }
    }

    private static void addToNodeMap(BaseNode node, HashMap<Class<? extends BaseNode>, List<BaseNode>> nodeMap) {
        nodeMap.computeIfAbsent(node.getClass(), k -> new ArrayList()).add(node);
    }

    private static void addToLevel(BaseNode node, Map<Class<? extends BaseNode>, Set<BaseNode>> levelMap) {
        levelMap.computeIfAbsent(node.getClass(), k -> new HashSet()).add(node);
    }

    private static String printNodeId(BaseNode node) {
        if (node instanceof EntryPointNode) {
            return "EP" + node.getId();
        }
        if (node instanceof ObjectTypeNode) {
            return "OTN" + node.getId();
        }
        if (node instanceof AlphaNode) {
            return "AN" + node.getId();
        }
        if (node instanceof LeftInputAdapterNode) {
            return "LIA" + node.getId();
        }
        if (node instanceof RightInputAdapterNode) {
            return "RIA" + node.getId();
        }
        if (node instanceof BetaNode) {
            return "BN" + node.getId();
        }
        if (node instanceof RuleTerminalNode) {
            return "RTN" + node.getId();
        }
        return "UNK" + node.getId();
    }

    private static String printNodeAttributes(BaseNode node) {
        if (node instanceof EntryPointNode) {
            EntryPointNode n = (EntryPointNode)node;
            return String.format("[shape=circle width=0.15 fillcolor=black style=filled label=\"\" xlabel=\"%1$s\"]", n.getEntryPoint().getEntryPointId());
        }
        if (node instanceof ObjectTypeNode) {
            ObjectTypeNode n = (ObjectTypeNode)node;
            return String.format("[shape=rect style=rounded label=\"%1$s\"]", ReteDiagram.strObjectType(n.getObjectType()));
        }
        if (node instanceof AlphaNode) {
            AlphaNode n = (AlphaNode)node;
            return String.format("[label=\"%1$s\"]", ReteDiagram.escapeDot(n.getConstraint().toString()));
        }
        if (node instanceof LeftInputAdapterNode) {
            return "[shape=house orientation=-90]";
        }
        if (node instanceof RightInputAdapterNode) {
            return "[shape=house orientation=90]";
        }
        if (node instanceof JoinNode) {
            BetaNode n = (BetaNode)node;
            BetaNodeFieldConstraint[] constraints = n.getConstraints();
            Object label = "\u22c8";
            if (constraints.length > 0) {
                label = ReteDiagram.strObjectType(n.getObjectType(), false);
                label = (String)label + "( " + Arrays.stream(constraints).map(Object::toString).collect(Collectors.joining(", ")) + " )";
            }
            return String.format("[shape=box label=\"%1$s\" href=\"http://drools.org\"]", ReteDiagram.escapeDot((String)label));
        }
        if (node instanceof NotNode) {
            NotNode n = (NotNode)node;
            Object label = "\u22c8";
            if (n.getObjectType() != null) {
                label = ReteDiagram.strObjectType(n.getObjectType(), false);
                label = (String)label + "(";
                if (n.getConstraints().length > 0) {
                    label = (String)label + " " + Arrays.stream(n.getConstraints()).map(Object::toString).collect(Collectors.joining(", ")) + " ";
                }
                label = (String)label + ")";
            }
            return String.format("[shape=box label=\"not( %1$s )\"]", label);
        }
        if (node instanceof AccumulateNode) {
            AccumulateNode n = (AccumulateNode)node;
            return String.format("[shape=box label=<%1$s<BR/>%2$s<BR/>%3$s>]", n, Arrays.asList(n.getAccumulate().getAccumulators()), Arrays.asList(n.getConstraints()));
        }
        if (node instanceof RuleTerminalNode) {
            RuleTerminalNode n = (RuleTerminalNode)node;
            return String.format("[shape=doublecircle width=0.2 fillcolor=black style=filled label=\"\" xlabel=\"%1$s\" href=\"http://drools.org\"]", n.getRule().getName());
        }
        return String.format("[shape=box style=dotted label=\"%1$s\"]", node.toString());
    }

    private static String strObjectType(ObjectType ot) {
        return ReteDiagram.strObjectType(ot, true);
    }

    private static String strObjectType(ObjectType ot, boolean prependAbbrPackage) {
        if (ot instanceof ClassObjectType) {
            return ReteDiagram.abbrvClassForObjectType((ClassObjectType)ot, prependAbbrPackage);
        }
        return "??" + (ot == null ? "null" : ot.toString());
    }

    private static String abbrvClassForObjectType(ClassObjectType cot, boolean prependAbbrPackage) {
        Class classType = cot.getClassType();
        StringBuilder result = new StringBuilder();
        if (prependAbbrPackage) {
            String[] packageToken;
            for (String pt : packageToken = classType.getPackage().getName().split("\\.")) {
                result.append(pt.charAt(0) + ".");
            }
        }
        result.append(classType.getSimpleName());
        return result.toString();
    }

    private static String escapeDot(String string) {
        String escapeQuote = string.replace("\"", "\\\"");
        return escapeQuote;
    }

    public static Sink[] getSinks(BaseNode node) {
        ObjectSink[] sinks = null;
        if (node instanceof EntryPointNode) {
            EntryPointNode source = (EntryPointNode)node;
            Collection otns = source.getObjectTypeNodes().values();
            sinks = otns.toArray(new Sink[otns.size()]);
        } else if (node instanceof ObjectSource) {
            ObjectSource source = (ObjectSource)node;
            sinks = source.getObjectSinkPropagator().getSinks();
        } else if (node instanceof LeftTupleSource) {
            LeftTupleSource source = (LeftTupleSource)node;
            sinks = source.getSinkPropagator().getSinks();
        }
        return sinks;
    }

    public static class Vertex<F, T> {
        public final F from;
        public final T to;

        public Vertex(F from, T to) {
            this.from = from;
            this.to = to;
        }

        public static <F, T> Vertex<F, T> of(F from, T to) {
            return new Vertex<F, T>(from, to);
        }
    }

    public static enum Layout {
        PARTITION,
        VLEVEL;

    }
}

