#!/usr/bin/python
#
# rpmdeps
#
# list dependency tree
#
# Copyright (C) 2004,2007 Red Hat, Inc.
# Authors:
# Thomas Woerner <twoerner@redhat.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# version 2004-10-25-02

import sys, os, rpm, types

# ----------------------------------------------------------------------------

def evr_split(evr):
    epoch = "0"
    version = ""
    release = ""

    num = ""
    has_epoch = 0
    has_version = 0
    for i in xrange(len(evr)):
        if evr[i] == "-" and has_version == 0:
            has_version = 1
            version = num
            num = ""
        else:
            digits = num.isdigit()
            if evr[i] == ":" and has_version == 0 and has_epoch == 0:
                has_epoch = 1
                if num != "(none)":
                    epoch = num
                num = ""
            else:
                num = num + evr[i]
        if has_version:
            release = num
        else:
            version = num

    return (epoch, version, release)

# ----------------------------------------------------------------------------

def evr_compare(evr1, comp, evr2):
#    print "evr_compare('%s', '%s', '%s')" % (evr1, comp, evr2)
    
    res = -1
    e1 = evr_split(evr1)
    e2 = evr_split(evr2)
    if e2[2] == "": # fix release number
        e1 = (e1[0], e1[1], "")
    res = rpm.labelCompare(e1, e2)
    if res == -1:
        if comp == "<" or comp == "<=":
            res = 1
    elif res == 0:
        if comp == "<=" or comp == "=" or comp == ">=":
            res = 1
    else: # res == 1
        if comp == ">=" or comp == ">":
            res = 1
    return res

# ----------------------------------------------------------------------------

def get_hdr(name):
    hdr = [ ]
    if os.path.isfile(name):
        fdno = os.open(name, os.O_RDONLY)
        try:
            h = ts.hdrFromFdno(fdno)
        except rpm.error:
            fdno = os.open(name, os.O_RDONLY)
            ts.setVSFlags(rpm._RPMVSF_NOSIGNATURES)
            try:
                h = ts.hdrFromFdno(fdno)
            except:
                pass
            else:
                hdr.append(h)
            os.close(fdno)
        else:
            hdr.append(h)
    else:
        mi = ts.dbMatch("name", name)
        for h in mi:
            hdr.append(h)
    return hdr

# ----------------------------------------------------------------------------

class RPM:
    def __init__(self, name, hdr):
        self.clear()
        self.set(name, hdr)
    def clear(self):
        self.p_file = ""
        self.p_name = ""
        self.p_epoch = ""
        self.p_version = ""
        self.p_release = ""
        self.p_arch = ""
        self.p_provides = [ ]
        self.p_requires = [ ]
        self.p_requireflags = [ ]
        self.p_requireversion = [ ]
        self.p_resolved = [ ]
        self.p_unresolved = [ ]
        self.p_filenames = [ ]
        self.p_installed = 0
    def set(self, name, hdr):
        self.p_file = name
        self.p_name = hdr[rpm.RPMTAG_NAME]
        self.p_epoch = hdr[rpm.RPMTAG_EPOCH]
        if isinstance(self.p_epoch, types.NoneType):
            self.p_epoch = "0"
        self.p_version = hdr[rpm.RPMTAG_VERSION]
        self.p_release = hdr[rpm.RPMTAG_RELEASE]
        self.p_arch = hdr[rpm.RPMTAG_ARCH]
        self.p_filenames = hdr[rpm.RPMTAG_FILENAMES]
        for i in xrange(len(hdr[rpm.RPMTAG_PROVIDES])):
            p = hdr[rpm.RPMTAG_PROVIDES][i]
            flag = ""
            if hdr[rpm.RPMTAG_PROVIDEFLAGS][i] & rpm.RPMSENSE_LESS:
                flag += "<"
            if hdr[rpm.RPMTAG_PROVIDEFLAGS][i] & rpm.RPMSENSE_GREATER:
                flag += ">"
            if hdr[rpm.RPMTAG_PROVIDEFLAGS][i] & rpm.RPMSENSE_EQUAL:
                flag += "="
            version = hdr[rpm.RPMTAG_PROVIDEVERSION][i]

            self.p_provides.append((p, flag, version))
        for i in xrange(len(hdr[rpm.RPMTAG_REQUIRES])):
            flag = ""
            if hdr[rpm.RPMTAG_REQUIREFLAGS][i] & rpm.RPMSENSE_LESS:
                flag += "<"
            if hdr[rpm.RPMTAG_REQUIREFLAGS][i] & rpm.RPMSENSE_GREATER:
                flag += ">"
            if hdr[rpm.RPMTAG_REQUIREFLAGS][i] & rpm.RPMSENSE_EQUAL:
                flag += "="
            version = hdr[rpm.RPMTAG_REQUIREVERSION][i]

            p = (hdr[rpm.RPMTAG_REQUIRES][i], flag, version)

            remove = 0
            for r in self.p_requires:
                if r == p:
                    remove = 1
                    break
            if remove == 1:
                continue
            
            self.p_requires.append(p)

            if p[0][0:6] == "rpmlib":
                continue
            for r in self.p_provides:
                if r == p:
                    remove = 1
                    if self_deps == 1:
                        self.p_resolved.append((p, self))
                    break
            if remove == 1:
                continue
            for f in self.p_filenames:
                if f == p[0]:
                    remove = 1
                    if self_deps == 1:
                        self.p_resolved.append((p, self))
                    break
            if remove == 1:
                continue

            self.p_unresolved.append(p)

        try:
            self.p_installed = int(int(hdr[rpm.RPMTAG_INSTALLTID]) > 0)
        except:
            self.p_installed = 0

    def __cmp__(self, r):
        if isinstance(r, RPM):
            if self.p_name == r.p_name and \
                   self.p_arch == r.p_arch:
                return 0
        elif isinstance(r, rpm.hdr):
            epoch = r[rpm.RPMTAG_EPOCH]
            if isinstance(epoch, types.NoneType):
                epoch = "0"
            if self.p_name == r[rpm.RPMTAG_NAME] and \
                   self.p_arch == r[rpm.RPMTAG_ARCH]:
                return 0
        return 1
    
