#! /usr/bin/env python
# Copyright(c) 2017, Intel Corporation
#
# Redistribution  and  use  in source  and  binary  forms,  with  or  without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of  source code  must retain the  above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name  of Intel Corporation  nor the names of its contributors
#   may be used to  endorse or promote  products derived  from this  software
#   without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING,  BUT NOT LIMITED TO,  THE
# IMPLIED WARRANTIES OF  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT  SHALL THE COPYRIGHT OWNER  OR CONTRIBUTORS BE
# LIABLE  FOR  ANY  DIRECT,  INDIRECT,  INCIDENTAL,  SPECIAL,  EXEMPLARY,  OR
# CONSEQUENTIAL  DAMAGES  (INCLUDING,  BUT  NOT LIMITED  TO,  PROCUREMENT  OF
# SUBSTITUTE GOODS OR SERVICES;  LOSS OF USE,  DATA, OR PROFITS;  OR BUSINESS
# INTERRUPTION)  HOWEVER CAUSED  AND ON ANY THEORY  OF LIABILITY,  WHETHER IN
# CONTRACT,  STRICT LIABILITY,  OR TORT  (INCLUDING NEGLIGENCE  OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,  EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import argparse
import struct
import fnmatch
import glob
import sys
import tempfile
import os
import fcntl
import filecmp
import stat
import re
import time
from array import array

try:
    from intelhex import IntelHex
except ImportError:
    sys.exit('Missing intelhex. Install by: sudo pip install intelhex')


def check_rpd(ifile):
    data = ifile.read(0x20)
    pof_hdr = struct.unpack('IIIIIIII', data)
    for i in range(0, 3):
        if pof_hdr[i] != 0xffffffff:
            print "invalid rpd file"
            raise Exception

    if pof_hdr[3] != 0x6a6a6a6a:
        print "invalid rpd file"
        raise Exception

    return pof_hdr[4]


def reverse_bits(x, n):
    result = 0
    for i in xrange(n):
        if (x >> i) & 1:
            result |= 1 << (n - 1 - i)
    return result


def reverse_bits_in_file(ifile, ofile):
    bit_rev = array('B')
    for i in range(0, 256):
        bit_rev.append(reverse_bits(i, 8))

    while True:
        ichunk = ifile.read(4096)
        if not ichunk:
            break

        ochunk = ''
        for b in ichunk:
            ochunk += chr(bit_rev[ord(b)])
        ofile.write(ochunk)


def get_flash_size(dev):

    MEMGETINFO = 0x80204d01

    ioctl_data = struct.pack('BIIIIIQ', 0, 0, 0, 0, 0, 0, 0)

    with os.fdopen(os.open(dev, os.O_SYNC | os.O_RDONLY), 'r') as file_:
        ret = fcntl.ioctl(file_.fileno(), MEMGETINFO, ioctl_data)

    ioctl_odata = struct.unpack_from('BIIIIIQ', ret)

    return ioctl_odata[2]


def flash_erase(dev, start, nbytes):
    MEMERASE = 0x40084d02

    ioctl_data = struct.pack('II', start, nbytes)
    with os.fdopen(os.open(dev, os.O_SYNC | os.O_RDWR), 'w') as file_:
        fcntl.ioctl(file_.fileno(), MEMERASE, ioctl_data)


def flash_write(dev, start, nbytes, ifile):

    with os.fdopen(os.open(dev, os.O_SYNC | os.O_RDWR), 'a') as file_:
        os.lseek(file_.fileno(), start, os.SEEK_SET)

        while nbytes > 0:
            if nbytes > 4096:
                rbytes = 4096
            else:
                rbytes = nbytes

            ichunk = ifile.read(rbytes)

            if not ichunk:
                raise Exception("read of flash failed")

            os.write(file_.fileno(), ichunk)
            nbytes -= rbytes


