#!/usr/bin/env python
import sys
import os
import pickle
import math
import numpy
from os.path import join as pjoin
import optparse
from glob import glob

from ufl.algorithms import load_forms

import ufc
import ufc_benchmark

#import sfc
#import logging
#sfc.set_logging_level(logging.DEBUG)

# --- Basic Utilities ---

inf = 1e9999
nan = inf*0

# Taken from http://ivory.idyll.org/blog/mar-07/replacing-commands-with-subprocess
from subprocess import Popen, PIPE, STDOUT
def get_status_output(cmd, input=None, cwd=None, env=None):
    pipe = Popen(cmd, shell=True, cwd=cwd, env=env, stdout=PIPE, stderr=STDOUT)
    (output, errout) = pipe.communicate(input=input)
    assert not errout
    status = pipe.returncode
    return (status, output)

def runcmd(cmd):
    get_status_output(cmd)

def write_file(filename, text):
    "Write text to a file and close it."
    f = open(filename, "w")
    f.write(text)
    f.close()
    print "Wrote file '%s'" % filename

# --- Option Parsing Data ---

usage = """Compile UFL forms, compute element tensors, and compare with reference values.

Examples:

  FIXME: Write usage examples.
"""

def opt(long, short, t, default, help):
    return optparse.make_option("--%s" % long, "-%s" % short, action="store", type=t, dest=long, default=default, help=help)

option_list = [ \
    # Directories:
    opt("ufldir",            "u",   "str",    "ufl",                "Input directory with .ufl files."),
    opt("outputdir",         "o",   "str",    "output",             "Output directory to write .ref files to."),
    opt("referencedir",      "r",   "str",    "reference",          "Reference directory to read .ref files from."),
    opt("cachedir",          "c",   "str",    "cache",              "Reference directory to read .ref files from."),
    # Main behaviour options:
    opt("skip",              "s",   "str",    "",                   "Comma-separated list of ufl files to skip."),
    opt("write",             "w",   "int",    0,                    "Write new reference files."),
    opt("debug",             "d",   "int",    0,                    "Print debugging info for this script."),
    # Form compiler options:
    opt("jit_options",       "j",   "str",    "options/default.py", "Python file containing jit options."),
    # Comparison options:
    opt("tolerance",         "t",   "float",  1e-10,                "Compare norm of data difference to this tolerance."),
    opt("norm",              "n",   "int",    1,                    "Compare data with references using the L2 norm of tensor difference."),
    opt("eig",               "e",   "int",    1,                    "Compare data with references using the eigenvalues."),
    opt("random_cell",       "a",   "int",    1,                    "Use a (predefined) random cell instead of reference cell."),
    opt("benchmark",         "b",   "int",    1,                    "Measure the time to call tabulate_tensor."),
    ]

# --- Main routine ---

def main(args):
    "Handle commandline arguments and orchestrate the tests."
    
    # Parse commandline arguments
    parser = optparse.OptionParser(usage=usage, option_list=option_list)
    (options, args) = parser.parse_args(args=args)
    if args:
        print "ERROR: Got additional unknown arguments: ", args
        parser.print_usage()
        return -1
    
    # Read input directory and filter filenames
    skip = set(s.strip() for s in options.skip.split(","))
    def skipmatch(fn):
        if fn in skip: return True
        path, name = os.path.split(fn)
        if name in skip: return True
        basename, ext = os.path.splitext(name)
        if basename in skip: return True
        return False
    uflfiles = glob(pjoin(options.ufldir, "*.ufl"))
    uflfiles = [f for f in uflfiles if not skipmatch(f)]
    uflfiles = sorted(uflfiles)
    
    if options.debug:
        print "."*40
        print "Got uflfiles ="
        print "\n".join("  " + f for f in uflfiles)
    
    # Handle each .ufl file separately
    fails = []
    passes = []
    summaries = []
    for filename in uflfiles:
        summary, ok = handle_file(filename, options)
        summaries.append(summary)
        if ok:
            passes.append(filename)
        else:
            fails.append(filename)
    
    # Print summaries
    print 
    print "="*80
    print "Summaries:",
    sep = "\n\n" + "-"*60 + "\n"
    print sep + sep.join(summaries)
    
    # Print files that passed and failed
    if passes:
        print "="*80
        print "The following files passed:"
        print "\n".join("  " + f for f in sorted(passes))
    if fails:
        print "="*80
        print "The following files failed:"
        print "\n".join("  " + f for f in sorted(fails))

