#!/usr/bin/python3
# mini-git-tag-fsck
# Handling of tag2upload SOURCE_VERSION.git.tar.xz archives
#
# Copyright (C)2024 Sean Whitton
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import argparse
import email
import hashlib
import itertools
import os
import re
import subprocess
import sys
import tarfile
import tempfile
import zlib

def audit(tarball, distro, out_txt, out_sig):
    extract = tempfile.TemporaryDirectory()
    tb = tarfile.open(name=tarball)

    # This is similar to using tarfile.data_filter from Python 3.12+.
    for member in tb.getmembers():
        if not (member.isreg() or member.isdir()):
            raise Exception(f"{tarball} contains non-regular file")
        if os.path.isabs(member.name):
            raise Exception(f"{tarball} contains member with absolute name")
        if not os.path.realpath(os.path.join(extract.name, member.name))\
           .startswith(extract.name):
            raise Exception(
                f"{tarball} contains member with path-join trickery")

    tb.extractall(path=extract.name)
    extract_list = os.listdir(extract.name)
    if len(extract_list) != 1:
        raise Exception(f"{tarball} has !=1 top-level members")
    repo = os.path.join(extract.name, extract_list[0])

    def read_obj(expected_type, oid):
        if re.match(r'^[0-9A-Fa-f]{40}$', oid) is None:
            raise Exception(f"Found invalid object ID")
        obj = os.path.join("objects", oid[:2], oid[2:])
        head, gen = iter_peel_field(read_deflated(os.path.join(repo, obj)))
        if (m := re.match(rb'^([a-z]+) (\d+)$', head)) is None:
            raise Exception(f"{obj} is malformed")
        actual_type = m.group(1).decode()
        expected_len = int(m.group(2))
        if actual_type != expected_type:
            raise Exception(f"{obj} is not {expected_type}")
        h = hashlib.sha1()
        h.update(head+b"\0")
        total_len = 0
        for data in gen:
            total_len += len(data)
            h.update(data)
            yield data
        if total_len != expected_len:
            raise Exception(f"{obj} is corrupt: incorrect length")
        if h.hexdigest() != oid:
            raise Exception(f"{obj} is corrupt: checksum mismatch")

    distro_tags_dir = os.path.join(repo, "refs", "tags", distro)
    dtags_list = os.listdir(distro_tags_dir)
    if len(dtags_list) != 1:
        raise Exception(f"{tarball} has !=1 {distro}/ tags or is not bare")
    dtag_name = f"{distro}/{dtags_list[0]}"
    dtag_target = open(os.path.join(distro_tags_dir, dtags_list[0]))\
        .read().splitlines()[0]

    # Avoid using any information from inside the tarball to default these.
    out_txt = out_txt or tarball+".tag"
    out_sig = out_sig or out_txt+".asc"

    headers = True
    out = tag = []
    sig = []
    dtag_commit = dtag_type = dtag_tag = None
    # read_obj validates dtag_target, checking it's a valid Git object name.
    for line in iter_slurp_lines(read_obj("tag", dtag_target)):
        if headers:
            if line.startswith("object "):
                if dtag_commit:
                    raise Exception(
                        f"Multiple 'object' headers in {dtag_target}")
                else:
                    dtag_commit = line[7:]
            elif line.startswith("type "):
                if dtag_type:
                    raise Exception(
                        f"Multiple 'type' headers in {dtag_target}")
                elif (dtag_type := line[5:]) != "commit":
                    raise NotImplementedError(
                        "DEP-14 tag points to a non-commit")
            elif line.startswith("tag "):
                if dtag_tag:
                    raise Exception(
                        f"Multiple 'tag' headers in {dtag_target}")
                elif (dtag_tag := line[4:]) != dtag_name:
                    raise Exception("Conflicting names for the DEP-14 tag")
            elif len(line) == 0:
                headers = False
        elif line == "-----BEGIN PGP SIGNATURE-----":
            out = sig
        out.append(line+"\n")
    if not dtag_commit:
        raise Exception("DEP-14 tag does not point to a commit")

    dtag_tree = None
    # read_obj validates dtag_commit, checking it's a valid Git object name.
    for line in iter_slurp_lines(read_obj("commit", dtag_commit)):
        if line.startswith("tree "):
            if dtag_tree:
                raise Exception(f"Multiple 'tree' headers in {dtag_commit}")
            else:
                dtag_tree = line[5:]
        elif len(line) == 0:
            break
    if not dtag_tree:
        raise Exception(f"{dtag_commit} does not point to a tree")

    def ls_tree(parent, tree):
        tree = read_obj("tree", tree)
        while True:
            try:
                item, tree = iter_peel_field(tree)
                sha1, tree = iter_peel_bytes(20, tree)
            except StopIteration:
                return

            mode, entry = item.decode().split(" ", 1)
            entry = os.path.join(parent, entry)
            sha1 = format(int.from_bytes(sha1, byteorder="big"), "040x")

            if mode == "40000":
                ls_tree(entry, sha1)
            elif mode in ("100755", "100644", "120000"):
                # Iterate through it in order to verify its checksum.
                for data in read_obj("blob", sha1):
                    pass
                sys.stdout.write(f"{mode} {sha1}    {entry}\0")
            else:
                raise Exception(f"Invalid tree entry mode")
    ls_tree("", dtag_tree)

    # Avoid writing these files out unless all ls_tree() checks passed.
    # This actually closes the handles, s.t. if, e.g., ENOSPC, we will die.
    with open(out_txt, "w") as f:
        f.writelines(tag)
    with open(out_sig, "w") as f:
        f.writelines(sig)

