# -*- coding: utf-8 -*-
import collections.abc
import logging
import seamm
import pprint
"""A simple graph structure for holding the flowchart. This handles a
directed graph -- all edges have a direction implied -- with zero or
more edges from or to each node.
"""
logger = logging.getLogger(__name__)
[docs]
class Graph(object):
    """A datastructure for holding a directed graph with multiple (parallel) edges."""
    def __init__(self):
        """Create the graph object"""
        self._node = {}
        self._edge = {}
    def __iter__(self):
        return self._node.values().__iter__()
    def __contains__(self, node):
        return node.uuid in self._node
[docs]
    def add_node(self, node):
        if node in self:
            raise RuntimeError("node is already in the graph")
        self._node[node.uuid] = node
        return node 
[docs]
    def remove_node(self, node):
        if node not in self:
            raise RuntimeError("node is not in the graph")
        del self._node[node.uuid] 
[docs]
    def clear(self):
        self._node = {}
        self._edge = {} 
[docs]
    def add_edge(
        self, u, v, edge_type=None, edge_subtype=None, edge_class=None, **kwargs
    ):
        if u not in self:
            self.add_node(u)
        if v not in self:
            self.add_node(v)
        key = (u.uuid, v.uuid, edge_type, edge_subtype)
        if edge_class is None:
            self._edge[key] = seamm.Edge(
                self, u, v, edge_type=edge_type, edge_subtype=edge_subtype, **kwargs
            )
        else:
            self._edge[key] = edge_class(
                self, u, v, edge_type=edge_type, edge_subtype=edge_subtype, **kwargs
            )
        return self._edge[key] 
[docs]
    def remove_edge(self, u, v, edge_type=None, edge_subtype=None):
        key = (u.uuid, v.uuid, edge_type, edge_subtype)
        if key not in self._edge:
            raise RuntimeError("edge does not exist!")
        del self._edge[key] 
[docs]
    def edges(self, node=None, direction="both"):
        result = []
        if node is None:
            return self._edge.values()
        else:
            h = node.uuid
            if direction == "both":
                for key in self._edge:
                    h1, h2, edge_type, edge_subtype = key
                    if h1 == h:
                        result.append(("out", self._edge[key]))
                    if h2 == h:
                        result.append(("in", self._edge[key]))
            elif direction == "out":
                for key in self._edge:
                    h1, h2, edge_type, edge_subtype = key
                    if h1 == h:
                        result.append(self._edge[key])
            elif direction == "in":
                for key in self._edge:
                    h1, h2, edge_type, edge_subtype = key
                    if h2 == h:
                        result.append(self._edge[key])
            else:
                return RuntimeError("Don't recognize direction '{}'!".format(direction))
        return result 
[docs]
    def has_edge(self, u, v, edge_type=None, edge_subtype=None):
        key = (u.uuid, v.uuid, edge_type, edge_subtype)
        return key in self._edge 
 
[docs]
class Edge(collections.abc.MutableMapping):
    def __init__(
        self, graph, node1, node2, edge_type="execution", edge_subtype="next", **kwargs
    ):
        self.graph = graph
        self._data = dict(**kwargs)
        self._data["node1"] = node1
        self._data["node2"] = node2
        self._data["edge_type"] = edge_type
        self._data["edge_subtype"] = edge_subtype
    def __getitem__(self, key):
        """Allow [] access to the dictionary!"""
        return self._data[key]
    def __setitem__(self, key, value):
        """Allow x[key] access to the data"""
        self._data[key] = value
    def __delitem__(self, key):
        """Allow deletion of keys"""
        del self._data[key]
    def __iter__(self):
        """Allow iteration over the object"""
        return iter(self._data)
    def __len__(self):
        """The len() command"""
        return len(self._data)
    def __repr__(self):
        """The string representation of this object"""
        return repr(self._data)
    def __str__(self):
        """The pretty string representation of this object"""
        return pprint.pformat(self._data)
    def __contains__(self, item):
        """Return a boolean indicating if a key exists."""
        if item in self._data:
            return True
        return False
    def __eq__(self, other):
        """Return a boolean if this object is equal to another"""
        return self._data == other._data
[docs]
    def copy(self):
        """Return a shallow copy of the dictionary"""
        return self._data.copy() 
    @property
    def node1(self):
        return self._data["node1"]
    @property
    def node2(self):
        return self._data["node2"]
    @property
    def edge_type(self):
        return self._data["edge_type"]
    @property
    def edge_subtype(self):
        return self._data["edge_subtype"] 
if __name__ == "__main__":
    class Node(object):
        def __init__(self, **kwargs):
            self.data = dict(**kwargs)
        def __str__(self):
            return pprint.pformat(self.data)
    graph = Graph()
    start = Node(title="start")
    node1 = Node(title="node1")
    node2 = Node(title="node2")
    edge1 = graph.add_edge(start, node1)
    print("edge1 = {}".format(edge1))
    edge2 = graph.add_edge(node1, node2)
    print("edge2 = {}".format(edge2))
    print("nodes:")
    for node in graph:
        print(" node = {}".format(node))
        print()
    print("edges:")
    for edge in graph.edges():
        print(edge)
        print()