def import_options_iterator(name):
    "Import options iterator from a python file."
    assert os.path.exists(name)

    path, name  = os.path.split(name)
    basename, dotpy = os.path.splitext(name)
    assert dotpy in (".py", "")

    sys.path.insert(0, path)
    #cwd = os.getcwd()
    #os.chdir(path)
    try:
        options_module = __import__(basename)
        iter_jit_options = options_module.options
    finally:
        sys.path.pop(0)
        #os.chdir(cwd)
    
    return iter_jit_options

def handle_file(filename, options):
    "Handle all aspects of testing a single .ufl file."

    # Split filename
    uflfilename = filename
    path, name = os.path.split(uflfilename)
    basename, ext = os.path.splitext(name)
    if ext != ".ufl":
        msg = "Expecting a .ufl file, not %s" % uflfilename
        return (msg, False)
 
    if options.debug:
        print "."*40   
        print "In handle_file, filename parsed as:"
        print "filename    =", filename
        print "uflfilename =", uflfilename
        print "path        =", path
        print "name        =", name
        print "basename    =", basename
        print "ext         =", ext
    
    # Load forms from this file
    try:
        forms = load_forms(uflfilename)
    except:
        msg = "Failed to load file, try running\n\n"\
              "  ufl-analyse %s\n\n"\
              "to find the bug in the form file or in UFL." % uflfilename
        return (msg, False)

    formnames = [form.form_data().name for form in forms]
    
    if options.debug:
        print "."*40   
        print "In handle_file, forms loaded: ", ", ".join(formnames)
    
    # Iterate over a given set of jit compilers and options:
    iter_jit_options = import_options_iterator(options.jit_options)
    if iter_jit_options is None:
        msg = "Failed to import options module '%s'." % options.jit_options
        return (msg, False)
    
    total_ok = True
    summary = ""
    for jit, jit_options in iter_jit_options():
        
        #jit_options.cache_dir = options.cache_dir # FIXME
        
        # Compile forms with given compiler and options
        try:
            jit_result = jit(forms, jit_options)
        except:
            msg = "Failed to jit-compile forms in file '%s'." % uflfilename
            raise
            #return (msg, False)
        
        if options.debug:
            print ":"*60
            print "Jit result:"
            print jit_result
        
        # Compute some data for each form from result of jit compilation
        data = {}
        #compiled_forms = jit_result # Ideally
        #compiled_forms, form_datas = jit_result # Previous
        compiled_forms, module, form_datas = jit_result # Current
        for i in range(len(forms)):
            form_data = forms[i].form_data()
            assert form_data is form_datas[i]
            data[form_data.name] = compute_data(compiled_forms[i], form_data, options.random_cell)
        
        # Benchmark the generated tabulate_tensor implementations
        benchmark_data = {}
        if options.benchmark:
            for i in range(len(forms)):
                form_data = forms[i].form_data()
                assert form_data is form_datas[i]
                compiled_form = compiled_forms[i]
                result = ufc_benchmark.benchmark_forms([compiled_form], False)
                benchmark_data[form_data.name] = result
        
        # Store results for future referencing
        if options.write:
            outputdir = options.referencedir
        else:
            outputdir = options.outputdir
        if options.debug:
            print "."*60
            print "outputdir =", outputdir
        
        for formname in formnames:
            outputfilename = pjoin(outputdir, "%s_%s.ref" % (basename, formname))
            if options.debug:
                print "Writing to output filename: ", outputfilename
            if not data[formname].any():
                print "*** Warning: reference tensor only contains zeros!"
            write_data(outputfilename, data[formname])
        
        if options.benchmark:
            for formname in formnames:
                outputfilename = pjoin(outputdir, "%s_%s.timing" % (basename, formname))
                if options.debug:
                    print "Writing to output filename: ", outputfilename
                write_data(outputfilename, benchmark_data[formname])

        # Compare to references unless we're writing the references
        if not options.write:
            # Read reference files
            reference = {}
            for formname in formnames:
                referencefilename = pjoin(options.referencedir, "%s_%s.ref" % (basename, formname))
                if options.debug:
                    print "Read reference filename: ", referencefilename
                reference[formname] = read_data(referencefilename)
  
            if options.benchmark:
                benchmark_reference = {}
                for formname in formnames:
                    referencefilename = pjoin(options.referencedir, "%s_%s.timing" % (basename, formname))
                    if options.debug:
                        print "Read reference filename: ", referencefilename
                    benchmark_reference[formname] = read_data(referencefilename)
  
            # Compare fresh data and reference data
            results = []
            for formname in formnames:
                msg, ok = compare_data(data[formname], reference[formname], options)
                
                if options.debug:
                    print "msg =", msg
                    print "ok = ", ok
                
                if not ok:
                    total_ok = False
                    results.append("--- Errors in form %s:\n" % formname)
                    results.append(msg)
                else:
                    results.append("--- Form %s is ok.\n" % formname)

                if options.benchmark:
                    msg = compare_benchmark(benchmark_data[formname], benchmark_reference[formname], options)
                    results.extend(msg)

            summary += "\n".join(results)
    
    # Return final report
    if total_ok:
        summary = "%s passed everything." % uflfilename
    else:
        summary = ("%s has problems:\n" % uflfilename) + summary
        if options.write:
            summary += "\n\nWrote reference files for %s." % uflfilename
    return (summary, total_ok)

