jsonrpc: Fix Python implementation of inactivity logic.
[sliver-openvswitch.git] / python / ovs / jsonrpc.py
1 # Copyright (c) 2010, 2011, 2012 Nicira, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import errno
16 import os
17
18 import ovs.json
19 import ovs.poller
20 import ovs.reconnect
21 import ovs.stream
22 import ovs.timeval
23 import ovs.util
24 import ovs.vlog
25
26 EOF = ovs.util.EOF
27 vlog = ovs.vlog.Vlog("jsonrpc")
28
29
30 class Message(object):
31     T_REQUEST = 0               # Request.
32     T_NOTIFY = 1                # Notification.
33     T_REPLY = 2                 # Successful reply.
34     T_ERROR = 3                 # Error reply.
35
36     __types = {T_REQUEST: "request",
37                T_NOTIFY: "notification",
38                T_REPLY: "reply",
39                T_ERROR: "error"}
40
41     def __init__(self, type_, method, params, result, error, id):
42         self.type = type_
43         self.method = method
44         self.params = params
45         self.result = result
46         self.error = error
47         self.id = id
48
49     _next_id = 0
50
51     @staticmethod
52     def _create_id():
53         this_id = Message._next_id
54         Message._next_id += 1
55         return this_id
56
57     @staticmethod
58     def create_request(method, params):
59         return Message(Message.T_REQUEST, method, params, None, None,
60                        Message._create_id())
61
62     @staticmethod
63     def create_notify(method, params):
64         return Message(Message.T_NOTIFY, method, params, None, None,
65                        None)
66
67     @staticmethod
68     def create_reply(result, id):
69         return Message(Message.T_REPLY, None, None, result, None, id)
70
71     @staticmethod
72     def create_error(error, id):
73         return Message(Message.T_ERROR, None, None, None, error, id)
74
75     @staticmethod
76     def type_to_string(type_):
77         return Message.__types[type_]
78
79     def __validate_arg(self, value, name, must_have):
80         if (value is not None) == (must_have != 0):
81             return None
82         else:
83             type_name = Message.type_to_string(self.type)
84             if must_have:
85                 verb = "must"
86             else:
87                 verb = "must not"
88             return "%s %s have \"%s\"" % (type_name, verb, name)
89
90     def is_valid(self):
91         if self.params is not None and type(self.params) != list:
92             return "\"params\" must be JSON array"
93
94         pattern = {Message.T_REQUEST: 0x11001,
95                    Message.T_NOTIFY:  0x11000,
96                    Message.T_REPLY:   0x00101,
97                    Message.T_ERROR:   0x00011}.get(self.type)
98         if pattern is None:
99             return "invalid JSON-RPC message type %s" % self.type
100
101         return (
102             self.__validate_arg(self.method, "method", pattern & 0x10000) or
103             self.__validate_arg(self.params, "params", pattern & 0x1000) or
104             self.__validate_arg(self.result, "result", pattern & 0x100) or
105             self.__validate_arg(self.error, "error", pattern & 0x10) or
106             self.__validate_arg(self.id, "id", pattern & 0x1))
107
108     @staticmethod
109     def from_json(json):
110         if type(json) != dict:
111             return "message is not a JSON object"
112
113         # Make a copy to avoid modifying the caller's dict.
114         json = dict(json)
115
116         if "method" in json:
117             method = json.pop("method")
118             if type(method) not in [str, unicode]:
119                 return "method is not a JSON string"
120         else:
121             method = None
122
123         params = json.pop("params", None)
124         result = json.pop("result", None)
125         error = json.pop("error", None)
126         id_ = json.pop("id", None)
127         if len(json):
128             return "message has unexpected member \"%s\"" % json.popitem()[0]
129
130         if result is not None:
131             msg_type = Message.T_REPLY
132         elif error is not None:
133             msg_type = Message.T_ERROR
134         elif id_ is not None:
135             msg_type = Message.T_REQUEST
136         else:
137             msg_type = Message.T_NOTIFY
138
139         msg = Message(msg_type, method, params, result, error, id_)
140         validation_error = msg.is_valid()
141         if validation_error is not None:
142             return validation_error
143         else:
144             return msg
145
146     def to_json(self):
147         json = {}
148
149         if self.method is not None:
150             json["method"] = self.method
151
152         if self.params is not None:
153             json["params"] = self.params
154
155         if self.result is not None or self.type == Message.T_ERROR:
156             json["result"] = self.result
157
158         if self.error is not None or self.type == Message.T_REPLY:
159             json["error"] = self.error
160
161         if self.id is not None or self.type == Message.T_NOTIFY:
162             json["id"] = self.id
163
164         return json
165
166     def __str__(self):
167         s = [Message.type_to_string(self.type)]
168         if self.method is not None:
169             s.append("method=\"%s\"" % self.method)
170         if self.params is not None:
171             s.append("params=" + ovs.json.to_string(self.params))
172         if self.result is not None:
173             s.append("result=" + ovs.json.to_string(self.result))
174         if self.error is not None:
175             s.append("error=" + ovs.json.to_string(self.error))
176         if self.id is not None:
177             s.append("id=" + ovs.json.to_string(self.id))
178         return ", ".join(s)
179
180
181 class Connection(object):
182     def __init__(self, stream):
183         self.name = stream.name
184         self.stream = stream
185         self.status = 0
186         self.input = ""
187         self.output = ""
188         self.parser = None
189         self.received_bytes = 0
190
191     def close(self):
192         self.stream.close()
193         self.stream = None
194
195     def run(self):
196         if self.status:
197             return
198
199         while len(self.output):
200             retval = self.stream.send(self.output)
201             if retval >= 0:
202                 self.output = self.output[retval:]
203             else:
204                 if retval != -errno.EAGAIN:
205                     vlog.warn("%s: send error: %s" %
206                               (self.name, os.strerror(-retval)))
207                     self.error(-retval)
208                 break
209
210     def wait(self, poller):
211         if not self.status:
212             self.stream.run_wait(poller)
213             if len(self.output):
214                 self.stream.send_wait(poller)
215
216     def get_status(self):
217         return self.status
218
219     def get_backlog(self):
220         if self.status != 0:
221             return 0
222         else:
223             return len(self.output)
224
225     def get_received_bytes(self):
226         return self.received_bytes
227
228     def __log_msg(self, title, msg):
229         vlog.dbg("%s: %s %s" % (self.name, title, msg))
230
231     def send(self, msg):
232         if self.status:
233             return self.status
234
235         self.__log_msg("send", msg)
236
237         was_empty = len(self.output) == 0
238         self.output += ovs.json.to_string(msg.to_json())
239         if was_empty:
240             self.run()
241         return self.status
242
243     def send_block(self, msg):
244         error = self.send(msg)
245         if error:
246             return error
247
248         while True:
249             self.run()
250             if not self.get_backlog() or self.get_status():
251                 return self.status
252
253             poller = ovs.poller.Poller()
254             self.wait(poller)
255             poller.block()
256
257     def recv(self):
258         if self.status:
259             return self.status, None
260
261         while True:
262             if not self.input:
263                 error, data = self.stream.recv(4096)
264                 if error:
265                     if error == errno.EAGAIN:
266                         return error, None
267                     else:
268                         # XXX rate-limit
269                         vlog.warn("%s: receive error: %s"
270                                   % (self.name, os.strerror(error)))
271                         self.error(error)
272                         return self.status, None
273                 elif not data:
274                     self.error(EOF)
275                     return EOF, None
276                 else:
277                     self.input += data
278                     self.received_bytes += len(data)
279             else:
280                 if self.parser is None:
281                     self.parser = ovs.json.Parser()
282                 self.input = self.input[self.parser.feed(self.input):]
283                 if self.parser.is_done():
284                     msg = self.__process_msg()
285                     if msg:
286                         return 0, msg
287                     else:
288                         return self.status, None
289
290     def recv_block(self):
291         while True:
292             error, msg = self.recv()
293             if error != errno.EAGAIN:
294                 return error, msg
295
296             self.run()
297
298             poller = ovs.poller.Poller()
299             self.wait(poller)
300             self.recv_wait(poller)
301             poller.block()
302
303     def transact_block(self, request):
304         id_ = request.id
305
306         error = self.send(request)
307         reply = None
308         while not error:
309             error, reply = self.recv_block()
310             if (reply
311                 and (reply.type == Message.T_REPLY
312                      or reply.type == Message.T_ERROR)
313                 and reply.id == id_):
314                 break
315         return error, reply
316
317     def __process_msg(self):
318         json = self.parser.finish()
319         self.parser = None
320         if type(json) in [str, unicode]:
321             # XXX rate-limit
322             vlog.warn("%s: error parsing stream: %s" % (self.name, json))
323             self.error(errno.EPROTO)
324             return
325
326         msg = Message.from_json(json)
327         if not isinstance(msg, Message):
328             # XXX rate-limit
329             vlog.warn("%s: received bad JSON-RPC message: %s"
330                       % (self.name, msg))
331             self.error(errno.EPROTO)
332             return
333
334         self.__log_msg("received", msg)
335         return msg
336
337     def recv_wait(self, poller):
338         if self.status or self.input:
339             poller.immediate_wake()
340         else:
341             self.stream.recv_wait(poller)
342
343     def error(self, error):
344         if self.status == 0:
345             self.status = error
346             self.stream.close()
347             self.output = ""
348
349
350 class Session(object):
351     """A JSON-RPC session with reconnection."""
352
353     def __init__(self, reconnect, rpc):
354         self.reconnect = reconnect
355         self.rpc = rpc
356         self.stream = None
357         self.pstream = None
358         self.seqno = 0
359
360     @staticmethod
361     def open(name):
362         """Creates and returns a Session that maintains a JSON-RPC session to
363         'name', which should be a string acceptable to ovs.stream.Stream or
364         ovs.stream.PassiveStream's initializer.
365
366         If 'name' is an active connection method, e.g. "tcp:127.1.2.3", the new
367         session connects and reconnects, with back-off, to 'name'.
368
369         If 'name' is a passive connection method, e.g. "ptcp:", the new session
370         listens for connections to 'name'.  It maintains at most one connection
371         at any given time.  Any new connection causes the previous one (if any)
372         to be dropped."""
373         reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec())
374         reconnect.set_name(name)
375         reconnect.enable(ovs.timeval.msec())
376
377         if ovs.stream.PassiveStream.is_valid_name(name):
378             reconnect.set_passive(True, ovs.timeval.msec())
379
380         if ovs.stream.stream_or_pstream_needs_probes(name):
381             reconnect.set_probe_interval(0)
382
383         return Session(reconnect, None)
384
385     @staticmethod
386     def open_unreliably(jsonrpc):
387         reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec())
388         reconnect.set_quiet(True)
389         reconnect.set_name(jsonrpc.name)
390         reconnect.set_max_tries(0)
391         reconnect.connected(ovs.timeval.msec())
392         return Session(reconnect, jsonrpc)
393
394     def close(self):
395         if self.rpc is not None:
396             self.rpc.close()
397             self.rpc = None
398         if self.stream is not None:
399             self.stream.close()
400             self.stream = None
401         if self.pstream is not None:
402             self.pstream.close()
403             self.pstream = None
404
405     def __disconnect(self):
406         if self.rpc is not None:
407             self.rpc.error(EOF)
408             self.rpc.close()
409             self.rpc = None
410             self.seqno += 1
411         elif self.stream is not None:
412             self.stream.close()
413             self.stream = None
414             self.seqno += 1
415
416     def __connect(self):
417         self.__disconnect()
418
419         name = self.reconnect.get_name()
420         if not self.reconnect.is_passive():
421             error, self.stream = ovs.stream.Stream.open(name)
422             if not error:
423                 self.reconnect.connecting(ovs.timeval.msec())
424             else:
425                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
426         elif self.pstream is not None:
427             error, self.pstream = ovs.stream.PassiveStream.open(name)
428             if not error:
429                 self.reconnect.listening(ovs.timeval.msec())
430             else:
431                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
432
433         self.seqno += 1
434
435     def run(self):
436         if self.pstream is not None:
437             error, stream = self.pstream.accept()
438             if error == 0:
439                 if self.rpc or self.stream:
440                     # XXX rate-limit
441                     vlog.info("%s: new connection replacing active "
442                               "connection" % self.reconnect.get_name())
443                     self.__disconnect()
444                 self.reconnect.connected(ovs.timeval.msec())
445                 self.rpc = Connection(stream)
446             elif error != errno.EAGAIN:
447                 self.reconnect.listen_error(ovs.timeval.msec(), error)
448                 self.pstream.close()
449                 self.pstream = None
450
451         if self.rpc:
452             backlog = self.rpc.get_backlog()
453             self.rpc.run()
454             if self.rpc.get_backlog() < backlog:
455                 # Data previously caught in a queue was successfully sent (or
456                 # there's an error, which we'll catch below).
457                 #
458                 # We don't count data that is successfully sent immediately as
459                 # activity, because there's a lot of queuing downstream from
460                 # us, which means that we can push a lot of data into a
461                 # connection that has stalled and won't ever recover.
462                 self.reconnect.activity(ovs.timeval.msec())
463
464             error = self.rpc.get_status()
465             if error != 0:
466                 self.reconnect.disconnected(ovs.timeval.msec(), error)
467                 self.__disconnect()
468         elif self.stream is not None:
469             self.stream.run()
470             error = self.stream.connect()
471             if error == 0:
472                 self.reconnect.connected(ovs.timeval.msec())
473                 self.rpc = Connection(self.stream)
474                 self.stream = None
475             elif error != errno.EAGAIN:
476                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
477                 self.stream.close()
478                 self.stream = None
479
480         action = self.reconnect.run(ovs.timeval.msec())
481         if action == ovs.reconnect.CONNECT:
482             self.__connect()
483         elif action == ovs.reconnect.DISCONNECT:
484             self.reconnect.disconnected(ovs.timeval.msec(), 0)
485             self.__disconnect()
486         elif action == ovs.reconnect.PROBE:
487             if self.rpc:
488                 request = Message.create_request("echo", [])
489                 request.id = "echo"
490                 self.rpc.send(request)
491         else:
492             assert action == None
493
494     def wait(self, poller):
495         if self.rpc is not None:
496             self.rpc.wait(poller)
497         elif self.stream is not None:
498             self.stream.run_wait(poller)
499             self.stream.connect_wait(poller)
500         if self.pstream is not None:
501             self.pstream.wait(poller)
502         self.reconnect.wait(poller, ovs.timeval.msec())
503
504     def get_backlog(self):
505         if self.rpc is not None:
506             return self.rpc.get_backlog()
507         else:
508             return 0
509
510     def get_name(self):
511         return self.reconnect.get_name()
512
513     def send(self, msg):
514         if self.rpc is not None:
515             return self.rpc.send(msg)
516         else:
517             return errno.ENOTCONN
518
519     def recv(self):
520         if self.rpc is not None:
521             received_bytes = self.rpc.get_received_bytes()
522             error, msg = self.rpc.recv()
523             if received_bytes != self.rpc.get_received_bytes():
524                 # Data was successfully received.
525                 #
526                 # Previously we only counted receiving a full message as
527                 # activity, but with large messages or a slow connection that
528                 # policy could time out the session mid-message.
529                 self.reconnect.activity(ovs.timeval.msec())
530
531             if not error:
532                 if msg.type == Message.T_REQUEST and msg.method == "echo":
533                     # Echo request.  Send reply.
534                     self.send(Message.create_reply(msg.params, msg.id))
535                 elif msg.type == Message.T_REPLY and msg.id == "echo":
536                     # It's a reply to our echo request.  Suppress it.
537                     pass
538                 else:
539                     return msg
540         return None
541
542     def recv_wait(self, poller):
543         if self.rpc is not None:
544             self.rpc.recv_wait(poller)
545
546     def is_alive(self):
547         if self.rpc is not None or self.stream is not None:
548             return True
549         else:
550             max_tries = self.reconnect.get_max_tries()
551             return max_tries is None or max_tries > 0
552
553     def is_connected(self):
554         return self.rpc is not None
555
556     def get_seqno(self):
557         return self.seqno
558
559     def force_reconnect(self):
560         self.reconnect.force_reconnect(ovs.timeval.msec())