//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      GUI/Model/Descriptor/DistributionItems.cpp
//! @brief     Implements class DistributionItem and several subclasses
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2022
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "GUI/Model/Descriptor/DistributionItems.h"
#include "GUI/Support/XML/UtilXML.h"
#include "Param/Distrib/Distributions.h"
namespace {
namespace Tag {

const QString NumberOfSamples("NumberOfSamples");
const QString RelSamplingWidth("RelSamplingWidth");
const QString Mean("Mean");
const QString Minimum("Minimum");
const QString Maximum("Maximum");
const QString HWHM("HWHM");
const QString StandardDeviation("StandardDeviation");
const QString Median("Median");
const QString ScaleParameter("ScaleParameter");
const QString Sigma("Sigma");
const QString Center("Center");
const QString LeftWidth("LeftWidth");
const QString MiddleWidth("MiddleWidth");
const QString RightWidth("RightWidth");
const QString BaseData("BaseData");

} // namespace Tag
} // namespace

using std::variant;

DistributionItem::DistributionItem() = default;

void DistributionItem::initRelSamplingWidth()
{
    m_relSamplingWidth.init("Rel. sampling width", "", 2.0, Unit::unitless, "relSamplingWidth");
}

bool DistributionItem::hasRelSamplingWidth() const
{
    return m_relSamplingWidth.isInitialized();
}

void DistributionItem::writeTo(QXmlStreamWriter* w) const
{
    XML::writeAttribute(w, XML::Attrib::version, uint(1));

    // a changed unit (with setUnit) will not be serialized here. They have to be set
    // again by the owner of DistributionItem after reading it

    // number of samples
    w->writeStartElement(Tag::NumberOfSamples);
    XML::writeAttribute(w, XML::Attrib::value, m_numberOfSamples);
    w->writeEndElement();

    // relative sampling width
    if (hasRelSamplingWidth()) {
        w->writeStartElement(Tag::RelSamplingWidth);
        m_relSamplingWidth.writeTo(w);
        w->writeEndElement();
    }
}

void DistributionItem::readFrom(QXmlStreamReader* r)
{
    const uint version = XML::readUIntAttribute(r, XML::Attrib::version);
    Q_UNUSED(version)

    while (r->readNextStartElement()) {
        QString tag = r->name().toString();

        // number of samples
        if (tag == Tag::NumberOfSamples) {
            XML::readAttribute(r, XML::Attrib::value, &m_numberOfSamples);
            XML::gotoEndElementOfTag(r, tag);

            // relative sampling width
        } else if (tag == Tag::RelSamplingWidth) {
            m_relSamplingWidth.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

        } else
            r->skipCurrentElement();
    }
}

// --------------------------------------------------------------------------------------------- //

SymmetricResolutionItem::SymmetricResolutionItem(double mean, int decimals,
                                                 const QString& meanLabel)
{
    m_mean.init(meanLabel, "", mean, Unit::unitless, decimals, RealLimits::limitless(), "mean");
}

void SymmetricResolutionItem::setUnit(const variant<QString, Unit>& unit)
{
    m_mean.setUnit(unit);
}

void SymmetricResolutionItem::setMeanDecimals(uint d)
{
    m_mean.setDecimals(d);
}

void SymmetricResolutionItem::writeTo(QXmlStreamWriter* w) const
{
    XML::writeAttribute(w, XML::Attrib::version, uint(1));

    // parameters from base class
    w->writeStartElement(Tag::BaseData);
    DistributionItem::writeTo(w);
    w->writeEndElement();

    // mean
    w->writeStartElement(Tag::Mean);
    m_mean.writeTo(w);
    w->writeEndElement();
}

void SymmetricResolutionItem::readFrom(QXmlStreamReader* r)
{
    const uint version = XML::readUIntAttribute(r, XML::Attrib::version);
    Q_UNUSED(version)

    while (r->readNextStartElement()) {
        QString tag = r->name().toString();

        // parameters from base class
        if (tag == Tag::BaseData) {
            DistributionItem::readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // mean
        } else if (tag == Tag::Mean) {
            m_mean.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

        } else
            r->skipCurrentElement();
    }
}

