LCOV - code coverage report
Current view: top level - src/audio/dsp - convolution_engine.cpp (source / functions) Coverage Total Hit
Test: merged.info Lines: 99.4 % 171 170
Test Date: 2026-06-07 15:51:50 Functions: 100.0 % 12 12

            Line data    Source code
       1              : #include "audio/dsp/convolution_engine.h"
       2              : 
       3              : #include <algorithm>
       4              : #include <cmath>
       5              : #include <cstring>
       6              : 
       7              : #include "kiss_fft.h"
       8              : 
       9              : namespace Amplitron {
      10              : 
      11              : // Helper: complex multiply-accumulate (accum += a * b)
      12           12 : static void complex_multiply_accumulate(kiss_fft_cpx* accum, const kiss_fft_cpx* a,
      13              :                                         const kiss_fft_cpx* b, int n) {
      14         6156 :     for (int i = 0; i < n; ++i) {
      15         6144 :         accum[i].r += a[i].r * b[i].r - a[i].i * b[i].i;
      16         6144 :         accum[i].i += a[i].r * b[i].i + a[i].i * b[i].r;
      17         2048 :     }
      18           12 : }
      19              : 
      20              : // =============================================================================
      21              : // ConvolutionKernel
      22              : // =============================================================================
      23              : 
      24          104 : ConvolutionKernel::ConvolutionKernel(const std::vector<float>& ir_samples, int block_size)
      25           52 :     : block_size_(block_size),
      26           52 :       fft_size_(block_size * 2),
      27           78 :       ir_length_(static_cast<int>(ir_samples.size())),
      28           78 :       ir_time_(ir_samples) {
      29           78 :     if (ir_samples.empty() || block_size <= 0) {
      30            3 :         num_partitions_ = 0;
      31            3 :         return;
      32              :     }
      33              : 
      34              :     // Number of partitions = ceil(ir_length / block_size)
      35           75 :     num_partitions_ = (ir_length_ + block_size_ - 1) / block_size_;
      36              : 
      37              :     // Allocate forward FFT config
      38           75 :     fft_cfg_ = kiss_fft_alloc(fft_size_, 0, nullptr, nullptr);
      39              : 
      40              :     // Pre-compute frequency-domain representation of each partition
      41          100 :     std::vector<kiss_fft_cpx> time_buf(static_cast<size_t>(fft_size_));
      42           75 :     std::vector<kiss_fft_cpx> freq_buf(static_cast<size_t>(fft_size_));
      43              : 
      44           75 :     partitions_freq_.resize(static_cast<size_t>(num_partitions_));
      45              : 
      46          162 :     for (int p = 0; p < num_partitions_; ++p) {
      47              :         // Zero the time buffer
      48           87 :         std::fill(time_buf.begin(), time_buf.end(), kiss_fft_cpx{0.0f, 0.0f});
      49              : 
      50              :         // Copy this partition's IR samples into the real part
      51           87 :         int offset = p * block_size_;
      52           87 :         int count = std::min(block_size_, ir_length_ - offset);
      53         4830 :         for (int i = 0; i < count; ++i) {
      54         4743 :             time_buf[static_cast<size_t>(i)].r = ir_samples[static_cast<size_t>(offset + i)];
      55         4743 :             time_buf[static_cast<size_t>(i)].i = 0.0f;
      56         1581 :         }
      57              : 
      58              :         // Forward FFT
      59           87 :         kiss_fft(fft_cfg_, time_buf.data(), freq_buf.data());
      60              : 
      61              :         // Store the frequency-domain partition
      62           87 :         size_t byte_size = sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size_);
      63           87 :         partitions_freq_[static_cast<size_t>(p)].resize(byte_size);
      64           87 :         std::memcpy(partitions_freq_[static_cast<size_t>(p)].data(), freq_buf.data(), byte_size);
      65           29 :     }
      66           76 : }
      67              : 
      68          104 : ConvolutionKernel::~ConvolutionKernel() {
      69           78 :     if (fft_cfg_) kiss_fft_free(fft_cfg_);
      70          104 : }
      71              : 
      72           21 : const void* ConvolutionKernel::partition_freq(int index) const {
      73           17 :     if (index < 0 || index >= num_partitions_) return nullptr;
      74           15 :     return partitions_freq_[static_cast<size_t>(index)].data();
      75            7 : }
      76              : 
      77              : // =============================================================================
      78              : // ConvolutionEngine
      79              : // =============================================================================
      80              : 
      81          180 : ConvolutionEngine::ConvolutionEngine() = default;
      82              : 
      83          180 : ConvolutionEngine::~ConvolutionEngine() { cleanup_fft(); }
      84              : 
      85           33 : void ConvolutionEngine::init_fft(int fft_size) {
      86           33 :     cleanup_fft();
      87           33 :     fft_cfg_ = kiss_fft_alloc(fft_size, 0, nullptr, nullptr);   // forward
      88           33 :     ifft_cfg_ = kiss_fft_alloc(fft_size, 1, nullptr, nullptr);  // inverse
      89           33 :     current_fft_size_ = fft_size;
      90           33 : }
      91              : 
      92          210 : void ConvolutionEngine::cleanup_fft() {
      93          210 :     if (fft_cfg_) {
      94           33 :         kiss_fft_free(fft_cfg_);
      95           33 :         fft_cfg_ = nullptr;
      96           11 :     }
      97          210 :     if (ifft_cfg_) {
      98           33 :         kiss_fft_free(ifft_cfg_);
      99           33 :         ifft_cfg_ = nullptr;
     100           11 :     }
     101          210 :     current_fft_size_ = 0;
     102          210 : }
     103              : 
     104           36 : void ConvolutionEngine::set_kernel(const ConvolutionKernel* kernel) {
     105           36 :     kernel_ = kernel;
     106           36 :     reset();
     107           36 : }
     108              : 
     109           78 : void ConvolutionEngine::reset() {
     110           78 :     if (!kernel_) {
     111           42 :         cleanup_fft();
     112           42 :         fdl_.clear();
     113           42 :         overlap_.clear();
     114           42 :         direct_input_.clear();
     115           42 :         direct_overlap_.clear();
     116           42 :         input_cpx_.clear();
     117           42 :         accum_cpx_.clear();
     118           42 :         ifft_out_cpx_.clear();
     119           42 :         fdl_index_ = 0;
     120           42 :         return;
     121              :     }
     122              : 
     123           36 :     int fft_size = kernel_->fft_size();
     124           36 :     int num_parts = kernel_->num_partitions();
     125              : 
     126              :     // Initialize FFT if size changed
     127           36 :     if (current_fft_size_ != fft_size) {
     128           33 :         init_fft(fft_size);
     129           11 :     }
     130              : 
     131              :     // Initialize frequency-domain delay line
     132           36 :     if (fft_size <= 0 || fft_size > 65536) return;  // Sanity check
     133              : 
     134           36 :     size_t cpx_bytes = sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size);
     135           36 :     fdl_.resize(static_cast<size_t>(num_parts));
     136           81 :     for (auto& buf : fdl_) {
     137           45 :         buf.assign(cpx_bytes, 0);
     138              :     }
     139           36 :     fdl_index_ = 0;
     140              : 
     141              :     // Initialize overlap buffer
     142           36 :     overlap_.assign(static_cast<size_t>(kernel_->block_size()), 0.0f);
     143              : 
     144              :     // Initialize direct convolution overlap
     145           36 :     int ir_len = kernel_->ir_length();
     146           36 :     if (ir_len > 0) {
     147           36 :         direct_overlap_.assign(static_cast<size_t>(ir_len - 1), 0.0f);
     148           12 :     } else {
     149            0 :         direct_overlap_.clear();
     150              :     }
     151              : 
     152              :     // Scratch input copy for direct convolution fallback
     153           36 :     direct_input_.assign(static_cast<size_t>(kernel_->block_size()), 0.0f);
     154              : 
     155              :     // Initialize FFT workspace buffers (allocation-free during process())
     156           36 :     input_cpx_.assign(cpx_bytes, 0);
     157           36 :     accum_cpx_.assign(cpx_bytes, 0);
     158           36 :     ifft_out_cpx_.assign(cpx_bytes, 0);
     159           26 : }
     160              : 
     161         3033 : void ConvolutionEngine::process_direct(float* buffer, int num_samples) {
     162         3033 :     const auto& ir = kernel_->ir_time_domain();
     163         3033 :     int ir_len = static_cast<int>(ir.size());
     164         3033 :     if (ir_len == 0) return;
     165              : 
     166              :     // Output length = num_samples + ir_len - 1.
     167              :     // We output num_samples and carry over the tail (overlap-add) in direct_overlap_.
     168         3033 :     const int tail_len = ir_len - 1;
     169         3033 :     if (tail_len <= 0) return;
     170              : 
     171              :     // direct_overlap_ should be pre-sized in reset(); avoid allocations here.
     172         3027 :     if (static_cast<int>(direct_overlap_.size()) != tail_len) return;
     173              : 
     174              :     // Need original input; avoid per-call allocations by using direct_input_ scratch.
     175         3027 :     if (static_cast<int>(direct_input_.size()) < num_samples) return;
     176         3027 :     std::memcpy(direct_input_.data(), buffer, sizeof(float) * static_cast<size_t>(num_samples));
     177              : 
     178              :     // First num_samples samples: convolution + previous overlap
     179       776883 :     for (int n = 0; n < num_samples; ++n) {
     180       773856 :         float y = (n < tail_len) ? direct_overlap_[static_cast<size_t>(n)] : 0.0f;
     181              : 
     182              :         // y += sum_{j=0}^{ir_len-1} x[n-j] * h[j]
     183       773856 :         int j0 = std::max(0, n - (num_samples - 1));
     184       773856 :         int j1 = std::min(ir_len - 1, n);
     185      3087390 :         for (int j = j0; j <= j1; ++j) {
     186      2313534 :             y += direct_input_[static_cast<size_t>(n - j)] * ir[static_cast<size_t>(j)];
     187       771178 :         }
     188       773856 :         buffer[n] = y;
     189       257952 :     }
     190              : 
     191              :     // Tail samples for next block: y[num_samples .. num_samples+tail_len-1]
     192         9081 :     for (int t = 0; t < tail_len; ++t) {
     193         6054 :         const int n = num_samples + t;
     194         6054 :         float y = 0.0f;
     195              : 
     196              :         // y += sum_{j=0}^{ir_len-1} x[n-j] * h[j], where (n-j) in [0, num_samples-1]
     197         6054 :         int j0 = std::max(0, n - (num_samples - 1));
     198         6054 :         int j1 = std::min(ir_len - 1, n);
     199        15144 :         for (int j = j0; j <= j1; ++j) {
     200         9090 :             int xi = n - j;
     201         9090 :             if (xi >= 0 && xi < num_samples) {
     202         9090 :                 y += direct_input_[static_cast<size_t>(xi)] * ir[static_cast<size_t>(j)];
     203         3030 :             }
     204         3030 :         }
     205         6054 :         direct_overlap_[static_cast<size_t>(t)] = y;
     206         2018 :     }
     207         1011 : }
     208              : 
     209         3036 : void ConvolutionEngine::process(float* buffer, int num_samples) {
     210         3036 :     if (!kernel_ || kernel_->ir_length() == 0) return;
     211              : 
     212         3036 :     int block_size = kernel_->block_size();
     213         3036 :     int fft_size = kernel_->fft_size();
     214         3036 :     int num_parts = kernel_->num_partitions();
     215              : 
     216              :     // If block size doesn't match, fall back to direct convolution
     217         3036 :     if (num_samples != block_size) {
     218            3 :         process_direct(buffer, num_samples);
     219            3 :         return;
     220              :     }
     221              : 
     222              :     // Also use direct convolution for very short IRs (1 partition, IR <= block_size)
     223         3033 :     if (num_parts == 1 && kernel_->ir_length() <= block_size) {
     224         3030 :         process_direct(buffer, num_samples);
     225         3030 :         return;
     226              :     }
     227              : 
     228              :     // --- Partitioned overlap-add convolution ---
     229              : 
     230              :     // 1. Prepare input: zero-pad to fft_size
     231            3 :     auto* input_cpx = reinterpret_cast<kiss_fft_cpx*>(input_cpx_.data());
     232          771 :     for (int i = 0; i < block_size; ++i) {
     233          768 :         input_cpx[static_cast<size_t>(i)].r = buffer[i];
     234          768 :         input_cpx[static_cast<size_t>(i)].i = 0.0f;
     235          256 :     }
     236          771 :     for (int i = block_size; i < fft_size; ++i) {
     237          768 :         input_cpx[static_cast<size_t>(i)].r = 0.0f;
     238          768 :         input_cpx[static_cast<size_t>(i)].i = 0.0f;
     239          256 :     }
     240              : 
     241              :     // 2. Forward FFT of input -> store in FDL at current index
     242            3 :     auto* fdl_data = reinterpret_cast<kiss_fft_cpx*>(fdl_[static_cast<size_t>(fdl_index_)].data());
     243            3 :     kiss_fft(fft_cfg_, input_cpx, fdl_data);
     244              : 
     245              :     // 3. Complex multiply-accumulate across all partitions
     246            3 :     auto* accum = reinterpret_cast<kiss_fft_cpx*>(accum_cpx_.data());
     247            3 :     std::memset(accum, 0, sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size));
     248              : 
     249           15 :     for (int k = 0; k < num_parts; ++k) {
     250              :         // FDL index for partition k (circular)
     251           12 :         int fdl_idx = (fdl_index_ - k + num_parts) % num_parts;
     252            8 :         const auto* fdl_block =
     253           12 :             reinterpret_cast<const kiss_fft_cpx*>(fdl_[static_cast<size_t>(fdl_idx)].data());
     254           12 :         const auto* ir_block = reinterpret_cast<const kiss_fft_cpx*>(kernel_->partition_freq(k));
     255              : 
     256           12 :         complex_multiply_accumulate(accum, fdl_block, ir_block, fft_size);
     257            4 :     }
     258              : 
     259              :     // 4. Inverse FFT
     260            3 :     auto* ifft_out = reinterpret_cast<kiss_fft_cpx*>(ifft_out_cpx_.data());
     261            3 :     kiss_fft(ifft_cfg_, accum, ifft_out);
     262              : 
     263              :     // kiss_fft inverse does NOT normalize -- divide by fft_size
     264            3 :     float norm = 1.0f / static_cast<float>(fft_size);
     265              : 
     266              :     // 5. Output = first block_size samples + overlap from previous block
     267          771 :     for (int i = 0; i < block_size; ++i) {
     268          768 :         buffer[i] = ifft_out[static_cast<size_t>(i)].r * norm + overlap_[static_cast<size_t>(i)];
     269          256 :     }
     270              : 
     271              :     // 6. Store new overlap (last block_size samples of IFFT result)
     272          771 :     for (int i = 0; i < block_size; ++i) {
     273          768 :         overlap_[static_cast<size_t>(i)] = ifft_out[static_cast<size_t>(block_size + i)].r * norm;
     274          256 :     }
     275              : 
     276              :     // 7. Advance FDL index
     277            3 :     fdl_index_ = (fdl_index_ + 1) % num_parts;
     278         1012 : }
     279              : 
     280              : }  // namespace Amplitron
        

Generated by: LCOV version 2.0-1