# While tags, commits and trees oughtn't be too big,
# we should avoid loading huge blobs into memory all at once.
def read_deflated(path):
    decomp = zlib.decompressobj()
    with open(path, "rb") as f:
        while buf := f.read(2**16):
            if data := decomp.decompress(buf):
                yield data
        if data := decomp.flush():
            yield data

def iter_slurp_lines(binary_iterable):
    return "".join(map(lambda x: x.decode(), binary_iterable)).splitlines()

def iter_peel_bytes(nbytes, binary_iterable):
    """Where binary_iterable is an iterable yielding chunks of bytes from a
    stream of bytes, read nbytes bytes from the stream, and return a pair of
    the first nbytes bytes and an iterable yielding the stream's remaining
    bytes.

    """
    buf = b""
    while len(buf) < nbytes:
        buf += next(binary_iterable)
    # Minor optimisation: don't cons up a fresh [b""] and don't call
    # itertools.chain if we already have exactly nbytes.
    if len(buf) == nbytes:
        return buf, binary_iterable
    else:
        return buf[:nbytes], itertools.chain([buf[nbytes:]], binary_iterable)

def iter_peel_field(binary_iterable):
    """Where binary_iterable is an iterable yielding chunks of bytes from a
    stream of bytes, efficiently read bytes from the stream until end-of-file
    or the next null byte.  Return a pair of the bytes read and a iterable
    yielding the stream's remaining bytes.  Note that this means skipping over
    the null byte itself.

    """
    buf = [next(binary_iterable)]
    try:
        while b"\0" not in buf[-1]:
            buf.append(next(binary_iterable))
    except StopIteration:
        pass
    buf = b"".join(buf)
    try:
        first = buf.index(b"\0")
    except ValueError:
        return buf, binary_iterable
    # Minor optimisation: don't cons up a fresh [b""] and don't call
    # itertools.chain if the null byte was the last byte read from the stream.
    if first < len(buf)-1:
        binary_iterable = itertools.chain([buf[first+1:]], binary_iterable)
    return buf[:first], binary_iterable