def compute_diff_norm(data, reference, options):
    if options.debug:
        print ":"*40
        print "In compute_diff_norm, comparing:"
        print "data ="
        print data
        print "reference ="
        print reference
    
    # Compute difference
    diff = data - reference
    
    # Compute sums of squares
    d2 = numpy.sum(diff**2)
    r2 = numpy.sum(reference**2)
    n2 = numpy.sum(data**2)
    
    # Compute normalized norm
    norm = math.sqrt(d2 / r2)
    
    if options.debug:
        print "diff ="
        print diff
        print "d2, r2, n2="
        print d2, r2, n2
    
    # Norm from FFC, don't understand the motivation?
    #norm = math.sqrt(numpy.sum(diff / (1 + reference)))
    return norm

def compute_eigenvalues(data):
    sh = data.shape
    assert sh[0] == sh[1]
    from scipy.linalg import eig
    l, v = eig(data)
    return numpy.array(sorted(l))

def compare_data(data, reference, options):
    norm = nan
    eig = nan
    msg = ""
    if reference is None:
        total_ok = False
        msg += "No reference to compare to."
    else:
        total_ok = True
        
        if data.shape != reference.shape:
            total_ok = False
            msg += "\n  ERROR: Data shape is %s, reference shape is %s." % (data.shape, reference.shape)
        else:
            if options.norm:
                norm = compute_diff_norm(data, reference, options)
                ok = norm < options.tolerance
                if not ok:
                    total_ok = False
                    msg += "\n  norm = %e >= %e = tol" % (norm, options.tolerance)
 
            if options.eig:
                sh = data.shape
                #assert len(sh) == 2 # code elsewhere ensures data is represented as a matrix
 
                if len(sh) == 1: #sh[0] == 1 or sh[1] == 1:
                    # Got a vector, compare sorted vectors
                    eig1 = numpy.array(sorted(data))
                    eig2 = numpy.array(sorted(reference))
                    eig = sum((eig1-eig2)**2)
                    ok = eig < options.tolerance
                    if not ok:
                        total_ok = False
                        msg += "\n  eig = %e >= %e = tol" % (eig, options.tolerance)
 
                elif sh[0] == sh[1]:
                    # Got a matrix, compare matrix
                    eig1 = compute_eigenvalues(data)
                    eig2 = compute_eigenvalues(reference)
                    eig = sum((eig1-eig2)**2)
                    ok = eig < options.tolerance
                    if not ok:
                        total_ok = False
                        msg += "\n  eig = %e >= %e = tol" % (eig, options.tolerance)
                else:
                    if not options.norm: # if it has passed the norm test, don't mark it as failed
                        total_ok = False
                    msg += "\n  WARNING: Not computing eigenvalues of data with shape %s" % repr(sh)
    
    # Make and return summary   
    if total_ok:
        msg = "Data compared ok."
    else:
        msg = "Failed because:%s\n\ndata is\n%r\n\nreference is\n%r" % (msg, data, reference)
    return (msg, total_ok)

