// ---------------------------------------------------------------------------
// BST.h
//
// A BST class that you can reuse by decorating.  It does insertion,
// lookup, rotations, and splays.  There is no remove operation,
// but you can call clear() to empty it out.
//
// The best way to use this class to make splay trees, AVL trees,
// Red-black trees and so on is by delegation, though in some cases
// you may be able to get away with subclassing.  But as you know,
// delegation (composition) is generally preferable to subclassing.
// ---------------------------------------------------------------------------

#ifndef _BST_H
#define _BST_H

#include <stdexcept>
using namespace std;

// A BST object contins only a pointer to its root node and its
// (cached) size.  It is assumed that whatever type is being
// contained in the tree has acceptable "<" and "==" operations.

template<class T>
class BST {

// Inner class for binary tree nodes.  Each node contains some data, and
// references to its left child, right child, and parent.
private:
  class Node {
  public:
    T data;
    Node* left;
    Node* right;
    Node* parent;
    BST* tree;

    // Construct a new (leaf) node
    Node(T data, Node* parent, BST* tree):
      data(data), left(0), right(0), parent(parent), tree(tree) {
    }

    // Destroy a node and its subtrees
    ~Node() {
      delete left;
      delete right;
    }

    // Return the data item stored in this node.
    T getData() {
      return data;
    }

    // Replace the contents of this node with new data, returning
    // what used to be there.
    T setData(T data) {
      T oldData = this->data;
      this->data = data;
      return oldData;
    }

    // These methods help to ensure consistency!!

    void makeRoot() {tree->root = this; parent = 0;}
    void setLeft(Node* n) {left = n; if (n != 0) n->parent = this;}
    void setRight(Node* n) {right = n; if (n != 0) n->parent = this;}

    // Rotations.  Note I actually check whether the rotation can legally be
    // performed.  It's slower, I guess, but it's good bulletproofing.  In
    // production code, the check might be eliminated since the BST would be
    // totally buried.

    void rotateLeft() {
      if (right == 0) return;
      Node* oldRight = right;
      setRight(oldRight->left);
      if (parent == 0) oldRight->makeRoot();
      else if (parent->left == this) parent->setLeft(oldRight);
      else parent->setRight(oldRight);
      oldRight->setLeft(this);
    }

    void rotateRight() {
      if (left == 0) return;
      Node* oldLeft = left;
      setLeft(oldLeft->right);
      if (parent == 0) oldLeft->makeRoot();
      else if (parent->left == this) parent->setLeft(oldLeft);
      else parent->setRight(oldLeft);
      oldLeft->setRight(this);
    }

    // Splays this node.  The implementation isn't so bad once you
    // realize it uses the primitive rotation methods.  Of course
    // when you call the rotate methods you do take a bit
    // of a performance hit.  Production code would do each
    // case inline.  NOTE: An interesting educational exercise
    // would be to comment out the "while" line and its
    // ending brace so you can watch each step of the splay
    // if you have a good graphical tester.
    void splay() {
      while (this != tree->root) {
        if (parent->parent == 0) {
          if (this == parent->right) parent->rotateLeft();
          else parent->rotateRight();
        }
        else if (this == parent->right && parent == parent->parent->right) {
          parent->parent->rotateLeft();
          parent->rotateLeft();
        }
        else if (this == parent->left && parent == parent->parent->left) {
          parent->parent->rotateRight();
          parent->rotateRight();
        }
        else if (this == parent->right && parent == parent->parent->left) {
          parent->rotateLeft();
          parent->rotateRight();
        }
        else if (this == parent->left && parent == parent->parent->right) {
          parent->rotateRight();
          parent->rotateLeft();
        }
      }
    }

    // A neat recursive method to delete all the nodes.
    void destroySubtree() {
      if (left != 0) left->destroySubtree();
      if (right != 0) right->destroySubtree();
      delete this;
    }
  };

// The tree itself is just a pointer to its root.
private:
  Node<T>* root;
  int size;

  // Nested classes can't automatically see private fields in C++
  friend class Visitor;
  friend class Node;

// Constructors and Destructors
public:

  // Constructs a new EMPTY binary search tree.  The tree is represented
  // with a pointer to the root and its size.
  BST(): root(0), size(0) {
  }

  // Deletes all the nodes.
  ~BST() {
    clear();
  }

// Behavior
public:

  // Returns the number of items in the list.
  int getSize() {
    return size;
  }

  // Returns whether the list is empty.
  bool isEmpty() {
    return size == 0;
  }

  // Returns whether the given data is in some node in the tree.
  // We defer the work to another method which returns the node
  // containing the item if it exists and null otherwise.
  bool contains(T data) {
    return nodeContaining(data) != 0;
  }

  // Adds a single data item to the tree.  If there is already an
  // item in the tree that compares equal to the item being inserted,
  // it is "overwritten" by the new item and returned.
  void add(T data) {
    if (root == 0) {
      root = new Node(data, 0, this);
      size++;
    }
    Node* n = root;
    while (true) {
      if (data == n->data) return;
      else if (data < n->data) {
	if (n->left != 0) n = n->left;
	else {size++; n->left = new Node(data, n, this); return;}
      }
      else { // data > n->data
	if (n->right != 0) n = n->right;
        else {size++; n->right = new Node(data, n, this); return;}
      }
    }
  }


  // Removes all the items in the tree.
  void clear() {
    delete root;
    size = 0;
    root = 0;
  }

  // A visitor for the binary search tree.  It is a classic
  // example of the template method pattern.  Calling go(t) will
  // do the Euler Tour, depth-first, left-to-right, so you'll hit
  // each node three times.  To do traversals, simply make a
  // subclass of Vistor and implement any or all of the five
  // callbacks.
  class Visitor {
  protected:
    virtual void initialize() {}
    virtual void onLeft(Node* node) {}
    virtual void onBelow(Node* node) {}
    virtual void onRight(Node* node) {}
    virtual T getResult() {return T();}

  public:
    T go(const BST& t) {
      initialize();
      if (t.root != 0) {visitFrom(t.root);}
      return getResult();
    }
  private:
    void visitFrom(Node* n) {
      onLeft(n);
      if (n->left != 0) {visitFrom(n->left);}
      onBelow(n);
      if (n->right != 0) {visitFrom(n->right);}
      onRight(n);
    }
  };

// Helper methods
public:

  // A special helper method that returns the node containing
  // a given object.  We make it package-visible instead of
  // protected since subclasses don't care about it, but testers
  // might!
  Node* nodeContaining(T data) {
    for (Node* n = root; n != 0;) {
      if (data == n->data) {
        return n;
      } else if (data < n->data) {
        n = n->left;
      } else {
        n = n->right;
      }
    }

    // not found
    return 0;
  }


// Prohibit copying and assignment
private:
  BST(BST&);
  BST& operator=(BST&);

};

#endif