Binary search without division, shifts or multiplications

I subscribed to the daily coding problem a while ago by recommendation from a friend. I only have the free version, just for fun. You get a problem per day, with no solution (unless you pay), the difficulty varies.

Last week I got a problem that asks you to implement a search function that runs in \(O(\log n)\) time over a sorted array without using division, shifts, or multiplications.

The problem is fun, and at least in my case, my first intuition was wrong. I will go through my solutions and compare them.

Without the restriction

If we didn’t have the restriction on operations, we would implement a binary search and be done with it. I wrote a binary search so we can compare the traditional solution to our other proposals.

def bsearch(arr, n):
    i, j = 0, len(arr) - 1
    while i < j:
        mid = (i + j) // 2
        if arr[mid] == n:
            return mid
        if arr[mid] < n:
            i = mid + 1
        else:
            j = mid
    if arr[i] == n:
        return i
    
    return -1

This is standard binary search, runs in \(O(\log n)\) time, and returns the position where the element is in the array or -1 if the element is not present.

First approach

As I mentioned before, the first intuition I had was not quite there. I though of iterating over exponential searches. Exponential search consists of two stages:

  1. A procedure that limits the range where we want to search.
  2. A binary search in the range we determined the element may be from step 1.

The way you determine the range in which to restrict the search is by advancing through the array, usually doubling the last position, and stopping as soon as you find an element that is smaller bigger than the one you are looking for. At this point, you know the element is between your previous checkpoint and the position you are looking at now. The pseudo-code looks like this:

prev, pos = 0, 1
while arr[pos] < value: // add check for bounds
    prev, pos = pos, 2*pos
binary_search(arr, prev, pos) // from prev to pos

Note that this allows us to get to the right range in $time O(\log n)$ time and then do a binary search, so the total runtime is \(O(\log n)\). To be more precise, it’s \(O(\log pos)\), which makes this search algorithm interesting when the value is close to the beginning, or the array size is unbounded.

In exponential search, you limit the segment in which the element may be present, and then run a binary search there. The problem is that the binary search implementation we have is using forbidden operations, so, for this problem, I was going to recursively run the first stage of exponential search until the segment became small enough to traverse it. The code is as follows:

def iter_exp_search(arr, n):
    add, prev, pos = 1, 0, 0
    end = len(arr) - 1
    while prev <= end:
        if arr[pos] == n:
            return pos
        elif arr[pos] < n:
            prev = pos + 1
            pos = min(pos + add, end)
            add += add
        elif arr[pos] > n:
            end = pos - 1
            pos = prev
            add = 1
    return -1

We will verify the implementation later, but for now, we can already check that this does not work. It doesn’t use any forbidden operation, but the runtime is \(O(\log^2 n)\). Let’s check this: we call \(T(n)\) the number of comparisons done by the algorithm to search in an array of \(n = 2^k\) elements in the worst case. The recursion is as follows:

\(T(n) = \log n + T(n/2)\)

The first term is the initial iteration searching for the right segment, and the second assumes we end up in the biggest segment. We can rename \(T(n) = T(2^k) = G(k)\) and rewrite the equation:

\(G(k) = k + G(k-1)\)

Let’s write it as:

\(G(i) – G(i-1) = i\)

If we sum over all i’s from 1 to k, we get:

\(\sum_{i=1}^k \left(G(i) – G(i-1)\right) = \sum_{i=1}^k i\)

\(G(k) – G(0) = \frac{k(k+1)}{2}\)

We can revert this now to:

\(T(n) = \frac{\log n \times (1 + \log n)}{2}\)

\(T(n) = O(\log^2 n)\)

Fail! We didn’t get \(O(\log n)\). Intuitively, this make sense, exponential search is two \(O(\log n)\) operations, the iterative exponential search is \(O(\log n)\) repeated roughly \(\log n\) times.

Going back to binary search

The solution is much easier. If we look at the binary search, we are using division to get into a middle point between i and j. We can do this using powers of two instead (it’s not the exact average, but good enough):

def bsearch_nodiv(arr, n):
    powers = [0, 1]
    next_power = powers[-1] + powers[-1]
    while next_power < len(arr):
        powers.append(next_power)
        next_power = powers[-1] + powers[-1]
    
    i, j, power_index = 0, len(arr) - 1, len(powers) - 1
    while i < j:
        while i + powers[power_index] >= len(arr):
            power_index -= 1
        mid = i + powers[power_index]
        if arr[mid] == n:
            return mid
        if arr[mid] < n:
            i = mid + 1
        else:
            j = mid
        power_index -= 1
    if i < len(arr) and arr[i] == n:
        return i
    
    return -1

