Search In An Ascending Matrix

February 10, 2012

Today’s exercise is taken from our inexhaustible list of interview questions:

Given an m by n matrix of integers with each row and column in ascending order, search the matrix and find the row and column where a key k appears, or report that k is not in the matrix. For instance, in the matrix

 1  5  7  9
 4  6 10 15
 8 11 12 19
14 16 18 21

the key 11 appears in row 2, column 1 (indexing from 0) and the key 13 is not present in the matrix. The obvious algorithm takes time O(m × n) to search the matrix row-by-row, but you must exploit the order in the matrix to find an algorithm that takes time O(m + n).

Your task is to write the requested search function. When you are finished, you are welcome to read or run a suggested solution, or to post your own solution or discuss the exercise in the comments below.

Advertisement

Pages: 1 2

20 Responses to “Search In An Ascending Matrix”

  1. Not optimal, but has the required complexity.

    M = [ [  1,  5,  7,  9 ],
          [  4,  6, 10, 15 ],
          [  8, 11, 12, 19 ],
          [ 14, 16, 18, 21 ] ]
    
    def find(m, v):
        if len(m) > 0 and len(m[0]) > 0:
            return find_r(m, v, 0, 0, len(m) - 1, len(m[0]) - 1)
        return None
    
    def find_r(m, v, r1, c1, r2, c2):
        if r1 > r2 or c1 > c2:
            return None
        if not (m[r1][c1] <= v <= m[r1][c2]):
            return find_r(m, v, r1 + 1, c1, r2, c2)
        if not (m[r2][c1] <= v <= m[r2][c2]):
            return find_r(m, v, r1, c1, r2 - 1, c2)
        if not (m[r1][c1] <= v <= m[r2][c1]):
            return find_r(m, v, r1, c1 + 1, r2, c2)
        if not (m[r1][c2] <= v <= m[r2][c2]):
            return find_r(m, v, r1, c1, r2, c2 - 1)
        return (r1, c1)
    
    print(find(M, 11))
    print(find(M, 13))
    
  2. Arthur said

    “If the key k is greater than the current element then it cannot be present in the current row”

    I’m sorry, I don’t understand that. Let’s look for the key 5 in the example. We start with element 1 and the condition quoted applies. So we go down. But wait, 5 can be actually found in the row in question.
    Hmm…

  3. Arthur, you apparently missed that programmingpraxis’s algorithm starts at the top-right corner which is 9 in this example.

  4. Arthur said

    Yup, that’s it. Thanks. I was starting at the top-left.

  5. phillip said

    a while loop might be better here.
    matrixsearch

  6. ardnew said

    Clever algorithm, OP. Haven’t seen this one before.

    use strict;
    use warnings;
    
    sub search
    {
      my ($t, $k, $m, $n) = @_;
        
      return ( ) if $m == scalar @$t || $n < 0;      
      return ($m, $n) if $k == $$t[$m][$n];  
      return search($t, $k, 
        $m + ($k > $$t[$m][$n]), 
        $n - ($k < $$t[$m][$n]));
    }
    
    #
    #  Usage example
    #
    my @t = ([  1,  5,  7,  9 ],
             [  4,  6, 10, 15 ],
             [  8, 11, 12, 19 ],
             [ 14, 16, 18, 21 ]);
    
    my $n = scalar @t;         
    my $m = scalar @{$t[0]} if $n > 0;
    
    while (<STDIN>)
    {
      chomp;
      my @c = search(\@t, $_, 0, $m - 1);
      print scalar @c ? "@c\n" : "not found\n";
    }
    
  7. ardnew said

    Correct me if I’m wrong, but all solutions posted by others (including OP) only move up or down, but not both, per iteration. The algorithms will not move both left and right in any one iteration.

    The solution I posted can move both left and right if the conditions allow it.

    Is there a reason you only move one direction per iteration?

  8. ardnew said

    Whoops, got my directions all mixed up.

    All other algorithms move either left or down, but not both, per iteration. The algorithms will not move both left and down in any one iteration.

    The solution I posted can move left and down simultaneously if the conditions allow it.

    Is there a reason you only move one direction per iteration?

  9. Yogesh said

    def find(m, k):
    h=0
    for i in m:
    v=0
    if k<=i[-1]:
    for j in i:
    if j==k: return h,v
    v+=1
    h+=1
    return "Not found"

    #Usage Example
    M = [ [ 1, 5, 7, 9 ],
    [ 4, 6, 10, 15 ],
    [ 8, 11, 12, 19 ],
    [ 14, 16, 18, 21 ] ]
    print find(M, 18)

    I am pretty new to this site, and don’t know if this is how we post answers, but please comment if its not like this.

  10. Mike said

    The same idea as ProgrammingPraxis’ solution, except I happend to start at the lower left corner instead of the upper right.

    
    def index(m, k):
        """Return indices of key 'k' in ordered matrix 'm'.
        
        >>> m = [[ 1, 5, 7, 9],
        ...      [ 4, 6,10,15],
        ...      [ 8,11,12,19],
        ...      [14,16,18,21]]
        >>> index(m,11)
        (2, 1)
        >>> index(m,21)
        (3, 3)
        >>> index(m,17)
        >>>
        """
    
        row_nos = iter(range(len(m)-1, -1, -1))
        col_nos = iter(range(len(m)))
    
        for row in row_nos:
            for col in col_nos:
                if m[row][col] > k:
                    break
                
                if m[row][col] == k:
                    return row, col
            
            else:
                break
            
        return None
    
    if __name__ == "__main__":
        import doctest
        doctest.testmod()
    
    
  11. Mike said

    Sorry, there’s a type in line 17. Here’s the corrected code:

    
    def index(m, k):
        """Return indices of key 'k' in ordered matrix 'm'.
        
        >>> m = [[ 1, 5, 7, 9,13],
        ...      [ 4, 6,10,15,17],
        ...      [ 8,11,12,19,20],
        ...      [14,16,18,21,22]]
        >>> index(m,11)
        (2, 1)
        >>> index(m,21)
        (3, 3)
        >>> index(m,17)
        >>>
        """
    
        row_nos = iter(range(len(m)-1, -1, -1))
        col_nos = iter(range(len(m[0])))
    
        for row in row_nos:
            for col in col_nos:
                if m[row][col] > k:
                    break
                
                if m[row][col] == k:
                    return row, col
            
            else:
                break
            
        return None
    
    if __name__ == "__main__":
        import doctest
        doctest.testmod()
    
    
  12. Not optimal, but wanted to post what I got:

    object AscendMatrixSearch extends App {
    val mat = List(List(1, 5, 7, 9), List(4, 6, 10, 15), List(8, 11, 12, 19), List(14, 16, 18, 21))
    def find(v: Int, mat: List[List[Int]]) =
    for (row <- mat if v <= row.last; i <- row if i == v) println("Found " + v)
    find(11, mat)
    find(13, mat)
    }

  13. Again, not optimal as what’s here, but forgot indices

    object AscendMatrixSearch extends App {
    val mat = List(List(1, 5, 7, 9), List(4, 6, 10, 15), List(8, 11, 12, 19), List(14, 16, 18, 21))
    def find(v: Int, mat: List[List[Int]]) =
    for (i <- 0 until mat.length; row = mat(i); j <- 0 until row.length if row(j) == v) println("Found " + v + " at (" + i + ", " + j + ")")
    find(15, mat)
    find(11, mat)
    find(33, mat)
    find(-24, mat)
    }

  14. Forgot filter on row, sorry for three posts (this is Scala code is that allowed here?)

    object AscendMatrixSearch extends App {
    val mat = List(List(1, 5, 7, 9), List(4, 6, 10, 15), List(8, 11, 12, 19), List(14, 16, 18, 21))
    def find(v: Int, mat: List[List[Int]]) =
    for (i <- 0 until mat.length if v <= mat(i).last; row = mat(i); j <- 0 until row.length if row(j) == v)
    println("Found " + v + " at (" + i + ", " + j + ")")
    find(15, mat)
    find(11, mat)
    find(33, mat)
    find(-24, mat)
    }

  15. def search(k, matrix):
        """ We loop over the diagonal elements looking for the first one greather
            than the number.Then, we carry out a linear search over the elements
            from the previous diagonal element -what we refer to as pivot- and
            the current one. It is easily seen that this algorithm is O(m + n). """
    
        def linear_search(k, list):
            for i in range(len(list)):
                pivot = list[i]
                if k == pivot:
                    return i
            return -1
    
        m = len(matrix)
        for i in range(m):
            pivot = matrix[i][i]
            if k == pivot:
                return i, i
            elif k < pivot:
                if i is 0:
                    # upper left element bigger than searched number,
                    # hence all elements of the matrix bigger also
                    # therefore not present
                    return -1
                else:
                    lower_slice = matrix[i - 1][i:]
                    lower_index = linear_search(k, lower_slice)
                    if lower_index != -1:
                        return i - 1, i + lower_index
    
                    upper_slice = matrix[i][0:i]
                    upper_index = linear_search(k, upper_slice)
                    if upper_index != -1:
                        return i, upper_index
                        # not found
        return -1
    
    # Tests
    matrix = [[1, 5, 7, 9],
        [4, 6, 10, 15],
        [8, 11, 12, 19],
        [14, 16, 18, 21]]
    
    assert search(11, matrix) == (2, 1)
    assert search(5, matrix) == (0, 1)
    assert search(1, matrix) == (0, 0)
    assert search(22, matrix) == -1
    assert search(13, matrix) == -1
    assert search(14, matrix) == (3, 0)
    assert search(21, matrix) == (3, 3)
    
  16. DGel said

    My attempt at a haskell solution. Probably could be done much more succinctly, but well

    module Main where
    import Data.Array
    
    dat = listArray ((0,0),(3,3)) [1,5,7,9,4,6,10,15,8,11,12,19,14,16,18,21]
    
    find m x = find' 0 (snd (snd (bounds m))) m x
    
    find' i j m x 
        | i < 0 = Nothing
        | j < 0 = Nothing
        | i > (fst . snd . bounds $ m) = Nothing
        | j > (snd . snd . bounds $ m) = Nothing
        | x == (m ! (i,j)) = Just (i,j)
        | x < (m ! (i,j)) = find' i (j-1) m x
        | x > (m ! (i,j)) = find' (i+1) j m x
    
    main = do
        print $ find dat 11
        print $ find dat 13
    
  17. dmitru said

    My recoursive solution:

    def find(M, key):
        cols = len(M[0])
        def f(i, j):
            if i < 0 or j >= cols:
                return False
            if key < M[i][j]:
                return f(i - 1, j)
            if key > M[i][j]:
                return f(i, j + 1)
            return True
        return f(len(M) - 1, 0)
    
    m = [[1, 5, 7, 9], [4, 6, 10, 15], [8, 11, 12, 19], [14, 16, 18, 21]]
    
    # Tests: 
    solve(m, 9)        # True
    solve(m, 12)      # True
    solev(m, 21)      # True
    solve(m, 1000)  # False
    solve(m, 17)      # False
    
  18. adeepak said

    My Attempt in Python. Basic flow is:
    Scan a column top to down and stop when you hit a higher value.
    Then move a column right and try finding a hit there.
    Keep repeating the above 2 steps. You will eventually snake your way through the matrix.

    Room to optimize this further.

    
    def scan_col(x, y):
        ''' Walk a column until you find a matching entry, OR
            - stop at the closest entry < matching value
            - stop at the end of the column
            Request for the next pos to start scan
        '''        
        for i in xrange(x, row):
            if (matrix[i][y] == val):
                print "Value Found at [%d][%d]" % (i , y)
                return (row, col)
            if (matrix[i][y] > val):
                i = i - 1;            
                break;
        return (i, y)
      
    def find_pos(x, y):
        ''' Find the next position to start the next scan.
            First go right from current location. 
            If the position is higher than value, find the closest lesser value (up)
            If you reach the top and still higher, exit. No match possible. 
        '''
        if (y == col - 1):
            print "Entry does not exist"
            return (row, col)
        y = y + 1
        for i in xrange(x, -1, -1):
            if (matrix[i][y] <= val):
                return (i, y)
        print "Entry does not exist"
        return (row, col)
    
    matrix = [ [1, 5, 7, 9],
               [4, 6, 10, 15],
               [8, 11, 12, 19],
               [14, 16, 18, 21],
              ]
    
    row = len(matrix)
    col = len(matrix[0])
    val = 9
    
    print matrix
    a, b, n = 0, 0, 0
    end = 0
    while (1):
        print "Start %d scan at:     [%d] [%d]" % (n, a, b)
        a, b = scan_col(a, b)
        if (a == row and b == col):
            break;
        print "Request find_pos at: [%d] [%d]" % (a, b)
        a, b = find_pos(a, b)
        if (a == row and b == col):
            break;
        n += 1
    
    
  19. Carl said

    In Racket:

    (define matrix ‘#(#(1 5 7 9)
    #(4 6 10 15)
    #(8 11 12 19)
    #(14 16 18 21)))

    (define matrix-rows vector-length)

    (define (matrix-cols m)
    (vector-length (vector-ref m 0)))

    (define (matrix-ref m r c)
    (vector-ref (vector-ref m r) c))

    (define (find-in-ascending-matrix num)
    (let loop ((r0 0)
    (r1 (sub1 (matrix-rows matrix)))
    (c0 0)
    (c1 (sub1 (matrix-cols matrix))))
    (if (and r0 r1 c0 c1 (>= r1 r0) (>= c1 c0))
    (if [= num (matrix-ref matrix r0 c0)]
    (vector r0 c0)
    (loop
    (for/first ([i (in-range r0 (add1 r1))]
    #:when (>= (matrix-ref matrix i c1) num))
    i)
    (for/first ([i (in-range r1 (sub1 r0) -1)]
    #:when (= (matrix-ref matrix r1 i) num))
    i)
    (for/first ([i (in-range c1 (sub1 c0) -1)]
    #:when (<= (matrix-ref matrix r0 i) num))
    i)
    ))
    #f)))

  20. Carl said

    Just noticed the method to post source:

    (define matrix '#(#(1  5   7  9)
                      #(4  6   10 15)    
                      #(8  11  12 19)                 
                      #(14 16  18 21)))
    
    (define matrix-rows vector-length)
    
    (define (matrix-cols m)
      (vector-length (vector-ref m 0)))
    
    (define (matrix-ref m r c)
      (vector-ref (vector-ref m r) c))
    
    (define (find-in-ascending-matrix num)
      (let loop ((r0 0) 
                 (r1 (sub1 (matrix-rows matrix)))
                 (c0 0)
                 (c1 (sub1 (matrix-cols matrix))))
        (printf "~a ~a  ~a ~a\n" r0 r1 c0 c1)
        (if (and r0 r1 c0 c1 (>= r1 r0) (>= c1 c0))
            (if [= num (matrix-ref matrix r0 c0)]
                (vector r0 c0)
                (loop
                 (for/first ([i (in-range r0 (add1 r1))]
                             #:when (>= (matrix-ref matrix i c1) num))
                   i)
                 (for/first ([i (in-range r1 (sub1 r0) -1)]
                            #:when (<= (matrix-ref matrix i c0) num))
                   i)
                 (for/first ([i (in-range c0 (add1 c1))]
                             #:when (>= (matrix-ref matrix r1 i) num))
                   i)
                 (for/first ([i (in-range c1 (sub1 c0) -1)]
                            #:when (<= (matrix-ref matrix r0 i) num))
                   i)
                 ))
            #f)))
    

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: