/*
  This file is part of CDO. CDO is a collection of Operators to manipulate and analyse Climate model Data.

  Author: Uwe Schulzweida

*/

#include <cdi.h>

#include "process_int.h"
#include "field_functions.h"

static void
check_unique_zaxis(int vlistID)
{
  auto nzaxis = vlistNzaxis(vlistID);
  auto zaxisID = vlistZaxis(vlistID, 0);
  auto nlevels = zaxisInqSize(zaxisID);
  for (int index = 1; index < nzaxis; ++index)
    {
      if (nlevels != zaxisInqSize(vlistZaxis(vlistID, index))) cdo_abort("Number of level differ!");
    }
}

static void
check_unique_gridsize(int vlistID)
{
  auto ngrids = vlistNgrids(vlistID);
  auto gridID = vlistGrid(vlistID, 0);
  auto gridsize = gridInqSize(gridID);
  for (int index = 0; index < ngrids; ++index)
    {
      if (gridsize != gridInqSize(vlistGrid(vlistID, index))) cdo_abort("Horizontal gridsize differ!");
    }
}

static void
set_attributes(const VarList &varList1, int vlistID2, int varID2, int operatorID)
{
  const auto &var0 = varList1[0];
  auto paramIsEqual = true;
  auto name = var0.name;
  auto param = var0.param;
  int nvars = varList1.size();
  for (int varID = 1; varID < nvars; ++varID)
    {
      if (param != varList1[varID].param || name != varList1[varID].name)
        {
          paramIsEqual = false;
          break;
        }
    }

  if (!paramIsEqual) name = cdo_operator_name(operatorID);
  cdiDefKeyString(vlistID2, varID2, CDI_KEY_NAME, name.c_str());
  vlistDefVarMissval(vlistID2, varID2, var0.missval);

  if (paramIsEqual)
    {
      if (param >= 0) vlistDefVarParam(vlistID2, varID2, param);
      if (var0.longname.size()) cdiDefKeyString(vlistID2, varID2, CDI_KEY_LONGNAME, var0.longname.c_str());
      if (var0.units.size()) cdiDefKeyString(vlistID2, varID2, CDI_KEY_UNITS, var0.units.c_str());
    }
}

class Varsstat : public Process
{
public:
  using Process::Process;
  inline static CdoModule module = {
    .name = "Varsstat",
    .operators = { { "varsrange", FieldFunc_Range, 0, VarsstatHelp },
                   { "varsmin", FieldFunc_Min, 0, VarsstatHelp },
                   { "varsmax", FieldFunc_Max, 0, VarsstatHelp },
                   { "varssum", FieldFunc_Sum, 0, VarsstatHelp },
                   { "varsmean", FieldFunc_Mean, 0, VarsstatHelp },
                   { "varsavg", FieldFunc_Avg, 0, VarsstatHelp },
                   { "varsstd", FieldFunc_Std, 0, VarsstatHelp },
                   { "varsstd1", FieldFunc_Std1, 0, VarsstatHelp },
                   { "varsvar", FieldFunc_Var, 0, VarsstatHelp },
                   { "varsvar1", FieldFunc_Var1, 0, VarsstatHelp } },
    .aliases = {},
    .mode = EXPOSED,     // Module mode: 0:intern 1:extern
    .number = CDI_REAL,  // Allowed number type
    .constraints = { 1, 1, NoRestriction },
  };
  inline static RegisterEntry<Varsstat> registration = RegisterEntry<Varsstat>(module);
  CdoStreamID streamID1;
  int taxisID1;

  CdoStreamID streamID2;
  int taxisID2;
  int vlistID2;

  bool lrange;
  bool lvarstd;
  bool lmean;
  bool lstd;
  double divisor;

  int nlevels;

  int operfunc;

