WaveletDenoiser.java

package com.morphiqlabs.wavelet.denoising;

import com.morphiqlabs.wavelet.modwt.MODWTResult;
import com.morphiqlabs.wavelet.modwt.MODWTTransform;
import com.morphiqlabs.wavelet.modwt.MultiLevelMODWTResult;
import com.morphiqlabs.wavelet.modwt.MultiLevelMODWTTransform;
import com.morphiqlabs.wavelet.api.BoundaryMode;
import com.morphiqlabs.wavelet.api.Wavelet;
import com.morphiqlabs.wavelet.exception.InvalidArgumentException;
import com.morphiqlabs.wavelet.exception.InvalidSignalException;
import com.morphiqlabs.wavelet.exception.InvalidStateException;
import com.morphiqlabs.wavelet.exception.ErrorCode;
import com.morphiqlabs.wavelet.exception.ErrorContext;
import com.morphiqlabs.wavelet.WaveletOperations;

/**
 * Wavelet-based signal denoising using various thresholding strategies.
 *
 * <p>This class provides comprehensive denoising capabilities for signals
 * corrupted by noise, particularly effective for financial time series data
 * where preserving important features while removing noise is critical.</p>
 *
 * <p>Now uses MODWT (Maximal Overlap Discrete Wavelet Transform) which provides:</p>
 * <ul>
 *   <li>Shift-invariant denoising (better for time series)</li>
 *   <li>Works with any signal length (not just power-of-2)</li>
 *   <li>Same-length coefficients preserve temporal alignment</li>
 *   <li>Multiple threshold selection methods (Universal, SURE, Minimax)</li>
 *   <li>Soft and hard thresholding</li>
 *   <li>Level-dependent thresholding</li>
 *   <li>Multi-level decomposition support</li>
 *   <li>SIMD-optimized thresholding operations</li>
 * </ul>
 *
 * <p>Example usage:</p>
 * <pre>{@code
 * WaveletDenoiser denoiser = new WaveletDenoiser(Daubechies.DB4, BoundaryMode.PERIODIC);
 * double[] noisySignal = ...;
 * double[] denoised = denoiser.denoise(noisySignal, ThresholdMethod.UNIVERSAL);
 * }</pre>
 * 
 * @since 1.0.0
 */
public class WaveletDenoiser {
    
    /**
     * Maximum safe level for bit shift scaling operations.
     * 
     * <p>When calculating level-dependent scaling using bit shifts (1 << (level - 1)),
     * we need to ensure the shift amount doesn't exceed 30 to avoid integer overflow.
     * Since level - 1 must be <= 30, the maximum safe level is 31.</p>
     * 
     * <p>In practice, wavelet decomposition rarely exceeds 10-15 levels due to
     * signal length constraints and numerical stability, so this limit provides
     * a large safety margin.</p>
     */
    private static final int MAX_SAFE_LEVEL_FOR_SCALING = 31;
    
    /**
     * Small positive value added to prevent division by zero in BayesShrink calculation.
     * 
     * <p>When the signal standard deviation (sigma_x) is zero or very small,
     * this epsilon prevents numerical instability in the threshold calculation
     * T = sigma^2 / sigma_x. The value is chosen to be negligible compared to
     * typical signal magnitudes while ensuring numerical stability.</p>
     */
    private static final double BAYES_EPSILON = 1e-10;

    private final Wavelet wavelet;
    private final BoundaryMode boundaryMode;
    private final boolean useVectorOps;

    /**
     * Creates a wavelet denoiser with the specified wavelet and boundary mode.
     *
     * @param wavelet      the wavelet to use for decomposition
     * @param boundaryMode the boundary handling mode
     * @throws InvalidArgumentException if wavelet or boundaryMode is null
     */
    public WaveletDenoiser(Wavelet wavelet, BoundaryMode boundaryMode) {
        if (wavelet == null) {
            throw InvalidArgumentException.nullArgument("wavelet");
        }
        if (boundaryMode == null) {
            throw InvalidArgumentException.nullArgument("boundaryMode");
        }

        this.wavelet = wavelet;
        this.boundaryMode = boundaryMode;
        this.useVectorOps = WaveletOperations.getPerformanceInfo().vectorizationEnabled();
    }

    /**
     * Creates a denoiser optimized for financial time series.
     * Uses DB4 wavelet with periodic boundaries and SURE thresholding.
     *
     * @return a new WaveletDenoiser instance
     */
    public static WaveletDenoiser forFinancialData() {
        return new WaveletDenoiser(com.morphiqlabs.wavelet.api.Daubechies.DB4, BoundaryMode.PERIODIC);
    }

