MultiLevelMODWTStreamingTransform.java

package com.morphiqlabs.wavelet.modwt.streaming;

import com.morphiqlabs.wavelet.api.BoundaryMode;
import com.morphiqlabs.wavelet.api.Wavelet;
import com.morphiqlabs.wavelet.exception.InvalidArgumentException;
import com.morphiqlabs.wavelet.exception.InvalidSignalException;
import com.morphiqlabs.wavelet.exception.InvalidStateException;
import com.morphiqlabs.wavelet.modwt.MODWTResult;
import com.morphiqlabs.wavelet.modwt.MultiLevelMODWTResult;
import com.morphiqlabs.wavelet.modwt.MultiLevelMODWTTransform;
import com.morphiqlabs.wavelet.util.ValidationUtils;

import java.util.Arrays;
import java.util.concurrent.SubmissionPublisher;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * Multi-level streaming MODWT transform implementation.
 *
 * <p>This implementation performs multi-level MODWT decomposition on streaming data,
 * providing hierarchical time-frequency analysis in real-time.</p>
 *
 * <p>Key features:</p>
 * <ul>
 *   <li>Multi-resolution analysis of streaming data</li>
 *   <li>Maintains continuity across all decomposition levels</li>
 *   <li>Efficient cascaded filtering</li>
 *   <li>Flexible buffer management per level</li>
 * </ul>
 */
