#!/usr/bin/env python3
"""
Tool to check for recursive calls in toxcore C code.

Usage:

cat toxav/*.c toxcore/*.c toxencryptsave/*.c \
  | clang `pkg-config --cflags libsodium opus vpx` \
      -Itoxav -Itoxcore -Itoxencryptsave -S -emit-llvm -xc - -o- \
  | opt -analyze -print-callgraph 2>&1 \
  | other/analysis/check_recursion
"""
import collections
import fileinput
import re
import sys
import time
from typing import Dict
from typing import List
from typing import Set


def load_callgraph() -> Dict:
    """
    Parses the output from opt -print-callgraph from stdin or argv.

    Returns graph as dict[str, list[str]] containing nodes with their outgoing
    edges.
    """
    graph = collections.defaultdict(set)
    cur = None
    for line in fileinput.input():
        found = re.search("Call graph node for function: '(.*)'", line)
        if found:
            cur = found.group(1)
        if cur:
            found = re.search("calls function '(.*)'", line)
            if found:
                graph[cur].add(found.group(1))

    return {k: sorted(v) for k, v in graph.items()}


def walk(
        visited: Set,
        callgraph: Dict,
        cycles: Set,
        stack: List,
        cur: str,
) -> None:
    """
    Detects cycles in the callgraph and adds them to the cycles parameter.
    """
    if cur in visited:
        return
    stack.append(cur)
    for callee in callgraph.get(cur, ()):
        try:
            cycles.add(" -> ".join(stack[stack.index(callee):] + [callee]))
        except ValueError:
            walk(visited, callgraph, cycles, stack, callee)
            visited.add(callee)
    stack.pop()


def get_time() -> int:
    """
    Return the current time in milliseconds.
    """
    return int(round(time.time() * 1000))


def find_recursion(expected: Set) -> None:
    """
    Main function: detects cycles and prints them.

    Takes a set of expected cycles. If any of the expected cycles was not found,
    or any unexpected cycle was found, the program exits with an error.
    """
    start = prev = get_time()
    print("[+0000=0000] Generating callgraph")
    callgraph = load_callgraph()

    now = get_time()
    print("[+%04d=%04d] Finding recursion" % (now - prev, now - start))
    prev = now

    cycles = set()
    visited = set()
    for func in sorted(callgraph.keys()):
        walk(visited, callgraph, cycles, [], func)

    now = get_time()
    if cycles:
        print("[+%04d=%04d] Recursion detected:" % (now - prev, now - start))
        for cycle in sorted(cycles):
            if cycle in expected:
                print(" - " + cycle + " (expected)")
                expected.remove(cycle)
                cycles.remove(cycle)
            else:
                print(" - " + cycle)
    else:
        print("[+%04d=%04d] No recursion detected" % (now - prev, now - start))

    if expected:
        print("Expected recursion no longer present: " + str(list(expected)))
    if expected or cycles:
        sys.exit(1)


find_recursion(expected={
    "add_to_closest -> add_to_closest",
    "add_to_list -> add_to_list",
})