# ----------------------------------------------------------------------------

def check_append(r, list, recursive):
    """ resolve dependencies against installed and new packages """
    i = 0
    while i < len(r.p_unresolved):
        u = r.p_unresolved[i]
        found = 0
        mi = ts.dbMatch(rpm.RPMTAG_PROVIDES, u[0])
        for h in mi:
            res = 1
            if u[2] != "": # we have a version requirement
                for k in xrange(len(h[rpm.RPMTAG_PROVIDES])):
                    if h[rpm.RPMTAG_PROVIDES][k] != u[0]:
                        continue
                    flag = ""
                    if h[rpm.RPMTAG_PROVIDEFLAGS][k] & rpm.RPMSENSE_LESS:
                        flag += "<"
                    if h[rpm.RPMTAG_PROVIDEFLAGS][k] & rpm.RPMSENSE_GREATER:
                        flag += ">"
                    if h[rpm.RPMTAG_PROVIDEFLAGS][k] & rpm.RPMSENSE_EQUAL:
                        flag += "="
                    p = (h[rpm.RPMTAG_PROVIDES][k], flag, \
                         h[rpm.RPMTAG_PROVIDEVERSION][k])
                    if u[2] != "":
                        res = evr_compare(p[2], u[1], u[2])
                            
                if res != 1:
                    if isinstance(h[rpm.RPMTAG_EPOCH], types.NoneType):
                        epoch = ""
                    else:
                        epoch = "%s:" % h[rpm.RPMTAG_EPOCH]
                    version = h[rpm.RPMTAG_VERSION]
                    release = h[rpm.RPMTAG_RELEASE]
                    res = evr_compare("%s%s-%s" % (epoch, version, release), \
                                      u[1], u[2])
            if res == 1:
                f = 0
                for r2 in list:
                    if r2.__cmp__(h) == 0:
                        f = 1
                        break
                if f == 0:
                    r2 = RPM(h[rpm.RPMTAG_NAME], h)
                    if recursive == 1:
                        list.append(r2)
                        check_append(r2, list, recursive)

                found = 1
        if found == 0:
            mi = ts.dbMatch(rpm.RPMTAG_BASENAMES, u[0])
            for h in mi:
                res = 1
                # TODO: is this version test needed?
                if u[2] != "": # we have a version requirement
                    if isinstance(h[rpm.RPMTAG_EPOCH], types.NoneType):
                        epoch = ""
                    else:
                        epoch = "%s:" % h[rpm.RPMTAG_EPOCH]
                    version = h[rpm.RPMTAG_VERSION]
                    release = h[rpm.RPMTAG_RELEASE]
                    res = evr_compare("%s%s-%s" % (epoch, version, release), u[1], u[2])
                if res == 1:
                    f = 0
                    for r2 in list:
                        if r2.__cmp__(h) == 0:
                            f = 1
                            break
                    if f == 0:
                        r2 = RPM(h[rpm.RPMTAG_NAME], h)
                        if recursive == 1:
                            list.append(r2)
                            check_append(r2, list, recursive)
                        
                    found = 1
        if found == 0:
            # search new packages (already in rpm list)
            for r2 in list:
                if r == r2:
                    continue
                res = 1
                if u[0] == r2.p_name or u[0] in r2.p_provides or \
                       u[0] in r2.p_filenames:
                    if u[2] != "": # we have a version requirement
                        epoch = r2.p_epoch
                        if epoch != "":
                            epoch = "%s:" % r2.p_epoch
                        res = evr_compare("%s%s-%s" % (epoch, r2.p_version, r2.p_release), u[1], u[2])
                    if res == 1:
                        r.p_resolved.append((u, r2))
                        found = 1
                        continue

        if found == 0:
            # requirement unresolved
            i += 1
        else:
            r.p_unresolved.remove(u)
            if r != r2:
                r.p_resolved.append((u, r2))