    /**
     * Denoises a signal using single-level wavelet transform with automatic threshold selection.
     *
     * @param signal the noisy signal to denoise
     * @param method the threshold selection method
     * @return the denoised signal
     * @throws InvalidSignalException if signal is invalid
     */
    public double[] denoise(double[] signal, ThresholdMethod method) {
        return denoise(signal, method, ThresholdType.SOFT);
    }

    /**
     * Denoises a signal using single-level wavelet transform.
     *
     * @param signal the noisy signal to denoise
     * @param method the threshold selection method
     * @param type   the thresholding type (soft or hard)
     * @return the denoised signal
     * @throws InvalidSignalException if signal is invalid
     */
    public double[] denoise(double[] signal, ThresholdMethod method, ThresholdType type) {
        MODWTTransform transform = new MODWTTransform(wavelet, boundaryMode);
        MODWTResult result = transform.forward(signal);

        // Estimate noise level from detail coefficients
        double sigma = estimateNoiseSigma(result.detailCoeffs());

        // Calculate threshold
        double threshold = calculateThreshold(result.detailCoeffs(), sigma, method);

        // Apply thresholding to detail coefficients
        double[] denoisedDetails = applyThreshold(result.detailCoeffs(), threshold, type);

        // Reconstruct with denoised coefficients
        MODWTResult denoisedResult = MODWTResult.create(
                result.approximationCoeffs(), denoisedDetails);

        return transform.inverse(denoisedResult);
    }

    /**
     * Denoises a signal using multi-level wavelet transform with level-dependent thresholding.
     *
     * @param signal the noisy signal to denoise
     * @param levels the number of decomposition levels
     * @param method the threshold selection method
     * @param type   the thresholding type
     * @return the denoised signal
     * @throws InvalidSignalException   if signal is invalid
     * @throws InvalidArgumentException if levels is invalid
     */
    public double[] denoiseMultiLevel(double[] signal, int levels,
                                      ThresholdMethod method, ThresholdType type) {
        // Use proper multi-level MODWT decomposition
        MultiLevelMODWTTransform multiTransform = new MultiLevelMODWTTransform(wavelet, boundaryMode);
        MultiLevelMODWTResult multiResult = multiTransform.decompose(signal, levels);
        
        // Estimate noise from the finest scale (level 1) detail coefficients
        double sigma = estimateNoiseSigma(multiResult.getDetailCoeffsAtLevel(1));
        
        // Create a wrapper that applies denoising on-the-fly
        MultiLevelMODWTResult denoisedResult = new DenoisedMultiLevelResult(
            multiResult, sigma, method, type);
        
        // Reconstruct the denoised signal
        return multiTransform.reconstruct(denoisedResult);
    }
    
    /**
     * Wrapper class that applies denoising to multi-level MODWT coefficients on-the-fly.
     */
    private class DenoisedMultiLevelResult implements MultiLevelMODWTResult {
        private final MultiLevelMODWTResult original;
        private final double sigma;
        private final ThresholdMethod method;
        private final ThresholdType type;
        private final double[][] denoisedDetails;
        
        DenoisedMultiLevelResult(MultiLevelMODWTResult original, double sigma,
                                ThresholdMethod method, ThresholdType type) {
            this.original = original;
            this.sigma = sigma;
            this.method = method;
            this.type = type;
            this.denoisedDetails = new double[original.getLevels()][];
            
            // Validate maximum level before processing to prevent overflow
            if (original.getLevels() > MAX_SAFE_LEVEL_FOR_SCALING) {
                throw new InvalidArgumentException(
                    ErrorCode.VAL_TOO_LARGE,
                    ErrorContext.builder("Decomposition level exceeds safe limit for scale-dependent thresholds")
                        .withContext("Operation", "Multi-level denoising")
                        .withLevelInfo(original.getLevels(), MAX_SAFE_LEVEL_FOR_SCALING)
                        .withContext("Threshold method", method.name())
                        .withContext("Threshold type", type.name())
                        .withSuggestion("Reduce decomposition levels to " + MAX_SAFE_LEVEL_FOR_SCALING + " or less")
                        .withSuggestion("Use level-independent threshold methods")
                        .build()
                );
            }
            
            // Pre-compute denoised details for all levels
            for (int level = 1; level <= original.getLevels(); level++) {
                double[] levelDetails = original.getDetailCoeffsAtLevel(level);
                
                // Calculate threshold with level-dependent scaling
                // Use bit shift for efficient power of 2 calculation
                // Safety guarantee: Constructor validation ensures original.getLevels() <= MAX_SAFE_LEVEL_FOR_SCALING (31)
                // Therefore: level <= 31, so (level - 1) <= 30, making 1 << (level - 1) safe from overflow
                if (level > MAX_SAFE_LEVEL_FOR_SCALING) {
                    throw new InvalidStateException(
                        ErrorCode.STATE_INVALID,
                        ErrorContext.builder("Internal error: Level exceeds bit shift safety limit")
                            .withContext("Operation", "Scale factor calculation")
                            .withLevelInfo(level, MAX_SAFE_LEVEL_FOR_SCALING)
                            .withContext("This should have been caught earlier", "Internal consistency check")
                            .withSuggestion("This is an internal error - please report this as a bug")
                            .build()
                    );
                }
                // With textbook MODWT scaling h_j,l = h_l / 2^(j/2), noise std at level j scales by 1/sqrt(2^j)
                double levelScale = Math.sqrt(1 << level); // sqrt(2^j)
                double threshold = calculateThreshold(levelDetails, sigma / levelScale, method);
                
                // Apply thresholding and store
                denoisedDetails[level - 1] = applyThreshold(levelDetails, threshold, type);
            }
        }
        
