#! /usr/bin/python -tt

# Copyright 2007 Red Hat
# Author: James Antill <jantill@redhat.com>

import os
import sys
import glob
import re

import yum
from yum.update_md import UpdateMetadata


def read_netstr(io, lim):
    num = 0
    x = '0'
    while True:
        x = io.read(1)
        if x == ':': # End of num for netstr
            break
        if num > lim:
            raise "Invalid input (too big)"
        if not num and (x == '\n'):
            continue # Skip newlines ... debugging
        if x == "":
            sys.exit(0)
        num *= 10
        try:
            num += int(x)
        except:
            raise "Invalid input (not a number: '%s')" % x
    data = io.read(num)
    while len(data) < num:
        tmp  = io.read(num - len(data))
        data = "".join([data, tmp])
    tmp = io.read(1)
    if tmp != ',':
        raise "Invalid input (bad length)"
    return data

class FakeStringIO:
    """ Very simple fake read IO from string. """
    def __init__(self, pbuf):
        self.buf  = pbuf
        self.used = 0
        self.size = len(pbuf)
        
    def read(self, x):
        beg = self.used
        end = self.used + x
        self.used += x
        return str(self.buf[beg:end])
    def remaining(self):
        return self.size - self.used

def explode_netstr(data):
    io = FakeStringIO(data)
    ret = []
    while io.remaining():
        ret.append(read_netstr(io, io.remaining()))
    return ret

def read_next_cmd(io = sys.stdin):
    """ Read a command from stdin, as a netstr. """
    while True:
        yield explode_netstr(read_netstr(io, 1024 * 1024 * 4))

def str_netstr1(val):
    val = str(val)
    return "%u:%s," % (len(val), val)
def str_netstr(*args, **kwords):
    combine = True
    if 'combine' in kwords:
        combine = kwrds['combine']
    data = []
    for val in args:
        data.append(str_netstr1(val))
    all = "".join(data)
    if combine:
        all = str_netstr1(all)
    return all

# End of NETSTR functions...

def package_list(x):
    y = yum.YumBase()

    y.localPackages = []
    y.updates = []

    y.doConfigSetup(init_plugins=False)

    return y.doPackageLists(x)
    
def installed_lister():
    ygh = package_list('installed')
    for pkg in ygh.installed:
        yield pkg
    
def available_lister():
    ygh = package_list('available')
    for pkg in ygh.available:
        yield pkg
    
def extras_lister():
    ygh = package_list('extras')
    for pkg in ygh.extras:
        yield pkg
    
def updates_lister(full_list, sec_only):
    ygh = package_list('updates')

    if not full_list:
        dummy_md = {'type' : 'unknown'}
        for pkg in ygh.updates:
            yield (pkg, dummy_md)
        
    repos = []

    upmd = UpdateMetadata()
    for i in ygh.updates:
        if i.repo not in repos:
            repos.append(i.repo)
            
    for r in repos:
        try: # attempt to grab the updateinfo.xml.gz from the repodata
            upmd.add(r)
        except yum.Errors.RepoMDError:
            pass # No metadata found for this repo

    # Walk our list of updates, outputting data:
    for pkg in ygh.updates:
        md = upmd.get_notice((pkg.name, pkg.ver, pkg.rel))
        if md:
            md = md.get_metadata()
        elif not sec_only:
            md = {'type' : 'unknown'}
            
        if md:
            if sec_only:
                if md['type'] != 'security':
                    continue
            yield (pkg, md)
                

def obsoletes_lister():
    ygh = package_list('recent')
    for (pkg, instpkg) in ygh.obsoletesTuples:
        yield (pkg, instpkg)
    
def recent_lister():
    ygh = package_list('recent')
    for pkg in ygh.recent:
        yield pkg
    
