StatisticalPaddingStrategy.java

package com.morphiqlabs.wavelet.padding;

import com.morphiqlabs.wavelet.exception.InvalidArgumentException;
import java.util.Arrays;

/**
 * Statistical padding strategy that extends signals based on statistical properties.
 *
 * <p>This strategy analyzes the signal's statistical characteristics and extends
 * it accordingly. Different methods are suitable for different signal types.
 * Ideal for:</p>
 * <ul>
 *   <li>Stochastic signals with known statistical properties</li>
 *   <li>Financial time series with mean-reverting behavior</li>
 *   <li>Signals where preserving statistical properties is important</li>
 *   <li>Noisy signals where simple extrapolation may amplify noise</li>
 * </ul>
 *
 * <p>Available statistical methods:</p>
 * <ul>
 *   <li><b>MEAN</b>: Pad with signal mean (good for stationary signals)</li>
 *   <li><b>MEDIAN</b>: Pad with signal median (robust to outliers)</li>
 *   <li><b>WEIGHTED_MEAN</b>: Recent values weighted higher (trending signals)</li>
 *   <li><b>TREND</b>: Linear trend extrapolation based on entire signal</li>
 *   <li><b>VARIANCE_MATCHED</b>: Random values matching signal variance</li>
 *   <li><b>LOCAL_MEAN</b>: Mean of nearby values (smooth transitions)</li>
 * </ul>
 *
 * @param method statistical method
 * @param windowSize window size used by the method (0 = auto)
 * @param mode where to apply padding (LEFT/RIGHT/SYMMETRIC)
 */
