#!/usr/bin/python3
# coding: utf8
#
# debtags - Implement package tags support for Debian
#
# Copyright (C) 2003--2012  Enrico Zini <enrico@debian.org>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

from __future__ import print_function
import argparse
import sys
import logging
import apt
import re
import os
import io
import math
import fnmatch
import shutil
from debian import deb822
from collections import Counter, deque
import tempfile

VERSION = "2.0"
DEBTAGS_TAGS = "/var/lib/debtags/package-tags"
DEBTAGS_VOCABULARY = "/var/lib/debtags/vocabulary"

log = logging.getLogger("debtags")

class Fail(Exception):
    pass


class atomic_writer(object):
    """
    Atomically write to a file
    """
    def __init__(self, fname, mode, osmode=0o644, sync=True, **kw):
        self.fname = fname
        self.osmode = osmode
        self.sync = sync
        dirname = os.path.dirname(self.fname)
        self.fd, self.abspath = tempfile.mkstemp(dir=dirname, text="b" not in mode)
        self.outfd = open(self.fd, mode, closefd=True, **kw)

    def __enter__(self):
        return self.outfd

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            self.outfd.flush()
            if self.sync: os.fdatasync(self.fd)
            os.fchmod(self.fd, self.osmode)
            os.rename(self.abspath, self.fname)
        else:
            os.unlink(self.abspath)
        self.outfd.close()
        return False


def load_tags(pathname=DEBTAGS_TAGS):
    """
    Load tags from a debtags tagged collection file.

    Generate a sequence of (pkgname, set(tags))
    """
    with io.open(pathname, "rt") as fd:
        for line in fd:
            line = line.strip()
            if not line: continue
            pkg, tags = line.split(": ")
            yield pkg, set(tags.split(", "))

def load_vocabulary(pathname=DEBTAGS_VOCABULARY):
    """
    Load the debtags vocabulary, generating a sequence of deb822 paragraphs.
    """
    with io.open(pathname, "rt") as fd:
        for rec in deb822.Deb822.iter_paragraphs(fd):
            yield rec


def output_tag_facets(pkg, tags):
    facets = { x.split("::")[0] for x in tags }
    if not facets: return
    print("{}: {}".format(pkg, ", ".join(sorted(facets))))

def output_tag_pkgname(pkg, tags):
    print(pkg)

def output_tag_quiet(pkg, tags):
    pass

def output_tag_tagcoll(pkg, tags):
    if not tags: return
    print("{}: {}".format(pkg, ", ".join(sorted(tags))))

def output_pkg_pkgname(pkg, ver):
    print(pkg.name)

def output_pkg_quiet(pkg, ver):
    pass

def output_pkg_record(pkg, ver):
    print(ver.record)

def output_pkg_short(pkg, ver):
    print(pkg.name, "-", ver.summary)


