Record.todict behaves suspisciously
[sfa.git] / sfa / storage / model.py
index 3be257f..7ce8345 100644 (file)
@@ -1,6 +1,7 @@
 from types import StringTypes
 from datetime import datetime
 
 from types import StringTypes
 from datetime import datetime
 
+from sqlalchemy import or_, and_ 
 from sqlalchemy import Column, Integer, String, DateTime
 from sqlalchemy import Table, Column, MetaData, join, ForeignKey
 from sqlalchemy.orm import relationship, backref
 from sqlalchemy import Column, Integer, String, DateTime
 from sqlalchemy import Table, Column, MetaData, join, ForeignKey
 from sqlalchemy.orm import relationship, backref
@@ -17,7 +18,7 @@ from sfa.util.xml import XML
 from sfa.trust.gid import GID
 
 ##############################
 from sfa.trust.gid import GID
 
 ##############################
-Base=declarative_base()
+Base = declarative_base()
 
 ####################
 # dicts vs objects
 
 ####################
 # dicts vs objects
@@ -74,7 +75,7 @@ class AlchemyObj(Record):
 # but we had to define another more internal column (classtype) so we 
 # accomodate variants in types like authority+am and the like
 
 # but we had to define another more internal column (classtype) so we 
 # accomodate variants in types like authority+am and the like
 
-class RegRecord (Base,AlchemyObj):
+class RegRecord (Base, AlchemyObj):
     __tablename__       = 'records'
     record_id           = Column (Integer, primary_key=True)
     # this is the discriminator that tells which class to use
     __tablename__       = 'records'
     record_id           = Column (Integer, primary_key=True)
     # this is the discriminator that tells which class to use
@@ -106,11 +107,19 @@ class RegRecord (Base,AlchemyObj):
         if dict:                                self.load_from_dict (dict)
 
     def __repr__(self):
         if dict:                                self.load_from_dict (dict)
 
     def __repr__(self):
-        result="<Record id=%s, type=%s, hrn=%s, authority=%s, pointer=%s" % \
-                (self.record_id, self.type, self.hrn, self.authority, self.pointer)
+        result="<Record id=%s, type=%s, hrn=%s, authority=%s" % \
+                (self.record_id, self.type, self.hrn, self.authority)
+#        for extra in ('pointer', 'email', 'name'):
+#        for extra in ('email', 'name'):
+# displaying names at this point it too dangerous, because of unicode
+        for extra in ('email'):
+            if hasattr(self, extra):
+                result += " {}={},".format(extra, getattr(self, extra))
         # skip the uniform '--- BEGIN CERTIFICATE --' stuff
         # skip the uniform '--- BEGIN CERTIFICATE --' stuff
-        if self.gid: result+=" gid=%s..."%self.gid[28:36]
-        else: result+=" nogid"
+        if self.gid:
+            result+=" gid=%s..."%self.gid[28:36]
+        else:
+            result+=" nogid"
         result += ">"
         return result
 
         result += ">"
         return result
 
@@ -125,30 +134,35 @@ class RegRecord (Base,AlchemyObj):
         else:                               return gid.save_to_string(save_parents=True)
 
     def validate_datetime (self, key, incoming):
         else:                               return gid.save_to_string(save_parents=True)
 
     def validate_datetime (self, key, incoming):
-        if isinstance (incoming, datetime):     return incoming
-        elif isinstance (incoming, (int,float)):return datetime.fromtimestamp (incoming)
-        else: logger.info("Cannot validate datetime for key %s with input %s"%\
-                              (key,incoming))
+        if isinstance (incoming, datetime):
+            return incoming
+        elif isinstance (incoming, (int, float)):
+            return datetime.fromtimestamp (incoming)
+        else:
+            logger.info("Cannot validate datetime for key %s with input %s"%\
+                        (key,incoming))
 
     @validates ('date_created')
 
     @validates ('date_created')
-    def validate_date_created (self, key, incoming): return self.validate_datetime (key, incoming)
+    def validate_date_created (self, key, incoming):
+        return self.validate_datetime (key, incoming)
 
     @validates ('last_updated')
 
     @validates ('last_updated')
