SignalAdaptiveScaleSelector.java

package com.morphiqlabs.wavelet.cwt;

import com.morphiqlabs.wavelet.api.ContinuousWavelet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Signal-adaptive scale selector that analyzes signal characteristics to optimize scale selection.
 * 
 * <p>This selector performs signal analysis to determine:
 * <ul>
 *   <li>Dominant frequency components and their strengths</li>
 *   <li>Spectral density distribution</li>
 *   <li>Time-frequency localization requirements</li>
 *   <li>Optimal resolution trade-offs</li>
 * </ul>
 * 
 * <p>Scale allocation is weighted by signal energy distribution, ensuring more scales
 * are placed where the signal has significant frequency content.</p>
 */
public class SignalAdaptiveScaleSelector implements AdaptiveScaleSelector {

    /**
     * Creates a new SignalAdaptiveScaleSelector.
     */
    public SignalAdaptiveScaleSelector() {
        // Default constructor
    }
    
    
    private static final double DEFAULT_ENERGY_THRESHOLD = 0.01; // 1% of total energy
    private static final int DEFAULT_SPECTRAL_ANALYSIS_SIZE = 1024;
    private static final double DEFAULT_SCALE_DENSITY_FACTOR = 1.5;
    
    // Default scale parameters for zero signals
    // These provide reasonable coverage for typical signal analysis scenarios
    private static final double ZERO_SIGNAL_MIN_SCALE = 1.0;    // Covers high frequencies up to Nyquist
    private static final double ZERO_SIGNAL_MAX_SCALE = 100.0;  // Covers low frequencies down to ~1% of sampling rate
    private static final int ZERO_SIGNAL_NUM_SCALES = 32;       // Sufficient resolution for most applications
    
    @Override
    public double[] selectScales(double[] signal, ContinuousWavelet wavelet, double samplingRate) {
        if (samplingRate <= 0) {
            throw new IllegalArgumentException("Sampling rate must be positive");
        }
        ScaleSelectionConfig config = ScaleSelectionConfig.builder(samplingRate)
            .spacing(ScaleSpacing.ADAPTIVE)
            .useSignalAdaptation(true)
            .scalesPerOctave(12) // Higher density for adaptive selection
            .build();
        return selectScales(signal, wavelet, config);
    }
    
    @Override
    public double[] selectScales(double[] signal, ContinuousWavelet wavelet, ScaleSelectionConfig config) {
        if (signal == null || signal.length == 0) {
            throw new IllegalArgumentException("Signal cannot be null or empty");
        }
        if (wavelet == null) {
            throw new IllegalArgumentException("Wavelet cannot be null");
        }
        if (config == null) {
            throw new IllegalArgumentException("Config cannot be null");
        }
        if (config.getSamplingRate() <= 0) {
            throw new IllegalArgumentException("Sampling rate must be positive");
        }
        
        // Early check for zero signal to avoid expensive analysis
        boolean isZeroSignal = Arrays.stream(signal).noneMatch(value -> value != 0.0);
        
        if (isZeroSignal) {
            // For zero signal, return default logarithmic scales
            return ScaleSpace.logarithmic(
                ZERO_SIGNAL_MIN_SCALE, 
                ZERO_SIGNAL_MAX_SCALE, 
                Math.min(ZERO_SIGNAL_NUM_SCALES, config.getMaxScales())
            ).getScales();
        }
        
        // Analyze signal characteristics
        SignalCharacteristics characteristics = analyzeSignal(signal, config.getSamplingRate());
        
        // Determine frequency range of interest
        double[] frequencyRange = determineFrequencyRange(characteristics, config);
        double minFreq = frequencyRange[0];
        double maxFreq = frequencyRange[1];
        
        // Convert to scale range
        double centerFreq = wavelet.centerFrequency();
        double minScale = centerFreq * config.getSamplingRate() / maxFreq;
        double maxScale = centerFreq * config.getSamplingRate() / minFreq;
        
        // Generate adaptive scales based on signal energy distribution
        List<Double> scales = generateAdaptiveScales(
            minScale, maxScale, characteristics, config, wavelet);
        
        // Ensure we don't exceed maximum scale count
        if (scales.size() > config.getMaxScales()) {
            scales = prioritizeScales(scales, characteristics, config.getMaxScales(), wavelet);
        }
        
        // Sort and return
        scales.sort(Double::compareTo);
        return scales.stream().mapToDouble(Double::doubleValue).toArray();
    }
    