public record StatisticalPaddingStrategy(
    StatMethod method,
    int windowSize,
    PaddingMode mode
) implements PaddingStrategy {
    
    /**
     * Statistical method for padding.
     */
    public enum StatMethod {
        /** Pad with global mean of signal */
        MEAN,
        /** Pad with global median of signal */
        MEDIAN,
        /** Pad with weighted mean (recent values weighted higher) */
        WEIGHTED_MEAN,
        /** Extrapolate linear trend from entire signal */
        TREND,
        /** Pad with random values matching signal variance */
        VARIANCE_MATCHED,
        /** Pad with mean of nearby values */
        LOCAL_MEAN
    }
    
    /**
     * Padding mode determines where padding is applied.
     */
    public enum PaddingMode {
        /** Pad only on the right side */
        RIGHT,
        /** Pad equally on both sides */
        SYMMETRIC,
        /** Pad only on the left side */
        LEFT
    }
    
    /**
     * Creates a statistical padding strategy with default MEAN method.
     */
    public StatisticalPaddingStrategy() {
        this(StatMethod.MEAN, 0, PaddingMode.RIGHT);
    }
    
    /**
     * Creates a statistical padding strategy with specified method.
     * 
     * @param method the statistical method to use
     */
    public StatisticalPaddingStrategy(StatMethod method) {
        this(method, 0, PaddingMode.RIGHT);
    }
    
    /**
     * Validates parameters and sets default window size if needed.
     */
    public StatisticalPaddingStrategy {
        if (windowSize == 0) {
            // Auto-select window size based on method
            windowSize = switch (method) {
                case LOCAL_MEAN, WEIGHTED_MEAN -> 10; // Default window
                default -> Integer.MAX_VALUE; // Use entire signal
            };
        }
        
        if (windowSize < 1) {
            throw new InvalidArgumentException("Window size must be positive, got " + windowSize);
        }
    }
    
    @Override
    public double[] pad(double[] signal, int targetLength) {
        if (signal == null) {
            throw new InvalidArgumentException("Signal cannot be null");
        }
        if (signal.length == 0) {
            throw new InvalidArgumentException("Signal cannot be empty");
        }
        if (targetLength < signal.length) {
            throw new InvalidArgumentException(
                    "Target length " + targetLength + " must be >= signal length " + signal.length);
        }
        
        if (targetLength == signal.length) {
            return signal.clone();
        }
        
        double[] padded = new double[targetLength];
        int padLength = targetLength - signal.length;
        
        switch (mode) {
            case RIGHT -> {
                System.arraycopy(signal, 0, padded, 0, signal.length);
                applyStatisticalPadding(signal, padded, signal.length, padLength, true);
            }
            case LEFT -> {
                applyStatisticalPadding(signal, padded, 0, padLength, false);
                System.arraycopy(signal, 0, padded, padLength, signal.length);
            }
            case SYMMETRIC -> {
                int leftPad = padLength / 2;
                int rightPad = padLength - leftPad;
                
                applyStatisticalPadding(signal, padded, 0, leftPad, false);
                System.arraycopy(signal, 0, padded, leftPad, signal.length);
                applyStatisticalPadding(signal, padded, leftPad + signal.length, rightPad, true);
            }
        }
        
        return padded;
    }
    
    /**
     * Apply statistical padding to the specified region.
     */
    private void applyStatisticalPadding(double[] signal, double[] padded, 
                                         int startIdx, int length, boolean rightSide) {
        switch (method) {
            case MEAN -> {
                double mean = calculateMean(signal);
                Arrays.fill(padded, startIdx, startIdx + length, mean);
            }
            case MEDIAN -> {
                double median = calculateMedian(signal);
                Arrays.fill(padded, startIdx, startIdx + length, median);
            }
            case WEIGHTED_MEAN -> {
                double weightedMean = calculateWeightedMean(signal, rightSide);
                Arrays.fill(padded, startIdx, startIdx + length, weightedMean);
            }
            case LOCAL_MEAN -> {
                double localMean = calculateLocalMean(signal, rightSide);
                Arrays.fill(padded, startIdx, startIdx + length, localMean);
            }
            case TREND -> {
                applyTrendExtrapolation(signal, padded, startIdx, length, rightSide);
            }
            case VARIANCE_MATCHED -> {
                applyVarianceMatchedPadding(signal, padded, startIdx, length);
            }
        }
    }
    
    /**
     * Calculate mean of signal.
     */
    private double calculateMean(double[] signal) {
        double sum = 0;
        for (double val : signal) {
            sum += val;
        }
        return sum / signal.length;
    }
    
    /**
     * Calculate median of signal.
     */
    private double calculateMedian(double[] signal) {
        double[] sorted = signal.clone();
        Arrays.sort(sorted);
        int n = sorted.length;
        if (n % 2 == 0) {
            return (sorted[n/2 - 1] + sorted[n/2]) / 2.0;
        } else {
            return sorted[n/2];
        }
    }
    
    /**
     * Calculate weighted mean with exponential weighting.
     */
    private double calculateWeightedMean(double[] signal, boolean rightSide) {
        int actualWindow = Math.min(windowSize, signal.length);
        double sumWeighted = 0;
        double sumWeights = 0;
        double alpha = 0.9; // Decay factor
        
        if (rightSide) {
            // Weight recent values (end of signal) more heavily
            for (int i = 0; i < actualWindow; i++) {
                int idx = signal.length - actualWindow + i;
                double weight = Math.pow(alpha, actualWindow - i - 1);
                sumWeighted += signal[idx] * weight;
                sumWeights += weight;
            }
        } else {
            // Weight early values (start of signal) more heavily
            for (int i = 0; i < actualWindow; i++) {
                double weight = Math.pow(alpha, i);
                sumWeighted += signal[i] * weight;
                sumWeights += weight;
            }
        }
        
        return sumWeights > 0 ? sumWeighted / sumWeights : calculateMean(signal);
    }
    
    /**
     * Calculate mean of local window.
     */
    private double calculateLocalMean(double[] signal, boolean rightSide) {
        int actualWindow = Math.min(windowSize, signal.length);
        double sum = 0;
        
        if (rightSide) {
            // Use last actualWindow values
            for (int i = signal.length - actualWindow; i < signal.length; i++) {
                sum += signal[i];
            }
        } else {
            // Use first actualWindow values
            for (int i = 0; i < actualWindow; i++) {
                sum += signal[i];
            }
        }
        
        return sum / actualWindow;
    }
    
    /**
     * Apply trend-based extrapolation.
     */
    private void applyTrendExtrapolation(double[] signal, double[] padded, 
                                         int startIdx, int length, boolean rightSide) {
        // Fit linear trend to entire signal
        double[] x = new double[signal.length];
        for (int i = 0; i < signal.length; i++) {
            x[i] = i;
        }
        
        double[] trend = fitLinearTrend(x, signal);
        double slope = trend[1];
        double intercept = trend[0];
        
        if (rightSide) {
            // Extrapolate forward
            for (int i = 0; i < length; i++) {
                double xVal = signal.length + i;
                padded[startIdx + i] = intercept + slope * xVal;
            }
        } else {
            // Extrapolate backward
            for (int i = 0; i < length; i++) {
                double xVal = -(length - i);
                padded[startIdx + i] = intercept + slope * xVal;
            }
        }
    }
    
    /**
     * Fit linear trend using least squares.
     * Returns [intercept, slope].
     */
    private double[] fitLinearTrend(double[] x, double[] y) {
        int n = x.length;
        double sumX = 0, sumY = 0, sumXY = 0, sumX2 = 0;
        
        for (int i = 0; i < n; i++) {
            sumX += x[i];
            sumY += y[i];
            sumXY += x[i] * y[i];
            sumX2 += x[i] * x[i];
        }
        
        double denominator = n * sumX2 - sumX * sumX;
        if (Math.abs(denominator) < 1e-10) {
            // Degenerate case, return mean as constant
            return new double[]{sumY / n, 0};
        }
        
        double slope = (n * sumXY - sumX * sumY) / denominator;
        double intercept = (sumY - slope * sumX) / n;
        
        return new double[]{intercept, slope};
    }
    
    /**
     * Apply variance-matched padding with pseudo-random values.
     */
    private void applyVarianceMatchedPadding(double[] signal, double[] padded, 
                                             int startIdx, int length) {
        double mean = calculateMean(signal);
        double variance = calculateVariance(signal, mean);
        double stdDev = Math.sqrt(variance);
        
        // Use deterministic pseudo-random for reproducibility
        // Based on signal properties as seed
        long seed = Double.doubleToLongBits(mean) ^ Double.doubleToLongBits(variance);
        java.util.Random random = new java.util.Random(seed);
        
        for (int i = 0; i < length; i++) {
            // Generate values with matching mean and variance
            padded[startIdx + i] = mean + stdDev * random.nextGaussian();
        }
    }
    
    /**
     * Calculate variance of signal.
     */
    private double calculateVariance(double[] signal, double mean) {
        double sumSquaredDiff = 0;
        for (double val : signal) {
            double diff = val - mean;
            sumSquaredDiff += diff * diff;
        }
        return sumSquaredDiff / signal.length;
    }
    
    @Override
    public double[] trim(double[] result, int originalLength) {
        if (result.length == originalLength) {
            return result;
        }
        if (originalLength > result.length) {
            throw new InvalidArgumentException(
                    "Original length " + originalLength + " exceeds result length " + result.length);
        }
        
        double[] trimmed = new double[originalLength];
        
        switch (mode) {
            case RIGHT -> System.arraycopy(result, 0, trimmed, 0, originalLength);
            case LEFT -> System.arraycopy(result, result.length - originalLength, trimmed, 0, originalLength);
            case SYMMETRIC -> {
                int totalPadding = result.length - originalLength;
                int leftPad = totalPadding / 2;
                System.arraycopy(result, leftPad, trimmed, 0, originalLength);
            }
        }
        
        return trimmed;
    }
    
    @Override
    public String name() {
        return String.format("statistical-%s-%s", 
                method.name().toLowerCase(), mode.name().toLowerCase());
    }
    
    @Override
    public String description() {
        String methodDesc = switch (method) {
            case MEAN -> "global mean";
            case MEDIAN -> "global median";
            case WEIGHTED_MEAN -> "weighted mean";
            case TREND -> "trend extrapolation";
            case VARIANCE_MATCHED -> "variance-matched random";
            case LOCAL_MEAN -> "local mean";
        };
        return String.format("Statistical padding (%s, window=%d, %s mode)", 
                methodDesc, windowSize == Integer.MAX_VALUE ? -1 : windowSize, 
                mode.name().toLowerCase());
    }
}