diff --git a/sbol3/toplevel.py b/sbol3/toplevel.py index 0957168..5eb93e7 100644 --- a/sbol3/toplevel.py +++ b/sbol3/toplevel.py @@ -111,6 +111,14 @@ def clone(self, new_identity: str) -> 'TopLevel': obj.traverse(make_update_references_traverser(identity_map)) return obj + def copy(self, target_doc=None, target_namespace=None): + new_obj = super().copy(target_doc=target_doc, target_namespace=target_namespace) + # Need to set `document` on all children recursively. That's what happens when + # you assign to the `document` property of an Identified + new_obj.document = target_doc + # Comply with the contract of super.copy() + return new_obj + def make_erase_identity_traverser(identity_map: Dict[str, Identified])\ -> Callable[[Identified], None]: diff --git a/test/test_toplevel.py b/test/test_toplevel.py index d2dfbac..5c8f33e 100644 --- a/test/test_toplevel.py +++ b/test/test_toplevel.py @@ -1,3 +1,4 @@ +import os import posixpath import unittest @@ -14,8 +15,11 @@ sbol:type . """ +MODULE_LOCATION = os.path.dirname(os.path.abspath(__file__)) +SBOL3_LOCATION = os.path.join(MODULE_LOCATION, 'SBOLTestSuite', 'SBOL3') -class MyTestCase(unittest.TestCase): + +class TestTopLevel(unittest.TestCase): def setUp(self) -> None: sbol3.set_defaults() @@ -48,6 +52,29 @@ def test_no_namespace_in_file(self): report = doc.validate() self.assertTrue(len(report) > 0) + def test_copy(self): + # See https://github.com/SynBioDex/pySBOL3/issues/176 reopened + # Copying a tree of objects to a new document left the document + # pointer of the child objects unset. This caused "lookup" to + # fail. + dest_doc = sbol3.Document() + + def check_document(i: sbol3.Identified): + # Verify that the object has a document, and that it is the + # expected document. + self.assertIsNotNone(i.document) + self.assertEqual(dest_doc, i.document) + + test_path = os.path.join(SBOL3_LOCATION, 'multicellular', + 'multicellular.nt') + doc = sbol3.Document() + doc.read(test_path) + for obj in doc.objects: + obj.copy(target_doc=dest_doc) + self.assertEqual(len(doc), len(dest_doc)) + for obj in dest_doc.objects: + obj.traverse(check_document) + if __name__ == '__main__': unittest.main()