#!/usr/bin/env python3
# pylint: disable=C0103,C0114,C0115,C0116,C0123,C0209,C0301,C0302,R0902,R0904,R0913,R0914,R0912,R0915,W0511,W0621
######################################################################

import argparse
import glob
import os
import re
import sys
import textwrap

# from pprint import pprint, pformat


# This class is used to represents both AstNode and DfgVertex sub-types
class Node:

    def __init__(self, name, superClass, file=None, lineno=None):
        self._name = name
        self._superClass = superClass
        self._subClasses = []  # Initially list, but tuple after completion
        self._allSuperClasses = None  # Computed on demand after completion
        self._allSubClasses = None  # Computed on demand after completion
        self._disableBase = False  # Skip base class const checking
        self._typeId = None  # Concrete type identifier number for leaf classes
        self._typeIdMin = None  # Lowest type identifier number for class
        self._typeIdMax = None  # Highest type identifier number for class
        self._file = file  # File this class is defined in
        self._lineno = lineno  # Line this class is defined on
        self._ordIdx = None  # Ordering index of this class
        self._arity = -1  # Arity of node
        self._ops = {}  # Operands of node
        self._ptrs = []  # Pointer members of node (name, types)

    @property
    def name(self):
        return self._name

    @property
    def superClass(self):
        return self._superClass

    @property
    def isRoot(self):
        return self.superClass is None

    @property
    def isCompleted(self):
        return isinstance(self._subClasses, tuple)

    @property
    def disableBase(self):
        return self._disableBase

    def disableBaseSet(self):
        self._disableBase = True

    @property
    def file(self):
        return self._file

    @property
    def lineno(self):
        return self._lineno

    @property
    def ptrs(self):
        assert self.isCompleted
        return self._ptrs

    # Pre completion methods
    def addSubClass(self, subClass):
        assert not self.isCompleted
        self._subClasses.append(subClass)

    def addOp(self, n, name, monad, kind, legals):
        assert 1 <= n <= 4
        self._ops[n] = (name, monad, kind, legals)
        self._arity = max(self._arity, n)

    def getOp(self, n):
        assert 1 <= n <= 4
        op = self._ops.get(n, None)
        if op is not None:
            return op
        if not self.isRoot:
            return self.superClass.getOp(n)
        return None

    def addPtr(self, name, monad, kind, legals):
        name = re.sub(r'^m_', '', name)
        self._ptrs.append({'name': name, 'monad': monad, 'kind': kind, 'legals': legals})

    # Computes derived properties over entire class hierarchy.
    # No more changes to the hierarchy are allowed once this was called
    def complete(self, typeId=0, ordIdx=0):
        assert not self.isCompleted
        # Sort sub-classes and convert to tuple, which marks completion
        self._subClasses = tuple(
            sorted(self._subClasses, key=lambda _: (bool(_._subClasses), _.name)))  # pylint: disable=protected-access

        self._ordIdx = ordIdx
        ordIdx = ordIdx + 1

        if self.isRoot:
            self._arity = 0
        else:
            self._arity = max(self._arity, self._superClass.arity)

        # Leaves
        if self.isLeaf:
            self._typeId = typeId
            return typeId + 1, ordIdx

        # Non-leaves
        for subClass in self._subClasses:
            typeId, ordIdx = subClass.complete(typeId, ordIdx)
        return typeId, ordIdx

    # Post completion methods
    @property
    def subClasses(self):
        assert self.isCompleted
        return self._subClasses

    @property
    def isLeaf(self):
        assert self.isCompleted
        return not self.subClasses

    @property
    def allSuperClasses(self):
        assert self.isCompleted
        if self._allSuperClasses is None:
            if self.superClass is None:
                self._allSuperClasses = ()
            else:
                self._allSuperClasses = self.superClass.allSuperClasses + (self.superClass, )
        return self._allSuperClasses

    @property
    def allSubClasses(self):
        assert self.isCompleted
        if self._allSubClasses is None:
            if self.isLeaf:
                self._allSubClasses = ()
            else:
                self._allSubClasses = self.subClasses + tuple(_ for subClass in self.subClasses
                                                              for _ in subClass.allSubClasses)
        return self._allSubClasses

    @property
    def typeId(self):
        assert self.isCompleted
        assert self.isLeaf
        return self._typeId

    @property
    def typeIdMin(self):
        assert self.isCompleted
        if self.isLeaf:
            return self.typeId
        if self._typeIdMin is None:
            self._typeIdMin = min(_.typeIdMin for _ in self.allSubClasses)
        return self._typeIdMin

    @property
    def typeIdMax(self):
        assert self.isCompleted
        if self.isLeaf:
            return self.typeId
        if self._typeIdMax is None:
            self._typeIdMax = max(_.typeIdMax for _ in self.allSubClasses)
        return self._typeIdMax

    @property
    def ordIdx(self):
        assert self.isCompleted
        return self._ordIdx

    @property
    def arity(self):
        assert self.isCompleted
        return self._arity

    def isSubClassOf(self, other):
        assert self.isCompleted
        if self is other:
            return True
        return self in other.allSubClasses


AstNodes = {}
AstNodeList = None

DfgVertices = {}
DfgVertexList = None

ClassRefs = {}
Stages = {}