def prepare(repo, distro, out_tar, deepen_to, utag_name, utag_target, dfiles):
    # For --audit mode, we must require only core modules.
    import pygit2

    orig_repo = pygit2.Repository(repo)
    orig_repo_dir = os.path.abspath(repo)
    out_repo_parent = tempfile.TemporaryDirectory()
    source, version = dpkg_parsechangelog(repo)
    fn_version = re.sub(r'^\d+:', "", version)
    fn_upstream_version = re.sub("-.+$", "", fn_version)
    pkg_dir_name = f"{source}-{fn_upstream_version}"
    out_repo_dir = os.path.join(out_repo_parent.name, pkg_dir_name)
    out_repo = pygit2.init_repository(out_repo_dir, bare=True)

    def get_tag(tag):
        if not tag.startswith("refs/tags/"):
            tag = f"refs/tags/{tag}"
        try:
            return orig_repo.get(orig_repo.lookup_reference(tag).target)
        except KeyError:
            raise Exception(f"Couldn't find {tag} in {orig_repo_dir}")
    def fetch_tag(tag, depth):
        run_git(out_repo_dir,
                "fetch", f"--depth={depth}", "origin", "tag", tag)

    dtag_name = f"{distro}/{dep14_version(version)}"
    orig_dtag = get_tag(dtag_name)
    orig_dtag_obj = orig_dtag.get_object()

    # It would be nicer to use
    #
    #     remote = out_repo.remotes.create_anonymous()
    #
    # here and then call
    #
    #    remote.fetch(depth=1, refspecs=[
    #        f"refs/tags/{dtag_name}:refs/tags/{dtag_name}"])
    #
    # rather than creating and deleting a true remote, but this
    # is broken: https://github.com/libgit2/pygit2/issues/1314
    out_repo.remotes.create("origin", f"file://{orig_repo_dir}")

    # We now decide exactly what to fetch so that we get enough history.
    # What we already know how to do is include the history back to the
    # upstream tag.  If that's not enough history for our purposes, we
    # rely on the caller passing us --deepen-to arguments.
    # This is to avoid having to teach mini-git-tag-fsck about quilt modes.
    #
    # You might think that it would be possible to achieve what we want using
    # just the --depth and --shallow-exclude options to 'git fetch'.
    # Investigation led me to believe that this (slightly) more complicated
    # 'git log --ancestry-path'-based algorithm is in fact needed.
    #
    # This implementation has git shallowly fetch down each side of each merge
    # commit between the dep14 tag and each commit to which we're deepening.
    # Commonly there will be no such merge commits.  In the case that there
    # are, we might be able to pull in less as follows: walk each ancestry
    # path and collect the other side of each merge, i.e., the side we don't
    # care about.  Pass a --shallow-exclude option to 'git fetch' for each.
    required_depth = 0

    def ancestry_depth(revision_range):
        ancestry_path = run_git(
            orig_repo_dir, "log", "--ancestry-path", "--pretty=%H",
            revision_range)
        return len(ancestry_path.stdout.splitlines())

    if utag_name:
        fetch_tag(utag_name, 1)
        udepth = ancestry_depth(f"{utag_name}..{dtag_name}")
        if udepth > 0:
            required_depth = udepth
        elif ancestry_depth(f"{dtag_name}..{utag_name}") > 0:
            raise Exception("Upstream tag is descendant of dep14 tag?!")

    if deepen_to:
        for committish in deepen_to:
            cdepth = ancestry_depth(f"{committish}..{dtag_name}")
            if cdepth == 0:
                raise Exception(
                    "Expected {committish} to be ancestor of {dtag_name}")
            elif cdepth > required_depth:
                required_depth = cdepth

    required_depth += 1
    stderrln(f">> Copying {dtag_name} to depth {required_depth}")
    fetch_tag(dtag_name, required_depth)

    out_repo.remotes.delete("origin")

    # Begin by unpacking everything.
    # There is no way to prevent the occurrence of packing when fetching, so
    # we have to do this repacking afterwards.
    pack = os.path.join(out_repo_dir, "objects", "pack")
    handles = []
    for f in os.listdir(path=pack):
        f = os.path.join(pack, f)
        if f.endswith(".pack"):
            handles.append(open(f, "r"))
        # We have to unlink now because git-unpack-objects(1) will
        # skip over objects already in the repository.
        os.unlink(f)
    for h in handles:
        subprocess.run(["git", "-C", out_repo_dir, "unpack-objects"],
                       stdin=h, check=True)

    # Now we pack everything that 'mini-git-tag-fsck --audit' *doesn't* need.
    needed = {str(oid) for oid in
              [orig_dtag.id, orig_dtag_obj.id, orig_dtag_obj.tree_id]}
    for line in run_git(out_repo_dir, "ls-tree", "-r", "-t", "--object-only",
                        str(orig_dtag_obj.tree_id)).stdout.splitlines():
        needed.add(line.decode())
    all_objs = run_git(out_repo_dir, "cat-file",
                       "--batch-check=%(objectname)", "--batch-all-objects",
                       "--unordered").stdout.splitlines()
    def delegate(builder):
        for obj in all_objs:
            obj = obj.decode()
            if obj not in needed:
                builder.add(pygit2.Oid(hex=obj))
    stderr(">> Re-packing ... ")
    packed = out_repo.pack(pack_delegate=delegate, n_threads=0) # ncpus
    stderrln(f"packed {packed} objects not required for --audit")
    run_git(out_repo_dir, "prune-packed")

    tar = f"{source}_{fn_version}.git.tar.xz"
    default_out_tar = os.path.abspath(f"../{tar}")
    out_tar = out_tar or default_out_tar

    stderr(">> Starting compression ... ")
    tarfile.open(name=out_tar, mode="w:xz")\
           .add(out_repo_dir, arcname=pkg_dir_name)
    stderrln("done")

    # Write to d/files such that dpkg-genchanges will pick up our .git.tar.xz.
    # Note that this is ignored by 'dpkg-genchanges -S'.
    if dfiles:
        dctrl = os.path.join(orig_repo_dir, "debian", "control")
        dfiles = os.path.join(orig_repo_dir, "debian", "files")
        with open(dctrl, "r") as f:
            control = email.message_from_file(f)
        with open(dfiles, "a") as f:
            line = f'{tar} {control["Section"]} {control["Priority"]}'
            stderrln(f">> Appending '{line}' to debian/files")
            print(line, file=f)

