InverseCWT.java
package com.morphiqlabs.wavelet.cwt;
import com.morphiqlabs.wavelet.api.ContinuousWavelet;
import com.morphiqlabs.wavelet.api.ComplexContinuousWavelet;
import com.morphiqlabs.wavelet.exception.InvalidArgumentException;
import com.morphiqlabs.wavelet.exception.InvalidConfigurationException;
import com.morphiqlabs.wavelet.util.SignalProcessor;
import java.util.Arrays;
/**
* Inverse Continuous Wavelet Transform for signal reconstruction.
*
* <p>This class implements the inverse CWT, allowing reconstruction of signals
* from their time-frequency representation. The reconstruction is based on the
* admissibility condition and uses the following formula:</p>
*
* <pre>
* x(t) = (1/C_ψ) ∫∫ W(a,b) ψ_{a,b}(t) da db / a²
* </pre>
*
* where:
* <ul>
* <li>W(a,b) are the CWT coefficients</li>
* <li>ψ_{a,b}(t) is the scaled and translated wavelet</li>
* <li>C_ψ is the admissibility constant</li>
* <li>a is the scale parameter</li>
* <li>b is the translation parameter</li>
* </ul>
*
* <p>Key features:</p>
* <ul>
* <li>Supports reconstruction from real CWT coefficients</li>
* <li>FFT acceleration for large-scale reconstructions</li>
* <li>Automatic admissibility constant calculation</li>
* <li>Progressive reconstruction with configurable frequency bands</li>
* </ul>
*
* <p><strong>Current Limitations:</strong></p>
* <ul>
* <li>Complex coefficient reconstruction is not yet implemented - only real
* coefficients are processed</li>
* <li>For complex wavelets, only the real part of coefficients is used in
* reconstruction</li>
* </ul>
*/
public final class InverseCWT {
private static final double DEFAULT_TOLERANCE = 1e-10;
private static final int MIN_INTEGRATION_POINTS = 100;
private final ContinuousWavelet wavelet;
private final double admissibilityConstant;
private final boolean useFFT;
/**
* Creates an inverse CWT calculator for the given wavelet.
*
* @param wavelet the continuous wavelet used in the forward transform
* @throws InvalidArgumentException if wavelet is null
* @throws InvalidConfigurationException if wavelet doesn't satisfy admissibility
*/
public InverseCWT(ContinuousWavelet wavelet) {
if (wavelet == null) {
throw new InvalidArgumentException("Wavelet cannot be null");
}
this.wavelet = wavelet;
this.admissibilityConstant = calculateAdmissibilityConstant(wavelet);
if (admissibilityConstant <= 0 || Double.isInfinite(admissibilityConstant)) {
throw new InvalidConfigurationException(
"Wavelet does not satisfy admissibility condition: C_ψ = " + admissibilityConstant);
}
// Use FFT for large-scale reconstructions
this.useFFT = true;
}
/**
* Reconstructs the signal from CWT coefficients.
*
* <p>For complex CWT results, this method automatically uses only the real part
* of the coefficients for reconstruction. This is the standard approach for
* real-valued signal reconstruction from complex wavelet transforms.</p>
*
* <p><strong>Note:</strong> Full complex coefficient reconstruction (using both
* real and imaginary parts) is not yet implemented. The current implementation
* is suitable for most practical applications with real-valued signals.</p>
*
* @param cwtResult the CWT result containing coefficients (real or complex)
* @return reconstructed signal
* @throws InvalidArgumentException if input is invalid
*/
public double[] reconstruct(CWTResult cwtResult) {
if (cwtResult == null) {
throw new InvalidArgumentException("CWT result cannot be null");
}
double[] scales = cwtResult.getScales();
if (scales == null || scales.length == 0) {
throw new InvalidArgumentException("CWT result has no scales");
}
int signalLength = cwtResult.getNumSamples();
if (signalLength <= 0) {
throw new InvalidArgumentException("Invalid signal length: " + signalLength);
}
// Get real coefficients (works for both real and complex CWT results)
double[][] realCoeffs = cwtResult.getCoefficients();
if (realCoeffs == null || realCoeffs.length == 0) {
throw new InvalidArgumentException("CWT result has no coefficients");
}
return reconstructInternalReal(realCoeffs, scales, signalLength, 0, scales.length);
}
/**
* Reconstructs the signal using only a specific frequency band.
*
* <p>For complex CWT results, this method automatically uses only the real part
* of the coefficients for reconstruction. This is the standard approach for
* real-valued signal reconstruction from complex wavelet transforms.</p>
*
* @param cwtResult the CWT result (real or complex)
* @param minScale minimum scale (inclusive)
* @param maxScale maximum scale (exclusive)
* @return band-limited reconstructed signal
* @throws InvalidArgumentException if parameters are invalid
*/
public double[] reconstructBand(CWTResult cwtResult, double minScale, double maxScale) {
if (cwtResult == null) {
throw new InvalidArgumentException("CWT result cannot be null");
}
if (minScale <= 0 || maxScale <= minScale) {
throw new InvalidArgumentException(
"Invalid scale range: minScale=" + minScale + ", maxScale=" + maxScale);
}
double[] scales = cwtResult.getScales();
int signalLength = cwtResult.getNumSamples();
// Find scale indices
int startIdx = -1, endIdx = -1;
for (int i = 0; i < scales.length; i++) {
if (startIdx == -1 && scales[i] >= minScale) {
startIdx = i;
}
if (scales[i] > maxScale) {
endIdx = i;
break;
}
}
if (startIdx == -1) {
startIdx = 0;
}
if (endIdx == -1) {
endIdx = scales.length;
}
// Check if we have any scales in the requested range
if (startIdx >= endIdx) {
// No scales in the requested range - return zero signal
return new double[signalLength];
}
double[][] realCoeffs = cwtResult.getCoefficients();
return reconstructInternalReal(realCoeffs, scales, signalLength, startIdx, endIdx);
}
/**
* Reconstructs the signal from frequency domain representation.
*
* <p>For complex CWT results, this method automatically uses only the real part
* of the coefficients for reconstruction. This is the standard approach for
* real-valued signal reconstruction from complex wavelet transforms.</p>
*
* @param cwtResult the CWT result (real or complex)
* @param samplingRate the sampling rate in Hz
* @param minFreq minimum frequency in Hz (inclusive)
* @param maxFreq maximum frequency in Hz (exclusive)
* @return frequency-band limited reconstructed signal
*/
public double[] reconstructFrequencyBand(CWTResult cwtResult, double samplingRate,
double minFreq, double maxFreq) {
if (samplingRate <= 0) {
throw new InvalidArgumentException("Sampling rate must be positive");
}
if (minFreq < 0 || maxFreq <= minFreq || maxFreq > samplingRate / 2) {
throw new InvalidArgumentException(
"Invalid frequency range: minFreq=" + minFreq + ", maxFreq=" + maxFreq);
}
// Convert frequencies to scales
// For Morlet wavelet: frequency = centerFreq * samplingRate / scale
// So scale = centerFreq * samplingRate / frequency
double centerFreq = wavelet.centerFrequency();
double maxScale = centerFreq * samplingRate / minFreq;
double minScale = centerFreq * samplingRate / maxFreq;
return reconstructBand(cwtResult, minScale, maxScale);
}
/**
* Internal reconstruction implementation for real coefficients.
*/
private double[] reconstructInternalReal(double[][] coefficients, double[] scales,
int signalLength, int startScale, int endScale) {
if (useFFT && signalLength >= 128) {
// Use FFT-based reconstruction for large signals
return reconstructInternalRealFFT(coefficients, scales, signalLength, startScale, endScale);
} else {
// Use direct method for small signals
return reconstructInternalRealDirect(coefficients, scales, signalLength, startScale, endScale);
}
}
/**
* FFT-based reconstruction - O(N log N * M) complexity.
*/
private double[] reconstructInternalRealFFT(double[][] coefficients, double[] scales,
int signalLength, int startScale, int endScale) {
// Pad to next power of 2 for FFT
int fftSize = nextPowerOfTwo(signalLength);
ComplexNumber[] reconstruction = new ComplexNumber[fftSize];
Arrays.fill(reconstruction, new ComplexNumber(0, 0));
// Integration weights
double[] weights = calculateLogScaleWeights(scales, startScale, endScale);
// For each scale, compute contribution using FFT convolution
for (int s = startScale; s < endScale; s++) {
double scale = scales[s];
double weight = weights[s - startScale] / scale;
// Create wavelet at this scale in frequency domain
ComplexNumber[] waveletFFT = createWaveletFFT(scale, fftSize);
// FFT of coefficients at this scale
ComplexNumber[] coeffFFT = new ComplexNumber[fftSize];
for (int i = 0; i < signalLength; i++) {
coeffFFT[i] = new ComplexNumber(coefficients[s][i], 0);
}
for (int i = signalLength; i < fftSize; i++) {
coeffFFT[i] = new ComplexNumber(0, 0);
}
SignalProcessor.fft(coeffFFT);
// Multiply in frequency domain and accumulate
for (int i = 0; i < fftSize; i++) {
ComplexNumber contrib = coeffFFT[i].multiply(waveletFFT[i]).multiply(weight);
reconstruction[i] = reconstruction[i].add(contrib);
}
}
// Inverse FFT to get time domain signal
SignalProcessor.ifft(reconstruction);
// Extract real part and normalize
double[] result = new double[signalLength];
for (int i = 0; i < signalLength; i++) {
result[i] = reconstruction[i].real() / admissibilityConstant;
}
return result;
}
/**
* Direct reconstruction - O(N²M) complexity.
*/
private double[] reconstructInternalRealDirect(double[][] coefficients, double[] scales,
int signalLength, int startScale, int endScale) {
double[] reconstructed = new double[signalLength];
// Integration weights for trapezoidal rule in log scale
double[] weights = calculateLogScaleWeights(scales, startScale, endScale);
// For each time point
for (int t = 0; t < signalLength; t++) {
double sum = 0.0;
// Integrate over scales
for (int s = startScale; s < endScale; s++) {
double scale = scales[s];
// Integrate over all translation positions
for (int b = 0; b < signalLength; b++) {
double coeff = coefficients[s][b];
// Skip negligible coefficients
if (Math.abs(coeff) < DEFAULT_TOLERANCE) {
continue;
}
// Calculate reconstruction kernel value
double kernelValue = reconstructionKernel(t, b, scale, signalLength);
// Add contribution: W(a,b) * ψ_{a,b}(t) * da / a²
// For logarithmic integration: da = a * d(log a)
// So da/a² = d(log a)/a
sum += coeff * kernelValue * weights[s - startScale] / scale;
}
}
// Normalize by admissibility constant
reconstructed[t] = sum / admissibilityConstant;
}
return reconstructed;
}
/**
* Internal reconstruction implementation for complex coefficients.
*/
private double[] reconstructInternal(ComplexMatrix coefficients, double[] scales,
int signalLength, int startScale, int endScale) {
double[] reconstructed = new double[signalLength];
// Integration weights for trapezoidal rule
double[] weights = calculateIntegrationWeights(scales, startScale, endScale);
// For each time point
for (int t = 0; t < signalLength; t++) {
double sum = 0.0;
// Integrate over scales
for (int s = startScale; s < endScale; s++) {
double scale = scales[s];
double coeffReal = coefficients.getReal(s, t);
double coeffImag = coefficients.getImaginary(s, t);
double coeffMagnitude = coefficients.getMagnitude(s, t);
// Skip negligible coefficients
if (coeffMagnitude < DEFAULT_TOLERANCE) {
continue;
}
// Calculate reconstruction kernel value
double kernelValue = reconstructionKernel(t, t, scale, signalLength);
// Add contribution: W(a,b) * ψ(t) * da/a
sum += coeffReal * kernelValue * weights[s - startScale] / scale;
}
// Normalize by admissibility constant
reconstructed[t] = sum / admissibilityConstant;
}
return reconstructed;
}
/**
* Calculates the reconstruction kernel ψ_{a,b}(t).
*/
private double reconstructionKernel(int t, int b, double scale, int signalLength) {
// Calculate the argument for the wavelet
double arg = (t - b) / scale;
// Scale factor for proper normalization: 1/√a
double scaleFactor = 1.0 / Math.sqrt(scale);
// Evaluate the wavelet (complex conjugate for reconstruction)
if (wavelet instanceof ComplexContinuousWavelet complexWavelet) {
// For complex wavelets, use the real part of the conjugate
return scaleFactor * complexWavelet.psi(arg);
} else {
// For real wavelets, just evaluate with scaling
return scaleFactor * wavelet.psi(arg);
}
}
/**
* Calculates integration weights using trapezoidal rule.
*/
private double[] calculateIntegrationWeights(double[] scales, int start, int end) {
int n = end - start;
double[] weights = new double[n];
if (n == 1) {
weights[0] = 1.0;
return weights;
}
// Trapezoidal rule weights
for (int i = 0; i < n - 1; i++) {
double da = scales[start + i + 1] - scales[start + i];
if (i == 0) {
weights[i] = da / 2.0;
} else {
weights[i] += da / 2.0;
}
weights[i + 1] = da / 2.0;
}
return weights;
}
/**
* Calculates integration weights for logarithmic scale spacing.
* Since scales are often logarithmically spaced, we integrate in log scale.
*
* @param scales the array of scale values
* @param start start index (inclusive)
* @param end end index (exclusive)
* @return integration weights, or empty array if start >= end
*/
private double[] calculateLogScaleWeights(double[] scales, int start, int end) {
int n = end - start;
if (n <= 0) {
// Empty range - return empty weights array
return new double[0];
}
if (start < 0 || end > scales.length) {
throw new InvalidArgumentException("Scale indices out of bounds: start=" + start +
", end=" + end + ", scales.length=" + scales.length);
}
// Validate that all scales in the range are positive
for (int i = start; i < end; i++) {
if (scales[i] <= 0) {
throw new InvalidArgumentException("Scale values must be positive for logarithmic integration. " +
"Found non-positive scale at index " + i + ": " + scales[i]);
}
}
double[] weights = new double[n];
if (n == 1) {
weights[0] = scales[start];
return weights;
}
// For logarithmic integration: ∫ f(a) da/a = ∫ f(a) d(log a)
// So we calculate weights for d(log a)
for (int i = 0; i < n; i++) {
if (i == 0) {
double dlogA = Math.log(scales[start + 1] / scales[start]);
weights[i] = dlogA / 2.0;
} else if (i == n - 1) {
double dlogA = Math.log(scales[start + i] / scales[start + i - 1]);
weights[i] = dlogA / 2.0;
} else {
double dlogA = Math.log(scales[start + i + 1] / scales[start + i - 1]) / 2.0;
weights[i] = dlogA;
}
}
return weights;
}
/**
* Calculates the admissibility constant for the wavelet.
*
* C_ψ = ∫ |Ψ(ω)|² / |ω| dω
*
* where Ψ(ω) is the Fourier transform of the wavelet.
*/
private double calculateAdmissibilityConstant(ContinuousWavelet wavelet) {
// Use known analytical values for common wavelets
String waveletName = wavelet.name().toLowerCase();
if (waveletName.contains("morlet") || waveletName.contains("morl")) {
// Morlet wavelet: C_ψ = π for ω₀ > 5
return Math.PI;
} else if (waveletName.contains("mexh") || waveletName.contains("dog2")) {
// Mexican Hat (DOG2): C_ψ = π/√2
return Math.PI / Math.sqrt(2.0);
} else if (waveletName.contains("dog")) {
// General DOG wavelets
return Math.PI; // Approximate
} else if (waveletName.contains("paul")) {
// Paul wavelet: C_ψ = 2π
return 2.0 * Math.PI;
} else if (waveletName.contains("shannon")) {
// Shannon wavelet: C_ψ = π
return Math.PI;
}
// For unknown wavelets, use numerical integration
return calculateAdmissibilityNumerical(wavelet);
}
/**
* Numerical calculation of admissibility constant.
*/
private double calculateAdmissibilityNumerical(ContinuousWavelet wavelet) {
int nPoints = 10000;
double[] freqs = new double[nPoints];
double maxFreq = 100.0; // Reasonable upper limit
// Create frequency grid (logarithmic spacing to handle 1/ω)
for (int i = 1; i < nPoints; i++) {
freqs[i] = maxFreq * Math.pow(10.0, -4.0 + 4.0 * i / (nPoints - 1));
}
double sum = 0.0;
// Numerical integration using trapezoidal rule
for (int i = 1; i < nPoints - 1; i++) {
double omega = freqs[i];
double psiHat = waveletFourierTransform(wavelet, omega);
double integrand = psiHat * psiHat / omega;
double dOmega = (freqs[i + 1] - freqs[i - 1]) / 2.0;
sum += integrand * dOmega;
}
return 2.0 * Math.PI * sum;
}
/**
* Approximates the Fourier transform of the wavelet at frequency ω.
*/
private double waveletFourierTransform(ContinuousWavelet wavelet, double omega) {
// Use numerical integration for Fourier transform
int nPoints = 1000;
double tMax = 20.0; // Integration limits
double dt = 2.0 * tMax / nPoints;
double realSum = 0.0;
double imagSum = 0.0;
for (int i = 0; i < nPoints; i++) {
double t = -tMax + i * dt;
double psi = wavelet.psi(t);
// Fourier transform: ∫ ψ(t) e^(-iωt) dt
realSum += psi * Math.cos(-omega * t) * dt;
imagSum += psi * Math.sin(-omega * t) * dt;
}
// Return magnitude
return Math.sqrt(realSum * realSum + imagSum * imagSum);
}
/**
* Gets the admissibility constant for this wavelet.
*
* @return the admissibility constant C_ψ
*/
public double getAdmissibilityConstant() {
return admissibilityConstant;
}
/**
* Checks if the wavelet satisfies the admissibility condition.
*
* @return true if admissible
*/
public boolean isAdmissible() {
return admissibilityConstant > 0 && admissibilityConstant < Double.POSITIVE_INFINITY;
}
/**
* Creates wavelet in frequency domain for FFT-based reconstruction.
*/
private ComplexNumber[] createWaveletFFT(double scale, int fftSize) {
// Create scaled wavelet in time domain with proper circular shift
double[] waveletTime = new double[fftSize];
// The wavelet should be centered at t=0, which corresponds to index 0
// in the FFT convention (not fftSize/2)
for (int i = 0; i < fftSize; i++) {
// Map index to time value with wraparound
double t;
if (i <= fftSize / 2) {
t = i / scale;
} else {
t = (i - fftSize) / scale;
}
waveletTime[i] = wavelet.psi(t) / Math.sqrt(scale);
}
// Convert to frequency domain
ComplexNumber[] waveletComplex = new ComplexNumber[fftSize];
for (int i = 0; i < fftSize; i++) {
waveletComplex[i] = new ComplexNumber(waveletTime[i], 0);
}
SignalProcessor.fft(waveletComplex);
return waveletComplex;
}
/**
* Finds next power of 2 greater than or equal to n.
*/
private static int nextPowerOfTwo(int n) {
if (n <= 1) return 1;
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
return n + 1;
}
}