  Field field;
  FieldVector vars1, samp1, vars2;
  VarList varList1;

public:
  void
  init()
  {

    auto operatorID = cdo_operator_id();
    operfunc = cdo_operator_f1(operatorID);

    operator_check_argc(0);

    lrange = (operfunc == FieldFunc_Range);
    lmean = (operfunc == FieldFunc_Mean || operfunc == FieldFunc_Avg);
    lstd = (operfunc == FieldFunc_Std || operfunc == FieldFunc_Std1);
    lvarstd = (lstd || operfunc == FieldFunc_Var || operfunc == FieldFunc_Var1);
    auto lvars2 = (lvarstd || lrange);
    divisor = (operfunc == FieldFunc_Std1 || operfunc == FieldFunc_Var1);

    streamID1 = cdo_open_read(0);
    auto vlistID1 = cdo_stream_inq_vlist(streamID1);

    varList_init(varList1, vlistID1);

    check_unique_zaxis(vlistID1);
    auto zaxisID = vlistZaxis(vlistID1, 0);
    nlevels = zaxisInqSize(zaxisID);

    check_unique_gridsize(vlistID1);
    auto gridID = vlistGrid(vlistID1, 0);
    auto gridsize = gridInqSize(gridID);

    auto timetype = varList1[0].timetype;
    auto nvars = vlistNvars(vlistID1);
    for (int varID = 1; varID < nvars; ++varID)
      {
        if (timetype != varList1[varID].timetype) cdo_abort("Number of timesteps differ!");
      }

    vlistID2 = vlistCreate();
    vlistDefNtsteps(vlistID2, vlistNtsteps(vlistID1));

    auto varID2 = vlistDefVar(vlistID2, gridID, zaxisID, timetype);
    set_attributes(varList1, vlistID2, varID2, operatorID);

    taxisID1 = vlistInqTaxis(vlistID1);
    taxisID2 = taxisDuplicate(taxisID1);
    vlistDefTaxis(vlistID2, taxisID2);

    streamID2 = cdo_open_write(1);
    cdo_def_vlist(streamID2, vlistID2);

    vars1 = FieldVector(nlevels);
    samp1 = FieldVector(nlevels);
    if (lvars2) vars2.resize(nlevels);

    for (int levelID = 0; levelID < nlevels; ++levelID)
      {
        auto missval = varList1[0].missval;

        samp1[levelID].grid = gridID;
        samp1[levelID].missval = missval;
        samp1[levelID].memType = MemType::Double;
        vars1[levelID].grid = gridID;
        vars1[levelID].missval = missval;
        vars1[levelID].memType = MemType::Double;
        vars1[levelID].resize(gridsize);
        if (lvars2)
          {
            vars2[levelID].grid = gridID;
            vars2[levelID].missval = missval;
            vars2[levelID].memType = MemType::Double;
            vars2[levelID].resize(gridsize);
          }
      }
  }

  void
  run()
  {
    auto field2_stdvar_func = lstd ? field2_std : field2_var;
    auto fieldc_stdvar_func = lstd ? fieldc_std : fieldc_var;

    int tsID = 0;
    while (true)
      {
        auto nrecs = cdo_stream_inq_timestep(streamID1, tsID);
        if (nrecs == 0) break;

        cdo_taxis_copy_timestep(taxisID2, taxisID1);
        cdo_def_timestep(streamID2, tsID);

        for (int recID = 0; recID < nrecs; ++recID)
          {
            int varID, levelID;
            cdo_inq_record(streamID1, &varID, &levelID);

            auto &rsamp1 = samp1[levelID];
            auto &rvars1 = vars1[levelID];

            rvars1.nsamp++;
            if (lrange) vars2[levelID].nsamp++;

            if (varID == 0)
              {
                cdo_read_record(streamID1, rvars1);
                if (lrange)
                  {
                    vars2[levelID].numMissVals = rvars1.numMissVals;
                    vars2[levelID].vec_d = rvars1.vec_d;
                  }

                if (lvarstd) field2_moq(vars2[levelID], rvars1);

                if (rvars1.numMissVals || !rsamp1.empty())
                  {
                    if (rsamp1.empty()) rsamp1.resize(rvars1.size);
                    field2_vinit(rsamp1, rvars1);
                  }
              }
            else
              {
                field.init(varList1[varID]);
                cdo_read_record(streamID1, field);

                if (field.numMissVals || !rsamp1.empty())
                  {
                    if (rsamp1.empty()) rsamp1.resize(rvars1.size, rvars1.nsamp);
                    field2_vincr(rsamp1, field);
                  }

                // clang-format off
                if      (lvarstd) field2_sumsumq(rvars1, vars2[levelID], field);
                else if (lrange)  field2_maxmin(rvars1, vars2[levelID], field);
                else              field2_function(rvars1, field, operfunc);
                // clang-format on
              }
          }

        for (int levelID = 0; levelID < nlevels; ++levelID)
          {
            const auto &rsamp1 = samp1[levelID];
            auto &rvars1 = vars1[levelID];

            if (rvars1.nsamp)
              {
                if (lmean)
                  {
                    if (!rsamp1.empty())
                      field2_div(rvars1, rsamp1);
                    else
                      fieldc_div(rvars1, (double) rvars1.nsamp);
                  }
                else if (lvarstd)
                  {
                    if (!rsamp1.empty())
                      field2_stdvar_func(rvars1, vars2[levelID], rsamp1, divisor);
                    else
                      fieldc_stdvar_func(rvars1, vars2[levelID], rvars1.nsamp, divisor);
                  }
                else if (lrange) { field2_sub(rvars1, vars2[levelID]); }

                cdo_def_record(streamID2, 0, levelID);
                cdo_write_record(streamID2, rvars1);
                rvars1.nsamp = 0;
              }
          }

        tsID++;
      }
  }

  void
  close()
  {
    cdo_stream_close(streamID2);
    cdo_stream_close(streamID1);

    vlistDestroy(vlistID2);
  }
};
