#!/usr/bin/env python3
# Copyright 2016-2025 Google Inc.
# Copyright 2025 AFLplusplus Project. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import array
import base64
import collections
import ctypes
import errno
import glob
import hashlib
import itertools
import logging
import multiprocessing
import os
import random
import shutil
import subprocess
import sys
import uuid

# https://more-itertools.readthedocs.io/en/stable/_modules/more_itertools/recipes.html#batched
from sys import hexversion

logger = logging.getLogger(__name__)


def _batched(iterable, n, *, strict=False):
    """Batch data into tuples of length *n*. If the number of items in
    *iterable* is not divisible by *n*:
    * The last batch will be shorter if *strict* is ``False``.
    * :exc:`ValueError` will be raised if *strict* is ``True``.

    >>> list(batched('ABCDEFG', 3))
    [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]

    On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
    """
    if n < 1:
        raise ValueError("n must be at least one")
    iterator = iter(iterable)
    while batch := tuple(itertools.islice(iterator, n)):
        if strict and len(batch) != n:
            raise ValueError("batched(): incomplete batch")
        yield batch


if hexversion >= 0x30D00A2:  # pragma: no cover
    from itertools import batched as itertools_batched

    def batched(iterable, n, *, strict=False):
        return itertools_batched(iterable, n, strict=strict)

else:
    batched = _batched

    batched.__doc__ = _batched.__doc__

try:
    from tqdm import tqdm
except ImportError:
    print('Hint: install python module "tqdm" to show progress bar')

    class tqdm:

        def __init__(self, data=None, *args, **argd):
            self.data = data

        def __iter__(self):
            yield from self.data

        def __enter__(self):
            return self

        def __exit__(self, exc_type, exc_value, traceback):
            pass

        def update(self, *args):
            pass


def init_logger(args):
    log_level = logging.DEBUG if args.debug else logging.INFO
    logging.basicConfig(
        level=log_level, format="%(asctime)s - %(levelname)s - %(message)s"
    )


class HelpFormatter(argparse.HelpFormatter):
    def __init__(self, prog, *args, **kargs):
        super().__init__(prog, *args, **kargs)
        self.add_text("corpus minimization tool for AFL++ (python version)")
        self.add_text("")
        self.add_text("%s" % prog)


def init_args():
    parser = argparse.ArgumentParser(formatter_class=HelpFormatter)

    cpu_count = multiprocessing.cpu_count()
    group = parser.add_argument_group("Required parameters")
    group.add_argument(
        "-i",
        dest="input",
        action="append",
        metavar="dir",
        required=True,
        help="input directory with the starting corpus",
    )
    group.add_argument(
        "-o",
        dest="output",
        metavar="dir",
        required=True,
        help="output directory for minimized files",
    )

    group = parser.add_argument_group("Execution control settings")
    group.add_argument(
        "-f",
        dest="stdin_file",
        metavar="file",
        help="location read by the fuzzed program (stdin)",
    )
    group.add_argument(
        "-m",
        dest="memory_limit",
        default="none",
        metavar="megs",
        type=lambda x: x if x == "none" else int(x),
        help="memory limit for child process (default: %(default)s)",
    )
    group.add_argument(
        "-t",
        dest="time_limit",
        default=5000,
        metavar="msec",
        type=lambda x: x if x == "none" else int(x),
        help="timeout for each run (default: %(default)s)",
    )
    group.add_argument(
        "-O",
        dest="frida_mode",
        action="store_true",
        default=False,
        help="use binary-only instrumentation (FRIDA mode)",
    )
    group.add_argument(
        "-Q",
        dest="qemu_mode",
        action="store_true",
        default=False,
        help="use binary-only instrumentation (QEMU mode)",
    )
    group.add_argument(
        "-U",
        dest="unicorn_mode",
        action="store_true",
        default=False,
        help="use unicorn-based instrumentation (Unicorn mode)",
    )
    group.add_argument(
        "-X", dest="nyx_mode", action="store_true", default=False, help="use Nyx mode"
    )

    group = parser.add_argument_group("Minimization settings")
    group.add_argument(
        "--crash-dir",
        dest="crash_dir",
        metavar="dir",
        default=None,
        help="move crashes to a separate dir, always deduplicated",
    )
    group.add_argument(
        "-A",
        dest="allow_any",
        action="store_true",
        help="allow crashes and timeouts (not recommended)",
    )
    group.add_argument(
        "-C",
        dest="crash_only",
        action="store_true",
        help="keep crashing inputs, reject everything else",
    )
    group.add_argument(
        "-e",
        dest="edge_mode",
        action="store_true",
        default=False,
        help="solve for edge coverage only, ignore hit counts",
    )

    group = parser.add_argument_group("Misc")
    group.add_argument(
        "-T",
        dest="workers",
        type=lambda x: cpu_count if x == "all" else int(x),
        default=1,
        help="number of concurrent worker (default: %(default)d)",
    )
    group.add_argument(
        "--as_queue",
        action="store_true",
        help='output file name like "id:000000,hash:value"',
    )
    group.add_argument(
        "--no-dedup",
        action="store_true",
        help="skip deduplication step for corpus files",
    )
    group.add_argument("--debug", action="store_true")

    parser.add_argument("exe", metavar="/path/to/target_app")
    parser.add_argument("args", nargs="*")
    return parser.parse_args()