    /**
     * Analyzes signal to extract key characteristics for scale selection.
     * Uses a fixed-size window from the signal center for consistent performance.
     */
    private SignalCharacteristics analyzeSignal(double[] signal, double samplingRate) {
        // Use a fixed-size analysis window for consistent performance
        double[] analysisSegment;
        
        if (signal.length <= DEFAULT_SPECTRAL_ANALYSIS_SIZE) {
            // Signal is small enough - use it all
            analysisSegment = signal;
        } else {
            // Extract a window from the signal center
            analysisSegment = new double[DEFAULT_SPECTRAL_ANALYSIS_SIZE];
            int startIdx = (signal.length - DEFAULT_SPECTRAL_ANALYSIS_SIZE) / 2;
            System.arraycopy(signal, startIdx, analysisSegment, 0, DEFAULT_SPECTRAL_ANALYSIS_SIZE);
        }
        
        // Compute power spectral density
        SpectralAnalysis spectral = computeSpectralAnalysis(analysisSegment, samplingRate);
        
        // Find dominant frequencies
        List<DominantFrequency> dominantFreqs = findDominantFrequencies(spectral);
        
        // Compute signal statistics
        double[] stats = computeSignalStatistics(signal);
        
        // Estimate bandwidth and frequency spread
        double[] bandwidthInfo = estimateBandwidth(spectral);
        
        return new SignalCharacteristics(
            spectral,
            dominantFreqs,
            stats[0], // mean
            stats[1], // variance
            stats[2], // skewness
            stats[3], // kurtosis
            bandwidthInfo[0], // effective bandwidth
            bandwidthInfo[1], // spectral centroid
            bandwidthInfo[2]  // spectral spread
        );
    }
    
    /**
     * Computes spectral analysis of signal.
     */
    private SpectralAnalysis computeSpectralAnalysis(double[] signal, double samplingRate) {
        int n = signal.length;
        int fftSize = nextPowerOfTwo(n);
        
        // Apply window function (Hann window)
        double[] windowed = applyHannWindow(signal);
        
        // Compute FFT
        Complex[] fft = computeFFT(windowed, fftSize);
        
        // Compute power spectral density
        double[] psd = new double[fftSize / 2];
        double[] frequencies = new double[fftSize / 2];
        
        for (int i = 0; i < psd.length; i++) {
            psd[i] = fft[i].magnitude2() / (samplingRate * n);
            frequencies[i] = i * samplingRate / fftSize;
        }
        
        return new SpectralAnalysis(frequencies, psd, samplingRate);
    }
    
    /**
     * Applies Hann window to reduce spectral leakage.
     */
    private double[] applyHannWindow(double[] signal) {
        int n = signal.length;
        double[] windowed = new double[n];
        
        for (int i = 0; i < n; i++) {
            double window = 0.5 * (1 - Math.cos(2 * Math.PI * i / (n - 1)));
            windowed[i] = signal[i] * window;
        }
        
        return windowed;
    }
    
    /**
     * Finds dominant frequency components in the spectrum.
     */
    private List<DominantFrequency> findDominantFrequencies(SpectralAnalysis spectral) {
        double[] psd = spectral.psd;
        double[] frequencies = spectral.frequencies;
        
        List<DominantFrequency> dominantFreqs = new ArrayList<>();
        
        // Find total energy
        double totalEnergy = Arrays.stream(psd).sum();
        
        // Handle zero energy case
        if (totalEnergy <= 0) {
            return dominantFreqs; // Return empty list for zero signal
        }
        
        double energyThreshold = totalEnergy * DEFAULT_ENERGY_THRESHOLD;
        
        // Find local maxima above threshold
        for (int i = 2; i < psd.length - 2; i++) {
            if (psd[i] > energyThreshold &&
                psd[i] > psd[i-1] && psd[i] > psd[i+1] &&
                psd[i] > psd[i-2] && psd[i] > psd[i+2]) {
                
                // Estimate bandwidth around peak
                double bandwidth = estimateLocalBandwidth(psd, i);
                double relativeStrength = psd[i] / totalEnergy;
                
                dominantFreqs.add(new DominantFrequency(
                    frequencies[i], psd[i], bandwidth, relativeStrength));
            }
        }
        
        // Sort by energy (descending)
        dominantFreqs.sort((a, b) -> Double.compare(b.energy, a.energy));
        
        // Keep top frequencies that account for 90% of energy
        double cumulativeEnergy = 0;
        List<DominantFrequency> significantFreqs = new ArrayList<>();
        
        for (DominantFrequency freq : dominantFreqs) {
            significantFreqs.add(freq);
            cumulativeEnergy += freq.relativeStrength;
            if (cumulativeEnergy >= 0.9) break;
        }
        
        return significantFreqs;
    }
    
