#!/usr/bin/env python3
""" tst_ofdm_demod_check

    Testing for tst_ofdm_demod_* tests

    Usage tst_ofdm_demod_check <dummy_test_name> quick|ideal|AWGN|fade|profile|ldpc|ldpc_AWGN|ldpc_fade

    Checks are different for each option, but similar

    - Convert stm32 output to octave text foramt
      (stm32 does not have memory for this)

    - ...

    """

import numpy as np
import math
import argparse
import struct
import os
import sys

if ("UNITTEST_BASE" in os.environ):
    sys.path.append(os.environ["UNITTEST_BASE"] + "/lib/python")
else:
    sys.path.append("../../lib/python")	# assume in test run dir

import sum_profiles

Nbitsperframe = 238

##############################################################################
# Read Octave text file
##############################################################################

def read_octave_text(f):
    if (args.verbose): print('read_octave_text()')
    data = {}
    for line in f:
        if (line[0:8] == "# name: "):
          var = line.split()[2]
          if (args.verbose): print('  var "{}"'.format(var))
          line = next(f)
          if (line.startswith("# type: matrix")):
            line = next(f)
            rows = int(line.split()[2])
            line = next(f)
            cols = int(line.split()[2])
            if (cols > 0):
              data[var] = np.empty((rows, cols), np.float32)
              # Read rows one at a time
              for row in range(rows):
                try:
                  line = next(f)
                  data[var][row] = np.fromstring(line, np.float32, cols, " ")
                except:
                  print("Error reading row {} of var {}".format(row, var))
                  raise

          elif (line.startswith("# type: complex matrix")):
            line = next(f)
            rows = int(line.split()[2])
            line = next(f)
            cols = int(line.split()[2])
            if (cols > 0):
              data[var] = np.empty((rows, cols), np.complex64)
              # Read rows one at a time
              for row in range(rows):
                try:
                  line = next(f)
                  # " (r,i) (r,i) ..."
                  col = 0
                  for tpl in line.split():
                    real, imag = tpl.strip("(,)").split(",")
                    data[var][row][col] = float(real) + (1j * float(imag))
                    col += 1
                except:
                  print("Error reading row {} of var {}".format(row, var))
                  raise
        # end for line in f
    return(data)


##############################################################################
# Read stm32 diag data, syms, amps for each frame
##############################################################################

def read_tgt_syms(f):
    # TODO: don't use hardcoded values...
    syms = np.zeros((100, 112), np.complex64)
    amps = np.zeros((100, 112), np.float32)
    row = 0
    while True:
        # syms
        buf = f.read(112 * 8)
        if (len(buf) < (112 * 8)): break
        row_lst = struct.unpack("<224f", buf)
        ary = np.array(row_lst, np.float32)
        ary.dtype = np.complex64
        syms[row] = ary
        # amps
        buf = f.read(112 * 4)
        if (len(buf) < (112 * 4)): break
        row_lst = struct.unpack("<112f", buf)
        ary = np.array(row_lst, np.float32)
        amps[row] = ary
        #
        row += 1
        if (row >= 100): break
        # end While True
    return(syms, amps)
    # end read_stm_syms()


##############################################################################
# Write out in octave text format as 2 matricies
##############################################################################

def write_syms_as_octave(syms, amps):
    with open("ofdm_demod_log.txt", "w") as f:
        # syms
        rows = syms.shape[0]
        cols = syms.shape[1]
        f.write("# name: payload_syms_log_stm32\n")
        f.write("# type: complex matrix\n")
        f.write("# rows: {}\n".format(rows))
        f.write("# columns: {}\n".format(cols))
        for row in range(rows):
            for col in range(cols):
                f.write(" ({},{})".format(
                    syms[row][col].real,
                    syms[row][col].imag
                    ))
            f.write("\n")
        # amps
        rows = amps.shape[0]
        cols = amps.shape[1]
        f.write("\n")
        f.write("# name: payload_amps_log_stm32\n")
        f.write("# type: matrix\n")
        f.write("# rows: {}\n".format(rows))
        f.write("# columns: {}\n".format(cols))
        for row in range(rows):
            for col in range(cols):
                f.write(" {}".format(
                    amps[row][col]
                    ))
            f.write("\n")
    # end write_syms_as_octave()


##############################################################################
# Main
##############################################################################

#### Options
argparser = argparse.ArgumentParser()
argparser.add_argument("-v", "--verbose", action="store_true")
argparser.add_argument("test", action="store")
argparser.add_argument("test_opt", action="store",
                        choices=["quick", "ideal", "AWGN", "fade", "profile",
                                 "ldpc", "ldpc_AWGN", "ldpc_fade" ])