class Cpt:

    def __init__(self):
        self.did_out_tree = False
        self.in_filename = ""
        self.in_linenum = 1
        self.out_filename = ""
        self.out_linenum = 1
        self.out_lines = []
        self.tree_skip_visit = {}
        self.treeop = {}
        self._exec_nsyms = 0
        self._exec_syms = {}

    def error(self, txt):
        sys.exit("%%Error: %s:%d: %s" % (self.in_filename, self.in_linenum, txt))

    def print(self, txt):
        self.out_lines.append(txt)

    def output_func(self, func):
        self.out_lines.append(func)

    def _output_line(self):
        self.print("#line " + str(self.out_linenum + 2) + " \"" + self.out_filename + "\"\n")

    def process(self, in_filename, out_filename):
        self.in_filename = in_filename
        self.out_filename = out_filename
        ln = 0
        didln = False

        # Read the file and parse into list of functions that generate output
        with open(self.in_filename, "r", encoding="utf8") as fhi:
            for line in fhi:
                ln += 1
                if not didln:
                    self.print("#line " + str(ln) + " \"" + self.in_filename + "\"\n")
                    didln = True
                match = re.match(r'^\s+(TREE.*)$', line)
                if match:
                    func = match.group(1)
                    self.in_linenum = ln
                    self.print("//" + line)
                    self.output_func(lambda self: self._output_line())
                    self.tree_line(func)
                    didln = False
                elif not re.match(r'^\s*(#define|/[/\*])\s*TREE', line) and re.search(
                        r'\s+TREE', line):
                    self.error("Unknown astgen line: " + line)
                else:
                    self.print(line)

        # Put out the resultant file, if the list has a reference to a
        # function, then call that func to generate output
        with open_file(self.out_filename) as fho:
            togen = self.out_lines
            for line in togen:
                if isinstance(line, str):
                    self.out_lines = [line]
                else:
                    self.out_lines = []
                    line(self)  # lambda call
                for out in self.out_lines:
                    for _ in re.findall(r'\n', out):
                        self.out_linenum += 1
                    fho.write(out)

    def tree_line(self, func):
        func = re.sub(r'\s*//.*$', '', func)
        func = re.sub(r'\s*;\s*$', '', func)

        # doflag "S" indicates an op specifying short-circuiting for a type.
        match = re.search(
            #       1   2                 3                  4
            r'TREEOP(1?)([ACSV]?)\s*\(\s*\"([^\"]*)\"\s*,\s*\"([^\"]*)\"\s*\)',
            func)
        match_skip = re.search(r'TREE_SKIP_VISIT\s*\(\s*\"([^\"]*)\"\s*\)', func)

        if match:
            order = match.group(1)
            doflag = match.group(2)
            fromn = match.group(3)
            to = match.group(4)
            # self.print("// $fromn $to\n")
            if not self.did_out_tree:
                self.did_out_tree = True
                self.output_func(lambda self: self.tree_match_base())
            match = re.search(r'Ast([a-zA-Z0-9]+)\s*\{(.*)\}\s*$', fromn)
            if not match:
                self.error("Can't parse from function: " + func)
            typen = match.group(1)
            subnodes = match.group(2)
            if AstNodes[typen].isRoot:
                self.error("Unknown AstNode typen: " + typen + ": in " + func)

            mif = ""
            if doflag == '':
                mif = "m_doNConst"
            elif doflag == 'A':
                mif = ""
            elif doflag == 'C':
                mif = "m_doCpp"
            elif doflag == 'S':
                mif = "m_doNConst"  # Not just for m_doGenerate
            elif doflag == 'V':
                mif = "m_doV"
            else:
                self.error("Unknown flag: " + doflag)

            subnodes = re.sub(r',,', '__ESCAPEDCOMMA__', subnodes)
            for subnode in re.split(r'\s*,\s*', subnodes):
                subnode = re.sub(r'__ESCAPEDCOMMA__', ',', subnode)
                if re.match(r'^\$([a-zA-Z0-9]+)$', subnode):
                    continue  # "$lhs" is just a comment that this op has a lhs
                if re.search(r'DISABLE_BASE', subnode):
                    AstNodes[typen].disableBaseSet()
                    return
                subnodeif = subnode
                subnodeif = re.sub(r'\$([a-zA-Z0-9]+)\.cast([A-Z][A-Za-z0-9]+)$',
                                   r'VN_IS(nodep->\1(),\2)', subnodeif)
                subnodeif = re.sub(r'\$([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)$', r'nodep->\1()->\2()',
                                   subnodeif)
                subnodeif = self.add_nodep(subnodeif)
                if mif != "" and subnodeif != "":
                    mif += " && "
                mif += subnodeif

            exec_func = self.treeop_exec_func(to)
            exec_func = re.sub(r'([-()a-zA-Z0-9_>]+)->cast([A-Z][A-Za-z0-9]+)\(\)',
                               r'VN_CAST(\1,\2)', exec_func)

            if typen not in self.treeop:
                self.treeop[typen] = []
            n = len(self.treeop[typen])
            typefunc = {
                'order': order,
                'comment': func,
                'match_func': "match_" + typen + "_" + str(n),
                'match_if': mif,
                'exec_func': exec_func,
                'uinfo': re.sub(r'[ \t\"\{\}]+', ' ', func),
                'uinfo_level': (0 if re.match(r'^!', to) else 7),
                'short_circuit': (doflag == 'S'),
            }
            self.treeop[typen].append(typefunc)

        elif match_skip:
            typen = match_skip.group(1)
            self.tree_skip_visit[typen] = 1
            if typen not in AstNodes:
                self.error("Unknown node type: " + typen)

        else:
            self.error("Unknown astgen op: " + func)

    @staticmethod
    def add_nodep(strg):
        strg = re.sub(r'\$([a-zA-Z0-9]+)', r'nodep->\1()', strg)
        return strg

    def _exec_syms_recurse(self, aref):
        for sym in aref:
            if isinstance(sym, list):
                self._exec_syms_recurse(sym)
            elif re.search(r'^\$.*', sym):
                if sym not in self._exec_syms:
                    self._exec_nsyms += 1
                    self._exec_syms[sym] = "arg" + str(self._exec_nsyms) + "p"

    def _exec_new_recurse(self, aref):
        out = "new " + aref[0] + "(nodep->fileline()"
        first = True
        for sym in aref:
            if first:
                first = False
                continue
            out += ", "
            if isinstance(sym, list):
                out += self._exec_new_recurse(sym)
            elif re.match(r'^\$.*', sym):
                out += self._exec_syms[sym]
            else:
                out += sym
        return out + ")"

    def treeop_exec_func(self, func):
        out = ""
        func = re.sub(r'^!', '', func)

        if re.match(r'^\s*[a-zA-Z0-9]+\s*\(', func):  # Function call
            outl = re.sub(r'\$([a-zA-Z0-9]+)', r'nodep->\1()', func)
            out += outl + ";"
        elif re.match(r'^\s*Ast([a-zA-Z0-9]+)\s*\{\s*(.*)\s*\}$', func):
            aref = None
            # Recursive array with structure to form
            astack = []
            forming = ""
            argtext = func + "\000"  # EOF character
            for tok in argtext:
                if tok == "\000":
                    pass
                elif re.match(r'\s+', tok):
                    pass
                elif tok == "{":
                    newref = [forming]
                    if not aref:
                        aref = []
                    aref.append(newref)
                    astack.append(aref)
                    aref = newref
                    forming = ""
                elif tok == "}":
                    if forming:
                        aref.append(forming)
                    if len(astack) == 0:
                        self.error("Too many } in execution function: " + func)
                    aref = astack.pop()
                    forming = ""
                elif tok == ",":
                    if forming:
                        aref.append(forming)
                    forming = ""
                else:
                    forming += tok
            if not (aref and len(aref) == 1):
                self.error("Badly formed execution function: " + func)
            aref = aref[0]

            # Assign numbers to each $ symbol
            self._exec_syms = {}
            self._exec_nsyms = 0
            self._exec_syms_recurse(aref)

            for sym in sorted(self._exec_syms.keys(), key=lambda val: self._exec_syms[val]):
                argnp = self._exec_syms[sym]
                arg = self.add_nodep(sym)
                out += "AstNodeExpr* " + argnp + " = " + arg + "->unlinkFrBack();\n"

            out += "AstNodeExpr* newp = " + self._exec_new_recurse(aref) + ";\n"
            out += "nodep->replaceWith(newp);"
            out += "VL_DO_DANGLING(nodep->deleteTree(), nodep);"
        elif func == "NEVER":
            out += "nodep->v3fatalSrc(\"Executing transform that was NEVERed\");"
        elif func == "DONE":
            pass
        else:
            self.error("Unknown execution function format: " + func + "\n")
        return out

    def tree_match_base(self):
        self.tree_match()
        self.tree_base()

    def tree_match(self):
        self.print("    // TREEOP functions, each return true if they matched & transformed\n")
        for base in sorted(self.treeop.keys()):
            for typefunc in self.treeop[base]:
                self.print("    // Generated by astgen\n")
                self.print("    bool " + typefunc['match_func'] + "(Ast" + base + "* nodep) {\n")
                self.print("\t// " + typefunc['comment'] + "\n")
                self.print("\tif (" + typefunc['match_if'] + ") {\n")
                self.print("\t    UINFO(" + str(typefunc['uinfo_level']) + ", cvtToHex(nodep)" +
                           " << \" " + typefunc['uinfo'] + "\");\n")
                self.print("\t    " + typefunc['exec_func'] + "\n")
                self.print("\t    return true;\n")
                self.print("\t}\n")
                self.print("\treturn false;\n")
                self.print("    }\n", )

    def tree_base(self):
        self.print("    // TREEOP visitors, call each base type's match\n")
        self.print("    // Bottom class up, as more simple transforms are generally better\n")
        for node in AstNodeList:
            out_for_type_sc = []
            out_for_type = []
            if node.disableBase:
                classes = []
                out_for_type.extend("        // DISABLE_BASE\n")
            else:
                classes = list(node.allSuperClasses)

            classes.append(node)
            for base in classes:
                base = base.name
                if base not in self.treeop:
                    continue
                for typefunc in self.treeop[base]:
                    lines = ["        if (" + typefunc['match_func'] + "(nodep)) return;\n"]
                    if typefunc['short_circuit']:  # short-circuit match fn
                        out_for_type_sc.extend(lines)
                    else:  # Standard match fn
                        if typefunc['order']:  # TREEOP1's go in front of others
                            out_for_type = lines + out_for_type
                        else:
                            out_for_type.extend(lines)

            # We need to deal with two cases. For short circuited functions we
            # evaluate the LHS, then apply the short-circuit matches, then
            # evaluate the RHS and possibly THS (ternary operators may
            # short-circuit) and apply all the other matches.

            # For types without short-circuits, we just use iterateChildren, which
            # saves one comparison.
            if len(out_for_type_sc) > 0:  # Short-circuited types
                self.print("    // Generated by astgen with short-circuiting\n" +
                           "    void visit(Ast" + node.name + "* nodep) override {\n" +
                           "      iterateAndNextNull(nodep->{op1}());\n".format(
                               op1=node.getOp(1)[0]) + "".join(out_for_type_sc))
                if out_for_type[0]:
                    self.print(
                        "      iterateAndNextNull(nodep->{op2}());\n".format(op2=node.getOp(2)[0]))
                    if node.isSubClassOf(AstNodes["NodeTriop"]):
                        self.print("      iterateAndNextNull(nodep->{op3}());\n".format(
                            op3=node.getOp(3)[0]))
                    self.print("".join(out_for_type) + "    }\n")
            elif len(out_for_type) > 0:  # Other types with something to print
                skip = node.name in self.tree_skip_visit
                gen = "Gen" if skip else ""
                virtual = "virtual " if skip else ""
                override = "" if skip else " override"
                self.print("    // Generated by astgen\n" + "    " + virtual + "void visit" + gen +
                           "(Ast" + node.name + "* nodep)" + override + " {\n" +
                           ("" if skip else "        iterateChildren(nodep);\n") +
                           ''.join(out_for_type) + "    }\n")