class TagMatcher:
    """
    Matcher for arbitrary boolean expressions of tag names or patterns
    """
    class MatchNode:
        def __init__(self, expr):
            self.expr = expr
            self.regexp = re.compile(fnmatch.translate(expr))
        def match(self, tags):
            return any(self.regexp.match(t) for t in tags)
        def __str__(self):
            return self.expr

    class NegationNode:
        def __init__(self, subnode):
            self.subnode = subnode
        def match(self, tags):
            return not self.subnode.match(tags)
        def __str__(self):
            return "(not " + str(self.subnode) + ")"

    class AndNode:
        def __init__(self, op1, op2):
            self.op1 = op1
            self.op2 = op2
        def match(self, tags):
            return self.op1.match(tags) and self.op2.match(tags)
        def __str__(self):
            return "(" + str(self.op1) + " and " + str(self.op2) + ")"

    class OrNode:
        def __init__(self, op1, op2):
            self.op1 = op1
            self.op2 = op2
        def match(self, tags):
            return self.op1.match(tags) or self.op2.match(tags)
        def __str__(self):
            return "(" + str(self.op1) + " or " + str(self.op2) + ")"

    def __init__(self, expr, inverted=False):
        self.orig_expr = expr
        tokens = deque(self.tokenize(expr))
        self.matcher = self.parse_expr(tokens)
        if inverted:
            self.matcher = self.NegationNode(self.matcher)

    def match(self, tags):
        return self.matcher.match(tags)

    def tokenize(self, s):
        return re.findall(r"(?:[^ \t\n&|!()]+|&&|\|\||\(|\)|!)", s)

    def parse_expr(self, toks):
        op1 = self.parse_operand(toks)
        if op1 is None: return None

        if not toks: return op1

        if toks[0] == "||":
            toks.popleft()
            op2 = self.parse_expr(toks)
            if op2 is None: raise ParseError("or operation found without second operand")
            return self.OrNode(op1, op2)
        elif toks[0] == "&&":
            toks.popleft()
            op2 = self.parse_expr(toks)
            if op2 is None: raise ParseError("and operation found without second operand")
            return self.AndNode(op1, op2)
        elif toks[0] == ")":
            return op1
        else:
            raise ParseError("found unexpected token \"{}\"".format(toks[0]))

    def parse_operand(self, toks):
        if not toks: return None

        if toks[0] in ("||", "&&", ")"):
            raise ParseError("{} found at start of expression".format(toks[0]))
        elif toks[0] == "(":
            # Discard the (
            toks.popleft()
            # Parse the expression inside
            res = self.parse_expr(toks)
            # Ensure it ends with a closed parenthesis
            if not toks or toks[0] != ")":
                raise ParseError("missing closed parenthesis")
            toks.popleft()
            return res
        elif toks[0] == "!":
            # Discard the !
            toks.popleft()
            # Parse the expression inside
            res = self.parse_operand(toks)
            return self.NegationNode(res)
        else:
            return self.MatchNode(toks.popleft())


re_apt_tag_split = re.compile(r",\s+")
def tags_from_apt_record(rec):
    tags = rec.get("Tag", None)
    if not tags: return set()
    return set(re_apt_tag_split.split(tags))


class Action:
    ACTIONS = []

    def __init__(self):
        pass

    def main(self, args):
        raise NotImplementedError("{} not implemented")

    @property
    def apt_cache(self):
        res = getattr(self, "_apt_cache", None)
        if res is None:
            res = self._apt_cache = apt.Cache()
        return res

    @property
    def debtags_db(self):
        res = getattr(self, "_debtags_db", None)
        if res is None:
            self._debtags_db = res = dict(load_tags())
        return res

    def tags_from_apt(self):
        """
        Iterate a sequence of pkgname, set(tags) for each package in the apt
        cache
        """
        cache = self.apt_cache
        for pkg in cache:
            cand = pkg.candidate
            if not cand: continue
            tags = tags_from_apt_record(cand.record)
            yield pkg.name, tags

    def get_tag_output_func(self, args, default=output_tag_tagcoll):
        """
        Return a function that prints a line from a (pkg, tags) couple.

        The function will be picked according to command line arguments.
        """
        if args.facets:
            return output_tag_facets
        elif args.names:
            return output_tag_pkgname
        elif args.quiet:
            return output_tag_quiet
        else:
            return default

    def get_pkg_output_func(self, args, default=output_pkg_short):
        """
        Return a function that prints a python-apt package record.

        The function will be picked according to command line arguments.
        """
        if args.names:
            return output_pkg_pkgname
        elif args.quiet:
            return output_pkg_quiet
        elif args.full:
            return output_pkg_record
        elif args.short:
            return output_pkg_short
        else:
            return default

    def get_matcher(self, args):
        """
        Return a tag matcher for the current command line args
        """
        if not args.expr: return None
        return TagMatcher(args.expr, args.invert)

    @classmethod
    def setup_parser(cls, subparsers, output_coll=False, output_pkgs=False, matches=False, aliases=[]):
        """
        Add a subparser for this action to the given command line parser.

        If output_coll is True, also add options to control printing of
        collections.

        If output_pkgs is True, also add options to control printing of package
        data.
        """
        sp = subparsers.add_parser(cls.NAME, help=cls.HELP, aliases=aliases)
        sp.set_defaults(action=cls)

        # Create the collection output group
        if output_coll:
            sp.add_argument("--facets", action="store_true", help="output only the names of the facets (mainly used for computing statistics)")

        if output_coll or output_pkgs:
            # OptionGroup* collOutputOpts = createGroup("Options controlling transformations of tag data on output");
            sp.add_argument("--names", action="store_true", help="output only the names of the packages")
            sp.add_argument("--quiet", action="store_true", help="do not write anything to standard output")

        # Create the package output group
        if output_pkgs:
            # OptionGroup* pkgOutputOpts = createGroup("Options controlling transformations of package data on output");
            sp.add_argument("--full", action="store_true", help="output the full record of package data")
            sp.add_argument("--short", action="store_true", help="output the names of the packages, plus a short description")

        # Create the matching options group
        if matches:
            sp.add_argument("--invert", "-i", action="store_true", help="invert the match, selecting non-matching items")
            sp.add_argument("expr", nargs="?", action="store", help="filter expression")

        return sp

    @classmethod
    def register(cls, c):
        # Ensure class has NAME and HELP
        name = getattr(c, "NAME", None)
        if name is None:
            c.NAME = c.__name__.lower()
        help = getattr(c, "HELP", None)
        if help is None:
            c.HELP = c.__doc__

        cls.ACTIONS.append(c)


