#! /usr/bin/python3

import getopt
import sys
import os.path
import re

OFP_ACTION_ALIGN = 8

# Map from OpenFlow version number to version ID used in ofp_header.
version_map = {"1.0": 0x01,
               "1.1": 0x02,
               "1.2": 0x03,
               "1.3": 0x04,
               "1.4": 0x05,
               "1.5": 0x06}
version_reverse_map = dict((v, k) for (k, v) in version_map.items())

# Map from vendor name to the length of the action header.
vendor_map = {"OF": (0x00000000,  4),
              "ONF": (0x4f4e4600, 10),
              "NX": (0x00002320, 10)}

# Basic types used in action arguments.
types = {}
types['uint8_t'] =  {"size": 1, "align": 1, "ntoh": None,     "hton": None}
types['ovs_be16'] = {"size": 2, "align": 2, "ntoh": "ntohs",  "hton": "htons"}
types['ovs_be32'] = {"size": 4, "align": 4, "ntoh": "ntohl",  "hton": "htonl"}
types['ovs_be64'] = {"size": 8, "align": 8, "ntoh": "ntohll", "hton": "htonll"}
types['uint16_t'] = {"size": 2, "align": 2, "ntoh": None,     "hton": None}
types['uint32_t'] = {"size": 4, "align": 4, "ntoh": None,     "hton": None}
types['uint64_t'] = {"size": 8, "align": 8, "ntoh": None,     "hton": None}

line = ""

arg_structs = set()

def round_up(x, y):
    return int((x + (y - 1)) / y) * y

def open_file(fn):
    global file_name
    global input_file
    global line_number
    file_name = fn
    input_file = open(file_name)
    line_number = 0

def get_line():
    global input_file
    global line
    global line_number
    line = input_file.readline()
    line_number += 1
    if line == "":
        fatal("unexpected end of input")
    return line

n_errors = 0
def error(msg):
    global n_errors
    sys.stderr.write("%s:%d: %s\n" % (file_name, line_number, msg))
    n_errors += 1

def fatal(msg):
    error(msg)
    sys.exit(1)

def usage():
    argv0 = os.path.basename(sys.argv[0])
    print('''\
%(argv0)s, for extracting OpenFlow action data
usage: %(argv0)s [prototypes | definitions] OFP-ACTIONS.c

Commands:

  prototypes OFP-ACTIONS.C
    Reads ofp-actions.c and prints a set of prototypes to #include early in
    ofp-actions.c.

  definitions OFP-ACTIONS.C
    Reads ofp-actions.c and prints a set of definitions to #include late in
    ofp-actions.c.
''' % {"argv0": argv0})
    sys.exit(0)