def get_asan_options():
    asan_options = "abort_on_error=1:symbolize=0:detect_leaks=0"
    user_options = os.environ.get("ASAN_OPTIONS")
    if user_options:
        asan_options += ":" + user_options
    return asan_options


def search_binary(name):
    searches = [
        None,
        os.path.dirname(__file__),
        os.getcwd(),
    ]
    if os.environ.get("AFL_PATH"):
        searches.append(os.environ["AFL_PATH"])

    for search in searches:
        binary = shutil.which(name, path=search)
        if binary:
            return binary
    logger.fatal("cannot find %s, please set AFL_PATH", name)
    sys.exit(1)


def init(args):
    if args.stdin_file and args.workers > 1:
        logger.error("-f is only supported with one worker (-T 1)")
        sys.exit(1)

    if args.memory_limit != "none" and args.memory_limit < 5:
        logger.error("dangerously low memory limit")
        sys.exit(1)

    if args.time_limit != "none" and args.time_limit < 10:
        logger.error("dangerously low timeout")
        sys.exit(1)

    if not args.nyx_mode and not os.path.isfile(args.exe):
        logger.error('binary "%s" not found or not regular file', args.exe)
        sys.exit(1)

    if not os.environ.get("AFL_SKIP_BIN_CHECK") and not any(
        [args.qemu_mode, args.frida_mode, args.unicorn_mode, args.nyx_mode]
    ):
        if b"__AFL_SHM_ID" not in open(args.exe, "rb").read():
            logger.error("binary '%s' doesn't appear to be instrumented", args.exe)
            sys.exit(1)

    for dn in args.input:
        if not os.path.isdir(dn) and not glob.glob(dn):
            logger.error('directory "%s" not found', dn)
            sys.exit(1)

    trace_dir = os.path.join(args.output, ".traces")
    shutil.rmtree(trace_dir, ignore_errors=True)
    try:
        os.rmdir(args.output)
    except OSError:
        pass
    if os.path.exists(args.output):
        logger.error(
            'directory "%s" exists and is not empty - delete it first', args.output
        )
        sys.exit(1)
    if args.crash_dir and not os.path.exists(args.crash_dir):
        os.makedirs(args.crash_dir)
    os.makedirs(trace_dir)

    logger.info("use %d workers (-T)", args.workers)


def detect_type_code(size):
    for type_code in ["B", "H", "I", "L", "Q"]:
        if 256 ** array.array(type_code).itemsize > size:
            return type_code


