ec57ee9f4a510b7e2c5785c9cc9aa9d9a7adb618
[nepi.git] / src / nepi / util / plotter.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20 import logging
21 import os
22
23 try:
24     import networkx
25 except ImportError:
26     msg = "Networkx library is not installed, you will not be able to plot."
27     logger = logging.Logger("Plotter")
28     logger.debug(msg)
29
30 try:
31     import matplotlib.pyplot as plt
32 except ImportError:
33     msg = ("Matplotlib library is not installed, you will not be able "
34         "generate PNG plots.")
35     logger = logging.Logger("Plotter")
36     logger.debug(msg)
37
38 class PFormats:
39     DOT = "dot"
40     FIGURE = "figure"
41
42 class ECPlotter(object):
43     def plot(self, ec, dirpath = None, format= PFormats.FIGURE, 
44             show = False):
45         graph, labels = self._ec2graph(ec)
46
47         add_extension = False
48
49         if not dirpath:
50             import tempfile
51             dirpath = tempfile.mkdtemp()
52         
53         fpath = os.path.join(dirpath, "%s_%s" % (ec.exp_id, ec.run_id)) 
54
55         if format == PFormats.FIGURE:
56             pos = networkx.graphviz_layout(graph, prog="neato")
57             networkx.draw(graph, pos = pos, node_color="white", 
58                     node_size = 500, with_labels=True)
59            
60             label = "\n".join(map(lambda v: "%s: %s" % (v[0], v[1]), labels.iteritems()))
61             plt.annotate(label, xy=(0.05, 0.95), xycoords='axes fraction')
62            
63             fpath += ".png"
64
65             plt.savefig(fpath, bbox_inches="tight")
66             
67             if show:
68                 plt.show()
69
70         elif format == PFormats.DOT:
71             fpath += ".dot"
72
73             networkx.write_dot(graph, fpath)
74             
75             if show:
76                 import subprocess
77                 subprocess.call(["dot", "-Tps", fpath, "-o", "%s.ps" % fpath])
78                 subprocess.call(["evince","%s.ps" % fpath])
79         
80         return fpath
81
82     def _ec2graph(self, ec):
83         graph = networkx.Graph(graph = dict(overlap = "false"))
84
85         labels = dict()
86         connections = set()
87
88         for guid, rm in ec._resources.iteritems():
89             label = rm.get_rtype()
90
91             graph.add_node(guid,
92                 label = "%d %s" % (guid, label),
93                 width = 50/72.0, # 1 inch = 72 points
94                 height = 50/72.0, 
95                 shape = "circle")
96
97             labels[guid] = label
98
99             for guid2 in rm.connections:
100                 # Avoid adding a same connection twice
101                 if (guid2, guid) not in connections:
102                     connections.add((guid, guid2))
103
104         for (guid1, guid2) in connections:
105             graph.add_edge(guid1, guid2)
106
107         return graph, labels