import java.util.LinkedList;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.Random;

/*
 * Binary Search Tree class in Java
 * 
 * @author William Killian
 * 
 * Code adapted from:
 * 
 * @author Robert Sedgewick
 * @author Kevin Wayne
 */
public class BinarySearchTree<Key extends Comparable<Key>> {
	private Node root; // root of BST

	private class Node {
		private Key key; // sorted by key
		private Node left, right; // left and right subtrees

		public Node(Key key, Node left, Node right) {
			this.key = key;
			this.left = left;
			this.right = right;
		}

		public Node(Key key) {
			this(key, null, null);
		}
	}

	/**
	 * Initializes an empty symbol table.
	 */
	public BinarySearchTree() {
		this.root = null;
	}

	/**
	 * Inserts the specified key-value pair into the symbol table, overwriting the
	 * old value with the new value if the symbol table already contains the
	 * specified key. Deletes the specified key (and its associated value) from this
	 * symbol table if the specified value is {@code null}.
	 *
	 * @param key
	 *            the key
	 * @throws IllegalArgumentException
	 *             if {@code key} is {@code null}
	 */
	public void insert(Key key) {
		if (key == null) {
			throw new IllegalArgumentException("calls put() with a null key");
		}
		root = insert(root, key);
	}

	/**
	 * Does this symbol table contain the given key?
	 *
	 * @param key
	 *            the key
	 * @return {@code true} if this symbol table contains {@code key} and
	 *         {@code false} otherwise
	 * @throws IllegalArgumentException
	 *             if {@code key} is {@code null}
	 */
	public boolean contains(Key key) {
		if (key == null) {
			throw new IllegalArgumentException("argument to contains() is null");
		}
		return contains(root, key);
	}

	/**
	 * Removes the specified key and its associated value from this symbol table (if
	 * the key is in this symbol table).
	 *
	 * @param key
	 *            the key
	 * @throws IllegalArgumentException
	 *             if {@code key} is {@code null}
	 */
	public void delete(Key key) {
		if (key == null) {
			throw new IllegalArgumentException("calls delete() with a null key");
		}
		root = delete(root, key);
	}

	/**
	 * Removes the smallest key and associated value from the symbol table.
	 *
	 * @throws NoSuchElementException
	 *             if the symbol table is empty
	 */
	public void deleteMin() {
		if (isEmpty()) {
			throw new NoSuchElementException("Symbol table underflow");
		}
		root = deleteMin(root);
	}

	/**
	 * Removes the largest key and associated value from the symbol table.
	 *
	 * @throws NoSuchElementException
	 *             if the symbol table is empty
	 */
	public void deleteMax() {
		if (isEmpty()) {
			throw new NoSuchElementException("Symbol table underflow");
		}
		root = deleteMax(root);
	}

	/**
	 * Returns the number of key-value pairs in this symbol table.
	 * 
	 * @return the number of key-value pairs in this symbol table
	 */
	public int size() {
		return size(root);
	}

	/**
	 * Returns the smallest key in the symbol table.
	 *
	 * @return the smallest key in the symbol table
	 * @throws NoSuchElementException
	 *             if the symbol table is empty
	 */
	public Key min() {
		if (isEmpty()) {
			throw new NoSuchElementException("calls min() with empty symbol table");
		}
		return min(root).key;
	}

	/**
	 * Returns the largest key in the symbol table.
	 *
	 * @return the largest key in the symbol table
	 * @throws NoSuchElementException
	 *             if the symbol table is empty
	 */
	public Key max() {
		if (isEmpty()) {
			throw new NoSuchElementException("calls max() with empty symbol table");
		}
		return max(root).key;
	}

	/**
	 * Returns the height of the BST (for debugging).
	 *
	 * @return the height of the BST (a 1-node tree has height 0)
	 */
	public int height() {
		return height(root);
	}

	/**
	 * Returns true if this symbol table is empty.
	 * 
	 * @return {@code true} if this symbol table is empty; {@code false} otherwise
	 */
	public boolean isEmpty() {
		return size() == 0;
	}