def flash_read(dev, start, nbytes, ofile):
    with os.fdopen(os.open(dev, os.O_RDONLY), 'r') as file_:
        os.lseek(file_.fileno(), start, os.SEEK_SET)

        while nbytes > 0:
            if nbytes > 4096:
                rbytes = 4096
            else:
                rbytes = nbytes

            ichunk = os.read(file_.fileno(), rbytes)

            if not ichunk:
                raise Exception("read of flash failed")

            ofile.write(ichunk)
            nbytes -= rbytes


def parse_args():
    descr = 'A tool to help update the flash used to configure an '
    descr += 'Intel FPGA at power up.'

    epi = 'example usage:\n\n'
    epi += '    fpgaflash user new_image.rpd 0000:04:00.0\n\n'

    fc_ = argparse.RawDescriptionHelpFormatter
    parser = argparse.ArgumentParser(description=descr, epilog=epi,
                                     formatter_class=fc_)

    parser.add_argument('type', help='type of flash programming',
                        choices=['user', 'factory', 'bmc_bl', 'bmc_app'])
    parser.add_argument('file', type=argparse.FileType('rb'),
                        help='file to program into flash')

    bdf_help = "bdf of device to program (e.g. 04:00.0 or 0000:04:00.0)"
    bdf_help += " optional when one device in system"

    parser.add_argument('bdf', nargs='?', help=bdf_help)
    return parser.parse_args()


def get_bdf_mtd_mapping():
    bdf_map = dict()
    for fpga in glob.glob('/sys/class/fpga/*'):
        bdf = os.path.basename(os.readlink(os.path.join(fpga, "device")))
        if not bdf:
            continue

        mtds = glob.glob(os.path.join(fpga, 'intel-fpga-fme.*',
                                      'altr-asmip2*', 'mtd', 'mtd*'))
        for mtd in mtds:
            if not fnmatch.fnmatchcase(mtd, "*ro"):
                bdf_map[bdf] = os.path.join('/dev', os.path.basename(mtd))
                break

    return bdf_map


def print_bdf_mtd_mapping(bdf_map):
    print "\nFPGA cards available for flashing:"

    for key in bdf_map.keys():
        print "    {}".format(key)

    print

    sys.exit(1)


def normalize_bdf(bdf):

    pat = r'[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:[0-9a-fA-F]{2}\.[0-9a-fA-F]$'
    if re.match(pat, bdf):
        return bdf

    if re.match(r'[0-9a-fA-F]{2}:[0-9a-fA-F]{2}\.[0-9a-fA-F]$', bdf):
        return "0000:{}".format(bdf)


def update_flash(update_type, ifile, mtd_dev):

    flash_size = get_flash_size(mtd_dev)

    print "flash size is {}".format(flash_size)

    start_addr = check_rpd(ifile)

    if update_type == 'factory':
        start_addr = 0

    ofile = tempfile.NamedTemporaryFile(mode='wb', delete=False)

    ifile.seek(start_addr)

    print "reversing bits"
    reverse_bits_in_file(ifile, ofile)

    ifile.close()
    ofile.close()

    print "erasing flash"
    flash_erase(mtd_dev, start_addr, (flash_size - start_addr))

    nbytes = os.path.getsize(ofile.name)

    with open(ofile.name, 'rb') as rfile:
        print "writing flash"
        flash_write(mtd_dev, start_addr, nbytes, rfile)

    vfile = tempfile.NamedTemporaryFile(mode='wb', delete=False)

    print "reading back flash"
    flash_read(mtd_dev, start_addr, nbytes, vfile)

    vfile.close()

    print "verifying flash"

    retval = filecmp.cmp(ofile.name, vfile.name)

    os.remove(ofile.name)
    os.remove(vfile.name)

    if retval:
        print "flash successfully verified"
    else:
        print "failed to verify flash"
        raise Exception


def fpga_update(utype, ifile, bdf, bdf_map):
    mtd_dev = bdf_map[bdf]

    if not mtd_dev:
        print_bdf_mtd_mapping(bdf_map)

    try:
        mode = os.stat(mtd_dev).st_mode
    except Exception as ex:
        print ex
        return 1

    if not stat.S_ISCHR(mode):
        print "{} is not a device file.".format(mtd_dev)
        return 1

    if utype == 'factory':
        msg = "Are you sure you want to perform a factory update? [Yes/No]"
        line = raw_input(msg)
        if line != "Yes":
            return 1

    update_flash(utype, ifile, mtd_dev)
    return 0


