MODWTStreamingDenoiser.java

package com.morphiqlabs.wavelet.modwt.streaming;

import com.morphiqlabs.wavelet.api.BoundaryMode;
import com.morphiqlabs.wavelet.api.Daubechies;
import com.morphiqlabs.wavelet.api.Wavelet;
import com.morphiqlabs.wavelet.denoising.WaveletDenoiser;
import com.morphiqlabs.wavelet.denoising.WaveletDenoiser.ThresholdMethod;
import com.morphiqlabs.wavelet.denoising.WaveletDenoiser.ThresholdType;
import com.morphiqlabs.wavelet.exception.InvalidArgumentException;
import com.morphiqlabs.wavelet.modwt.MODWTResult;
import com.morphiqlabs.wavelet.modwt.MODWTTransform;
import com.morphiqlabs.wavelet.util.MathUtils;

import java.util.concurrent.Flow;
import java.util.concurrent.SubmissionPublisher;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Streaming denoiser based on MODWT for real-time signal denoising.
 * 
 * <p>This class provides real-time denoising capabilities using the MODWT transform,
 * which offers several advantages over DWT-based streaming denoisers:</p>
 * <ul>
 *   <li>Shift-invariance prevents artifacts at block boundaries</li>
 *   <li>Works with any buffer size (not restricted to powers of 2)</li>
 *   <li>Better preservation of signal features</li>
 *   <li>Improved noise estimation accuracy</li>
 * </ul>
 * 
 * <p>The denoiser processes data in blocks and can optionally publish results
 * to subscribers for further processing.</p>
 * 
 * @since 1.0.0
 */
public class MODWTStreamingDenoiser implements Flow.Publisher<double[]>, AutoCloseable {
    
    private final MODWTTransform transform;
    private final WaveletDenoiser denoiser;
    private final int bufferSize;
    private final ThresholdType thresholdType;
    private final ThresholdMethod thresholdMethod;
    private final double thresholdMultiplier;
    private final NoiseEstimation noiseEstimation;
    private final int noiseWindowSize;
    
    private final SubmissionPublisher<double[]> publisher;
    private final AtomicBoolean closed = new AtomicBoolean(false);
    private final AtomicLong samplesProcessed = new AtomicLong(0);
    
    // Noise estimation
    private double[] noiseWindow;
    private int noiseWindowIndex = 0;
    private double estimatedNoiseLevel = 0.0;
    
    /**
     * Noise estimation method for streaming denoising.
     */
    public enum NoiseEstimation {
        /** Median Absolute Deviation estimation */
        MAD,
        /** Standard deviation estimation */
        STD,
        /** Fixed noise level */
        FIXED
    }
    
    private MODWTStreamingDenoiser(Builder builder) {
        this.transform = new MODWTTransform(builder.wavelet, builder.boundaryMode);
        this.denoiser = new WaveletDenoiser(builder.wavelet, builder.boundaryMode);
        
        this.bufferSize = builder.bufferSize;
        this.thresholdType = builder.thresholdType;
        this.thresholdMethod = builder.thresholdMethod;
        this.thresholdMultiplier = builder.thresholdMultiplier;
        this.noiseEstimation = builder.noiseEstimation;
        this.noiseWindowSize = builder.noiseWindowSize;
        
        this.publisher = new SubmissionPublisher<>();
        
        if (noiseEstimation != NoiseEstimation.FIXED) {
            this.noiseWindow = new double[noiseWindowSize];
        }
    }
    
