浏览代码

af-packet: Hold reference to bound network devices.

Old code was probably safe, but with this change we
can actually use the netdev object, not just compare
the pointer values.

Signed-off-by: Ben Greear <greearb@candelatech.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Ben Greear 14 年之前
父节点
当前提交
160ff18a07
共有 1 个文件被更改,包括 9 次插入5 次删除
  1. 9 5
      net/packet/af_packet.c

+ 9 - 5
net/packet/af_packet.c

@@ -1342,6 +1342,10 @@ static int packet_release(struct socket *sock)
 		__dev_remove_pack(&po->prot_hook);
 		__dev_remove_pack(&po->prot_hook);
 		__sock_put(sk);
 		__sock_put(sk);
 	}
 	}
+	if (po->prot_hook.dev) {
+		dev_put(po->prot_hook.dev);
+		po->prot_hook.dev = NULL;
+	}
 	spin_unlock(&po->bind_lock);
 	spin_unlock(&po->bind_lock);
 
 
 	packet_flush_mclist(sk);
 	packet_flush_mclist(sk);
@@ -1395,6 +1399,8 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc
 
 
 	po->num = protocol;
 	po->num = protocol;
 	po->prot_hook.type = protocol;
 	po->prot_hook.type = protocol;
+	if (po->prot_hook.dev)
+		dev_put(po->prot_hook.dev);
 	po->prot_hook.dev = dev;
 	po->prot_hook.dev = dev;
 
 
 	po->ifindex = dev ? dev->ifindex : 0;
 	po->ifindex = dev ? dev->ifindex : 0;
@@ -1439,10 +1445,8 @@ static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr,
 	strlcpy(name, uaddr->sa_data, sizeof(name));
 	strlcpy(name, uaddr->sa_data, sizeof(name));
 
 
 	dev = dev_get_by_name(sock_net(sk), name);
 	dev = dev_get_by_name(sock_net(sk), name);
-	if (dev) {
+	if (dev)
 		err = packet_do_bind(sk, dev, pkt_sk(sk)->num);
 		err = packet_do_bind(sk, dev, pkt_sk(sk)->num);
-		dev_put(dev);
-	}
 	return err;
 	return err;
 }
 }
 
 
@@ -1470,8 +1474,6 @@ static int packet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len
 			goto out;
 			goto out;
 	}
 	}
 	err = packet_do_bind(sk, dev, sll->sll_protocol ? : pkt_sk(sk)->num);
 	err = packet_do_bind(sk, dev, sll->sll_protocol ? : pkt_sk(sk)->num);
-	if (dev)
-		dev_put(dev);
 
 
 out:
 out:
 	return err;
 	return err;
@@ -2240,6 +2242,8 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void
 				}
 				}
 				if (msg == NETDEV_UNREGISTER) {
 				if (msg == NETDEV_UNREGISTER) {
 					po->ifindex = -1;
 					po->ifindex = -1;
+					if (po->prot_hook.dev)
+						dev_put(po->prot_hook.dev);
 					po->prot_hook.dev = NULL;
 					po->prot_hook.dev = NULL;
 				}
 				}
 				spin_unlock(&po->bind_lock);
 				spin_unlock(&po->bind_lock);