0d8e19d780c3c90a3b84b0d85958ec8693932239
[sfa.git] / geni / gimport.py
1 ##
2 # Import PLC records into the Geni database. It is indended that this tool be
3 # run once to create Geni records that reflect the current state of the
4 # planetlab database.
5 #
6 # The import tool assumes that the existing PLC hierarchy should all be part
7 # of "planetlab.us" (see the root_auth and level1_auth variables below).
8 #
9 # Public keys are extracted from the users' SSH keys automatically and used to
10 # create GIDs. This is relatively experimental as a custom tool had to be
11 # written to perform conversion from SSH to OpenSSL format. It only supports
12 # RSA keys at this time, not DSA keys.
13 ##
14
15 import getopt
16 import sys
17 import tempfile
18
19 from geni.util.cert import *
20 from geni.util.trustedroot import *
21 from geni.util.hierarchy import *
22 from geni.util.record import *
23 from geni.util.genitable import *
24 from geni.util.misc import *
25
26 shell = None
27
28 ##
29 # Two authorities are specified: the root authority and the level1 authority.
30
31 root_auth = "planetlab"
32 level1_auth = "planetlab.us"
33
34
35 def un_unicode(str):
36    if isinstance(str, unicode):
37        return str.encode("ascii", "ignore")
38    else:
39        return str
40
41 def cleanup_string(str):
42     # pgsql has a fit with strings that have high ascii in them, so filter it
43     # out when generating the hrns.
44     tmp = ""
45     for c in str:
46         if ord(c) < 128:
47             tmp = tmp + c
48     str = tmp
49
50     str = un_unicode(str)
51     str = str.replace(" ", "_")
52     str = str.replace(".", "_")
53     str = str.replace("(", "_")
54     str = str.replace("'", "_")
55     str = str.replace(")", "_")
56     str = str.replace('"', "_")
57     return str
58
59 def process_options():
60    global hrn
61
62    (options, args) = getopt.getopt(sys.argv[1:], '', [])
63    for opt in options:
64        name = opt[0]
65        val = opt[1]
66
67 def connect_shell():
68     global pl_auth, shell
69
70     # get PL account settings from config module
71     pl_auth = get_pl_auth()
72
73     # connect to planetlab
74     if "Url" in pl_auth:
75         from geni.util import remoteshell
76         shell = remoteshell.RemoteShell()
77     else:
78         import PLC.Shell
79         shell = PLC.Shell.Shell(globals = globals())
80
81 def get_auth_table(auth_name):
82     AuthHierarchy = Hierarchy()
83     auth_info = AuthHierarchy.get_auth_info(auth_name)
84
85     table = GeniTable(hrn=auth_name,
86                       cninfo=auth_info.get_dbinfo())
87
88     # if the table doesn't exist, then it means we haven't put any records
89     # into this authority yet.
90
91     if not table.exists():
92         report.trace("Import: creating table for authority " + auth_name)
93         table.create()
94
95     return table
96
97 def get_pl_pubkey(key_id):
98     keys = shell.GetKeys(pl_auth, [key_id])
99     if keys:
100         key_str = keys[0]['key']
101
102         if "ssh-dss" in key_str:
103             print "XXX: DSA key encountered, ignoring"
104             return None
105
106         # generate temporary files to hold the keys
107         (ssh_f, ssh_fn) = tempfile.mkstemp()
108         ssl_fn = tempfile.mktemp()
109
110         os.write(ssh_f, key_str)
111         os.close(ssh_f)
112
113         cmd = "../keyconvert/keyconvert " + ssh_fn + " " + ssl_fn
114         print cmd
115         os.system(cmd)
116
117         # this check leaves the temporary file containing the public key so
118         # that it can be expected to see why it failed.
119         # TODO: for production, cleanup the temporary files
120         if not os.path.exists(ssl_fn):
121             report.trace("  failed to convert key from " + ssh_fn + " to " + ssl_fn)
122             return None
123
124         k = Keypair()
125         try:
126             k.load_pubkey_from_file(ssl_fn)
127         except:
128             print "XXX: Error while converting key: ", key_str
129             k = None
130
131         # remove the temporary files
132         os.remove(ssh_fn)
133         os.remove(ssl_fn)
134
135         return k
136     else:
137         return None
138
139 def person_to_hrn(parent_hrn, person):
140     personname = person['last_name'] + "_" + person['first_name']
141
142     personname = cleanup_string(personname)
143
144     hrn = parent_hrn + "." + personname
145     return hrn
146
147 def import_person(parent_hrn, person):
148     AuthHierarchy = Hierarchy()
149     hrn = person_to_hrn(parent_hrn, person)
150
151     # ASN.1 will have problems with hrn's longer than 64 characters
152     if len(hrn) > 64:
153         hrn = hrn[:64]
154
155     report.trace("Import: importing person " + hrn)
156
157     table = get_auth_table(parent_hrn)
158
159     person_record = table.resolve("user", hrn)
160     if not person_record:
161         key_ids = person["key_ids"]
162
163         if key_ids:
164             # get the user's private key from the SSH keys they have uploaded
165             # to planetlab
166             pkey = get_pl_pubkey(key_ids[0])
167         else:
168             # the user has no keys
169             report.trace("   person " + hrn + " does not have a PL public key")
170             pkey = None
171
172         # if a key is unavailable, then we still need to put something in the
173         # user's GID. So make one up.
174         if not pkey:
175             pkey = Keypair(create=True)
176
177         person_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
178         person_record = GeniRecord(name=hrn, gid=person_gid, type="user", pointer=person['person_id'])
179         report.trace("  inserting user record for " + hrn)
180         table.insert(person_record)
181     else:
182         key_ids = person["key_ids"]
183         if key_ids:
184             pkey = get_pl_pubkey(key_ids[0])
185             person_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
186             person_record = GeniRecord(name=hrn, gid=person_gid, type="user", pointer=person['person_id'])
187             report.trace("  updating user record for " + hrn)
188             table.update(person_record)
189             
190 def import_slice(parent_hrn, slice):
191     AuthHierarchy = Hierarchy()
192     slicename = slice['name'].split("_",1)[-1]
193     slicename = cleanup_string(slicename)
194
195     if not slicename:
196         report.error("Import_Slice: failed to parse slice name " + slice['name'])
197         return
198
199     hrn = parent_hrn + "." + slicename
200     report.trace("Import: importing slice " + hrn)
201
202     table = get_auth_table(parent_hrn)
203
204     slice_record = table.resolve("slice", hrn)
205     if not slice_record:
206         pkey = Keypair(create=True)
207         slice_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
208         slice_record = GeniRecord(name=hrn, gid=slice_gid, type="slice", pointer=slice['slice_id'])
209         report.trace("  inserting slice record for " + hrn)
210         table.insert(slice_record)
211
212 def import_node(parent_hrn, node):
213     AuthHierarchy = Hierarchy()
214     nodename = node['hostname']
215     nodename = cleanup_string(nodename)
216
217     if not nodename:
218         report.error("Import_node: failed to parse node name " + node['hostname'])
219         return
220
221     hrn = parent_hrn + "." + nodename
222
223     # ASN.1 will have problems with hrn's longer than 64 characters
224     if len(hrn) > 64:
225         hrn = hrn[:64]
226
227     report.trace("Import: importing node " + hrn)
228
229     table = get_auth_table(parent_hrn)
230
231     node_record = table.resolve("node", hrn)
232     if not node_record:
233         pkey = Keypair(create=True)
234         node_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
235         node_record = GeniRecord(name=hrn, gid=node_gid, type="node", pointer=node['node_id'])
236         report.trace("  inserting node record for " + hrn)
237         table.insert(node_record)
238
239 def import_site(parent_hrn, site):
240     AuthHierarchy = Hierarchy()
241     sitename = site['login_base']
242     sitename = cleanup_string(sitename)
243
244     hrn = parent_hrn + "." + sitename
245
246     report.trace("Import_Site: importing site " + hrn)
247
248     # create the authority
249     if not AuthHierarchy.auth_exists(hrn):
250         AuthHierarchy.create_auth(hrn)
251
252     auth_info = AuthHierarchy.get_auth_info(hrn)
253
254     table = get_auth_table(parent_hrn)
255
256     sa_record = table.resolve("sa", hrn)
257     if not sa_record:
258         sa_record = GeniRecord(name=hrn, gid=auth_info.get_gid_object(), type="sa", pointer=site['site_id'])
259         report.trace("  inserting sa record for " + hrn)
260         table.insert(sa_record)
261
262     ma_record = table.resolve("ma", hrn)
263     if not ma_record:
264         ma_record = GeniRecord(name=hrn, gid=auth_info.get_gid_object(), type="ma", pointer=site['site_id'])
265         report.trace("  inserting ma record for " + hrn)
266         table.insert(ma_record)
267
268     for person_id in site['person_ids']:
269         persons = shell.GetPersons(pl_auth, [person_id])
270         if persons:
271             try: 
272                 import_person(hrn, persons[0])
273             except:
274                 report.trace("Failed to import: %s" % persons[0])
275     for slice_id in site['slice_ids']:
276         slices = shell.GetSlices(pl_auth, [slice_id])
277         if slices:
278             try:
279                 import_slice(hrn, slices[0])
280             except:
281                 report.trace("Failed to import: %s" % slices[0])
282     for node_id in site['node_ids']:
283         nodes = shell.GetNodes(pl_auth, [node_id])
284         if nodes:
285             try:
286                 import_node(hrn, nodes[0])
287             except:
288                 report.trace("Failed to import: %s" % nodes[0])
289
290 def create_top_level_auth_records(hrn):
291     parent_hrn = get_authority(hrn)
292     print hrn, ":", parent_hrn  
293     auth_info = AuthHierarchy.get_auth_info(parent_hrn)
294     table = get_auth_table(parent_hrn)
295
296     sa_record = table.resolve("sa", hrn)
297     if not sa_record:
298         sa_record = GeniRecord(name=hrn, gid=auth_info.get_gid_object(), type="sa", pointer=-1)
299         report.trace("  inserting sa record for " + hrn)
300         table.insert(sa_record)
301
302     ma_record = table.resolve("ma", hrn)
303     if not ma_record:
304         ma_record = GeniRecord(name=hrn, gid=auth_info.get_gid_object(), type="ma", pointer=-1)
305         report.trace("  inserting ma record for " + hrn)
306         table.insert(ma_record)
307
308 def main():
309     global AuthHierarchy
310     global TrustedRoots
311
312     process_options()
313
314     AuthHierarchy = Hierarchy()
315     TrustedRoots = TrustedRootList()
316
317     print "Import: creating top level authorities"
318
319     if not AuthHierarchy.auth_exists(root_auth):
320         AuthHierarchy.create_auth(root_auth)
321     #create_top_level_auth_records(root_auth)
322     if not AuthHierarchy.auth_exists(level1_auth):
323         AuthHierarchy.create_auth(level1_auth)
324     create_top_level_auth_records(level1_auth)
325
326     print "Import: adding", root_auth, "to trusted list"
327     root = AuthHierarchy.get_auth_info(root_auth)
328     TrustedRoots.add_gid(root.get_gid_object())
329
330     connect_shell()
331
332     sites = shell.GetSites(pl_auth)
333     for site in sites:
334         import_site(level1_auth, site)
335
336 if __name__ == "__main__":
337     main()