- subclass the correct class
[plcapi.git] / PLC / Sites.py
index aac4a6d..479c0fb 100644 (file)
@@ -3,14 +3,14 @@ import string
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Slices import Slice, Slices
 from PLC.PCUs import PCU, PCUs
 from PLC.Nodes import Node, Nodes
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Slices import Slice, Slices
 from PLC.PCUs import PCU, PCUs
 from PLC.Nodes import Node, Nodes
-from PLC.NodeGroups import NodeGroup, NodeGroups
 from PLC.Addresses import Address, Addresses
 from PLC.Addresses import Address, Addresses
-import PLC.Persons
+from PLC.Persons import Person, Persons
 
 class Site(Row):
     """
 
 class Site(Row):
     """
@@ -19,39 +19,59 @@ class Site(Row):
     dict. Commit to the database with sync().
     """
 
     dict. Commit to the database with sync().
     """
 
+    table_name = 'sites'
+    primary_key = 'site_id'
+    join_tables = ['person_site', 'site_address', 'peer_site']
     fields = {
         'site_id': Parameter(int, "Site identifier"),
         'name': Parameter(str, "Full site name", max = 254),
         'abbreviated_name': Parameter(str, "Abbreviated site name", max = 50),
         'login_base': Parameter(str, "Site slice prefix", max = 20),
         'is_public': Parameter(bool, "Publicly viewable site"),
     fields = {
         'site_id': Parameter(int, "Site identifier"),
         'name': Parameter(str, "Full site name", max = 254),
         'abbreviated_name': Parameter(str, "Abbreviated site name", max = 50),
         'login_base': Parameter(str, "Site slice prefix", max = 20),
         'is_public': Parameter(bool, "Publicly viewable site"),
-        'latitude': Parameter(float, "Decimal latitude of the site", min = -90.0, max = 90.0),
-        'longitude': Parameter(float, "Decimal longitude of the site", min = -180.0, max = 180.0),
-        'url': Parameter(str, "URL of a page that describes the site", max = 254),
-        'date_created': Parameter(int, "Date and time when site entry was created, in seconds since UNIX epoch"),
-        'last_updated': Parameter(int, "Date and time when site entry was last updated, in seconds since UNIX epoch"),
-        'deleted': Parameter(bool, "Has been deleted"),
+        'latitude': Parameter(float, "Decimal latitude of the site", min = -90.0, max = 90.0, nullok = True),
+        'longitude': Parameter(float, "Decimal longitude of the site", min = -180.0, max = 180.0, nullok = True),
+        'url': Parameter(str, "URL of a page that describes the site", max = 254, nullok = True),
+        'date_created': Parameter(int, "Date and time when site entry was created, in seconds since UNIX epoch", ro = True),
+        'last_updated': Parameter(int, "Date and time when site entry was last updated, in seconds since UNIX epoch", ro = True),
         'max_slices': Parameter(int, "Maximum number of slices that the site is able to create"),
         'max_slivers': Parameter(int, "Maximum number of slivers that the site is able to create"),
         'person_ids': Parameter([int], "List of account identifiers"),
         'slice_ids': Parameter([int], "List of slice identifiers"),
         'address_ids': Parameter([int], "List of address identifiers"),
         'max_slices': Parameter(int, "Maximum number of slices that the site is able to create"),
         'max_slivers': Parameter(int, "Maximum number of slivers that the site is able to create"),
         'person_ids': Parameter([int], "List of account identifiers"),
         'slice_ids': Parameter([int], "List of slice identifiers"),
         'address_ids': Parameter([int], "List of address identifiers"),
-        'pcu_ids': Parameter([int], "List of PCU identifiers"),
+        'pcu_ids': Parameter([int], "List of PCU identifiers"),
         'node_ids': Parameter([int], "List of site node identifiers"),
         'node_ids': Parameter([int], "List of site node identifiers"),
+        'peer_id': Parameter(int, "Peer to which this site belongs", nullok = True),
+        'peer_site_id': Parameter(int, "Foreign site identifier at peer", nullok = True),
         }
 
         }
 
-    def __init__(self, api, fields):
-        Row.__init__(self, fields)
-        self.api = api
+    # for Cache
+    class_key = 'login_base'
+    foreign_fields = ['abbreviated_name', 'name', 'is_public', 'latitude', 'longitude',
+                     'url', 'max_slices', 'max_slivers',
+                     ]
+    # forget about these ones, they are read-only anyway
+    # handling them causes Cache to re-sync all over again 
+    # 'last_updated', 'date_created'
+    foreign_xrefs = []
+
+    def validate_name(self, name):
+        if not len(name):
+            raise PLCInvalidArgument, "Name must be specified"
+
+        return name
+
+    validate_abbreviated_name = validate_name
 
     def validate_login_base(self, login_base):
 
     def validate_login_base(self, login_base):