This is easier to analyze, the array powers has at most \(O(log n)\) values, and we visit each one at most once while running the algorithm, so we are safe now.

Checking our implementation

We start by implementing an elementary test. The test generates arrays of elements from 0 to n, and search all elements (found) and half elements (not found) between 0 and n+0.5. To validate our search function we will implement a search function that exploits the structure of the array:

def arr(n):
    return [i for i in range(n)]
    
def find(arr, n):
    if n == int(n) and n < len(arr):
        return n
    return -1

To check a function we will write a helper:

def check(N, search_fn):
    array = arr(N)
    for i in range(N):
        if search_fn(array, i) != find(array, i):
            print("Error for", i, array, searcher(array, i), find(array, i))
            return False
        if search_fn(array, i + 0.5) != find(array, i + 0.5):
            print("Error for", i, array, searcher(array, i), find(array, i))
            return False
    return True

We will also implement a linear search to make timing comparisons more interesting:

def lsearch(arr, n):
    for i, v in enumerate(arr):
        if v == n:
            return i
        if v > n:
            break
    return -1

First, let’s validate our search functions:

for fn in [lsearch, bsearch, bsearch_nodiv, iter_exp_search]:
    print([check(t, fn) for t in [10, 100, 1000, 10000]])

#[True, True, True, True]
#[True, True, True, True]
#[True, True, True, True]
#[True, True, True, True]

Once we know they are doing what we want (or at least we have some evidence), we can measure timing. Here is a set of helper functions to do so:

import time
import numpy as np

def time_search(fn, array, queries):
    reps = 10
    def timer():
        res = 0
        for q in queries[:min(10000, len(queries))]:
            res += fn(array, q)
        t_ini = time.process_time_ns()
        for q in queries:
            res += fn(array, q)
        t_fin = time.process_time_ns()
        return ((t_fin-t_ini) / len(queries) / 1000, res)
    results = [timer() for a in range(reps)]
    times = [x[0] for x in results]
    checksums = [x[1] for x in results]
    times.sort()
    checksums.sort()
    if checksums[0] != checksums[-1]:
        print("checksums don't match!!")
    return np.mean(times[2:-2]), checksums[0] # small cheat to make it prettier

import random
import matplotlib.pyplot as plt

def plot_data(search_fn, queries, sizes, size_fn):
    x = []
    y = []
    s = 0
    for s in range(1, sizes + 1):
        test_size = size_fn(s)
        array = [i for i in range(test_size)]
        query_sample = [v + 0.5 
          for v in random.sample(array, min(test_size, queries))]
        t, r = time_search(search_fn, array, query_sample)
        s += r
        x.append(test_size)
        y.append(t)
    return [x, y]

def plot(experiments, queries, sizes, size_fn, fname):
    data = {}
    for search_fn, fn_name in experiments:
        data[fn_name] = plot_data(search_fn, queries, sizes, size_fn)
    
    fig, ax = plt.subplots()
    for fn_name in data:
        ax.plot(data[fn_name][0], data[fn_name][1], label=fn_name)
    ax.legend(loc='upper left', shadow=False, fontsize='small')

    plt.ylabel('Time in microsec')
    plt.xlabel('Number of elements in array')
    plt.savefig(fname)
    plt.show()

Let’s try our functions:

exps = [
    (iter_exp_search, 'iterative exp-search'),
    (bsearch, 'binary search'),
    (bsearch_nodiv, 'binary search (no div)'),
    (lsearch, 'linear search')
]

plot(exps, 10000, 25, lambda x: 40 * x, 'figure.png')

The result looks like this:

Sort of looks like we expected. Let’s drop linear search and see how things behave for bigger inputs:

exps = [
    (iter_exp_search, 'iterative exp-search'),
    (bsearch, 'binary search'),
    (bsearch_nodiv, 'binary search (no div)')
]

plot(exps, 10000, 6, lambda x: 2**(12 + x), 'iter4.png')

Improving our solution

One thing we can notice is that we are computing the table of powers of two every time we search. We can precompute this. The only disadvantage is that we are limiting the size of the array we could search on (not an issue, since we can store a table bigger than a practical array size):

powers_arr = [0, 1]
next_power = powers_arr[-1] + powers_arr[-1]
while next_power < 10**9:
    powers_arr.append(next_power)
    next_power = powers_arr[-1] + powers_arr[-1]