    /**
     * Process a block of samples and return the denoised result.
     * 
     * @param samples the input samples
     * @return the denoised samples
     * @throws IllegalStateException if the denoiser is closed
     * @throws InvalidArgumentException if samples is null or empty
     */
    public double[] denoise(double[] samples) {
        if (closed.get()) {
            throw new IllegalStateException("Denoiser is closed");
        }
        if (samples == null || samples.length == 0) {
            throw new InvalidArgumentException("Samples cannot be null or empty");
        }
        
        // Update noise estimation if needed
        if (noiseEstimation != NoiseEstimation.FIXED) {
            updateNoiseEstimation(samples);
        }
        
        // Denoise using WaveletDenoiser with threshold multiplier applied
        double[] denoised;
        if (Math.abs(thresholdMultiplier - 1.0) < 1e-10) {
            // No multiplier adjustment needed - use standard denoising
            denoised = denoiser.denoise(samples, thresholdMethod, thresholdType);
        } else {
            // Calculate threshold using the selected method, then apply multiplier
            double baseThreshold = calculateThreshold(samples);
            double adjustedThreshold = baseThreshold * thresholdMultiplier;
            denoised = denoiser.denoiseFixed(samples, adjustedThreshold, thresholdType);
        }
        
        // Update statistics
        samplesProcessed.addAndGet(samples.length);
        
        // Publish to subscribers if any
        if (publisher.hasSubscribers()) {
            publisher.submit(denoised.clone());
        }
        
        return denoised;
    }
    
    /**
     * Update noise estimation based on new samples.
     */
    private void updateNoiseEstimation(double[] samples) {
        // Transform to get detail coefficients
        MODWTResult result = transform.forward(samples);
        double[] details = result.detailCoeffs();
        
        // Update noise window with detail coefficients
        // Strategy: If we have more details than window size, use stratified sampling
        // to maintain temporal diversity in noise estimation
        if (details.length <= noiseWindowSize) {
            // Case 1: Fewer details than window size - add all
            for (double detail : details) {
                noiseWindow[noiseWindowIndex] = Math.abs(detail);
                noiseWindowIndex = (noiseWindowIndex + 1) % noiseWindowSize;
            }
        } else {
            // Case 2: More details than window size - use stratified sampling
            // Divide the signal into equal strata and sample from each
            // This ensures we capture temporal variations across the entire signal
            
            // Calculate number of samples per stratum
            int strataCount = Math.min(noiseWindowSize, 10); // Use at most 10 strata
            int samplesPerStratum = noiseWindowSize / strataCount;
            int extraSamples = noiseWindowSize % strataCount;
            int strataSize = details.length / strataCount;
            
            int sampleIndex = 0;
            
            // Sample from each stratum
            for (int stratum = 0; stratum < strataCount; stratum++) {
                int strataStart = stratum * strataSize;
                int strataEnd = (stratum == strataCount - 1) ? details.length : (stratum + 1) * strataSize;
                int samplesToTake = samplesPerStratum + (stratum < extraSamples ? 1 : 0);
                
                // Within each stratum, sample uniformly
                if (samplesToTake > 0) {
                    int strataLength = strataEnd - strataStart;
                    int strataStep = Math.max(1, strataLength / samplesToTake);
                    
                    for (int i = 0; i < samplesToTake && sampleIndex < noiseWindowSize; i++) {
                        int idx = strataStart + (i * strataStep) % strataLength;
                        if (idx < details.length) {
                            noiseWindow[noiseWindowIndex] = Math.abs(details[idx]);
                            noiseWindowIndex = (noiseWindowIndex + 1) % noiseWindowSize;
                            sampleIndex++;
                        }
                    }
                }
            }
            
            // Fill any remaining slots with samples from the end (recent data)
            int remaining = noiseWindowSize - sampleIndex;
            if (remaining > 0) {
                int startIdx = Math.max(0, details.length - remaining);
                for (int i = startIdx; i < details.length && sampleIndex < noiseWindowSize; i++) {
                    noiseWindow[noiseWindowIndex] = Math.abs(details[i]);
                    noiseWindowIndex = (noiseWindowIndex + 1) % noiseWindowSize;
                    sampleIndex++;
                }
            }
        }
        
        // Calculate noise level based on method
        switch (noiseEstimation) {
            case MAD -> estimatedNoiseLevel = calculateMAD(noiseWindow) / 0.6745;
            case STD -> estimatedNoiseLevel = calculateSTD(noiseWindow);
            default -> {
                // Fixed noise level, no update needed
            }
        }
    }
    
