SSH daemonization test fix, along with environment setup fixes.
[nepi.git] / src / nepi / util / proxy.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import base64
5 import nepi.core.execute
6 import nepi.util.environ
7 from nepi.core.attributes import AttributesMap, Attribute
8 from nepi.util import server, validation
9 from nepi.util.constants import TIME_NOW, ATTR_NEPI_TESTBED_ENVIRONMENT_SETUP, DeploymentConfiguration as DC
10 import getpass
11 import cPickle
12 import sys
13 import time
14 import tempfile
15 import shutil
16 import functools
17 import os
18
19 # PROTOCOL REPLIES
20 OK = 0
21 ERROR = 1
22
23 # PROTOCOL INSTRUCTION MESSAGES
24 XML = 2 
25 TRACE   = 4
26 FINISHED    = 5
27 START   = 6
28 STOP    = 7
29 SHUTDOWN    = 8
30 CONFIGURE   = 9
31 CREATE      = 10
32 CREATE_SET  = 11
33 FACTORY_SET = 12
34 CONNECT     = 13
35 CROSS_CONNECT   = 14
36 ADD_TRACE   = 15
37 ADD_ADDRESS = 16
38 ADD_ROUTE   = 17
39 DO_SETUP    = 18
40 DO_CREATE   = 19
41 DO_CONNECT_INIT = 20
42 DO_CONFIGURE    = 21
43 DO_CROSS_CONNECT_INIT   = 22
44 GET = 23
45 SET = 24
46 ACTION  = 25
47 STATUS  = 26
48 GUIDS  = 27
49 GET_ROUTE = 28
50 GET_ADDRESS = 29
51 RECOVER = 30
52 DO_PRECONFIGURE     = 31
53 GET_ATTRIBUTE_LIST  = 32
54 DO_CONNECT_COMPL    = 33
55 DO_CROSS_CONNECT_COMPL  = 34
56 TESTBED_ID  = 35
57 TESTBED_VERSION  = 36
58 DO_PRESTART = 37
59 GET_FACTORY_ID = 38
60 GET_TESTBED_ID = 39
61 GET_TESTBED_VERSION = 40
62 TRACES_INFO = 41
63 EXEC_XML = 42
64
65 instruction_text = dict({
66     OK:     "OK",
67     ERROR:  "ERROR",
68     XML:    "XML",
69     EXEC_XML:    "EXEC_XML",
70     TRACE:  "TRACE",
71     FINISHED:   "FINISHED",
72     START:  "START",
73     STOP:   "STOP",
74     RECOVER: "RECOVER",
75     SHUTDOWN:   "SHUTDOWN",
76     CONFIGURE:  "CONFIGURE",
77     CREATE: "CREATE",
78     CREATE_SET: "CREATE_SET",
79     FACTORY_SET:    "FACTORY_SET",
80     CONNECT:    "CONNECT",
81     CROSS_CONNECT: "CROSS_CONNECT",
82     ADD_TRACE:  "ADD_TRACE",
83     ADD_ADDRESS:    "ADD_ADDRESS",
84     ADD_ROUTE:  "ADD_ROUTE",
85     DO_SETUP:   "DO_SETUP",
86     DO_CREATE:  "DO_CREATE",
87     DO_CONNECT_INIT: "DO_CONNECT_INIT",
88     DO_CONNECT_COMPL: "DO_CONNECT_COMPL",
89     DO_CONFIGURE:   "DO_CONFIGURE",
90     DO_PRECONFIGURE:   "DO_PRECONFIGURE",
91     DO_CROSS_CONNECT_INIT:  "DO_CROSS_CONNECT_INIT",
92     DO_CROSS_CONNECT_COMPL: "DO_CROSS_CONNECT_COMPL",
93     GET:    "GET",
94     SET:    "SET",
95     GET_ROUTE: "GET_ROUTE",
96     GET_ADDRESS: "GET_ADDRESS",
97     GET_ATTRIBUTE_LIST: "GET_ATTRIBUTE_LIST",
98     GET_FACTORY_ID: "GET_FACTORY_ID",
99     GET_TESTBED_ID: "GET_TESTBED_ID",
100     GET_TESTBED_VERSION: "GET_TESTBED_VERSION",
101     ACTION: "ACTION",
102     STATUS: "STATUS",
103     GUIDS:  "GUIDS",
104     TESTBED_ID: "TESTBED_ID",
105     TESTBED_VERSION: "TESTBED_VERSION",
106     TRACES_INFO: "TRACES_INFO",
107     })
108
109 def log_msg(server, params):
110     try:
111         instr = int(params[0])
112         instr_txt = instruction_text[instr]
113         server.log_debug("%s - msg: %s [%s]" % (server.__class__.__name__, 
114             instr_txt, ", ".join(map(str, params[1:]))))
115     except:
116         # don't die for logging
117         pass
118
119 def log_reply(server, reply):
120     try:
121         res = reply.split("|")
122         code = int(res[0])
123         code_txt = instruction_text[code]
124         try:
125             txt = base64.b64decode(res[1])
126         except:
127             txt = res[1]
128         server.log_debug("%s - reply: %s %s" % (server.__class__.__name__, 
129                 code_txt, txt))
130     except:
131         # don't die for logging
132         server.log_debug("%s - reply: %s" % (server.__class__.__name__, 
133                 reply))
134         pass
135
136 def to_server_log_level(log_level):
137     return (
138         server.DEBUG_LEVEL
139             if log_level == DC.DEBUG_LEVEL 
140         else server.ERROR_LEVEL
141     )
142
143 def get_access_config_params(access_config):
144     root_dir = access_config.get_attribute_value(DC.ROOT_DIRECTORY)
145     log_level = access_config.get_attribute_value(DC.LOG_LEVEL)
146     log_level = to_server_log_level(log_level)
147     user = host = port = agent = key = None
148     communication = access_config.get_attribute_value(DC.DEPLOYMENT_COMMUNICATION)
149     environment_setup = (
150         access_config.get_attribute_value(DC.DEPLOYMENT_ENVIRONMENT_SETUP)
151         if access_config.has_attribute(DC.DEPLOYMENT_ENVIRONMENT_SETUP)
152         else None
153     )
154     if communication == DC.ACCESS_SSH:
155         user = access_config.get_attribute_value(DC.DEPLOYMENT_USER)
156         host = access_config.get_attribute_value(DC.DEPLOYMENT_HOST)
157         port = access_config.get_attribute_value(DC.DEPLOYMENT_PORT)
158         agent = access_config.get_attribute_value(DC.USE_AGENT)
159         key = access_config.get_attribute_value(DC.DEPLOYMENT_KEY)
160     return (root_dir, log_level, user, host, port, key, agent, environment_setup)
161
162 class AccessConfiguration(AttributesMap):
163     def __init__(self, params = None):
164         super(AccessConfiguration, self).__init__()
165         
166         from nepi.core.metadata import Metadata
167         
168         for _,attr_info in Metadata.DEPLOYMENT_ATTRIBUTES.iteritems():
169             self.add_attribute(**attr_info)
170         
171         if params:
172             for attr_name, attr_value in params.iteritems():
173                 parser = Attribute.type_parsers[self.get_attribute_type(attr_name)]
174                 attr_value = parser(attr_value)
175                 self.set_attribute_value(attr_name, attr_value)
176
177 class TempDir(object):
178     def __init__(self):
179         self.path = tempfile.mkdtemp()
180     
181     def __del__(self):
182         shutil.rmtree(self.path)
183
184 class PermDir(object):
185     def __init__(self, path):
186         self.path = path
187
188 def create_experiment_controller(xml, access_config = None):
189     mode = None if not access_config \
190             else access_config.get_attribute_value(DC.DEPLOYMENT_MODE)
191     launch = True if not access_config \
192             else not access_config.get_attribute_value(DC.RECOVER)
193     if not mode or mode == DC.MODE_SINGLE_PROCESS:
194         if not launch:
195             raise ValueError, "Unsupported instantiation mode: %s with lanch=False" % (mode,)
196         
197         from nepi.core.execute import ExperimentController
198         
199         if not access_config or not access_config.has_attribute(DC.ROOT_DIRECTORY):
200             root_dir = TempDir()
201         else:
202             root_dir = PermDir(access_config.get_attribute_value(DC.ROOT_DIRECTORY))
203         controller = ExperimentController(xml, root_dir.path)
204         
205         # inject reference to temporary dir, so that it gets cleaned
206         # up at destruction time.
207         controller._tempdir = root_dir
208         
209         return controller
210     elif mode == DC.MODE_DAEMON:
211         (root_dir, log_level, user, host, port, key, agent, environment_setup) = \
212                 get_access_config_params(access_config)
213         return ExperimentControllerProxy(root_dir, log_level,
214                 experiment_xml = xml, host = host, port = port, user = user, ident_key = key,
215                 agent = agent, launch = launch,
216                 environment_setup = environment_setup)
217     raise RuntimeError("Unsupported access configuration '%s'" % mode)
218
219 def create_testbed_controller(testbed_id, testbed_version, access_config):
220     mode = None if not access_config \
221             else access_config.get_attribute_value(DC.DEPLOYMENT_MODE)
222     launch = True if not access_config \
223             else not access_config.get_attribute_value(DC.RECOVER)
224     if not mode or mode == DC.MODE_SINGLE_PROCESS:
225         if not launch:
226             raise ValueError, "Unsupported instantiation mode: %s with lanch=False" % (mode,)
227         return  _build_testbed_controller(testbed_id, testbed_version)
228     elif mode == DC.MODE_DAEMON:
229         (root_dir, log_level, user, host, port, key, agent, environment_setup) = \
230                 get_access_config_params(access_config)
231         return TestbedControllerProxy(root_dir, log_level, testbed_id = testbed_id, 
232                 testbed_version = testbed_version, host = host, port = port, ident_key = key,
233                 user = user, agent = agent, launch = launch,
234                 environment_setup = environment_setup)
235     raise RuntimeError("Unsupported access configuration '%s'" % mode)
236
237 def _build_testbed_controller(testbed_id, testbed_version):
238     mod_name = nepi.util.environ.find_testbed(testbed_id)
239     
240     if not mod_name in sys.modules:
241         try:
242             __import__(mod_name)
243         except ImportError:
244             raise ImportError, "Cannot find module %s in %r" % (mod_name, sys.path)
245     
246     module = sys.modules[mod_name]
247     tc = module.TestbedController()
248     if tc.testbed_version != testbed_version:
249         raise RuntimeError("Bad testbed version on testbed %s. Asked for %s, got %s" % \
250                 (testbed_id, testbed_version, tc.testbed_version))
251     return tc
252
253 # Just a namespace class
254 class Marshalling:
255     class Decoders:
256         @staticmethod
257         def pickled_data(sdata):
258             return cPickle.loads(base64.b64decode(sdata))
259         
260         @staticmethod
261         def base64_data(sdata):
262             return base64.b64decode(sdata)
263         
264         @staticmethod
265         def nullint(sdata):
266             return None if sdata == "None" else int(sdata)
267         
268         @staticmethod
269         def bool(sdata):
270             return sdata == 'True'
271         
272     class Encoders:
273         @staticmethod
274         def pickled_data(data):
275             return base64.b64encode(cPickle.dumps(data))
276         
277         @staticmethod
278         def base64_data(data):
279             return base64.b64encode(data)
280         
281         @staticmethod
282         def nullint(data):
283             return "None" if data is None else int(data)
284         
285         @staticmethod
286         def bool(data):
287             return str(bool(data))
288            
289     # import into Marshalling all the decoders
290     # they act as types
291     locals().update([
292         (typname, typ)
293         for typname, typ in vars(Decoders).iteritems()
294         if not typname.startswith('_')
295     ])
296
297     _TYPE_ENCODERS = dict([
298         # id(type) -> (<encoding_function>, <formatting_string>)
299         (typname, (getattr(Encoders,typname),"%s"))
300         for typname in vars(Decoders)
301         if not typname.startswith('_')
302            and hasattr(Encoders,typname)
303     ])
304
305     # Builtins
306     _TYPE_ENCODERS["float"] = (float, "%r")
307     _TYPE_ENCODERS["int"] = (int, "%d")
308     _TYPE_ENCODERS["long"] = (int, "%d")
309     _TYPE_ENCODERS["str"] = (str, "%s")
310     _TYPE_ENCODERS["unicode"] = (str, "%s")
311     
312     # Generic encoder
313     _TYPE_ENCODERS[None] = (str, "%s")
314     
315     @staticmethod
316     def args(*types):
317         """
318         Decorator that converts the given function into one that takes
319         a single "params" list, with each parameter marshalled according
320         to the given factory callable (type constructors are accepted).
321         
322         The first argument (self) is left untouched.
323         
324         eg:
325         
326         @Marshalling.args(int,int,str,base64_data)
327         def somefunc(self, someint, otherint, somestr, someb64):
328            return someretval
329         """
330         def decor(f):
331             @functools.wraps(f)
332             def rv(self, params):
333                 return f(self, *[ ctor(val)
334                                   for ctor,val in zip(types, params[1:]) ])
335             
336             rv._argtypes = types
337             
338             # Derive type encoders by looking up types in _TYPE_ENCODERS
339             # make_proxy will use it to encode arguments in command strings
340             argencoders = []
341             TYPE_ENCODERS = Marshalling._TYPE_ENCODERS
342             for typ in types:
343                 if typ.__name__ in TYPE_ENCODERS:
344                     argencoders.append(TYPE_ENCODERS[typ.__name__])
345                 else:
346                     # generic encoder
347                     argencoders.append(TYPE_ENCODERS[None])
348             
349             rv._argencoders = tuple(argencoders)
350             
351             rv._retval = getattr(f, '_retval', None)
352             return rv
353         return decor
354
355     @staticmethod
356     def retval(typ=Decoders.base64_data):
357         """
358         Decorator that converts the given function into one that 
359         returns a properly encoded return string, given that the undecorated
360         function returns suitable input for the encoding function.
361         
362         The optional typ argument specifies a type.
363         For the default of base64_data, return values should be strings.
364         The return value of the encoding method should be a string always.
365         
366         eg:
367         
368         @Marshalling.args(int,int,str,base64_data)
369         @Marshalling.retval(str)
370         def somefunc(self, someint, otherint, somestr, someb64):
371            return someint
372         """
373         encode, fmt = Marshalling._TYPE_ENCODERS.get(
374             typ.__name__,
375             Marshalling._TYPE_ENCODERS[None])
376         fmt = "%d|"+fmt
377         
378         def decor(f):
379             @functools.wraps(f)
380             def rv(self, *p, **kw):
381                 data = f(self, *p, **kw)
382                 return fmt % (
383                     OK,
384                     encode(data)
385                 )
386             rv._retval = typ
387             rv._argtypes = getattr(f, '_argtypes', None)
388             rv._argencoders = getattr(f, '_argencoders', None)
389             return rv
390         return decor
391     
392     @staticmethod
393     def retvoid(f):
394         """
395         Decorator that converts the given function into one that 
396         always return an encoded empty string.
397         
398         Useful for null-returning functions.
399         """
400         OKRV = "%d|" % (OK,)
401         
402         @functools.wraps(f)
403         def rv(self, *p, **kw):
404             f(self, *p, **kw)
405             return OKRV
406         
407         rv._retval = None
408         rv._argtypes = getattr(f, '_argtypes', None)
409         rv._argencoders = getattr(f, '_argencoders', None)
410         return rv
411     
412     @staticmethod
413     def handles(whichcommand):
414         """
415         Associates the method with a given command code for servers.
416         It should always be the topmost decorator.
417         """
418         def decor(f):
419             f._handles_command = whichcommand
420             return f
421         return decor
422
423 class BaseServer(server.Server):
424     def reply_action(self, msg):
425         if not msg:
426             result = base64.b64encode("Invalid command line")
427             reply = "%d|%s" % (ERROR, result)
428         else:
429             params = msg.split("|")
430             instruction = int(params[0])
431             log_msg(self, params)
432             try:
433                 for mname,meth in vars(self.__class__).iteritems():
434                     if not mname.startswith('_'):
435                         cmd = getattr(meth, '_handles_command', None)
436                         if cmd == instruction:
437                             meth = getattr(self, mname)
438                             reply = meth(params)
439                             break
440                 else:
441                     error = "Invalid instruction %s" % instruction
442                     self.log_error(error)
443                     result = base64.b64encode(error)
444                     reply = "%d|%s" % (ERROR, result)
445             except:
446                 error = self.log_error()
447                 result = base64.b64encode(error)
448                 reply = "%d|%s" % (ERROR, result)
449         log_reply(self, reply)
450         return reply
451
452 class TestbedControllerServer(BaseServer):
453     def __init__(self, root_dir, log_level, testbed_id, testbed_version, environment_setup):
454         super(TestbedControllerServer, self).__init__(root_dir, log_level, 
455             environment_setup = environment_setup )
456         self._testbed_id = testbed_id
457         self._testbed_version = testbed_version
458         self._testbed = None
459
460     def post_daemonize(self):
461         self._testbed = _build_testbed_controller(self._testbed_id, 
462                 self._testbed_version)
463
464     @Marshalling.handles(GUIDS)
465     @Marshalling.args()
466     @Marshalling.retval( Marshalling.pickled_data )
467     def guids(self):
468         return self._testbed.guids
469
470     @Marshalling.handles(TESTBED_ID)
471     @Marshalling.args()
472     @Marshalling.retval()
473     def testbed_id(self):
474         return str(self._testbed.testbed_id)
475
476     @Marshalling.handles(TESTBED_VERSION)
477     @Marshalling.args()
478     @Marshalling.retval()
479     def testbed_version(self):
480         return str(self._testbed.testbed_version)
481
482     @Marshalling.handles(CREATE)
483     @Marshalling.args(int, str)
484     @Marshalling.retvoid
485     def defer_create(self, guid, factory_id):
486         self._testbed.defer_create(guid, factory_id)
487
488     @Marshalling.handles(TRACE)
489     @Marshalling.args(int, str, Marshalling.base64_data)
490     @Marshalling.retval()
491     def trace(self, guid, trace_id, attribute):
492         return self._testbed.trace(guid, trace_id, attribute)
493
494     @Marshalling.handles(TRACES_INFO)
495     @Marshalling.args()
496     @Marshalling.retval( Marshalling.pickled_data )
497     def traces_info(self):
498         return self._testbed.traces_info()
499
500     @Marshalling.handles(START)
501     @Marshalling.args()
502     @Marshalling.retvoid
503     def start(self):
504         self._testbed.start()
505
506     @Marshalling.handles(STOP)
507     @Marshalling.args()
508     @Marshalling.retvoid
509     def stop(self):
510         self._testbed.stop()
511
512     @Marshalling.handles(SHUTDOWN)
513     @Marshalling.args()
514     @Marshalling.retvoid
515     def shutdown(self):
516         self._testbed.shutdown()
517
518     @Marshalling.handles(CONFIGURE)
519     @Marshalling.args(Marshalling.base64_data, Marshalling.pickled_data)
520     @Marshalling.retvoid
521     def defer_configure(self, name, value):
522         self._testbed.defer_configure(name, value)
523
524     @Marshalling.handles(CREATE_SET)
525     @Marshalling.args(int, Marshalling.base64_data, Marshalling.pickled_data)
526     @Marshalling.retvoid
527     def defer_create_set(self, guid, name, value):
528         self._testbed.defer_create_set(guid, name, value)
529
530     @Marshalling.handles(FACTORY_SET)
531     @Marshalling.args(Marshalling.base64_data, Marshalling.pickled_data)
532     @Marshalling.retvoid
533     def defer_factory_set(self, name, value):
534         self._testbed.defer_factory_set(name, value)
535
536     @Marshalling.handles(CONNECT)
537     @Marshalling.args(int, str, int, str)
538     @Marshalling.retvoid
539     def defer_connect(self, guid1, connector_type_name1, guid2, connector_type_name2):
540         self._testbed.defer_connect(guid1, connector_type_name1, guid2, 
541             connector_type_name2)
542
543     @Marshalling.handles(CROSS_CONNECT)
544     @Marshalling.args(int, str, int, int, str, str, str)
545     @Marshalling.retvoid
546     def defer_cross_connect(self, 
547             guid, connector_type_name,
548             cross_guid, cross_testbed_guid,
549             cross_testbed_id, cross_factory_id,
550             cross_connector_type_name):
551         self._testbed.defer_cross_connect(guid, connector_type_name, cross_guid, 
552             cross_testbed_guid, cross_testbed_id, cross_factory_id, 
553             cross_connector_type_name)
554
555     @Marshalling.handles(ADD_TRACE)
556     @Marshalling.args(int, str)
557     @Marshalling.retvoid
558     def defer_add_trace(self, guid, trace_id):
559         self._testbed.defer_add_trace(guid, trace_id)
560
561     @Marshalling.handles(ADD_ADDRESS)
562     @Marshalling.args(int, str, int, Marshalling.pickled_data)
563     @Marshalling.retvoid
564     def defer_add_address(self, guid, address, netprefix, broadcast):
565         self._testbed.defer_add_address(guid, address, netprefix,
566                 broadcast)
567
568     @Marshalling.handles(ADD_ROUTE)
569     @Marshalling.args(int, str, int, str, int)
570     @Marshalling.retvoid
571     def defer_add_route(self, guid, destination, netprefix, nexthop, metric):
572         self._testbed.defer_add_route(guid, destination, netprefix, nexthop, metric)
573
574     @Marshalling.handles(DO_SETUP)
575     @Marshalling.args()
576     @Marshalling.retvoid
577     def do_setup(self):
578         self._testbed.do_setup()
579
580     @Marshalling.handles(DO_CREATE)
581     @Marshalling.args()
582     @Marshalling.retvoid
583     def do_create(self):
584         self._testbed.do_create()
585
586     @Marshalling.handles(DO_CONNECT_INIT)
587     @Marshalling.args()
588     @Marshalling.retvoid
589     def do_connect_init(self):
590         self._testbed.do_connect_init()
591
592     @Marshalling.handles(DO_CONNECT_COMPL)
593     @Marshalling.args()
594     @Marshalling.retvoid
595     def do_connect_compl(self):
596         self._testbed.do_connect_compl()
597
598     @Marshalling.handles(DO_CONFIGURE)
599     @Marshalling.args()
600     @Marshalling.retvoid
601     def do_configure(self):
602         self._testbed.do_configure()
603
604     @Marshalling.handles(DO_PRECONFIGURE)
605     @Marshalling.args()
606     @Marshalling.retvoid
607     def do_preconfigure(self):
608         self._testbed.do_preconfigure()
609
610     @Marshalling.handles(DO_PRESTART)
611     @Marshalling.args()
612     @Marshalling.retvoid
613     def do_prestart(self):
614         self._testbed.do_prestart()
615
616     @Marshalling.handles(DO_CROSS_CONNECT_INIT)
617     @Marshalling.args( Marshalling.Decoders.pickled_data )
618     @Marshalling.retvoid
619     def do_cross_connect_init(self, cross_data):
620         self._testbed.do_cross_connect_init(cross_data)
621
622     @Marshalling.handles(DO_CROSS_CONNECT_COMPL)
623     @Marshalling.args( Marshalling.Decoders.pickled_data )
624     @Marshalling.retvoid
625     def do_cross_connect_compl(self, cross_data):
626         self._testbed.do_cross_connect_compl(cross_data)
627
628     @Marshalling.handles(GET)
629     @Marshalling.args(int, Marshalling.base64_data, str)
630     @Marshalling.retval( Marshalling.pickled_data )
631     def get(self, guid, name, time):
632         return self._testbed.get(guid, name, time)
633
634     @Marshalling.handles(SET)
635     @Marshalling.args(int, Marshalling.base64_data, Marshalling.pickled_data, str)
636     @Marshalling.retvoid
637     def set(self, guid, name, value, time):
638         self._testbed.set(guid, name, value, time)
639
640     @Marshalling.handles(GET_ADDRESS)
641     @Marshalling.args(int, int, Marshalling.base64_data)
642     @Marshalling.retval()
643     def get_address(self, guid, index, attribute):
644         return str(self._testbed.get_address(guid, index, attribute))
645
646     @Marshalling.handles(GET_ROUTE)
647     @Marshalling.args(int, int, Marshalling.base64_data)
648     @Marshalling.retval()
649     def get_route(self, guid, index, attribute):
650         return str(self._testbed.get_route(guid, index, attribute))
651
652     @Marshalling.handles(ACTION)
653     @Marshalling.args(str, int, Marshalling.base64_data)
654     @Marshalling.retvoid
655     def action(self, time, guid, command):
656         self._testbed.action(time, guid, command)
657
658     @Marshalling.handles(STATUS)
659     @Marshalling.args(Marshalling.nullint)
660     @Marshalling.retval(int)
661     def status(self, guid):
662         return self._testbed.status(guid)
663
664     @Marshalling.handles(GET_ATTRIBUTE_LIST)
665     @Marshalling.args(int, Marshalling.nullint, Marshalling.bool)
666     @Marshalling.retval( Marshalling.pickled_data )
667     def get_attribute_list(self, guid, filter_flags = None, exclude = False):
668         return self._testbed.get_attribute_list(guid, filter_flags, exclude)
669
670     @Marshalling.handles(GET_FACTORY_ID)
671     @Marshalling.args(int)
672     @Marshalling.retval()
673     def get_factory_id(self, guid):
674         return self._testbed.get_factory_id(guid)
675
676 class ExperimentControllerServer(BaseServer):
677     def __init__(self, root_dir, log_level, experiment_xml, environment_setup):
678         super(ExperimentControllerServer, self).__init__(root_dir, log_level, 
679             environment_setup = environment_setup )
680         self._experiment_xml = experiment_xml
681         self._experiment = None
682
683     def post_daemonize(self):
684         from nepi.core.execute import ExperimentController
685         self._experiment = ExperimentController(self._experiment_xml, 
686             root_dir = self._root_dir)
687
688     @Marshalling.handles(GUIDS)
689     @Marshalling.args()
690     @Marshalling.retval( Marshalling.pickled_data )
691     def guids(self):
692         return self._experiment.guids
693
694     @Marshalling.handles(XML)
695     @Marshalling.args()
696     @Marshalling.retval()
697     def experiment_design_xml(self):
698         return self._experiment.experiment_design_xml
699         
700     @Marshalling.handles(EXEC_XML)
701     @Marshalling.args()
702     @Marshalling.retval()
703     def experiment_execute_xml(self):
704         return self._experiment.experiment_execute_xml
705         
706     @Marshalling.handles(TRACE)
707     @Marshalling.args(int, str, Marshalling.base64_data)
708     @Marshalling.retval()
709     def trace(self, guid, trace_id, attribute):
710         return str(self._experiment.trace(guid, trace_id, attribute))
711
712     @Marshalling.handles(TRACES_INFO)
713     @Marshalling.args()
714     @Marshalling.retval( Marshalling.pickled_data )
715     def traces_info(self):
716         return self._experiment.traces_info()
717
718     @Marshalling.handles(FINISHED)
719     @Marshalling.args(int)
720     @Marshalling.retval(Marshalling.bool)
721     def is_finished(self, guid):
722         return self._experiment.is_finished(guid)
723
724     @Marshalling.handles(GET)
725     @Marshalling.args(int, Marshalling.base64_data, str)
726     @Marshalling.retval( Marshalling.pickled_data )
727     def get(self, guid, name, time):
728         return self._experiment.get(guid, name, time)
729
730     @Marshalling.handles(SET)
731     @Marshalling.args(int, Marshalling.base64_data, Marshalling.pickled_data, str)
732     @Marshalling.retvoid
733     def set(self, guid, name, value, time):
734         self._experiment.set(guid, name, value, time)
735
736     @Marshalling.handles(START)
737     @Marshalling.args()
738     @Marshalling.retvoid
739     def start(self):
740         self._experiment.start()
741
742     @Marshalling.handles(STOP)
743     @Marshalling.args()
744     @Marshalling.retvoid
745     def stop(self):
746         self._experiment.stop()
747
748     @Marshalling.handles(RECOVER)
749     @Marshalling.args()
750     @Marshalling.retvoid
751     def recover(self):
752         self._experiment.recover()
753
754     @Marshalling.handles(SHUTDOWN)
755     @Marshalling.args()
756     @Marshalling.retvoid
757     def shutdown(self):
758         self._experiment.shutdown()
759
760     @Marshalling.handles(GET_TESTBED_ID)
761     @Marshalling.args(int)
762     @Marshalling.retval()
763     def get_testbed_id(self, guid):
764         return self._experiment.get_testbed_id(guid)
765
766     @Marshalling.handles(GET_FACTORY_ID)
767     @Marshalling.args(int)
768     @Marshalling.retval()
769     def get_factory_id(self, guid):
770         return self._experiment.get_factory_id(guid)
771
772     @Marshalling.handles(GET_TESTBED_VERSION)
773     @Marshalling.args(int)
774     @Marshalling.retval()
775     def get_testbed_version(self, guid):
776         return self._experiment.get_testbed_version(guid)
777
778 class BaseProxy(object):
779     _ServerClass = None
780     _ServerClassModule = "nepi.util.proxy"
781     
782     def __init__(self, 
783             ctor_args, root_dir, 
784             launch = True, host = None, 
785             port = None, user = None, ident_key = None, agent = None,
786             environment_setup = ""):
787         if launch:
788             # ssh
789             if host != None:
790                 python_code = (
791                     "from %(classmodule)s import %(classname)s;"
792                     "s = %(classname)s%(ctor_args)r;"
793                     "s.run()" 
794                 % dict(
795                     classname = self._ServerClass.__name__,
796                     classmodule = self._ServerClassModule,
797                     ctor_args = ctor_args
798                 ) )
799                 proc = server.popen_ssh_subprocess(python_code, host = host,
800                     port = port, user = user, agent = agent,
801                     ident_key = ident_key,
802                     environment_setup = environment_setup,
803                     waitcommand = True)
804                 if proc.poll():
805                     err = proc.stderr.read()
806                     raise RuntimeError, "Server could not be executed: %s" % (err,)
807             else:
808                 # launch daemon
809                 s = self._ServerClass(*ctor_args)
810                 s.run()
811
812         # connect client to server
813         self._client = server.Client(root_dir, host = host, port = port, 
814                 user = user, agent = agent, 
815                 environment_setup = environment_setup)
816     
817     @staticmethod
818     def _make_message(argtypes, argencoders, command, methname, classname, *args):
819         if len(argtypes) != len(argencoders):
820             raise ValueError, "Invalid arguments for _make_message: "\
821                 "in stub method %s of class %s "\
822                 "argtypes and argencoders must match in size" % (
823                     methname, classname )
824         if len(argtypes) != len(args):
825             raise ValueError, "Invalid arguments for _make_message: "\
826                 "in stub method %s of class %s "\
827                 "expected %d arguments, got %d" % (
828                     methname, classname,
829                     len(argtypes), len(args))
830         
831         buf = []
832         for argnum, (typ, (encode, fmt), val) in enumerate(zip(argtypes, argencoders, args)):
833             try:
834                 buf.append(fmt % encode(val))
835             except:
836                 import traceback
837                 raise TypeError, "Argument %d of stub method %s of class %s "\
838                     "requires a value of type %s, but got %s - nested error: %s" % (
839                         argnum, methname, classname,
840                         getattr(typ, '__name__', typ), type(val),
841                         traceback.format_exc()
842                 )
843         
844         return "%d|%s" % (command, '|'.join(buf))
845     
846     @staticmethod
847     def _parse_reply(rvtype, methname, classname, reply):
848         if not reply:
849             raise RuntimeError, "Invalid reply: %r "\
850                 "for stub method %s of class %s" % (
851                     reply,
852                     methname,
853                     classname)
854         
855         try:
856             result = reply.split("|")
857             code = int(result[0])
858             text = result[1]
859         except:
860             import traceback
861             raise TypeError, "Return value of stub method %s of class %s "\
862                 "cannot be parsed: must be of type %s, got %r - nested error: %s" % (
863                     methname, classname,
864                     getattr(rvtype, '__name__', rvtype), reply,
865                     traceback.format_exc()
866             )
867         if code == ERROR:
868             text = base64.b64decode(text)
869             raise RuntimeError(text)
870         elif code == OK:
871             try:
872                 if rvtype is None:
873                     return
874                 else:
875                     return rvtype(text)
876             except:
877                 import traceback
878                 raise TypeError, "Return value of stub method %s of class %s "\
879                     "cannot be parsed: must be of type %s - nested error: %s" % (
880                         methname, classname,
881                         getattr(rvtype, '__name__', rvtype),
882                         traceback.format_exc()
883                 )
884         else:
885             raise RuntimeError, "Invalid reply: %r "\
886                 "for stub method %s of class %s - unknown code" % (
887                     reply,
888                     methname,
889                     classname)
890     
891     @staticmethod
892     def _make_stubs(server_class, template_class):
893         """
894         Returns a dictionary method_name -> method
895         with stub methods.
896         
897         Usage:
898         
899             class SomeProxy(BaseProxy):
900                ...
901                
902                locals().update( BaseProxy._make_stubs(
903                     ServerClass,
904                     TemplateClass
905                ) )
906         
907         ServerClass is the corresponding Server class, as
908         specified in the _ServerClass class method (_make_stubs
909         is static and can't access the method), and TemplateClass
910         is the ultimate implementation class behind the server,
911         from which argument names and defaults are taken, to
912         maintain meaningful interfaces.
913         """
914         rv = {}
915         
916         class NONE: pass
917         
918         import os.path
919         func_template_path = os.path.join(
920             os.path.dirname(__file__),
921             'proxy_stub.tpl')
922         func_template_file = open(func_template_path, "r")
923         func_template = func_template_file.read()
924         func_template_file.close()
925         
926         for methname in vars(template_class).copy():
927             if methname.endswith('_deferred'):
928                 # cannot wrap deferreds...
929                 continue
930             dmethname = methname+'_deferred'
931             if hasattr(server_class, methname) and not methname.startswith('_'):
932                 template_meth = getattr(template_class, methname)
933                 server_meth = getattr(server_class, methname)
934                 
935                 command = getattr(server_meth, '_handles_command', None)
936                 argtypes = getattr(server_meth, '_argtypes', None)
937                 argencoders = getattr(server_meth, '_argencoders', None)
938                 rvtype = getattr(server_meth, '_retval', None)
939                 doprop = False
940                 
941                 if hasattr(template_meth, 'fget'):
942                     # property getter
943                     template_meth = template_meth.fget
944                     doprop = True
945                 
946                 if command is not None and argtypes is not None and argencoders is not None:
947                     # We have an interface method...
948                     code = template_meth.func_code
949                     argnames = code.co_varnames[:code.co_argcount]
950                     argdefaults = ( (NONE,) * (len(argnames) - len(template_meth.func_defaults or ()))
951                                   + (template_meth.func_defaults or ()) )
952                     
953                     func_globals = dict(
954                         BaseProxy = BaseProxy,
955                         argtypes = argtypes,
956                         argencoders = argencoders,
957                         rvtype = rvtype,
958                         functools = functools,
959                     )
960                     context = dict()
961                     
962                     func_text = func_template % dict(
963                         self = argnames[0],
964                         args = '%s' % (','.join(argnames[1:])),
965                         argdefs = ','.join([
966                             argname if argdef is NONE
967                             else "%s=%r" % (argname, argdef)
968                             for argname, argdef in zip(argnames[1:], argdefaults[1:])
969                         ]),
970                         command = command,
971                         methname = methname,
972                         classname = server_class.__name__
973                     )
974                     
975                     func_text = compile(
976                         func_text,
977                         func_template_path,
978                         'exec')
979                     
980                     exec func_text in func_globals, context
981                     
982                     if doprop:
983                         rv[methname] = property(context[methname])
984                         rv[dmethname] = property(context[dmethname])
985                     else:
986                         rv[methname] = context[methname]
987                         rv[dmethname] = context[dmethname]
988                     
989                     # inject _deferred into core classes
990                     if hasattr(template_class, methname) and not hasattr(template_class, dmethname):
991                         def freezename(methname, dmethname):
992                             def dmeth(self, *p, **kw): 
993                                 return getattr(self, methname)(*p, **kw)
994                             dmeth.__name__ = dmethname
995                             return dmeth
996                         dmeth = freezename(methname, dmethname)
997                         setattr(template_class, dmethname, dmeth)
998         
999         return rv
1000                         
1001 class TestbedControllerProxy(BaseProxy):
1002     
1003     _ServerClass = TestbedControllerServer
1004     
1005     def __init__(self, root_dir, log_level, testbed_id = None, 
1006             testbed_version = None, launch = True, host = None, 
1007             port = None, user = None, ident_key = None, agent = None,
1008             environment_setup = ""):
1009         if launch and (testbed_id == None or testbed_version == None):
1010             raise RuntimeError("To launch a TesbedControllerServer a "
1011                     "testbed_id and testbed_version are required")
1012         super(TestbedControllerProxy,self).__init__(
1013             ctor_args = (root_dir, log_level, testbed_id, testbed_version, environment_setup),
1014             root_dir = root_dir,
1015             launch = launch, host = host, port = port, user = user,
1016             ident_key = ident_key, agent = agent, 
1017             environment_setup = environment_setup)
1018
1019     locals().update( BaseProxy._make_stubs(
1020         server_class = TestbedControllerServer,
1021         template_class = nepi.core.execute.TestbedController,
1022     ) )
1023     
1024     # Shutdown stops the serverside...
1025     def shutdown(self, _stub = shutdown):
1026         rv = _stub(self)
1027         self._client.send_stop()
1028         self._client.read_reply() # wait for it
1029         return rv
1030     
1031
1032 class ExperimentControllerProxy(BaseProxy):
1033     _ServerClass = ExperimentControllerServer
1034     
1035     def __init__(self, root_dir, log_level, experiment_xml = None, 
1036             launch = True, host = None, port = None, user = None, 
1037             ident_key = None, agent = None, environment_setup = ""):
1038         if launch and experiment_xml is None:
1039             raise RuntimeError("To launch a ExperimentControllerServer a \
1040                     xml description of the experiment is required")
1041         super(ExperimentControllerProxy,self).__init__(
1042             ctor_args = (root_dir, log_level, experiment_xml, environment_setup),
1043             root_dir = root_dir,
1044             launch = launch, host = host, port = port, user = user,
1045             ident_key = ident_key, agent = agent, 
1046             environment_setup = environment_setup)
1047
1048     locals().update( BaseProxy._make_stubs(
1049         server_class = ExperimentControllerServer,
1050         template_class = nepi.core.execute.ExperimentController,
1051     ) )
1052
1053     
1054     # Shutdown stops the serverside...
1055     def shutdown(self, _stub = shutdown):
1056         rv = _stub(self)
1057         self._client.send_stop()
1058         self._client.read_reply() # wait for it
1059         return rv
1060