def get_nyx_map_size(target_dir):
    libnyx_path = search_binary("libnyx.so")
    libnyx = ctypes.CDLL(libnyx_path)

    NYX_ROLE_StandAlone = 0

    target_dir_c = target_dir.encode("utf-8")

    dummy_workdir_path = "/tmp/_afl_cmin_nyx_work_dir_%s" % (str(uuid.uuid4()))
    dummy_workdir_path_c = dummy_workdir_path.encode("utf-8")

    # nyx_config_load
    libnyx.nyx_config_load.argtypes = [ctypes.c_char_p]
    libnyx.nyx_config_load.restype = ctypes.c_void_p
    nyx_config = libnyx.nyx_config_load(target_dir_c)

    # nyx_config_set_workdir_path
    libnyx.nyx_config_set_workdir_path.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
    libnyx.nyx_config_set_workdir_path.restype = None
    libnyx.nyx_config_set_workdir_path(nyx_config, dummy_workdir_path_c)

    # nyx_config_set_process_role
    libnyx.nyx_config_set_process_role.argtypes = [ctypes.c_void_p, ctypes.c_int]
    libnyx.nyx_config_set_process_role.restype = None
    libnyx.nyx_config_set_process_role(nyx_config, NYX_ROLE_StandAlone)

    # nyx_new
    libnyx.nyx_new.argtypes = [ctypes.c_void_p, ctypes.c_int]
    libnyx.nyx_new.restype = ctypes.c_void_p
    nyx_runner = libnyx.nyx_new(nyx_config, 0)

    # nyx_get_bitmap_buffer_size
    libnyx.nyx_get_bitmap_buffer_size.argtypes = [ctypes.c_void_p]
    libnyx.nyx_new.restype = ctypes.c_int
    map_size = libnyx.nyx_get_bitmap_buffer_size(nyx_runner)

    # nyx_shutdown
    libnyx.nyx_shutdown.argtypes = [ctypes.c_void_p]
    libnyx.nyx_shutdown.restype = None
    libnyx.nyx_shutdown(nyx_runner)

    # nyx_remove_work_dir
    libnyx.nyx_remove_work_dir.argtypes = [ctypes.c_char_p]
    libnyx.nyx_remove_work_dir.restype = None
    libnyx.nyx_remove_work_dir(dummy_workdir_path_c)

    return map_size