    /**
     * Calculates the Median Absolute Deviation (MAD) of the given values.
     * MAD is a robust measure of variability based on the median of absolute deviations.
     * 
     * @param values the array of values (must not be empty)
     * @return the median absolute deviation, or 0 if all values are zero/invalid
     * @throws IllegalArgumentException if values array is empty
     */
    private double calculateMAD(double[] values) {
        if (values == null || values.length == 0) {
            throw new IllegalArgumentException("Values array cannot be null or empty");
        }
        
        // Check if we have any non-zero, finite values
        boolean hasValidData = false;
        int validCount = 0;
        
        for (double value : values) {
            if (Double.isFinite(value)) {
                validCount++;
                if (value != 0.0) {
                    hasValidData = true;
                }
            }
        }
        
        // If no valid finite values, return 0
        if (validCount == 0) {
            return 0.0;
        }
        
        // If all valid values are zero, MAD is 0
        if (!hasValidData) {
            return 0.0;
        }
        
        return MathUtils.medianAbsoluteDeviation(values);
    }
    
    /**
     * Calculates the standard deviation of the given values.
     * 
     * @param values the array of values
     * @return the standard deviation, or 0 if insufficient valid data
     */
    private double calculateSTD(double[] values) {
        if (values == null || values.length == 0) {
            return 0.0;
        }
        
        // Count valid finite values
        int validCount = 0;
        for (double value : values) {
            if (Double.isFinite(value)) {
                validCount++;
            }
        }
        
        // Need at least 2 valid values for standard deviation
        if (validCount < 2) {
            return 0.0;
        }
        
        return MathUtils.standardDeviation(values);
    }
    
    /**
     * Calculates the base threshold using the selected method and current noise estimation.
     * This threshold can then be adjusted by the threshold multiplier.
     * 
     * @param samples the input samples to analyze
     * @return the calculated base threshold
     */
    private double calculateThreshold(double[] samples) {
        // Get current noise level estimate
        double sigma = estimatedNoiseLevel;
        
        // If no noise estimation available yet, estimate from current samples
        if (sigma <= 0.0 || noiseEstimation == NoiseEstimation.FIXED) {
            // Transform samples to get detail coefficients for noise estimation
            MODWTResult result = transform.forward(samples);
            double[] details = result.detailCoeffs();
            
            // Estimate noise using MAD (consistent with default noise estimation)
            double[] absDetails = new double[details.length];
            for (int i = 0; i < details.length; i++) {
                absDetails[i] = Math.abs(details[i]);
            }
            sigma = calculateMAD(absDetails) / 0.6745;
        }
        
        // Calculate threshold based on selected method
        int n = samples.length;
        return switch (thresholdMethod) {
            case UNIVERSAL -> 
                // Universal threshold (VisuShrink): sigma * sqrt(2 * log(n))
                sigma * Math.sqrt(2.0 * Math.log(n));
                
            case SURE -> 
                // For SURE threshold, we need the actual coefficients
                // This is more complex, so we'll approximate with a conservative factor
                sigma * Math.sqrt(2.0 * Math.log(n)) * 0.8;
                
            case MINIMAX -> {
                // Minimax threshold approximation
                double logN = Math.log(n);
                if (n <= 32) {
                    yield 0.0;
                } else if (n <= 64) {
                    yield sigma * (0.3936 + 0.1829 * logN);
                } else {
                    yield sigma * (0.4745 + 0.1148 * logN);
                }
            }
                
            case FIXED -> 
                // For fixed method, return a reasonable default that will be multiplied
                sigma;
                
            default -> throw new IllegalArgumentException("Unknown threshold method: " + thresholdMethod);
        };
    }
    
    
    /**
     * Get the current estimated noise level.
     * 
     * @return the estimated noise level
     */
    public double getEstimatedNoiseLevel() {
        return estimatedNoiseLevel;
    }
    
    /**
     * Get the total number of samples processed.
     * 
     * @return the number of samples processed
     */
    public long getSamplesProcessed() {
        return samplesProcessed.get();
    }
    
    @Override
    public void subscribe(Flow.Subscriber<? super double[]> subscriber) {
        publisher.subscribe(subscriber);
    }
    
    @Override
    public void close() {
        if (closed.compareAndSet(false, true)) {
            publisher.close();
        }
    }
    
    /**
     * Check if the denoiser is closed.
     * 
     * @return true if closed, false otherwise
     */
    public boolean isClosed() {
        return closed.get();
    }
    