args = argparser.parse_args()

# Use ENV value of UNITTEST_BASE from upper level shell script (default to .)
if ('UNITTEST_BASE' in os.environ):
    run_dir = os.environ['UNITTEST_BASE'] + "/test_run/"
    run_dir += args.test + "_" + args.test_opt
    print(run_dir)
    os.chdir(run_dir)

#### Settings
# Defaults, (for tests without channel degradation, results should be close to ideal)
max_ber         = 0.001  # Max BER value in Target
max_ber2        = 0.001  # Max Coded BER value in Target
compare_ber     = 1      # Compare Target to Reference?
# Used if compare_ber:
tolerance_ber   = 0.001  # Difference from reference for BER
tolerance_ber2  = 0.001  # Difference from reference for Coded BER
tolerance_tbits = 0
tolerance_terrs = 1
#
compare_output  = 1        # Compare Target to Reference?
# Used if compare_output:
tolerance_output_differences = 0
tolerance_syms = 0.01
tolerance_amps = 0.01
#
# Per test settings
if (args.test_opt == "quick"):
    pass
elif (args.test_opt == "ideal"):
    pass
elif (args.test_opt == "AWGN"):     # Still close enough to compare BERs loosely
    max_ber         = 0.1
    max_ber2        = 0.1
    tolerance_ber   = 0.01
    tolerance_ber2  = 0.005
    tolerance_tbits = 1000
    tolerance_terrs = 50
    tolerance_output_differences = 2
    compare_output  = 0
elif (args.test_opt == "fade"):
    max_ber         = 0.1
    max_ber2        = 0.1
    tolerance_ber   = 0.01
    tolerance_ber2  = 0.005
    tolerance_tbits = 1000
    tolerance_terrs = 200
    tolerance_output_differences = 5
    compare_output  = 0
    pass
elif (args.test_opt == "profile"):
    tolerance_output_differences = 1
    pass
elif (args.test_opt == "ldpc"):
    pass
elif (args.test_opt == "ldpc_AWGN"):
    max_ber         = 0.1
    max_ber2        = 0.01
    compare_ber     = 0
    compare_output  = 0
elif (args.test_opt == "ldpc_fade"):
    max_ber         = 0.1
    max_ber2        = 0.01
    compare_ber     = 0
    compare_output  = 0
    pass
else:
    print("Error: Test {} not recognized".format(args.test_opt))
    sys.exit(1)

#### Check that we are in the test directory:
#### TODO:::


#### Read test configuration - a file of '0' or '1' characters
with open("stm_cfg.txt", "r") as f:
    config = f.read(8)
    config_verbose = (config[0] == '1')
    config_testframes = (config[1] == '1')
    config_ldpc_en = (config[2] == '1')
    config_log_payload_syms = (config[3] == '1')
    config_profile = (config[4] == '1')


####
fails = 0