// --------------------------------------------------------------------------------------------- //

DistributionNoneItem::DistributionNoneItem()
    : SymmetricResolutionItem(0.1, 3, "Value")
{
}

std::unique_ptr<IDistribution1D> DistributionNoneItem::createDistribution(double) const
{
    return nullptr;
}

double DistributionNoneItem::deviation(double) const
{
    return 0.0;
}

void DistributionNoneItem::initDistribution(double value)
{
    setMean(value);
}

DoubleProperties DistributionNoneItem::distributionValues(bool withMean)
{
    return withMean ? DoubleProperties{&m_mean} : DoubleProperties{};
}

// --------------------------------------------------------------------------------------------- //

DistributionGateItem::DistributionGateItem()
{
    m_minimum.init("Min", "", 0.0, Unit::unitless, 3 /* decimals */, RealLimits::limitless(),
                   "min");
    m_maximum.init("Max", "", 1.0, Unit::unitless, 3 /* decimals */, RealLimits::limitless(),
                   "max");
}

void DistributionGateItem::setUnit(const variant<QString, Unit>& unit)
{
    m_minimum.setUnit(unit);
    m_maximum.setUnit(unit);
}

std::unique_ptr<IDistribution1D> DistributionGateItem::createDistribution(double scale) const
{
    return std::make_unique<DistributionGate>(scale * m_minimum.value(), scale * m_maximum.value(),
                                              m_numberOfSamples);
}

void DistributionGateItem::initDistribution(double value)
{
    double sigma(0.1 * std::abs(value));
    if (sigma == 0.0)
        sigma = 0.1;
    m_minimum.setValue(value - sigma);
    m_maximum.setValue(value + sigma);
}

void DistributionGateItem::setRange(double min, double max)
{
    m_minimum.setValue(min);
    m_maximum.setValue(max);
}

void DistributionGateItem::writeTo(QXmlStreamWriter* w) const
{
    XML::writeAttribute(w, XML::Attrib::version, uint(1));

    // parameters from base class
    w->writeStartElement(Tag::BaseData);
    DistributionItem::writeTo(w);
    w->writeEndElement();

    // minimum
    w->writeStartElement(Tag::Minimum);
    m_minimum.writeTo(w);
    w->writeEndElement();

    // maximum
    w->writeStartElement(Tag::Maximum);
    m_maximum.writeTo(w);
    w->writeEndElement();
}

void DistributionGateItem::readFrom(QXmlStreamReader* r)
{
    const uint version = XML::readUIntAttribute(r, XML::Attrib::version);
    Q_UNUSED(version)

    while (r->readNextStartElement()) {
        QString tag = r->name().toString();

        // parameters from base class
        if (tag == Tag::BaseData) {
            DistributionItem::readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // minimum
        } else if (tag == Tag::Minimum) {
            m_minimum.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // maximum
        } else if (tag == Tag::Maximum) {
            m_maximum.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

        } else
            r->skipCurrentElement();
    }
}

DoubleProperties DistributionGateItem::distributionValues(bool /*withMean*/)
{
    return {&m_minimum, &m_maximum};
}

// --------------------------------------------------------------------------------------------- //

DistributionLorentzItem::DistributionLorentzItem()
    : SymmetricResolutionItem(1.0)
{
    initRelSamplingWidth();

    m_hwhm.init("HWHM", "", 1.0, Unit::unitless, "hwhm");
}

void DistributionLorentzItem::setUnit(const variant<QString, Unit>& unit)
{
    SymmetricResolutionItem::setUnit(unit);
    m_hwhm.setUnit(unit);
}

std::unique_ptr<IDistribution1D> DistributionLorentzItem::createDistribution(double scale) const
{
    return std::make_unique<DistributionLorentz>(scale * m_mean.value(), scale * m_hwhm.value(),
                                                 m_numberOfSamples, m_relSamplingWidth);
}

