LinearExtrapolationStrategy.java
package com.morphiqlabs.wavelet.padding;
import com.morphiqlabs.wavelet.exception.InvalidArgumentException;
/**
* Linear extrapolation padding strategy.
*
* @param fitPoints number of points used for linear fit
* @param mode padding mode applied to extrapolated values
*/
public record LinearExtrapolationStrategy(int fitPoints, PaddingMode mode) implements PaddingStrategy {
/**
* Padding mode determines where padding is applied.
*/
public enum PaddingMode {
/** Pad only on the right side */
RIGHT,
/** Pad equally on both sides */
SYMMETRIC,
/** Pad only on the left side */
LEFT
}
/**
* Creates a linear extrapolation strategy with default parameters.
* Uses 2 fit points and RIGHT padding mode.
*/
public LinearExtrapolationStrategy() {
this(2, PaddingMode.RIGHT);
}
/**
* Creates a linear extrapolation strategy with specified fit points.
* Uses RIGHT padding mode.
*
* @param fitPoints number of points to use for slope calculation (minimum 2)
*/
public LinearExtrapolationStrategy(int fitPoints) {
this(fitPoints, PaddingMode.RIGHT);
}
/**
* Validates and adjusts fit points if necessary.
*/
public LinearExtrapolationStrategy {
if (fitPoints < 2) {
throw new InvalidArgumentException("Fit points must be at least 2, got " + fitPoints);
}
}
@Override
public double[] pad(double[] signal, int targetLength) {
if (signal == null) {
throw new InvalidArgumentException("Signal cannot be null");
}
if (signal.length == 0) {
throw new InvalidArgumentException("Signal cannot be empty");
}
if (targetLength < signal.length) {
throw new InvalidArgumentException(
"Target length " + targetLength + " must be >= signal length " + signal.length);
}
if (targetLength == signal.length) {
return signal.clone();
}
double[] padded = new double[targetLength];
int padLength = targetLength - signal.length;
// Adjust fit points if signal is too short
int actualFitPoints = Math.min(fitPoints, signal.length);
switch (mode) {
case RIGHT -> {
// Copy original signal to the start
System.arraycopy(signal, 0, padded, 0, signal.length);
// Calculate slope from last fitPoints
double slope = calculateSlope(signal, signal.length - actualFitPoints, actualFitPoints);
double lastValue = signal[signal.length - 1];
// Extrapolate to the right
for (int i = 0; i < padLength; i++) {
padded[signal.length + i] = lastValue + slope * (i + 1);
}
}
case LEFT -> {
// Calculate slope from first fitPoints
double slope = calculateSlope(signal, 0, actualFitPoints);
double firstValue = signal[0];
// Extrapolate to the left
for (int i = 0; i < padLength; i++) {
padded[padLength - 1 - i] = firstValue - slope * (i + 1);
}
// Copy original signal after padding
System.arraycopy(signal, 0, padded, padLength, signal.length);
}
case SYMMETRIC -> {
// Calculate left and right padding
int leftPad = padLength / 2;
int rightPad = padLength - leftPad;
// Calculate left slope
double leftSlope = calculateSlope(signal, 0, actualFitPoints);
double firstValue = signal[0];
// Extrapolate to the left
for (int i = 0; i < leftPad; i++) {
padded[leftPad - 1 - i] = firstValue - leftSlope * (i + 1);
}
// Copy original signal
System.arraycopy(signal, 0, padded, leftPad, signal.length);
// Calculate right slope
double rightSlope = calculateSlope(signal, signal.length - actualFitPoints, actualFitPoints);
double lastValue = signal[signal.length - 1];
// Extrapolate to the right
for (int i = 0; i < rightPad; i++) {
padded[leftPad + signal.length + i] = lastValue + rightSlope * (i + 1);
}
}
}
return padded;
}
/**
* Calculate slope using least-squares fitting.
*
* @param signal the signal data
* @param start starting index for fitting
* @param length number of points to fit
* @return the calculated slope
*/
private double calculateSlope(double[] signal, int start, int length) {
if (length == 2) {
// Simple slope for 2 points
return signal[start + 1] - signal[start];
}
// Least-squares fitting for more points
double sumX = 0, sumY = 0, sumXY = 0, sumX2 = 0;
for (int i = 0; i < length; i++) {
double x = i;
double y = signal[start + i];
sumX += x;
sumY += y;
sumXY += x * y;
sumX2 += x * x;
}
// Calculate slope: (n*sumXY - sumX*sumY) / (n*sumX2 - sumX*sumX)
double n = length;
double denominator = n * sumX2 - sumX * sumX;
if (Math.abs(denominator) < 1e-10) {
// Points are collinear vertically, no slope
return 0.0;
}
return (n * sumXY - sumX * sumY) / denominator;
}
@Override
public double[] trim(double[] result, int originalLength) {
if (result.length == originalLength) {
return result;
}
if (originalLength > result.length) {
throw new InvalidArgumentException(
"Original length " + originalLength + " exceeds result length " + result.length);
}
double[] trimmed = new double[originalLength];
switch (mode) {
case RIGHT -> {
// Trim from the end
System.arraycopy(result, 0, trimmed, 0, originalLength);
}
case LEFT -> {
// Trim from the beginning
System.arraycopy(result, result.length - originalLength, trimmed, 0, originalLength);
}
case SYMMETRIC -> {
// Trim equally from both sides
int totalPadding = result.length - originalLength;
int leftPad = totalPadding / 2;
System.arraycopy(result, leftPad, trimmed, 0, originalLength);
}
}
return trimmed;
}
@Override
public String name() {
return String.format("linear-%d-%s", fitPoints, mode.name().toLowerCase());
}
@Override
public String description() {
return String.format("Linear extrapolation padding (%d fit points, %s mode) - extrapolates based on edge slopes",
fitPoints, mode.name().toLowerCase());
}
}