package it.neckar.open.math

import kotlin.js.JsExport
import kotlin.math.exp

/**
 * Gompertz sigmoid function for curve fitting
 * taken from https://de.mathworks.com/matlabcentral/answers/1754060-sigmoid-function-shaping-and-fitting-by-curve-fitting-toolbox
 */
@JsExport
object GompertzSigmoid {

  // Gompertz function
  fun gompertzFunction(x: Double, a: Double, b: Double, c: Double, d: Double): Double {
    return a * exp(-exp(-b - c * (x - d)))
  }

  // Compute gradients for Gompertz curve fitting
  fun computeGompertzGradients(parameters: DoubleArray, x: DoubleArray, y: DoubleArray): DoubleArray {
    val a = parameters[0]
    val b = parameters[1]
    val c = parameters[2]
    val d = parameters[3]

    val predicted = x.map { gompertzFunction(it, a, b, c, d) }
    val errors = predicted.toList().zip(y.toList()).map { (predicted, actual) -> predicted - actual }

    val gradientA = errors.toList().zip(x.toList()).map { (error, xValue) ->
      error * gompertzFunction(xValue, a, b, c, d) * exp(-b - c * (xValue - d))
    }
    val gradientB = errors.toList().zip(x.toList()).map { (error, xValue) ->
      error * a * exp(-b - c * (xValue - d)) * gompertzFunction(xValue, a, b, c, d) * (1 - exp(-b - c * (xValue - d)))
    }
    val gradientC = errors.toList().zip(x.toList()).map { (error, xValue) ->
      error * a * (xValue - d) * exp(-b - c * (xValue - d)) * gompertzFunction(xValue, a, b, c, d) * (1 - exp(-b - c * (xValue - d)))
    }
    val gradientD = errors.toList().zip(x.toList()).map { (error, xValue) ->
      error * a * c * exp(-b - c * (xValue - d)) * gompertzFunction(xValue, a, b, c, d) * (1 - exp(-b - c * (xValue - d)))
    }

    return doubleArrayOf(gradientA.sum(), gradientB.sum(), gradientC.sum(), gradientD.sum())
  }

  // Fit Gompertz curve using built-in functions
  fun fitGompertz(x: DoubleArray, y: DoubleArray, a: Double, b: Double, c: Double, d: Double, learningRate: Double, maxIterations: Int): DoubleArray {
    val initialGuess = doubleArrayOf(a, b, c, d)

    var parameters = initialGuess.copyOf()

    repeat(maxIterations) {
      val gradients = computeGompertzGradients(parameters, x, y)
      parameters = parameters.zip(gradients).map { (param, gradient) ->
        param - learningRate * gradient
      }.toDoubleArray()
    }

    return parameters
  }

  /**
   * a*exp(-b*exp(-c*x))
   * taken from https://en.wikipedia.org/wiki/Gompertz_function
   * @param x the x value
   * @param a the asymptote
   * @param b the displacement on the x-axis
   * @param c the growth rate (y scaling)
   */
  fun wikiGompertz(x: Double, a: Double, b: Double, c: Double): Double {
    return a * exp(-b * exp(-c * x))
  }

  fun computeWikiGompertzGradients(parameters: DoubleArray, x: DoubleArray, y: DoubleArray): DoubleArray {
    val a = parameters[0]
    val b = parameters[1]
    val c = parameters[2]

    val predicted = x.map { wikiGompertz(it, a, b, c) }
    val errors = predicted.toList().zip(y.toList()).map { (predicted, actual) -> predicted - actual }

    val gradientA = errors.toList().zip(x.toList()).map { (error, xValue) ->
      error * wikiGompertz(xValue, a, b, c) * exp(-b * exp(-c * xValue))
    }
    val gradientB = errors.toList().zip(x.toList()).map { (error, xValue) ->
      error * a * exp(-b * exp(-c * xValue)) * wikiGompertz(xValue, a, b, c) * (1 - exp(-b * exp(-c * xValue)))
    }
    val gradientC = errors.toList().zip(x.toList()).map { (error, xValue) ->
      error * a * b * exp(-b * exp(-c * xValue)) * wikiGompertz(xValue, a, b, c) * (1 - exp(-b * exp(-c * xValue)))
    }

    return doubleArrayOf(gradientA.sum(), gradientB.sum(), gradientC.sum())
  }

  fun fitWikiGompertz(x: DoubleArray, y: DoubleArray, a: Double, b: Double, c: Double, learningRate: Double, maxIterations: Int): DoubleArray {
    val initialGuess = doubleArrayOf(a, b, c)

    var parameters = initialGuess.copyOf()

    repeat(maxIterations) {
      val gradients = computeWikiGompertzGradients(parameters, x, y)
      parameters = parameters.zip(gradients).map { (param, gradient) ->
        param - learningRate * gradient
      }.toDoubleArray()
    }

    return parameters
  }

}