    /**
     * Estimates local bandwidth around a spectral peak.
     */
    private double estimateLocalBandwidth(double[] psd, int peakIndex) {
        double peakValue = psd[peakIndex];
        double halfPeak = peakValue / 2.0;
        
        // Find half-power points
        int leftIdx = peakIndex;
        while (leftIdx > 0 && psd[leftIdx] > halfPeak) {
            leftIdx--;
        }
        
        int rightIdx = peakIndex;
        while (rightIdx < psd.length - 1 && psd[rightIdx] > halfPeak) {
            rightIdx++;
        }
        
        return rightIdx - leftIdx; // In frequency bins
    }
    
    /**
     * Computes basic signal statistics.
     */
    private double[] computeSignalStatistics(double[] signal) {
        int n = signal.length;
        
        // Mean
        double mean = Arrays.stream(signal).average().orElse(0.0);
        
        // Variance
        double variance = Arrays.stream(signal)
            .map(x -> Math.pow(x - mean, 2))
            .average().orElse(0.0);
        
        double stdDev = Math.sqrt(variance);
        
        // Skewness
        double skewness = 0;
        if (stdDev > 0) {
            skewness = Arrays.stream(signal)
                .map(x -> Math.pow((x - mean) / stdDev, 3))
                .average().orElse(0.0);
        }
        
        // Kurtosis
        double kurtosis = 0;
        if (stdDev > 0) {
            kurtosis = Arrays.stream(signal)
                .map(x -> Math.pow((x - mean) / stdDev, 4))
                .average().orElse(0.0) - 3.0; // Excess kurtosis
        }
        
        return new double[]{mean, variance, skewness, kurtosis};
    }
    
    /**
     * Estimates signal bandwidth and spectral characteristics.
     */
    private double[] estimateBandwidth(SpectralAnalysis spectral) {
        double[] psd = spectral.psd;
        double[] frequencies = spectral.frequencies;
        
        // Compute spectral centroid (center of mass)
        double totalEnergy = Arrays.stream(psd).sum();
        
        // Handle zero energy case (e.g., zero signal)
        if (totalEnergy <= 0) {
            // Return default values for zero signal
            double defaultBandwidth = spectral.samplingRate / 4; // Nyquist / 2
            double defaultCentroid = spectral.samplingRate / 4;
            double defaultSpread = spectral.samplingRate / 8;
            return new double[]{defaultBandwidth, defaultCentroid, defaultSpread};
        }
        
        double centroid = 0;
        for (int i = 0; i < psd.length; i++) {
            centroid += frequencies[i] * psd[i];
        }
        centroid /= totalEnergy;
        
        // Compute spectral spread (second moment)
        double spread = 0;
        for (int i = 0; i < psd.length; i++) {
            spread += Math.pow(frequencies[i] - centroid, 2) * psd[i];
        }
        spread = Math.sqrt(spread / totalEnergy);
        
        // Estimate effective bandwidth (90% energy)
        double[] sortedPsd = psd.clone();
        Arrays.sort(sortedPsd);
        
        double targetEnergy = totalEnergy * 0.9; // Top 90% of energy
        double cumulativeEnergy = 0;
        double powerThreshold = 0;
        
        // Find power level that captures 90% of total energy
        for (int i = sortedPsd.length - 1; i >= 0; i--) {
            cumulativeEnergy += sortedPsd[i];
            if (cumulativeEnergy >= targetEnergy) {
                powerThreshold = sortedPsd[i];
                break;
            }
        }
        
        // Count frequency bins above threshold
        int significantBins = 0;
        for (double power : psd) {
            if (power >= powerThreshold) significantBins++;
        }
        
        double effectiveBandwidth = significantBins * spectral.samplingRate / (2 * psd.length);
        
        return new double[]{effectiveBandwidth, centroid, spread};
    }
    