	/**
	 * Returns the smallest key in the symbol table greater than or equal to
	 * {@code key}.
	 *
	 * @param key
	 *            the key
	 * @return the smallest key in the symbol table greater than or equal to
	 *         {@code key}
	 * @throws NoSuchElementException
	 *             if there is no such key
	 * @throws IllegalArgumentException
	 *             if {@code key} is {@code null}
	 */
	public Key ceiling(Key key) {
		if (key == null) {
			throw new IllegalArgumentException("argument to ceiling() is null");
		}
		if (isEmpty()) {
			throw new NoSuchElementException("calls ceiling() with empty symbol table");
		}
		Node x = ceiling(root, key);
		if (x == null) {
			return null;
		} else {
			return x.key;
		}
	}

	/**
	 * Returns the largest key in the symbol table less than or equal to
	 * {@code key}.
	 *
	 * @param key
	 *            the key
	 * @return the largest key in the symbol table less than or equal to {@code key}
	 * @throws NoSuchElementException
	 *             if there is no such key
	 * @throws IllegalArgumentException
	 *             if {@code key} is {@code null}
	 */
	public Key floor(Key key) {
		if (key == null) {
			throw new IllegalArgumentException("argument to floor() is null");
		}
		if (isEmpty()) {
			throw new NoSuchElementException("calls floor() with empty symbol table");
		}
		Node x = floor(root, key);
		if (x == null) {
			return null;
		} else {
			return x.key;
		}
	}

	/**
	 * Return the number of keys in the symbol table strictly less than {@code key}.
	 *
	 * @param key
	 *            the key
	 * @return the number of keys in the symbol table strictly less than {@code key}
	 * @throws IllegalArgumentException
	 *             if {@code key} is {@code null}
	 */
	public int rank(Key key) {
		if (key == null) {
			throw new IllegalArgumentException("argument to rank() is null");
		}
		return rank(key, root);
	}

	/**
	 * Returns all keys in the symbol table as an {@code Iterable}. To iterate over
	 * all of the keys in the symbol table named {@code st}, use the foreach
	 * notation: {@code for (Key key : st.keys())}.
	 *
	 * @return all keys in the symbol table
	 */
	public Iterable<Key> keys() {
		if (isEmpty()) {
			return new LinkedList<Key>();
		}
		return keys(min(), max());
	}

	/**
	 * Returns all keys in the symbol table in the given range, as an
	 * {@code Iterable}.
	 *
	 * @param lo
	 *            minimum endpoint
	 * @param hi
	 *            maximum endpoint
	 * @return all keys in the symbol table between {@code lo} (inclusive) and
	 *         {@code hi} (inclusive)
	 * @throws IllegalArgumentException
	 *             if either {@code lo} or {@code hi} is {@code null}
	 */
	public Iterable<Key> keys(Key lo, Key hi) {
		if (lo == null) {
			throw new IllegalArgumentException("first argument to keys() is null");
		}
		if (hi == null) {
			throw new IllegalArgumentException("second argument to keys() is null");
		}
		Queue<Key> queue = new LinkedList<Key>();
		keys(root, queue, lo, hi);
		return queue;
	}

	/**
	 * Returns the number of keys in the symbol table in the given range.
	 *
	 * @param lo
	 *            minimum endpoint
	 * @param hi
	 *            maximum endpoint
	 * @return the number of keys in the symbol table between {@code lo} (inclusive)
	 *         and {@code hi} (inclusive)
	 * @throws IllegalArgumentException
	 *             if either {@code lo} or {@code hi} is {@code null}
	 */
	public int size(Key lo, Key hi) {
		if (lo == null) {
			throw new IllegalArgumentException("first argument to size() is null");
		}
		if (hi == null) {
			throw new IllegalArgumentException("second argument to size() is null");
		}
		if (lo.compareTo(hi) > 0) {
			return 0;
		}
		if (contains(hi)) {
			return rank(hi) - rank(lo) + 1;
		} else {
			return rank(hi) - rank(lo);
		}
	}

	/**
	 * Returns the keys in the BST in level order (for debugging).
	 *
	 * @return the keys in the BST in level order traversal
	 */
	public Iterable<Key> levelOrder() {
		Queue<Key> keys = new LinkedList<Key>();
		Queue<Node> queue = new LinkedList<Node>();
		queue.add(root);
		while (!queue.isEmpty()) {
			Node x = queue.remove();
			if (x != null) {
				keys.add(x.key);
				queue.add(x.left);
				queue.add(x.right);
			}
		}
		return keys;
	}

