diff --git a/typetapper/hierarchy_graph.py b/typetapper/hierarchy_graph.py index efcc5d2..9ebe901 100644 --- a/typetapper/hierarchy_graph.py +++ b/typetapper/hierarchy_graph.py @@ -101,8 +101,8 @@ class HierarchicalGraph(RelativeAtomGraph): # if item is a group, all edges will have two sides # if item is an atom, all edges will have one side if isinstance(item, RelativeAtom): - yield from ((None, (item, succ, key)) for succ in self.__graph.succ[item] for key in succ) - yield from (((pred, item, key), None) for pred in self.__graph.pred[item] for key in pred) + yield from ((None, (item, succ, key)) for succ, keys in self.__graph.succ[item].items() for key in keys) + yield from (((pred, item, key), None) for pred, keys in self.__graph.pred[item].items() for key in keys) else: yield from ((self.__graph.edges[(item, succ, key)]['prev'], (item, succ, key)) for succ in self.__graph.succ[item] for key in succ) @@ -274,6 +274,14 @@ class HierarchicalGraph(RelativeAtomGraph): self.move_node_in(node, moveto) self.move_node_in(node, new_parent) + def create_group(self, nodes: List[RelativeAtomOrGroup], parent: RelativeAtomGroup) -> RelativeAtomGroup: + if any(self._parent(node) is not parent for node in nodes): + raise ValueError("Not a child of parent") + group = self._add_group(parent) + for node in nodes: + self.move_node_in(node, group) + return group + def _parent(self, item: RelativeAtomOrGroup) -> RelativeAtomGroup: result = item.parent if isinstance(item, RelativeAtomGroup) else self._atom_parents[item] if result is None: diff --git a/typetapper/relative_graph.py b/typetapper/relative_graph.py index 0045f55..fcfcc7f 100644 --- a/typetapper/relative_graph.py +++ b/typetapper/relative_graph.py @@ -112,9 +112,11 @@ class RelativeAtomGraph: relsucc = RelativeAtom(atom=succ, callstack=callstack) res = self._add_node(relsucc, path) if is_pred: - self._add_edge(relsucc, relatom) + if not self.__graph.has_edge(relsucc, relatom): + self._add_edge(relsucc, relatom) else: - self._add_edge(relatom, relsucc) + if not self.__graph.has_edge(relsucc, relatom): + self._add_edge(relsucc, relatom) return relsucc if res else None @staticmethod