Line data Source code
1 : #include "audio/dsp/convolution_engine.h"
2 : #include "kiss_fft.h"
3 : #include <cstring>
4 : #include <cmath>
5 : #include <algorithm>
6 :
7 : namespace Amplitron {
8 :
9 : // Helper: complex multiply-accumulate (accum += a * b)
10 12 : static void complex_multiply_accumulate(kiss_fft_cpx* accum,
11 : const kiss_fft_cpx* a,
12 : const kiss_fft_cpx* b,
13 : int n) {
14 6156 : for (int i = 0; i < n; ++i) {
15 6144 : accum[i].r += a[i].r * b[i].r - a[i].i * b[i].i;
16 6144 : accum[i].i += a[i].r * b[i].i + a[i].i * b[i].r;
17 2048 : }
18 12 : }
19 :
20 : // =============================================================================
21 : // ConvolutionKernel
22 : // =============================================================================
23 :
24 104 : ConvolutionKernel::ConvolutionKernel(const std::vector<float>& ir_samples,
25 52 : int block_size)
26 52 : : block_size_(block_size)
27 52 : , fft_size_(block_size * 2)
28 78 : , ir_length_(static_cast<int>(ir_samples.size()))
29 78 : , ir_time_(ir_samples) {
30 :
31 78 : if (ir_samples.empty() || block_size <= 0) {
32 3 : num_partitions_ = 0;
33 3 : return;
34 : }
35 :
36 : // Number of partitions = ceil(ir_length / block_size)
37 75 : num_partitions_ = (ir_length_ + block_size_ - 1) / block_size_;
38 :
39 : // Allocate forward FFT config
40 75 : fft_cfg_ = kiss_fft_alloc(fft_size_, 0, nullptr, nullptr);
41 :
42 : // Pre-compute frequency-domain representation of each partition
43 100 : std::vector<kiss_fft_cpx> time_buf(static_cast<size_t>(fft_size_));
44 75 : std::vector<kiss_fft_cpx> freq_buf(static_cast<size_t>(fft_size_));
45 :
46 75 : partitions_freq_.resize(static_cast<size_t>(num_partitions_));
47 :
48 162 : for (int p = 0; p < num_partitions_; ++p) {
49 : // Zero the time buffer
50 87 : std::fill(time_buf.begin(), time_buf.end(), kiss_fft_cpx{0.0f, 0.0f});
51 :
52 : // Copy this partition's IR samples into the real part
53 87 : int offset = p * block_size_;
54 87 : int count = std::min(block_size_, ir_length_ - offset);
55 4830 : for (int i = 0; i < count; ++i) {
56 4743 : time_buf[static_cast<size_t>(i)].r = ir_samples[static_cast<size_t>(offset + i)];
57 4743 : time_buf[static_cast<size_t>(i)].i = 0.0f;
58 1581 : }
59 :
60 : // Forward FFT
61 87 : kiss_fft(fft_cfg_, time_buf.data(), freq_buf.data());
62 :
63 : // Store the frequency-domain partition
64 87 : size_t byte_size = sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size_);
65 87 : partitions_freq_[static_cast<size_t>(p)].resize(byte_size);
66 116 : std::memcpy(partitions_freq_[static_cast<size_t>(p)].data(),
67 87 : freq_buf.data(), byte_size);
68 29 : }
69 76 : }
70 :
71 104 : ConvolutionKernel::~ConvolutionKernel() {
72 78 : if (fft_cfg_) kiss_fft_free(fft_cfg_);
73 104 : }
74 :
75 21 : const void* ConvolutionKernel::partition_freq(int index) const {
76 17 : if (index < 0 || index >= num_partitions_) return nullptr;
77 15 : return partitions_freq_[static_cast<size_t>(index)].data();
78 7 : }
79 :
80 : // =============================================================================
81 : // ConvolutionEngine
82 : // =============================================================================
83 :
84 180 : ConvolutionEngine::ConvolutionEngine() = default;
85 :
86 180 : ConvolutionEngine::~ConvolutionEngine() {
87 135 : cleanup_fft();
88 180 : }
89 :
90 33 : void ConvolutionEngine::init_fft(int fft_size) {
91 33 : cleanup_fft();
92 33 : fft_cfg_ = kiss_fft_alloc(fft_size, 0, nullptr, nullptr); // forward
93 33 : ifft_cfg_ = kiss_fft_alloc(fft_size, 1, nullptr, nullptr); // inverse
94 33 : current_fft_size_ = fft_size;
95 33 : }
96 :
97 210 : void ConvolutionEngine::cleanup_fft() {
98 210 : if (fft_cfg_) { kiss_fft_free(fft_cfg_); fft_cfg_ = nullptr; }
99 210 : if (ifft_cfg_) { kiss_fft_free(ifft_cfg_); ifft_cfg_ = nullptr; }
100 210 : current_fft_size_ = 0;
101 210 : }
102 :
103 36 : void ConvolutionEngine::set_kernel(const ConvolutionKernel* kernel) {
104 36 : kernel_ = kernel;
105 36 : reset();
106 36 : }
107 :
108 78 : void ConvolutionEngine::reset() {
109 78 : if (!kernel_) {
110 42 : cleanup_fft();
111 42 : fdl_.clear();
112 42 : overlap_.clear();
113 42 : direct_input_.clear();
114 42 : direct_overlap_.clear();
115 42 : input_cpx_.clear();
116 42 : accum_cpx_.clear();
117 42 : ifft_out_cpx_.clear();
118 42 : fdl_index_ = 0;
119 42 : return;
120 : }
121 :
122 36 : int fft_size = kernel_->fft_size();
123 36 : int num_parts = kernel_->num_partitions();
124 :
125 : // Initialize FFT if size changed
126 36 : if (current_fft_size_ != fft_size) {
127 33 : init_fft(fft_size);
128 11 : }
129 :
130 : // Initialize frequency-domain delay line
131 36 : if (fft_size <= 0 || fft_size > 65536) return; // Sanity check
132 :
133 36 : size_t cpx_bytes = sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size);
134 36 : fdl_.resize(static_cast<size_t>(num_parts));
135 81 : for (auto& buf : fdl_) {
136 45 : buf.assign(cpx_bytes, 0);
137 : }
138 36 : fdl_index_ = 0;
139 :
140 : // Initialize overlap buffer
141 36 : overlap_.assign(static_cast<size_t>(kernel_->block_size()), 0.0f);
142 :
143 : // Initialize direct convolution overlap
144 36 : int ir_len = kernel_->ir_length();
145 36 : if (ir_len > 0) {
146 36 : direct_overlap_.assign(static_cast<size_t>(ir_len - 1), 0.0f);
147 12 : } else {
148 0 : direct_overlap_.clear();
149 : }
150 :
151 : // Scratch input copy for direct convolution fallback
152 36 : direct_input_.assign(static_cast<size_t>(kernel_->block_size()), 0.0f);
153 :
154 : // Initialize FFT workspace buffers (allocation-free during process())
155 36 : input_cpx_.assign(cpx_bytes, 0);
156 36 : accum_cpx_.assign(cpx_bytes, 0);
157 36 : ifft_out_cpx_.assign(cpx_bytes, 0);
158 26 : }
159 :
160 3033 : void ConvolutionEngine::process_direct(float* buffer, int num_samples) {
161 3033 : const auto& ir = kernel_->ir_time_domain();
162 3033 : int ir_len = static_cast<int>(ir.size());
163 3033 : if (ir_len == 0) return;
164 :
165 : // Output length = num_samples + ir_len - 1.
166 : // We output num_samples and carry over the tail (overlap-add) in direct_overlap_.
167 3033 : const int tail_len = ir_len - 1;
168 3033 : if (tail_len <= 0) return;
169 :
170 : // direct_overlap_ should be pre-sized in reset(); avoid allocations here.
171 3027 : if (static_cast<int>(direct_overlap_.size()) != tail_len) return;
172 :
173 : // Need original input; avoid per-call allocations by using direct_input_ scratch.
174 3027 : if (static_cast<int>(direct_input_.size()) < num_samples) return;
175 3027 : std::memcpy(direct_input_.data(), buffer, sizeof(float) * static_cast<size_t>(num_samples));
176 :
177 : // First num_samples samples: convolution + previous overlap
178 776883 : for (int n = 0; n < num_samples; ++n) {
179 773856 : float y = (n < tail_len) ? direct_overlap_[static_cast<size_t>(n)] : 0.0f;
180 :
181 : // y += sum_{j=0}^{ir_len-1} x[n-j] * h[j]
182 773856 : int j0 = std::max(0, n - (num_samples - 1));
183 773856 : int j1 = std::min(ir_len - 1, n);
184 3087390 : for (int j = j0; j <= j1; ++j) {
185 2313534 : y += direct_input_[static_cast<size_t>(n - j)] * ir[static_cast<size_t>(j)];
186 771178 : }
187 773856 : buffer[n] = y;
188 257952 : }
189 :
190 : // Tail samples for next block: y[num_samples .. num_samples+tail_len-1]
191 9081 : for (int t = 0; t < tail_len; ++t) {
192 6054 : const int n = num_samples + t;
193 6054 : float y = 0.0f;
194 :
195 : // y += sum_{j=0}^{ir_len-1} x[n-j] * h[j], where (n-j) in [0, num_samples-1]
196 6054 : int j0 = std::max(0, n - (num_samples - 1));
197 6054 : int j1 = std::min(ir_len - 1, n);
198 15144 : for (int j = j0; j <= j1; ++j) {
199 9090 : int xi = n - j;
200 9090 : if (xi >= 0 && xi < num_samples) {
201 9090 : y += direct_input_[static_cast<size_t>(xi)] * ir[static_cast<size_t>(j)];
202 3030 : }
203 3030 : }
204 6054 : direct_overlap_[static_cast<size_t>(t)] = y;
205 2018 : }
206 1011 : }
207 :
208 3036 : void ConvolutionEngine::process(float* buffer, int num_samples) {
209 3036 : if (!kernel_ || kernel_->ir_length() == 0) return;
210 :
211 3036 : int block_size = kernel_->block_size();
212 3036 : int fft_size = kernel_->fft_size();
213 3036 : int num_parts = kernel_->num_partitions();
214 :
215 : // If block size doesn't match, fall back to direct convolution
216 3036 : if (num_samples != block_size) {
217 3 : process_direct(buffer, num_samples);
218 3 : return;
219 : }
220 :
221 : // Also use direct convolution for very short IRs (1 partition, IR <= block_size)
222 3033 : if (num_parts == 1 && kernel_->ir_length() <= block_size) {
223 3030 : process_direct(buffer, num_samples);
224 3030 : return;
225 : }
226 :
227 : // --- Partitioned overlap-add convolution ---
228 :
229 : // 1. Prepare input: zero-pad to fft_size
230 3 : auto* input_cpx = reinterpret_cast<kiss_fft_cpx*>(input_cpx_.data());
231 771 : for (int i = 0; i < block_size; ++i) {
232 768 : input_cpx[static_cast<size_t>(i)].r = buffer[i];
233 768 : input_cpx[static_cast<size_t>(i)].i = 0.0f;
234 256 : }
235 771 : for (int i = block_size; i < fft_size; ++i) {
236 768 : input_cpx[static_cast<size_t>(i)].r = 0.0f;
237 768 : input_cpx[static_cast<size_t>(i)].i = 0.0f;
238 256 : }
239 :
240 : // 2. Forward FFT of input -> store in FDL at current index
241 2 : auto* fdl_data = reinterpret_cast<kiss_fft_cpx*>(
242 3 : fdl_[static_cast<size_t>(fdl_index_)].data());
243 3 : kiss_fft(fft_cfg_, input_cpx, fdl_data);
244 :
245 : // 3. Complex multiply-accumulate across all partitions
246 3 : auto* accum = reinterpret_cast<kiss_fft_cpx*>(accum_cpx_.data());
247 3 : std::memset(accum, 0, sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size));
248 :
249 15 : for (int k = 0; k < num_parts; ++k) {
250 : // FDL index for partition k (circular)
251 12 : int fdl_idx = (fdl_index_ - k + num_parts) % num_parts;
252 8 : const auto* fdl_block = reinterpret_cast<const kiss_fft_cpx*>(
253 12 : fdl_[static_cast<size_t>(fdl_idx)].data());
254 8 : const auto* ir_block = reinterpret_cast<const kiss_fft_cpx*>(
255 12 : kernel_->partition_freq(k));
256 :
257 12 : complex_multiply_accumulate(accum, fdl_block, ir_block, fft_size);
258 4 : }
259 :
260 : // 4. Inverse FFT
261 3 : auto* ifft_out = reinterpret_cast<kiss_fft_cpx*>(ifft_out_cpx_.data());
262 3 : kiss_fft(ifft_cfg_, accum, ifft_out);
263 :
264 : // kiss_fft inverse does NOT normalize -- divide by fft_size
265 3 : float norm = 1.0f / static_cast<float>(fft_size);
266 :
267 : // 5. Output = first block_size samples + overlap from previous block
268 771 : for (int i = 0; i < block_size; ++i) {
269 1024 : buffer[i] = ifft_out[static_cast<size_t>(i)].r * norm +
270 768 : overlap_[static_cast<size_t>(i)];
271 256 : }
272 :
273 : // 6. Store new overlap (last block_size samples of IFFT result)
274 771 : for (int i = 0; i < block_size; ++i) {
275 768 : overlap_[static_cast<size_t>(i)] =
276 768 : ifft_out[static_cast<size_t>(block_size + i)].r * norm;
277 256 : }
278 :
279 : // 7. Advance FDL index
280 3 : fdl_index_ = (fdl_index_ + 1) % num_parts;
281 1012 : }
282 :
283 : } // namespace Amplitron
|