double DistributionLorentzItem::deviation(double scale) const
{
    return m_hwhm * scale;
}

void DistributionLorentzItem::initDistribution(double value)
{
    double hw(0.1 * std::abs(value));
    if (hw == 0.0)
        hw = 0.1;

    setMean(value);
    setHwhm(hw);
}

void DistributionLorentzItem::writeTo(QXmlStreamWriter* w) const
{
    XML::writeAttribute(w, XML::Attrib::version, uint(1));

    // parameters from base class
    w->writeStartElement(Tag::BaseData);
    SymmetricResolutionItem::writeTo(w);
    w->writeEndElement();

    // half width at half maximum
    w->writeStartElement(Tag::HWHM);
    m_hwhm.writeTo(w);
    w->writeEndElement();
}

void DistributionLorentzItem::readFrom(QXmlStreamReader* r)
{
    const uint version = XML::readUIntAttribute(r, XML::Attrib::version);
    Q_UNUSED(version)

    while (r->readNextStartElement()) {
        QString tag = r->name().toString();

        // parameters from base class
        if (tag == Tag::BaseData) {
            SymmetricResolutionItem::readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // half width at half maximum
        } else if (tag == Tag::HWHM) {
            m_hwhm.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

        } else
            r->skipCurrentElement();
    }
}

DoubleProperties DistributionLorentzItem::distributionValues(bool withMean)
{
    return withMean ? DoubleProperties{&m_mean, &m_hwhm, &m_relSamplingWidth}
                    : DoubleProperties{&m_hwhm, &m_relSamplingWidth};
}

// --------------------------------------------------------------------------------------------- //

DistributionGaussianItem::DistributionGaussianItem()
    : SymmetricResolutionItem(1.0)
{
    initRelSamplingWidth();

    m_standardDeviation.init("StdDev", "", 1.0, Unit::unitless, 3 /* decimals */,
                             RealLimits::lowerLimited(0.0), "stdDev");
}

void DistributionGaussianItem::setUnit(const variant<QString, Unit>& unit)
{
    SymmetricResolutionItem::setUnit(unit);
    m_standardDeviation.setUnit(unit);
}

std::unique_ptr<IDistribution1D> DistributionGaussianItem::createDistribution(double scale) const
{
    return std::make_unique<DistributionGaussian>(scale * m_mean.value(),
                                                  scale * m_standardDeviation.value(),
                                                  m_numberOfSamples, m_relSamplingWidth);
}

double DistributionGaussianItem::deviation(double scale) const
{
    return m_standardDeviation.value() * scale;
}

void DistributionGaussianItem::initDistribution(double value)
{
    double stddev(0.1 * std::abs(value));
    if (stddev == 0.0)
        stddev = 0.1;

    setMean(value);
    setStandardDeviation(stddev);
}

void DistributionGaussianItem::writeTo(QXmlStreamWriter* w) const
{
    XML::writeAttribute(w, XML::Attrib::version, uint(1));

    // parameters from base class
    w->writeStartElement(Tag::BaseData);
    SymmetricResolutionItem::writeTo(w);
    w->writeEndElement();

    // standard deviation
    w->writeStartElement(Tag::StandardDeviation);
    m_standardDeviation.writeTo(w);
    w->writeEndElement();
}

void DistributionGaussianItem::readFrom(QXmlStreamReader* r)
{
    const uint version = XML::readUIntAttribute(r, XML::Attrib::version);
    Q_UNUSED(version)

    while (r->readNextStartElement()) {
        QString tag = r->name().toString();

        // parameters from base class
        if (tag == Tag::BaseData) {
            SymmetricResolutionItem::readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // standard deviation
        } else if (tag == Tag::StandardDeviation) {
            m_standardDeviation.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

        } else
            r->skipCurrentElement();
    }
}

DoubleProperties DistributionGaussianItem::distributionValues(bool withMean)
{
    return withMean ? DoubleProperties{&m_mean, &m_standardDeviation, &m_relSamplingWidth}
                    : DoubleProperties{&m_standardDeviation, &m_relSamplingWidth};
}

