rbtree_test.c 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. #include <linux/module.h>
  2. #include <linux/rbtree_augmented.h>
  3. #include <linux/random.h>
  4. #include <asm/timex.h>
  5. #define NODES 100
  6. #define PERF_LOOPS 100000
  7. #define CHECK_LOOPS 100
  8. struct test_node {
  9. struct rb_node rb;
  10. u32 key;
  11. /* following fields used for testing augmented rbtree functionality */
  12. u32 val;
  13. u32 augmented;
  14. };
  15. static struct rb_root root = RB_ROOT;
  16. static struct test_node nodes[NODES];
  17. static struct rnd_state rnd;
  18. static void insert(struct test_node *node, struct rb_root *root)
  19. {
  20. struct rb_node **new = &root->rb_node, *parent = NULL;
  21. u32 key = node->key;
  22. while (*new) {
  23. parent = *new;
  24. if (key < rb_entry(parent, struct test_node, rb)->key)
  25. new = &parent->rb_left;
  26. else
  27. new = &parent->rb_right;
  28. }
  29. rb_link_node(&node->rb, parent, new);
  30. rb_insert_color(&node->rb, root);
  31. }
  32. static inline void erase(struct test_node *node, struct rb_root *root)
  33. {
  34. rb_erase(&node->rb, root);
  35. }
  36. static inline u32 augment_recompute(struct test_node *node)
  37. {
  38. u32 max = node->val, child_augmented;
  39. if (node->rb.rb_left) {
  40. child_augmented = rb_entry(node->rb.rb_left, struct test_node,
  41. rb)->augmented;
  42. if (max < child_augmented)
  43. max = child_augmented;
  44. }
  45. if (node->rb.rb_right) {
  46. child_augmented = rb_entry(node->rb.rb_right, struct test_node,
  47. rb)->augmented;
  48. if (max < child_augmented)
  49. max = child_augmented;
  50. }
  51. return max;
  52. }
  53. RB_DECLARE_CALLBACKS(static, augment_callbacks, struct test_node, rb,
  54. u32, augmented, augment_recompute)
  55. static void insert_augmented(struct test_node *node, struct rb_root *root)
  56. {
  57. struct rb_node **new = &root->rb_node, *rb_parent = NULL;
  58. u32 key = node->key;
  59. u32 val = node->val;
  60. struct test_node *parent;
  61. while (*new) {
  62. rb_parent = *new;
  63. parent = rb_entry(rb_parent, struct test_node, rb);
  64. if (parent->augmented < val)
  65. parent->augmented = val;
  66. if (key < parent->key)
  67. new = &parent->rb.rb_left;
  68. else
  69. new = &parent->rb.rb_right;
  70. }
  71. node->augmented = val;
  72. rb_link_node(&node->rb, rb_parent, new);
  73. rb_insert_augmented(&node->rb, root, &augment_callbacks);
  74. }
  75. static void erase_augmented(struct test_node *node, struct rb_root *root)
  76. {
  77. rb_erase_augmented(&node->rb, root, &augment_callbacks);
  78. }
  79. static void init(void)
  80. {
  81. int i;
  82. for (i = 0; i < NODES; i++) {
  83. nodes[i].key = prandom_u32_state(&rnd);
  84. nodes[i].val = prandom_u32_state(&rnd);
  85. }
  86. }
  87. static bool is_red(struct rb_node *rb)
  88. {
  89. return !(rb->__rb_parent_color & 1);
  90. }
  91. static int black_path_count(struct rb_node *rb)
  92. {
  93. int count;
  94. for (count = 0; rb; rb = rb_parent(rb))
  95. count += !is_red(rb);
  96. return count;
  97. }
  98. static void check_postorder(int nr_nodes)
  99. {
  100. struct rb_node *rb;
  101. int count = 0;
  102. for (rb = rb_first_postorder(&root); rb; rb = rb_next_postorder(rb))
  103. count++;
  104. WARN_ON_ONCE(count != nr_nodes);
  105. }
  106. static void check(int nr_nodes)
  107. {
  108. struct rb_node *rb;
  109. int count = 0, blacks = 0;
  110. u32 prev_key = 0;
  111. for (rb = rb_first(&root); rb; rb = rb_next(rb)) {
  112. struct test_node *node = rb_entry(rb, struct test_node, rb);
  113. WARN_ON_ONCE(node->key < prev_key);
  114. WARN_ON_ONCE(is_red(rb) &&
  115. (!rb_parent(rb) || is_red(rb_parent(rb))));
  116. if (!count)
  117. blacks = black_path_count(rb);
  118. else
  119. WARN_ON_ONCE((!rb->rb_left || !rb->rb_right) &&
  120. blacks != black_path_count(rb));
  121. prev_key = node->key;
  122. count++;
  123. }
  124. WARN_ON_ONCE(count != nr_nodes);
  125. WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root))) - 1);
  126. check_postorder(nr_nodes);
  127. }
  128. static void check_augmented(int nr_nodes)
  129. {
  130. struct rb_node *rb;
  131. check(nr_nodes);
  132. for (rb = rb_first(&root); rb; rb = rb_next(rb)) {
  133. struct test_node *node = rb_entry(rb, struct test_node, rb);
  134. WARN_ON_ONCE(node->augmented != augment_recompute(node));
  135. }
  136. }
  137. static int __init rbtree_test_init(void)
  138. {
  139. int i, j;
  140. cycles_t time1, time2, time;
  141. printk(KERN_ALERT "rbtree testing");
  142. prandom_seed_state(&rnd, 3141592653589793238ULL);
  143. init();
  144. time1 = get_cycles();
  145. for (i = 0; i < PERF_LOOPS; i++) {
  146. for (j = 0; j < NODES; j++)
  147. insert(nodes + j, &root);
  148. for (j = 0; j < NODES; j++)
  149. erase(nodes + j, &root);
  150. }
  151. time2 = get_cycles();
  152. time = time2 - time1;
  153. time = div_u64(time, PERF_LOOPS);
  154. printk(" -> %llu cycles\n", (unsigned long long)time);
  155. for (i = 0; i < CHECK_LOOPS; i++) {
  156. init();
  157. for (j = 0; j < NODES; j++) {
  158. check(j);
  159. insert(nodes + j, &root);
  160. }
  161. for (j = 0; j < NODES; j++) {
  162. check(NODES - j);
  163. erase(nodes + j, &root);
  164. }
  165. check(0);
  166. }
  167. printk(KERN_ALERT "augmented rbtree testing");
  168. init();
  169. time1 = get_cycles();
  170. for (i = 0; i < PERF_LOOPS; i++) {
  171. for (j = 0; j < NODES; j++)
  172. insert_augmented(nodes + j, &root);
  173. for (j = 0; j < NODES; j++)
  174. erase_augmented(nodes + j, &root);
  175. }
  176. time2 = get_cycles();
  177. time = time2 - time1;
  178. time = div_u64(time, PERF_LOOPS);
  179. printk(" -> %llu cycles\n", (unsigned long long)time);
  180. for (i = 0; i < CHECK_LOOPS; i++) {
  181. init();
  182. for (j = 0; j < NODES; j++) {
  183. check_augmented(j);
  184. insert_augmented(nodes + j, &root);
  185. }
  186. for (j = 0; j < NODES; j++) {
  187. check_augmented(NODES - j);
  188. erase_augmented(nodes + j, &root);
  189. }
  190. check_augmented(0);
  191. }
  192. return -EAGAIN; /* Fail will directly unload the module */
  193. }
  194. static void __exit rbtree_test_exit(void)
  195. {
  196. printk(KERN_ALERT "test exit\n");
  197. }
  198. module_init(rbtree_test_init)
  199. module_exit(rbtree_test_exit)
  200. MODULE_LICENSE("GPL");
  201. MODULE_AUTHOR("Michel Lespinasse");
  202. MODULE_DESCRIPTION("Red Black Tree test");