parse_xml() method now determines the xml namespace elements
authorTony Mack <tmack@paris.CS.Princeton.EDU>
Tue, 4 Oct 2011 18:45:47 +0000 (14:45 -0400)
committerTony Mack <tmack@paris.CS.Princeton.EDU>
Tue, 4 Oct 2011 18:45:47 +0000 (14:45 -0400)
sfa/rspecs/xml.py

index 57cccbf..ac6526c 100755 (executable)
@@ -28,22 +28,18 @@ class XpathFilter:
 
 class XML:
  
-    def __init__(self, xml=""):
-        self.header = None 
-        self.template = None 
-        self.xml = None
+    def __init__(self, xml=None):
+        self.root = None
         self.namespaces = None
-        if xml:
+        self.default_namespace = None
+        self.schema = None
+        if isinstance(xml, basestring):
             self.parse_xml(xml)
-        else:
-            self.create()
+        elif isinstance(xml, etree._ElementTree):
+            self.root = xml.getroot()
+        elif isinstance(xml, etree._Element):
+            self.root = xml 
 
-    def create(self):
-        """
-        Create root element
-        """
-        self.parse_rspec(self.template)
-    
     def parse_xml(self, xml):
         """
         parse rspec into etree
@@ -57,7 +53,25 @@ class XML:
                 tree = etree.parse(StringIO(xml), parser)
             except Exception, e:
                 raise InvalidRSpec(str(e))
-        self.xml = tree.getroot()  
+        self.root = tree.getroot()
+        # set namespaces map
+        self.namespaces = dict(self.root.nsmap)
+        # If the 'None' exist, then it's pointing to the default namespace. This makes 
+        # it hard for us to write xpath queries for the default naemspace because lxml 
+        # wont understand a None prefix. We will just associate the default namespeace 
+        # with a key named 'default'.     
+        if None in self.namespaces:
+            default_namespace = self.namespaces.pop(None)
+            self.namespaces['default'] = default_namespace
+
+        # set schema 
+        for key in self.root.attrib.keys():
+            if key.endswith('schemaLocation'):
+                # schema location should be at the end of the list
+                schema_parts  = self.root.attrib[key].split(' ')
+                self.schema = schema_parts[1]    
+                namespace, schema  = schema_parts[0], schema_parts[1]
+                break
 
     def validate(self, schema):
         """
@@ -65,14 +79,19 @@ class XML:
         """
         relaxng_doc = etree.parse(schema)
         relaxng = etree.RelaxNG(relaxng_doc)
-        if not relaxng(self.xml):
+        if not relaxng(self.root):
             error = relaxng.error_log.last_error
             message = "%s (line %s)" % (error.message, error.line)
             raise InvalidRSpec(message)
         return True
 
-    def xpath(self, xpath):
-        return self.xml.xpath(xpath, namespaces=self.namespaces)
+    def xpath(self, xpath, namespaces=None):
+        if not namespaces:
+            namespaces = self.namespaces
+        return self.root.xpath(xpath, namespaces=namespaces)
+
+    def set(self, key, value):
+        return self.root.set(key, value)
 
     def add_attribute(self, elem, name, value):
         """
@@ -88,7 +107,7 @@ class XML:
         not specified. 
         """
         if parent == None:
-            parent = self.xml
+            parent = self.root
         element = etree.SubElement(parent, name)
         if text:
             element.text = text
@@ -114,7 +133,7 @@ class XML:
         specified root_node if specified, otherwise start at tree's root.   
         """
         if not root_node:
-            root_node = self.xml
+            root_node = self.root
 
         if not element_name.startswith('//'):
             element_name = '//' + element_name
@@ -157,19 +176,11 @@ class XML:
     def merge(self, in_xml):
         pass
 
-    def cleanup(self):
-        """
-        Optional method which inheriting classes can choose to implent. 
-        """
-        pass 
-
     def __str__(self):
         return self.toxml()
 
-    def toxml(self, cleanup=False):
-        if cleanup:
-            self.cleanup()
-        return self.header + etree.tostring(self.xml, pretty_print=True)  
+    def toxml(self):
+        return etree.tostring(self.root, pretty_print=True)  
         
     def save(self, filename):
         f = open(filename, 'w')