// --------------------------------------------------------------------------------------------- //

DistributionLogNormalItem::DistributionLogNormalItem()
{
    initRelSamplingWidth();

    m_median.init("Median", "", 1.0, Unit::unitless, "median");
    m_scaleParameter.init("ScaleParameter", "", 1.0, Unit::unitless, 3 /* decimals */,
                          RealLimits::lowerLimited(0.0), "scalePar");
}

void DistributionLogNormalItem::setUnit(const variant<QString, Unit>& unit)
{
    m_median.setUnit(unit);
}

std::unique_ptr<IDistribution1D> DistributionLogNormalItem::createDistribution(double scale) const
{
    return std::make_unique<DistributionLogNormal>(
        scale * m_median.value(), m_scaleParameter.value(), m_numberOfSamples, m_relSamplingWidth);
}

void DistributionLogNormalItem::initDistribution(double value)
{
    double scale(0.1 * std::abs(value));
    if (scale == 0.0)
        scale = 0.1;

    setMedian(value);
    setScaleParameter(scale);
}

void DistributionLogNormalItem::writeTo(QXmlStreamWriter* w) const
{
    XML::writeAttribute(w, XML::Attrib::version, uint(1));

    // parameters from base class
    w->writeStartElement(Tag::BaseData);
    DistributionItem::writeTo(w);
    w->writeEndElement();

    // median
    w->writeStartElement(Tag::Median);
    m_median.writeTo(w);
    w->writeEndElement();

    // scale parameter
    w->writeStartElement(Tag::ScaleParameter);
    m_scaleParameter.writeTo(w);
    w->writeEndElement();
}

void DistributionLogNormalItem::readFrom(QXmlStreamReader* r)
{
    const uint version = XML::readUIntAttribute(r, XML::Attrib::version);
    Q_UNUSED(version)

    while (r->readNextStartElement()) {
        QString tag = r->name().toString();

        // parameters from base class
        if (tag == Tag::BaseData) {
            DistributionItem::readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // median
        } else if (tag == Tag::Median) {
            m_median.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // scale parameter
        } else if (tag == Tag::ScaleParameter) {
            m_scaleParameter.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

        } else
            r->skipCurrentElement();
    }
}

DoubleProperties DistributionLogNormalItem::distributionValues(bool /*withMean*/)
{
    return {&m_median, &m_scaleParameter, &m_relSamplingWidth};
}

// --------------------------------------------------------------------------------------------- //

DistributionCosineItem::DistributionCosineItem()
    : SymmetricResolutionItem(1.0)
{
    m_hwhm.init("HWHM", "", 1.0, Unit::unitless, 3 /* decimals */, RealLimits::lowerLimited(0.0),
                "hwhm");
}

void DistributionCosineItem::setUnit(const variant<QString, Unit>& unit)
{
    SymmetricResolutionItem::setUnit(unit);
    m_hwhm.setUnit(unit);
}

std::unique_ptr<IDistribution1D> DistributionCosineItem::createDistribution(double scale) const
{
    return std::make_unique<DistributionCosine>(scale * m_mean.value(), scale * m_hwhm.value(),
                                                m_numberOfSamples);
}

double DistributionCosineItem::deviation(double scale) const
{
    return m_hwhm.value() * scale;
}

void DistributionCosineItem::initDistribution(double value)
{
    double sigma(0.1 * std::abs(value));
    if (sigma == 0.0)
        sigma = 0.1;

    setMean(value);
    setHwhm(sigma);
}

void DistributionCosineItem::writeTo(QXmlStreamWriter* w) const
{
    XML::writeAttribute(w, XML::Attrib::version, uint(1));

    // parameters from base class
    w->writeStartElement(Tag::BaseData);
    SymmetricResolutionItem::writeTo(w);
    w->writeEndElement();

    // sigma
    w->writeStartElement(Tag::Sigma);
    m_hwhm.writeTo(w);
    w->writeEndElement();
}