-        if not set(login_base).issubset(string.ascii_letters):
-            raise PLCInvalidArgument, "Login base must consist only of ASCII letters"
+        if not len(login_base):
+            raise PLCInvalidArgument, "Login base must be specified"
+
+        if not set(login_base).issubset(string.ascii_letters.lower()):
+            raise PLCInvalidArgument, "Login base must consist only of lowercase ASCII letters"
 
 
-        login_base = login_base.lower()
         conflicts = Sites(self.api, [login_base])
         conflicts = Sites(self.api, [login_base])
-        for site_id, site in conflicts.iteritems():
-            if not site['deleted'] and ('site_id' not in self or self['site_id'] != site_id):
+        for site in conflicts:
+            if 'site_id' not in self or self['site_id'] != site['site_id']:
                 raise PLCInvalidArgument, "login_base already in use"
 
         return login_base
                 raise PLCInvalidArgument, "login_base already in use"
 
         return login_base
@@ -70,107 +90,14 @@ class Site(Row):
 
         return longitude
 
 
         return longitude
 
-    def add_person(self, person, commit = True):
-        """
-        Add person to existing site.
-        """
-
-        assert 'site_id' in self
-        assert isinstance(person, PLC.Persons.Person)
-        assert 'person_id' in person
-
-        site_id = self['site_id']
-        person_id = person['person_id']
-        self.api.db.do("INSERT INTO person_site (person_id, site_id)" \
-                       " VALUES(%(person_id)d, %(site_id)d)",
-                       locals())
-
-        if commit:
-            self.api.db.commit()
-
-        if 'person_ids' in self and person_id not in self['person_ids']:
-            self['person_ids'].append(person_id)
-
-        if 'site_ids' in person and site_id not in person['site_ids']:
-            person['site_ids'].append(site_id)
+    validate_date_created = Row.validate_timestamp
+    validate_last_updated = Row.validate_timestamp
 
 
-    def remove_person(self, person, commit = True):
-        """
-        Remove person from existing site.
-        """
-
-        assert 'site_id' in self
-        assert isinstance(person, PLC.Persons.Person)
-        assert 'person_id' in person
-
-        site_id = self['site_id']
-        person_id = person['person_id']
-        self.api.db.do("DELETE FROM person_site" \
-                       " WHERE person_id = %(person_id)d" \
-                       " AND site_id = %(site_id)d",
-                       locals())
-
-        if commit:
-            self.api.db.commit()
-
-        if 'person_ids' in self and person_id in self['person_ids']:
-            self['person_ids'].remove(person_id)
-
-        if 'site_ids' in person and site_id in person['site_ids']:
-            person['site_ids'].remove(site_id)
-
-    def sync(self, commit = True):
-        """
-        Flush changes back to the database.
-        """
+    add_person = Row.add_object(Person, 'person_site')
+    remove_person = Row.remove_object(Person, 'person_site')
 
 
-        self.validate()
-
-        try:
-            if not self['name'] or \
-               not self['abbreviated_name'] or \
-               not self['login_base']:
-                raise KeyError
-        except KeyError:
-            raise PLCInvalidArgument, "name, abbreviated_name, and login_base must all be specified"
-
-        # Fetch a new site_id if necessary
-        if 'site_id' not in self:
-            rows = self.api.db.selectall("SELECT NEXTVAL('sites_site_id_seq') AS site_id")
-            if not rows:
-                raise PLCDBError, "Unable to fetch new site_id"
-            self['site_id'] = rows[0]['site_id']
-            insert = True
-        else:
-            insert = False
-
-        # Filter out fields that cannot be set or updated directly
-        sites_fields = self.api.db.fields('sites')
-        fields = dict(filter(lambda (key, value): key in sites_fields,
-                             self.items()))
-        for ro_field in 'date_created', 'last_updated':
-            if ro_field in fields:
-                del fields[ro_field]
-
-        # Parameterize for safety
-        keys = fields.keys()
-        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
-
-        if insert:
-            # Insert new row in sites table
-            sql = "INSERT INTO sites (%s) VALUES (%s)" % \
-                  (", ".join(keys), ", ".join(values))
-        else:
-            # Update existing row in sites table
-            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
-            sql = "UPDATE sites SET " + \
-                  ", ".join(columns) + \
-                  " WHERE site_id = %(site_id)d"
-
-        self.api.db.do(sql, fields)
-
-        if commit:
-            self.api.db.commit()
+    add_address = Row.add_object(Address, 'site_address')
+    remove_address = Row.remove_object(Address, 'site_address')
 
     def delete(self, commit = True):
         """
 
     def delete(self, commit = True):
         """
@@ -181,14 +108,13 @@ class Site(Row):
 
         # Delete accounts of all people at the site who are not
         # members of at least one other non-deleted site.
 
         # Delete accounts of all people at the site who are not
         # members of at least one other non-deleted site.
