#include "datapath.h"
#include "table.h"
+#include <linux/genetlink.h>
#include <linux/gfp.h>
#include <linux/slab.h>
#include <linux/mm.h>
struct tbl_node *objs[];
};
+static struct tbl_bucket *get_bucket(struct tbl_bucket __rcu *bucket)
+{
+ return rcu_dereference_check(bucket, rcu_read_lock_held() ||
+ lockdep_genl_is_held());
+}
+
+static struct tbl_bucket *get_bucket_protected(struct tbl_bucket __rcu *bucket)
+{
+ return rcu_dereference_protected(bucket, lockdep_genl_is_held());
+}
+
static inline int bucket_size(int n_objs)
{
return sizeof(struct tbl_bucket) + sizeof(struct tbl_node *) * n_objs;
return &table->buckets[l1][l2];
}
-static int search_bucket(const struct tbl_bucket *bucket, void *target, u32 hash,
- int (*cmp)(const struct tbl_node *, void *))
+static int search_bucket(const struct tbl_bucket *bucket, void *target, int len, u32 hash,
+ int (*cmp)(const struct tbl_node *, void *, int len))
{
int i;
for (i = 0; i < bucket->n_objs; i++) {
struct tbl_node *obj = bucket->objs[i];
- if (obj->hash == hash && likely(cmp(obj, target)))
+ if (obj->hash == hash && likely(cmp(obj, target, len)))
return i;
}
* @table: hash table to search
* @target: identifier for the object that is being searched for, will be
* provided as an argument to @cmp when making comparisions
+ * @len: length of @target in bytes, will be provided as an argument to @cmp
+ * when making comparisons
* @hash: hash of @target
* @cmp: comparision function to match objects with the given hash, returns
* nonzero if the objects match, zero otherwise
* Searches @table for an object identified by @target. Returns the tbl_node
* contained in the object if successful, otherwise %NULL.
*/
-struct tbl_node *tbl_lookup(struct tbl *table, void *target, u32 hash,
- int (*cmp)(const struct tbl_node *, void *))
+struct tbl_node *tbl_lookup(struct tbl *table, void *target, int len, u32 hash,
+ int (*cmp)(const struct tbl_node *, void *, int))
{
struct tbl_bucket __rcu **bucketp = find_bucket(table, hash);
- struct tbl_bucket *bucket = rcu_dereference(*bucketp);
+ struct tbl_bucket *bucket = get_bucket(*bucketp);
int index;
if (!bucket)
return NULL;
- index = search_bucket(bucket, target, hash, cmp);
+ index = search_bucket(bucket, target, len, hash, cmp);
if (index < 0)
return NULL;
struct tbl_bucket *bucket;
unsigned int i;
- bucket = rcu_dereference(l2[l2_idx]);
+ bucket = get_bucket(l2[l2_idx]);
if (!bucket)
continue;
for (l2_idx = s_l2_idx; l2_idx < TBL_L2_SIZE; l2_idx++) {
struct tbl_bucket *bucket;
- bucket = rcu_dereference(l2[l2_idx]);
+ bucket = get_bucket_protected(l2[l2_idx]);
if (bucket && s_obj < bucket->n_objs) {
*bucketp = (l1_idx << TBL_L1_SHIFT) + l2_idx;
*objp = s_obj + 1;
int n_buckets = table->n_buckets * 2;
struct tbl *new_table;
- if (n_buckets >= TBL_MAX_BUCKETS) {
+ if (n_buckets > TBL_MAX_BUCKETS) {
err = -ENOSPC;
goto error;
}
err = -ENOMEM;
- new_table = tbl_create(TBL_MIN_BUCKETS);
+ new_table = tbl_create(n_buckets);
if (!new_table)
goto error;
int tbl_insert(struct tbl *table, struct tbl_node *target, u32 hash)
{
struct tbl_bucket __rcu **oldp = find_bucket(table, hash);
- struct tbl_bucket *old = rcu_dereference(*oldp);
+ struct tbl_bucket *old = get_bucket_protected(*oldp);
unsigned int n = old ? old->n_objs : 0;
struct tbl_bucket *new = bucket_alloc(n + 1);
int tbl_remove(struct tbl *table, struct tbl_node *target)
{
struct tbl_bucket __rcu **oldp = find_bucket(table, target->hash);
- struct tbl_bucket *old = rcu_dereference(*oldp);
+ struct tbl_bucket *old = get_bucket_protected(*oldp);
unsigned int n = old->n_objs;
struct tbl_bucket *new;