def get_dev_bmc(bdf):
    path = os.path.join('/sys/bus/pci/devices/', bdf,
                        'fpga/intel-fpga-dev.*/intel-fpga-fme.*',
                        'avmmi-bmc.*.auto')
    dirs = glob.glob(path)
    if len(dirs) < 1:
        print "The avmmi-bmc driver was not found."
        print "Driver or FIM may need to be upgraded."
        sys.exit(1)

    if len(dirs) > 1:
        print "Catastrophic error! More than one avmmi-bmc driver found."
        sys.exit(1)

    dev = os.path.join('/dev', os.path.basename(dirs[0]))

    mode = os.stat(dev).st_mode

    if not stat.S_ISCHR(mode):
        raise Exception("{} is not a device file.".format(dev))

    return dev


class BittwareBmc(object):

    BMC_IOCTL = 0xc0187600

    BW_ACT_APP_MAIN = 0x01
    BW_ACT_APP_BL = 0x02

    BW_DEV_FLASH = 0x00
    BW_DEV_EEPROM = 0x01

    BW_BL_CMD_HDR = [0xB8, 0x00, 0x64, 0x18, 0x7b, 0x00]
    BW_BL_MIN_RSP_LEN = 8  # includes completion code and BL result code

    BW_BL_CMD_VER = 0
    BW_BL_CMD_JUMP = 1
    BW_BL_CMD_READ = 2
    BW_BL_CMD_WRITE = 3

    BW_BL_READ_MAX = 512
    BW_BL_WRITE_MAX = 512
    BW_BL_PAGE_SIZE = 512

    BW_BL_HDR_SIZE = 16

    partitions = [
        {
            'name': 'main',
            'app': BW_ACT_APP_MAIN,
            'start': 0x00000,
            'hdr_start': 0x34FF0,
        },
        {
            'name': 'cfg',
            'app': BW_ACT_APP_MAIN,
            'start': 0x35000,
            'hdr_start': 0x38ff0,
        },
        {
            'name': 'bootloader',
            'app': BW_ACT_APP_BL,
            'start': 0x39000,
            'hdr_start': 0x40ff0,
        },
        {
            'name': 'bootloader_boot',
            'app': BW_ACT_APP_BL,
            'start': 0x41000,
            'hdr_start': 0x41770,
        },
        {
            'name': 'main_boot',
            'app': BW_ACT_APP_MAIN,
            'start': 0x41800,
            'hdr_start': 0x41ff0,
        },
    ]

    def __init__(self, dev_path, ifile):
        self.dev = dev_path
        self.fd_ = os.open(self.dev, os.O_SYNC | os.O_RDWR)
        self.ihex = IntelHex(ifile)
        self.ihex.padding = 0xff

    def verify_segments(self, utype):
        print self.ihex.segments()
        for (start, size) in self.ihex.segments():
            print "0x%x 0x%x" % (start, size)

    def bw_xact(self, txarray, rxlen):
        tx_buf = array('B', txarray)

        rx_buf = array('B')
        for _ in range(rxlen):
            rx_buf.append(0)

        xact = struct.pack('IHHQQ', 24,
                           tx_buf.buffer_info()[1], rx_buf.buffer_info()[1],
                           tx_buf.buffer_info()[0], rx_buf.buffer_info()[0])

        fcntl.ioctl(self.fd_, 0xc0187600, xact)

        if rx_buf[3] != 0:
            raise Exception("bad completion code 0x%x" % rx_buf[3])

        return rx_buf

    def bw_bl_xact(self, bltx, blrx):
        tx_buf = array('B', self.BW_BL_CMD_HDR)

        for elem in bltx:
            tx_buf.append(elem)

        rx_buf = self.bw_xact(tx_buf, blrx + self.BW_BL_MIN_RSP_LEN)

        if rx_buf[self.BW_BL_MIN_RSP_LEN - 1] != 0:
            raise Exception("bad BL result code 0x%x" %
                            rx_buf[self.BW_BL_MIN_RSP_LEN - 1])

        for _ in range(self.BW_BL_MIN_RSP_LEN):
            rx_buf.pop(0)

        return rx_buf

    def bl_version(self):

        rx_buf = self.bw_bl_xact([self.BW_BL_CMD_VER], 3)

        ver = rx_buf.pop(0)
        ver |= rx_buf.pop(0) << 8

        act = rx_buf.pop(0)

        if (act != self.BW_ACT_APP_MAIN) and (act != self.BW_ACT_APP_BL):
            raise Exception("bad active application 0x%x" % act)

        return ver, act

    def bl_jump_other(self, app):
        if (app != self.BW_ACT_APP_MAIN) and (app != self.BW_ACT_APP_BL):
            raise Exception("bad app %d" % app)

        self.bw_bl_xact([self.BW_BL_CMD_JUMP, app], 0)

        time.sleep(1)

        [ver, active] = self.bl_version()

        if active != app:
            raise Exception("failed to jump to app {}".format(app))
        else:
            print "successfully jumped to app {}".format(active)

    def bl_read(self, device, offset, count):
        if count > self.BW_BL_READ_MAX:
            raise Exception("bad count %d > %d" % (count, self.BW_BL_READ_MAX))

        if (device != self.BW_DEV_FLASH) and (device != self.BW_DEV_EEPROM):
            raise Exception("bad device %d" % device)

        txar = [self.BW_BL_CMD_READ, device,
                offset & 0xff, (offset >> 8) & 0xff,
                (offset >> 16) & 0xff, (offset >> 24) & 0xff,
                count & 0xff, (count >> 8) & 0xff]

        rx_buf = self.bw_bl_xact(txar, 2 + count)

        rx_cnt = rx_buf.pop(0)
        rx_cnt |= rx_buf.pop(0) << 8

        if (rx_cnt != count) or (rx_cnt != len(rx_buf)):
            raise Exception("bad rx_cnt 0x%x != 0x%x != 0x%x" %
                            (rx_cnt, count, len(rx_buf)))

        return rx_buf

    def bl_write(self, device, offset, txdata):
        count = len(txdata)

        if count > self.BW_BL_WRITE_MAX:
            raise Exception("bad len %d > %d" %
                            (count, self.BW_BL_WRITE_MAX))

        if (device != self.BW_DEV_FLASH) and (device != self.BW_DEV_EEPROM):
            raise Exception("bad device %d" % device)

        txar = [self.BW_BL_CMD_WRITE, device,
                offset & 0xff, (offset >> 8) & 0xff,
                (offset >> 16) & 0xff, (offset >> 24) & 0xff,
                count & 0xff, (count >> 8) & 0xff]

        for data in txdata:
            txar.append(data)

        rx_buf = self.bw_bl_xact(txar, 2)

        tx_cnt = rx_buf.pop(0)
        tx_cnt |= rx_buf.pop(0) << 8

        if tx_cnt != count:
            raise Exception("bad tx_cnt 0x%x != 0x%x" %
                            (tx_cnt, count))

    def verify_partition(self, part):
        print "verifying %s from 0x%x to 0x%x" % (
            part['name'], part['start'],
            part['hdr_start'] + self.BW_BL_HDR_SIZE)

        valid = True
        for offset in range(part['start'],
                            part['hdr_start'] + self.BW_BL_HDR_SIZE,
                            self.BW_BL_READ_MAX):

            rx_buf = self.bl_read(self.BW_DEV_FLASH, offset,
                                  self.BW_BL_READ_MAX)

            for i in range(self.BW_BL_READ_MAX):
                if rx_buf[i] != self.ihex[offset+i]:
                    print "mismatch at offset 0x%x 0x%x != 0x%x" % (
                        offset+i, rx_buf[i], self.ihex[offset+i])
                    valid = False

        return valid

    def verify_partitions(self, utype):
        valid = True
        for part in self.partitions:
            if part['app'] == utype:
                if not self.verify_partition(part):
                    valid = False

        return valid

    def write_range(self, ih, start, end):
        for offset in range(start, end, self.BW_BL_WRITE_MAX):

            data = []
            for i in range(offset, offset+self.BW_BL_WRITE_MAX):
                data.append(ih[i])

            print "    0x%x - 0x%x %d" % (
                offset, offset + self.BW_BL_WRITE_MAX, len(data))
            self.bl_write(self.BW_DEV_FLASH, offset, data)

    def write_page0(self):
        ih = IntelHex()
        jump_bl_instrs = [0x82, 0xe0, 0x8c, 0xbf, 0xe0, 0xe0, 0xf0, 0xe0,
                          0x19, 0x94, 0x0c, 0x94, 0x00, 0x00, 0xff, 0xcf]
        for i in range(len(jump_bl_instrs)):
            ih[i] = jump_bl_instrs[i]

        self.write_range(ih, 0, self.BW_BL_PAGE_SIZE)

    def write_partitions(self, utype):
        wrote_page0 = False
        for part in self.partitions:
            if part['app'] == utype:
                start = part['start']
                print "updating %s from 0x%x to 0x%x" % (
                    part['name'], start,
                    part['hdr_start'] + self.BW_BL_HDR_SIZE)

                if part['start'] == 0:
                    self.write_page0()
                    start = self.BW_BL_PAGE_SIZE
                    wrote_page0 = True

                self.write_range(self.ihex, start,
                                 part['hdr_start'] + self.BW_BL_HDR_SIZE)
        if wrote_page0:
            self.write_range(self.ihex, 0, self.BW_BL_PAGE_SIZE)