def bsearch_nodiv2(arr, n):
    power_index = len(powers_arr) - 1
    i, j = 0, len(arr) - 1
    while i < j:
        while i + powers_arr[power_index] >= len(arr):
            power_index -= 1
        mid = i + powers_arr[power_index]
        if arr[mid] == n:
            return mid
        if arr[mid] < n:
            i = mid + 1
        else:
            j = mid
        power_index -= 1
    if i < len(arr) and arr[i] == n:
        return i
    
    return -1

def bsearch_nodiv3(arr, n):
    power_index = 0
    while powers_arr[power_index] < n:
        power_index += 1
    power_index -= 1
    i, j = 0, len(arr) - 1
    while i < j:
        while i + powers_arr[power_index] >= len(arr):
            power_index -= 1
        mid = i + powers_arr[power_index]
        if arr[mid] == n:
            return mid
        if arr[mid] < n:
            i = mid + 1
        else:
            j = mid
        power_index -= 1
    if i < len(arr) and arr[i] == n:
        return i
    
    return -1

One version (nr. 2) starts from the end of the powers finding which one matches the array size, the other one (nr. 3) starts from the beginning. They favor large and small arrays respectively. Let’s first check them:

for fn in [bsearch_nodiv2, bsearch_nodiv3]:
    print([check(t, fn) for t in [10, 100, 1000, 10000]])

# [True, True, True, True]
# [True, True, True, True]

And now we can plot and check whether we improve with them:

exps = [
    (bsearch, 'binary search'),
    (bsearch_nodiv, 'binary search (no div)'),
    (bsearch_nodiv2, 'binary search (no div) v2'),
    (bsearch_nodiv3, 'binary search (no div) v3')
]

plot(exps, 10000, 6, lambda x: 2**(12 + x), 'iter5.png')

Checking our models

We concluded that the iterative exponential search is \(O(\log^2 n)\) and that the other solutions are \(O(\log n)\) (I’m leaving linear outside, from the plot looks quite linear ;-)).

We will now fit our data to the models, for that, I’ll first write a helper that fits the results to a model to estimate its parameters, and then plots both the original data and the predicted model.

from scipy.optimize import curve_fit
import numpy as np
from math import log
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

init_vals = [1, 1]

def model_plot(fn, fn_name, queries, sizes, size_multiplier, model_fn, init_params, file_name):
    x, y = plot_data(fn, queries, sizes, size_multiplier)
    x, y = np.array(x), np.array(y)
    
    estimated_params, _ = curve_fit(model_fn, x[:-1], y[:-1], p0=init_params)
    print(f'estimated parameters: {estimated_params}')
    
    y_pred = [model_fn(v, *estimated_params) for v in x]
    print(f'mse: {mean_squared_error(y, y_pred)}')
    
    fig, ax = plt.subplots()
    ax.plot(x, y, label=fn_name)
    ax.plot(x, y_pred, label=fn_name+' (predicted)')
    ax.legend(loc='upper left', shadow=True, fontsize='small')

    plt.ylabel('Time in microsec')
    plt.xlabel('Number of elements in array')
    plt.savefig(file_name)
    plt.show()

We will now use it for our solutions fitting it against a function of the form \(a\log n + b\).

QUERIES = 10000
IT = 14
IT_SIZE_FN = lambda x: 2**(4 + x)

def time_log(x, a, b):
    return a * np.log(x) + b

model_plot(bsearch, 'binary search', QUERIES, IT, 
  IT_SIZE_FN, time_log, [1, 1], 'binary_pred.png')
model_plot(bsearch_nodiv3, 'binary search nodiv3', QUERIES, IT, 
  IT_SIZE_FN, time_log, [1, 1], 'binary_nodiv3_pred.png')
model_plot(iter_exp_search, 'iter exp-search', QUERIES, IT, 
  IT_SIZE_FN, time_log, [1, 1], 'iter_exp_search.png')

The resulting plots are:

They work quite well for the first two, but not the third, as expected. The mse for the three experiments is 0.004, 0.007, and 2.283. Let’s try the last one with a better model:

def time_logsqr(x, a, b, c):
    return a * np.log(x) * np.log(x) + b * np.log(x) + c

model_plot(iter_exp_search, 'iter exp-search', QUERIES, IT, 
  IT_SIZE_FN, time_logsqr, [1, 1, 1], 'iter_exp_search_logsqr.png')

The result is:

Much better now, mse is 0.06.

And with that we can wrap this one up :).

This entry was posted in Programming. Bookmark the permalink.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.