ParallelMultiLevelMODWT.java
package com.morphiqlabs.wavelet.modwt;
import com.morphiqlabs.wavelet.api.BoundaryMode;
import com.morphiqlabs.wavelet.api.DiscreteWavelet;
import com.morphiqlabs.wavelet.api.Wavelet;
import com.morphiqlabs.wavelet.exception.InvalidArgumentException;
import com.morphiqlabs.wavelet.exception.InvalidSignalException;
import com.morphiqlabs.wavelet.WaveletOperations;
import com.morphiqlabs.wavelet.util.ValidationUtils;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ForkJoinPool;
/**
* Parallel implementation of multi-level MODWT with CompletableFuture chains.
*
* <p>This class provides an optimized parallel version of multi-level MODWT that:</p>
* <ul>
* <li>Uses CompletableFuture chains to handle level dependencies</li>
* <li>Parallelizes low-pass and high-pass filtering at each level</li>
* <li>Pre-allocates memory to avoid contention</li>
* <li>Properly handles filter upsampling at each level</li>
* </ul>
*
* @since 1.0.0
*/
public class ParallelMultiLevelMODWT implements AutoCloseable {
private static final int MAX_DECOMPOSITION_LEVELS = 10;
private static final int MAX_SAFE_SHIFT_BITS = 31;
private final Executor executor;
private final boolean ownsExecutor;
private final int minParallelSignalLength;
// Per-instance cache of synthesis filters by (wavelet instance, level)
private final java.util.concurrent.ConcurrentHashMap<Wavelet, java.util.concurrent.ConcurrentHashMap<Integer, FilterSet>> synthesisCache = new java.util.concurrent.ConcurrentHashMap<>();
/**
* Creates a parallel multi-level MODWT using the common ForkJoinPool.
*/
public ParallelMultiLevelMODWT() {
this(ForkJoinPool.commonPool(), false, 0);
}
/**
* Creates a parallel multi-level MODWT with a custom executor.
*
* @param executor The executor to use for parallel tasks
* @throws NullPointerException if {@code executor} is null
*/
public ParallelMultiLevelMODWT(Executor executor) {
this(executor, false, 0);
}
/**
* Creates a parallel multi-level MODWT with a new ForkJoinPool of given parallelism.
* If {@code parallelism <= 0}, falls back to the common pool.
*
* @param parallelism desired parallelism level
*/
public ParallelMultiLevelMODWT(int parallelism) {
this(parallelism > 0 ? new ForkJoinPool(parallelism) : ForkJoinPool.commonPool(),
parallelism > 0, 0);
}
private ParallelMultiLevelMODWT(Executor executor, boolean ownsExecutor, int minParallelSignalLength) {
this.executor = executor;
this.ownsExecutor = ownsExecutor;
this.minParallelSignalLength = Math.max(0, minParallelSignalLength);
}
/**
* Performs parallel multi-level MODWT decomposition.
*
* @param signal Input signal
* @param wavelet Wavelet to use
* @param mode Boundary mode
* @param levels Number of decomposition levels
* @return Multi-level MODWT result
* @throws com.morphiqlabs.wavelet.exception.InvalidSignalException if {@code signal} is invalid
* @throws com.morphiqlabs.wavelet.exception.InvalidArgumentException if {@code wavelet} is not discrete or {@code levels} is invalid
*/
public MultiLevelMODWTResult decompose(double[] signal, Wavelet wavelet,
BoundaryMode mode, int levels) {
// Validate inputs
ValidationUtils.validateFiniteValues(signal, "signal");
if (signal.length == 0) {
throw new InvalidSignalException("Signal cannot be empty");
}
if (!(wavelet instanceof DiscreteWavelet dw)) {
throw new InvalidArgumentException("Multi-level MODWT requires a discrete wavelet");
}
int maxLevels = calculateMaxLevels(signal.length, dw);
if (levels < 1 || levels > maxLevels) {
throw new InvalidArgumentException(
"Invalid number of levels: " + levels +
". Must be between 1 and " + maxLevels);
}
// Initialize result structure
MultiLevelMODWTResultImpl result = new MultiLevelMODWTResultImpl(signal.length, levels);
// Pre-allocate arrays and filters
double[][] detailArrays = new double[levels][signal.length];
double[] currentApprox = signal.clone();
double[] nextApprox = new double[signal.length];
FilterSet[] filterSets = precomputeFilterSets(dw, levels);
boolean useParallel = signal.length >= minParallelSignalLength;
if (useParallel) {
// Build dependency chain across levels (cascade)
CompletableFuture<Void> chain = CompletableFuture.completedFuture(null);
for (int level = 1; level <= levels; level++) {
final int li = level - 1;
final FilterSet filters = filterSets[li];
final double[] inputAtLevel = currentApprox;
final double[] outputApprox = nextApprox;
chain = chain.thenCompose(v -> {
CompletableFuture<Void> low = CompletableFuture.runAsync(() -> {
if (mode == BoundaryMode.PERIODIC) {
WaveletOperations.circularConvolveMODWT(inputAtLevel, filters.scaledLowPass, outputApprox);
} else {
applyZeroPaddingMODWT(inputAtLevel, filters.scaledLowPass, outputApprox);
}
}, executor);
CompletableFuture<Void> high = CompletableFuture.runAsync(() -> {
if (mode == BoundaryMode.PERIODIC) {
WaveletOperations.circularConvolveMODWT(inputAtLevel, filters.scaledHighPass, detailArrays[li]);
} else {
applyZeroPaddingMODWT(inputAtLevel, filters.scaledHighPass, detailArrays[li]);
}
}, executor);
return CompletableFuture.allOf(low, high).thenRun(() -> {
// Move to next level: copy nextApprox into currentApprox and clear nextApprox
System.arraycopy(outputApprox, 0, inputAtLevel, 0, inputAtLevel.length);
java.util.Arrays.fill(outputApprox, 0.0);
});
});
}
chain.join();
} else {
// Sequential path for small signals to avoid async overhead
for (int level = 1; level <= levels; level++) {
final int li = level - 1;
final FilterSet filters = filterSets[li];
final double[] inputAtLevel = currentApprox;
final double[] outputApprox = nextApprox;
if (mode == BoundaryMode.PERIODIC) {
WaveletOperations.circularConvolveMODWT(inputAtLevel, filters.scaledLowPass, outputApprox);
WaveletOperations.circularConvolveMODWT(inputAtLevel, filters.scaledHighPass, detailArrays[li]);
} else {
applyZeroPaddingMODWT(inputAtLevel, filters.scaledLowPass, outputApprox);
applyZeroPaddingMODWT(inputAtLevel, filters.scaledHighPass, detailArrays[li]);
}
System.arraycopy(outputApprox, 0, inputAtLevel, 0, inputAtLevel.length);
java.util.Arrays.fill(outputApprox, 0.0);
}
}
// Collect results
for (int level = 1; level <= levels; level++) {
result.setDetailCoeffsAtLevel(level, detailArrays[level - 1]);
}
result.setApproximationCoeffs(currentApprox);
return result;
}
/**
* Pre-computes all filter sets for all levels.
*/
private FilterSet[] precomputeFilterSets(DiscreteWavelet wavelet, int levels) {
FilterSet[] filterSets = new FilterSet[levels];
double[] lowPass = wavelet.lowPassDecomposition();
double[] highPass = wavelet.highPassDecomposition();
for (int level = 1; level <= levels; level++) {
// Use cache per wavelet instance and level
var perWavelet = synthesisCache.computeIfAbsent(wavelet, w -> new java.util.concurrent.ConcurrentHashMap<>());
FilterSet fs = perWavelet.computeIfAbsent(level, l -> upsampleFiltersForLevel(lowPass, highPass, l));
filterSets[level - 1] = fs;
}
return filterSets;
}
/**
* Upsample and scale filters per analysis stage (1/√2 per level) for parallel path.
*/
private FilterSet upsampleFiltersForLevel(double[] lowFilter, double[] highFilter, int level) {
double[] scaledLow = com.morphiqlabs.wavelet.internal.ScalarOps
.upsampleAndScaleForIMODWTSynthesis(lowFilter, level);
double[] scaledHigh = com.morphiqlabs.wavelet.internal.ScalarOps
.upsampleAndScaleForIMODWTSynthesis(highFilter, level);
return new FilterSet(scaledLow, scaledHigh);
}
/**
* Applies MODWT with zero-padding boundary handling.
*
* <p>This uses the dedicated zero-padding convolution routine which mirrors
* the behavior of the sequential {@code MultiLevelMODWTTransform}.</p>
*/
private void applyZeroPaddingMODWT(double[] input, double[] filter, double[] output) {
WaveletOperations.zeroPaddingConvolveMODWT(input, filter, output);
}
/**
* Calculates maximum decomposition levels for the signal.
*/
private int calculateMaxLevels(int signalLength, DiscreteWavelet wavelet) {
int filterLength = wavelet.lowPassDecomposition().length;
if (signalLength < filterLength) {
return 0;
}
int maxLevel = 1;
int filterLengthMinus1 = filterLength - 1;
while (maxLevel < MAX_DECOMPOSITION_LEVELS) {
if (maxLevel - 1 >= MAX_SAFE_SHIFT_BITS) {
break;
}
try {
long scaledFilterLength = Math.addExact(
Math.multiplyExact((long)filterLengthMinus1, 1L << (maxLevel - 1)),
1L
);
if (scaledFilterLength > signalLength) {
break;
}
} catch (ArithmeticException e) {
break;
}
maxLevel++;
}
return maxLevel - 1;
}
/**
* Holds a pair of scaled filters for a specific level.
*/
private record FilterSet(double[] scaledLowPass, double[] scaledHighPass) {}
/**
* Releases resources if this instance owns its executor.
*/
@Override
public void close() {
if (ownsExecutor && executor instanceof java.util.concurrent.ExecutorService es) {
es.shutdown();
}
}
/**
* Builder for configuring {@link ParallelMultiLevelMODWT}.
*/
public static final class Builder {
private Executor executor;
private Integer parallelism;
private int minParallelSignalLength;
/**
* Creates a new Builder.
*/
public Builder() {
// Default constructor
}
/**
* Use a specific executor for parallel tasks.
*
* @param executor an {@link Executor} to run parallel tasks
* @return this builder
*/
public Builder executor(Executor executor) {
this.executor = executor;
return this;
}
/**
* Create an internal ForkJoinPool with given parallelism. Ignored if executor is set.
*
* @param parallelism desired parallelism level
* @return this builder
*/
public Builder parallelism(int parallelism) {
this.parallelism = parallelism;
return this;
}
/**
* Minimum signal length required to use parallel execution. Smaller inputs run sequentially.
* Default is 0 (always parallel as before).
*
* @param minParallelSignalLength minimum length to trigger parallel path
* @return this builder
*/
public Builder minParallelSignalLength(int minParallelSignalLength) {
this.minParallelSignalLength = Math.max(0, minParallelSignalLength);
return this;
}
/**
* Builds a configured {@link ParallelMultiLevelMODWT} instance.
*
* @return a new ParallelMultiLevelMODWT
* @throws IllegalStateException if required parameters are missing
*/
public ParallelMultiLevelMODWT build() {
if (executor != null) {
return new ParallelMultiLevelMODWT(executor, false, minParallelSignalLength);
}
int par = parallelism != null ? parallelism : 0;
boolean owns = par > 0;
Executor exec = owns ? new ForkJoinPool(par) : ForkJoinPool.commonPool();
return new ParallelMultiLevelMODWT(exec, owns, minParallelSignalLength);
}
}
}