-    def validate_last_updated (self, key, incoming): return self.validate_datetime (key, incoming)
+    def validate_last_updated (self, key, incoming):
+        return self.validate_datetime (key, incoming)
 
     # xxx - there might be smarter ways to handle get/set'ing gid using validation hooks 
     def get_gid_object (self):
 
     # xxx - there might be smarter ways to handle get/set'ing gid using validation hooks 
     def get_gid_object (self):
-        if not self.gid: return None
-        else: return GID(string=self.gid)
+        if not self.gid:        return None
+        else:                   return GID(string=self.gid)
 
     def just_created (self):
 
     def just_created (self):
-        now=datetime.now()
-        self.date_created=now
-        self.last_updated=now
+        now = datetime.utcnow()
+        self.date_created = now
+        self.last_updated = now
 
     def just_updated (self):
 
     def just_updated (self):
-        now=datetime.now()
-        self.last_updated=now
+        now = datetime.utcnow()
+        self.last_updated = now
 
 #################### cross-relations tables
 # authority x user (pis) association
 
 #################### cross-relations tables
 # authority x user (pis) association
@@ -173,31 +187,40 @@ class RegAuthority (RegRecord):
     __mapper_args__     = { 'polymorphic_identity' : 'authority' }
     record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
     #### extensions come here
     __mapper_args__     = { 'polymorphic_identity' : 'authority' }
     record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
     #### extensions come here
+    name                = Column ('name', String)
+    #### extensions come here
     reg_pis             = relationship \
         ('RegUser',
     reg_pis             = relationship \
         ('RegUser',
-         secondary=authority_pi_table,
-         primaryjoin=RegRecord.record_id==authority_pi_table.c.authority_id,
-         secondaryjoin=RegRecord.record_id==authority_pi_table.c.pi_id,
-         backref='reg_authorities_as_pi')
+         secondary = authority_pi_table,
+         primaryjoin = RegRecord.record_id==authority_pi_table.c.authority_id,
+         secondaryjoin = RegRecord.record_id==authority_pi_table.c.pi_id,
+         backref = 'reg_authorities_as_pi',
+        )
     
     def __init__ (self, **kwds):
     
     def __init__ (self, **kwds):
+        # handle local settings
+        if 'name' in kwds:
+            self.name = kwds.pop('name')
         # fill in type if not previously set
         # fill in type if not previously set
-        if 'type' not in kwds: kwds['type']='authority'
+        if 'type' not in kwds:
+            kwds['type']='authority'
         # base class constructor
         RegRecord.__init__(self, **kwds)
 
     # no proper data yet, just hack the typename
     def __repr__ (self):
         # base class constructor
         RegRecord.__init__(self, **kwds)
 
     # no proper data yet, just hack the typename
     def __repr__ (self):
-        return RegRecord.__repr__(self).replace("Record","Authority")
+        result = RegRecord.__repr__(self).replace("Record", "Authority")
+# here again trying to display names that can be utf8 is too dangerous        
+#        result.replace(">", " name={}>".format(self.name))
+        return result
 
 
-    def update_pis (self, pi_hrns):
-        # don't ruin the import of that file in a client world
-        from sfa.storage.alchemy import dbsession
+    def update_pis (self, pi_hrns, dbsession):
         # strip that in case we have <researcher> words </researcher>
         pi_hrns = [ x.strip() for x in pi_hrns ]
         # strip that in case we have <researcher> words </researcher>
         pi_hrns = [ x.strip() for x in pi_hrns ]
-        request = dbsession.query (RegUser).filter(RegUser.hrn.in_(pi_hrns))
-        logger.info ("RegAuthority.update_pis: %d incoming pis, %d matches found"%(len(pi_hrns),request.count()))
-        pis = dbsession.query (RegUser).filter(RegUser.hrn.in_(pi_hrns)).all()
+        request = dbsession.query(RegUser).filter(RegUser.hrn.in_(pi_hrns))
+        logger.info("RegAuthority.update_pis: %d incoming pis, %d matches found"\
+                    % (len(pi_hrns), request.count()))
+        pis = dbsession.query(RegUser).filter(RegUser.hrn.in_(pi_hrns)).all()
         self.reg_pis = pis
 
 ####################
         self.reg_pis = pis
 
 ####################
