cosmetic
[sfa.git] / sfa / storage / persistentobjs.py
index a1f4cb6..8ff81ed 100644 (file)
@@ -6,9 +6,11 @@ from sqlalchemy import Table, Column, MetaData, join, ForeignKey
 from sqlalchemy.orm import relationship, backref
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import object_mapper
+from sqlalchemy.orm import validates
 from sqlalchemy.ext.declarative import declarative_base
 
 from sfa.util.sfalogging import logger
+from sfa.util.xml import XML 
 
 from sfa.trust.gid import GID
 
@@ -41,7 +43,7 @@ Base=declarative_base()
 # so the latter obj.todict() seems more reliable but more hacky as is relies on the form of fields, so this can probably be improved
 #
 # (*) finally for converting a dictionary into an sqlalchemy object, we provide
-# obj.set_from_dict(dict)
+# obj.load_from_dict(dict)
 
 class AlchemyObj:
     def __iter__(self): 
@@ -54,66 +56,102 @@ class AlchemyObj:
         d=self.__dict__
         keys=[k for k in d.keys() if not k.startswith('_')]
         return dict ( [ (k,d[k]) for k in keys ] )
-    def set_from_dict (self, d):
+    def load_from_dict (self, d):
         for (k,v) in d.iteritems():
             # experimental
-            if isinstance(v, StringTypes):
-                if v.lower() in ['true']: v=True
-                if v.lower() in ['false']: v=False
+            if isinstance(v, StringTypes) and v.lower() in ['true']: v=True
+            if isinstance(v, StringTypes) and v.lower() in ['false']: v=False
             setattr(self,k,v)
+    
+    # in addition we provide convenience for converting to and from xml records
+    # for this purpose only, we need the subclasses to define 'fields' as either 
+    # a list or a dictionary
+    def xml_fields (self):
+        fields=self.fields
+        if isinstance(fields,dict): fields=fields.keys()
+        return fields
 
-##############################
-class Type (Base):
-    __table__ = Table ('types', Base.metadata,
-                       Column ('type',String, primary_key=True)
-                       )
-    def __init__ (self, type): self.type=type
-    def __repr__ (self): return "<Type %s>"%self.type
+    def save_as_xml (self):
+        # xxx not sure about the scope here
+        input_dict = dict( [ (key, getattr(self.key), ) for key in self.xml_fields() if getattr(self,key,None) ] )
+        xml_record=XML("<record />")
+        xml_record.parse_dict (input_dict)
+        return xml_record.toxml()
+
+    def dump(self, dump_parents=False):
+        for key in self.fields:
+            if key == 'gid' and self.gid:
+                gid = GID(string=self.gid)
+                print "    %s:" % key
+                gid.dump(8, dump_parents)
+            elif getattr(self,key,None):    
+                print "    %s: %s" % (key, getattr(self,key))
     
-#BUILTIN_TYPES = [ 'authority', 'slice', 'node', 'user' ]
-# xxx for compat but sounds useless
-BUILTIN_TYPES = [ 'authority', 'slice', 'node', 'user',
-                  'authority+sa', 'authority+am', 'authority+sm' ]
-
-def insert_builtin_types(dbsession):
-    for type in BUILTIN_TYPES :
-        count = dbsession.query (Type).filter_by (type=type).count()
-        if count==0:
-            dbsession.add (Type (type))
-    dbsession.commit()
+#    # only intended for debugging 
+#    def inspect (self, logger, message=""):
+#        logger.info("%s -- Inspecting AlchemyObj -- attrs"%message)
+#        for k in dir(self):
+#            if not k.startswith('_'):
+#                logger.info ("  %s: %s"%(k,getattr(self,k)))
+#        logger.info("%s -- Inspecting AlchemyObj -- __dict__"%message)
+#        d=self.__dict__
+#        for (k,v) in d.iteritems():
+#            logger.info("[%s]=%s"%(k,v))
+
 
 ##############################
