588c75597ea07ed239ed0bf55a0ca86fb14ea8a5
[sfa.git] / geni / gimport.py
1 #!/usr/bin/python
2 #
3 #
4 ##
5 # Import PLC records into the Geni database. It is indended that this tool be
6 # run once to create Geni records that reflect the current state of the
7 # planetlab database.
8 #
9 # The import tool assumes that the existing PLC hierarchy should all be part
10 # of "planetlab.us" (see the root_auth and level1_auth variables below).
11 #
12 # Public keys are extracted from the users' SSH keys automatically and used to
13 # create GIDs. This is relatively experimental as a custom tool had to be
14 # written to perform conversion from SSH to OpenSSL format. It only supports
15 # RSA keys at this time, not DSA keys.
16 ##
17
18 import getopt
19 import sys
20 import tempfile
21
22 from geni.util.cert import *
23 from geni.util.trustedroot import *
24 from geni.util.hierarchy import *
25 from geni.util.record import *
26 from geni.util.genitable import *
27 from geni.util.misc import *
28 from geni.util.config import *
29
30 # get PL account settings from config module
31 pl_auth = get_pl_auth()
32
33 # connect to planetlab
34 if "Url" in pl_auth:
35     from geni.util import remoteshell
36     shell = remoteshell.RemoteShell()
37 else:
38     import PLC.Shell
39     shell = PLC.Shell.Shell(globals = globals())
40
41 ##
42 # Two authorities are specified: the root authority and the level1 authority.
43
44 #root_auth = "plc"
45 #level1_auth = None
46
47 #root_auth = "planetlab"
48 #level1_auth = "planetlab.us"
49 config = Config()
50
51 root_auth = config.GENI_REGISTRY_ROOT_AUTH
52 level1_auth = config.GENI_REGISTRY_LEVEL1_AUTH
53 if not level1_auth or level1_auth in ['']:
54     level1_auth = None
55
56 keyconvert = 'keyconvert'
57 loaded = False
58 default_path = "/usr/share/keyconvert/" + keyconvert
59 cwd = os.path.dirname(os.path.abspath(__file__))
60 alt_path = os.sep.join(cwd.split(os.sep)[:-1] + ['keyconvert', 'keyconvert'])
61 geni_path = config.GENI_BASE_DIR + os.sep + "keyconvert/keyconvert"
62 files = [default_path, alt_path, geni_path]
63 for path in files:
64     if os.path.isfile(path):
65         keyconvert_fn = path
66         loaded = True
67         break
68
69 if not loaded:
70     raise Exception, "Could not find keyconvert in " + ", ".join(files)        
71
72 def un_unicode(str):
73    if isinstance(str, unicode):
74        return str.encode("ascii", "ignore")
75    else:
76        return str
77
78 def cleanup_string(str):
79     # pgsql has a fit with strings that have high ascii in them, so filter it
80     # out when generating the hrns.
81     tmp = ""
82     for c in str:
83         if ord(c) < 128:
84             tmp = tmp + c
85     str = tmp
86
87     str = un_unicode(str)
88     str = str.replace(" ", "_")
89     str = str.replace(".", "_")
90     str = str.replace("(", "_")
91     str = str.replace("'", "_")
92     str = str.replace(")", "_")
93     str = str.replace('"', "_")
94     return str
95
96 def process_options():
97    global hrn
98
99    (options, args) = getopt.getopt(sys.argv[1:], '', [])
100    for opt in options:
101        name = opt[0]
102        val = opt[1]
103
104 def connect_shell():
105     global pl_auth, shell
106
107     # get PL account settings from config module
108     pl_auth = get_pl_auth()
109
110     # connect to planetlab
111     if "Url" in pl_auth:
112         from geni.util import remoteshell
113         shell = remoteshell.RemoteShell()
114     else:
115         import PLC.Shell
116         shell = PLC.Shell.Shell(globals = globals())
117
118     return shell
119
120 def get_auth_table(auth_name):
121     AuthHierarchy = Hierarchy()
122     auth_info = AuthHierarchy.get_auth_info(auth_name)
123
124     table = GeniTable(hrn=auth_name,
125                       cninfo=auth_info.get_dbinfo())
126
127     # if the table doesn't exist, then it means we haven't put any records
128     # into this authority yet.
129
130     if not table.exists():
131         report.trace("Import: creating table for authority " + auth_name)
132         table.create()
133
134     return table
135
136 def get_pl_pubkey(key_id):
137     keys = shell.GetKeys(pl_auth, [key_id])
138     if keys:
139         key_str = keys[0]['key']
140
141         if "ssh-dss" in key_str:
142             print "XXX: DSA key encountered, ignoring"
143             return None
144
145         # generate temporary files to hold the keys
146         (ssh_f, ssh_fn) = tempfile.mkstemp()
147         ssl_fn = tempfile.mktemp()
148
149         os.write(ssh_f, key_str)
150         os.close(ssh_f)
151
152         if not os.path.exists(keyconvert_fn):
153             report.trace("  keyconvert utility " + str(keyconvert_fn) + " does not exist");
154             sys.exit(-1)
155
156         cmd = keyconvert_fn + " " + ssh_fn + " " + ssl_fn
157         print cmd
158         os.system(cmd)
159
160         # this check leaves the temporary file containing the public key so
161         # that it can be expected to see why it failed.
162         # TODO: for production, cleanup the temporary files
163         if not os.path.exists(ssl_fn):
164             report.trace("  failed to convert key from " + ssh_fn + " to " + ssl_fn)
165             return None
166
167         k = Keypair()
168         try:
169             k.load_pubkey_from_file(ssl_fn)
170         except:
171             print "XXX: Error while converting key: ", key_str
172             k = None
173
174         # remove the temporary files
175         os.remove(ssh_fn)
176         os.remove(ssl_fn)
177
178         return k
179     else:
180         return None
181
182 def person_to_hrn(parent_hrn, person):
183     # the old way - Lastname_Firstname
184     #personname = person['last_name'] + "_" + person['first_name']
185
186     # the new way - use email address up to the "@" 
187     personname = person['email'].split("@")[0]
188
189     personname = cleanup_string(personname)
190
191     hrn = parent_hrn + "." + personname
192     return hrn
193
194 def import_person(parent_hrn, person):
195     AuthHierarchy = Hierarchy()
196     hrn = person_to_hrn(parent_hrn, person)
197
198     # ASN.1 will have problems with hrn's longer than 64 characters
199     if len(hrn) > 64:
200         hrn = hrn[:64]
201
202     report.trace("Import: importing person " + hrn)
203
204     table = get_auth_table(parent_hrn)
205
206     person_record = table.resolve("user", hrn)
207     if not person_record:
208         key_ids = []
209         if 'key_ids' in person:    
210             key_ids = person["key_ids"]
211
212         if key_ids:
213             # get the user's private key from the SSH keys they have uploaded
214             # to planetlab
215             pkey = get_pl_pubkey(key_ids[0])
216         else:
217             # the user has no keys
218             report.trace("   person " + hrn + " does not have a PL public key")
219             pkey = None
220
221         # if a key is unavailable, then we still need to put something in the
222         # user's GID. So make one up.
223         if not pkey:
224             pkey = Keypair(create=True)
225
226         person_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
227         person_record = GeniRecord(name=hrn, gid=person_gid, type="user", pointer=person['person_id'])
228         report.trace("  inserting user record for " + hrn)
229         table.insert(person_record)
230     else:
231         key_ids = person["key_ids"]
232
233     if key_ids:
234         pkey = get_pl_pubkey(key_ids[0])
235         person_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
236         person_record = GeniRecord(name=hrn, gid=person_gid, type="user", pointer=person['person_id'])
237         report.trace("  updating user record for " + hrn)
238         table.update(person_record)
239         
240 def import_slice(parent_hrn, slice):
241     AuthHierarchy = Hierarchy()
242     slicename = slice['name'].split("_",1)[-1]
243     slicename = cleanup_string(slicename)
244
245     if not slicename:
246         report.error("Import_Slice: failed to parse slice name " + slice['name'])
247         return
248
249     hrn = parent_hrn + "." + slicename
250     report.trace("Import: importing slice " + hrn)
251
252     table = get_auth_table(parent_hrn)
253
254     slice_record = table.resolve("slice", hrn)
255     if not slice_record:
256         pkey = Keypair(create=True)
257         slice_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
258         slice_record = GeniRecord(name=hrn, gid=slice_gid, type="slice", pointer=slice['slice_id'])
259         report.trace("  inserting slice record for " + hrn)
260         table.insert(slice_record)
261
262 def import_node(parent_hrn, node):
263     AuthHierarchy = Hierarchy()
264     nodename = node['hostname'].split(".")[0]
265     nodename = cleanup_string(nodename)
266
267     if not nodename:
268         report.error("Import_node: failed to parse node name " + node['hostname'])
269         return
270
271     hrn = parent_hrn + "." + nodename
272
273     # ASN.1 will have problems with hrn's longer than 64 characters
274     if len(hrn) > 64:
275         hrn = hrn[:64]
276
277     report.trace("Import: importing node " + hrn)
278
279     table = get_auth_table(parent_hrn)
280
281     node_record = table.resolve("node", hrn)
282     if not node_record:
283         pkey = Keypair(create=True)
284         node_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
285         node_record = GeniRecord(name=hrn, gid=node_gid, type="node", pointer=node['node_id'])
286         report.trace("  inserting node record for " + hrn)
287         table.insert(node_record)
288
289 def import_site(parent_hrn, site):
290     AuthHierarchy = Hierarchy()
291     sitename = site['login_base']
292     sitename = cleanup_string(sitename)
293     
294     hrn = parent_hrn + "." + sitename
295
296     # Hardcode 'internet2' into the hrn for sites hosting 
297     # internet2 nodes. This is a special operation for some vini
298     # sites
299     if parent_hrn.find(".vini.") != -1 and 'abbreviated_name' in site:
300         abv_name = site['abbreviated_name']
301         if abv_name.find("I2") != -1 or abv_name.find("NLR") != -1:
302             if sitename.startswith("ii"): sitename.replace("ii", "")
303             if sitename.startswith("nlr"): sitename.replace("nlr", "")
304             hrn = ".".join([parent_hrn, "internet2", sitename]) 
305  
306     report.trace("Import_Site: importing site " + hrn)
307
308     # create the authority
309     if not AuthHierarchy.auth_exists(hrn):
310         AuthHierarchy.create_auth(hrn)
311
312     auth_info = AuthHierarchy.get_auth_info(hrn)
313
314     table = get_auth_table(parent_hrn)
315
316     auth_record = table.resolve("authority", hrn)
317     if not auth_record:
318         auth_record = GeniRecord(name=hrn, gid=auth_info.get_gid_object(), type="authority", pointer=site['site_id'])
319         report.trace("  inserting authority record for " + hrn)
320         table.insert(auth_record)
321
322     if 'person_ids' in site: 
323         for person_id in site['person_ids']:
324             persons = shell.GetPersons(pl_auth, [person_id])
325             if persons:
326                 try: 
327                     import_person(hrn, persons[0])
328                 except:
329                     report.trace("Failed to import: %s" % persons[0])
330     if 'slice_ids' in site:
331         for slice_id in site['slice_ids']:
332             slices = shell.GetSlices(pl_auth, [slice_id])
333             if slices:
334                 try:
335                     import_slice(hrn, slices[0])
336                 except:
337                     report.trace("Failed to import: %s" % slices[0])
338     if 'node_ids' in site:
339         for node_id in site['node_ids']:
340             nodes = shell.GetNodes(pl_auth, [node_id])
341             if nodes:
342                 try:
343                     import_node(hrn, nodes[0])
344                 except:
345                     report.trace("Failed to import: %s" % nodes[0])
346
347 def create_top_level_auth_records(hrn):
348     parent_hrn = get_authority(hrn)
349     print hrn, ":", parent_hrn
350     if not parent_hrn:
351         parent_hrn = hrn    
352     auth_info = AuthHierarchy.get_auth_info(parent_hrn)
353     table = get_auth_table(parent_hrn)
354
355     auth_record = table.resolve("authority", hrn)
356     if not auth_record:
357         auth_record = GeniRecord(name=hrn, gid=auth_info.get_gid_object(), type="authority", pointer=-1)
358         report.trace("  inserting authority record for " + hrn)
359         table.insert(auth_record)
360
361 def main():
362     global AuthHierarchy
363     global TrustedRoots
364
365     process_options()
366
367     print "Base Directory: ", config.GENI_BASE_DIR
368
369     AuthHierarchy = Hierarchy()
370     TrustedRoots = TrustedRootList()
371
372     print "Import: creating top level authorities"
373
374     if not AuthHierarchy.auth_exists(root_auth):
375         AuthHierarchy.create_auth(root_auth)
376
377     create_top_level_auth_records(root_auth)
378     if level1_auth:
379         if not AuthHierarchy.auth_exists(level1_auth):
380             AuthHierarchy.create_auth(level1_auth)
381         create_top_level_auth_records(level1_auth)
382         import_auth = level1_auth
383     else:
384         import_auth = root_auth
385
386     print "Import: adding", root_auth, "to trusted list"
387     root = AuthHierarchy.get_auth_info(root_auth)
388     TrustedRoots.add_gid(root.get_gid_object())
389
390     connect_shell()
391
392     sites = shell.GetSites(pl_auth)
393     for site in sites:
394         import_site(import_auth, site)
395
396 if __name__ == "__main__":
397     main()