        @Override
        public int getSignalLength() {
            return original.getSignalLength();
        }
        
        @Override
        public int getLevels() {
            return original.getLevels();
        }
        
        @Override
        public double[] getApproximationCoeffs() {
            // Return original approximation coefficients (not denoised)
            return original.getApproximationCoeffs();
        }
        
        @Override
        public double[] getDetailCoeffsAtLevel(int level) {
            if (level < 1 || level > getLevels()) {
                throw new InvalidArgumentException(
                    ErrorCode.CFG_INVALID_DECOMPOSITION_LEVEL,
                    ErrorContext.builder("Invalid level for detail coefficient access")
                        .withContext("Operation", "getDetailCoeffsAtLevel")
                        .withLevelInfo(level, getLevels())
                        .withContext("Result type", "DenoisedMultiLevelResult")
                        .withSuggestion("Level must be between 1 and " + getLevels())
                        .build()
                );
            }
            // Return denoised detail coefficients
            return denoisedDetails[level - 1].clone();
        }
        
        @Override
        public double getDetailEnergyAtLevel(int level) {
            if (level < 1 || level > getLevels()) {
                throw new InvalidArgumentException(
                    ErrorCode.CFG_INVALID_DECOMPOSITION_LEVEL,
                    ErrorContext.builder("Invalid level for detail energy calculation")
                        .withContext("Operation", "getDetailEnergyAtLevel")
                        .withLevelInfo(level, getLevels())
                        .withContext("Result type", "DenoisedMultiLevelResult")
                        .withSuggestion("Level must be between 1 and " + getLevels())
                        .build()
                );
            }
            double energy = 0.0;
            double[] details = denoisedDetails[level - 1];
            for (double val : details) {
                energy += val * val;
            }
            return energy;
        }
        
        @Override
        public double getApproximationEnergy() {
            double energy = 0.0;
            double[] approx = getApproximationCoeffs();
            for (double val : approx) {
                energy += val * val;
            }
            return energy;
        }
        
        @Override
        public double getTotalEnergy() {
            // Recalculate based on denoised coefficients
            double energy = getApproximationEnergy();
            for (int level = 1; level <= getLevels(); level++) {
                energy += getDetailEnergyAtLevel(level);
            }
            return energy;
        }
        
        @Override
        public double[] getRelativeEnergyDistribution() {
            // Recalculate based on denoised coefficients
            double totalEnergy = getTotalEnergy();
            double[] distribution = new double[getLevels() + 1];
            
            // Approximation energy
            double[] approx = getApproximationCoeffs();
            double approxEnergy = 0.0;
            for (double val : approx) {
                approxEnergy += val * val;
            }
            distribution[0] = approxEnergy / totalEnergy;
            
            // Detail energies
            for (int level = 1; level <= getLevels(); level++) {
                double[] details = denoisedDetails[level - 1];
                double levelEnergy = 0.0;
                for (double val : details) {
                    levelEnergy += val * val;
                }
                distribution[level] = levelEnergy / totalEnergy;
            }
            
            return distribution;
        }
        
        @Override
        public boolean isValid() {
            return original.isValid();
        }
        
        @Override
        public MultiLevelMODWTResult copy() {
            // Return a new wrapper with the same parameters
            return new DenoisedMultiLevelResult(original.copy(), sigma, method, type);
        }
    }