+# various kinds of records are implemented as an inheritance hierarchy
+# RegRecord is the base class for all actual variants
+# a first draft was using 'type' as the discriminator for the inheritance
+# but we had to define another more internal column (classtype) so we 
+# accomodate variants in types like authority+am and the like
+
 class RegRecord (Base,AlchemyObj):
     # xxx tmp would be 'records'
-    __table__ = Table ('records', Base.metadata,
-                       Column ('record_id', Integer, primary_key=True),
-                       Column ('type', String, ForeignKey ("types.type")),
-                       Column ('hrn',String),
-                       Column ('gid',String),
-                       Column ('authority',String),
-                       Column ('peer_authority',String),
-                       Column ('pointer',Integer,default=-1),
-                       Column ('date_created',DateTime),
-                       Column ('last_updated',DateTime),
-                       )
-    def __init__ (self, type, hrn=None, gid=None, authority=None, peer_authority=None, pointer=-1):
-        self.type=type
-        if hrn: self.hrn=hrn
+    __tablename__       = 'records'
+    record_id           = Column (Integer, primary_key=True)
+    # this is the discriminator that tells which class to use
+    classtype           = Column (String)
+    type                = Column (String)
+    hrn                 = Column (String)
+    gid                 = Column (String)
+    authority           = Column (String)
+    peer_authority      = Column (String)
+    pointer             = Column (Integer, default=-1)
+    date_created        = Column (DateTime)
+    last_updated        = Column (DateTime)
+    # use the 'type' column to decide which subclass the object is of
+    __mapper_args__     = { 'polymorphic_on' : classtype }
+
+    fields = [ 'type', 'hrn', 'gid', 'authority', 'peer_authority' ]
+    def __init__ (self, type=None, hrn=None, gid=None, authority=None, peer_authority=None, 
+                  pointer=None, dict=None):
+        if type:                                self.type=type
+        if hrn:                                 self.hrn=hrn
         if gid: 
-            if isinstance(gid, StringTypes): self.gid=gid
-            else: self.gid=gid.save_to_string(save_parents=True)
-        if authority: self.authority=authority
-        if peer_authority: self.peer_authority=peer_authority
-        self.pointer=pointer
+            if isinstance(gid, StringTypes):    self.gid=gid
+            else:                               self.gid=gid.save_to_string(save_parents=True)
+        if authority:                           self.authority=authority
+        if peer_authority:                      self.peer_authority=peer_authority
+        if pointer:                             self.pointer=pointer
+        if dict:                                self.load_from_dict (dict)
 
     def __repr__(self):
-        result="[Record(record_id=%s, hrn=%s, type=%s, authority=%s, pointer=%s" % \
-                (self.record_id, self.hrn, self.type, self.authority, self.pointer)
-        if self.gid: result+=" %s..."%self.gid[:10]
-        else: result+=" no-gid"
+        result="[Record id=%s, type=%s, hrn=%s, authority=%s, pointer=%s" % \
+                (self.record_id, self.type, self.hrn, self.authority, self.pointer)
+        # skip the uniform '--- BEGIN CERTIFICATE --' stuff
+        if self.gid: result+=" gid=%s..."%self.gid[28:36]
+        else: result+=" nogid"
         result += "]"
         return result
 
+    @validates ('gid')
+    def validate_gid (self, key, gid):
+        if gid is None:                     return
+        elif isinstance(gid, StringTypes):  return gid
+        else:                               return gid.save_to_string(save_parents=True)
+
+    # xxx - there might be smarter ways to handle get/set'ing gid using validation hooks 
     def get_gid_object (self):
         if not self.gid: return None
         else: return GID(string=self.gid)
@@ -128,38 +166,98 @@ class RegRecord (Base,AlchemyObj):
         self.last_updated=now
 
 ##############################
