|
@@ -858,16 +858,23 @@ netlink_unlock_table(void)
|
|
|
wake_up(&nl_table_wait);
|
|
|
}
|
|
|
|
|
|
+static bool netlink_compare(struct net *net, struct sock *sk)
|
|
|
+{
|
|
|
+ return net_eq(sock_net(sk), net);
|
|
|
+}
|
|
|
+
|
|
|
static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
|
|
|
{
|
|
|
- struct nl_portid_hash *hash = &nl_table[protocol].hash;
|
|
|
+ struct netlink_table *table = &nl_table[protocol];
|
|
|
+ struct nl_portid_hash *hash = &table->hash;
|
|
|
struct hlist_head *head;
|
|
|
struct sock *sk;
|
|
|
|
|
|
read_lock(&nl_table_lock);
|
|
|
head = nl_portid_hashfn(hash, portid);
|
|
|
sk_for_each(sk, head) {
|
|
|
- if (net_eq(sock_net(sk), net) && (nlk_sk(sk)->portid == portid)) {
|
|
|
+ if (table->compare(net, sk) &&
|
|
|
+ (nlk_sk(sk)->portid == portid)) {
|
|
|
sock_hold(sk);
|
|
|
goto found;
|
|
|
}
|
|
@@ -980,7 +987,8 @@ netlink_update_listeners(struct sock *sk)
|
|
|
|
|
|
static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
|
|
|
{
|
|
|
- struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash;
|
|
|
+ struct netlink_table *table = &nl_table[sk->sk_protocol];
|
|
|
+ struct nl_portid_hash *hash = &table->hash;
|
|
|
struct hlist_head *head;
|
|
|
int err = -EADDRINUSE;
|
|
|
struct sock *osk;
|
|
@@ -990,7 +998,8 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
|
|
|
head = nl_portid_hashfn(hash, portid);
|
|
|
len = 0;
|
|
|
sk_for_each(osk, head) {
|
|
|
- if (net_eq(sock_net(osk), net) && (nlk_sk(osk)->portid == portid))
|
|
|
+ if (table->compare(net, osk) &&
|
|
|
+ (nlk_sk(osk)->portid == portid))
|
|
|
break;
|
|
|
len++;
|
|
|
}
|
|
@@ -1165,6 +1174,7 @@ static int netlink_release(struct socket *sock)
|
|
|
kfree_rcu(old, rcu);
|
|
|
nl_table[sk->sk_protocol].module = NULL;
|
|
|
nl_table[sk->sk_protocol].bind = NULL;
|
|
|
+ nl_table[sk->sk_protocol].compare = NULL;
|
|
|
nl_table[sk->sk_protocol].flags = 0;
|
|
|
nl_table[sk->sk_protocol].registered = 0;
|
|
|
}
|
|
@@ -1187,7 +1197,8 @@ static int netlink_autobind(struct socket *sock)
|
|
|
{
|
|
|
struct sock *sk = sock->sk;
|
|
|
struct net *net = sock_net(sk);
|
|
|
- struct nl_portid_hash *hash = &nl_table[sk->sk_protocol].hash;
|
|
|
+ struct netlink_table *table = &nl_table[sk->sk_protocol];
|
|
|
+ struct nl_portid_hash *hash = &table->hash;
|
|
|
struct hlist_head *head;
|
|
|
struct sock *osk;
|
|
|
s32 portid = task_tgid_vnr(current);
|
|
@@ -1199,7 +1210,7 @@ retry:
|
|
|
netlink_table_grab();
|
|
|
head = nl_portid_hashfn(hash, portid);
|
|
|
sk_for_each(osk, head) {
|
|
|
- if (!net_eq(sock_net(osk), net))
|
|
|
+ if (!table->compare(net, osk))
|
|
|
continue;
|
|
|
if (nlk_sk(osk)->portid == portid) {
|
|
|
/* Bind collision, search negative portid values. */
|
|
@@ -2315,9 +2326,12 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,
|
|
|
rcu_assign_pointer(nl_table[unit].listeners, listeners);
|
|
|
nl_table[unit].cb_mutex = cb_mutex;
|
|
|
nl_table[unit].module = module;
|
|
|
+ nl_table[unit].compare = netlink_compare;
|
|
|
if (cfg) {
|
|
|
nl_table[unit].bind = cfg->bind;
|
|
|
nl_table[unit].flags = cfg->flags;
|
|
|
+ if (cfg->compare)
|
|
|
+ nl_table[unit].compare = cfg->compare;
|
|
|
}
|
|
|
nl_table[unit].registered = 1;
|
|
|
} else {
|
|
@@ -2740,6 +2754,7 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
|
|
|
{
|
|
|
struct sock *s;
|
|
|
struct nl_seq_iter *iter;
|
|
|
+ struct net *net;
|
|
|
int i, j;
|
|
|
|
|
|
++*pos;
|
|
@@ -2747,11 +2762,12 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
|
|
|
if (v == SEQ_START_TOKEN)
|
|
|
return netlink_seq_socket_idx(seq, 0);
|
|
|
|
|
|
+ net = seq_file_net(seq);
|
|
|
iter = seq->private;
|
|
|
s = v;
|
|
|
do {
|
|
|
s = sk_next(s);
|
|
|
- } while (s && sock_net(s) != seq_file_net(seq));
|
|
|
+ } while (s && !nl_table[s->sk_protocol].compare(net, s));
|
|
|
if (s)
|
|
|
return s;
|
|
|
|
|
@@ -2763,7 +2779,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
|
|
|
|
|
|
for (; j <= hash->mask; j++) {
|
|
|
s = sk_head(&hash->table[j]);
|
|
|
- while (s && sock_net(s) != seq_file_net(seq))
|
|
|
+
|
|
|
+ while (s && !nl_table[s->sk_protocol].compare(net, s))
|
|
|
s = sk_next(s);
|
|
|
if (s) {
|
|
|
iter->link = i;
|