def bmc_update(utype, ifile, bdf):
    dev = get_dev_bmc(bdf)

    bw_bmc = BittwareBmc(dev, ifile)

    ver, active = bw_bmc.bl_version()

    print "ver %d act %d" % (ver, active)

    if utype == 'bmc_bl':
        bw_bmc.verify_segments(bw_bmc.BW_ACT_APP_BL)

        if active != bw_bmc.BW_ACT_APP_MAIN:
            bw_bmc.bl_jump_other(bw_bmc.BW_ACT_APP_MAIN)

        bw_bmc.write_partitions(bw_bmc.BW_ACT_APP_BL)
        bw_bmc.verify_partitions(bw_bmc.BW_ACT_APP_BL)
    elif utype == 'bmc_app':
        bw_bmc.verify_segments(bw_bmc.BW_ACT_APP_MAIN)

        if active != bw_bmc.BW_ACT_APP_BL:
            bw_bmc.bl_jump_other(bw_bmc.BW_ACT_APP_BL)
        bw_bmc.write_partitions(bw_bmc.BW_ACT_APP_MAIN)
        bw_bmc.verify_partitions(bw_bmc.BW_ACT_APP_MAIN)
        bw_bmc.bl_jump_other(bw_bmc.BW_ACT_APP_MAIN)
    else:
        raise Exception("unknown utype: %s" % utype)

    return 0


def main():
    args = parse_args()

    bdf_map = get_bdf_mtd_mapping()

    if len(bdf_map) == 0:
        print "No FPGA devices found"
        sys.exit(1)

    bdf = args.bdf

    if not bdf:
        if len(bdf_map) > 1:
            print "Must specify a bdf. More than one device found."
            print_bdf_mtd_mapping(bdf_map)
        else:
            bdf = bdf_map.keys()[0]
    else:
        bdf = normalize_bdf(bdf)
        if not bdf:
            print "{} is an invalid bdf".format(bdf)
            sys.exit(1)
        elif bdf not in bdf_map.keys():
            print "Could not find flash device for {}".format(bdf)
            print_bdf_mtd_mapping(bdf_map)

    if (args.type == 'bmc_bl') or (args.type == 'bmc_app'):
        ret = bmc_update(args.type, args.file, bdf)
    else:
        ret = fpga_update(args.type, args.file, bdf, bdf_map)

    sys.exit(ret)


if __name__ == "__main__":
    main()