def list_cmd(cmd, args):
    sec_only = False
    ret = [str_netstr1("ok")]
    if len(args) == 1:
        sub2func = {"installed" : installed_lister,
                    "available" : available_lister,
                    "extras" : extras_lister,
                    "recent" : recent_lister}
        if args[0] in sub2func:
            for pkg in sub2func[args[0]]():
                x = str_netstr('pkg', pkg.name,
                               'epoch', pkg.epoch,
                               'version', pkg.version,
                               'release', pkg.release,
                               'arch', pkg.arch)
                ret.append(x)
            print str_netstr1("".join(ret)),
            return
        if args[0] == "obsoletes":
            for (opkg, ipkg) in obsoletes_lister():
                x = str_netstr('pkg', opkg.name,
                               'epoch', opkg.epoch,
                               'version', opkg.version,
                               'release', opkg.release,
                               'arch', opkg.arch)
                y = str_netstr('pkg', ipkg.name,
                               'epoch', ipkg.epoch,
                               'version', ipkg.version,
                               'release', ipkg.release,
                               'arch', ipkg.arch)
                ret.append(str_netstr1(x + y))
            print str_netstr1("".join(ret)),
            return        
            
    ok = False
    if len(args) >= 1:
        if args[0] == "security":
            sec_only = True
            args.pop(0)
    if len(args) >= 1:
        if args[0] == "updates":
            ok = True
    if not ok:
        print str_netstr("err", "unknown argument"),
        return # Don't do anything

    for pkg, md in updates_lister(False, sec_only):
        x = str_netstr('pkg', pkg.name,
                       'epoch', pkg.epoch,
                       'version', pkg.version,
                       'release', pkg.release,
                       'arch', pkg.arch)
        ret.append(x)
    print str_netstr1("".join(ret)),
        
def info_cmd(cmd, args):
    sec_only = False
    ok = False
    if len(args) >= 1:
        if args[0] == "security":
            sec_only = True
            args.pop(0)
    if len(args) >= 1:
        if args[0] == "updates":
            ok = True
    if not ok:
        print str_netstr1(""),
        return # Don't do anything

    ret = [str_netstr1("ok")]
    for pkg, md in updates_lister(True, sec_only):
        def refs():
            ret = []
            for ref in md['references']:
                ret.append(str_netstr('type', ref['type'], 'id', ref['id'],
                                      'href', ref['href'],
                                      'title', ref['title']))
            return "".join(ret)

        if 'id' not in md:
            upd = str_netstr('type', md['type'], combine=False)
        else:
            upd = str_netstr('type', md['type'],
                             'status', md['status'],
                             'issued', md['issued'],
                             'updated', md['updated'],
                             'references', refs(),
                             'title', md['title'],
                             'description', md['description'],
                             combine=False)

        x = str_netstr('pkg', pkg.name,
                       'epoch', pkg.epoch,
                       'version', pkg.version,
                       'release', pkg.release,
                       'arch', pkg.arch,
                       'update', upd)
        ret.append(x)
    print str_netstr1("".join(ret)),

