task_work.c 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #include <linux/spinlock.h>
  2. #include <linux/task_work.h>
  3. #include <linux/tracehook.h>
  4. int
  5. task_work_add(struct task_struct *task, struct callback_head *twork, bool notify)
  6. {
  7. struct callback_head *last, *first;
  8. unsigned long flags;
  9. /*
  10. * Not inserting the new work if the task has already passed
  11. * exit_task_work() is the responisbility of callers.
  12. */
  13. raw_spin_lock_irqsave(&task->pi_lock, flags);
  14. last = task->task_works;
  15. first = last ? last->next : twork;
  16. twork->next = first;
  17. if (last)
  18. last->next = twork;
  19. task->task_works = twork;
  20. raw_spin_unlock_irqrestore(&task->pi_lock, flags);
  21. /* test_and_set_bit() implies mb(), see tracehook_notify_resume(). */
  22. if (notify)
  23. set_notify_resume(task);
  24. return 0;
  25. }
  26. struct callback_head *
  27. task_work_cancel(struct task_struct *task, task_work_func_t func)
  28. {
  29. unsigned long flags;
  30. struct callback_head *last, *res = NULL;
  31. raw_spin_lock_irqsave(&task->pi_lock, flags);
  32. last = task->task_works;
  33. if (last) {
  34. struct callback_head *q = last, *p = q->next;
  35. while (1) {
  36. if (p->func == func) {
  37. q->next = p->next;
  38. if (p == last)
  39. task->task_works = q == p ? NULL : q;
  40. res = p;
  41. break;
  42. }
  43. if (p == last)
  44. break;
  45. q = p;
  46. p = q->next;
  47. }
  48. }
  49. raw_spin_unlock_irqrestore(&task->pi_lock, flags);
  50. return res;
  51. }
  52. void task_work_run(void)
  53. {
  54. struct task_struct *task = current;
  55. struct callback_head *p, *q;
  56. while (1) {
  57. raw_spin_lock_irq(&task->pi_lock);
  58. p = task->task_works;
  59. task->task_works = NULL;
  60. raw_spin_unlock_irq(&task->pi_lock);
  61. if (unlikely(!p))
  62. return;
  63. q = p->next; /* head */
  64. p->next = NULL; /* cut it */
  65. while (q) {
  66. p = q->next;
  67. q->func(q);
  68. q = p;
  69. }
  70. }
  71. }