Adding base RMs for ns-3
[nepi.git] / src / nepi / resources / ns3 / resource_manager_generator.py
index f5d4acc..4042bf8 100644 (file)
@@ -23,6 +23,51 @@ from nepi.resources.ns3.ns3wrapper import load_ns3_module
 import os
 import re
 
+def select_base_class(ns3, tid): 
+    rtype = tid.GetName()
+
+    type_id = ns3.TypeId()
+    appbase = type_id.LookupByName("ns3::Application")
+    devicebase = type_id.LookupByName("ns3::NetDevice")
+    channelbase = type_id.LookupByName("ns3::Channel")
+    queuebase = type_id.LookupByName("ns3::Queue")
+    lossbase = type_id.LookupByName("ns3::PropagationLossModel")
+    delaybase = type_id.LookupByName("ns3::PropagationDelayModel")
+    managerbase = type_id.LookupByName("ns3::WifiRemoteStationManager")
+
+    if tid.IsChiledOf(appbase):
+       base_class_import = "from nepi.resources.ns3.ns3application import NS3BaseApplication"
+       base_class = "NS3BaseApplication"
+    elif tid.IsChiledOf(devicebase):
+       base_class_import = "from nepi.resources.ns3.ns3device import NS3BaseNetDevice"
+       base_class = "NS3BaseNetDevice"
+    elif tid.IsChiledOf(channelbase):
+       base_class_import = "from nepi.resources.ns3.ns3channel import NS3BaseChannel"
+       base_class = "NS3BaseChannel"
+    elif tid.IsChiledOf(queuebase):
+       base_class_import = "from nepi.resources.ns3.ns3queue import NS3BaseQueue"
+       base_class = "NS3BaseQueue"
+    elif tid.IsChiledOf(lossbase):
+       base_class_import = "from nepi.resources.ns3.ns3loss import NS3BasePropagationLossModel"
+       base_class = "NS3BasePropagationLossDelay"
+    elif tid.IsChiledOf(delaybase):
+       base_class_import = "from nepi.resources.ns3.ns3delay import NS3BasePropagationDelayModel"
+       base_class = "NS3BasePropagationDelayModel"
+    elif tid.IsChiledOf(managerbase):
+       base_class_import = "from nepi.resources.ns3.ns3manager import NS3BaseWifiRemoteStationManager"
+       base_class = "NS3BaseWifiRemoteStationManager"
+    elif rtype == "ns3::Node":
+       base_class_import = "from nepi.resources.ns3.ns3node import NS3BaseNode"
+       base_class = "NS3BaseNode"
+    elif rtype == "ns3::Ipv4L3Protocol":
+       base_class_import = "from nepi.resources.ns3.ns3ipv4protocol import NS3BaseIpV4Protocol"
+       base_class = "NS3BaseIpV4Protocol"
+    else:
+       base_class_import = "from nepi.resources.ns3.ns3base import NS3Base"
+       base_class = "NS3Base"
+
+    return (base_class_import, base_class)
+
 def create_ns3_rms():
     ns3 = load_ns3_module()
 
@@ -51,12 +96,11 @@ def create_ns3_rms():
         attributes = "\n" + attributes if attributes else "pass"
         traces = "\n" + traces if traces else "pass"
 
+        (base_class_import, base_class) = select_base_class(ns3, tid)
+
         rtype = tid.GetName()
         category = tid.GetGroupName()
 
-        base_class_import = "from nepi.resources.ns3.ns3base import NS3Base"
-        base_clas = "NS3Base"
         classname = rtype.replace("ns3::", "NS3").replace("::","")
         uncamm_rtype = re.sub('([a-z])([A-Z])', r'\1-\2', rtype).lower()
         short_rtype = uncamm_rtype.replace("::","-")