def compare_benchmark(data, reference, options):
    msg = []
    # For each form
    for (b1, b2) in zip(data, reference):
        # For each integral type
        for (x, y) in zip(b1, b2):
            # For each integral
            for (r,s) in zip(x,y):
                f = r/s
                if f > 1:
                    msg.append("The reference is faster than the new, f = %.2f" % f)
                else:
                    msg.append("The new is faster than the reference, f = %.2f" % f)
    return msg

def write_data(fn, data):
    try:
        f = open(fn, "w")
        pickle.dump(data, f)
        #f.write(data)
        f.close()
    except Exception, what:
        print "*** An error occured while trying to write reference file: %s" % fn
        raise

def read_data(fn):
    try:
        f = open(fn)
        data = pickle.load(f)
        #data = f.read()
        f.close()
    except Exception, what:
        print "*** An error occured while trying to load reference file: %s" % fn
        print "*** Maybe you need to generate the reference? Returning None."
        data = None
    return data

def make_mesh(cell_shape, random_cell):
    # Random cells were generated by:
    # >>> random.uniform(-2.5,2.5)
    if cell_shape == "interval":
        mesh = ufc_benchmark.Mesh(1, 1, [2, 1, 0])
        if random_cell:
            cell = ufc_benchmark.Cell(1, 1, [[-1.445], [0.4713]], [2, 1, 0, 0])
        else:
            cell = ufc_benchmark.Cell(1, 1, [[0], [1]], [2, 1, 0, 0])
    elif cell_shape == "triangle":
        mesh = ufc_benchmark.Mesh(2, 2, [3, 3, 1])
        if random_cell:
            cell = ufc_benchmark.Cell(2, 2, [[-2.2304, -0.88317], [1.3138, -1.0164],\
                                             [0.24622, 1.4431]], [3, 3, 1, 0])
        else:
            cell = ufc_benchmark.Cell(2, 2, [[0, 0], [1, 0], [1, 1]], [3, 3, 1, 0])
    elif cell_shape == "tetrahedron":
        mesh = ufc_benchmark.Mesh(3, 3, [4, 6, 4, 1])
        if random_cell:
            cell = ufc_benchmark.Cell(3, 3, [[-2.2561, -1.6144, -1.7349], [-1.5612, -1.5121, -0.17675],\
                                             [1.6861, -1.1494, 2.4070], [0.52083, 1.1940, 1.8220]], [4, 6, 4, 1])
        else:
            cell = ufc_benchmark.Cell(3, 3, [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]], [4, 6, 4, 1])
    else:
        raise RuntimeError, "Unknown shape " + cell_shape # FIXME: Define quadrilateral and hexahedral cells
    return mesh, cell