@@ -211,36 +234,41 @@ class RegSlice (RegRecord):
          secondary=slice_researcher_table,
          primaryjoin=RegRecord.record_id==slice_researcher_table.c.slice_id,
          secondaryjoin=RegRecord.record_id==slice_researcher_table.c.researcher_id,
          secondary=slice_researcher_table,
          primaryjoin=RegRecord.record_id==slice_researcher_table.c.slice_id,
          secondaryjoin=RegRecord.record_id==slice_researcher_table.c.researcher_id,
-         backref='reg_slices_as_researcher')
+         backref='reg_slices_as_researcher',
+        )
 
     def __init__ (self, **kwds):
 
     def __init__ (self, **kwds):
-        if 'type' not in kwds: kwds['type']='slice'
+        if 'type' not in kwds:
+            kwds['type']='slice'
         RegRecord.__init__(self, **kwds)
 
     def __repr__ (self):
         RegRecord.__init__(self, **kwds)
 
     def __repr__ (self):
-        return RegRecord.__repr__(self).replace("Record","Slice")
+        return RegRecord.__repr__(self).replace("Record", "Slice")
 
 
-    def update_researchers (self, researcher_hrns):
-        # don't ruin the import of that file in a client world
-        from sfa.storage.alchemy import dbsession
+    def update_researchers (self, researcher_hrns, dbsession):
         # strip that in case we have <researcher> words </researcher>
         researcher_hrns = [ x.strip() for x in researcher_hrns ]
         request = dbsession.query (RegUser).filter(RegUser.hrn.in_(researcher_hrns))
         # strip that in case we have <researcher> words </researcher>
         researcher_hrns = [ x.strip() for x in researcher_hrns ]
         request = dbsession.query (RegUser).filter(RegUser.hrn.in_(researcher_hrns))
-        logger.info ("RegSlice.update_researchers: %d incoming researchers, %d matches found"%(len(researcher_hrns),request.count()))
+        logger.info ("RegSlice.update_researchers: %d incoming researchers, %d matches found"\
+                     % (len(researcher_hrns), request.count()))
         researchers = dbsession.query (RegUser).filter(RegUser.hrn.in_(researcher_hrns)).all()
         self.reg_researchers = researchers
 
     # when dealing with credentials, we need to retrieve the PIs attached to a slice
         researchers = dbsession.query (RegUser).filter(RegUser.hrn.in_(researcher_hrns)).all()
         self.reg_researchers = researchers
 
     # when dealing with credentials, we need to retrieve the PIs attached to a slice
+    # WARNING: with the move to passing dbsessions around, we face a glitch here because this
+    # helper function is called from the trust/ area that
     def get_pis (self):
     def get_pis (self):
-        # don't ruin the import of that file in a client world
-        from sfa.storage.alchemy import dbsession
+        from sqlalchemy.orm import sessionmaker
+        Session = sessionmaker()
+        dbsession = Session.object_session(self)
         from sfa.util.xrn import get_authority
         authority_hrn = get_authority(self.hrn)
         auth_record = dbsession.query(RegAuthority).filter_by(hrn=authority_hrn).first()
         return auth_record.reg_pis
         
     @validates ('expires')
         from sfa.util.xrn import get_authority
         authority_hrn = get_authority(self.hrn)
         auth_record = dbsession.query(RegAuthority).filter_by(hrn=authority_hrn).first()
         return auth_record.reg_pis
         
     @validates ('expires')
-    def validate_expires (self, key, incoming): return self.validate_datetime (key, incoming)
+    def validate_expires (self, key, incoming):
+        return self.validate_datetime (key, incoming)
 
 ####################
 class RegNode (RegRecord):
 
 ####################
 class RegNode (RegRecord):
@@ -248,12 +276,13 @@ class RegNode (RegRecord):
     __mapper_args__     = { 'polymorphic_identity' : 'node' }
     record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
     
     __mapper_args__     = { 'polymorphic_identity' : 'node' }
     record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
     
