We'll start by looking at a simple example:
First, let's consider the straightforward way, using a loop: (source file)
public class PowerWithIteration { public static void main (String[] argv) { int p = power (3, 2); System.out.println ( "3^2 = " + p); p = power (3, 4); System.out.println ( "3^4 = " + p); p = power (2, 8); System.out.println ( "2^8 = " + p); } static int power (int a, int b) { int p = 1; // a^0 while (b > 0) { p = p * a; // b times through loop. b --; } return p; } }
Next, let's look at the recursive version of the same: (source file)
public class PowerExample { public static void main (String[] argv) { int p = power (3, 2); System.out.println ( "3^2 = " + p); p = power (3, 4); System.out.println ( "3^4 = " + p); p = power (2, 8); System.out.println ( "2^8 = " + p); } static int power (int a, int b) { int p; if (b == 0) { p = 1; } else { // Note use of recursion: p = a * power (a, b-1); } return p; } }How does this work?
To see what's happening more clearly, we'll print something in each call: (source file)
public class PowerExample2 { public static void main (String[] argv) { // We added a third parameter to track which call we're in. int p = power (3, 2, 0); System.out.println ( "3^2 = " + p); p = power (3, 4, 0); System.out.println ( "3^4 = " + p); p = power (2, 8, 0); System.out.println ( "2^8 = " + p); } static int power (int a, int b, int level) { // Print the "level", along with extra blanks. System.out.println ( makeBlanks(level) + "Level " + level + ": b=" + b ); int p; if (b == 0) { p = 1; } else { p = a * power (a, b-1, level+1); } // Now the result. System.out.println ( makeBlanks(level) + "Level " + level + ": p=" + p ); return p; } static String makeBlanks (int n) { String str = ""; for (int i=0; i < n; i++) { str += " "; } return str; } }Here's some of the output, annotated:
Level 0: b=2 // The first time power() is called a=3, b=2 Level 1: b=1 // First recursion: b=1 Level 2: b=0 // Second recursion: b=0 Level 2: p=1 // No further recursion: "bottom out" case when b=0 Level 1: p=3 // Back to first recursion, when b=1 Level 0: p=9 // Back to first call, when b=2 3^2 = 9 Level 0: b=4 // First time power() is called with a=3, b=4 Level 1: b=3 // First recursion: b=3 Level 2: b=2 // Second recursion: b=2 Level 3: b=1 // Third recursion: b=1 Level 4: b=0 // Fourth recursion: b=0 Level 4: p=1 // Bottom out and return p=1 Level 3: p=3 // Back to 3rd recursion, compute p=3 Level 2: p=9 // Back to 2nd recursion, compute p=9 Level 1: p=27 // Back to 1st recursion, p=27 now Level 0: p=81 // Back to original call, p=81 (final result to return to main) 3^4 = 81
To see this a little differently, let's examine the state of memory during these calls for the 32 case.
Next, we will point out that the original example with recursion could have been written more "tightly" (compactly) as follows: (source file)
public class PowerExample3 { public static void main (String[] argv) { // ... } static int power (int a, int b) { if (b == 0) { return 1; } return (a * power (a, b-1)); } }Note:
Now for another example:
public class FactorialExample { public static void main (String[] argv) { System.out.println ( "3! = " + factorial(3) ); System.out.println ( "5! = " + factorial(5) ); System.out.println ( "5! x 3! = " + (factorial(3) * factorial(5)) ); } static int factorial (int n) { if (n == 1) { return 1; } return ( n * factorial(n-1) ); } }
We need to point out one more important idea:
static int power (int a, int b) { // ... }In this case, b is the parameter that changes and controls the recursion.
public class PowerExample4 { // Declare "a" as global to be set in main() and accessed elsewhere. static int a; public static void main (String[] argv) { a = 3; int p = power (2); // Only "b" is passed as parameter. System.out.println ( "3^4 = " + p); p = power (4); System.out.println ( "3^4 = " + p); a = 2; p = power (8); System.out.println ( "2^8 = " + p); } // Only one parameter. static int power (int b) { int p; if (b == 0) { p = 1; } else { p = (a * power (b-1)); } System.out.println ("Intermediate result: " + a + "^" + b + "=" + p); return p; } }
Next, let's see what happens when b is a global variable: (source file)
public class PowerExample5 { // Declare "a" as global to be set in main() and accessed elsewhere. static int a; // Now try "b" static int b; public static void main (String[] argv) { a = 3; b = 2; int p = power (); // No parameters. System.out.println ( "3^2 = " + p); b = 4; p = power (); System.out.println ( "3^4 = " + p); a = 2; b = 8; p = power (); System.out.println ( "2^8 = " + p); } // No parameter. static int power () { int p; if (b == 0) { p = 1; } else { b = b - 1; p = a * power(); } System.out.println ("Intermediate result: " + a + "^" + b + "=" + p); return p; } }
We know how to loop through an array to search for a particular value.
Let's see how this can be done using recursion: (source file)
public class ArraySearch { public static void main (String[] argv) { // Fill an array with some random values - for testing. int[] testData = makeRandomArray (10); // A random search term. int searchTerm = UniformRandom.uniform (1, 100); // Call the recursive search method. boolean found = search (A, searchTerm, 0); System.out.println ("found=" + found); } static boolean search (int[] A, int value, int index) { // Two "bottom out" cases: if (index >= A.length) { return false; } if (A[index] == value) { return true; } // Else try further into the array: return search (A, value, index+1); } static int[] makeRandomArray (int length) { // ... we've seen this before ... } }Note:
We'll use the following idea to check whether a string is a palindrome:
public class Palindrome { public static void main (String[] argv) { String str = "redder"; System.out.println ( str + " " + checkPalindrome(str) ); str = "river"; System.out.println ( str + " " + checkPalindrome(str) ); str = "neveroddoreven"; System.out.println ( str + " " + checkPalindrome(str) ); } static String checkPalindrome (String str) { // Two bottom out cases: if ( (str.length() == 0) || (str.length() == 1) ) { return "is a palindrome"; } if ( str.charAt(0) != str.charAt(str.length()-1) ) { return "is not a palindrome"; } // First and last letters matched. Remove them and check remaining recursively. String nextStr = str.substring (1, str.length()-1); return checkPalindrome (nextStr); } }Note:
Consider this problem of seating people in a row of seats:
1 2 _ 1 _ 2 2 1 _ 2 _ 1 _ 1 2 _ 2 1
First let's solve the problem of merely counting the number of such arrangements: (source file)
public class PermutationSeating { public static void main (String[] argv) { int numSeats = 3; // M int numPeople = 2; // K int n = countPermutations (numSeats, numPeople); System.out.println (numPeople + " can sit on " + numSeats + " seats in " + n + " different arrangements"); } static int countPermutations (int numSpaces, int numRemaining) { // Given numRemaining to assign among numSpaces seats, // count all possible arrangements. // Bottom out case: if none are remaining, there's only one way to do that. if (numRemaining == 0) { return 1; } // Otherwise, obtain the count for a smaller version of the problem. int n = countPermutations (numSpaces-1, numRemaining-1); // Since one person can chose from among numSpaces, there // numSpaces ways of doing that, which we multiply with n. return (numSpaces * n); } }Note:
Now let's modify the program to print the actual permutations: (source file)
import java.util.*; public class PermutationSeating2 { // We'll use a global to count the number of permutations. static int count; public static void main (String[] argv) { // Test case 1: M=3, K=2. int numSeats = 3; int numPeople = 2; int[] seats = new int [numSeats]; count = 0; printPermutations (numSeats, numPeople, seats, 1); System.out.println (" => " + count + " permutations"); // Test case 1: M=5, K=2. numSeats = 5; numPeople = 2; seats = new int [numSeats]; count = 0; printPermutations (numSeats, numPeople, seats, 1); System.out.println (" => " + count + " permutations"); } static void printPermutations (int numSpaces, int numRemaining, int[] seats, int person) { // Bottom-out case. Note that we are printing here, since each time // we get here we complete one permutation. if (numRemaining == 0) { // Print. System.out.println ( Arrays.toString(seats) ); // Remember to increment the number of permutations found. count ++; return; } // Otherwise, non-base case: look for an empty spot for "person" for (int i=0; i < seats.length; i++) { if (seats[i] == 0) { // Empty spot. seats[i] = person; // Recursively assign remaining, starting with person+1 printPermutations (numSpaces-1, numRemaining-1, seats, person+1); // Important: we need to un-do the seating for other trials. seats[i] = 0; } } //end-for } }Note:
static void printPermutations (int numSpaces, int numRemaining, int[] seats, int person)
// Try a seat. seats[i] = person; // Recursively assign remaining, starting with person+1 printPermutations (numSpaces-1, numRemaining-1, seats, person+1); // Important: now un-do the seating for the next iteration of the loop. seats[i] = 0;
[1, 2, 0] [1, 0, 2] [2, 1, 0] [0, 1, 2] [2, 0, 1] [0, 2, 1] => 6 permutations [1, 2, 0, 0, 0] [1, 0, 2, 0, 0] [1, 0, 0, 2, 0] [1, 0, 0, 0, 2] [2, 1, 0, 0, 0] [0, 1, 2, 0, 0] [0, 1, 0, 2, 0] [0, 1, 0, 0, 2] [2, 0, 1, 0, 0] [0, 2, 1, 0, 0] [0, 0, 1, 2, 0] [0, 0, 1, 0, 2] [2, 0, 0, 1, 0] [0, 2, 0, 1, 0] [0, 0, 2, 1, 0] [0, 0, 0, 1, 2] [2, 0, 0, 0, 1] [0, 2, 0, 0, 1] [0, 0, 2, 0, 1] [0, 0, 0, 2, 1] => 20 permutations
Next, we'll look at another way to handle the un-do part of the program: (source file)
public class PermutationSeating3 { static int count; public static void main (String[] argv) { // ... } static void printPermutations (int numSpaces, int numRemaining, int[] seats, int person) { if (numRemaining == 0) { System.out.println ( Arrays.toString(seats) ); count ++; return; } // Look for an empty spot. for (int i=0; i < seats.length; i++) { if (seats[i] == 0) { // Make a copy so that we don't need to un-do. int[] seatsCopy = copy (seats); // Assign and recurse. seatsCopy[i] = person; printPermutations (numSpaces-1, numRemaining-1, seatsCopy, person+1); // Don't need this: seats[i] = 0; } } } static int[] copy (int[] A) { int[] B = new int [A.length]; for (int i=0; i < A.length; i++) { B[i] = A[i]; } return B; } }Note:
Next we'll solve a slight variation of the problem:
public class PermutationSeating4 { static int count; public static void main (String[] argv) { // ... } static void printPermutations (int numSpaces, int numRemaining, int[] seats, int person) { if (numRemaining == 0) { // Print. System.out.println ( Arrays.toString(seats) ); count ++; return; } // Look for an empty spot for this person. for (int i=0; i < seats.length; i++) { // Check for "banned" configuration: if (person == 1) { if ( (i == 0) || (i == seats.length-1) ) { // Person 1 can't sit at the ends. // Skip to next loop iteration. continue; } } // We reach here if it's ok to explore further. if (seats[i] == 0) { // Empty spot. seats[i] = person; printPermutations (numSpaces-1, numRemaining-1, seats, person+1); seats[i] = 0; } } } }
Consider this problem:
public class Manhattan { public static void main (String[] argv) { // Test case 1: go from (1,1) to (0,0) int r = 1, c = 1; int n = countPaths (r, c); System.out.println ("r=" + r + " c=" + c + " => n=" + n); // Test case 2: go from (2,2) to (0,0) r = 2; c = 2; n = countPaths (r, c); System.out.println ("r=" + r + " c=" + c + " => n=" + n); // Test case 2: go from (5,7) to (0,0) r = 5; c = 7; n = countPaths (r, c); System.out.println ("r=" + r + " c=" + c + " => n=" + n); } static int countPaths (int numRows, int numCols) { // Bottom out case: there's only one way to (0,0). // Note: it's || and not &&. if ( (numRows == 0) || (numCols == 0) ) { return 1; } // Otherwise, reduce to two sub-problems and add. int downCount = countPaths (numRows-1, numCols); int rightCount = countPaths (numRows, numCols-1); return (downCount + rightCount); } }Note:
int downCount = countPaths (numRows-1, numCols);
int rightCount = countPaths (numRows, numCols-1);
Let's now modify the code to print out the different paths:
[2,2] -> [1,2] -> [0,2] -> [0,1] -> [0,0] [2,2] -> [1,2] -> [1,1] -> [0,1] -> [0,0] [2,2] -> [1,2] -> [1,1] -> [1,0] -> [0,0] [2,2] -> [2,1] -> [1,1] -> [0,1] -> [0,0] [2,2] -> [2,1] -> [1,1] -> [1,0] -> [0,0] [2,2] -> [2,1] -> [2,0] -> [1,0] -> [0,0](There are 6 different paths.)
public class Manhattan2 { public static void main (String[] argv) { // Test case 1: int r = 1, c = 1; int n = countPaths (r, c, "[1,1]"); System.out.println ("r=" + r + " c=" + c + " => n=" + n); // Test case 2: r = 2; c = 2; n = countPaths (r, c, "[2,2]"); System.out.println ("r=" + r + " c=" + c + " => n=" + n); } static int countPaths (int numRows, int numCols, String partialPath) { // Bottom out case: this is more complicated now. if (numRows == 0) { // Make the path across the columns. String finalStr = partialPath; for (int c=numCols-1; c>=0; c--) { finalStr += " -> [0," + c + "]"; } System.out.println (finalStr); return 1; } else if (numCols == 0) { // Make the path down rows. String finalStr = partialPath; for (int r=numRows-1; r>=0; r--) { finalStr += " -> [" + r + ",0]"; } System.out.println (finalStr); return 1; } // Otherwise, reduce problem size. // Downwards. String downpathStr = partialPath + " -> " + "[" + (numRows-1) + "," + numCols + "]"; int downCount = countPaths (numRows-1, numCols, downpathStr); // Rightwards. String rightpathStr = partialPath + " -> " + "[" + (numRows) + "," + (numCols-1) + "]"; int rightCount = countPaths (numRows, numCols-1, rightpathStr); // Add the two. return (downCount + rightCount); } }Note:
Recursion is not always the best solution:
Sometimes, recursion can be downright wasteful, as we'll see in the next example.
Let's consider a recursive program to compute Fibonacci numbers.
1st Fibonacci number: 0 2nd Fibonacci number: 1 3rd Fibonacci number: 1 4th Fibonacci number: 2 5th Fibonacci number: 3 6th Fibonacci number: 5 7th Fibonacci number: 8 ...
public class Fibonacci { public static void main (String[] argv) { // Test case 1: int n = 5; int f = fibonacci (n); System.out.println ("f(" + n + ") = " + f); // Test case 2: n = 20; f = fibonacci (n); System.out.println ("f(" + n + ") = " + f); } static int fibonacci (int n) { // Base cases: if (n == 1) { return 0; } else if (n == 2) { return 1; } // f_n = f_{n-1} + f_{n-2} int f_n_minus_one = fibonacci (n-1); int f_n_minus_two = fibonacci (n-2); return f_n_minus_one + f_n_minus_two; } }Thus, previous terms are computed recursively.
The program can be written more compactly: (source file)
public class Fibonacci2 { public static void main (String[] argv) { // ... } static int fibonacci (int n) { // Base cases rolled into one: if (n <= 2) { return (n-1); } return ( fibonacci(n-1) + fibonacci(n-2) ); } }
Let us now count how often the recursive method fibonacci() is called: (source file)
public class Fibonacci3 { // A counter, accessible everywhere. static int numCalls; public static void main (String[] argv) { // Test case 1: int n = 5; // Need to initialize: numCalls = 0; int f = fibonacci (n); System.out.println ("f(" + n + ") = " + f + " numCalls=" + numCalls); // Test case 2: n = 20; // Need to initialize: numCalls = 0; f = fibonacci (n); System.out.println ("f(" + n + ") = " + f + " numCalls=" + numCalls); } static int fibonacci (int n) { // We are recording the number of times this method is called. numCalls ++; if (n <= 2) { return (n-1); } return ( fibonacci(n-1) + fibonacci(n-2) ); } }Note:
To fix the problem:
public class Fibonacci4 { // Same counter as before. static int numCalls; // We'll use an array to record previously computed values: static int[] fValues; public static void main (String[] argv) { // Test case 1: int n = 5; // Initialize both the counter and the array: numCalls = 0; fValues = new int [n+1]; int f = fibonacci (n); System.out.println ("f(" + n + ") = " + f + " numCalls=" + numCalls); // Test case 2: n = 20; // Initialize both the counter and the array: numCalls = 0; fValues = new int [n+1]; f = fibonacci (n); System.out.println ("f(" + n + ") = " + f + " numCalls=" + numCalls); } static int fibonacci (int n) { // We are recording the number of times this method is called. numCalls ++; if (n <= 2) { // First time we reach here, we store the values. fValues[n] = n-1; return (n-1); } // If the values haven't been computed, then compute and store. if (fValues[n-1] == 0) { fValues[n-1] = fibonacci(n-1); } if (fValues[n-2] == 0) { fValues[n-2] = fibonacci(n-2); } // By the time we reach here, previous fib values have been stored. fValues[n] = fValues[n-1] + fValues[n-2]; return fValues[n]; } }
Now the results are reasonable:
Of course, Fibonacci numbers can be computed without arrays very easily in an iterative manner: (source file)
public class Fibonacci5 { public static void main (String[] argv) { // ... } static int fibonacci (int n) { // Base cases: same as in recursive version. if (n == 1) { return 0; } else if (n == 2) { return 1; } // Now we know n >= 3. // Start with first and second terms. int fPrev = 1; int fPrevPrev = 0; int f = -1; // A simple iteration takes us to the n-th term. for (int k=3; k <= n; k++) { f = fPrev + fPrevPrev; fPrevPrev = fPrev; fPrev = f; } return f; } }