	public String toString() {
		return toString(root, 0);
	}

	private Node insert(Node x, Key key) {
		if (x == null) {
			return new Node(key);
		}
		int cmp = key.compareTo(x.key);
		if (cmp < 0) {
			x.left = insert(x.left, key);
		} else if (cmp > 0) {
			x.right = insert(x.right, key);
		}
		return x;
	}

	private boolean contains(Node x, Key key) {
		if (key == null) {
			throw new IllegalArgumentException("calls get() with a null key");
		}
		if (x == null) {
			return false;
		}
		int cmp = key.compareTo(x.key);
		if (cmp < 0) {
			return contains(x.left, key);
		} else if (cmp > 0) {
			return contains(x.right, key);
		} else {
			return true;
		}
	}

	private Node delete(Node x, Key key) {
		if (x == null) {
			return null;
		}
		int cmp = key.compareTo(x.key);
		if (cmp < 0) {
			x.left = delete(x.left, key);
		} else if (cmp > 0) {
			x.right = delete(x.right, key);
		} else {
			if (x.right == null) {
				return x.left;
			}
			if (x.left == null) {
				return x.right;
			}
			Node t = x;
			x = min(t.right);
			x.right = deleteMin(t.right);
			x.left = t.left;
		}
		return x;
	}

	private Node deleteMin(Node x) {
		if (x.left == null) {
			return x.right;
		}
		x.left = deleteMin(x.left);
		return x;
	}

	private Node deleteMax(Node x) {
		if (x.right == null) {
			return x.left;
		}
		x.right = deleteMax(x.right);
		return x;
	}

	// return number of key-value pairs in BST rooted at x
	private int size(Node x) {
		if (x == null) {
			return 0;
		} else {
			return 1 + size(x.left) + size(x.right);
		}
	}

	private Node min(Node x) {
		if (x.left == null) {
			return x;
		} else {
			return min(x.left);
		}
	}

	private Node max(Node x) {
		if (x.right == null) {
			return x;
		} else {
			return max(x.right);
		}
	}

	private int height(Node x) {
		if (x == null) {
			return -1;
		}
		return 1 + Math.max(height(x.left), height(x.right));
	}

	private Node ceiling(Node x, Key key) {
		if (x == null) {
			return null;
		}
		int cmp = key.compareTo(x.key);
		if (cmp == 0) {
			return x;
		}
		if (cmp < 0) {
			Node t = ceiling(x.left, key);
			if (t != null) {
				return t;
			} else {
				return x;
			}
		}
		return ceiling(x.right, key);
	}

	private Node floor(Node x, Key key) {
		if (x == null) {
			return null;
		}
		int cmp = key.compareTo(x.key);
		if (cmp == 0) {
			return x;
		}
		if (cmp < 0) {
			return floor(x.left, key);
		}
		Node t = floor(x.right, key);
		if (t != null) {
			return t;
		} else {
			return x;
		}
	}

	// Number of keys in the subtree less than key.
	private int rank(Key key, Node x) {
		if (x == null) {
			return 0;
		}
		int cmp = key.compareTo(x.key);
		if (cmp < 0) {
			return rank(key, x.left);
		} else if (cmp > 0) {
			return 1 + size(x.left) + rank(key, x.right);
		} else {
			return size(x.left);
		}
	}

	private void keys(Node x, Queue<Key> queue, Key lo, Key hi) {
		if (x == null)
			return;
		int cmplo = lo.compareTo(x.key);
		int cmphi = hi.compareTo(x.key);
		if (cmplo < 0) {
			keys(x.left, queue, lo, hi);
		}
		if (cmplo <= 0 && cmphi >= 0) {
			queue.add(x.key);
		}
		if (cmphi > 0) {
			keys(x.right, queue, lo, hi);
		}
	}

	private String toString(Node x, int level) {
		if (x == null) {
			return "";
		}
		StringBuilder sb = new StringBuilder();
		sb.append(toString(x.right, level + 2));
		for (int i = 0; i < level; ++i) {
			sb.append(' ').append(' ');
		}
		sb.append(x.key);
		sb.append('\n');
		sb.append(toString(x.left, level + 2));
		return sb.toString();
	}
}