class MultiLevelMODWTStreamingTransform extends SubmissionPublisher<MODWTResult>
        implements MODWTStreamingTransform {

    private final Wavelet wavelet;
    private final BoundaryMode boundaryMode;
    private final int bufferSize;
    private final int levels;
    private final MultiLevelMODWTTransform multiLevelTransform;
    
    // Level-specific buffers and state
    private final double[][] levelBuffers;
    private final int[] levelPositions;
    private final int[] levelSamplesCount;
    
    // State management
    private final AtomicBoolean isClosed = new AtomicBoolean(false);
    
    // Statistics
    private final StreamingStatisticsImpl statistics = new StreamingStatisticsImpl();
    
    // Shared empty array to avoid unnecessary allocations
    private static final double[] EMPTY_ARRAY = new double[0];

    /**
     * Creates a new multi-level streaming MODWT transform.
     *
     * @param wavelet      the wavelet to use
     * @param boundaryMode the boundary mode
     * @param bufferSize   the buffer size
     * @param levels       number of decomposition levels
     * @throws InvalidArgumentException if parameters are invalid
     */
    public MultiLevelMODWTStreamingTransform(Wavelet wavelet, BoundaryMode boundaryMode, 
                                           int bufferSize, int levels) {
        super();

        if (wavelet == null) {
            throw new InvalidArgumentException("Wavelet cannot be null");
        }
        if (boundaryMode == null) {
            throw new InvalidArgumentException("Boundary mode cannot be null");
        }
        if (bufferSize <= 0) {
            throw new InvalidArgumentException("Buffer size must be positive, got: " + bufferSize);
        }
        if (levels < 1) {
            throw new InvalidArgumentException("Levels must be at least 1, got: " + levels);
        }

        this.wavelet = wavelet;
        this.boundaryMode = boundaryMode;
        this.bufferSize = bufferSize;
        this.levels = levels;
        
        // Initialize level-specific buffers
        this.levelBuffers = new double[levels][];
        this.levelPositions = new int[levels];
        this.levelSamplesCount = new int[levels];
        
        for (int level = 0; level < levels; level++) {
            // Each level needs its own buffer
            levelBuffers[level] = new double[bufferSize];
        }
        
        // Create multi-level MODWT transform
        this.multiLevelTransform = new MultiLevelMODWTTransform(wavelet, boundaryMode);
    }

    @Override
    public synchronized void process(double[] data) {
        if (isClosed.get()) {
            throw InvalidStateException.closed("Transform");
        }
        if (data == null || data.length == 0) {
            throw new InvalidSignalException("Data cannot be null or empty");
        }

        // Process data through the first level buffer
        for (double sample : data) {
            processSampleInternal(sample);
        }
    }

    @Override
    public synchronized void processSample(double sample) {
        if (isClosed.get()) {
            throw InvalidStateException.closed("Transform");
        }
        
        processSampleInternal(sample);
    }

    private void processSampleInternal(double sample) {
        // Add sample to first level buffer
        levelBuffers[0][levelPositions[0]] = sample;
        levelPositions[0] = (levelPositions[0] + 1) % bufferSize;
        levelSamplesCount[0]++;
        
        // Check if first level buffer is full
        if (levelSamplesCount[0] >= bufferSize) {
            processMultiLevel();
        }
        
        statistics.incrementSamplesProcessed();
    }

    private void processMultiLevel() {
        long startTime = System.nanoTime();
        
        // Extract buffer contents for processing
        double[] processingBuffer = new double[bufferSize];
        int startPos = (levelPositions[0] - bufferSize + bufferSize) % bufferSize;
        
        for (int i = 0; i < bufferSize; i++) {
            processingBuffer[i] = levelBuffers[0][(startPos + i) % bufferSize];
        }
        
        // Apply multi-level MODWT transform
        MultiLevelMODWTResult multiResult = multiLevelTransform.decompose(processingBuffer, levels);
        
        // Convert to single-level results and publish each level
        for (int level = 1; level <= levels; level++) {
            double[] details = multiResult.getDetailCoeffsAtLevel(level);
            double[] approx = getApproximationForLevel(multiResult, level);
            
            MODWTResult levelResult = new MODWTResultWrapper(approx, details);
            submit(levelResult);
        }
        
        // Update statistics
        long processingTime = System.nanoTime() - startTime;
        statistics.recordBlockProcessed(processingTime);
        
        // Reset first level counter
        levelSamplesCount[0] = 0;
    }

    @Override
    public synchronized void flush() {
        if (levelSamplesCount[0] > 0) {
            // Process remaining samples with zero padding
            double[] finalBuffer = new double[bufferSize];
            int startPos = (levelPositions[0] - levelSamplesCount[0] + bufferSize) % bufferSize;
            
            for (int i = 0; i < levelSamplesCount[0]; i++) {
                finalBuffer[i] = levelBuffers[0][(startPos + i) % bufferSize];
            }
            
            // Apply multi-level transform
            MultiLevelMODWTResult multiResult = multiLevelTransform.decompose(finalBuffer, levels);
            
            // Publish results for each level
            for (int level = 1; level <= levels; level++) {
                double[] details = multiResult.getDetailCoeffsAtLevel(level);
                double[] approx = getApproximationForLevel(multiResult, level);
                
                MODWTResult levelResult = new MODWTResultWrapper(approx, details);
                submit(levelResult);
            }
            
            // Reset all levels
            Arrays.fill(levelSamplesCount, 0);
            Arrays.fill(levelPositions, 0);
        }
    }

    @Override
    public StreamingStatistics getStatistics() {
        return statistics;
    }

    @Override
    public synchronized void reset() {
        if (isClosed.get()) {
            throw InvalidStateException.closed("Transform");
        }
        
        // Clear all buffers
        for (double[] buffer : levelBuffers) {
            Arrays.fill(buffer, 0.0);
        }
        Arrays.fill(levelPositions, 0);
        Arrays.fill(levelSamplesCount, 0);
        
        // Reset statistics
        statistics.reset();
    }

    @Override
    public int getBufferLevel() {
        return levelSamplesCount[0];
    }

    @Override
    public boolean isClosed() {
        return isClosed.get();
    }
    
    /**
     * Gets the appropriate approximation coefficients for a given level.
     * Only the final level contains actual approximation coefficients;
     * intermediate levels use an empty array.
     * 
     * @param multiResult the multi-level MODWT result
     * @param currentLevel the current level being processed
     * @return approximation coefficients for the final level, empty array otherwise
     */
    private double[] getApproximationForLevel(MultiLevelMODWTResult multiResult, int currentLevel) {
        return currentLevel == levels ? multiResult.getApproximationCoeffs() : EMPTY_ARRAY;
    }

    @Override
    public void close() {
        if (isClosed.compareAndSet(false, true)) {
            // Flush any remaining data
            flush();
            
            // Close the publisher
            super.close();
        }
    }

    /**
     * Wrapper class to convert arrays to MODWTResult.
     */
    private static class MODWTResultWrapper implements MODWTResult {
        private final double[] approx;
        private final double[] details;

        MODWTResultWrapper(double[] approx, double[] details) {
            this.approx = approx;
            this.details = details;
        }

        @Override
        public double[] approximationCoeffs() {
            return approx.clone();
        }

        @Override
        public double[] detailCoeffs() {
            return details.clone();
        }

        @Override
        public int getSignalLength() {
            return approx.length;
        }

        @Override
        public boolean isValid() {
            // Check if arrays are valid
            if (approx == null || details == null) {
                return false;
            }
            
            // Check for finite values in approximation coefficients
            for (double value : approx) {
                if (Double.isNaN(value) || Double.isInfinite(value)) {
                    return false;
                }
            }
            
            // Check for finite values in detail coefficients
            for (double value : details) {
                if (Double.isNaN(value) || Double.isInfinite(value)) {
                    return false;
                }
            }
            
            return true;
        }
    }
    
    /**
     * Implementation of streaming statistics.
     */
    private static class StreamingStatisticsImpl implements MODWTStreamingTransform.StreamingStatistics {
        private final java.util.concurrent.atomic.AtomicLong samplesProcessed = new java.util.concurrent.atomic.AtomicLong();
        private final java.util.concurrent.atomic.AtomicLong blocksProcessed = new java.util.concurrent.atomic.AtomicLong();
        private final java.util.concurrent.atomic.LongAdder totalProcessingTime = new java.util.concurrent.atomic.LongAdder();
        private final java.util.concurrent.atomic.AtomicLong maxProcessingTime = new java.util.concurrent.atomic.AtomicLong();
        private final java.util.concurrent.atomic.AtomicLong minProcessingTime = new java.util.concurrent.atomic.AtomicLong(Long.MAX_VALUE);
        private final java.util.concurrent.atomic.AtomicLong startTime = new java.util.concurrent.atomic.AtomicLong(System.nanoTime());

        void incrementSamplesProcessed() {
            samplesProcessed.incrementAndGet();
        }

        void recordBlockProcessed(long processingTimeNanos) {
            blocksProcessed.incrementAndGet();
            totalProcessingTime.add(processingTimeNanos);
            
            // Update max
            long currentMax = maxProcessingTime.get();
            while (processingTimeNanos > currentMax) {
                if (maxProcessingTime.compareAndSet(currentMax, processingTimeNanos)) {
                    break;
                }
                currentMax = maxProcessingTime.get();
            }
            
            // Update min
            long currentMin = minProcessingTime.get();
            while (processingTimeNanos < currentMin) {
                if (minProcessingTime.compareAndSet(currentMin, processingTimeNanos)) {
                    break;
                }
                currentMin = minProcessingTime.get();
            }
        }

        @Override
        public long getSamplesProcessed() {
            return samplesProcessed.get();
        }

        @Override
        public long getBlocksProcessed() {
            return blocksProcessed.get();
        }

        @Override
        public long getAverageProcessingTimeNanos() {
            long blocks = blocksProcessed.get();
            return blocks > 0 ? totalProcessingTime.sum() / blocks : 0;
        }

        @Override
        public long getMaxProcessingTimeNanos() {
            return maxProcessingTime.get();
        }

        @Override
        public long getMinProcessingTimeNanos() {
            long min = minProcessingTime.get();
            return min == Long.MAX_VALUE ? 0 : min;
        }

        @Override
        public double getThroughputSamplesPerSecond() {
            long elapsedNanos = System.nanoTime() - startTime.get();
            double elapsedSeconds = elapsedNanos / 1_000_000_000.0;
            return elapsedSeconds > 0 ? samplesProcessed.get() / elapsedSeconds : 0;
        }

        @Override
        public void reset() {
            samplesProcessed.set(0);
            blocksProcessed.set(0);
            totalProcessingTime.reset();
            maxProcessingTime.set(0);
            minProcessingTime.set(Long.MAX_VALUE);
            startTime.set(System.nanoTime());
        }
    }
}