######################################################################
######################################################################


def partitionAndStrip(string, separator):
    return map(lambda _: _.strip(), string.partition(separator))


def parseOpType(string):
    # Return [Optional/List, AstNodeKind, LegalAstKinds]
    match = re.match(r'^(\w+)\[(.*?)\]$', string)
    if match:
        monad, kind = match.groups()
        if monad not in ("Optional", "List"):
            return None
        kind = parseOpType(kind)
        if not kind or kind[0]:
            return None
        return monad, kind[1], kind[2]
    match = re.match(r'^Ast(\w+)<(.*?)>$', string)
    if match:
        kind = match.group(1)
        legals = match.group(2)
        legals = re.sub(r'\bAst', '', legals)
        return "", kind, legals
    match = re.match(r'^Ast(\w+)$', string)
    if match:
        return "", match.group(1), match.group(1)
    return None


def read_types(filename, Nodes, prefix):
    hasErrors = False

    def error(lineno, message):
        nonlocal hasErrors
        print(filename + ":" + str(lineno) + ": %Error: " + message, file=sys.stderr)
        hasErrors = True

    node = None
    hasAstgenMembers = False

    def checkFinishedNode(node):
        nonlocal hasAstgenMembers
        if not node:
            return
        if not hasAstgenMembers:
            error(
                node.lineno,
                "'{p}{n}' does not contain 'ASTGEN_MEMBERS_{p}{n};'".format(p=prefix, n=node.name))
        hasAstgenMembers = False

    with open(filename, "r", encoding="utf8") as fh:
        for (lineno, line) in enumerate(fh, start=1):
            line = line.strip()
            if not line:
                continue

            match = re.search(r'^\s*(class|struct)\s*(\S+)', line)
            if match:
                classn = match.group(2)
                match = re.search(r':\s*public\s+(\S+)', line)
                supern = match.group(1) if match else ""
                if re.search(prefix, supern):
                    classn = re.sub(r'^' + prefix, '', classn)
                    supern = re.sub(r'^' + prefix, '', supern)
                    if not supern:
                        sys.exit("%Error: '{p}{c}' has no super-class".format(p=prefix, c=classn))
                    checkFinishedNode(node)
                    superClass = Nodes[supern]
                    node = Node(classn, superClass, filename, lineno)
                    superClass.addSubClass(node)
                    Nodes[classn] = node
            if not node:
                continue

            if re.match(r'^\s*ASTGEN_MEMBERS_' + prefix + node.name + ';', line):
                hasAstgenMembers = True

            if prefix != "Ast":
                continue

            match = re.match(r'^\s*//\s*@astgen\s+(.*)$', line)
            if match:
                decl = re.sub(r'//.*$', '', match.group(1))
                what, sep, rest = partitionAndStrip(decl, ":=")
                what = re.sub(r'\s+', ' ', what)
                if not sep:
                    error(
                        lineno,
                        "Malformed '@astgen' directive (expecting '<keywords> := <description>'): "
                        + decl)
                elif what in ("op1", "op2", "op3", "op4"):
                    n = int(what[-1])
                    ident, sep, kind = partitionAndStrip(rest, ":")
                    ident = ident.strip()
                    if not sep or not re.match(r'^\w+$', ident):
                        error(
                            lineno, "Malformed '@astgen " + what + "' directive (expecting '" +
                            what + " := <identifier> : <type>': " + decl)
                    else:
                        kind = parseOpType(kind)
                        if not kind:
                            error(
                                lineno, "Bad type for '@astgen " + what +
                                "' (expecting Ast*, Optional[Ast*], or List[Ast*]):" + decl)
                        elif node.getOp(n) is not None:
                            error(lineno, "Already defined " + what + " for " + node.name)
                        else:
                            node.addOp(n, ident, *kind)
                elif what in ("alias op1", "alias op2", "alias op3", "alias op4"):
                    n = int(what[-1])
                    ident = rest.strip()
                    if not re.match(r'^\w+$', ident):
                        error(
                            lineno, "Malformed '@astgen " + what + "' directive (expecting '" +
                            what + " := <identifier>': " + decl)
                    else:
                        op = node.getOp(n)
                        if op is None:
                            error(lineno, "Aliased op" + str(n) + " is not defined")
                        else:
                            node.addOp(n, ident, *op[1:])
                elif what == "ptr":
                    ident, sep, kind = partitionAndStrip(rest, ":")
                    ident = ident.strip()
                    kind = parseOpType(kind)
                    if not kind:
                        error(
                            lineno, "Bad type for '@astgen " + what +
                            "' (expecting Ast*, Optional[Ast*], or List[Ast*]):" + decl)
                    if not re.match(r'^m_(\w+)$', ident):
                        error(
                            lineno, "Malformed '@astgen ptr'"
                            " identifier (expecting m_ in '" + ident + "')")
                    else:
                        node.addPtr(ident, *kind)
                else:
                    error(
                        lineno, "Malformed @astgen what (expecting 'op1'..'op4'," +
                        " 'alias op1'.., 'ptr'): " + what)
            else:
                line = re.sub(r'//.*$', '', line)
                if re.match(r'.*[Oo]p[1-9].*', line):
                    error(lineno, "Use generated accessors to access op<N> operands")

            if re.match(r'^\s*Ast[A-Z][A-Za-z0-9_]+\s*\*(\s*const)?\s+m_[A-Za-z0-9_]+\s*;', line):
                error(lineno, "Use '@astgen ptr' for Ast pointer members: " + line)

        checkFinishedNode(node)
    if hasErrors:
        sys.exit("%Error: Stopping due to errors reported above")


