#!/usr/bin/python
#############################################################
## rpm-deathwatch
##
##      This program is free software: you can redistribute it and/or modify
##      it under the terms of the GNU General Public License as published by
##      the Free Software Foundation, either version 3 of the License, or
##      (at your option) any later version.
##  
##      This program is distributed in the hope that it will be useful,
##      but WITHOUT ANY WARRANTY; without even the implied warranty of
##      MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
##      GNU General Public License for more details.
##  
##      You should have received a copy of the GNU General Public License
##      along with this program.  If not, see <https://www.gnu.org/licenses/>.
##
##      Author: Kyle Walker <kwalker@redhat.com>
##
#############################################################

from __future__ import print_function

import atexit
import argparse
import fcntl
import logging
import os
import platform
import re
import signal
import sys
import subprocess
import time

STAP_PACKAGES = [
    'kernel',
    'kernel-debuginfo',
    'kernel-headers',
    'kernel-devel',
    'systemtap',
    'systemtap-runtime'
]

STAP_PATH        = '/var/rpm-deathwatch/'
STAP_MODULE_NAME = 'rpm_deathwatch'
STAP_SCRIPT      = 'rpm_deathwatch.stp'
STAP_MODULE      = 'rpm_deathwatch.ko'
STAP_PIPE        = '/var/run/rpm-deathwatch.fifo'
STAP_PIDFILE     = '/var/run/rpm-deathwatch-stap.pid'

logging.basicConfig()
logger = logging.getLogger('rpm-deathwatch')
logger.setLevel(logging.INFO)

# Maintains output for the utility and determines if the writes should stop.
#
# Required arguments:
#    name:  The /<path>/<name> to prepend to each strace output file.
#    num:   The number of files to write prior to cleaning up the initial file.
#    size:  Size limit imposed to any individual strace file (In MiB).
#    pid:   pid which, if killed, should stop the systemtap module.
class  files:
    def __init__(self, args, version):
        self.name = args.name
        self.num = args.num
        self.size = args.size
        self.pid = re.compile('was sent to [a-zA-Z]* (pid:{0})'.format(args.pid))
        self.version = version

        self.filelist  = []
        self.openfile  = self.newfile()
        self.lastprint = time.time()

    @property
    def writemode(self):
        if self.version == 2:
            return 'strings'
        else:
            return 'bytes'

    def newfile(self):
        filename = self.name + "-" + time.strftime("%m-%d-%Y-%I:%M:%S")
        try:
            f = open(filename, "w", buffering=0)
            self.writemode = 'string'
        except ValueError:
            f = open(filename, "wb", buffering=0)

        # Initialize to zero since this avoids having to stat the file outside of the write operation
        self.cachetell = 0

        self.filelist.append(f.name)
        logger.info("\tOpened: {0}".format(f.name))

        return f
    
    def deletefile(self):
        logger.info("\tDeleting {0}".format(self.filelist[0]))
        os.remove(self.filelist[0])
        self.filelist.pop(0)

    def check_pid(self, buf):
        if self.pid.findall(buf):
            return 1

        return 0

    def cleanup(self):
        if int(time.time() - self.lastprint) > 1:
            logger.debug("\t\tFile: {0}\tWrote: {1}Mib ({2})\tLimit: {3}".format(self.openfile.name, btomib(self.cachetell), self.cachetell, mibtob(self.size)))
            self.lastprint = time.time()

        if self.cachetell > mibtob( self.size ):
            logger.debug("\t{0} - Over size limit of {1}MiB - ".format(self.openfile.name, self.size))
            self.openfile = self.newfile()

            # Check to see if we are over the passed number of overall files
            if len(self.filelist) >= (self.num + 1):
                logger.debug("\t\tNumber of open files {0} exceeds the limit of {1}.".format(len(self.filelist), self.num))
                self.deletefile()
    
    def write(self, buf):
        logger.debug(buf.rstrip('\n'))

        if self.writemode == 'string':
            bytes_written = self.openfile.write(buf)
            self.cachetell = self.openfile.tell()
        else:
            bytes_written = self.openfile.write(buf.encode('utf-8', 'strict'))
            self.cachetell += bytes_written

def btomib(val):
    return ( val / 1024 / 1024 )

def mibtob(val):
    return ( val * 1024 * 1024 )

