AVL Tree

Written by Luka Kerr on August 13, 2018

The code for a self balancing AVL tree, written while at UNSW.

#include <stdlib.h>
#include <stdio.h>
#include <assert.h>

/**
 * Handy macros for operating on an AVL tree
 */
#define left(tree) tree->left
#define right(tree) tree->right
#define key(tree) tree->key
#define val(tree) tree->val
#define max(x, y) (x > y ? x : y)
#define height(tree) (tree ? tree->height : 0)

/**
 * @typedef {struct}  AVL: a pointer to a root node of a AVL tree
 */
typedef struct AVLNode *AVL;

/**
 * @typedef  {struct}      AVLNode: a node in an AVL tree
 * @property {key}         - unique key of node
 * @property {val}         - value of node
 * @property {height}      - height of node
 * @property {AVLNode *}   - pointer to left child node
 * @property {AVLNode *}   - pointer to right child node
 */
typedef struct AVLNode {
  int key;
  int val;
  int height;
  struct AVLNode *left;
  struct AVLNode *right;
} AVLNode;

static AVL searchNodeAVL(AVL tree, int key);
static AVLNode *newAVLNode(int key, int val);
static void recalc(AVL tree);
static AVL rotateRight(AVL tree);
static AVL rotateLeft(AVL tree);
static AVL rebalance(AVL tree);
static int balance(AVL tree);
static AVLNode *minAVLNode(AVL node);

/**
 * Creates a new AVL tree and returns it
 * @return {AVL}   A new AVL tree (NULL initially)
 */
AVL newAVL() {
  return NULL;
}

/**
 * Inserts a new node in to the tree with a given key and value
 * @param  {AVL}   tree  The tree to insert into
 * @param  {int}   key   The key of the new node
 * @param  {int}   val   The value of the new node
 * @return {AVL}         The tree containing the inserted value
 */
AVL insertAVL(AVL tree, int key, int val) {
  if (tree == NULL) {
    return newAVLNode(key, val);
  } else if (key < key(tree)) {
    left(tree) = insertAVL(left(tree), key, val);
  } else if (key > key(tree)) {
    right(tree) = insertAVL(right(tree), key, val);
  }

  return rebalance(tree);
}

/**
 * Delete a node from a tree given a key
 * @param  {AVL}   tree  The tree to operate on
 * @param  {int}   key   The key of the node to delete
 * @return {AVL}         The tree with the node removed
 */
AVL deleteAVL(AVL tree, int key) {
  if (tree == NULL) {
    return NULL;
  }

  if (key < key(tree)) {
    left(tree) = deleteAVL(left(tree), key);
  } else if (key > key(tree)) {
    right(tree) = deleteAVL(right(tree), key);
  } else {
    // node with <= 1 child
    if (!left(tree) || !right(tree)) {
      AVLNode *tmp = left(tree) ? left(tree) : right(tree);

      if (tmp == NULL) {
        tmp = tree;
        tree = NULL;
      } else {
        *tree = *tmp;
      }
      free(tmp);
    } else {
      AVLNode *tmp = minAVLNode(right(tree));
      key(tree) = key(tmp);
      right(tree) = deleteAVL(right(tree), key(tmp));
    }
  }

  // tree only had a single node
  if (tree == NULL) {
    return NULL;
  }

  tree->height = 1 + max(height(left(tree)), height(right(tree)));
  int diff = balance(tree);

  // LL
  if (diff > 1 && balance(left(tree)) >= 0) {
    return rotateRight(tree);
  }

  // LR
  if (diff > 1 && balance(left(tree)) < 0) {
    left(tree) = rotateLeft(left(tree));
    return rotateRight(tree);
  }

  // RR
  if (diff < -1 && balance(right(tree)) <= 0) {
    return rotateLeft(tree);
  }

  // RL
  if (diff < -1 && balance(right(tree)) > 0) {
    right(tree) = rotateRight(right(tree));
    return rotateLeft(tree);
  }

  return tree;
}

/**
 * Checks whether a node is in a tree, given a key
 * @param  {int}   key   The key to use to search
 * @param  {AVL}   tree  The tree to search
 * @return {int}         Returns -1 if the node doesn't exist, or
 *                       returns the value of the node if it exists
 */
int searchAVL(AVL tree, int key) {
  if (tree == NULL) {
    return -1;
  } else if (key < key(tree)) {
    return searchAVL(left(tree), key);
  } else if (key > key(tree)) {
    return searchAVL(right(tree), key);
  }

  return val(tree);
}