def check_types(sortedTypes, prefix, abstractPrefix):
    baseClass = prefix + abstractPrefix

    # Check all leaf types are not AstNode* and non-leaves are AstNode*
    for node in sortedTypes:
        if re.match(r'^' + abstractPrefix, node.name):
            if node.isLeaf:
                sys.exit("%Error: Final {b} subclasses must not be named {b}*: {p}{n}".format(
                    b=baseClass, p=prefix, n=node.name))
        else:
            if not node.isLeaf:
                sys.exit("%Error: Non-final {b} subclasses must be named {b}*: {p}{n}".format(
                    b=baseClass, p=prefix, n=node.name))

    # Skip ordering check for Dfg
    if prefix == "Dfg":
        return

    # Check ordering of node definitions
    hasOrderingError = False

    files = tuple(sorted(set(_.file for _ in sortedTypes if _.file is not None)))

    for file in files:
        nodes = tuple(filter(lambda _, f=file: _.file == f, sortedTypes))
        expectOrder = tuple(sorted(nodes, key=lambda _: (_.isLeaf, _.ordIdx)))
        actualOrder = tuple(sorted(nodes, key=lambda _: _.lineno))
        expect = {node: pred for pred, node in zip((None, ) + expectOrder[:-1], expectOrder)}
        actual = {node: pred for pred, node in zip((None, ) + actualOrder[:-1], actualOrder)}
        for node in nodes:
            if expect[node] != actual[node]:
                hasOrderingError = True
                pred = expect[node]
                print(
                    "{file}:{lineno}: %Error: Definition of '{p}{n}' is out of order. Should be {where}."
                    .format(file=file,
                            lineno=node.lineno,
                            p=prefix,
                            n=node.name,
                            where=("right after '" + prefix + pred.name +
                                   "'" if pred else "first in file")),
                    file=sys.stderr)

    if hasOrderingError:
        sys.exit("%Error: Stopping due to out of order definitions listed above")


def read_stages(filename):
    with open(filename, "r", encoding="utf8") as fh:
        n = 100
        for line in fh:
            line = re.sub(r'//.*$', '', line)
            if re.match(r'^\s*$', line):
                continue
            match = re.search(r'\s([A-Za-z0-9]+)::', line)
            if match:
                stage = match.group(1) + ".cpp"
                if stage not in Stages:
                    Stages[stage] = n
                    n += 1


def read_refs(filename):
    basename = re.sub(r'.*/', '', filename)
    with open(filename, "r", encoding="utf8") as fh:
        for line in fh:
            line = re.sub(r'//.*$', '', line)
            for match in re.finditer(r'\bnew\s*(Ast[A-Za-z0-9_]+)', line):
                ref = match.group(1)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['newed'][basename] = 1
            for match in re.finditer(r'\b(Ast[A-Za-z0-9_]+)', line):
                ref = match.group(1)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['used'][basename] = 1
            for match in re.finditer(r'(VN_IS|VN_AS|VN_CAST)\([^.]+, ([A-Za-z0-9_]+)', line):
                ref = "Ast" + match.group(2)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['used'][basename] = 1


def open_file(filename):
    fh = open(filename, "w", encoding="utf8")  # pylint: disable=consider-using-with
    if re.search(r'\.txt$', filename):
        fh.write("// Generated by astgen\n")
    else:
        fh.write('// Generated by astgen // -*- mode: C++; c-file-style: "cc-mode" -*-' + "\n")
    return fh


# ---------------------------------------------------------------------


def write_report(filename):
    with open_file(filename) as fh:

        fh.write("Processing stages (approximate, based on order in Verilator.cpp):\n")
        for classn in sorted(Stages.keys(), key=lambda val: Stages[val]):
            fh.write("  " + classn + "\n")

        fh.write("\nClasses:\n")
        for node in AstNodeList:
            fh.write("  class Ast%-17s\n" % node.name)
            fh.write("    arity:  {}\n".format(node.arity))
            fh.write("    parent: ")
            for superClass in node.allSuperClasses:
                if not superClass.isRoot:
                    fh.write("Ast%-12s " % superClass.name)
            fh.write("\n")
            fh.write("    childs:  ")
            for subClass in node.allSubClasses:
                fh.write("Ast%-12s " % subClass.name)
            fh.write("\n")
            if ("Ast" + node.name) in ClassRefs:  # pylint: disable=superfluous-parens
                refs = ClassRefs["Ast" + node.name]
                fh.write("    newed:  ")
                for stage in sorted(refs['newed'].keys(),
                                    key=lambda val: Stages[val] if (val in Stages) else -1):
                    fh.write(stage + "  ")
                fh.write("\n")
                fh.write("    used:   ")
                for stage in sorted(refs['used'].keys(),
                                    key=lambda val: Stages[val] if (val in Stages) else -1):
                    fh.write(stage + "  ")
                fh.write("\n")
            fh.write("\n")


################################################################################
# Common code generation
################################################################################


def write_forward_class_decls(prefix, nodeList):
    with open_file("V3{p}__gen_forward_class_decls.h".format(p=prefix)) as fh:
        for node in nodeList:
            fh.write("class {p}{n:<17} // ".format(p=prefix, n=node.name + ";"))
            for superClass in node.allSuperClasses:
                fh.write("{p}{n:<12} ".format(p=prefix, n=superClass.name))
            fh.write("\n")


def write_visitor_decls(prefix, nodeList):
    with open_file("V3{p}__gen_visitor_decls.h".format(p=prefix)) as fh:
        for node in nodeList:
            if not node.isRoot:
                fh.write("virtual void visit({p}{n}*);\n".format(p=prefix, n=node.name))