def check_env(version):
    logger.info('Checking the environment')
    current_kernelver = platform.uname()[2]

    missing = []
    for package in STAP_PACKAGES:
        if 'kernel' in package:
            checkstr = "{0}-{1}".format(package, current_kernelver)
        else:
            checkstr = "{0}".format(package)

        rpm = subprocess.Popen(['/bin/rpm', '-q', checkstr], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        buf = rpm.communicate()

        if version == 2:
            retval = buf[0]
        else:
            retval = buf[0].decode()

        if 'not installed' in retval:
            missing.append(retval.split(' ')[1])

    if missing:
        logger.error('The following RPMs are missing from the system and will be necessary to use this utility')
        logger.error('')
        for package in missing:
            logger.error('    {0}'.format(package))
        logger.error('')
        logger.error('The following yum command is generally needed in order to install the necessary dependencies:')
        if version == 2:
            logger.error('    # yum install -y kernel-{devel,headers}-$(uname -r) systemtap && debuginfo-install -y kernel')
        else:
            logger.error('    # yum install -y kernel-{devel,headers}-$(uname -r) systemtap && yum debuginfo-install -y kernel')
        return 1

    return 0

def build_module(version):
    logger.info('Building the systemtap module to report termination signals')
    stap = subprocess.Popen(['/usr/bin/stap', '--suppress-handler-errors', '-DMAXSTRINGLEN=4096', '-p4', '-m{0}'.format(STAP_MODULE_NAME), STAP_SCRIPT], cwd=STAP_PATH, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    buf = stap.communicate()
    if version == 2:
        retval = buf[0]
        errorval = buf[1]
    else:  # Version is 3 and so we get bytes back
        retval = buf[0].decode()
        errorval = buf[1].decode()

    if stap.returncode:
        logger.error('Failure while building systemtap module')
        logger.error('\t{0}'.format(errorval))

    return stap.returncode

def check_proc(pid):
    if pid:
        ps_proc = subprocess.Popen(['/bin/ps', '-o', 'comm=', '-p', '{0}'.format(pid)], stdout=subprocess.PIPE)
        buf = ps_proc.communicate()
        try:
            retval = buf[0].decode()
        except (UnicodeDecodeError, AttributeError):
            retval = buf[0]

        if 'stapio' in retval:
            return 1
        else:
            return 0
    else:
        return 0

def load_module(version):
    stap = None

    with open('/proc/modules', 'r') as f:
        if STAP_MODULE_NAME in f.read():
            logger.error('Module already loaded. Not attempting a further load operation.')
            return (stap, True)

    if not os.path.exists(STAP_PIPE):
        error = os.mkfifo(STAP_PIPE)
        if error:
            logger.error('Unable to create {0} - {1}'.format(STAP_PIPE, error))
            return (stap, True)

    if os.path.exists(os.path.join(STAP_PATH, STAP_MODULE)):
        logger.info('Loading the systemtap module')
        stap = subprocess.Popen(['/usr/bin/staprun', '-D', '-o{0}'.format(STAP_PIPE), '{0}'.format(STAP_MODULE)], cwd=STAP_PATH, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

        buf = stap.communicate()
        if version == 2:
            retval = buf[0]
            errorval = buf[1]
        else:  # Version is 3 and so we get bytes back
            retval = buf[0].decode()
            errorval = buf[1].decode()

        logger.debug("Received \"{0}\" from {1}".format(retval, stap))

        # The following works in either the python 2 or python 3 case, so we keep it
        if not stap.returncode:
            retval = int(retval)
            logger.debug("Writing pid \"{0}\" to {1}".format(retval, STAP_PIDFILE))

            try:
                with open(STAP_PIDFILE, "w") as f:
                    f.write('{0}'.format(retval))
            except Exception as e:
                logger.error("Unable to write pid {0} to {1} - {2}".format(retval, STAP_PIDFILE, e))
        else:
            logger.error("Error during module load\n--------------------\n{1}\n--------------------".format(errorval, STAP_PIDFILE))
            return (stap, True)
    else:
        logger.error('{0} not found.'.format(os.path.join(STAP_PATH, STAP_MODULE)))
        return (stap, True)

    return (stap, False)

def start_monitoring(args, stap, version):
    logger.info('Starting monitoring')
    output_file = files(args, version)
    logger.info('Writing output to {0}'.format(output_file))

    if os.path.exists(STAP_PIDFILE):
        with open(STAP_PIDFILE, 'r') as f:
            monitorpid = int(f.read())
        logger.info('Monitoring {0}'.format(monitorpid))

    else:
        logger.error("Error while reading {0}".format(STAP_PIDFILE))
        return 1

    # This is necessary due to python3 not allowing unbuffered string reads
    if version == 2:
        read_as = 'r'
    else:
        read_as = 'rb'

    with open(STAP_PIPE, read_as, buffering=0) as f:
        while True:
            try:
                if version == 2:
                    buf = f.readline()
                else:
                    buf = f.readline().decode()

                output_file.cleanup()
                output_file.write(buf)
                if output_file.check_pid(buf):
                    return 0

                if stap.returncode:
                    logger.error('{0} Exited - {1}'.format(stap, stap.communicate()))
                    return 1

                try:
                    os.kill(int(monitorpid), 0)
                except:
                    logger.error('{0} No longer exists'.format(monitorpid))
                    return 1

            except (KeyboardInterrupt, SystemExit):
                # This suppresses the <CTRL>-c backtrace splat
                return 1

def main(version, args = None):
    if check_env(version):
        return 1

    if build_module(version):
        return 1

    stap, state = load_module(version)
    if state:
        return 1

    if start_monitoring(args, stap, version):
        return 1

    logger.info('Exiting due to monitored pid being killed')

@atexit.register
def kill_stapio():
    if os.path.exists(STAP_PIDFILE):
        with open(STAP_PIDFILE, 'r') as f:
            buf = f.read()

            if len(buf) > 0:
                try:
                    os.kill(int(buf), signal.SIGTERM)
                    os.unlink(STAP_PIDFILE)
                except:
                    pass # We really don't can't about failures

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--name', required=True, help='The /<path>/<name> to prepend to each logging output file.')
    parser.add_argument('--num', type=int, default=5, help='The number of files to write prior to cleaning up the initial file.')
    parser.add_argument('--size', type=int, default=100,  help='Size limit imposed to any individual logging file (In MiB).')
    parser.add_argument('--pid', type=int, help='PID killed where the trace should be stopped.')
    parser.add_argument('--debug', action='store_true', help='Enable debugging for the script.')
    args = parser.parse_args()

    if args.debug:
        logger.setLevel(logging.DEBUG)
    
    main(sys.version_info[0], args)