    /**
     * Denoises a signal using a fixed threshold value.
     *
     * @param signal    the noisy signal
     * @param threshold the threshold value
     * @param type      the thresholding type
     * @return the denoised signal
     */
    public double[] denoiseFixed(double[] signal, double threshold, ThresholdType type) {
        MODWTTransform transform = new MODWTTransform(wavelet, boundaryMode);
        MODWTResult result = transform.forward(signal);

        double[] denoisedDetails = applyThreshold(result.detailCoeffs(), threshold, type);

        MODWTResult denoisedResult = MODWTResult.create(
                result.approximationCoeffs(), denoisedDetails);

        return transform.inverse(denoisedResult);
    }

    /**
     * Estimates the noise standard deviation using the Median Absolute Deviation (MAD)
     * of the detail coefficients at the finest scale.
     *
     * @param detailCoeffs the detail coefficients
     * @return estimated noise standard deviation
     */
    protected double estimateNoiseSigma(double[] detailCoeffs) {
        // Calculate median absolute deviation
        double[] absCoeffs = new double[detailCoeffs.length];
        for (int i = 0; i < detailCoeffs.length; i++) {
            absCoeffs[i] = Math.abs(detailCoeffs[i]);
        }

        double median = calculateMedian(absCoeffs);

        // Scale factor for Gaussian noise
        return median / 0.6745;
    }

    /**
     * Calculates the threshold value based on the selected method.
     *
     * @param coeffs the wavelet coefficients
     * @param sigma the noise standard deviation
     * @param method the threshold selection method
     * @return the calculated threshold value
     */
    protected double calculateThreshold(double[] coeffs, double sigma, ThresholdMethod method) {
        int n = coeffs.length;

        switch (method) {
            case UNIVERSAL:
                // Universal threshold (VisuShrink)
                return sigma * Math.sqrt(2.0 * Math.log(n));

            case SURE:
                // SURE threshold
                return calculateSUREThreshold(coeffs, sigma);

            case MINIMAX:
                // Minimax threshold
                return calculateMinimaxThreshold(n, sigma);

            case BAYES:
                // BayesShrink threshold
                return calculateBayesThreshold(coeffs, sigma);

            case FIXED:
                // Should not reach here for automatic threshold selection
                throw new InvalidArgumentException(
                    ErrorCode.CFG_UNSUPPORTED_OPERATION,
                    ErrorContext.builder("Fixed threshold method requires explicit threshold value")
                        .withContext("Operation", "denoise")
                        .withContext("Threshold method", "FIXED")
                        .withSuggestion("Use denoiseFixed() method with explicit threshold value")
                        .withSuggestion("Or select an automatic threshold method: UNIVERSAL, SURE, MINIMAX, or BAYES")
                        .build()
                );

            default:
                throw new InvalidArgumentException(
                    ErrorCode.CFG_UNSUPPORTED_OPERATION,
                    ErrorContext.builder("Unknown threshold selection method")
                        .withContext("Operation", "selectThreshold")
                        .withContext("Unknown method", method.toString())
                        .withSuggestion("Supported methods: UNIVERSAL, SURE, MINIMAX, BAYES, FIXED")
                        .build()
                );
        }
    }

    /**
     * Calculates SURE (Stein's Unbiased Risk Estimate) threshold.
     */
    private double calculateSUREThreshold(double[] coeffs, double sigma) {
        int n = coeffs.length;

        // Sort coefficients by absolute value
        double[] sortedAbs = new double[n];
        for (int i = 0; i < n; i++) {
            sortedAbs[i] = Math.abs(coeffs[i]);
        }
        java.util.Arrays.sort(sortedAbs);

        // Calculate SURE for each possible threshold
        double minRisk = Double.POSITIVE_INFINITY;
        double bestThreshold = 0;

        for (int k = 0; k < n; k++) {
            double t = sortedAbs[k];
            double risk = calculateSURERisk(coeffs, t, sigma);

            if (risk < minRisk) {
                minRisk = risk;
                bestThreshold = t;
            }
        }

        // Compare with universal threshold
        double universalThreshold = sigma * Math.sqrt(2.0 * Math.log(n));
        if (bestThreshold > universalThreshold) {
            bestThreshold = universalThreshold;
        }

        return bestThreshold;
    }

    /**
     * Calculates SURE risk for a given threshold.
     */
    private double calculateSURERisk(double[] coeffs, double threshold, double sigma) {
        int n = coeffs.length;
        double sigma2 = sigma * sigma;
        double risk = -n * sigma2;

        for (double c : coeffs) {
            double absC = Math.abs(c);
            if (absC <= threshold) {
                risk += c * c;
            } else {
                risk += sigma2 + (absC - threshold) * (absC - threshold);
            }
        }

        return risk / n;
    }