def write_visitor_defns(prefix, nodeList, visitor):
    with open_file("V3{p}__gen_visitor_defns.h".format(p=prefix)) as fh:
        variable = "nodep" if prefix == "Ast" else "vtxp"
        for node in nodeList:
            base = node.superClass
            if base is not None:
                fh.write("void {c}::visit({p}{n}* {v}) {{ visit(static_cast<{p}{b}*>({v})); }}\n".
                         format(c=visitor, p=prefix, n=node.name, b=base.name, v=variable))


def write_type_enum(prefix, nodeList):
    root = next(_ for _ in nodeList if _.isRoot)
    with open_file("V3{p}__gen_type_enum.h".format(p=prefix)) as fh:

        fh.write("    enum en : uint16_t {\n")
        for node in sorted(filter(lambda _: _.isLeaf, nodeList), key=lambda _: _.typeId):
            fh.write("        {t} = {n},\n".format(t=node.name, n=node.typeId))
        fh.write("    };\n")
        fh.write("    static constexpr size_t NUM_TYPES() {{ return {n}; }}\n".format(
            n=root.typeIdMax + 1))

        fh.write("    enum bounds : uint16_t {\n")
        for node in sorted(filter(lambda _: not _.isLeaf, nodeList), key=lambda _: _.typeIdMin):
            fh.write("        first{t} = {n},\n".format(t=node.name, n=node.typeIdMin))
            fh.write("        last{t}  = {n},\n".format(t=node.name, n=node.typeIdMax))
        fh.write("    };\n")

        fh.write("    const char* ascii() const VL_MT_SAFE {\n")
        fh.write("        static const char* const names[] = {\n")
        for node in sorted(filter(lambda _: _.isLeaf, nodeList), key=lambda _: _.typeId):
            fh.write('            "{T}",\n'.format(T=node.name.upper()))
        fh.write("        };\n")
        fh.write("        return names[m_e];\n")
        fh.write("    }\n")


def write_type_tests(prefix, nodeList):
    with open_file("V3{p}__gen_type_tests.h".format(p=prefix)) as fh:
        fh.write("// For internal use. They assume argument is not nullptr.\n")
        if prefix == "Ast":
            base = "AstNode"
            variable = "nodep"
            enum = "VNType"
        elif prefix == "Dfg":
            base = "DfgVertex"
            variable = "vtxp"
            enum = "VDfgType"
        else:
            base = None
            variable = None
            enum = None
        for node in nodeList:
            fh.write(
                "template<> inline bool {b}::privateTypeTest<{p}{n}>(const {b}* {v}) {{ ".format(
                    b=base, p=prefix, n=node.name, v=variable))
            if node.isRoot:
                fh.write("return true;")
            elif not node.isLeaf:
                fh.write(
                    "return static_cast<int>({v}->type()) >= static_cast<int>({e}::first{t}) && static_cast<int>({v}->type()) <= static_cast<int>({e}::last{t});"
                    .format(v=variable, e=enum, t=node.name))
            else:
                fh.write("return {v}->type() == {e}::{t};".format(v=variable, e=enum, t=node.name))
            fh.write(" }\n")


################################################################################
# Ast code generation
################################################################################


def write_ast_type_info(filename):
    with open_file(filename) as fh:
        for node in sorted(filter(lambda _: _.isLeaf, AstNodeList), key=lambda _: _.typeId):
            opTypeList = []
            opNameList = []
            for n in range(1, 5):
                op = node.getOp(n)
                if not op:
                    opTypeList.append('OP_UNUSED')
                    opNameList.append('op{0}p'.format(n))
                else:
                    name, monad, _, _ = op
                    if not monad:
                        opTypeList.append('OP_USED')
                    elif monad == "Optional":
                        opTypeList.append('OP_OPTIONAL')
                    elif monad == "List":
                        opTypeList.append('OP_LIST')
                    opNameList.append(name)
            # opTypeStr = ', '.join(opTypeList)
            opTypeStr = ', '.join(['VNTypeInfo::{0}'.format(s) for s in opTypeList])
            opNameStr = ', '.join(['"{0}"'.format(s) for s in opNameList])
            fh.write(
                '    {{ "Ast{name}", {{{opTypeStr}}}, {{{opNameStr}}}, sizeof(Ast{name}) }},\n'.
                format(
                    name=node.name,
                    opTypeStr=opTypeStr,
                    opNameStr=opNameStr,
                ))


def write_ast_impl(filename):
    with open_file(filename) as fh:

        def emitBlock(pattern, **fmt):
            fh.write(pattern.format(**fmt))

        for node in AstNodeList:
            if node.name == "Node":
                continue
            emitBlock("const char* Ast{t}::brokenGen() const {{\n", t=node.name)
            if node.superClass.name != 'Node':
                emitBlock("    BROKEN_BASE_RTN(Ast{base}::brokenGen());\n",
                          base=node.superClass.name)
            for ptr in node.ptrs:
                if ptr['monad'] == 'Optional':
                    emitBlock("    BROKEN_RTN(m_{name} && !m_{name}->brokeExists());\n",
                              name=ptr['name'])
                else:
                    emitBlock("    BROKEN_RTN(!m_{name});\n" +
                              "    BROKEN_RTN(!m_{name}->brokeExists());\n",
                              name=ptr['name'])
                legals = ptr['legals'].split('|')
                if legals and legals != ["Node"]:
                    emitBlock("    BROKEN_RTN(m_{name} && !(", name=ptr['name'])
                    eor = ""
                    for legal in legals:
                        # We use privateTypeTest, as VN_IS would assert that we know
                        # the type is correct, but we want to check regardless,
                        # to find errors after raw node edits/replacements
                        emitBlock("{eor}privateTypeTest<Ast{legal}>(m_{name})",
                                  eor=eor,
                                  name=ptr['name'],
                                  legal=legal)
                        eor = " || "
                    emitBlock("));\n")
            for i in range(1, 5):
                op = node.getOp(i)
                if op is None:
                    continue
                name, _, _, legals = op
                legals = legals.split('|')
                if legals and legals != ["Node"]:
                    emitBlock("    BROKEN_RTN({name}() && !(", name=name)
                    eor = ""
                    for legal in legals:
                        emitBlock("{eor}privateTypeTest<Ast{legal}>({name}())",
                                  eor=eor,
                                  name=name,
                                  legal=legal)
                        eor = " || "
                    emitBlock("));\n")
            # Node's broken rules can be specialized by declaring broken()
            emitBlock("    return Ast{t}::broken();\n", t=node.name)
            emitBlock("}}\n")

            emitBlock(
                "bool Ast{t}::wouldBreakGen(const AstNode* const oldp, const AstNode* const newp) const {{\n",
                t=node.name)
            for i in range(1, 5):
                op = node.getOp(i)
                if op is None:
                    continue
                name, _, _, legals = op
                legals = legals.split('|')
                if legals and legals != ["Node"]:
                    # 'this' is a parent, where oldp replacing newp as op1p, must follow op1p's rules
                    # Could also be on a list, we don't check for speed reasons and as V3Broken doesn't
                    emitBlock("    if (oldp == op{i}p() && !(", i=i)
                    eor = ""
                    for legal in legals:
                        emitBlock("{eor}privateTypeTest<Ast{legal}>(newp)",
                                  eor=eor,
                                  name=name,
                                  legal=legal)
                        eor = " || "
                    emitBlock(")) return true;\n")
            # Node's broken rules can be specialized by declaring broken()
            emitBlock("    return false;\n")
            emitBlock("}}\n")

            if node.ptrs:
                emitBlock("void Ast{t}::cloneRelinkGen() {{\n", t=node.name)
                if node.superClass.name != 'Node':
                    emitBlock("    Ast{base}::cloneRelinkGen();\n", base=node.superClass.name)
                for ptr in node.ptrs:
                    emitBlock(
                        "    if (m_{name} && m_{name}->clonep()) m_{name} = m_{name}->clonep();\n",
                        name=ptr['name'],
                        kind=ptr['kind'])

                emitBlock("}}\n")

            emitBlock("void Ast{t}::foreachLink(std::function<void(" +
                      "AstNode** linkpp, const char* namep)> f) {{\n",
                      t=node.name)
            if node.superClass.name != 'Node':
                emitBlock("    Ast{base}::foreachLink(f);\n", base=node.superClass.name)
            for ptr in node.ptrs:
                emitBlock("    f(reinterpret_cast<AstNode**>(&m_{name}), \"{name}\");",
                          name=ptr['name'])
            emitBlock("}}\n")

            emitBlock("void Ast{t}::dumpJsonGen(std::ostream& str) const {{\n", t=node.name)
            if node.superClass.name != 'Node':
                emitBlock("    Ast{base}::dumpJson(str);\n", base=node.superClass.name)
            for ptr in node.ptrs:
                emitBlock("    dumpJsonPtr(str, \"{name}\", m_{name});\n", name=ptr['name'])
            emitBlock("}}\n")

            emitBlock(
                "void Ast{t}::dumpTreeJsonOpGen(std::ostream& str, const string& indent) const {{\n",
                t=node.name)
            for i in range(1, 5):
                op = node.getOp(i)
                if op is None:
                    continue
                name, _, _, _ = op
                emitBlock("    dumpNodeListJson(str, {name}(), \"{name}\", indent);\n", name=name)
            emitBlock("}}\n")