    /**
     * Determines the frequency range of interest based on signal characteristics.
     */
    private double[] determineFrequencyRange(SignalCharacteristics characteristics, 
                                           ScaleSelectionConfig config) {
        double minFreq, maxFreq;
        
        if (config.getMinFrequency() > 0 && config.getMaxFrequency() > 0) {
            minFreq = config.getMinFrequency();
            maxFreq = config.getMaxFrequency();
        } else {
            // Auto-determine from signal characteristics
            SpectralAnalysis spectral = characteristics.spectralAnalysis;
            
            // Find frequency range containing significant energy
            double totalEnergy = Arrays.stream(spectral.psd).sum();
            
            // Handle zero energy case - use default frequency range
            if (totalEnergy <= 0) {
                minFreq = 1.0;
                maxFreq = spectral.samplingRate / 2;
                return new double[]{minFreq, maxFreq};
            }
            
            double cumulativeEnergy = 0;
            
            int minIdx = 0, maxIdx = spectral.psd.length - 1;
            
            // Find minimum frequency (5% of energy)
            for (int i = 0; i < spectral.psd.length; i++) {
                cumulativeEnergy += spectral.psd[i];
                if (cumulativeEnergy >= totalEnergy * 0.05) {
                    minIdx = i;
                    break;
                }
            }
            
            // Find maximum frequency (95% of energy)
            cumulativeEnergy = 0;
            for (int i = spectral.psd.length - 1; i >= 0; i--) {
                cumulativeEnergy += spectral.psd[i];
                if (cumulativeEnergy >= totalEnergy * 0.05) {
                    maxIdx = i;
                    break;
                }
            }
            
            minFreq = Math.max(spectral.frequencies[minIdx], 1.0);
            maxFreq = Math.min(spectral.frequencies[maxIdx], spectral.samplingRate / 2);
        }
        
        return new double[]{minFreq, maxFreq};
    }
    
    /**
     * Generates adaptive scales based on signal energy distribution.
     */
    private List<Double> generateAdaptiveScales(double minScale, double maxScale,
                                               SignalCharacteristics characteristics,
                                               ScaleSelectionConfig config,
                                               ContinuousWavelet wavelet) {
        List<Double> scales = new ArrayList<>();
        
        // Base logarithmic scale distribution
        int baseScaleCount = Math.min(config.getMaxScales(), 
            estimateScaleCount(characteristics.spectralCentroid * 0.5, 
                             characteristics.spectralCentroid * 2.0,
                             wavelet, characteristics.spectralAnalysis.samplingRate, 
                             config.getScalesPerOctave()));
        
        // Generate base scales
        for (int i = 0; i < baseScaleCount; i++) {
            double logScale = Math.log(minScale) + 
                (Math.log(maxScale) - Math.log(minScale)) * i / (baseScaleCount - 1);
            scales.add(Math.exp(logScale));
        }
        
        // Add extra scales around dominant frequencies
        for (DominantFrequency domFreq : characteristics.dominantFrequencies) {
            double centerFreq = wavelet.centerFrequency();
            double samplingRate = characteristics.spectralAnalysis.samplingRate;
            double scale = (centerFreq * samplingRate) / domFreq.frequency;
            
            if (scale >= minScale && scale <= maxScale) {
                // Add scales around this frequency with density proportional to energy
                int extraScales = (int) Math.ceil(domFreq.relativeStrength * 
                    DEFAULT_SCALE_DENSITY_FACTOR * config.getScalesPerOctave());
                
                for (int j = -extraScales/2; j <= extraScales/2; j++) {
                    if (j == 0) continue; // Skip center (already have base scales)
                    
                    double factor = Math.pow(2.0, j / (double)config.getScalesPerOctave());
                    double adaptiveScale = scale * factor;
                    
                    if (adaptiveScale >= minScale && adaptiveScale <= maxScale) {
                        scales.add(adaptiveScale);
                    }
                }
            }
        }
        
        // Remove duplicates and sort
        return scales.stream()
            .distinct()
            .sorted()
            .collect(ArrayList::new, ArrayList::add, ArrayList::addAll);
    }
    
    /**
     * Prioritizes scales when we have too many, keeping the most important ones.
     */
    private List<Double> prioritizeScales(List<Double> scales, SignalCharacteristics characteristics,
                                         int maxScales, ContinuousWavelet wavelet) {
        // Assign priority scores to each scale
        double[] priorities = new double[scales.size()];
        
        for (int i = 0; i < scales.size(); i++) {
            double scale = scales.get(i);
            priorities[i] = computeScalePriority(scale, characteristics, wavelet);
        }
        
        // Sort by priority and keep top scales
        Integer[] indices = new Integer[scales.size()];
        for (int i = 0; i < indices.length; i++) indices[i] = i;
        
        Arrays.sort(indices, (a, b) -> Double.compare(priorities[b], priorities[a]));
        
        List<Double> prioritizedScales = new ArrayList<>();
        for (int i = 0; i < Math.min(maxScales, indices.length); i++) {
            prioritizedScales.add(scales.get(indices[i]));
        }
        
        return prioritizedScales;
    }
    