def afl_showmap(
    args,
    afl_showmap_bin,
    tuple_index_type_code,
    input_path=None,
    batch=None,
    afl_map_size=None,
    first=False,
):
    assert input_path or batch
    # yapf: disable
    cmd = [
        afl_showmap_bin,
        "-m", str(args.memory_limit),
        "-t", str(args.time_limit),
        "-Z",  # cmin mode
    ]
    # yapf: enable
    placeholder = os.environ.get("AFL_INPUT_PLACEHOLDER", "@@")
    found_atat = False
    for arg in args.args:
        if placeholder in arg:
            found_atat = True

    if args.stdin_file:
        assert args.workers == 1
        input_from_file = True
        stdin_file = args.stdin_file
        cmd += ["-H", stdin_file]
    elif found_atat:
        input_from_file = True
        stdin_file = os.path.join(args.output, f".input.{os.getpid()}")
        cmd += ["-H", stdin_file]
    else:
        input_from_file = False

    if batch:
        input_from_file = True
        filelist = os.path.join(args.output, f".filelist.{os.getpid()}")
        temp_dir = os.path.join(args.output, f".filelist.{os.getpid()}.d")
        os.makedirs(temp_dir, exist_ok=True)
        temp_entries = []
        with open(filelist, "w") as f:
            for _, path in batch:
                base = os.path.basename(path)
                unique = f"{random.getrandbits(32):08x}_{base}"
                temp_path = os.path.join(temp_dir, unique)
                try:
                    os.link(path, temp_path)
                except OSError:
                    shutil.copy(path, temp_path)
                temp_entries.append((_, unique))
                f.write(temp_path + "\n")
        cmd += ["-I", filelist]
        output_path = os.path.join(args.output, f".showmap.{os.getpid()}")
        cmd += ["-o", output_path]
    else:
        if input_from_file:
            shutil.copy(input_path, stdin_file)
        cmd += ["-o", "-"]

    if args.frida_mode:
        cmd += ["-O"]
    if args.qemu_mode:
        cmd += ["-Q"]
    if args.unicorn_mode:
        cmd += ["-U"]
    if args.nyx_mode:
        cmd += ["-X"]
    if args.edge_mode:
        cmd += ["-e"]
    cmd += ["--", args.exe] + args.args

    env = os.environ.copy()
    env["AFL_QUIET"] = "1"
    env["ASAN_OPTIONS"] = get_asan_options()
    if first:
        logger.debug("run command line: %s", subprocess.list2cmdline(cmd))
        env["AFL_CMIN_ALLOW_ANY"] = "1"
    if afl_map_size:
        env["AFL_MAP_SIZE"] = str(afl_map_size)
    if args.crash_only:
        env["AFL_CMIN_CRASHES_ONLY"] = "1"
    if args.allow_any:
        env["AFL_CMIN_ALLOW_ANY"] = "1"

    if input_from_file:
        p = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env, bufsize=1048576)
    else:
        p = subprocess.Popen(
            cmd,
            stdin=open(input_path, "rb"),
            stdout=subprocess.PIPE,
            env=env,
            bufsize=1048576,
        )
    out = p.stdout.read()
    p.wait()

    if batch:
        result = []
        for idx, unique in temp_entries:
            values = []
            try:
                trace_file = os.path.join(output_path, unique)
                with open(trace_file, "r") as f:
                    values = list(map(int, f))
                crashed = len(values) == 0
                os.unlink(trace_file)
            except FileNotFoundError:
                a = None
                crashed = True
            values = [(t // 1000) * 9 + t % 1000 for t in values]
            a = array.array(tuple_index_type_code, values)
            result.append((idx, a, crashed))
        os.unlink(filelist)
        os.rmdir(output_path)
        for _, unique in temp_entries:
            try:
                os.unlink(os.path.join(temp_dir, unique))
            except FileNotFoundError:
                pass
        os.rmdir(temp_dir)
        return result
    else:
        values = []
        # split by newline to avoid issues with Nyx mode
        for line in out.split(b"\n"):
            if not line.isdigit():
                continue
            values.append(int(line))
        values = [(t // 1000) * 9 + t % 1000 for t in values]
        a = array.array(tuple_index_type_code, values)
        crashed = p.returncode in [2, 3]
        if input_from_file and stdin_file != args.stdin_file:
            os.unlink(stdin_file)
        return a, crashed


class JobDispatcher(multiprocessing.Process):

    def __init__(self, args, job_queue, jobs):
        super().__init__()
        self.args = args
        self.job_queue = job_queue
        self.jobs = jobs

    def run(self):
        init_logger(self.args)
        for job in self.jobs:
            self.job_queue.put(job)
        self.job_queue.close()


class Worker(multiprocessing.Process):

    def __init__(
        self,
        args,
        idx,
        afl_map_size,
        q_in,
        p_out,
        r_out,
        file_index_type_code,
        tuple_index_type_code,
        afl_showmap_bin,
    ):
        super().__init__()
        self.args = args
        self.idx = idx
        self.afl_map_size = afl_map_size
        self.q_in = q_in
        self.p_out = p_out
        self.r_out = r_out
        self.file_index_type_code = file_index_type_code
        self.tuple_index_type_code = tuple_index_type_code
        self.afl_showmap_bin = afl_showmap_bin

    def run(self):
        init_logger(self.args)
        map_size = self.afl_map_size or 65536
        max_tuple = map_size * 9
        max_file_index = 256 ** array.array(self.file_index_type_code).itemsize - 1
        m = array.array(self.file_index_type_code, [max_file_index] * max_tuple)
        counter = collections.Counter()
        crashes = []

        pack_name = os.path.join(self.args.output, ".traces", f"{self.idx}.pack")
        pack_pos = 0
        with open(pack_name, "wb") as trace_pack:
            while True:
                batch = self.q_in.get()
                if batch is None:
                    break

                for idx, r, crash in afl_showmap(
                    self.args,
                    self.afl_showmap_bin,
                    batch=batch,
                    afl_map_size=self.afl_map_size,
                    tuple_index_type_code=self.tuple_index_type_code,
                ):
                    counter.update(r)

                    used = False

                    if crash:
                        crashes.append(idx)

                    # If we aren't saving crashes to a separate dir, handle them
                    # the same as other inputs. However, unless AFL_CMIN_ALLOW_ANY=1,
                    # afl_showmap will not return any coverage for crashes so they will
                    # never be retained.
                    if not crash or not self.args.crash_dir:
                        for t in r:
                            if idx < m[t]:
                                m[t] = idx
                                used = True

                    if used:
                        tuple_count = len(r)
                        r.tofile(trace_pack)
                        self.p_out.put((idx, self.idx, pack_pos, tuple_count))
                        pack_pos += tuple_count * r.itemsize
                    else:
                        self.p_out.put(None)

        self.r_out.put((self.idx, m, counter, crashes))


class CombineTraceWorker(multiprocessing.Process):

    def __init__(self, args, pack_name, jobs, r_out, tuple_index_type_code):
        super().__init__()
        self.args = args
        self.pack_name = pack_name
        self.jobs = jobs
        self.r_out = r_out
        self.tuple_index_type_code = tuple_index_type_code

    def run(self):
        init_logger(self.args)
        already_have = set()
        with open(self.pack_name, "rb") as f:
            for pos, tuple_count in self.jobs:
                f.seek(pos)
                result = array.array(self.tuple_index_type_code)
                result.fromfile(f, tuple_count)
                already_have.update(result)
        self.r_out.put(already_have)


def hash_file(path):
    m = hashlib.sha1()
    with open(path, "rb") as f:
        m.update(f.read())
    return m.digest()


def dedup(args, files):
    seen_hash = set()
    result = []
    hash_list = []
    if args.workers <= 1:
        for i, h in enumerate(
            tqdm(
                map(hash_file, files),
                desc="dedup",
                total=len(files),
                ncols=0,
                leave=(len(files) > 100000),
            )
        ):
            if h in seen_hash:
                continue
            seen_hash.add(h)
            result.append(files[i])
            hash_list.append(h)
        return result, hash_list
    with multiprocessing.Pool(
        args.workers, initializer=init_logger, initargs=(args,)
    ) as pool:
        # use large chunksize to reduce multiprocessing overhead
        chunksize = max(1, min(256, len(files) // args.workers))
        for i, h in enumerate(
            tqdm(
                pool.imap(hash_file, files, chunksize),
                desc="dedup",
                total=len(files),
                ncols=0,
                leave=(len(files) > 100000),
            )
        ):
            if h in seen_hash:
                continue
            seen_hash.add(h)
            result.append(files[i])
            hash_list.append(h)
        return result, hash_list


def is_afl_dir(dirnames, filenames):
    return (
        "queue" in dirnames
        and "hangs" in dirnames
        and "crashes" in dirnames
        and "fuzzer_setup" in filenames
    )


def collect_files(args):
    paths = []
    for s in args.input:
        paths += glob.glob(s)

    files = []
    with tqdm(desc="search", unit=" files", ncols=0) as pbar:
        for path in paths:
            for root, dirnames, filenames in os.walk(path, followlinks=True):
                for dirname in dirnames:
                    if dirname.startswith("."):
                        dirnames.remove(dirname)

                if not args.crash_only and is_afl_dir(dirnames, filenames):
                    continue

                for filename in filenames:
                    if filename.startswith("."):
                        continue
                    full_path = os.path.join(root, filename)
                    if not os.path.isfile(full_path):
                        continue
                    pbar.update(1)
                    files.append(full_path)
    return files


def main():
    afl_showmap_bin = None
    file_index_type_code = None
    tuple_index_type_code = "I"

    args = init_args()

    init_logger(args)
    init(args)

    afl_showmap_bin = search_binary("afl-showmap")

    files = collect_files(args)
    if len(files) == 0:
        logger.error("no inputs in the target directory - nothing to be done")
        sys.exit(1)
    logger.info("Found %d input files in %d directories", len(files), len(args.input))

    if not args.no_dedup:
        files, hash_list = dedup(args, files)
        logger.info("Remain %d files after dedup", len(files))
    else:
        logger.info("Skipping file deduplication.")

    file_index_type_code = detect_type_code(len(files))

    logger.info("Sorting files.")
    if args.workers <= 1:
        size_list = list(map(os.path.getsize, files))
    else:
        with multiprocessing.Pool(
            args.workers, initializer=init_logger, initargs=(args,)
        ) as pool:
            chunksize = max(1, min(512, len(files) // args.workers))
            size_list = list(pool.map(os.path.getsize, files, chunksize))
    idxes = sorted(range(len(files)), key=lambda x: size_list[x])
    files = [files[idx] for idx in idxes]
    hash_list = [hash_list[idx] for idx in idxes]

    afl_map_size = None
    if "AFL_MAP_SIZE" in os.environ:
        afl_map_size = int(os.environ["AFL_MAP_SIZE"])
    elif args.nyx_mode:
        afl_map_size = get_nyx_map_size(args.exe)
        logger.info("Setting AFL_MAP_SIZE=%d", afl_map_size)
    elif b"AFL_DUMP_MAP_SIZE" in open(args.exe, "rb").read():
        output = subprocess.run(
            [args.exe],
            capture_output=True,
            env={
                **os.environ,
                "AFL_DUMP_MAP_SIZE": "1",
                "ASAN_OPTIONS": get_asan_options(),
            },
            check=False,
        ).stdout
        afl_map_size = int(output)
        logger.info("Setting AFL_MAP_SIZE=%d", afl_map_size)

    if afl_map_size:
        tuple_index_type_code = detect_type_code(afl_map_size * 9)

    logger.info("Testing the target binary")
    tuples, _ = afl_showmap(
        args,
        afl_showmap_bin,
        input_path=files[0],
        afl_map_size=afl_map_size,
        first=True,
        tuple_index_type_code=tuple_index_type_code,
    )
    if tuples:
        logger.info("ok, %d tuples recorded", len(tuples))
    else:
        logger.error("no instrumentation output detected")
        sys.exit(1)

    job_queue = multiprocessing.Queue()
    progress_queue = multiprocessing.Queue()
    result_queue = multiprocessing.Queue()

    workers = []
    for i in range(args.workers):
        p = Worker(
            args,
            i,
            afl_map_size,
            job_queue,
            progress_queue,
            result_queue,
            file_index_type_code,
            tuple_index_type_code,
            afl_showmap_bin,
        )
        p.start()
        workers.append(p)

    chunk = max(1, min(128, len(files) // args.workers))
    jobs = list(batched(enumerate(files), chunk))
    jobs += [None] * args.workers  # sentinel

    dispatcher = JobDispatcher(args, job_queue, jobs)
    dispatcher.start()

    logger.info("Processing traces")
    effective = 0
    trace_info = {}
    for _ in tqdm(files, ncols=0, smoothing=0.01):
        r = progress_queue.get()
        if r is not None:
            idx, worker_idx, pos, tuple_count = r
            trace_info[idx] = worker_idx, pos, tuple_count
            effective += 1
    dispatcher.join()

    logger.info("Obtaining trace results")
    ms = []
    crashes = []
    counter = collections.Counter()
    for _ in tqdm(range(args.workers), ncols=0):
        idx, m, c, crs = result_queue.get()
        ms.append(m)
        counter.update(c)
        crashes.extend(crs)
        workers[idx].join()
    best_idxes = list(map(min, zip(*ms)))

    if not args.crash_dir:
        logger.info(
            "Found %d unique tuples across %d files (%d effective)",
            len(counter),
            len(files),
            effective,
        )
    else:
        logger.info(
            "Found %d unique tuples across %d files (%d effective, %d crashes)",
            len(counter),
            len(files),
            effective,
            len(crashes),
        )
    all_unique = counter.most_common()

    logger.info("Processing candidates and writing output")
    already_have = set()
    count = 0
    use_sha1_filenames = bool(os.environ.get("AFL_SHA1_FILENAMES"))
    hash_cache = {}
    used_output_names = set()

    def unique_output_path(base_name):
        output_path = os.path.join(args.output, base_name)
        if output_path not in used_output_names and not os.path.exists(output_path):
            used_output_names.add(output_path)
            return output_path
        for _ in range(10000):
            prefix = f"{random.getrandbits(32):08x}"
            candidate = os.path.join(args.output, f"{prefix}_{base_name}")
            if candidate not in used_output_names and not os.path.exists(candidate):
                used_output_names.add(candidate)
                return candidate
        raise RuntimeError(f'Unable to find unique output name for "{base_name}"')

    def get_sha1(idx, input_path):
        if not args.no_dedup:
            return hash_list[idx]
        if idx in hash_cache:
            return hash_cache[idx]
        h = hash_file(input_path)
        hash_cache[idx] = h
        return h

    def save_file(idx):
        input_path = files[idx]
        if use_sha1_filenames:
            fn = base64.b16encode(get_sha1(idx, input_path)).decode("utf8").lower()
        else:
            fn = os.path.basename(input_path)
        if args.as_queue:
            if args.no_dedup:
                fn = "id:%06d,orig:%s" % (count, fn)
            else:
                if use_sha1_filenames:
                    fn = "id:%06d,hash:%s" % (count, fn)
                else:
                    fn = "id:%06d,orig:%s" % (count, fn)
        output_path = os.path.join(args.output, fn)
        use_orig_name = not use_sha1_filenames and not args.as_queue
        if use_orig_name:
            output_path = unique_output_path(fn)
        while True:
            try:
                os.link(input_path, output_path)
                break
            except OSError as exc:
                if use_orig_name and exc.errno == errno.EEXIST:
                    output_path = unique_output_path(fn)
                    continue
                if use_orig_name and os.path.exists(output_path):
                    output_path = unique_output_path(fn)
                shutil.copy(input_path, output_path)
                break

    jobs = [[] for i in range(args.workers)]
    saved = set()
    for t, c in all_unique:
        if c != 1:
            continue
        idx = best_idxes[t]
        if idx in saved:
            continue
        save_file(idx)
        saved.add(idx)
        count += 1

        worker_idx, pos, tuple_count = trace_info[idx]
        job = (pos, tuple_count)
        jobs[worker_idx].append(job)

    trace_packs = []
    workers = []
    for i in range(args.workers):
        pack_name = os.path.join(args.output, ".traces", f"{i}.pack")
        trace_f = open(pack_name, "rb")
        trace_packs.append(trace_f)

        p = CombineTraceWorker(
            args, pack_name, jobs[i], result_queue, tuple_index_type_code
        )
        p.start()
        workers.append(p)

    for _ in range(args.workers):
        result = result_queue.get()
        already_have.update(result)

    for t, c in tqdm(list(reversed(all_unique)), ncols=0):
        if t in already_have:
            continue

        idx = best_idxes[t]
        save_file(idx)
        count += 1

        worker_idx, pos, tuple_count = trace_info[idx]
        trace_pack = trace_packs[worker_idx]
        trace_pack.seek(pos)
        result = array.array(tuple_index_type_code)
        result.fromfile(trace_pack, tuple_count)

        already_have.update(result)

    for f in trace_packs:
        f.close()

    if args.crash_dir:
        logger.info("Saving crashes to %s", args.crash_dir)
        crash_files = [files[c] for c in crashes]

        if args.no_dedup:
            # Unless we deduped previously, we have to dedup the crash files
            # now.
            crash_files, hash_list = dedup(args, crash_files)

        for idx, crash_path in enumerate(crash_files):
            if use_sha1_filenames:
                fn = base64.b16encode(hash_list[idx]).decode("utf8").lower()
            else:
                fn = os.path.basename(crash_path)
            output_path = os.path.join(args.crash_dir, fn)
            try:
                os.link(crash_path, output_path)
            except OSError:
                try:
                    shutil.copy(crash_path, output_path)
                except shutil.Error:
                    # This error happens when src and dest are hardlinks of the
                    # same file. We have nothing to do in this case, but handle
                    # it gracefully.
                    pass

    if count == 1:
        logger.warning("all test cases had the same traces, check syntax!")
    logger.info('narrowed down to %s files, saved in "%s"', count, args.output)
    if not os.environ.get("AFL_KEEP_TRACES"):
        logger.info("Deleting trace files")
        trace_dir = os.path.join(args.output, ".traces")
        shutil.rmtree(trace_dir, ignore_errors=True)


if __name__ == "__main__":
    main()