if (config_testframes):
    #### BER checks - log output looks like this, for non-ldpc:
    # BER......: 0.1951 Tbits: 14994 Terrs:  2926
    # BER2.....: 0.2001 Tbits: 10234 Terrs:  2048
    #
    # Or this, for ldpc:
    # BER......: 0.0000 Tbits: 15008 Terrs:     0
    # Coded BER: 0.0000 Tbits:  7504 Terrs:     0
    #
    # HACK: store "Coded BER" info as BER2.

    print("\nBER checks")

    # Read ref log
    print("Reference")
    with open("ref_gen_log.txt", "r") as f:
        for line in f:
            if (line[0:4] == "BER2"):
                print(line, end="")
                _, ref_ber2, _, ref_tbits2, _, ref_terrs2 = line.split()
            elif (line[0:3] == "BER"):
                print(line, end="")
                _, ref_ber, _, ref_tbits, _, ref_terrs, _, ref_tpackets, _, ref_snr3k = line.split()
            elif (line[0:9] == "Coded BER"):
                print(line, end="")
                _, _, ref_ber2, _, ref_tbits2, _, ref_terrs2 = line.split()

    # Strings to integers
    ref_ber = float(ref_ber)
    ref_tbits = int(ref_tbits)
    ref_terrs = int(ref_terrs)
    ref_ber2 = float(ref_ber2)
    ref_tbits2 = int(ref_tbits2)
    ref_terrs2 = int(ref_terrs2)

    # Read stm log
    print("Target")
    with open("stdout.log", "r") as f:
        for line in f:
            if (line[0:4] == "BER2"):
                print(line, end="")
                _, tgt_ber2, _, tgt_tbits2, _, tgt_terrs2 = line.split()
            elif (line[0:3] == "BER"):
                print(line, end="")
                _, tgt_ber, _, tgt_tbits, _, tgt_terrs = line.split()
            elif (line[0:9] == "Coded BER"):
                print(line, end="")
                _, _, tgt_ber2, _, tgt_tbits2, _, tgt_terrs2 = line.split()
    # Strings to integers
    tgt_ber = float(tgt_ber)
    tgt_tbits = int(tgt_tbits)
    tgt_terrs = int(tgt_terrs)
    tgt_ber2 = float(tgt_ber2)
    tgt_tbits2 = int(tgt_tbits2)
    tgt_terrs2 = int(tgt_terrs2)

    # simple hack to tolerate zero bits > NAN
    if (math.isnan(ref_ber2)): ref_ber2 = 0
    if (math.isnan(tgt_ber2)): tgt_ber2 = 0

    ## Max BER values
    if ((tgt_ber > max_ber) or (tgt_ber2 > max_ber2)):
        fails += 1
        print("FAIL: max BER")
    else:
        print("PASS: max BER")

    ## Compare BER values
    if (compare_ber):
        chk_tolerance_ber    = abs(ref_ber - tgt_ber)
        chk_tolerance_tbits  = abs(ref_tbits - tgt_tbits)
        chk_tolerance_terrs  = abs(ref_terrs - tgt_terrs)
        chk_tolerance_ber2   = abs(ref_ber2 - tgt_ber2)
        chk_tolerance_tbits2 = abs(ref_tbits2 - tgt_tbits2)
        chk_tolerance_terrs2 = abs(ref_terrs2 - tgt_terrs2)
        passes = True

        if (chk_tolerance_ber > tolerance_ber):
            print("fail tolerance_ber {} > {}".
                format(chk_tolerance_ber, tolerance_ber))
            passes = False
        if (chk_tolerance_tbits > tolerance_tbits):
            print("fail tolerance_tbits {} > {}".
                format(chk_tolerance_tbits, tolerance_tbits))
            passes = False
        if (chk_tolerance_terrs > tolerance_terrs):
            print("fail tolerance_terrs {} > {}".
                format(chk_tolerance_terrs, tolerance_terrs))
            passes = False
        if (ref_tbits2 == 0):
           if (chk_tolerance_ber2   > tolerance_ber2):
                print("fail tolerance_ber2 {} > {}".
                    format(chk_tolerance_ber2, tolerance_ber2))
                passes = False
           if (chk_tolerance_tbits2 > tolerance_tbits):
                print("fail tolerance_tbits2 {} > {}".
                    format(chk_tolerance_tbits2, tolerance_tbits))
                passes = False
           if (chk_tolerance_terrs2 > tolerance_terrs):
                print("fail tolerance_terrs2 {} > {}".
                    format(chk_tolerance_terrs2, tolerance_terrs))
                passes = False
          
        if (passes): 
            print("PASS: BER compare")
        else:
            fails += 1
            print("FAIL: BER compare")

        # end Compare BER
    # end BER checks

