import { DataFrame, FieldType } from '@grafana/data';

import { clone } from 'lodash';

import { Band, DataFrameOutlierIntervals, OutlierResults } from 'api/types';

import { thdmedian, thdnanmedian } from './trimmedQuantile';

function sensitivityToThreshold(sensitivity: number): number {
  // TODO: scale this estimation better with sensitivity

  // Z-score at which individual datapoints are considered an outliers
  // higher sensitivity = lower threshold value (e.g. lower threshold)
  const maxT = 7.941444487; // percentile = 0.9999999999999999
  const minT = 0.841621234; // percentile = 0.80
  const threshold = maxT - (maxT - minT) * Math.sqrt(sensitivity);
  return threshold;
}

// see https://en.wikipedia.org/wiki/Median_absolute_deviation
// scale factor k to approximate standard deviation of a normal distribution
const MADk = 1.4826;

type NumericOrUndefined = number | undefined;

export interface MADDoubleMedianData {
  lowerMedian: number;
  globalMedian: number;
  upperMedian: number;
}

export interface MADScoresData extends MADDoubleMedianData {
  madScores: number[][];
}

export interface MADData extends MADScoresData {
  timestamps: number[];
}

export abstract class MAD {
  static preprocess(alignedDataFrame: DataFrame, precalculatedDoubleMedian: MADDoubleMedianData): MADData {
    const rawData: number[][] = [];
    let timestamps = [];
    // Get data into most useful format here
    for (const field of alignedDataFrame.fields) {
      if (field.type === FieldType.number) {
        rawData.push(field.values);
      } else if (field.type === FieldType.time) {
        timestamps = field.values;
      }
    }

    // MAD - Median Absolute Deviation
    const { madScores, lowerMedian, globalMedian, upperMedian } = MAD.calculateMAD(rawData, precalculatedDoubleMedian);
    return { madScores, lowerMedian, globalMedian, upperMedian, timestamps };
  }

  static computeDoubleMedian(X: NumericOrUndefined[][]) {
    const globalMedian = thdnanmedian(X.flat());
    const absDeviations: number[][] = new Array(X.length);
    for (let i = 0; i < X.length; i++) {
      absDeviations[i] = new Array(X[i].length);
      for (let j = 0; j < X[i].length; j++) {
        const x_ij = X[i][j];
        absDeviations[i][j] = x_ij !== undefined && Number.isFinite(x_ij) ? Math.abs(x_ij - globalMedian) : NaN;
      }
    }

    // split into lower and upper half of the global median
    const absDeviationsLower = [];
    const absDeviationsUpper = [];
    for (let i = 0; i < absDeviations.length; i++) {
      for (let j = 0; j < absDeviations[i].length; j++) {
        const x_ij = X[i][j];
        if (x_ij === undefined) {
          continue;
        }
        if (x_ij <= globalMedian) {
          absDeviationsLower.push(absDeviations[i][j]);
        }
        if (x_ij >= globalMedian) {
          absDeviationsUpper.push(absDeviations[i][j]);
        }
      }
    }

    const lowerMedian = thdmedian(absDeviationsLower);
    const upperMedian = thdmedian(absDeviationsUpper);

    if (lowerMedian === 0 || upperMedian === 0) {
      throw new Error(
        'MAD Detector: Unavoidable divide by zero encountered (the median of the absolute ' +
          'difference of the data from its median is zero, usually happens when a majority ' +
          'of the data points are identical)'
      );
    }
    return { lowerMedian: lowerMedian, globalMedian: globalMedian, upperMedian: upperMedian };
  }

  // calculateMAD with number or undefined
  // optionally take precalculated medians
  static calculateMAD(X: NumericOrUndefined[][], precalculatedMedianData?: MADDoubleMedianData): MADScoresData {
    // calculate medians
    const { lowerMedian, globalMedian, upperMedian } =
      precalculatedMedianData &&
      Number.isFinite(precalculatedMedianData.lowerMedian) &&
      Number.isFinite(precalculatedMedianData.globalMedian) &&
      Number.isFinite(precalculatedMedianData.upperMedian)
        ? precalculatedMedianData
        : MAD.computeDoubleMedian(X);

    const madScores: number[][] = new Array(X.length);
    for (let i = 0; i < X.length; i++) {
      madScores[i] = new Array(X[i].length);
      for (let j = 0; j < X[i].length; j++) {
        const x_ij = X[i][j];
        if (!Number.isFinite(x_ij) || x_ij === null || x_ij === undefined) {
          madScores[i][j] = NaN;
        } else if (x_ij < globalMedian) {
          madScores[i][j] = Math.abs(x_ij - globalMedian) / (MADk * lowerMedian);
        } else if (x_ij === globalMedian) {
          madScores[i][j] = 0;
        } else {
          madScores[i][j] = Math.abs(x_ij - globalMedian) / (MADk * upperMedian);
        }

        // if there is a divide by 0 in the median, replace Infinity with NaN
        if (madScores[i][j] === Infinity) {
          madScores[i][j] = NaN;
        }
      }
    }

    return { madScores: madScores, lowerMedian: lowerMedian, globalMedian: globalMedian, upperMedian: upperMedian };
  }

  static run(data: MADData, sensitivity: number): OutlierResults {
    const { madScores, lowerMedian, globalMedian, upperMedian, timestamps } = data;
    const threshold = sensitivityToThreshold(sensitivity);

    const timestampCount = timestamps?.length ?? 0;

    const upperLimit = globalMedian + MADk * upperMedian * threshold;
    const lowerLimit = globalMedian - MADk * lowerMedian * threshold;

    const seriesCount = madScores.length;

    // Main "normal" boundary = region encompassing the threshold of data set to be normal
    const normalBand: Band = [
      Array(timestampCount).fill(lowerLimit), // min
      Array(timestampCount).fill(upperLimit), // max
    ];
    const outlierIntervals: DataFrameOutlierIntervals = {}; // {index: [ts_start,ts_end,ts_start,ts_end....]}
    let outliersSoFar: number[] = [];

    // For each series that has outliers, find the intervals where they are outliers.
    for (let t = 0; t < timestampCount; t++) {
      const ts = timestamps[t] as number;

      const outliers: number[] = [];
      for (let s = 0; s < seriesCount; s++) {
        if (madScores[s]![t]! > threshold) {
          outliers.push(s);
        }
      }

      // For each series that has outliers, find the intervals where they are outliers.
      if (outliers.length > 0) {
        // Compare with outliersSoFar
        // What was in previous outliers, but now not
        const stoppedBeingOutlier = outliersSoFar.filter((x) => !outliers.includes(x));
        for (const stoppedIndex of stoppedBeingOutlier) {
          outlierIntervals[stoppedIndex]!.push(ts);
        }

        // What has started being outlier
        const startedBeingOutlier = outliers.filter((x) => !outliersSoFar.includes(x));
        for (const startedIndex of startedBeingOutlier) {
          if (startedIndex in outlierIntervals) {
            outlierIntervals[startedIndex]!.push(ts);
          } else {
            outlierIntervals[startedIndex] = [ts];
          }
        }

        outliersSoFar = clone(outliers);
      } else {
        // all series considered normal at this timestamp, so take all outliersSoFar entries,
        // mark them as stopped and empty it
        for (const stoppedIndex of outliersSoFar) {
          outlierIntervals[stoppedIndex]!.push(ts);
        }
        outliersSoFar = [];
      }
    }

    return { outlierIntervals: outlierIntervals, normalBand: normalBand };
  }
}