@Action.register
class Tag(Action):
    """
    (not implemented anymore)
    """
    def main(self, args):
        # tag
        #   tag [add <package> <tags...>\n"
        #   tag [rm  <package> <tags...>\n"
        #   tag [ls  <package>\n"
        #                View and edit the tags for a package\n");
        raise Fail("debtags tag is no longer supported."
                   " You can use a text editor to create a tag patch,"
                   " see http://debtags.debian.net/api/ for the format.")


@Action.register
class Cat(Action):
    """
    output the lines of the full package tag database that match
    the given tag expression. A tag expression (given as a single
    argument) is an arbitrarily complex binary expression of tag
    names. For example: role::program && ((use::editing || use::viewing)
    && !works-with::text)
    """
    def main(self, args):
        coll = self.debtags_db
        output_func = self.get_tag_output_func(args)
        matcher = self.get_matcher(args)
        has_output = False
        for pkg, tags in coll.items():
            if matcher and not matcher.match(tags): continue
            output_func(pkg, tags)
            has_output = True

        return 0 if has_output else 1

    @classmethod
    def setup_parser(cls, subparsers):
        sp = super().setup_parser(subparsers, output_coll=True, matches=True, aliases=["grep"])
        return sp


@Action.register
class Check(Action):
    """
    Check that all the tags in the given tagged collection are present
    in the tag vocabulary.  Checks the main database if no file is
    specified
    """
    def main(self, args):
        all_tags = { rec["Tag"] for rec in load_vocabulary() if "Tag" in rec }

        missing_count = 0
        missing = Counter()
        missing_tagname_col_width = 0
        for pkg, tags in load_tags(args.tagfile):
            for t in tags:
                if t in all_tags: continue
                missing_count += 1
                missing[t] += 1
                if len(t) > missing_tagname_col_width:
                    missing_tagname_col_width = len(t)

        if missing_count == 0:
            return 0

        print("{} tags were found in packages but not in the vocabulary".format(len(missing)))
        print("This happened {} times".format(missing_count))
        print("The tags found in the collection but not in the vocabulary are:")

        count_col_width = math.ceil(math.log10(missing_count))

        for tag, count in sorted(missing.items()):
            print(tag.rjust(missing_tagname_col_width), "in", str(count).rjust(count_col_width), "packages")
        return 1

    @classmethod
    def setup_parser(cls, subparsers):
        sp = super().setup_parser(subparsers, output_coll=True)
        sp.add_argument("tagfile", action="store", nargs="?", help="tag file to read", default=DEBTAGS_TAGS)
        return sp