-class User (Base):
-    __table__ = Table ('users', Base.metadata,
-                       Column ('user_id', Integer, primary_key=True),
-                       Column ('record_id',Integer, ForeignKey('records.record_id')),
-                       Column ('email', String),
-                       )
-    def __init__ (self, email):
-        self.email=email
-    def __repr__ (self): return "<User(%d) %s, record_id=%d>"%(self.user_id,self.email,self.record_id,)
-                           
-record_table = RegRecord.__table__
-user_table = User.__table__
-record_user_join = join (record_table, user_table)
-
-class UserRecord (Base):
-    __table__ = record_user_join
-    record_id = column_property (record_table.c.record_id, user_table.c.record_id)
-    user_id = user_table.c.user_id
-    def __init__ (self, gid, email):
-        self.type='user'
-        self.gid=gid
-        self.email=email
-    def __repr__ (self): return "<UserRecord %s %s>"%(self.email,self.gid)
-
-##############################    
+class RegUser (RegRecord):
+    __tablename__       = 'users'
+    # these objects will have type='user' in the records table
+    __mapper_args__     = { 'polymorphic_identity' : 'user' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+    email               = Column ('email', String)
+    
+    # append stuff at the end of the record __repr__
+    def __repr__ (self): 
+        result = RegRecord.__repr__(self).replace("Record","User")
+        result.replace ("]"," email=%s"%self.email)
+        return result
+    
+    @validates('email') 
+    def validate_email(self, key, address):
+        assert '@' in address
+        return address
+
+class RegAuthority (RegRecord):
+    __tablename__       = 'authorities'
+    __mapper_args__     = { 'polymorphic_identity' : 'authority' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+    
+    # no proper data yet, just hack the typename
+    def __repr__ (self):
+        return RegRecord.__repr__(self).replace("Record","Authority")
+
+class RegSlice (RegRecord):
+    __tablename__       = 'slices'
+    __mapper_args__     = { 'polymorphic_identity' : 'slice' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+    
+    def __repr__ (self):
+        return RegRecord.__repr__(self).replace("Record","Slice")
+
+class RegNode (RegRecord):
+    __tablename__       = 'nodes'
+    __mapper_args__     = { 'polymorphic_identity' : 'node' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+    
+    def __repr__ (self):
+        return RegRecord.__repr__(self).replace("Record","Node")
+
+##############################
+# although the db needs of course to be reachable,
+# the schema management functions are here and not in alchemy
+# because the actual details of the classes need to be known
 def init_tables(dbsession):
     logger.info("Initializing db schema and builtin types")
-    engine=dbsession.get_bind()
+    # the doc states we could retrieve the engine this way
+    # engine=dbsession.get_bind()
+    # however I'm getting this
+    # TypeError: get_bind() takes at least 2 arguments (1 given)
+    # so let's import alchemy - but not from toplevel 
+    from sfa.storage.alchemy import engine
     Base.metadata.create_all(engine)
-    insert_builtin_types(dbsession)
 
 def drop_tables(dbsession):
     logger.info("Dropping tables")
-    engine=dbsession.get_bind()
+    # same as for init_tables
+    from sfa.storage.alchemy import engine
     Base.metadata.drop_all(engine)
+
+##############################
+# create a record of the right type from either a dict or an xml string
+def make_record (dict={}, xml=""):
+    if dict:    return make_record_dict (dict)
+    elif xml:   return make_record_xml (xml)
+    else:       raise Exception("make_record has no input")
+
+# convert an incoming record - typically from xmlrpc - into an object
+def make_record_dict (record_dict):
+    assert ('type' in record_dict)
+    type=record_dict['type'].split('+')[0]
+    if type=='authority':
+        result=RegAuthority (dict=record_dict)
+    elif type=='user':
+        result=RegUser (dict=record_dict)
+    elif type=='slice':
+        result=RegSlice (dict=record_dict)
+    elif type=='node':
+        result=RegNode (dict=record_dict)
+    else:
+        result=RegRecord (dict=record_dict)
+    logger.info ("converting dict into Reg* with type=%s"%type)
+    logger.info ("returning=%s"%result)
+    # xxx todo
+    # register non-db attributes in an extensions field
+    return result
+        
+def make_record_xml (xml):
+    xml_record = XML(xml)
+    xml_dict = xml_record.todict()
+    logger.info("load from xml, keys=%s"%xml_dict.keys())
+    return make_record_dict (xml_dict)