def extract_ofp_actions(fn, definitions):
    error_types = {}

    comments = []
    names = []
    domain = {}
    for code, size in vendor_map.values():
        domain[code] = {}
    enums = {}

    n_errors = 0

    open_file(fn)

    while True:
        get_line()
        if re.match('enum ofp_raw_action_type {', line):
            break

    while True:
        get_line()
        if line.startswith('/*') or not line or line.isspace():
            continue
        elif re.match('}', line):
            break

        if not line.lstrip().startswith('/*'):
            fatal("unexpected syntax between actions")

        comment = line.lstrip()[2:].strip()
        while not comment.endswith('*/'):
            get_line()
            if line.startswith('/*') or not line or line.isspace():
                fatal("unexpected syntax within action")
            comment += ' %s' % line.lstrip('* \t').rstrip(' \t\r\n')
        comment = re.sub('\[[^]]*\]', '', comment)
        comment = comment[:-2].rstrip()

        m = re.match('([^:]+):\s+(.*)$', comment)
        if not m:
            fatal("unexpected syntax between actions")

        dsts = m.group(1)
        argtypes = m.group(2).strip().replace('.', '', 1)

        if 'VLMFF' in argtypes:
            arg_vl_mff_map = True
        else:
            arg_vl_mff_map = False
        argtype = argtypes.replace('VLMFF', '', 1).rstrip()

        get_line()
        m = re.match(r'\s+(([A-Z]+)_RAW([0-9]*)_([A-Z0-9_]+)),?', line)
        if not m:
            fatal("syntax error expecting enum value")

        enum = m.group(1)
        if enum in names:
            fatal("%s specified twice" % enum)

        names.append(enum)

        for dst in dsts.split(', '):
            m = re.match(r'([A-Z]+)([0-9.]+)(\+|-[0-9.]+)?(?:\((\d+)\))(?: is deprecated \(([^)]+)\))?$', dst)
            if not m:
                fatal("%r: syntax error in destination" % dst)
            vendor_name = m.group(1)
            version1_name = m.group(2)
            version2_name = m.group(3)
            type_ = int(m.group(4))
            deprecation = m.group(5)

            if vendor_name not in vendor_map:
                fatal("%s: unknown vendor" % vendor_name)
            vendor = vendor_map[vendor_name][0]

            if version1_name not in version_map:
                fatal("%s: unknown OpenFlow version" % version1_name)
            v1 = version_map[version1_name]

            if version2_name is None:
                v2 = v1
            elif version2_name == "+":
                v2 = max(version_map.values())
            elif version2_name[1:] not in version_map:
                fatal("%s: unknown OpenFlow version" % version2_name[1:])
            else:
                v2 = version_map[version2_name[1:]]

            if v2 < v1:
                fatal("%s%s: %s precedes %s"
                      % (version1_name, version2_name,
                         version2_name, version1_name))

            for version in range(v1, v2 + 1):
                domain[vendor].setdefault(type_, {})
                if version in domain[vendor][type_]:
                    v = domain[vendor][type_][version]
                    msg = "%#x,%d in OF%s means both %s and %s" % (
                        vendor, type_, version_reverse_map[version],
                        v["enum"], enum)
                    error("%s: %s." % (dst, msg))
                    sys.stderr.write("%s:%d: %s: Here is the location "
                                     "of the previous definition.\n"
                                     % (v["file_name"], v["line_number"],
                                        dst))
                    n_errors += 1
                else:
                    header_len = vendor_map[vendor_name][1]

                    base_argtype = argtype.replace(', ..', '', 1)
                    if base_argtype in types:
                        arg_align = types[base_argtype]['align']
                        arg_len = types[base_argtype]['size']
                        arg_ofs = round_up(header_len, arg_align)
                        min_length = round_up(arg_ofs + arg_len,
                                              OFP_ACTION_ALIGN)
                    elif base_argtype == 'void':
                        min_length = round_up(header_len, OFP_ACTION_ALIGN)
                        arg_len = 0
                        arg_ofs = 0
                    elif re.match(r'struct [a-zA-Z0-9_]+$', base_argtype):
                        min_length = 'sizeof(%s)' % base_argtype
                        arg_structs.add(base_argtype)
                        arg_len = 0
                        arg_ofs = 0
                        # should also emit OFP_ACTION_ALIGN assertion
                    else:
                        fatal("bad argument type %s" % argtype)

                    ellipsis = argtype != base_argtype
                    if ellipsis:
                        max_length = '65536 - OFP_ACTION_ALIGN'
                    else:
                        max_length = min_length

                    info = {"enum": enum,                     # 0
                            "deprecation": deprecation,       # 1
                            "file_name": file_name,           # 2
                            "line_number": line_number,       # 3
                            "min_length": min_length,         # 4
                            "max_length": max_length,         # 5
                            "arg_ofs": arg_ofs,               # 6
                            "arg_len": arg_len,               # 7
                            "base_argtype": base_argtype,     # 8
                            "arg_vl_mff_map": arg_vl_mff_map, # 9
                            "version": version,               # 10
                            "type": type_}                    # 11
                    domain[vendor][type_][version] = info

                    enums.setdefault(enum, [])
                    enums[enum].append(info)

    input_file.close()

    if n_errors:
        sys.exit(1)

    print("""\
/* Generated automatically; do not modify!     -*- buffer-read-only: t -*- */
""")

    if definitions:
        print("/* Verify that structs used as actions are reasonable sizes. */")
        for s in sorted(arg_structs):
            print("BUILD_ASSERT_DECL(sizeof(%s) %% OFP_ACTION_ALIGN == 0);" % s)

        print("\nstatic struct ofpact_raw_instance all_raw_instances[] = {")
        for vendor in domain:
            for type_ in domain[vendor]:
                for version in domain[vendor][type_]:
                    d = domain[vendor][type_][version]
                    print("    { { 0x%08x, %2d, 0x%02x }, " % (
                        vendor, type_, version))
                    print("      %s," % d["enum"])
                    print("      HMAP_NODE_NULL_INITIALIZER,")
                    print("      HMAP_NODE_NULL_INITIALIZER,")
                    print("      %s," % d["min_length"])
                    print("      %s," % d["max_length"])
                    print("      %s," % d["arg_ofs"])
                    print("      %s," % d["arg_len"])
                    print("      \"%s\"," % re.sub('_RAW[0-9]*', '', d["enum"], 1))
                    if d["deprecation"]:
                        print("      \"%s\"," % re.sub(r'(["\\])', r'\\\1', d["deprecation"]))
                    else:
                        print("      NULL,")
                    print("    },")
        print("};")

    for versions in enums.values():
        need_ofp_version = False
        for v in versions:
            assert v["arg_len"] == versions[0]["arg_len"]
            assert v["base_argtype"] == versions[0]["base_argtype"]
            if (v["min_length"] != versions[0]["min_length"] or
                v["arg_ofs"] != versions[0]["arg_ofs"] or
                v["type"] != versions[0]["type"]):
                need_ofp_version = True
        base_argtype = versions[0]["base_argtype"]

        decl = "static inline "
        if base_argtype.startswith('struct'):
            decl += "%s *" %base_argtype
        else:
            decl += "void"
        decl += "\nput_%s(struct ofpbuf *openflow" % versions[0]["enum"].replace('_RAW', '', 1)
        if need_ofp_version:
            decl += ", enum ofp_version version"
        if base_argtype != 'void' and not base_argtype.startswith('struct'):
            decl += ", %s arg" % base_argtype
        decl += ")"
        if definitions:
            decl += "{\n"
            decl += "    "
            if base_argtype.startswith('struct'):
                decl += "return "
            decl += "ofpact_put_raw(openflow, "
            if need_ofp_version:
                decl += "version"
            else:
                decl += "%s" % versions[0]["version"]
            decl += ", %s, " % versions[0]["enum"]
            if base_argtype.startswith('struct') or base_argtype == 'void':
                decl += "0"
            else:
                ntoh = types[base_argtype]['ntoh']
                if ntoh:
                    decl += "%s(arg)" % ntoh
                else:
                    decl += "arg"
            decl += ");\n"
            decl += "}"
        else:
            decl += ";"
        print(decl)
        print("")

    if definitions:
        print("""\
static enum ofperr
ofpact_decode(const struct ofp_action_header *a, enum ofp_raw_action_type raw,
              enum ofp_version version, uint64_t arg,
              const struct vl_mff_map *vl_mff_map,
              uint64_t *tlv_bitmap, struct ofpbuf *out)
{
    switch (raw) {\
""")
        for versions in enums.values():
            enum = versions[0]["enum"]
            print("    case %s:" % enum)
            base_argtype = versions[0]["base_argtype"]
            arg_vl_mff_map = versions[0]["arg_vl_mff_map"]
            if base_argtype == 'void':
                print("        return decode_%s(out);" % enum)
            else:
                if base_argtype.startswith('struct'):
                    arg = "ALIGNED_CAST(const %s *, a)" % base_argtype
                else:
                    hton = types[base_argtype]['hton']
                    if hton:
                        arg = "%s(arg)" % hton
                    else:
                        arg = "arg"
                if arg_vl_mff_map:
                    print("        return decode_%s(%s, version, vl_mff_map, tlv_bitmap, out);" % (enum, arg))
                else:
                    print("        return decode_%s(%s, version, out);" % (enum, arg))
            print("")
        print("""\
    default:
        OVS_NOT_REACHED();
    }
}\
""")
    else:
        for versions in enums.values():
            enum = versions[0]["enum"]
            prototype = "static enum ofperr decode_%s(" % enum
            base_argtype = versions[0]["base_argtype"]
            arg_vl_mff_map = versions[0]["arg_vl_mff_map"]
            if base_argtype != 'void':
                if base_argtype.startswith('struct'):
                    prototype += "const %s *, enum ofp_version, " % base_argtype
                else:
                    prototype += "%s, enum ofp_version, " % base_argtype
                if arg_vl_mff_map:
                    prototype += 'const struct vl_mff_map *, uint64_t *, '
            prototype += "struct ofpbuf *);"
            print(prototype)

        print("""
static enum ofperr ofpact_decode(const struct ofp_action_header *,
                                 enum ofp_raw_action_type raw,
                                 enum ofp_version version,
                                 uint64_t arg, const struct vl_mff_map *vl_mff_map,
                                 uint64_t *tlv_bitmap, struct ofpbuf *out);
""")

