trash old code
[sfa.git] / sfa / importer / plimporter.py
1 import os
2
3 from sfa.util.config import Config
4 from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn
5 from sfa.util.plxrn import hostname_to_hrn, slicename_to_hrn, email_to_hrn, hrn_to_pl_slicename
6
7 from sfa.trust.gid import create_uuid    
8 from sfa.trust.certificate import convert_public_key, Keypair
9
10 from sfa.storage.alchemy import dbsession
11 from sfa.storage.model import RegRecord, RegAuthority, RegSlice, RegNode, RegUser, RegKey
12
13 from sfa.plc.plshell import PlShell    
14
15 def _get_site_hrn(interface_hrn, site):
16     # Hardcode 'internet2' into the hrn for sites hosting
17     # internet2 nodes. This is a special operation for some vini
18     # sites only
19     hrn = ".".join([interface_hrn, site['login_base']]) 
20     if ".vini" in interface_hrn and interface_hrn.endswith('vini'):
21         if site['login_base'].startswith("i2") or site['login_base'].startswith("nlr"):
22             hrn = ".".join([interface_hrn, "internet2", site['login_base']])
23     return hrn
24
25
26 class PlImporter:
27
28     def __init__ (self, auth_hierarchy, logger):
29         self.auth_hierarchy = auth_hierarchy
30         self.logger=logger
31
32     def add_options (self, parser):
33         # we don't have any options for now
34         pass
35
36     # this makes the run method a bit abtruse - out of the way
37     def create_special_vini_record (self, interface_hrn):
38         # special case for vini
39         if ".vini" in interface_hrn and interface_hrn.endswith('vini'):
40             # create a fake internet2 site first
41             i2site = {'name': 'Internet2', 'login_base': 'internet2', 'site_id': -1}
42             site_hrn = _get_site_hrn(interface_hrn, i2site)
43             # import if hrn is not in list of existing hrns or if the hrn exists
44             # but its not a site record
45             if ( 'authority', site_hrn, ) not in self.records_by_type_hrn:
46                 urn = hrn_to_urn(site_hrn, 'authority')
47                 if not self.auth_hierarchy.auth_exists(urn):
48                     self.auth_hierarchy.create_auth(urn)
49                 auth_info = self.auth_hierarchy.get_auth_info(urn)
50                 auth_record = RegAuthority(hrn=site_hrn, gid=auth_info.get_gid_object(),
51                                            pointer=site['site_id'],
52                                            authority=get_authority(site_hrn))
53                 auth_record.just_created()
54                 dbsession.add(auth_record)
55                 dbsession.commit()
56                 self.logger.info("PlImporter: Imported authority (vini site) %s"%auth_record)
57
58     def run (self, options):
59         config = Config ()
60         interface_hrn = config.SFA_INTERFACE_HRN
61         root_auth = config.SFA_REGISTRY_ROOT_AUTH
62         shell = PlShell (config)
63
64         ######## retrieve all existing SFA objects
65         records = dbsession.query(RegRecord)
66         # create indexes / hashes by (type,hrn) 
67         self.records_by_type_hrn = dict ( [ ( (record.type, record.hrn) , record ) for record in records ] )
68         # and by (type,pointer)
69         self.records_by_type_pointer = \
70             dict ( [ ( (record.type, record.pointer) , record ) for record in records if record.pointer != -1 ] )
71
72         ######## retrieve PLC records
73         # Get all plc sites
74         # retrieve only required stuf
75         sites = shell.GetSites({'peer_id': None, 'enabled' : True},
76                                ['site_id','login_base','node_ids','slice_ids','person_ids',])
77         # create a hash of sites by login_base
78         sites_by_login_base = dict ( [ ( site['login_base'], site ) for site in sites ] )
79     
80         # Get all plc users
81         persons = shell.GetPersons({'peer_id': None, 'enabled': True}, 
82                                    ['person_id', 'email', 'key_ids', 'site_ids'])
83         # create a hash of persons by person_id
84         persons_by_id = dict ( [ ( person['person_id'], person) for person in persons ] )
85         
86         # Get all plc public keys
87         # accumulate key ids for keys retrieval
88         key_ids = []
89         for person in persons:
90             key_ids.extend(person['key_ids'])
91         keys = shell.GetKeys( {'peer_id': None, 'key_id': key_ids} )
92
93         # create a hash of keys by key_id
94         keys_by_id = dict ( [ ( key['key_id'], key ) for key in keys ] ) 
95
96         # create a dict person_id -> [ (plc)keys ]
97         keys_by_person_id = {} 
98         for person in persons:
99             pubkeys = []
100             for key_id in person['key_ids']:
101                 pubkeys.append(keys_by_id[key_id])
102             keys_by_person_id[person['person_id']] = pubkeys
103
104         # Get all plc nodes  
105         nodes = shell.GetNodes( {'peer_id': None}, ['node_id', 'hostname', 'site_id'])
106         # create hash by node_id
107         nodes_by_id = dict ( [ ( node['node_id'], node, ) for node in nodes ] )
108
109         # Get all plc slices
110         slices = shell.GetSlices( {'peer_id': None}, ['slice_id', 'name'])
111         # create hash by slice_id
112         slices_by_id = dict ( [ (slice['slice_id'], slice ) for slice in slices ] )
113
114         # isolate special vini case in separate method
115         self.create_special_vini_record (interface_hrn)
116
117         # start importing 
118         for site in sites:
119             site_hrn = _get_site_hrn(interface_hrn, site)
120     
121             # import if hrn is not in list of existing hrns or if the hrn exists
122             # but its not a site record
123             if ( 'authority', site_hrn, ) not in self.records_by_type_hrn:
124                 try:
125                     urn = hrn_to_urn(site_hrn, 'authority')
126                     if not self.auth_hierarchy.auth_exists(urn):
127                         self.auth_hierarchy.create_auth(urn)
128                     auth_info = self.auth_hierarchy.get_auth_info(urn)
129                     auth_record = RegAuthority(hrn=site_hrn, gid=auth_info.get_gid_object(),
130                                                pointer=site['site_id'],
131                                                authority=get_authority(site_hrn))
132                     auth_record.just_created()
133                     dbsession.add(auth_record)
134                     dbsession.commit()
135                     self.logger.info("PlImporter: imported authority (site) : %s" % auth_record)  
136                 except:
137                     # if the site import fails then there is no point in trying to import the
138                     # site's child records (node, slices, persons), so skip them.
139                     self.logger.log_exc("PlImporter: failed to import site. Skipping child records") 
140                     continue 
141              
142             # import node records
143             for node_id in site['node_ids']:
144                 if node_id not in nodes_by_id:
145                     continue 
146                 node = nodes_by_id[node_id]
147                 site_auth = get_authority(site_hrn)
148                 site_name = get_leaf(site_hrn)
149                 hrn =  hostname_to_hrn(site_auth, site_name, node['hostname'])
150                 if len(hrn) > 64:
151                     hrn = hrn[:64]
152                 if ( 'node', hrn, ) not in self.records_by_type_hrn:
153                     try:
154                         pkey = Keypair(create=True)
155                         urn = hrn_to_urn(hrn, 'node')
156                         node_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey)
157                         node_record = RegNode (hrn=hrn, gid=node_gid, 
158                                                pointer =node['node_id'],
159                                                authority=get_authority(hrn))
160                         node_record.just_created()
161                         dbsession.add(node_record)
162                         dbsession.commit()
163                         self.logger.info("PlImporter: imported node: %s" % node_record)  
164                     except:
165                         self.logger.log_exc("PlImporter: failed to import node") 
166                     
167
168             # import slices
169             for slice_id in site['slice_ids']:
170                 if slice_id not in slices_by_id:
171                     continue 
172                 slice = slices_by_id[slice_id]
173                 hrn = slicename_to_hrn(interface_hrn, slice['name'])
174                 if ( 'slice', hrn, ) not in self.records_by_type_hrn:
175                     try:
176                         pkey = Keypair(create=True)
177                         urn = hrn_to_urn(hrn, 'slice')
178                         slice_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey)
179                         slice_record = RegSlice (hrn=hrn, gid=slice_gid, 
180                                                  pointer=slice['slice_id'],
181                                                  authority=get_authority(hrn))
182                         slice_record.just_created()
183                         dbsession.add(slice_record)
184                         dbsession.commit()
185                         self.logger.info("PlImporter: imported slice: %s" % slice_record)  
186                     except:
187                         self.logger.log_exc("PlImporter: failed to  import slice")
188
189             # import persons
190             for person_id in site['person_ids']:
191                 if person_id not in persons_by_id:
192                     self.logger.warning ("PlImporter: skipping person %s"%person_id)
193                     continue 
194                 person = persons_by_id[person_id]
195                 hrn = email_to_hrn(site_hrn, person['email'])
196                 if len(hrn) > 64:
197                     hrn = hrn[:64]
198     
199                 previous_record = self.records_by_type_hrn.get( ( 'user', hrn, ) )
200                 if not previous_record:
201                     previous_record = self.records_by_type_pointer.get ( ('user', person_id,) )
202                 # if user's primary key has changed then we need to update the 
203                 # users gid by forcing an update here
204                 plc_keys = []
205                 sfa_keys = []
206                 if previous_record:
207                     sfa_keys = previous_record.reg_keys
208                 if person_id in keys_by_person_id:
209                     plc_keys = keys_by_person_id[person_id]
210                 update_record = False
211                 def key_in_list (key,sfa_keys):
212                     for reg_key in sfa_keys:
213                         if reg_key.key==key['key']: return True
214                     return False
215                 for key in plc_keys:
216                     if not key_in_list (key,sfa_keys):
217                         update_record = True 
218     
219                 if not previous_record or update_record:
220                     try:
221                         pubkey=None
222                         if 'key_ids' in person and person['key_ids']:
223                             # randomly pick first key in set
224                             pubkey = plc_keys[0]
225                             try:
226                                 pkey = convert_public_key(pubkey['key'])
227                             except:
228                                 self.logger.warn('PlImporter: unable to convert public key for %s' % hrn)
229                                 pkey = Keypair(create=True)
230                         else:
231                             # the user has no keys. Creating a random keypair for the user's gid
232                             self.logger.warn("PlImporter: person %s does not have a PL public key"%hrn)
233                             pkey = Keypair(create=True)
234                         urn = hrn_to_urn(hrn, 'user')
235                         person_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey)
236                         if previous_record: 
237                             previous_record.gid=person_gid
238                             if pubkey: previous_record.reg_keys=[ RegKey (pubkey['key'], pubkey['key_id'])]
239                             self.logger.info("PlImporter: updated person: %s" % previous_record)
240                         else:
241                             new_record = RegUser (hrn=hrn, gid=person_gid, 
242                                                   pointer=person['person_id'], 
243                                                   authority=get_authority(hrn),
244                                                   email=person['email'])
245                             if pubkey: 
246                                 new_record.reg_keys=[RegKey (pubkey['key'], pubkey['key_id'])]
247                             else:
248                                 logger.warning("No key found for user %s"%new_record)
249                             dbsession.add (new_record)
250                             dbsession.commit()
251                             self.logger.info("PlImporter: imported person: %s" % new_record)
252                     except:
253                         self.logger.log_exc("PlImporter: failed to import person.") 
254     
255         # remove stale records    
256         system_records = [interface_hrn, root_auth, interface_hrn + '.slicemanager']
257         for record in records:
258             record_hrn=record.hrn
259             if record_hrn in system_records:
260                 continue
261             if record.peer_authority:
262                 continue
263             type=record.type
264             hrn=record.hrn
265             # dont delete vini's internet2 placeholdder record
266             # normally this would be deleted becuase it does not have a plc record 
267             if ".vini" in interface_hrn and interface_hrn.endswith('vini') and \
268                record_hrn.endswith("internet2"):     
269                 continue
270     
271             found = False
272             
273             if isinstance (record, RegAuthority):
274                 for site in sites:
275                     site_hrn = interface_hrn + "." + site['login_base']
276                     if site_hrn == record_hrn and site['site_id'] == record.pointer:
277                         found = True
278                         break
279
280             elif isinstance (record, RegUser):
281                 login_base = get_leaf(get_authority(record_hrn))
282                 username = get_leaf(record_hrn)
283                 if login_base in sites_by_login_base:
284                     site = sites_by_login_base[login_base]
285                     for person in persons:
286                         tmp_username = person['email'].split("@")[0]
287                         alt_username = person['email'].split("@")[0].replace(".", "_").replace("+", "_")
288                         if username in [tmp_username, alt_username] and \
289                            site['site_id'] in person['site_ids'] and \
290                            person['person_id'] == record.pointer:
291                             found = True
292                             break
293         
294             elif isinstance (record, RegSlice):
295                 slicename = hrn_to_pl_slicename(record_hrn)
296                 for slice in slices:
297                     if slicename == slice['name'] and \
298                        slice['slice_id'] == record.pointer:
299                         found = True
300                         break    
301  
302             elif isinstance (record, RegNode):
303                 login_base = get_leaf(get_authority(record_hrn))
304                 nodename = Xrn.unescape(get_leaf(record_hrn))
305                 if login_base in sites_by_login_base:
306                     site = sites_by_login_base[login_base]
307                     for node in nodes:
308                         tmp_nodename = node['hostname']
309                         if tmp_nodename == nodename and \
310                            node['site_id'] == site['site_id'] and \
311                            node['node_id'] == record.pointer:
312                             found = True
313                             break  
314             else:
315                 continue 
316         
317             if not found:
318                 try:
319                     record_object = self.records_by_type_hrn[ ( type, record_hrn, ) ]
320                     self.logger.info("PlImporter: deleting record: %s" % record)
321                     dbsession.delete(record_object)
322                     dbsession.commit()
323                 except:
324                     self.logger.log_exc("PlImporter: failded to delete record")                    
325