You are given a binary tree in which each node contains an integer value.
Find the number of paths that sum to a given value.
The path does not need to start or end at the root or a leaf, but it must go downwards (traveling only from parent nodes to child nodes).
The tree has no more than 1,000 nodes and the values are in the range -1,000,000 to 1,000,000.
Example:
root = [10,5,-3,3,2,null,11,3,-2,null,1], sum = 8 10 / \ 5 -3 / \ \ 3 2 11 / \ \ 3 -2 1 Return 3. The paths that sum to 8 are: 1. 5 -> 3 2. 5 -> 2 -> 1 3. -3 -> 11
The description was taken from https://leetcode.com/problems/path-sum-iii/.
#O(N) Time, O(N) Space
class Solution:
def pathSum(self, root: TreeNode, sum: int) -> int:
count, k = 0, sum
seen = {}
def dfs(node:TreeNode, curr_sum:int) -> None:
if not node:
return
nonlocal count
curr_sum += node.val
if curr_sum == k:
count += 1
if (curr_sum - k) in seen:
count += seen[curr_sum - k]
if curr_sum not in seen:
seen[curr_sum] = 1
else:
seen[curr_sum] += 1
dfs(node.left, curr_sum)
dfs(node.right, curr_sum)
seen[curr_sum] -= 1
dfs(root, 0)
return count
This problem is a variation of Subarray Sum Equals K which the solution for is also here on AlgoNinjutsu. This is a similar solution to Two-Sum as well.
In order to get a path that equals k
, we may encounter two situations. We have a running sum that started from the root that equals k
, or we will have a running sum cached that we can exclude from the current sum to also equal k
.
However, instead of looking through a single array for these sums, we have to instead look through branches of a binary tree. Since we have to go deep down each branch, that is an indication that we should use a Depth-First Search. Since we have to look through multiple branches to get each answer, that is an indication that we should use backtracking to "backtrack" from each path we travel down so that we can consider each branch as its own individual path.
Let's start by initializing our count, k
value, and a HashMap for sums that we have seen.
count, k = 0, sum
seen = {}
Let's create our recursive, dfs
function.
The function will take a node and a current running sum from the parent nodes as its arguments.
def dfs(node:TreeNode, curr_sum:int) -> None:
Being a recursive function, let's make sure to set our base case first. If we have reached a null node, that means that we have no more values to add to the sum so we will return from here.
if not node:
return
Otherwise, we will start by declaring our count as nonlocal, since it exists within the outer function.
nonlocal count
Then, we will add the current node's value to the current sum.
curr_sum += node.val
Let's take this tree as an input for example, and let's say k
is eight.
3 / \ 5 -3 / \ \ 4 2 11
If the current sum equals k
, then that would be the first situation mentioned previously where we have a running sum which started from the root that equals k
.
3 / \ 5 -3 / \ \ 4 2 11
if curr_sum == k:
count += 1
However, if the current sum minus k is in seen
, then that would be the second situation mentioned previously where we have a previous running sum cached that we can exclude from the current sum to also equal k
.
3 / \ 5 -3 / \ \ 4 2 11
if (curr_sum - k) in seen:
count += seen[curr_sum - k]
We cached three as the current sum in the first node of the path, so in the third node, we can subtract k
, eight, from the current sum, eleven, to see if we have any threes that we could exclude from the current sum to get a path sum that equals k
.
Notice we didn't just increment k
by one, but by the number of sums we have seen with that value. That is mostly to handle the test cases for paths that could reset the running sum to k
.
At the end of each iteration, we will cache the current sum for later use.
If it's already in seen
, we will increment the value. If not, we will initialize it by writing that we have seen it once.
if curr_sum not in seen:
seen[curr_sum] = 1
else:
seen[curr_sum] += 1
We will then continue our DFS on the left and right child nodes of the current node while remembering to pass the current sum of parent nodes plus the current node.
dfs(node.left, curr_sum)
dfs(node.right, curr_sum)
Once the DFS returns to the current function, we will need to backtrack and remove the current sum from the seen
dictionary.
3 / \ 5 -3 / \ \ 4 2 11
This is because once we get to another branch, we don't want to include any previous branch sums, the path has travel from a path node to a child node.
seen[curr_sum] -= 1
3 / \ 5 -3 / \ \ 4 2 11
Once our dfs
function is built, we will just need to make the initial call by passing the root as the starting node and zero as the current sum.
Once we have counted each continuous path that equals k
, we will return how many paths were counted.
dfs(root, 0)
return count