-        persons = PLC.Persons.Persons(self.api, self['person_ids'])
-        for person_id, person in persons.iteritems():
+        persons = Persons(self.api, self['person_ids'])
+        for person in persons:
             delete = True
 
             person_sites = Sites(self.api, person['site_ids'])
             delete = True
 
             person_sites = Sites(self.api, person['site_ids'])
-            for person_site_id, person_site in person_sites.iteritems():
-                if person_site_id != self['site_id'] and \
-                   not person_site['deleted']:
+            for person_site in person_sites:
+                if person_site['site_id'] != self['site_id']:
                     delete = False
                     break
 
                     delete = False
                     break
 
@@ -197,29 +123,28 @@ class Site(Row):
 
         # Delete all site addresses
         addresses = Addresses(self.api, self['address_ids'])
 
         # Delete all site addresses
         addresses = Addresses(self.api, self['address_ids'])
-        for address in addresses.values():
-           address.delete(commit = False)
+        for address in addresses:
+            address.delete(commit = False)
 
         # Delete all site slices
         slices = Slices(self.api, self['slice_ids'])
 
         # Delete all site slices
         slices = Slices(self.api, self['slice_ids'])
-        for slice in slices.values():
-           slice.delete(commit = False)
+        for slice in slices:
+            slice.delete(commit = False)
 
         # Delete all site PCUs
 
         # Delete all site PCUs
-        pcus = PCUs(self.api, self['pcu_ids'])
-        # for pcu in pcus.values():
-        #    pcu.delete(commit = False)
+        pcus = PCUs(self.api, self['pcu_ids'])
+        for pcu in pcus:
+            pcu.delete(commit = False)
 
         # Delete all site nodes
         nodes = Nodes(self.api, self['node_ids'])
 
         # Delete all site nodes
         nodes = Nodes(self.api, self['node_ids'])
-        for node in nodes.values():
+        for node in nodes:
             node.delete(commit = False)
 
         # Clean up miscellaneous join tables
             node.delete(commit = False)
 
         # Clean up miscellaneous join tables
-        for table in ['person_site']:
-            self.api.db.do("DELETE FROM %s" \
-                           " WHERE site_id = %d" % \
-                           (table, self['site_id']), self)
+        for table in self.join_tables:
+            self.api.db.do("DELETE FROM %s WHERE site_id = %d" % \
+                           (table, self['site_id']))
 
         # Mark as deleted
         self['deleted'] = True
 
         # Mark as deleted
         self['deleted'] = True
@@ -228,36 +153,24 @@ class Site(Row):
 class Sites(Table):
     """
     Representation of row(s) from the sites table in the
 class Sites(Table):
     """
     Representation of row(s) from the sites table in the
-    database. Specify fields to limit columns to just the specified
-    fields.
+    database.
     """
 
     """
 
-    def __init__(self, api, site_id_or_login_base_list = None, fields = Site.fields):
-        self.api = api
+    def __init__(self, api, site_filter = None, columns = None):
+        Table.__init__(self, api, Site, columns)
 
         sql = "SELECT %s FROM view_sites WHERE deleted IS False" % \
 
         sql = "SELECT %s FROM view_sites WHERE deleted IS False" % \
-              ", ".join(fields)
-
-        if site_id_or_login_base_list:
-            # Separate the list into integers and strings
-            site_ids = filter(lambda site_id: isinstance(site_id, (int, long)),
-                              site_id_or_login_base_list)
-            login_bases = filter(lambda login_base: isinstance(login_base, StringTypes),
-                                 site_id_or_login_base_list)
-            sql += " AND (False"
-            if site_ids:
-                sql += " OR site_id IN (%s)" % ", ".join(map(str, site_ids))
-            if login_bases:
-                sql += " OR login_base IN (%s)" % ", ".join(api.db.quote(login_bases))
-            sql += ")"
-
-        rows = self.api.db.selectall(sql)
-
-        for row in rows:
-            self[row['site_id']] = site = Site(api, row)
-            for aggregate in ['person_ids', 'slice_ids', 'address_ids',
-                              'pcu_ids', 'node_ids']:
-                if not site.has_key(aggregate) or site[aggregate] is None:
-                    site[aggregate] = []
-                else:
-                    site[aggregate] = map(int, site[aggregate].split(','))
+              ", ".join(self.columns)
+
+        if site_filter is not None:
+            if isinstance(site_filter, (list, tuple, set)):
+                # Separate the list into integers and strings
+                ints = filter(lambda x: isinstance(x, (int, long)), site_filter)
+                strs = filter(lambda x: isinstance(x, StringTypes), site_filter)
+                site_filter = Filter(Site.fields, {'site_id': ints, 'login_base': strs})
+                sql += " AND (%s)" % site_filter.sql(api, "OR")
+            elif isinstance(site_filter, dict):
+                site_filter = Filter(Site.fields, site_filter)
+                sql += " AND (%s)" % site_filter.sql(api, "AND")
+
+        self.selectall(sql)