-    def __init__ (self, **kwds):
-        if 'type' not in kwds: kwds['type']='node'
+    def __init__(self, **kwds):
+        if 'type' not in kwds:
+            kwds['type']='node'
         RegRecord.__init__(self, **kwds)
 
     def __repr__ (self):
         RegRecord.__init__(self, **kwds)
 
     def __repr__ (self):
-        return RegRecord.__repr__(self).replace("Record","Node")
+        return RegRecord.__repr__(self).replace("Record", "Node")
 
 ####################
 class RegUser (RegRecord):
 
 ####################
 class RegUser (RegRecord):
@@ -267,20 +296,22 @@ class RegUser (RegRecord):
     # a 'keys' tag, and assigning a list of strings in a reference column like this crashes
     reg_keys            = relationship \
         ('RegKey', backref='reg_user',
     # a 'keys' tag, and assigning a list of strings in a reference column like this crashes
     reg_keys            = relationship \
         ('RegKey', backref='reg_user',
-         cascade="all, delete, delete-orphan")
+         cascade = "all, delete, delete-orphan",
+        )
     
     # so we can use RegUser (email=.., hrn=..) and the like
     def __init__ (self, **kwds):
         # handle local settings
     
     # so we can use RegUser (email=.., hrn=..) and the like
     def __init__ (self, **kwds):
         # handle local settings
-        if 'email' in kwds: self.email=kwds.pop('email')
-        if 'type' not in kwds: kwds['type']='user'
+        if 'email' in kwds:
+            self.email = kwds.pop('email')
+        if 'type' not in kwds:
+            kwds['type'] = 'user'
         RegRecord.__init__(self, **kwds)
 
     # append stuff at the end of the record __repr__
     def __repr__ (self): 
         RegRecord.__init__(self, **kwds)
 
     # append stuff at the end of the record __repr__
     def __repr__ (self): 
-        result = RegRecord.__repr__(self).replace("Record","User")
-        result.replace (">"," email=%s"%self.email)
-        result += ">"
+        result = RegRecord.__repr__(self).replace("Record", "User")
+        result.replace(">", " email={}>".format(self.email))
         return result
 
     @validates('email') 
         return result
 
     @validates('email') 
@@ -296,21 +327,103 @@ class RegUser (RegRecord):
 class RegKey (Base):
     __tablename__       = 'keys'
     key_id              = Column (Integer, primary_key=True)
 class RegKey (Base):
     __tablename__       = 'keys'
     key_id              = Column (Integer, primary_key=True)
-    record_id             = Column (Integer, ForeignKey ("records.record_id"))
+    record_id           = Column (Integer, ForeignKey ("records.record_id"))
     key                 = Column (String)
     pointer             = Column (Integer, default = -1)
     
     def __init__ (self, key, pointer=None):
     key                 = Column (String)
     pointer             = Column (Integer, default = -1)
     
     def __init__ (self, key, pointer=None):
-        self.key=key
-        if pointer: self.pointer=pointer
+        self.key = key
+        if pointer:
+            self.pointer = pointer
 
     def __repr__ (self):
 
     def __repr__ (self):
-        result="<key id=%s key=%s..."%(self.key_id,self.key[8:16],)
-        try:    result += " user=%s"%self.reg_user.record_id
+        result = "<key id=%s key=%s..." % (self.key_id, self.key[8:16],)
+        try:    result += " user=%s" % self.reg_user.record_id
         except: result += " no-user"
         result += ">"
         return result
 
         except: result += " no-user"
         result += ">"
         return result
 