def write_ast_macros(filename):
    with open_file(filename) as fh:

        def emitBlock(pattern, **fmt):
            fh.write(
                textwrap.indent(textwrap.dedent(pattern),
                                "    ").format(**fmt).replace("\n", " \\\n"))

        for node in AstNodeList:
            fh.write("#define ASTGEN_MEMBERS_Ast{t} \\\n".format(t=node.name))
            any_ptr = False
            for ptr in node.ptrs:
                if not any_ptr:
                    fh.write("private: \\\n")
                    any_ptr = True
                emitBlock("Ast{kind}* m_{name} = nullptr;", name=ptr['name'], kind=ptr['kind'])
            if any_ptr:
                fh.write("public: \\\n")
            # TODO pointer accessors
            # for ptr in node.ptrs:
            #     emitBlock(
            #         ("{kind}* {name}() const {{ return m_{name}; }}\n" +
            #          "void {name}({kind}* nodep) {{ m_{name} = nodep; }}"),
            #         name=ptr['name'],
            #         kind=ptr['kind'])

            emitBlock('''\
            Ast{t}* unlinkFrBack(VNRelinker* linkerp = nullptr) {{
                return static_cast<Ast{t}*>(AstNode::unlinkFrBack(linkerp));
            }}
            Ast{t}* unlinkFrBackWithNext(VNRelinker* linkerp = nullptr) {{
                return static_cast<Ast{t}*>(AstNode::unlinkFrBackWithNext(linkerp));
            }}
            Ast{t}* cloneTree(bool cloneNext, bool needPure = false) {{
                return static_cast<Ast{t}*>(AstNode::cloneTree(cloneNext, needPure));
            }}
            Ast{t}* cloneTreePure(bool cloneNext) {{
                return static_cast<Ast{t}*>(AstNode::cloneTreePure(cloneNext));
            }}
            Ast{t}* clonep() const {{ return static_cast<Ast{t}*>(AstNode::clonep()); }}
            Ast{t}* addNext(Ast{t}* nodep) {{ return static_cast<Ast{t}*>(AstNode::addNext(this, nodep)); }}
            const char* brokenGen() const override;
            void foreachLink(std::function<void(AstNode** linkpp, const char* namep)> f) override;
            bool wouldBreakGen(const AstNode* const oldp, const AstNode* const newp) const override;
            void dumpTreeJsonOpGen(std::ostream& str, const string& indent) const override;
            void dumpJsonGen(std::ostream& str) const;
            ''',
                      t=node.name)
            if node.ptrs:
                emitBlock('''\
                void cloneRelinkGen() override;
                ''')

            if node.isLeaf:
                emitBlock('''\
                void accept(VNVisitorConst& v) override {{ v.visit(this); }}
                AstNode* clone() override {{ return new Ast{t}(*this); }}
                ''',
                          t=node.name)

            hiddenMethods = []

            for n in range(1, 5):
                op = node.getOp(n)
                if not op:
                    continue
                name, monad, kind, _ = op
                retrieve = ("VN_DBG_AS(op{n}p(), {kind})"
                            if kind != "Node" else "op{n}p()").format(n=n, kind=kind)
                superOp = node.superClass.getOp(n)
                superName = None
                if superOp:
                    superName = superOp[0]
                    hiddenMethods.append(superName)
                if monad == "List":
                    emitBlock('''\
                    Ast{kind}* {name}() const VL_MT_STABLE {{ return {retrieve}; }}
                    void add{Name}(Ast{kind}* nodep) {{ addNOp{n}p(reinterpret_cast<AstNode*>(nodep)); }}
                    ''',
                              kind=kind,
                              name=name,
                              Name=name[0].upper() + name[1:],
                              n=n,
                              retrieve=retrieve)
                    if superOp:
                        hiddenMethods.append("add" + superName[0].upper() + superName[1:])
                elif monad == "Optional":
                    emitBlock('''\
                    Ast{kind}* {name}() const VL_MT_STABLE {{ return {retrieve}; }}
                    void {name}(Ast{kind}* nodep) {{ setNOp{n}p(reinterpret_cast<AstNode*>(nodep)); }}
                    ''',
                              kind=kind,
                              name=name,
                              n=n,
                              retrieve=retrieve)
                else:
                    emitBlock('''\
                    Ast{kind}* {name}() const VL_MT_STABLE {{ return {retrieve}; }}
                    void {name}(Ast{kind}* nodep) {{ setOp{n}p(reinterpret_cast<AstNode*>(nodep)); }}
                    ''',
                              kind=kind,
                              name=name,
                              n=n,
                              retrieve=retrieve)

            if hiddenMethods:
                fh.write("private: \\\n")
                for method in hiddenMethods:
                    fh.write("    using Ast{sup}::{method}; \\\n".format(sup=node.superClass.name,
                                                                         method=method))
                fh.write("public: \\\n")

            fh.write("    static_assert(true, \"\")\n")  # Swallowing the semicolon

            # Only care about leaf classes for the rest
            if node.isLeaf:
                fh.write("#define ASTGEN_SUPER_{t}(...) Ast{b}(VNType::{t}, __VA_ARGS__)\n".format(
                    t=node.name, b=node.superClass.name))
            fh.write("\n")