    /**
     * Computes priority score for a scale based on signal characteristics.
     */
    private double computeScalePriority(double scale, SignalCharacteristics characteristics, 
                                      ContinuousWavelet wavelet) {
        double priority = 1.0; // Base priority
        
        // Higher priority for scales corresponding to dominant frequencies
        for (DominantFrequency domFreq : characteristics.dominantFrequencies) {
            double centerFreq = wavelet.centerFrequency();
            double samplingRate = characteristics.spectralAnalysis.samplingRate;
            double freqScale = centerFreq * samplingRate / domFreq.frequency;
            
            // Gaussian weighting around dominant frequency
            double distance = Math.abs(Math.log(scale) - Math.log(freqScale));
            double weight = Math.exp(-distance * distance / (2 * 0.5 * 0.5)); // σ = 0.5
            
            priority += domFreq.relativeStrength * weight;
        }
        
        return priority;
    }
    
    /**
     * FFT computation for spectral analysis.
     */
    private Complex[] computeFFT(double[] signal, int fftSize) {
        Complex[] data = new Complex[fftSize];
        
        // Initialize with signal data
        for (int i = 0; i < signal.length && i < fftSize; i++) {
            data[i] = new Complex(signal[i], 0);
        }
        
        // Pad with zeros
        for (int i = signal.length; i < fftSize; i++) {
            data[i] = new Complex(0, 0);
        }
        
        return fftRecursive(data);
    }
    
    /**
     * Recursive FFT implementation.
     */
    private Complex[] fftRecursive(Complex[] x) {
        int n = x.length;
        if (n <= 1) return x;
        
        // Divide
        Complex[] even = new Complex[n/2];
        Complex[] odd = new Complex[n/2];
        for (int k = 0; k < n/2; k++) {
            even[k] = x[2*k];
            odd[k] = x[2*k + 1];
        }
        
        // Conquer
        Complex[] evenFFT = fftRecursive(even);
        Complex[] oddFFT = fftRecursive(odd);
        
        // Combine
        Complex[] result = new Complex[n];
        for (int k = 0; k < n/2; k++) {
            double theta = -2 * Math.PI * k / n;
            Complex w = new Complex(Math.cos(theta), Math.sin(theta));
            Complex t = w.multiply(oddFFT[k]);
            result[k] = evenFFT[k].add(t);
            result[k + n/2] = evenFFT[k].subtract(t);
        }
        
        return result;
    }
    
    private static int nextPowerOfTwo(int n) {
        if (n <= 1) return 1;
        n--;
        n |= n >> 1;
        n |= n >> 2;
        n |= n >> 4;
        n |= n >> 8;
        n |= n >> 16;
        return n + 1;
    }
    
    // Helper classes
    
    private static class SignalCharacteristics {
        final SpectralAnalysis spectralAnalysis;
        final List<DominantFrequency> dominantFrequencies;
        final double mean;
        final double variance;
        final double skewness;
        final double kurtosis;
        final double effectiveBandwidth;
        final double spectralCentroid;
        final double spectralSpread;
        
        SignalCharacteristics(SpectralAnalysis spectralAnalysis,
                            List<DominantFrequency> dominantFrequencies,
                            double mean, double variance, double skewness, double kurtosis,
                            double effectiveBandwidth, double spectralCentroid, double spectralSpread) {
            this.spectralAnalysis = spectralAnalysis;
            this.dominantFrequencies = dominantFrequencies;
            this.mean = mean;
            this.variance = variance;
            this.skewness = skewness;
            this.kurtosis = kurtosis;
            this.effectiveBandwidth = effectiveBandwidth;
            this.spectralCentroid = spectralCentroid;
            this.spectralSpread = spectralSpread;
        }
    }
    
    private static class SpectralAnalysis {
        final double[] frequencies;
        final double[] psd;
        final double samplingRate;
        
        SpectralAnalysis(double[] frequencies, double[] psd, double samplingRate) {
            this.frequencies = frequencies;
            this.psd = psd;
            this.samplingRate = samplingRate;
        }
    }
    
    private static class DominantFrequency {
        final double frequency;
        final double energy;
        final double bandwidth;
        final double relativeStrength;
        
        DominantFrequency(double frequency, double energy, double bandwidth, double relativeStrength) {
            this.frequency = frequency;
            this.energy = energy;
            this.bandwidth = bandwidth;
            this.relativeStrength = relativeStrength;
        }
    }
    
    private static class Complex {
        final double real, imag;
        
        Complex(double real, double imag) {
            this.real = real;
            this.imag = imag;
        }
        
        Complex add(Complex other) {
            return new Complex(real + other.real, imag + other.imag);
        }
        
        Complex subtract(Complex other) {
            return new Complex(real - other.real, imag - other.imag);
        }
        
        Complex multiply(Complex other) {
            return new Complex(real * other.real - imag * other.imag,
                             real * other.imag + imag * other.real);
        }
        
        double magnitude2() {
            return real * real + imag * imag;
        }
    }
}