#### Output differences
if (compare_output):

    print("\nOutput checks")

    # Output is a binary file of bytes whose values are 0x00 or 0x01.
    with open("ref_demod_out.raw", "rb") as f: ref_out_bytes = f.read()
    with open("stm_out.raw", "rb") as f:       tgt_out_bytes = f.read()
    if (len(ref_out_bytes) != len(tgt_out_bytes)):
        fails += 1
        print("FAIL Output, length mismatch")
    else:
        output_diffs = 0
        for i in range(len(ref_out_bytes)):
            fnum = math.floor(i/Nbitsperframe)
            bnum = i - (fnum * Nbitsperframe)
            # Both legal values??
            if (ref_out_bytes[i] > 1):
                print("Error: Output frame {} byte {} not 0 or 1 in reference data".format(fnum, bnum))
                fails += 1
            if (tgt_out_bytes[i] > 1):
                print("Error: Output frame {} byte {} not 0 or 1 in target data".format(fnum, bnum))
                fails += 1
            # Match??
            if (ref_out_bytes[i] != tgt_out_bytes[i]):
                print("Output frame {} byte {} mismatch: ref={} tgt={}".format(
                    fnum, bnum, ref_out_bytes[i], tgt_out_bytes[i]))
                output_diffs += 1
            # end for i
        if (output_diffs > tolerance_output_differences):
            print("FAIL: Output Differences = {}".format(output_diffs))
            fails += 1
        else:
            print("PASS: Output Differences = {}".format(output_diffs))
        # end not length mismatch


    #### Syms data 
    if (config_log_payload_syms):
        print("\nSyms and Amps checks")

        fref  = open("ofdm_demod_ref_log.txt", "r")
        fdiag = open("stm_diag.raw", "rb")

        ref_data = read_octave_text(fref)
        (tgt_syms, tgt_amps) = read_tgt_syms(fdiag)
        fdiag.close()
        write_syms_as_octave(tgt_syms, tgt_amps) # for manual debug...

        # Find smallest common subset
        hgt = min(tgt_syms.shape[0], ref_data["payload_syms_log_c"].shape[0])
        wid = min(tgt_syms.shape[1], ref_data["payload_syms_log_c"].shape[1])

        ref_syms =  ref_data["payload_syms_log_c"][:hgt][:wid]
        ref_amps =  ref_data["payload_amps_log_c"][:hgt][:wid]
        tgt_syms=  tgt_syms[:hgt][:wid]
        tgt_amps=  tgt_amps[:hgt][:wid]

        # Eliminate trailing rows of all zeros
        # Sum the rows to find rows of all zeros
        row_sums = ref_syms.sum(axis=1) + tgt_syms.sum(axis=1)
        nonzeros = row_sums.nonzero()
        last_nonzero = nonzeros[0][-1]
        # stop index is 1 past the last!!
        # and use the Magnitude of the complex values
        ref_syms = np.abs(ref_syms[:last_nonzero+1])
        ref_amps = np.abs(ref_amps[:last_nonzero+1])
        tgt_syms = np.abs(tgt_syms[:last_nonzero+1])
        tgt_amps = np.abs(tgt_amps[:last_nonzero+1])

        # Differences - Syms
        #diffs_syms = np.abs(ref_syms - tgt_syms)    # This is the mag of complex
        diffs_syms = np.abs(np.divide((ref_syms - tgt_syms), ref_syms, 
                                      where=(ref_syms!=0)))
        print("Minimum syms difference = {:.6f}".format(np.amin(diffs_syms)))
        print("Maximum syms difference = {:.6f}".format(np.amax(diffs_syms)))
        print("Average syms difference = {:.6f}".format(np.average(diffs_syms)))
        if (args.verbose):  # Print top 10 differences
            diffs_syms_sorted_indexes = (diffs_syms).argsort(axis=None)[::-1]
            print(" Top 10 differences")
            for i in range(10):
              j = diffs_syms_sorted_indexes[i]
              print("  #{} @{}: {} <?> {} = {:.6f}".format(
                i, j, 
                ref_syms.flatten()[j], tgt_syms.flatten()[j], diffs_syms.flatten()[j])
                )
        # Errors are differences > tolerance_syms
        errors_syms = diffs_syms - tolerance_syms
        errors_syms[errors_syms < 0.0] = 0.0
        num_errors_syms = np.count_nonzero(errors_syms)
        error_rows_syms = np.amax(errors_syms, axis=1)
        num_error_rows_syms = np.count_nonzero(error_rows_syms)
        print("")
        print("{} symbol errors on {} rows".format(num_errors_syms, num_error_rows_syms))

        # Differences - Amps
        diffs_amps = np.abs(np.divide((ref_amps - tgt_amps), ref_amps, 
                                      where=(ref_amps!=0)))
        print("Minimum amps difference = {:.6f}".format(np.amin(diffs_amps)))
        print("Maximum amps difference = {:.6f}".format(np.amax(diffs_amps)))
        print("Average amps difference = {:.6f}".format(np.average(diffs_amps)))
        if (args.verbose):  # Print top 10 differences
            diffs_amps_sorted_indexes = (diffs_amps).argsort(axis=None)[::-1]
            print(" Top 10 differences")
            for i in range(10):
              j = diffs_amps_sorted_indexes[i]
              print("  #{} @{}: {} <?> {} = {:.6f}".format(
                i, j, 
                ref_amps.flatten()[j], tgt_amps.flatten()[j], diffs_amps.flatten()[j])
                )

        # Errors are differences > tolerance_syms
        errors_amps = diffs_amps - tolerance_amps
        errors_amps[errors_amps < 0.0] = 0.0
        num_errors_amps = np.count_nonzero(errors_amps)
        error_rows_amps = np.amax(errors_amps, axis=1)
        num_error_rows_amps = np.count_nonzero(error_rows_amps)
        print("")
        print("{} Amplitude errors on {} rows".format(num_errors_amps, num_error_rows_amps))
    # End compare_output


#### Profile
if (config_profile): 
    print("\nProfile:")
    with open("stdout.log", "r") as f:
        sum_profiles.sum_profiles(f, 100)

    print("\nStack:")
    with open("stdout.txt", "r") as f:
        for line in f:
            if (line.startswith("Max stack")): 
                print(line)


#### Print final status message
if (fails): print("\nTest FAILED!")
else: print("\nTest PASSED")

sys.exit(fails)