def dpkg_parsechangelog(dir):
    msg = email.message_from_bytes(
        subprocess.run("dpkg-parsechangelog",
                       stdout=subprocess.PIPE, cwd=dir,
                       check=True).stdout)
    return msg["Source"], msg["Version"]

def dep14_version(ver):
    return re.sub("\\.(?=\\.|$|lock$)", ".#",
		  ver.translate(str.maketrans(":~", "%_")))

# This is for when pygit2 doesn't wrap something.
def run_git(repo, *args):
    return subprocess.run(["git", "-C", repo, *args],
                          stdout=subprocess.PIPE, check=True)

def stderr(arg):
    sys.stderr.write(arg)
    sys.stderr.flush()

def stderrln(*args):
    print(*args, file=sys.stderr)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="\
        Handling of tag2upload SOURCE_VERSION.git.tar.xz archives")

    # Can't pass a title= directly to add_mutually_exclusive_group.
    desc = parser.add_argument_group(title="mode of operation")
    mode = desc.add_mutually_exclusive_group(required=True)
    mode.add_argument(
        "--audit", metavar="TARBALL",
        help="print 'git ls-files -z'-style output, extract tag text & sig")
    mode.add_argument(
        "--prepare", metavar="REPO",
        help="produce .git.tar.xz for upload (called by t2u server-side)",
        nargs="?", const=".")

    files = parser.add_argument_group(
        title="optional arguments for output file names")
    files.add_argument(
        "--out-txt", metavar="FILE",
        help="in --audit mode, file for the tag's text")
    files.add_argument(
        "--out-sig", metavar="FILE",
        help="in --audit mode, file for the tag's signature")
    files.add_argument(
        "--out-tar", metavar="FILE",
        help="in --prepare mode, file for the .git.tar.xz")

    prepare_args = parser.add_argument_group(
        title="optional arguments for --prepare mode")
    prepare_args.add_argument(
        "--deepen-to", metavar="COMMITTISH", nargs="*",
        help="deepen the clone down to each COMMITTISH")
    prepare_args.add_argument(
        "--upstream", metavar="TAG",
        help="the upstream= value from the maintainer's tag")
    prepare_args.add_argument(
        "--upstream-commit", metavar="COMMIT",
        help="the upstream-commit= value from the maintainer's tag")
    prepare_args.add_argument(
        "--not-for-upload",
        help="do not write to debian/files")

    other_args = parser.add_argument_group(title="other optional arguments")
    other_args.add_argument(
        "--distro", metavar="DISTRO",
        help="the distro name: 'debian' by default",
        nargs="?", default="debian", const="debian")

    args = parser.parse_args()

    if (args.upstream is None) != (args.upstream_commit is None):
        raise Exception("--upstream without --upstream-commit or vice-versa")
    if args.audit:
        if args.out_tar:
            stderrln("--out-tar argument ignored in --audit mode")
        if args.deepen_to:
            stderrln("--deepen-to argument ignored in --audit mode")
        if args.upstream:
            stderrln("--upstream argument ignored in --audit mode")
        if args.upstream_commit:
            stderrln("--upstream-commit argument ignored in --audit mode")
        if args.not_for_upload:
            stderrln("--not-for-upload argument ignored in --audit mode")
        audit(args.audit, args.distro, args.out_txt, args.out_sig)
    elif args.prepare:
        if args.out_txt:
            stderrln("--out-txt argument ignored in --prepare mode")
        if args.out_sig:
            stderrln("--out-sig argument ignored in --prepare mode")
        prepare(args.prepare, args.distro, args.out_tar, args.deepen_to,
                args.upstream, args.upstream_commit, not args.not_for_upload)
