Allow for static initialisation.
[libeze] / ez-tree.c
1 /* ez-tree.c Balanced (avl) tree.
2
3    Copyright (C) 1998-2002, 2004 Free Software Foundation, Inc.
4    Copyright (C) 2019 Michael Zucchi
5
6    This program is free software; you can redistribute it and/or
7    modify it under the terms of the GNU General Public License as
8    published by the Free Software Foundation; either version 3 of the
9    License, or (at your option) any later version.
10
11    This program is distributed in the hope that it will be useful, but
12    WITHOUT ANY WARRANTY; without even the implied warranty of
13    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
14    See the GNU General Public License for more details.
15
16    You should have received a copy of the GNU General Public License
17    along with this program; if not, write to the Free Software
18    Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
19    02111-1307, USA.
20
21 */
22
23 #include <stdlib.h>
24 #include <stdio.h>
25
26 #include "ez-tree.h"
27
28 /**
29  * This is an implementation of an AVL balanced binary tree optimised
30  * for code and memory size.  The source itself is also intended to be
31  * concise.
32  *
33  * Each tree node takes only 2 pointers of overhead.  The balance
34  * factor is stored in the lower 2 bits of the link pointers.  This
35  * makes for some messy programming but the overhead isn't high.
36  *
37  * The code-size is reduced by using parameterisation and other small
38  * tricks.
39  *
40  * As with the other ez- containers normal operation involves no
41  * resource allocation.  Operations cannot fail assuming valid
42  * program state.  Likewise there is almost no checking of any
43  * state; the caller is entirely responsible for correct operation.
44  *
45  * The insert function was coded from Algorithm A in section 6.2.2 in
46  * ``The Art of Computer Programming'' Volume 3, Donald E. Knuth.
47  *
48  * Most of the the delete function was taken from GNU libavl 2.0.  It
49  * was then parameterised and broken apart into reusuable components.
50  *
51  * The rest I just made up.
52  */
53
54 /* ********************************************************************** */
55
56 #define BALANCE0 0
57
58 /*
59   Balance is stored as a 2-bit signed integer in the 2 lsb bits of
60   link[0].
61
62   A constraint is that all nodes must be allocated to at least
63   4-byte boundaries, which is always the case with any known memory
64   allocator.
65
66   This is all a bit fiddly but will save at least 3 bytes per node on
67   32-bit systems and 7 on 64-bit systems.
68  */
69
70 static __inline__ int node_balance(struct ez_node *n) __attribute__((always_inline));
71 static __inline__ void node_set_balance(struct ez_node *node, int balance) __attribute__((always_inline));
72 static __inline__ void *node_link(struct ez_node *n, int i) __attribute__((always_inline));
73 static __inline__ void node_set_link(struct ez_node *n, int i, struct ez_node *link) __attribute__((always_inline));
74
75 static void *fail_depth(void) __attribute__((noinline));
76
77 static void *fail_depth(void) {
78         fprintf(stderr, "EZ_TREE_MAX_DEPTH=%d exceeded\n", EZ_TREE_MAX_DEPTH);
79         return NULL;
80 }
81
82 static __inline__ int node_balance(struct ez_node *n) {
83         int i = n->u.tree.link[0];
84         return i << 30 >> 30;
85 }
86
87 static __inline__ void node_set_balance(struct ez_node *node, int balance) {
88         node->u.tree.link[0] = (node->u.tree.link[0] & ~3) | (balance & 3);
89 }
90
91 static __inline__ void *node_link(struct ez_node *n, int i) {
92         return (void *)(n->u.tree.link[i] & ~3);
93 }
94
95 static __inline__ void node_set_link(struct ez_node *n, int i, struct ez_node *link) {
96         n->u.tree.link[i] = (n->u.tree.link[i] & 3) | (intptr_t)link;
97 }
98
99 /**
100  * Perform a single node rotation.
101  *
102  * See Knuth $6.2.3 expression (1) Case 1, where s is A, and r is B.
103  * The names s and r are also from Knuth $6.2.3 algorithm A.
104  *
105  * @param s is really A
106  * @param r is really B
107  * @param ai is the link direction one takes to go from A to B.  It must be 0 or 1.
108  * @param a a balance factor to set on the nodes.  This depends on
109  * other state, it must be -1, 0, or 1.
110  */
111 static struct ez_node *node_rotate_single(struct ez_node * __restrict s, struct ez_node * __restrict r, int ai, int a) {
112         node_set_link(s, ai, node_link(r, ai ^ 1));
113         node_set_link(r, ai ^ 1, s);
114
115         node_set_balance(s, -a);
116         node_set_balance(r, +a);
117
118         return r;
119 }
120
121 /**
122  * Perform a double node rotation.
123  *
124  * See Knuth $6.2.3 expression (1) Case 2, where s is A, and r is B.
125  * The names s and r are also from Knuth $6.2.3 algorithm A.
126  *
127  * @param s is really A
128  * @param r is really B
129  * @param ai is the link direction one takes to go from A to B.  It must be 0 or 1.
130  * @param a is the node comparision value matching ai.  It must be -1 or +1.
131  */
132 static struct ez_node *node_rotate_double(struct ez_node *s, struct ez_node *r, int ai, int a) {
133         struct ez_node *p = node_link(r, ai ^ 1);
134         int pb = node_balance(p);
135
136         node_set_link(r, ai ^ 1, node_link(p, ai));
137         node_set_link(p, ai, r);
138         node_set_link(s, ai, node_link(p, ai ^ 1));
139         node_set_link(p, ai ^ 1, s);
140
141         node_set_balance(s, pb == +a ? -a : 0);
142         node_set_balance(r, pb == -a ? +a : 0);
143         node_set_balance(p, 0);
144
145         return p;
146 }
147
148 /**
149  * Internal node remove function.
150  *
151  * This is much much nastier than it looks.  The rebalance code was
152  * taken from libavl and paramterised etc.
153  *
154  * The stack is a list of all nodes navigated to @param p including
155  * the node link directions taken at each level.  The stack includes
156  * the psuedo-root node at stack[0] and the node to remove at sp[-1].
157  *
158  * @param stack stack base.
159  * @param sp stack pointer to one beyond the end of the stack.
160  */
161 static void node_remove(struct ez_tree_scan_info *stack, struct ez_tree_scan_info *sp) {
162         struct ez_node *p = sp[-1].node;
163         struct ez_node *r = ez_node_right(p);
164         struct ez_tree_scan_info *si;
165
166         if (!r) {
167                 node_set_link(sp[-2].node, sp[-2].link, ez_node_left(p));
168                 sp -= 1;
169         } else {
170                 struct ez_node *s = ez_node_left(r);
171
172                 if (!s) {
173                         node_set_link(r, 0, ez_node_left(p));
174                         node_set_balance(r, node_balance(p));
175                         node_set_link(sp[-2].node, sp[-2].link, r);
176                         sp[-1] = (struct ez_tree_scan_info) { r, 1 };
177                 } else {
178                         struct ez_node *q;
179                         struct ez_tree_scan_info *psp = sp;
180
181                         *sp++ = (struct ez_tree_scan_info){ r, 0 };
182                         while ((q = ez_node_left(s))) {
183                                 *sp++ = (struct ez_tree_scan_info){ s, 0 };
184                                 r = s;
185                                 s = q;
186                         }
187
188                         node_set_link(s, 0, ez_node_left(p));
189                         node_set_link(r, 0, ez_node_right(s));
190                         node_set_link(s, 1, ez_node_right(p));
191                         node_set_balance(s, node_balance(p));
192
193                         node_set_link(psp[-2].node, psp[-2].link, s);
194                         psp[-1] = (struct ez_tree_scan_info){ s, 1 };
195                 }
196         }
197
198         // Rebalance tree
199         // basic code from libavl but paramterised as in Knuth
200         si = sp-1;
201         while (si > stack) {
202                 struct ez_node *y = si->node;
203                 int ai = si->link;
204                 int aj = ai ^ 1;
205                 int a = ai * 2 - 1;
206                 int yb = node_balance(y);
207
208                 if (a == yb)
209                         node_set_balance(y, 0);
210                 else if (yb == 0) {
211                         node_set_balance(y, -a);
212                         break;
213                 } else {
214                         struct ez_node *x = node_link(y, aj);
215                         int xb = node_balance(x);
216                         struct ez_node *w;
217
218                         if (xb == a)
219                                 w = node_rotate_double(y, x, aj, -a);
220                         else
221                                 w = node_rotate_single(y, x, ai ^ 1, xb == 0 ? a : 0);
222
223                         node_set_link(si[-1].node, si[-1].link, w);
224
225                         if (xb == 0)
226                                 break;
227                 }
228
229                 si -= 1;
230         }
231
232         stack[0].node->u.tree.link[0] -= 1;
233 }
234
235 /**
236  * Recursive tree free operation.
237  *
238  * Doesn't update any nodes.
239  */
240 static void tree_free(ez_tree *tree, struct ez_node *n, ez_free_fn node_free) {
241         while (n) {
242                 struct ez_node *l = ez_node_left(n);
243                 struct ez_node *r = ez_node_right(n);
244
245                 tree_free(tree, r, node_free);
246                 node_free(n);
247                 n = l;
248         }
249 }
250
251 /* ********************************************************************** */
252
253 void ez_tree_init(struct ez_tree *tree, ez_cmp_fn node_cmp) {
254         tree->root.u.tree.link[0] = 0;
255         tree->root.u.tree.link[1] = 0;
256         tree->node_cmp = node_cmp;
257 }
258
259 void ez_tree_clear(ez_tree *tree, ez_free_fn node_free) {
260         tree_free(tree, ez_tree_root(tree), node_free);
261         tree->root.u.tree.link[0] = 0;
262         tree->root.u.tree.link[1] = 0;
263 }
264
265 void *ez_tree_get(ez_tree *tree, const void *node) {
266         const struct ez_node *k = node;
267         struct ez_node *p = ez_tree_root(tree);
268         ez_cmp_fn node_cmp = tree->node_cmp;
269
270         while (p) {
271                 int cmp = node_cmp(k, p);
272
273                 if (cmp)
274                         p = node_link(p, cmp > 0);
275                 else
276                         break;
277         }
278
279         return p;
280 }
281
282 void *ez_tree_put(ez_tree *tree, void *node) {
283         struct ez_node *k = node;
284         struct ez_tree_scan_info stack[EZ_TREE_MAX_DEPTH];
285         struct ez_tree_scan_info *sp = stack, *se = sp + EZ_TREE_MAX_DEPTH - 1;
286         ez_cmp_fn node_cmp = tree->node_cmp;
287
288         k->u.tree.link[0] = BALANCE0; // balance=0
289         node_set_link(k, 1, 0);
290
291         if (!tree->root.u.tree.link[1]) {
292                 tree->root.u.tree.link[1] = (intptr_t)k;
293                 tree->root.u.tree.link[0] += 1;
294                 return NULL;
295         }
296
297         // Search for node, keeping track of last non-balanced node in si
298         struct ez_tree_scan_info *si = stack+1;
299         int cmp = 1;
300         struct ez_node *p, *n;
301
302         *sp++ = (struct ez_tree_scan_info){ &tree->root, 1 };
303         n = p = ez_node_right(&tree->root);
304         while (n && cmp != 0) {
305                 if (sp >= se)
306                         return fail_depth();
307                 p = n;
308                 cmp = node_cmp(k, p);
309                 *sp++ = (struct ez_tree_scan_info){ p, cmp > 0 };
310                 n = node_link(p, cmp > 0);
311                 if (n && node_balance(n))
312                         si = sp;
313         }
314
315         // Matched, replace and return old node
316         if (cmp == 0) {
317                 node_set_link(sp[-2].node, sp[-2].link, k);
318                 *k = *p;
319                 return p;
320         }
321
322         // Insert new node
323         node_set_link(p, cmp > 0, k);
324         tree->root.u.tree.link[0] += 1;
325         *sp = (struct ez_tree_scan_info){ k, 0 };
326
327         // Fix balance factors between s and q
328         for (struct ez_tree_scan_info *pi = si+1;pi != sp; pi+=1)
329                 node_set_balance(pi->node, pi->link * 2 - 1);
330
331         // Balance
332         struct ez_node *s = si[0].node;
333         struct ez_node *r = si[1].node;
334         int ai = si->link;
335         int a = ai * 2 - 1; // map 0,1 to -1,+1
336         int balance = node_balance(s);
337
338         if (balance == 0) {
339                 node_set_balance(s, a);
340         } else if (balance == -a) {
341                 node_set_balance(s, 0);
342         } else {
343                 struct ez_node *q;
344
345                 if (node_balance(r) == a)
346                         q = node_rotate_single(s, r, ai, 0);
347                 else
348                         q = node_rotate_double(s, r, ai, a);
349
350                 node_set_link(si[-1].node, si[-1].link, q);
351         }
352
353         return NULL;
354 }
355
356 void *ez_tree_remove(ez_tree *tree, const void *key) {
357         struct ez_node *p = ez_node_right(&tree->root);
358         struct ez_node *n = p;
359         struct ez_tree_scan_info stack[EZ_TREE_MAX_DEPTH];
360         struct ez_tree_scan_info *sp = stack, *se = sp + EZ_TREE_MAX_DEPTH;
361         ez_cmp_fn node_cmp = tree->node_cmp;
362         int cmp = 1;
363
364         *sp++ = (struct ez_tree_scan_info) { &tree->root, 1 };
365         while (n && cmp != 0) {
366                 if (sp >= se)
367                         return fail_depth();
368                 p = n;
369                 cmp = node_cmp(key, p);
370                 *sp++ = (struct ez_tree_scan_info) { p, cmp > 0 };
371                 n = node_link(p, cmp > 0);
372         }
373
374         if (cmp == 0)
375                 node_remove(stack, sp);
376
377         return p;
378 }
379
380 static void *ez_tree_scan_dir(ez_tree_scan *scan, enum ez_node_link_t dir) {
381         int l = scan->level;
382         int nir = dir ^ 1;
383         struct ez_node *p = scan->stack[l-1].node;
384         struct ez_node *r = node_link(p, dir);
385
386         if (r) {
387                 scan->stack[l-1].link = dir;
388                 do {
389                         if (l >= EZ_TREE_MAX_DEPTH)
390                                 return fail_depth();
391                         scan->stack[l++] = (struct ez_tree_scan_info) { r, nir };
392                         r = node_link(r, nir);
393                 } while (r);
394                 scan->level = l;
395                 return scan->stack[l-1].node;
396         } else {
397                 while (l > 2 && scan->stack[l-2].link == dir)
398                         l--;
399                 scan->level = l - 1;
400                 if (l > 2)
401                         return scan->stack[l-2].node;
402                 return NULL;
403         }
404 }
405
406 void *ez_tree_scan_init_key(ez_tree *tree, ez_tree_scan *scan, enum ez_node_link_t scan_dir, enum ez_node_link_t key_dir, void *key) {
407         struct ez_node *p = ez_node_right(&tree->root);
408         ez_cmp_fn node_cmp = tree->node_cmp;
409         int l = 0;
410         int cmp = 1;
411
412         scan->tree = tree;
413         scan->scan_dir = scan_dir;
414         scan->stack[l++] = (struct ez_tree_scan_info) { &tree->root, 1 };
415
416         if (key) {
417                 while (p && cmp != 0) {
418                         if (l >= EZ_TREE_MAX_DEPTH)
419                                 return fail_depth();
420                         cmp = node_cmp(key, p);
421                         scan->stack[l++] = (struct ez_tree_scan_info) { p, cmp > 0 };
422                         p = node_link(p, cmp > 0);
423                 }
424         } else {
425                 while (p) {
426                         if (l >= EZ_TREE_MAX_DEPTH)
427                                 return fail_depth();
428                         scan->stack[l++] = (struct ez_tree_scan_info) { p, 1 - scan_dir };
429                         p = ez_node_link(p, 1 - scan_dir);
430                 }
431         }
432         scan->level = l;
433
434         if (l > 1) {
435                 if (key) {
436                         int c = (cmp > 0) - (cmp < 0); // -1 0 1 from compare
437                         int a = key_dir * 2 - 1;       // -1 x 1 from direction
438
439                         if (c == a)
440                                 return ez_tree_scan_dir(scan, key_dir);
441                         // same as:
442                         //if ((cmp < 0 && dir == EZ_LINK_LEFT)
443                         //    || (cmp > 0 && dir == EZ_LINK_RIGHT))
444                         //      return ez_tree_scan_next(scan, dir);
445                 }
446
447                 return scan->stack[l-1].node;
448         } else
449                 return NULL;
450 }
451
452 void *ez_tree_scan_init(ez_tree *tree, ez_tree_scan *scan, enum ez_node_link_t scan_dir) {
453         return ez_tree_scan_init_key(tree, scan, scan_dir, 1 - scan_dir, NULL);
454 }
455
456 void *ez_tree_scan_next(ez_tree_scan *scan) {
457         return ez_tree_scan_dir(scan, scan->scan_dir);
458 }
459
460 void *ez_tree_scan_prev(ez_tree_scan *scan) {
461         return ez_tree_scan_dir(scan, 1 - scan->scan_dir);
462 }
463
464 void *ez_tree_scan_remove(ez_tree_scan *scan, enum ez_node_link_t dir) {
465         int l = scan->level;
466         struct ez_tree_scan_info *sp = scan->stack + l;
467
468         if (l > 1) {
469                 struct ez_node *last = sp[-1].node;
470
471                 node_remove(scan->stack, sp);
472
473                 // If node_remove could maintain the stack or indicate when it
474                 // is still valid this could be avoided, but its messy.
475
476                 if (dir != -1)
477                         return ez_tree_scan_init_key(scan->tree, scan, scan->scan_dir, dir, last);
478         }
479         return NULL;
480 }
481
482 void *ez_tree_scan_put(ez_tree_scan *scan, void *node) {
483         struct ez_node *n = node;
484         struct ez_tree_scan_info *sp = scan->stack + scan->level;
485         struct ez_node *o = sp[-1].node;
486
487         *n = *o;
488         sp[-1].node = n;
489         node_set_link(sp[-2].node, sp[-2].link, n);
490
491         return o;
492 }