@Action.register
class Diff(Action):
    """
    Create a tag patch between the current tag database and the tag
    collection [filename].  Standard input is used if filename is not specified
    """
    def main(self, args):
        system = dict(load_tags())
        user = dict(load_tags(args.tagfile))

        for pkg in system.keys() - user.keys():
            print("{}: {}".format(pkg, ", ".join("-" + t for t in sorted(system[pkg]))))
        for pkg in user.keys() - system.keys():
            print("{}: {}".format(pkg, ", ".join("+" + t for t in sorted(user[pkg]))))
        for pkg in system.keys() & user.keys():
            ts = system[pkg]
            tu = user[pkg]
            diff = []
            for t in sorted(tu - ts):
                diff.append("+" + t)
            for t in sorted(ts - tu):
                diff.append("-" + t)
            if diff:
                print("{}: {}".format(pkg, ", ".join(diff)))
        return 0

    @classmethod
    def setup_parser(cls, subparsers):
        sp = super().setup_parser(subparsers, aliases=["mkpatch"])
        sp.add_argument("tagfile", action="store", help="tag file to read")
        return sp


@Action.register
class Search(Action):
    """
    Output the packages matching the given tag expression
    """
    def main(self, args):
        cache = self.apt_cache

        if args.command_name == "dumpavail":
            output = self.get_pkg_output_func(args, default=output_pkg_record)
        else:
            output = self.get_pkg_output_func(args, default=output_pkg_short)

        matcher = self.get_matcher(args)

        has_output = False
        for pkg in cache:
            for ver in pkg.versions:
                if matcher:
                    tags = tags_from_apt_record(ver.record)
                    if not matcher.match(tags):
                        continue
                output(pkg, ver)
                has_output = True

        return 0 if has_output else 1

    @classmethod
    def setup_parser(cls, subparsers):
        sp = super().setup_parser(subparsers, output_pkgs=True, matches=True, aliases=["dumpavail"])
        return sp


@Action.register
class Show(Action):
    """
    (deprecated) mostly the same as apt-cache show <pkgnames>
    """
    def main(self, args):
        cache = self.apt_cache
        output = self.get_pkg_output_func(args, default=output_pkg_record)

        for name in args.pkgname:
            if name not in cache:
                raise Fail("package {} not found".format(name))
            pkg = cache[name]
            output(pkg, pkg.candidate)
        return 0

    @classmethod
    def setup_parser(cls, subparsers):
        sp = super().setup_parser(subparsers, output_pkgs=True)
        sp.add_argument("pkgname", nargs="+", action="store", help="name of the packge to show")
        return sp


@Action.register
class Submit(Action):
    """
    Upload the given patch file to the central tag repository
    (uses debtags-submit-patch)
    """
    def main(self, args):
        cmd = ["debtags-submit-patch"]
        if args.verbose:
            cmd.append("--verbose")
        cmd.append(args.patchfile)
        os.execvp("debtags-submit-patch", cmd)

    @classmethod
    def setup_parser(cls, subparsers):
        sp = super().setup_parser(subparsers)
        sp.add_argument("patchfile", action="store", help="pathname of the tag diff file to submit")
        return sp


@Action.register
class Tagcat(Action):
    """
    Output the tag vocabulary
    """
    def main(self, args):
        for rec in load_vocabulary():
            print(rec)
        return 0


@Action.register
class TagShow(Action):
    """
    Show the vocabulary informations about a tag
    """
    def main(self, args):
        found = False
        for rec in load_vocabulary():
            if rec.get("Tag", None) != args.tag: continue
            found = True
            print(rec)
        if not found:
            log.info("Tag {} was not found in the tag vocabulary".format(args.tag))
        return 0 if found else 1

    @classmethod
    def setup_parser(cls, subparsers):
        sp = super().setup_parser(subparsers)
        sp.add_argument("tag", action="store", help="name of the tag to show")
        return sp


