import { MATRIX_FILL_METHODS } from "../../../../constants";
import Matrix from "ml-matrix";
import { PCA } from "ml-pca";
import { truncateNumber } from "../../../../services/utils";

function prepareMatrix(data, selectedDatesForTraits, fillMethod) {
  const columns = Object.entries(selectedDatesForTraits).flatMap(
    ([technicalName, dates]) =>
      Array.from(dates).map((date) => [technicalName, date])
  );

  const plotsData = {};
  // Create a tracking object for each column to store values that are non-null
  const columnValidValues = Array.from({ length: columns.length }, () => ({
    sum: 0,
    count: 0,
  }));

  // For each item in featuresByDate, push data below the corresponding column
  data.forEach((feature) => {
    /* For each plot id, create a table of the size corresponding to columns. Also store features for later use
     * Remaining null values will be filled or dropped later depending on the fillMethod */
    if (!plotsData[feature.id])
      plotsData[feature.id] = {
        features: [],
        data: Array(columns.length).fill(null),
      };

    plotsData[feature.id].features.push(feature);

    columns.forEach(([technicalName, date], index) => {
      /* Skip every column that are different than the current feature's date
       * Making sure we populate the proper column corresponding index with data */
      if (!date || feature.date === date) {
        plotsData[feature.id].data[index] =
          feature.properties[technicalName] != null
            ? Number(feature.properties[technicalName])
            : null;
        // Store valid values for each column to calculate means
        if (
          feature.properties[technicalName] != null &&
          !isNaN(feature.properties[technicalName])
        ) {
          columnValidValues[index].sum += feature.properties[technicalName];
          columnValidValues[index].count += 1;
        }
      }
    });
  });

  // Calculate means of each column using only non null values
  const means = columnValidValues.map(({ sum, count }) =>
    count > 0 ? sum / count : 0
  );

  const matrix = [];
  const rows = [];

  Object.entries(plotsData).forEach(([id, { data, features }]) => {
    let filledData = data;
    if (data.includes(null) || data.includes(NaN)) {
      // Drop rows with missing data
      if (fillMethod === MATRIX_FILL_METHODS.DROP) return;
      // Fill null values with column mean
      if (fillMethod === MATRIX_FILL_METHODS.MEAN)
        filledData = data.map((value, index) =>
          value === null || isNaN(value) ? means[index] : value
        );
    }

    matrix.push(filledData);
    rows.push({
      displayId: features[0].displayId,
      group: features[0].group,
      id,
      features,
    });
  });

  return { matrix, columns, rows };
}

export function performPCA(data, selectedDatesForTraits, fillMethod) {
  const { matrix, columns, rows } = prepareMatrix(
    data,
    selectedDatesForTraits,
    fillMethod
  );

  const centered = new Matrix(matrix).center("column").scale("column");
  if (centered.isEmpty()) return null;
  const pca = new PCA(centered);

  const eignevectors = pca.getEigenvectors();
  const explainedVariance = pca.getExplainedVariance();
  const plotMatrix = centered.mmul(eignevectors);
  const plotValues = rows.map((plot, index) => ({
    ...plot,
    values: plotMatrix.getRow(index).slice(0, explainedVariance.length),
  }));

  const components = explainedVariance.map((variance, index) => ({
    id: index,
    name: `Comp. ${index + 1} (${truncateNumber(variance * 100, 1)}%)`,
    variance,
  }));

  const loadings = pca.getLoadings().data;

  return {
    components,
    plotValues,
    columns,
    loadings,
  };
}