void DistributionCosineItem::readFrom(QXmlStreamReader* r)
{
    const uint version = XML::readUIntAttribute(r, XML::Attrib::version);
    Q_UNUSED(version)

    while (r->readNextStartElement()) {
        QString tag = r->name().toString();

        // parameters from base class
        if (tag == Tag::BaseData) {
            SymmetricResolutionItem::readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // sigma
        } else if (tag == Tag::Sigma) {
            m_hwhm.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

        } else
            r->skipCurrentElement();
    }
}

DoubleProperties DistributionCosineItem::distributionValues(bool withMean)
{
    return withMean ? DoubleProperties{&m_mean, &m_hwhm} : DoubleProperties{&m_hwhm};
}

// --------------------------------------------------------------------------------------------- //

DistributionTrapezoidItem::DistributionTrapezoidItem()
{
    m_center.init("Center", "", 1.0, Unit::unitless, 3 /* decimals */, RealLimits::limitless(),
                  "center");
    m_leftWidth.init("LeftWidth", "", 1.0, Unit::unitless, "left");
    m_middleWidth.init("MiddleWidth", "", 1.0, Unit::unitless, "middle");
    m_rightWidth.init("RightWidth", "", 1.0, Unit::unitless, "right");
}

void DistributionTrapezoidItem::setUnit(const variant<QString, Unit>& unit)
{
    m_center.setUnit(unit);
    m_leftWidth.setUnit(unit);
    m_middleWidth.setUnit(unit);
    m_rightWidth.setUnit(unit);
}

std::unique_ptr<IDistribution1D> DistributionTrapezoidItem::createDistribution(double scale) const
{
    return std::make_unique<DistributionTrapezoid>(
        scale * m_center.value(), scale * m_leftWidth.value(), scale * m_middleWidth.value(),
        scale * m_rightWidth.value(), m_numberOfSamples);
}

void DistributionTrapezoidItem::initDistribution(double value)
{
    double width(0.1 * std::abs(value));
    if (width == 0.0)
        width = 0.1;
    setCenter(value);
    setLeftWidth(width);
    setMiddleWidth(width);
    setRightWidth(width);
}

void DistributionTrapezoidItem::writeTo(QXmlStreamWriter* w) const
{
    XML::writeAttribute(w, XML::Attrib::version, uint(1));

    // parameters from base class
    w->writeStartElement(Tag::BaseData);
    DistributionItem::writeTo(w);
    w->writeEndElement();

    // center
    w->writeStartElement(Tag::Center);
    m_center.writeTo(w);
    w->writeEndElement();

    // left width
    w->writeStartElement(Tag::LeftWidth);
    m_leftWidth.writeTo(w);
    w->writeEndElement();

    // middle width
    w->writeStartElement(Tag::MiddleWidth);
    m_middleWidth.writeTo(w);
    w->writeEndElement();

    // right width
    w->writeStartElement(Tag::RightWidth);
    m_rightWidth.writeTo(w);
    w->writeEndElement();
}

void DistributionTrapezoidItem::readFrom(QXmlStreamReader* r)
{
    const uint version = XML::readUIntAttribute(r, XML::Attrib::version);
    Q_UNUSED(version)

    while (r->readNextStartElement()) {
        QString tag = r->name().toString();

        // parameters from base class
        if (tag == Tag::BaseData) {
            DistributionItem::readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // center
        } else if (tag == Tag::Center) {
            m_center.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // left width
        } else if (tag == Tag::LeftWidth) {
            m_leftWidth.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // middle width
        } else if (tag == Tag::MiddleWidth) {
            m_middleWidth.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

            // right width
        } else if (tag == Tag::RightWidth) {
            m_rightWidth.readFrom(r);
            XML::gotoEndElementOfTag(r, tag);

        } else
            r->skipCurrentElement();
    }
}

DoubleProperties DistributionTrapezoidItem::distributionValues(bool withMean)
{
    return withMean ? DoubleProperties{&m_center, &m_leftWidth, &m_middleWidth, &m_rightWidth}
                    : DoubleProperties{&m_leftWidth, &m_middleWidth, &m_rightWidth};
}