    /**
     * Calculates Minimax threshold.
     */
    private double calculateMinimaxThreshold(int n, double sigma) {
        // Minimax threshold approximation
        double logN = Math.log(n);

        if (n <= 32) {
            return 0;
        } else if (n <= 64) {
            return sigma * 0.3936 + 0.1829 * sigma * logN;
        } else {
            return sigma * (0.4745 + 0.1148 * logN);
        }
    }

    /**
     * Calculates the BayesShrink threshold.
     * 
     * <p>BayesShrink threshold is computed as: T = sigma^2 / sigma_x
     * where sigma_x = sqrt(max(0, sigma_y^2 - sigma^2))
     * and sigma_y^2 is the variance of the wavelet coefficients.</p>
     * 
     * @param coeffs the wavelet coefficients
     * @param sigma the noise standard deviation
     * @return the BayesShrink threshold
     */
    private double calculateBayesThreshold(double[] coeffs, double sigma) {
        int n = coeffs.length;
        double sigma2 = sigma * sigma;
        
        // Calculate variance of coefficients (sigma_y^2)
        double mean = 0.0;
        for (double c : coeffs) {
            mean += c;
        }
        mean /= n;
        
        double variance = 0.0;
        for (double c : coeffs) {
            double diff = c - mean;
            variance += diff * diff;
        }
        variance /= n;
        
        // Calculate signal standard deviation (sigma_x)
        // sigma_x^2 = max(0, sigma_y^2 - sigma^2)
        double sigmaX2 = Math.max(0.0, variance - sigma2);
        
        // Avoid division by zero using epsilon constant
        double sigmaX = Math.sqrt(sigmaX2 + BAYES_EPSILON);
        
        // BayesShrink threshold: T = sigma^2 / sigma_x
        return sigma2 / sigmaX;
    }

    /**
     * Applies thresholding to coefficients.
     *
     * @param coeffs the wavelet coefficients
     * @param threshold the threshold value
     * @param type the thresholding type
     * @return the thresholded coefficients
     */
    protected double[] applyThreshold(double[] coeffs, double threshold, ThresholdType type) {
        if (useVectorOps) {
            // Use SIMD-optimized thresholding
            return type == ThresholdType.SOFT
                    ? WaveletOperations.softThreshold(coeffs, threshold)
                    : WaveletOperations.hardThreshold(coeffs, threshold);
        } else {
            // Scalar implementation
            double[] result = new double[coeffs.length];

            if (type == ThresholdType.SOFT) {
                for (int i = 0; i < coeffs.length; i++) {
                    double absCoeff = Math.abs(coeffs[i]);
                    result[i] = absCoeff <= threshold ? 0.0
                            : Math.signum(coeffs[i]) * (absCoeff - threshold);
                }
            } else { // HARD
                for (int i = 0; i < coeffs.length; i++) {
                    result[i] = Math.abs(coeffs[i]) <= threshold ? 0.0 : coeffs[i];
                }
            }

            return result;
        }
    }

    /**
     * Calculates the median of an array.
     */
    private double calculateMedian(double[] values) {
        double[] sorted = values.clone();
        java.util.Arrays.sort(sorted);

        int n = sorted.length;
        if (n % 2 == 0) {
            return (sorted[n / 2 - 1] + sorted[n / 2]) / 2.0;
        } else {
            return sorted[n / 2];
        }
    }

    /**
     * Threshold selection methods.
     */
    public enum ThresholdMethod {
        /**
         * Universal threshold (VisuShrink): sqrt(2 * log(N)) * sigma
         * Conservative, tends to oversmooth but ensures noise removal.
         */
        UNIVERSAL,

        /**
         * SURE (Stein's Unbiased Risk Estimate) threshold.
         * Adapts to signal characteristics, good for smooth signals.
         */
        SURE,

        /**
         * Minimax threshold: optimal for worst-case MSE.
         * Good compromise between smoothing and feature preservation.
         */
        MINIMAX,

        /**
         * BayesShrink threshold: sigma^2 / sqrt(max(var(X) - sigma^2, 0))
         * Adaptive method that minimizes Bayesian risk, good for signals with varying SNR.
         */
        BAYES,

        /**
         * Fixed threshold: user-specified value.
         */
        FIXED
    }

    /**
     * Thresholding function types.
     */
    public enum ThresholdType {
        /**
         * Soft thresholding: shrinks coefficients towards zero.
         * Produces smoother results with less artifacts.
         */
        SOFT,

        /**
         * Hard thresholding: keeps or kills coefficients.
         * Better preserves signal features but may introduce artifacts.
         */
        HARD
    }
}