Convert BST to Greater Tree

Patrick Leaton
Problem Description

Given the root of a Binary Search Tree (BST), convert it to a Greater Tree such that every key of the original BST is changed to the original key plus the sum of all keys greater than the original key in BST.

As a reminder, a binary search tree is a tree that satisfies these constraints:

  • The left subtree of a node contains only nodes with keys less than the node's key.
  • The right subtree of a node contains only nodes with keys greater than the node's key.
  • Both the left and right subtrees must also be binary search trees.

 

Example 1:

Input: root = [4,1,6,0,2,5,7,null,null,null,3,null,null,null,8]
Output: [30,36,21,36,35,26,15,null,null,null,33,null,null,null,8]

Example 2:

Input: root = [0,null,1]
Output: [1,null,1]

Example 3:

Input: root = [1,0,2]
Output: [3,3,2]

Example 4:

Input: root = [3,2,4,1]
Output: [7,9,4,10]

 

Constraints:

  • The number of nodes in the tree is in the range [0, 10^4].
  • -10^4 <= Node.val <= 10^4
  • All the values in the tree are unique.
  • root is guaranteed to be a valid binary search tree.

 

Note: This question is the same as 1038: https://leetcode.com/problems/binary-search-tree-to-greater-sum-tree/

The description was taken from https://leetcode.com/problems/convert-bst-to-greater-tree/.

Problem Solution

#Recursive DFS
#O(N) Time, O(N) Space
class Solution:
    def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        suffix_sum = 0
       
        def dfs(node:TreeNode) -> TreeNode:
            if not node:
                return
            nonlocal suffix_sum
            dfs(node.right)
            suffix_sum += node.val
            node.val = suffix_sum
            dfs(node.left)
            return node
       
        return dfs(root)
 
#Iterative DFS
#O(N) Time, O(N) Space
class Solution:
    def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        if not root:
            return root
       
        stack = []
 
        def stack_right_children(node:TreeNode) -> None:
            while node:
                stack.append(node)
                node = node.right
       
        def dfs() -> TreeNode:
            suffix_sum = 0
            node = root
            stack_right_children(node)
            while stack:
                node = stack.pop()
                suffix_sum += node.val
                node.val = suffix_sum
                stack_right_children(node.left)
            return root
 
        return dfs()

Problem Explanation


Recursive Depth-First Search

Since this is a binary search tree, if we wanted to traverse the values from least to greatest, we would just need to perform an inorder traversal where we visit the deepest left node first and the deepest right node last.

To get the values from greatest to least, we would just need to do the opposite.

Keeping a rolling sum of each value before the current one is known as a Prefix Sum.  Since we are keeping a rolling value of each value after the current one, that is a suffix sum.


Let's start by initializing the suffix sum to zero.

        suffix_sum = 0


Next, we will create our DFS function for visiting each node.

        def dfs(node:TreeNode) -> TreeNode

 

The base case for our recursive function will be if we hit a null node then we have exhausted a DFS path so we will return from here.

            if not node:
                return

 

Otherwise, we will start each visit by specifying to Python that the suffix sum is nonlocal to this DFS function and its scope is within the outer function.

            nonlocal suffix_sum

 

As mentioned previously, if we wanted to get the values in order, we would visit the leftmost node first.  

Since we are wanting to get the reverse order, we will do that in reverse and visit the rightmost node first.

            dfs(node.right)

 

Once the DFS yoyo has gone past the end of its right branch and has come back up to the current node, we will add the node value to the suffix sum and set the node value to the current suffix sum.

            suffix_sum += node.val
            node.val = suffix_sum

 

After, we will recurse the left child.

            dfs(node.left)

 

Once we have visited the current node's right and left children, we will return the node.

This is only so that we have something to return after converting all of the root's children.

            return node


Now that we have our DFS function built, we just need to make the initial call and return the result.

        return dfs(root)



Iterative Depth-First Search

Converting a recursive preorder traversal to an iterative one is the simplest type of tree DFS to convert.  That usually just entails swapping out a stack for a call stack, swapping out function arguments for parameters in the tuple that is being placed in the stack, and creating a loop to run while the stack isn't empty.

Converting a recursive inorder traversal is much more difficult.

In order to do that, we can use a pointer to push each node on the leftmost branch into a stack, and each additional node that we visit during our search will have its right child's left branch pushed into the stack as well.

Since we are doing a reverse inorder traversal similar to the recursive version, we will be doing this process in reverse as well.


Let's start by checking if the root is null.

If that is the case then we have no greater nodes to convert the root with so we will return the null root.

        if not root:
            return root


Otherwise, we will initialize our stack.

        stack = []


Next, we will make a helper function that will stack all the right children of the current node.  

While the node hasn't hit a null value, we will keep adding each right child to the stack.

        def stack_right_children(node:TreeNode) -> None:
            while node:
                stack.append(node)
                node = node.right


Afterward, we will make our iterative DFS function.

        def dfs() -> TreeNode:

 

Let's initialize our suffix sum, our node pointer to the root, and fill our stack with all of the root's right children.

            suffix_sum = 0
            node = root
            stack_right_children(node)

 

While the stack isn't empty, we will pop a node from the stack, increment the suffix sum with its value, set its value as the new suffix sum, and then stack all of its left child node's right children.

This allows us to mimic recursion's DFS property of visiting the deepest right nodes before the left ones.

            while stack:
                node = stack.pop()
                suffix_sum += node.val
                node.val = suffix_sum
                stack_right_children(node.left)

 

Once our stack is empty, we have converted the whole tree so we will return the root.

            return root


Now that our DFS function is built, we just need to make the initial call.

        return dfs()