def write_ast_yystype(filename):
    with open_file(filename) as fh:
        for node in AstNodeList:
            fh.write("Ast{t}* {m}p;\n".format(t=node.name, m=node.name[0].lower() + node.name[1:]))


################################################################################
# DFG code generation
################################################################################


def write_dfg_macros(filename):
    with open_file(filename) as fh:

        def emitBlock(pattern, **fmt):
            fh.write(
                textwrap.indent(textwrap.dedent(pattern),
                                "    ").format(**fmt).replace("\n", " \\\n"))

        for node in DfgVertexList:
            if node.isRoot:
                continue

            fh.write("#define ASTGEN_MEMBERS_Dfg{t} \\\n".format(t=node.name))

            if node.isLeaf:
                emitBlock('''\
                static constexpr VDfgType dfgType() {{ return VDfgType::{t}; }};
                using Super = Dfg{s};
                void accept(DfgVisitor& v) override {{ v.visit(this); }}
                ''',
                          t=node.name,
                          s=node.superClass.name)

            for n in range(1, node.arity + 1):
                name, _, _, _ = node.getOp(n)
                emitBlock('''\
                DfgVertex* {name}() const {{ return inputp({n}); }}
                void {name}(DfgVertex* vtxp) {{ inputp({n}, vtxp); }}
                ''',
                          name=name,
                          n=n - 1)

            if node.isLeaf and node.arity > 1:
                operandNames = tuple(node.getOp(n)[0] for n in range(1, node.arity + 1))
                if operandNames:
                    emitBlock('''\
                              std::string srcName(size_t idx) const override final {{
                                  static const char* names[{a}] = {{ {ns} }};
                                  return names[idx];
                              }}
                              ''',
                              a=node.arity,
                              ns=", ".join(map(lambda _: '"' + _ + '"', operandNames)))

            fh.write("    static_assert(true, \"\")\n")  # Swallowing the semicolon


def write_dfg_auto_classes(filename):
    with open_file(filename) as fh:

        def emitBlock(pattern, **fmt):
            fh.write(textwrap.dedent(pattern).format(**fmt))

        for node in DfgVertexList:
            # Only generate code for automatically derived leaf nodes
            if (node.file is not None) or not node.isLeaf:
                continue

            emitBlock('''\
                      class Dfg{t} final : public Dfg{s} {{
                      public:
                          Dfg{t}(DfgGraph& dfg, FileLine* flp, const DfgDataType& dtype)
                              : Dfg{s}{{dfg, dfgType(), flp, dtype}} {{}}
                          ASTGEN_MEMBERS_Dfg{t};
                      }};
                      ''',
                      t=node.name,
                      s=node.superClass.name)
        fh.write("\n")


def write_dfg_clone_cases(filename):
    with open_file(filename) as fh:

        def emitBlock(pattern, **fmt):
            fh.write(textwrap.dedent(pattern).format(**fmt))

        for node in DfgVertexList:
            # Only generate code for automatically derived leaf nodes
            if (node.file is not None) or not node.isLeaf:
                continue

            emitBlock('''\
                      case VDfgType::{t}: {{
                        Dfg{t}* const cp = new Dfg{t}{{*clonep, vtx.fileline(), vtx.dtype()}};
                        vtxp2clonep.emplace(&vtx, cp);
                        break;
                      }}
                      ''',
                      t=node.name,
                      s=node.superClass.name)
        fh.write("\n")


def write_dfg_ast_to_dfg(filename):
    with open_file(filename) as fh:
        for node in DfgVertexList:
            # Only generate code for automatically derived leaf nodes
            if (node.file is not None) or (not node.isLeaf):
                continue

            fh.write("void visit(Ast{t}* nodep) override {{\n".format(t=node.name))
            fh.write(
                '    UASSERT_OBJ(m_converting, nodep, "AstToDfg visit called without m_converting");\n'
            )
            fh.write('    UASSERT_OBJ(!nodep->user2p(), nodep, "Already has Dfg vertex");\n\n')
            fh.write("    if (unhandled(nodep)) return;\n\n")

            fh.write(
                "    const DfgDataType* const dtypep = DfgDataType::fromAst(nodep->dtypep());\n")
            fh.write("    if (!dtypep) {\n")
            fh.write("        m_foundUnhandled = true;\n")
            fh.write("        ++m_ctx.m_conv.nonRepDType;\n")
            fh.write("        return;\n")
            fh.write("    }\n\n")

            for i in range(node.arity):
                fh.write("    iterate(nodep->op{j}p());\n".format(j=i + 1))
                fh.write("    if (m_foundUnhandled) return;\n")
                fh.write(
                    '    UASSERT_OBJ(nodep->op{j}p()->user2p(), nodep, "Child {j} missing Dfg vertex");\n'
                    .format(j=i + 1))
            fh.write("\n")
            fh.write(
                "    Dfg{t}* const vtxp = makeVertex<Dfg{t}>(nodep, m_dfg, *dtypep);\n".format(
                    t=node.name))
            fh.write("    if (!vtxp) {\n")
            fh.write("        m_foundUnhandled = true;\n")
            fh.write("        ++m_ctx.m_conv.nonRepNode;\n")
            fh.write("        return;\n")
            fh.write("    }\n")
            fh.write("    m_logicp->synth().emplace_back(vtxp);\n\n")
            for i in range(node.arity):
                fh.write(
                    "    vtxp->inputp({i}, nodep->op{j}p()->user2u().to<DfgVertex*>());\n".format(
                        i=i, j=i + 1))
            fh.write("\n")
            fh.write("    nodep->user2p(vtxp);\n")
            fh.write("}\n")


def write_dfg_dfg_to_ast(filename):
    with open_file(filename) as fh:
        for node in DfgVertexList:
            # Only generate code for automatically derived leaf nodes
            if (node.file is not None) or (not node.isLeaf):
                continue

            fh.write("void visit(Dfg{t}* vtxp) override {{\n".format(t=node.name))
            for i in range(node.arity):
                fh.write(
                    "    AstNodeExpr* const op{j}p = convertDfgVertexToAstNodeExpr(vtxp->inputp({i}));\n"
                    .format(i=i, j=i + 1))
            fh.write("    m_resultp = makeNode<Ast{t}>(vtxp".format(t=node.name))
            for i in range(node.arity):
                fh.write(", op{j}p".format(j=i + 1))
            fh.write(");\n")
            fh.write("}\n")


######################################################################
# main

parser = argparse.ArgumentParser(
    allow_abbrev=False,
    formatter_class=argparse.RawDescriptionHelpFormatter,
    description="""Generate V3Ast headers to reduce C++ code duplication.""",
    epilog="""Copyright 2002-2025 by Wilson Snyder. This program is free software; you
can redistribute it and/or modify it under the terms of either the GNU
Lesser General Public License Version 3 or the Perl Artistic License
Version 2.0.

SPDX-License-Identifier: LGPL-3.0-only OR Artistic-2.0""")

parser.add_argument('-I', action='store', help='source code include directory')
parser.add_argument('--astdef', action='append', help='add AST definition file (relative to -I)')
parser.add_argument('--dfgdef', action='append', help='add DFG definition file (relative to -I)')
parser.add_argument('--classes', action='store_true', help='makes class declaration files')
parser.add_argument('--debug', action='store_true', help='enable debug')

