datapath: Adopt Generic Netlink-compatible locking.
[sliver-openvswitch.git] / datapath / table.c
index 35a532e..47fa016 100644 (file)
@@ -10,6 +10,7 @@
 #include "datapath.h"
 #include "table.h"
 
+#include <linux/genetlink.h>
 #include <linux/gfp.h>
 #include <linux/slab.h>
 #include <linux/mm.h>
@@ -30,6 +31,17 @@ struct tbl_bucket {
        struct tbl_node *objs[];
 };
 
+static struct tbl_bucket *get_bucket(struct tbl_bucket __rcu *bucket)
+{
+       return rcu_dereference_check(bucket, rcu_read_lock_held() ||
+                                    lockdep_genl_is_held());
+}
+
+static struct tbl_bucket *get_bucket_protected(struct tbl_bucket __rcu *bucket)
+{
+       return rcu_dereference_protected(bucket, lockdep_genl_is_held());
+}
+
 static inline int bucket_size(int n_objs)
 {
        return sizeof(struct tbl_bucket) + sizeof(struct tbl_node *) * n_objs;
@@ -196,7 +208,7 @@ struct tbl_node *tbl_lookup(struct tbl *table, void *target, u32 hash,
                            int (*cmp)(const struct tbl_node *, void *))
 {
        struct tbl_bucket __rcu **bucketp = find_bucket(table, hash);
-       struct tbl_bucket *bucket = rcu_dereference(*bucketp);
+       struct tbl_bucket *bucket = get_bucket(*bucketp);
        int index;
 
        if (!bucket)
@@ -237,7 +249,7 @@ int tbl_foreach(struct tbl *table,
                        struct tbl_bucket *bucket;
                        unsigned int i;
 
-                       bucket = rcu_dereference(l2[l2_idx]);
+                       bucket = get_bucket(l2[l2_idx]);
                        if (!bucket)
                                continue;
 
@@ -282,7 +294,7 @@ struct tbl_node *tbl_next(struct tbl *table, u32 *bucketp, u32 *objp)
                for (l2_idx = s_l2_idx; l2_idx < TBL_L2_SIZE; l2_idx++) {
                        struct tbl_bucket *bucket;
 
-                       bucket = rcu_dereference(l2[l2_idx]);
+                       bucket = get_bucket_protected(l2[l2_idx]);
                        if (bucket && s_obj < bucket->n_objs) {
                                *bucketp = (l1_idx << TBL_L1_SHIFT) + l2_idx;
                                *objp = s_obj + 1;
@@ -371,7 +383,7 @@ static void free_bucket_rcu(struct rcu_head *rcu)
 int tbl_insert(struct tbl *table, struct tbl_node *target, u32 hash)
 {
        struct tbl_bucket __rcu **oldp = find_bucket(table, hash);
-       struct tbl_bucket *old = rcu_dereference(*oldp);
+       struct tbl_bucket *old = get_bucket_protected(*oldp);
        unsigned int n = old ? old->n_objs : 0;
        struct tbl_bucket *new = bucket_alloc(n + 1);
 
@@ -409,7 +421,7 @@ int tbl_insert(struct tbl *table, struct tbl_node *target, u32 hash)
 int tbl_remove(struct tbl *table, struct tbl_node *target)
 {
        struct tbl_bucket __rcu **oldp = find_bucket(table, target->hash);
-       struct tbl_bucket *old = rcu_dereference(*oldp);
+       struct tbl_bucket *old = get_bucket_protected(*oldp);
        unsigned int n = old->n_objs;
        struct tbl_bucket *new;