def logs_cmd(cmd, args):
    num = 2
    if len(args) >= 2:
        if args[0] == "num":
            args.pop(0)
            num = int(args.pop(0))
    if len(args) >= 1:
        if args[0] == "all":
            num = 0
            args.pop(0)    
    if len(args) >= 1:
        if args[0] == "last":
            num = 1
            args.pop(0)
    if not os.path.exists("/var/log/yum.log"):
        print str_netstr1(""),
        return
    logs = ["/var/log/yum.log"] + glob.glob("/var/log/yum.log.*")
    if num and num < len(logs):
        logs = logs[0:num]
    logs.reverse()

    # Examples....
    # Aug 06 19:02:16 Updated: openoffice.org-writer.x86_64 1:2.2.1-18.1.fc7
    # Aug 06 19:02:30 Updated: yum-changelog.noarch 1.1.6-1.fc7
    # Sep 05 17:56:33 Updated: evolution - 2.10.3-4.fc7.i386
    # Sep 07 13:28:57 Installed: DevIL - 1.6.8-0.13.rc2.fc7.x86_64
    # Sep 05 03:12:54 Erased: iwlwifi-firmware
    # Is this localized?
    # Is epoch in there?

    map_mon2num = {'Jan' :  1, 'Feb' :  2, 'Mar' :  3, 'Apr' :  4,
                   'May' :  5, 'Jun' :  6, 'Jul' :  7, 'Aug' :  8,
                   'Sep' :  9, 'Oct' : 10, 'Nov' : 11, 'Dec' : 12}
    mon_match = "|".join(map_mon2num.keys())

    possible_re_matches = []

    # Updated/Installed Old style, with epoch
    line_re = re.compile("^(" + mon_match + ") ([^ ]+) " +
                         "([^:]+):([^:]+):([^:]+) " +
                         "(Updated|Installed): " +
                         "([^ ]+)[.]([^.]+) ([^:]+):([^-]+)-(.*)\n$")
    possible_re_matches.append(line_re)
    # Updated/Installed Old style, no epoch
    line_re = re.compile("^(" + mon_match + ") ([^ ]+) " +
                         "([^:]+):([^:]+):([^:]+) " +
                         "(Updated|Installed): " +
                         "([^ ]+)[.]([^.]+) ([^-]+)-(.*)\n$")
    possible_re_matches.append(line_re)
    # Updated/Installed New style, with epoch
    line_re = re.compile("^(" + mon_match + ") ([^ ]+) " +
                         "([^:]+):([^:]+):([^:]+) " +
                         "(Updated|Installed): " +
                         "([^ ]+) (-) ([^:]+):([^-]+)-(.*)[.]([^.]+)\n$")
    possible_re_matches.append(line_re)
    # Updated/Installed New style, no epoch
    line_re = re.compile("^(" + mon_match + ") ([^ ]+) " +
                         "([^:]+):([^:]+):([^:]+) " +
                         "(Updated|Installed): " +
                         "([^ ]+) (-) ([^-]+)-(.*)[.]([^.]+)\n$")
    possible_re_matches.append(line_re)
    # Erase
    line_re = re.compile("^(" + mon_match + ") ([^ ]+) " +
                         "([^:]+):([^:]+):([^:]+) " +
                         "(Erased): " +
                         "([^ ]+)\n$")
    possible_re_matches.append(line_re)

    ret = [str_netstr1("ok")]
    for log in logs:
      for line in file(log).readlines():
        for line_re in possible_re_matches:
            matcher = line_re.match(line)
            if matcher:
                break
        if not matcher:
            print >>sys.stderr, "DBG:", "Failed to match=%s" % line,
            continue # Didn't understand a line of the log file
        grps = matcher.groups()


        # Erase
        if len(grps) == 7:
            ret.append(str_netstr('month', map_mon2num[grps[0]], 'day', grps[1],
                                  'hour', grps[2], 'minute', grps[3],
                                  'second', grps[4], 'type', grps[5],
                                  'pkg', grps[6]))
            continue

        # Update or install
        if False:
            pass
        elif len(grps) == 12 and grps[7] == '-': # "new" style no epoch
            ret.append(str_netstr('month', map_mon2num[grps[0]], 'day', grps[1],
                                  'hour', grps[2], 'minute', grps[3],
                                  'second', grps[4], 'type', grps[5],
                                  'pkg', grps[6], 'epoch', grps[8],
                                  'version', grps[9], 'release', grps[10],
                                  'arch', grps[11]))
        elif len(grps) == 11 and grps[7] == '-': # "new" style no epoch
            ret.append(str_netstr('month', map_mon2num[grps[0]], 'day', grps[1],
                                  'hour', grps[2], 'minute', grps[3],
                                  'second', grps[4], 'type', grps[5],
                                  'pkg', grps[6],
                                  'version', grps[8], 'release', grps[9],
                                  'arch', grps[10]))
        elif len(grps) == 11: # "old" style with epoch
            ret.append(str_netstr('month', map_mon2num[grps[0]], 'day', grps[1],
                                  'hour', grps[2], 'minute', grps[3],
                                  'second', grps[4], 'type', grps[5],
                                  'pkg', grps[6], 'epoch', grps[8],
                                  'version', grps[9], 'release', grps[10],
                                  'arch', grps[7]))
        elif len(grps) == 10: # "old" style no epoch
            ret.append(str_netstr('month', map_mon2num[grps[0]], 'day', grps[1],
                                  'hour', grps[2], 'minute', grps[3],
                                  'second', grps[4], 'type', grps[5],
                                  'pkg', grps[6],
                                  'version', grps[8], 'release', grps[9],
                                  'arch', grps[7]))
            
    print str_netstr1("".join(ret)),