@Action.register
class TagSearch(Action):
    """
    Show a summary of all tags whose data contains the given strings
    """
    def main(self, args):
        re_search = re.compile("|".join(re.escape(x) for x in args.terms), re.I)

        found = False

        for rec in load_vocabulary():
            name = rec.get("Facet", "") or rec.get("Tag", "")
            if not re_search.search(name) and not re_search.search(rec.get("Description", "")):
                continue
            found = True
            print(name, "-", rec.get("Description", "").split("\n", 1)[0])

        return 0 if found else 1

    @classmethod
    def setup_parser(cls, subparsers):
        sp = super().setup_parser(subparsers)
        sp.add_argument("terms", nargs="+", action="store", help="strings used to search the tag data")
        return sp


@Action.register
class Update(Action):
    """
    Updates the package tag database (requires root)
    Collect package tag data from the sources listed in
    /etc/debtags/sources.list, then regenerate the debtags
    tag database and main index.
    It needs to be run as root
    """
    # // Create the collection update group
    # OptionGroup* updateOpts = createGroup("");
    # misc_local = updateOpts->add<BoolOption>("local", 0, "local", "",
    #         "do not download files when performing an update");
    # misc_reindex = updateOpts->add<BoolOption>("reindex", 0, "reindex",
    #         "", "do not download any file, just do reindexing if needed");
    # update = addEngine("update", "",
    # update->add(updateOpts);
    def main(self, args):
        dbdir = "/var/lib/debtags/"
        if not os.access(dbdir, os.W_OK):
            if not os.path.exists(dbdir):
                raise Fail("Output directory {} does not exist".format(dbdir))
            else:
                raise Fail("I do not have permission to write to {}".format(dbdir))

        # Set a standard umask since we create system files
        orig_umask = os.umask(0o022);

        # Load the vocabulary shipped with debtags
        shutil.copyfile("/usr/share/debtags/vocabulary", os.path.join(dbdir, "vocabulary"))

        # Load tags from apt and write them out
        with atomic_writer(os.path.join(dbdir, "package-tags"), "wt") as out:
            for pkg, tags in self.tags_from_apt():
                if not tags: continue
                print("{}: {}".format(pkg, ", ".join(sorted(tags))), file=out)



@Action.register
class VocFilter(Action):
    """
    Filter out the tags that are not found in the given vocabulary file
    """
    def main(self, args):
        all_tags = { rec["Tag"] for rec in load_vocabulary(args.vocabulary) if "Tag" in rec }

        for pkg, tags in load_tags(args.tagfile):
            output_tag_tagcoll(pkg, sorted(t for t in tags if t in all_tags))

        return 0

    @classmethod
    def setup_parser(cls, subparsers):
        sp = super().setup_parser(subparsers, output_coll=True)
        sp.add_argument("--vocabulary", action="store", default=DEBTAGS_VOCABULARY,
                        help="vocabulary file to use instead of the current debtags vocabulary")
        sp.add_argument("tagfile", action="store", nargs="?", help="tag file to read", default=DEBTAGS_TAGS)
        return sp



def main():
    parser = argparse.ArgumentParser(
        description="Command line interface to access and manipulate Debian Package Tags")
    parser.add_argument("--version", action="version", version="%(prog)s {}".format(VERSION))
    parser.add_argument("--verbose", "-v", action="store_true", help="enable verbose output")
    parser.add_argument("--debug", action="store_true", help="enable debugging output (including verbose output)")

    subparsers = parser.add_subparsers(title="subcommands",
                                       description="valid subcommands",
                                       dest="command_name",
                                       help="additional help")
    for a in Action.ACTIONS:
        a.setup_parser(subparsers)

    args = parser.parse_args()

    # Setup logging
    FORMAT = "%(message)s"
    if args.debug:
        logging.basicConfig(level=logging.DEBUG, stream=sys.stderr, format=FORMAT)
    elif args.verbose:
        logging.basicConfig(level=logging.INFO, stream=sys.stderr, format=FORMAT)
    else:
        logging.basicConfig(level=logging.WARNING, stream=sys.stderr, format=FORMAT)

    action = args.action()
    try:
        sys.exit(action.main(args))
    except Fail as e:
        print(e, file=sys.stderr)
        sys.exit(1)

if __name__ == "__main__":
    main()
