jsonrpc.py: Don't swallow errors in transact_block().
[sliver-openvswitch.git] / python / ovs / jsonrpc.py
1 # Copyright (c) 2010, 2011 Nicira Networks
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import errno
16 import os
17
18 import ovs.json
19 import ovs.poller
20 import ovs.reconnect
21 import ovs.stream
22 import ovs.timeval
23 import ovs.vlog
24
25 EOF = -1
26 vlog = ovs.vlog.Vlog("jsonrpc")
27
28
29 class Message(object):
30     T_REQUEST = 0               # Request.
31     T_NOTIFY = 1                # Notification.
32     T_REPLY = 2                 # Successful reply.
33     T_ERROR = 3                 # Error reply.
34
35     __types = {T_REQUEST: "request",
36                T_NOTIFY: "notification",
37                T_REPLY: "reply",
38                T_ERROR: "error"}
39
40     def __init__(self, type_, method, params, result, error, id):
41         self.type = type_
42         self.method = method
43         self.params = params
44         self.result = result
45         self.error = error
46         self.id = id
47
48     _next_id = 0
49
50     @staticmethod
51     def _create_id():
52         this_id = Message._next_id
53         Message._next_id += 1
54         return this_id
55
56     @staticmethod
57     def create_request(method, params):
58         return Message(Message.T_REQUEST, method, params, None, None,
59                        Message._create_id())
60
61     @staticmethod
62     def create_notify(method, params):
63         return Message(Message.T_NOTIFY, method, params, None, None,
64                        None)
65
66     @staticmethod
67     def create_reply(result, id):
68         return Message(Message.T_REPLY, None, None, result, None, id)
69
70     @staticmethod
71     def create_error(error, id):
72         return Message(Message.T_ERROR, None, None, None, error, id)
73
74     @staticmethod
75     def type_to_string(type_):
76         return Message.__types[type_]
77
78     def __validate_arg(self, value, name, must_have):
79         if (value is not None) == (must_have != 0):
80             return None
81         else:
82             type_name = Message.type_to_string(self.type)
83             if must_have:
84                 verb = "must"
85             else:
86                 verb = "must not"
87             return "%s %s have \"%s\"" % (type_name, verb, name)
88
89     def is_valid(self):
90         if self.params is not None and type(self.params) != list:
91             return "\"params\" must be JSON array"
92
93         pattern = {Message.T_REQUEST: 0x11001,
94                    Message.T_NOTIFY:  0x11000,
95                    Message.T_REPLY:   0x00101,
96                    Message.T_ERROR:   0x00011}.get(self.type)
97         if pattern is None:
98             return "invalid JSON-RPC message type %s" % self.type
99
100         return (
101             self.__validate_arg(self.method, "method", pattern & 0x10000) or
102             self.__validate_arg(self.params, "params", pattern & 0x1000) or
103             self.__validate_arg(self.result, "result", pattern & 0x100) or
104             self.__validate_arg(self.error, "error", pattern & 0x10) or
105             self.__validate_arg(self.id, "id", pattern & 0x1))
106
107     @staticmethod
108     def from_json(json):
109         if type(json) != dict:
110             return "message is not a JSON object"
111
112         # Make a copy to avoid modifying the caller's dict.
113         json = dict(json)
114
115         if "method" in json:
116             method = json.pop("method")
117             if type(method) not in [str, unicode]:
118                 return "method is not a JSON string"
119         else:
120             method = None
121
122         params = json.pop("params", None)
123         result = json.pop("result", None)
124         error = json.pop("error", None)
125         id_ = json.pop("id", None)
126         if len(json):
127             return "message has unexpected member \"%s\"" % json.popitem()[0]
128
129         if result is not None:
130             msg_type = Message.T_REPLY
131         elif error is not None:
132             msg_type = Message.T_ERROR
133         elif id_ is not None:
134             msg_type = Message.T_REQUEST
135         else:
136             msg_type = Message.T_NOTIFY
137
138         msg = Message(msg_type, method, params, result, error, id_)
139         validation_error = msg.is_valid()
140         if validation_error is not None:
141             return validation_error
142         else:
143             return msg
144
145     def to_json(self):
146         json = {}
147
148         if self.method is not None:
149             json["method"] = self.method
150
151         if self.params is not None:
152             json["params"] = self.params
153
154         if self.result is not None or self.type == Message.T_ERROR:
155             json["result"] = self.result
156
157         if self.error is not None or self.type == Message.T_REPLY:
158             json["error"] = self.error
159
160         if self.id is not None or self.type == Message.T_NOTIFY:
161             json["id"] = self.id
162
163         return json
164
165     def __str__(self):
166         s = [Message.type_to_string(self.type)]
167         if self.method is not None:
168             s.append("method=\"%s\"" % self.method)
169         if self.params is not None:
170             s.append("params=" + ovs.json.to_string(self.params))
171         if self.result is not None:
172             s.append("result=" + ovs.json.to_string(self.result))
173         if self.error is not None:
174             s.append("error=" + ovs.json.to_string(self.error))
175         if self.id is not None:
176             s.append("id=" + ovs.json.to_string(self.id))
177         return ", ".join(s)
178
179
180 class Connection(object):
181     def __init__(self, stream):
182         self.name = stream.name
183         self.stream = stream
184         self.status = 0
185         self.input = ""
186         self.output = ""
187         self.parser = None
188
189     def close(self):
190         self.stream.close()
191         self.stream = None
192
193     def run(self):
194         if self.status:
195             return
196
197         while len(self.output):
198             retval = self.stream.send(self.output)
199             if retval >= 0:
200                 self.output = self.output[retval:]
201             else:
202                 if retval != -errno.EAGAIN:
203                     vlog.warn("%s: send error: %s" %
204                               (self.name, os.strerror(-retval)))
205                     self.error(-retval)
206                 break
207
208     def wait(self, poller):
209         if not self.status:
210             self.stream.run_wait(poller)
211             if len(self.output):
212                 self.stream.send_wait()
213
214     def get_status(self):
215         return self.status
216
217     def get_backlog(self):
218         if self.status != 0:
219             return 0
220         else:
221             return len(self.output)
222
223     def __log_msg(self, title, msg):
224         vlog.dbg("%s: %s %s" % (self.name, title, msg))
225
226     def send(self, msg):
227         if self.status:
228             return self.status
229
230         self.__log_msg("send", msg)
231
232         was_empty = len(self.output) == 0
233         self.output += ovs.json.to_string(msg.to_json())
234         if was_empty:
235             self.run()
236         return self.status
237
238     def send_block(self, msg):
239         error = self.send(msg)
240         if error:
241             return error
242
243         while True:
244             self.run()
245             if not self.get_backlog() or self.get_status():
246                 return self.status
247
248             poller = ovs.poller.Poller()
249             self.wait(poller)
250             poller.block()
251
252     def recv(self):
253         if self.status:
254             return self.status, None
255
256         while True:
257             if not self.input:
258                 error, data = self.stream.recv(4096)
259                 if error:
260                     if error == errno.EAGAIN:
261                         return error, None
262                     else:
263                         # XXX rate-limit
264                         vlog.warn("%s: receive error: %s"
265                                   % (self.name, os.strerror(error)))
266                         self.error(error)
267                         return self.status, None
268                 elif not data:
269                     self.error(EOF)
270                     return EOF, None
271                 else:
272                     self.input += data
273             else:
274                 if self.parser is None:
275                     self.parser = ovs.json.Parser()
276                 self.input = self.input[self.parser.feed(self.input):]
277                 if self.parser.is_done():
278                     msg = self.__process_msg()
279                     if msg:
280                         return 0, msg
281                     else:
282                         return self.status, None
283
284     def recv_block(self):
285         while True:
286             error, msg = self.recv()
287             if error != errno.EAGAIN:
288                 return error, msg
289
290             self.run()
291
292             poller = ovs.poller.Poller()
293             self.wait(poller)
294             self.recv_wait(poller)
295             poller.block()
296
297     def transact_block(self, request):
298         id_ = request.id
299
300         error = self.send(request)
301         reply = None
302         while not error:
303             error, reply = self.recv_block()
304             if (reply
305                 and (reply.type == Message.T_REPLY
306                      or reply.type == Message.T_ERROR)
307                 and reply.id == id_):
308                 break
309         return error, reply
310
311     def __process_msg(self):
312         json = self.parser.finish()
313         self.parser = None
314         if type(json) in [str, unicode]:
315             # XXX rate-limit
316             vlog.warn("%s: error parsing stream: %s" % (self.name, json))
317             self.error(errno.EPROTO)
318             return
319
320         msg = Message.from_json(json)
321         if not isinstance(msg, Message):
322             # XXX rate-limit
323             vlog.warn("%s: received bad JSON-RPC message: %s"
324                       % (self.name, msg))
325             self.error(errno.EPROTO)
326             return
327
328         self.__log_msg("received", msg)
329         return msg
330
331     def recv_wait(self, poller):
332         if self.status or self.input:
333             poller.immediate_wake()
334         else:
335             self.stream.recv_wait(poller)
336
337     def error(self, error):
338         if self.status == 0:
339             self.status = error
340             self.stream.close()
341             self.output = ""
342
343
344 class Session(object):
345     """A JSON-RPC session with reconnection."""
346
347     def __init__(self, reconnect, rpc):
348         self.reconnect = reconnect
349         self.rpc = rpc
350         self.stream = None
351         self.pstream = None
352         self.seqno = 0
353
354     @staticmethod
355     def open(name):
356         """Creates and returns a Session that maintains a JSON-RPC session to
357         'name', which should be a string acceptable to ovs.stream.Stream or
358         ovs.stream.PassiveStream's initializer.
359
360         If 'name' is an active connection method, e.g. "tcp:127.1.2.3", the new
361         session connects and reconnects, with back-off, to 'name'.
362
363         If 'name' is a passive connection method, e.g. "ptcp:", the new session
364         listens for connections to 'name'.  It maintains at most one connection
365         at any given time.  Any new connection causes the previous one (if any)
366         to be dropped."""
367         reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec())
368         reconnect.set_name(name)
369         reconnect.enable(ovs.timeval.msec())
370
371         if ovs.stream.PassiveStream.is_valid_name(name):
372             reconnect.set_passive(True, ovs.timeval.msec())
373
374         return Session(reconnect, None)
375
376     @staticmethod
377     def open_unreliably(jsonrpc):
378         reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec())
379         reconnect.set_quiet(True)
380         reconnect.set_name(jsonrpc.name)
381         reconnect.set_max_tries(0)
382         reconnect.connected(ovs.timeval.msec())
383         return Session(reconnect, jsonrpc)
384
385     def close(self):
386         if self.rpc is not None:
387             self.rpc.close()
388             self.rpc = None
389         if self.stream is not None:
390             self.stream.close()
391             self.stream = None
392         if self.pstream is not None:
393             self.pstream.close()
394             self.pstream = None
395
396     def __disconnect(self):
397         if self.rpc is not None:
398             self.rpc.error(EOF)
399             self.rpc.close()
400             self.rpc = None
401             self.seqno += 1
402         elif self.stream is not None:
403             self.stream.close()
404             self.stream = None
405             self.seqno += 1
406
407     def __connect(self):
408         self.__disconnect()
409
410         name = self.reconnect.get_name()
411         if not self.reconnect.is_passive():
412             error, self.stream = ovs.stream.Stream.open(name)
413             if not error:
414                 self.reconnect.connecting(ovs.timeval.msec())
415             else:
416                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
417         elif self.pstream is not None:
418             error, self.pstream = ovs.stream.PassiveStream.open(name)
419             if not error:
420                 self.reconnect.listening(ovs.timeval.msec())
421             else:
422                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
423
424         self.seqno += 1
425
426     def run(self):
427         if self.pstream is not None:
428             error, stream = self.pstream.accept()
429             if error == 0:
430                 if self.rpc or self.stream:
431                     # XXX rate-limit
432                     vlog.info("%s: new connection replacing active "
433                               "connection" % self.reconnect.get_name())
434                     self.__disconnect()
435                 self.reconnect.connected(ovs.timeval.msec())
436                 self.rpc = Connection(stream)
437             elif error != errno.EAGAIN:
438                 self.reconnect.listen_error(ovs.timeval.msec(), error)
439                 self.pstream.close()
440                 self.pstream = None
441
442         if self.rpc:
443             self.rpc.run()
444             error = self.rpc.get_status()
445             if error != 0:
446                 self.reconnect.disconnected(ovs.timeval.msec(), error)
447                 self.__disconnect()
448         elif self.stream is not None:
449             self.stream.run()
450             error = self.stream.connect()
451             if error == 0:
452                 self.reconnect.connected(ovs.timeval.msec())
453                 self.rpc = Connection(self.stream)
454                 self.stream = None
455             elif error != errno.EAGAIN:
456                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
457                 self.stream.close()
458                 self.stream = None
459
460         action = self.reconnect.run(ovs.timeval.msec())
461         if action == ovs.reconnect.CONNECT:
462             self.__connect()
463         elif action == ovs.reconnect.DISCONNECT:
464             self.reconnect.disconnected(ovs.timeval.msec(), 0)
465             self.__disconnect()
466         elif action == ovs.reconnect.PROBE:
467             if self.rpc:
468                 request = Message.create_request("echo", [])
469                 request.id = "echo"
470                 self.rpc.send(request)
471         else:
472             assert action == None
473
474     def wait(self, poller):
475         if self.rpc is not None:
476             self.rpc.wait(poller)
477         elif self.stream is not None:
478             self.stream.run_wait(poller)
479             self.stream.connect_wait(poller)
480         if self.pstream is not None:
481             self.pstream.wait(poller)
482         self.reconnect.wait(poller, ovs.timeval.msec())
483
484     def get_backlog(self):
485         if self.rpc is not None:
486             return self.rpc.get_backlog()
487         else:
488             return 0
489
490     def get_name(self):
491         return self.reconnect.get_name()
492
493     def send(self, msg):
494         if self.rpc is not None:
495             return self.rpc.send(msg)
496         else:
497             return errno.ENOTCONN
498
499     def recv(self):
500         if self.rpc is not None:
501             error, msg = self.rpc.recv()
502             if not error:
503                 self.reconnect.received(ovs.timeval.msec())
504                 if msg.type == Message.T_REQUEST and msg.method == "echo":
505                     # Echo request.  Send reply.
506                     self.send(Message.create_reply(msg.params, msg.id))
507                 elif msg.type == Message.T_REPLY and msg.id == "echo":
508                     # It's a reply to our echo request.  Suppress it.
509                     pass
510                 else:
511                     return msg
512         return None
513
514     def recv_wait(self, poller):
515         if self.rpc is not None:
516             self.rpc.recv_wait(poller)
517
518     def is_alive(self):
519         if self.rpc is not None or self.stream is not None:
520             return True
521         else:
522             max_tries = self.reconnect.get_max_tries()
523             return max_tries is None or max_tries > 0
524
525     def is_connected(self):
526         return self.rpc is not None
527
528     def get_seqno(self):
529         return self.seqno
530
531     def force_reconnect(self):
532         self.reconnect.force_reconnect(ovs.timeval.msec())