# ----------------------------------------------------------------------------

def usage():
    print """Usage: %s [-v[v]] [-r] [-nr] [-sd] [<rpm name/package>...]

  -h  | --help                print help
  -v  | --verbose             be more verbose
  -r  | --recursive           search recursive
  -nr | --no-resolved         do not list resolved dependencies
  -sd | --self-dependencies   list self dependencies

This program prints a list packages for the dependencies.
The option -v enables the listing of each dependency with the corresponding
package.
    """ % sys.argv[0]

# ----------------------------------------------------------------------------
            
rpms = [ ]

print_resolved = 1
verbose = 0
self_deps = 0
recursive = 0

ts = rpm.ts()

if __name__ == '__main__':
    if len(sys.argv) == 1:
        usage()
        sys.exit(0)
    for i in xrange(1, len(sys.argv)):
        if sys.argv[i] == "-h" or sys.argv[i] == "--help":
            usage()
            sys.exit(0)
        elif sys.argv[i] == "-r" or sys.argv[i] == "--recursive":
            recursive = 1
        elif sys.argv[i] == "-nr"or sys.argv[i] == "--no-resolved":
            print_resolved = 0
        elif sys.argv[i] == "-sd"or sys.argv[i] == "--self-dependencies":
            self_deps = 1
        elif sys.argv[i] == "-v"or sys.argv[i] == "--verbose":
            verbose += 1
        else:
            name = sys.argv[i]
            hdr = get_hdr(name)
            for i in xrange(len(hdr)):
                if hdr[i] not in rpms:
                    r = RPM(name, hdr[i])
                    rpms.append(r)
            if len(hdr) == 0:
                print "'%s' neither is a rpm nor is it installed." % name
                sys.exit(-1)

    for r in rpms:
        check_append(r, rpms, recursive)

    for r in rpms:
        if len(r.p_unresolved) > 0 or \
               (print_resolved == 1 and len(r.p_resolved) > 0):
            inst = ""
            name = "%s-%s-%s" % (r.p_name, r.p_version, r.p_release)
            if r.p_installed == 0 and verbose > 1:
                name += ": %s" % r.p_file
            print "%s" % name

        if print_resolved == 1:
            # resolved dependencies
            list = [ ]
            for j in xrange(len(r.p_resolved)):
                r2 = r.p_resolved[j]
                dep = r2[0][0]
                pkg = "%s-%s-%s" % (r2[1].p_name, r2[1].p_version, r2[1].p_release)
                if verbose > 0:
                    if r2[0][2] != "":
                        dep = "%s %s %s" % (r2[0][0], r2[0][1], r2[0][2])
                    print "\t%s --> %s" % (dep, pkg)
                else:
                    if pkg not in list:
                        list.append(pkg)
            if verbose == 0:
                list.sort()
                for pkg in list:
                    print "\t%s" % pkg
        for j in xrange(len(r.p_unresolved)):
            r2 = r.p_unresolved[j]
            dep = r2[0]
            if r2[2] != "":
                dep = "%s %s %s" % (r2[0], r2[1], r2[2])
            print "\trequires '%s'" % dep