def compute_data(compiled_form, form_data, random_cell):
    # --- Initialize some variables
    rank             = form_data.rank
    num_coefficients = form_data.num_functions
    num_arguments    = num_coefficients + rank
    
    num_cell_integrals           = compiled_form.num_cell_integrals()
    num_exterior_facet_integrals = compiled_form.num_exterior_facet_integrals()
    num_interior_facet_integrals = compiled_form.num_interior_facet_integrals()

    # --- Initialize geometry variables
    mesh, cell = make_mesh(form_data.cell.domain(), random_cell)
    dim = form_data.geometric_dimension
    num_facets = cell.num_entities[dim - 1]

    # --- Initialize dofmaps
    dof_maps = [0]*num_arguments
    for i in range(num_arguments):
        dof_maps[i] = compiled_form.create_dof_map(i)
        dof_maps[i].init_mesh(mesh)

    # --- Generate arbitrary coefficient dofsdofs
    w = [0]*num_coefficients
    for i in range(num_coefficients):
        w[i] = [0]*(dof_maps[rank+i].local_dimension())
        for j in range(dof_maps[rank+i].local_dimension()):
            w[i][j] = 1.111 + (i + j)/1.111
    macro_w = [0]*num_coefficients
    for i in range(num_coefficients):
        macro_w[i] = [0]*(2*dof_maps[rank+i].local_dimension())
        for j in range(2*dof_maps[rank+i].local_dimension()):
            macro_w[i][j] = 1.111 + (i + j)/1.111
    
    # --- Target variables
    A = numpy.zeros((1,1))

    # --- Add contributions from ALL domains from cell integrals
    if num_cell_integrals:
        domain = 0
        # Get shape of A and reset values
        try:
            A = ufc_benchmark.tabulate_cell_integral(compiled_form, w, cell, domain)
            A = numpy.array(A)
            A = numpy.zeros(numpy.shape(A))
            for domain in range(num_cell_integrals):
                A += ufc_benchmark.tabulate_cell_integral(compiled_form, w, cell, domain)
        except Exception, what:
            print "*** An error occured while calling tabulate_cell_integral() for domain %d." % domain
            raise

    # --- Add contributions from ALL domains and facets from exterior integrals
    if num_exterior_facet_integrals:
        domain, facet = (0, 0)
        try:
            if not numpy.any(A):
                A = ufc_benchmark.tabulate_exterior_facet_integral(compiled_form, w, cell, facet, domain)
                A = numpy.array(A)
                A = numpy.zeros(numpy.shape(A))
            for domain in range(num_exterior_facet_integrals):
                for facet in range(num_facets):
                    A += ufc_benchmark.tabulate_exterior_facet_integral(compiled_form, w, cell, facet, domain)
        except Exception, what:
            print "*** An error occured while calling tabulate_exterior_facet_integral() for domain %d, facet %d." % (domain, facet)
            raise

    # --- Add contributions from ALL domains and facets from interior integrals
    # FIXME: this currently makes no sense (integrating interior facets on 1 cell)
    #        but it should be OK since we just compare numbers.
    macro_A = numpy.array([0.0])
    if num_interior_facet_integrals:
        domain, facet0, facet1 = (0,0,0)
        try:
            macro_A = ufc_benchmark.tabulate_interior_facet_integral(compiled_form, macro_w, cell, cell, facet0, facet1, domain)
            macro_A = numpy.array(macro_A)
            macro_A = numpy.zeros(numpy.shape(macro_A))
            for domain in range(num_interior_facet_integrals):
                for facet0 in range(num_facets):
                    for facet1 in range(num_facets):
                        macro_A += ufc_benchmark.tabulate_interior_facet_integral(compiled_form, macro_w, cell, cell, facet0, facet1, domain)
        except Exception, what:
            print "*** An error occured while calling tabulate_interior_facet_integral() for domain %d, facet0 %d, facet1 %d." % (domain, facet0, facet1)
            raise
    
    # Add A to the upper left quadrant of macro_A, it makes no sense,
    # but the numbers should be OK
    if not macro_A.any():
        data = A
    elif A.any():
        data = macro_A
        dims = A.shape
        data[:dims[0], :dims[1]] += A
    
    return data

# --- Execute! ---

if __name__ == "__main__":
    result = main(sys.argv[1:])
    sys.exit(result)