    /**
     * Builder for creating MODWTStreamingDenoiser instances.
     */
    public static class Builder {
        private Wavelet wavelet = Daubechies.DB4;
        private BoundaryMode boundaryMode = BoundaryMode.PERIODIC;
        private int bufferSize = 256;
        private ThresholdType thresholdType = ThresholdType.SOFT;
        private ThresholdMethod thresholdMethod = ThresholdMethod.UNIVERSAL;
        private double thresholdMultiplier = 1.0;
        private NoiseEstimation noiseEstimation = NoiseEstimation.MAD;
        private int noiseWindowSize = 1024;
        
        /**
         * Set the wavelet to use for denoising.
         * 
         * @param wavelet the wavelet
         * @return this builder
         */
        public Builder wavelet(Wavelet wavelet) {
            if (wavelet == null) {
                throw new InvalidArgumentException("Wavelet cannot be null");
            }
            this.wavelet = wavelet;
            return this;
        }
        
        /**
         * Set the boundary mode.
         * 
         * @param boundaryMode the boundary mode
         * @return this builder
         */
        public Builder boundaryMode(BoundaryMode boundaryMode) {
            if (boundaryMode == null) {
                throw new InvalidArgumentException("Boundary mode cannot be null");
            }
            this.boundaryMode = boundaryMode;
            return this;
        }
        
        /**
         * Set the buffer size.
         * 
         * @param bufferSize the buffer size (must be positive)
         * @return this builder
         */
        public Builder bufferSize(int bufferSize) {
            if (bufferSize <= 0) {
                throw new InvalidArgumentException("Buffer size must be positive");
            }
            this.bufferSize = bufferSize;
            return this;
        }
        
        /**
         * Set the threshold type.
         * 
         * @param thresholdType the threshold type
         * @return this builder
         */
        public Builder thresholdType(ThresholdType thresholdType) {
            if (thresholdType == null) {
                throw new InvalidArgumentException("Threshold type cannot be null");
            }
            this.thresholdType = thresholdType;
            return this;
        }
        
        /**
         * Set the threshold method.
         * 
         * @param thresholdMethod the threshold method
         * @return this builder
         */
        public Builder thresholdMethod(ThresholdMethod thresholdMethod) {
            if (thresholdMethod == null) {
                throw new InvalidArgumentException("Threshold method cannot be null");
            }
            this.thresholdMethod = thresholdMethod;
            return this;
        }
        
        /**
         * Set the threshold multiplier for fine-tuning denoising aggressiveness.
         * 
         * <p>The multiplier is applied to the automatically calculated threshold:</p>
         * <ul>
         *   <li>Values {@literal >} 1.0: More aggressive denoising (removes more noise but may lose signal details)</li>
         *   <li>Values {@literal <} 1.0: Less aggressive denoising (preserves more signal details but may retain noise)</li>
         *   <li>Value = 1.0: Uses the standard threshold calculation without adjustment</li>
         * </ul>
         * 
         * @param thresholdMultiplier the threshold multiplier (must be positive)
         * @return this builder
         */
        public Builder thresholdMultiplier(double thresholdMultiplier) {
            if (thresholdMultiplier <= 0) {
                throw new InvalidArgumentException("Threshold multiplier must be positive");
            }
            this.thresholdMultiplier = thresholdMultiplier;
            return this;
        }
        
        /**
         * Set the noise estimation method.
         * 
         * @param noiseEstimation the noise estimation method
         * @return this builder
         */
        public Builder noiseEstimation(NoiseEstimation noiseEstimation) {
            if (noiseEstimation == null) {
                throw new InvalidArgumentException("Noise estimation cannot be null");
            }
            this.noiseEstimation = noiseEstimation;
            return this;
        }
        
        /**
         * Set the noise window size for estimation.
         * 
         * @param noiseWindowSize the window size (must be positive)
         * @return this builder
         */
        public Builder noiseWindowSize(int noiseWindowSize) {
            if (noiseWindowSize <= 0) {
                throw new InvalidArgumentException("Noise window size must be positive");
            }
            this.noiseWindowSize = noiseWindowSize;
            return this;
        }
        
        /**
         * Build the streaming denoiser.
         * 
         * @return a new MODWTStreamingDenoiser instance
         */
        public MODWTStreamingDenoiser build() {
            return new MODWTStreamingDenoiser(this);
        }
    }
}