+class SliverAllocation(Base,AlchemyObj):
+    __tablename__       = 'sliver_allocation'
+    sliver_id           = Column(String, primary_key=True)
+    client_id           = Column(String)
+    component_id        = Column(String)
+    slice_urn           = Column(String)
+    allocation_state    = Column(String)
+
+    def __init__(self, **kwds):
+        if 'sliver_id' in kwds:
+            self.sliver_id = kwds['sliver_id']
+        if 'client_id' in kwds:
+            self.client_id = kwds['client_id']
+        if 'component_id' in kwds:
+            self.component_id = kwds['component_id']
+        if 'slice_urn' in kwds:
+            self.slice_urn = kwds['slice_urn']
+        if 'allocation_state' in kwds:
+            self.allocation_state = kwds['allocation_state']
+
+    def __repr__(self):
+        result = "<sliver_allocation sliver_id=%s allocation_state=%s"\
+                 % (self.sliver_id, self.allocation_state)
+        return result
+
+    @validates('allocation_state')
+    def validate_allocation_state(self, key, state):
+        allocation_states = ['geni_unallocated', 'geni_allocated', 'geni_provisioned']
+        assert state in allocation_states
+        return state
+
+    @staticmethod    
+    def set_allocations(sliver_ids, state, dbsession):
+        if not isinstance(sliver_ids, list):
+            sliver_ids = [sliver_ids]
+        sliver_state_updated = {}
+        constraint = SliverAllocation.sliver_id.in_(sliver_ids)
+        sliver_allocations = dbsession.query (SliverAllocation).filter(constraint)
+        sliver_ids_found = []
+        for sliver_allocation in sliver_allocations:
+            sliver_allocation.allocation_state = state
+            sliver_ids_found.append(sliver_allocation.sliver_id)
+
+        # Some states may not have been updated becuase no sliver allocation state record
+        # exists for the sliver. Insert new allocation records for these slivers and set
+        # it to geni_allocated.
+        sliver_ids_not_found = set(sliver_ids).difference(sliver_ids_found)
+        for sliver_id in sliver_ids_not_found:
+            record = SliverAllocation(sliver_id=sliver_id, allocation_state=state)
+            dbsession.add(record)
+        dbsession.commit()
+
+    @staticmethod
+    def delete_allocations(sliver_ids, dbsession):
+        if not isinstance(sliver_ids, list):
+            sliver_ids = [sliver_ids]
+        constraint = SliverAllocation.sliver_id.in_(sliver_ids)
+        sliver_allocations = dbsession.query(SliverAllocation).filter(constraint)
+        for sliver_allocation in sliver_allocations:
+            dbsession.delete(sliver_allocation)
+        dbsession.commit()
+    
+    def sync(self, dbsession):
+        constraints = [SliverAllocation.sliver_id == self.sliver_id]
+        results = dbsession.query(SliverAllocation).filter(and_(*constraints))
+        records = []
+        for result in results:
+            records.append(result) 
+        
+        if not records:
+            dbsession.add(self)
+        else:
+            record = records[0]
+            record.sliver_id = self.sliver_id
+            record.client_id  = self.client_id
+            record.component_id  = self.component_id
+            record.slice_urn  = self.slice_urn
+            record.allocation_state = self.allocation_state
+        dbsession.commit()    
+        
+
 ##############################
 # although the db needs of course to be reachable for the following functions
 # the schema management functions are here and not in alchemy
 ##############################
 # although the db needs of course to be reachable for the following functions
 # the schema management functions are here and not in alchemy
@@ -329,7 +442,8 @@ def drop_tables(engine):
 
 ##############################
 # create a record of the right type from either a dict or an xml string
 
 ##############################
 # create a record of the right type from either a dict or an xml string
-def make_record (dict={}, xml=""):
+def make_record (dict=None, xml=""):
+    if dict is None: dict={}
     if dict:    return make_record_dict (dict)
     elif xml:   return make_record_xml (xml)
     else:       raise Exception("make_record has no input")
     if dict:    return make_record_dict (dict)
     elif xml:   return make_record_xml (xml)
     else:       raise Exception("make_record has no input")
@@ -337,27 +451,27 @@ def make_record (dict={}, xml=""):
 # convert an incoming record - typically from xmlrpc - into an object
 def make_record_dict (record_dict):
     assert ('type' in record_dict)
 # convert an incoming record - typically from xmlrpc - into an object
 def make_record_dict (record_dict):
     assert ('type' in record_dict)
-    type=record_dict['type'].split('+')[0]
-    if type=='authority':
-        result=RegAuthority (dict=record_dict)
-    elif type=='user':
-        result=RegUser (dict=record_dict)
-    elif type=='slice':
-        result=RegSlice (dict=record_dict)
-    elif type=='node':
-        result=RegNode (dict=record_dict)
+    type = record_dict['type'].split('+')[0]
+    if type == 'authority':
+        result = RegAuthority (dict=record_dict)
+    elif type == 'user':
+        result = RegUser (dict=record_dict)
+    elif type == 'slice':
+        result = RegSlice (dict=record_dict)
+    elif type == 'node':
+        result = RegNode (dict=record_dict)
     else:
         logger.debug("Untyped RegRecord instance")
     else:
         logger.debug("Untyped RegRecord instance")