/**
 * Swaps two nodes in a tree given two unique keys, given both nodes exist
 * @param  {AVL}   tree  The tree to operate on
 * @param  {AVL}   k1    The first key
 * @param  {int}   k2    The second key
 * @return {void}
 */
void swapAVL(AVL tree, int k1, int k2) {
  assert(tree != NULL);

  AVL t1 = searchNodeAVL(tree, k1);
  AVL t2 = searchNodeAVL(tree, k2);

  if (t1 && t2) {
    int tmp = val(t1);
    val(t1) = val(t2);
    val(t2) = tmp;
  }
}

/**
 * Frees all memory of the given AVL tree
 * @param  {AVL}   tree  The tree to free
 * @return {void}
 */
void freeAVL(AVL tree) {
  if (tree == NULL) {
    return;
  }

  freeAVL(left(tree));
  freeAVL(right(tree));
  free(tree);
}

/*
 * AVL helper/static functions
 */

/**
 * Checks whether a node is in a tree, given a key
 * @param  {int}   key   The key to use to search
 * @param  {AVL}   tree  The tree to search
 * @return {AVL}         Returns NULL if the node doesn't exist, or
 *                       returns the node if it exists
 */
static AVL searchNodeAVL(AVL tree, int key) {
  if (tree == NULL) {
    return NULL;
  } else if (key < key(tree)) {
    return searchNodeAVL(left(tree), key);
  } else if (key > key(tree)) {
    return searchNodeAVL(right(tree), key);
  }

  return tree;
}

/**
 * Creates a new node with a key and a value
 * @param  {int}   key   The key of the new node
 * @param  {int}   val   The value of the new node
 * @return {AVLNode *}   A pointer to a new node
 */
static AVLNode *newAVLNode(int key, int val) {
  AVLNode *new = malloc(sizeof(AVLNode));
  assert(new != NULL);

  new->key = key;
  new->val = val;
  new->height = 1;
  new->left = NULL;
  new->right = NULL;

  return new;
}

/**
 * Recalculates the height of the given AVL tree
 * @param  {AVL}   tree  The given tree
 * @return {void}
 */
static void recalc(AVL tree) {
  tree->height = 1 + max(height(left(tree)), height(right(tree)));
}

/**
 * Rotates the given AVL tree right
 * @param  {AVL}   tree  The given tree
 * @return {AVL}         The rotated tree
 */
static AVL rotateRight(AVL tree) {
  AVL left = left(tree);
  left(tree) = right(left);
  right(left) = tree;
  recalc(tree);
  recalc(left);
  return left;
}

/**
 * Rotates the given AVL tree left
 * @param  {AVL}   tree  The given tree
 * @return {AVL}         The rotated tree
 */
static AVL rotateLeft(AVL tree) {
  AVL right = right(tree);
  right(tree) = left(right);
  left(right) = tree;
  recalc(tree);
  recalc(right);
  return right;
}

/**
 * Rebalances the given AVL tree
 * @param  {AVL}   tree  The given tree
 * @return {AVL}         The balanced tree
 */
static AVL rebalance(AVL tree) {
  recalc(tree);

  if (height(left(tree)) - height(right(tree)) == 2) {
    if (height(right(left(tree))) > height(left(left(tree)))) {
      left(tree) = rotateLeft(left(tree));
    }
    return rotateRight(tree);
  } else if (height(right(tree)) - height(left(tree)) == 2) {
    if (height(left(right(tree))) > height(right(right(tree)))) {
      right(tree) = rotateRight(right(tree));
    }
    return rotateLeft(tree);
  }

  return tree;
}

/**
 * Returns the balance of the given AVL tree
 * @param  {AVL}   tree  The given tree
 * @return {int}         The balance of the tree (i.e. the height of the
 *                       left subtree from the root minus the right subtree
 *                       from the root)
 */
static int balance(AVL tree) {
  if (tree == NULL) {
    return 0;
  }

  return height(left(tree)) - height(right(tree));
}

/**
 * Returns the minimum node in the given AVL tree
 * @param  {AVL}   node  The given AVL node
 * @return {AVLNode *}   The minimum node in the tree (left most node)
 */
static AVLNode *minAVLNode(AVL node) {
  AVLNode *curr = node;

  while (left(curr) != NULL) {
    curr = left(curr);
  }

  return curr;
}