def update_cmd(cmd, userlist):

    if True:
        print str_netstr1(""),
        return
    self = yum.YumBase()

    oldcount = len(self.tsInfo)
    installed = self.rpmdb.simplePkgList()
    updates = self.up.getUpdatesTuples()
    if self.conf.obsoletes:
        obsoletes = self.up.getObsoletesTuples(newest=1)
    else:
        obsoletes = []

    if not userlist:
            for (obsoleting, installed) in obsoletes:
                obsoleting_pkg = self.getPackageObject(obsoleting)
                installed_pkg =  self.rpmdb.searchPkgTuple(installed)[0]
                self.tsInfo.addObsoleting(obsoleting_pkg, installed_pkg)
                self.tsInfo.addObsoleted(installed_pkg, obsoleting_pkg)
                                
            for (new, old) in updates:
                txmbrs = self.tsInfo.getMembers(pkgtup=old)

                if txmbrs and txmbrs[0].output_state == TS_OBSOLETED: 
                    self.verbose_logger.log(yum.logginglevels.DEBUG_2, 'Not Updating Package that is already obsoleted: %s.%s %s:%s-%s', old)
                else:
                    updating_pkg = self.getPackageObject(new)
                    updated_pkg = self.rpmdb.searchPkgTuple(old)[0]
                    self.tsInfo.addUpdate(updating_pkg, updated_pkg)
    else:
            # go through the userlist - look for items that are local rpms. If we find them
            # pass them off to localInstall() and then move on
            localupdates = []
            for item in userlist:
                if os.path.exists(item) and item[-4:] == '.rpm': # this is hurky, deal w/it
                    localupdates.append(item)
            
            if len(localupdates) > 0:
                val, msglist = self.localInstall(filelist=localupdates, updateonly=1)
                for item in localupdates:
                    userlist.remove(item)
                
            # we've got a userlist, match it against updates tuples and populate
            # the tsInfo with the matches
            updatesPo = []
            for (new, old) in updates:
                (n,a,e,v,r) = new
                updatesPo.extend(self.pkgSack.searchNevra(name=n, arch=a, epoch=e, 
                                 ver=v, rel=r))
                                 
            exactmatch, matched, unmatched = yum.packages.parsePackages(
                                                updatesPo, userlist, casematch=1
)
            for userarg in unmatched:
                pass
#                if not quiet:
#                    self.logger.error('Could not find update match for %s' % userarg)

            updateMatches = yum.misc.unique(matched + exactmatch)
            for po in updateMatches:
                for (new, old) in updates:
                    if po.pkgtup == new:
                        updated_pkg = self.rpmdb.searchPkgTuple(old)[0]
                        self.tsInfo.addUpdate(po, updated_pkg)


    if len(self.tsInfo) > oldcount:
        change = len(self.tsInfo) - oldcount
        msg = '%d packages marked for Update' % change
        print str_netstr(change),
#       return 2, [msg]
    else:
        print str_netstr(0),
#        return 0, ['No Packages marked for Update']

def exit_cmd(cmd, args):
    """ Exit """
    print str_netstr("exit"),
    sys.exit(0)
    
netstr_cmds = {'list' : list_cmd, 'info' : info_cmd, 'logs' : logs_cmd,
               'update' : update_cmd,
               'exit' : exit_cmd}

for i in read_next_cmd():
    if not len(i):
        continue

    cmd = i.pop(0)
    if cmd not in netstr_cmds:
        print str_netstr1(""),
        sys.exit(0)

    netstr_cmds[cmd](cmd, i)
    