-        result=RegRecord (dict=record_dict)
-    logger.info ("converting dict into Reg* with type=%s"%type)
-    logger.info ("returning=%s"%result)
+        result = RegRecord (dict=record_dict)
+    logger.info("converting dict into Reg* with type=%s"%type)
+    logger.info("returning=%s"%result)
     # xxx todo
     # register non-db attributes in an extensions field
     return result
         
     # xxx todo
     # register non-db attributes in an extensions field
     return result
         
-def make_record_xml (xml):
-    xml_record = XML(xml)
-    xml_dict = xml_record.todict()
+def make_record_xml (xml_str):
+    xml = XML(xml_str)
+    xml_dict = xml.todict()
     logger.info("load from xml, keys=%s"%xml_dict.keys())
     return make_record_dict (xml_dict)
 
     logger.info("load from xml, keys=%s"%xml_dict.keys())
     return make_record_dict (xml_dict)
 
@@ -368,27 +482,28 @@ def make_record_xml (xml):
 # were the relationships data came from the testbed side
 # for each type, a dict of the form {<field-name-exposed-in-record>:<alchemy_accessor_name>}
 # so after that, an 'authority' record will e.g. have a 'reg-pis' field with the hrns of its pi-users
 # were the relationships data came from the testbed side
 # for each type, a dict of the form {<field-name-exposed-in-record>:<alchemy_accessor_name>}
 # so after that, an 'authority' record will e.g. have a 'reg-pis' field with the hrns of its pi-users
-augment_map={'authority': {'reg-pis':'reg_pis',},
-             'slice': {'reg-researchers':'reg_researchers',},
-             'user': {'reg-pi-authorities':'reg_authorities_as_pi',
-                      'reg-slices':'reg_slices_as_researcher',},
-             }
+augment_map = {'authority': {'reg-pis' : 'reg_pis',},
+               'slice': {'reg-researchers' : 'reg_researchers',},
+               'user': {'reg-pi-authorities' : 'reg_authorities_as_pi',
+                        'reg-slices' : 'reg_slices_as_researcher',},
+           }
+
 
 
-def augment_with_sfa_builtins (local_record):
+def augment_with_sfa_builtins(local_record):
     # don't ruin the import of that file in a client world
     from sfa.util.xrn import Xrn
     # add a 'urn' field
     # don't ruin the import of that file in a client world
     from sfa.util.xrn import Xrn
     # add a 'urn' field
-    setattr(local_record,'reg-urn',Xrn(xrn=local_record.hrn,type=local_record.type).urn)
+    setattr(local_record, 'reg-urn', Xrn(xrn=local_record.hrn, type=local_record.type).urn)
     # users have keys and this is needed to synthesize 'users' sent over to CreateSliver
     # users have keys and this is needed to synthesize 'users' sent over to CreateSliver
-    if local_record.type=='user':
+    if local_record.type == 'user':
         user_keys = [ key.key for key in local_record.reg_keys ]
         setattr(local_record, 'reg-keys', user_keys)
     # search in map according to record type
         user_keys = [ key.key for key in local_record.reg_keys ]
         setattr(local_record, 'reg-keys', user_keys)
     # search in map according to record type
-    type_map=augment_map.get(local_record.type,{})
+    type_map = augment_map.get(local_record.type, {})
     # use type-dep. map to do the job
     # use type-dep. map to do the job
-    for (field_name,attribute) in type_map.items():
+    for (field_name, attribute) in type_map.items():
         # get related objects
         # get related objects
-        related_records = getattr(local_record,attribute,[])
+        related_records = getattr(local_record, attribute, [])
         hrns = [ r.hrn for r in related_records ]
         setattr (local_record, field_name, hrns)
     
         hrns = [ r.hrn for r in related_records ]
         setattr (local_record, field_name, hrns)