## ------------ ##
## Main Program ##
## ------------ ##

if __name__ == '__main__':
    argv0 = sys.argv[0]
    try:
        options, args = getopt.gnu_getopt(sys.argv[1:], 'h',
                                          ['help', 'ovs-version='])
    except getopt.GetoptError as geo:
        sys.stderr.write("%s: %s\n" % (argv0, geo.msg))
        sys.exit(1)

    global version
    version = None
    for key, value in options:
        if key in ['-h', '--help']:
            usage()
        elif key == '--ovs-version':
            version = value
        else:
            sys.exit(0)

    if not args:
        sys.stderr.write("%s: missing command argument "
                         "(use --help for help)\n" % argv0)
        sys.exit(1)

    commands = {"prototypes": (lambda fn: extract_ofp_actions(fn, False), 1),
                "definitions": (lambda fn: extract_ofp_actions(fn, True), 1)}

    if not args[0] in commands:
        sys.stderr.write("%s: unknown command \"%s\" "
                         "(use --help for help)\n" % (argv0, args[0]))
        sys.exit(1)

    func, n_args = commands[args[0]]
    if len(args) - 1 != n_args:
        sys.stderr.write("%s: \"%s\" requires %d arguments but %d "
                         "provided\n"
                         % (argv0, args[0], n_args, len(args) - 1))
        sys.exit(1)

    func(*args[1:])