parser.add_argument('infiles', nargs='*', help='list of input .cpp filenames')

Args = parser.parse_args()

###############################################################################
# Read AstNode definitions
###############################################################################

# Set up the root AstNode type. It is standalone so we don't need to parse the
# sources for this.
AstNodes["Node"] = Node("Node", None)

# Read AstNode definitions
for filename in Args.astdef:
    read_types(os.path.join(Args.I, filename), AstNodes, "Ast")

# Compute derived properties over the whole AstNode hierarchy
AstNodes["Node"].complete()

AstNodeList = tuple(map(lambda _: AstNodes[_], sorted(AstNodes.keys())))

check_types(AstNodeList, "Ast", "Node")

###############################################################################
# Read and generate DfgVertex definitions
###############################################################################

# Set up the root DfgVertex type and some other hand-written base types.
# These are standalone so we don't need to parse the sources for this.
DfgVertices["Vertex"] = Node("Vertex", None)

# AstNodeExpr that are not representable in Dfg
DfgIgnored = (
    # Floating point operations
    "AcosD",
    "AcoshD",
    "AddD",
    "AsinD",
    "AsinhD",
    "Atan2D",
    "AtanD",
    "AtanhD",
    "BitsToRealD",
    "CeilD",
    "CosD",
    "CoshD",
    "DivD",
    "EqD",
    "ExpD",
    "FloorD",
    "GtD",
    "GteD",
    "HypotD",
    "ISToRD",
    "IToRD",
    "Log10D",
    "LogD",
    "LtD",
    "LteD",
    "MulD",
    "NegateD",
    "NeqD",
    "PowD",
    "RealToBits",
    "RToIRoundS",
    "RToIS",
    "SinD",
    "SinhD",
    "SqrtD",
    "SubD",
    "TanD",
    "TanhD",
    # String operations
    "AtoN",
    "CompareNN",
    "ConcatN",
    "CvtPackString",
    "EqN",
    "GetcN",
    "GetcRefN",
    "GteN",
    "GtN",
    "LenN",
    "LteN",
    "LtN",
    "NeqN",
    "NToI",
    "PutcN",
    "ReplicateN",
    "SubstrN",
    "ToLowerN",
    "ToStringN",
    "ToUpperN",
    # Effectful
    "PostAdd",
    "PostSub",
    "PreAdd",
    "PreSub",
    # Only used after DFG
    "ShiftLOvr",
    "ShiftROvr",
    "ShiftRSOvr",
    "WordSel",
    # File operations
    "FEof",
    "FGetC",
    "FGetS",
    "FUngetC",
    # Dynamic array operations
    "AssocSel",
    "IsUnbounded",
    "WildcardSel",
    # Type comparison
    "EqT",
    "NeqT",
    # Distributions
    "DistChiSquare",
    "DistErlang",
    "DistExponential",
    "DistNormal",
    "DistPoisson",
    "DistT",
    "DistUniform",
    # Specials
    "CastDynamic",
    "CastWrap",
    "CAwait",
    "CCast",
    "CLog2",
    "CountOnes",
    "IsUnknown",
    "NullCheck",
    "OneHot",
    "OneHot0",
    "ResizeLValue",
    "Signed",
    "SliceSel",
    "TimeImport",
    "Unsigned",
    "URandomRange",
)

# Read DfgVertex definitions
for filename in Args.dfgdef:
    read_types(os.path.join(Args.I, filename), DfgVertices, "Dfg")

# Add the DfgVertex sub-types automatically derived from AstNode sub-types
for node in AstNodeList:
    # Ignore the hierarchy for now
    if not node.isLeaf:
        continue

    # Ignore any explicitly defined vertex
    if node.name in DfgVertices:
        continue

    # Ignore expressions types that DFG cannot handle
    if node.name in DfgIgnored:
        continue

    if node.isSubClassOf(AstNodes["NodeUniop"]):
        base = DfgVertices["VertexUnary"]
    elif node.isSubClassOf(AstNodes["NodeBiop"]):
        base = DfgVertices["VertexBinary"]
    elif node.isSubClassOf(AstNodes["NodeTriop"]):
        base = DfgVertices["VertexTernary"]
    else:
        continue

    vertex = Node(node.name, base)
    DfgVertices[node.name] = vertex
    base.addSubClass(vertex)

    for n in range(1, node.arity + 1):
        op = node.getOp(n)
        if op is not None:
            name, monad, kind, legals = op
            assert monad == "", "Cannot represent AstNode as DfgVertex"
            vertex.addOp(n, name, "", "", "")

# Compute derived properties over the whole DfgVertex hierarchy
DfgVertices["Vertex"].complete()

DfgVertexList = tuple(map(lambda _: DfgVertices[_], sorted(DfgVertices.keys())))

check_types(DfgVertexList, "Dfg", "Vertex")

###############################################################################
# Read additional files
###############################################################################

read_stages(Args.I + "/Verilator.cpp")

source_files = glob.glob(Args.I + "/*.y")
source_files.extend(glob.glob(Args.I + "/*.h"))
source_files.extend(glob.glob(Args.I + "/*.cpp"))
for filename in source_files:
    read_refs(filename)

###############################################################################
# Generate output
###############################################################################

if Args.classes:
    write_report("V3Ast__gen_report.txt")
    # Write Ast code
    write_forward_class_decls("Ast", AstNodeList)
    write_visitor_decls("Ast", AstNodeList)
    write_visitor_defns("Ast", AstNodeList, "VNVisitorConst")
    write_type_enum("Ast", AstNodeList)
    write_type_tests("Ast", AstNodeList)
    write_ast_type_info("V3Ast__gen_type_info.h")
    write_ast_impl("V3Ast__gen_impl.h")
    write_ast_macros("V3Ast__gen_macros.h")
    write_ast_yystype("V3Ast__gen_yystype.h")
    # Write Dfg code
    write_forward_class_decls("Dfg", DfgVertexList)
    write_visitor_decls("Dfg", DfgVertexList)
    write_visitor_defns("Dfg", DfgVertexList, "DfgVisitor")
    write_type_enum("Dfg", DfgVertexList)
    write_type_tests("Dfg", DfgVertexList)
    write_dfg_macros("V3Dfg__gen_macros.h")
    write_dfg_auto_classes("V3Dfg__gen_auto_classes.h")
    write_dfg_clone_cases("V3Dfg__gen_clone_cases.h")
    write_dfg_ast_to_dfg("V3Dfg__gen_ast_to_dfg.h")
    write_dfg_dfg_to_ast("V3Dfg__gen_dfg_to_ast.h")

for cpt in Args.infiles:
    if not re.search(r'.cpp$', cpt):
        sys.exit("%Error: Expected argument to be .cpp file: " + cpt)
    cpt = re.sub(r'.cpp$', '', cpt)
    Cpt().process(in_filename=Args.I + "/" + cpt + ".cpp", out_filename=cpt + "__gen.cpp")

######################################################################
# Local Variables:
# compile-command: "touch src/V3AstNodeExpr.h ; v4make"
# End:
