Line data Source code
1 : #include "audio/engine/audio_graph.h"
2 : #include <algorithm>
3 : #include <unordered_map>
4 : #include <unordered_set>
5 :
6 : namespace Amplitron {
7 :
8 4650 : int AudioGraph::add_node(const std::string &name, NodeRoutingType type,
9 : std::shared_ptr<Effect> pedal, int num_inputs) {
10 4650 : DSPNode node;
11 4650 : node.id = next_id_++; // Uses your unified member counter
12 4650 : node.name = name;
13 4650 : node.routing_type = type;
14 4650 : node.pedal = pedal;
15 :
16 : // Dynamically configure pin structures using the same unified ID pool
17 4650 : if (type == NodeRoutingType::Mixer || type == NodeRoutingType::MergeSum) {
18 1267 : int inputs_to_create = std::clamp((num_inputs > 0) ? num_inputs : 2, 2, 8);
19 2874 : for (int i = 0; i < inputs_to_create; ++i) {
20 1920 : node.input_pin_ids.push_back(next_id_++);
21 640 : }
22 954 : node.input_gains.assign(inputs_to_create, 1.0f);
23 954 : node.output_pin_ids.push_back(next_id_++); // 1 Output Pin
24 4332 : } else if (type == NodeRoutingType::Splitter) {
25 954 : node.input_pin_ids.push_back(next_id_++); // 1 Input Pin
26 954 : node.output_pin_ids.push_back(next_id_++); // Output Pin Branch A
27 954 : node.output_pin_ids.push_back(next_id_++); // Output Pin Branch B
28 318 : } else {
29 2742 : node.input_pin_ids.push_back(next_id_++);
30 2742 : node.output_pin_ids.push_back(next_id_++);
31 : }
32 :
33 4650 : nodes_.push_back(node);
34 :
35 : // Auto-recompile topology order whenever a structural block changes
36 4650 : rebuild_topology();
37 :
38 4650 : return node.id;
39 4650 : }
40 :
41 3132 : int AudioGraph::add_link(int source_pin_id, int dest_pin_id) {
42 : // Prevent duplicate connections between the exact same pair of pins
43 6966 : for (const auto& existing_link : links_) {
44 3840 : if (existing_link.source_pin_id == source_pin_id && existing_link.dest_pin_id == dest_pin_id) {
45 6 : printf("add_link failed: duplicate link\n");
46 6 : return existing_link.id;
47 : }
48 : }
49 :
50 : // Enforce that each input pin can only have ONE incoming link
51 6954 : for (const auto& existing_link : links_) {
52 3834 : if (existing_link.dest_pin_id == dest_pin_id) {
53 6 : printf("add_link failed: Input pin %d already in use!\n", dest_pin_id);
54 6 : return -1; // Pin already in use!
55 : }
56 : }
57 :
58 : // Enforce that ALL pins can only have ONE outgoing link
59 3120 : int source_node_id = get_node_from_pin(source_pin_id);
60 3120 : if (source_node_id != -1) {
61 4160 : const DSPNode *src_node = find_node(source_node_id);
62 3120 : if (src_node) {
63 : // Count existing outgoing links from this specific pin
64 2080 : int out_count = 0;
65 6948 : for (const auto &existing_link : links_) {
66 3828 : if (existing_link.source_pin_id == source_pin_id) {
67 9 : printf("add_link failed: Output pin %d already has an outgoing connection!\n", source_pin_id);
68 9 : out_count++;
69 3 : }
70 : }
71 3120 : if (out_count >= 1) {
72 6 : return -1; // Each output pin can only have 1 outgoing connection!
73 : }
74 1037 : }
75 1037 : }
76 :
77 1037 : GraphLink link;
78 3111 : link.id = next_id_++; // Uses your unified member counter
79 3111 : link.source_pin_id = source_pin_id;
80 3111 : link.dest_pin_id = dest_pin_id;
81 :
82 3111 : links_.push_back(link);
83 :
84 : // Validate if the new patch wire forms an impossible audio loop feedback cycle
85 3111 : if (!rebuild_topology()) {
86 3 : printf("add_link failed: forms an impossible audio loop feedback cycle!\n");
87 : // If a feedback loop is detected, pop the dangerous link back off to keep the engine safe
88 3 : links_.pop_back();
89 3 : rebuild_topology();
90 3 : return -1;
91 : }
92 :
93 2072 : return link.id;
94 1044 : }
95 :
96 1338 : void AudioGraph::set_node_as_input(int node_id, bool is_input) {
97 1416 : for (auto &node : nodes_) {
98 1416 : if (node.id == node_id) {
99 1338 : node.is_graph_input = is_input;
100 1338 : rebuild_topology();
101 892 : break;
102 : }
103 : }
104 1338 : }
105 :
106 1353 : void AudioGraph::set_node_as_output(int node_id, bool is_output) {
107 4122 : for (auto &node : nodes_) {
108 4122 : if (node.id == node_id) {
109 1353 : node.is_graph_output = is_output;
110 1353 : rebuild_topology();
111 902 : break;
112 : }
113 : }
114 1353 : }
115 :
116 1308 : void AudioGraph::set_node_position(int node_id, float x, float y) {
117 4674 : for (auto &node : nodes_) {
118 4674 : if (node.id == node_id) {
119 1308 : node.x = x;
120 1308 : node.y = y;
121 1308 : break;
122 : }
123 : }
124 1308 : }
125 :
126 46239 : int AudioGraph::get_node_from_pin(int pin_id) const {
127 : // Search through all nodes to find which one owns the given Pin ID
128 136260 : for (const auto &node : nodes_) {
129 241200 : for (int p : node.input_pin_ids) {
130 138621 : if (p == pin_id)
131 33669 : return node.id;
132 : }
133 221829 : for (int p : node.output_pin_ids) {
134 131808 : if (p == pin_id)
135 12558 : return node.id;
136 : }
137 : }
138 8 : return -1; // Pin ID not found in any registered node
139 15413 : }
140 :
141 11772 : bool AudioGraph::rebuild_topology() {
142 : // Kahn's algorithm or DFS to topologically sort the nodes based on links.
143 : // Since your test suite cases are already passing, we can use a basic
144 : // Kahn's sort dependency tracker to map links to execution order.
145 :
146 11772 : sorted_node_ids_.clear();
147 :
148 : // 1. Forward Reachability BFS
149 11772 : std::unordered_set<int> forward_reachable;
150 11772 : std::vector<int> queue;
151 46218 : for (const auto &node : nodes_) {
152 34446 : if (node.is_graph_input) {
153 7971 : queue.push_back(node.id);
154 19453 : forward_reachable.insert(node.id);
155 2657 : }
156 : }
157 7848 : size_t head = 0;
158 30174 : while (head < queue.size()) {
159 18402 : int curr = queue[head++];
160 24536 : auto it = std::find_if(nodes_.begin(), nodes_.end(),
161 48962 : [&](const DSPNode &n) { return n.id == curr; });
162 18402 : if (it != nodes_.end()) {
163 42066 : for (int out_pin : it->output_pin_ids) {
164 82494 : for (const auto &link : links_) {
165 58830 : if (link.source_pin_id == out_pin) {
166 11067 : int dest = get_node_from_pin(link.dest_pin_id);
167 18445 : if (dest != -1 &&
168 14756 : forward_reachable.find(dest) == forward_reachable.end()) {
169 10431 : forward_reachable.insert(dest);
170 10431 : queue.push_back(dest);
171 3477 : }
172 3689 : }
173 : }
174 : }
175 6134 : }
176 : }
177 :
178 : // 2. Backward Reachability BFS
179 11772 : std::unordered_set<int> backward_reachable;
180 11772 : queue.clear();
181 46218 : for (const auto &node : nodes_) {
182 34446 : if (node.is_graph_output) {
183 5004 : queue.push_back(node.id);
184 16486 : backward_reachable.insert(node.id);
185 1668 : }
186 : }
187 7848 : head = 0;
188 22374 : while (head < queue.size()) {
189 10602 : int curr = queue[head++];
190 14136 : auto it = std::find_if(nodes_.begin(), nodes_.end(),
191 30252 : [&](const DSPNode &n) { return n.id == curr; });
192 10602 : if (it != nodes_.end()) {
193 25530 : for (int in_pin : it->input_pin_ids) {
194 48147 : for (const auto &link : links_) {
195 33219 : if (link.dest_pin_id == in_pin) {
196 6231 : int src = get_node_from_pin(link.source_pin_id);
197 10385 : if (src != -1 &&
198 8308 : backward_reachable.find(src) == backward_reachable.end()) {
199 5598 : backward_reachable.insert(src);
200 5598 : queue.push_back(src);
201 1866 : }
202 2077 : }
203 : }
204 : }
205 3534 : }
206 : }
207 :
208 : // 3. Update is_reachable for all nodes
209 46218 : for (auto &node : nodes_) {
210 58196 : node.is_reachable = (forward_reachable.count(node.id) > 0 &&
211 18402 : backward_reachable.count(node.id) > 0);
212 : }
213 :
214 11772 : std::unordered_map<int, int> in_degree;
215 :
216 : // Initialize in-degree count for all active nodes
217 46218 : for (const auto &node : nodes_) {
218 34446 : in_degree[node.id] = 0;
219 : }
220 :
221 : // Calculate how many incoming cables are hooked up to each node
222 23040 : for (const auto &link : links_) {
223 11268 : int dest_node = get_node_from_pin(link.dest_pin_id);
224 11268 : if (dest_node != -1) {
225 11268 : in_degree[dest_node]++;
226 3756 : }
227 : }
228 :
229 : // Gather all source nodes that have 0 dependencies
230 11772 : std::vector<int> process_queue;
231 46218 : for (const auto &node : nodes_) {
232 34446 : if (in_degree[node.id] == 0) {
233 23826 : process_queue.push_back(node.id);
234 7942 : }
235 : }
236 :
237 : // Topologically extract nodes from the dependency queue
238 7848 : head = 0;
239 46212 : while (head < process_queue.size()) {
240 34440 : int current_node_id = process_queue[head++];
241 34440 : sorted_node_ids_.push_back(current_node_id);
242 :
243 : // Decrement dependencies for downstream targets linked to this node
244 165072 : for (const auto &node : nodes_) {
245 130632 : if (node.id != current_node_id)
246 96192 : continue;
247 :
248 77442 : for (int out_pin : node.output_pin_ids) {
249 106032 : for (const auto &link : links_) {
250 63030 : if (link.source_pin_id == out_pin) {
251 11262 : int target_node = get_node_from_pin(link.dest_pin_id);
252 11262 : if (target_node != -1) {
253 11262 : in_degree[target_node]--;
254 11262 : if (in_degree[target_node] == 0) {
255 10614 : process_queue.push_back(target_node);
256 3538 : }
257 3754 : }
258 3754 : }
259 : }
260 : }
261 : }
262 : }
263 :
264 : // If the sorted list length doesn't match total nodes, an impossible feedback
265 : // loop exists!
266 11772 : if (sorted_node_ids_.size() != nodes_.size()) {
267 3 : return false; // Rejects connection modifications to protect engine
268 : // stability
269 : }
270 :
271 7846 : return true; // Topology built successfully!
272 11772 : }
273 :
274 336 : bool AudioGraph::remove_node(int node_id) {
275 112 : auto it =
276 448 : std::find_if(nodes_.begin(), nodes_.end(),
277 1117 : [node_id](const DSPNode &n) { return n.id == node_id; });
278 :
279 336 : if (it != nodes_.end()) {
280 : // 1. Destroy all cables attached to this node's Input Pins
281 660 : for (int pin : it->input_pin_ids) {
282 550 : links_.erase(std::remove_if(links_.begin(), links_.end(),
283 173 : [pin](const GraphLink &l) {
284 63 : return l.dest_pin_id == pin;
285 : }),
286 440 : links_.end());
287 : }
288 : // 2. Destroy all cables attached to this node's Output Pins
289 669 : for (int pin : it->output_pin_ids) {
290 565 : links_.erase(std::remove_if(links_.begin(), links_.end(),
291 134 : [pin](const GraphLink &l) {
292 21 : return l.source_pin_id == pin;
293 : }),
294 452 : links_.end());
295 : }
296 :
297 : // 3. Erase the node and recompile the audio thread topology
298 330 : nodes_.erase(it);
299 330 : rebuild_topology();
300 330 : return true;
301 : }
302 4 : return false;
303 112 : }
304 27 : bool AudioGraph::remove_link(int link_id) {
305 9 : auto it =
306 36 : std::remove_if(links_.begin(), links_.end(),
307 36 : [link_id](const GraphLink &l) { return l.id == link_id; });
308 27 : if (it != links_.end()) {
309 18 : links_.erase(it, links_.end());
310 18 : rebuild_topology();
311 18 : return true;
312 : }
313 6 : return false;
314 9 : }
315 :
316 3495 : const DSPNode *AudioGraph::find_node(int node_id) const {
317 6801 : for (const auto &node : nodes_) {
318 6792 : if (node.id == node_id)
319 2449 : return &node;
320 : }
321 6 : return nullptr;
322 1165 : }
323 :
324 6 : void AudioGraph::restore_node(const DSPNode& node) {
325 6 : nodes_.push_back(node);
326 6 : if (node.id >= next_id_) next_id_ = node.id + 1;
327 12 : for (int pin : node.input_pin_ids) {
328 6 : if (pin >= next_id_) next_id_ = pin + 1;
329 : }
330 18 : for (int pin : node.output_pin_ids) {
331 12 : if (pin >= next_id_) next_id_ = pin + 1;
332 : }
333 6 : rebuild_topology();
334 6 : }
335 :
336 9 : void AudioGraph::restore_link(const GraphLink& link) {
337 9 : int prev_next_id = next_id_;
338 9 : links_.push_back(link);
339 9 : if (link.id >= next_id_) next_id_ = link.id + 1;
340 9 : if (!rebuild_topology()) {
341 0 : links_.pop_back();
342 0 : next_id_ = prev_next_id;
343 0 : rebuild_topology();
344 0 : }
345 9 : }
346 24 : bool AudioGraph::add_input_pin(int node_id) {
347 27 : for (auto &node : nodes_) {
348 27 : if (node.id == node_id && node.routing_type == NodeRoutingType::Mixer) {
349 24 : if (node.input_pin_ids.size() < 8) {
350 21 : node.input_pin_ids.push_back(next_id_++);
351 21 : node.input_gains.push_back(1.0f);
352 : // Do not necessarily need to rebuild topology if we just added an unconnected pin
353 22 : return true;
354 : }
355 2 : return false;
356 : }
357 : }
358 0 : return false;
359 8 : }
360 :
361 24 : bool AudioGraph::remove_input_pin(int node_id, int pin_id) {
362 24 : for (auto &node : nodes_) {
363 24 : if (node.id == node_id && (node.routing_type == NodeRoutingType::Mixer || node.routing_type == NodeRoutingType::MergeSum)) {
364 24 : if (node.input_pin_ids.size() > 2) {
365 18 : if (node.input_gains.size() < node.input_pin_ids.size()) {
366 0 : node.input_gains.resize(node.input_pin_ids.size(), 1.0f);
367 0 : }
368 18 : int index_to_remove = -1;
369 18 : if (pin_id == -1) {
370 18 : index_to_remove = node.input_pin_ids.size() - 1;
371 6 : } else {
372 0 : for (size_t i = 0; i < node.input_pin_ids.size(); ++i) {
373 0 : if (node.input_pin_ids[i] == pin_id) {
374 0 : index_to_remove = i;
375 0 : break;
376 : }
377 0 : }
378 : }
379 18 : if (index_to_remove != -1) {
380 18 : int pin_to_remove = node.input_pin_ids[index_to_remove];
381 : // Prevent removal if the pin is linked
382 36 : for (const auto &link : links_) {
383 18 : if (link.dest_pin_id == pin_to_remove) {
384 2 : return false;
385 : }
386 : }
387 18 : node.input_pin_ids.erase(node.input_pin_ids.begin() + index_to_remove);
388 18 : node.input_gains.erase(node.input_gains.begin() + index_to_remove);
389 18 : return true;
390 : }
391 0 : }
392 6 : return false;
393 : }
394 : }
395 0 : return false;
396 8 : }
397 :
398 0 : void AudioGraph::restore_input_pin(int node_id, int pin_id, int index, float gain) {
399 0 : for (auto &node : nodes_) {
400 0 : if (node.id == node_id) {
401 0 : if (index >= 0) {
402 0 : size_t idx = static_cast<size_t>(index);
403 0 : if (idx <= node.input_pin_ids.size()) {
404 0 : node.input_pin_ids.insert(node.input_pin_ids.begin() + idx, pin_id);
405 0 : while (node.input_gains.size() < idx) {
406 0 : node.input_gains.push_back(1.0f);
407 : }
408 0 : node.input_gains.insert(node.input_gains.begin() + idx, gain);
409 0 : } else {
410 0 : node.input_pin_ids.push_back(pin_id);
411 0 : node.input_gains.push_back(gain);
412 : }
413 0 : } else {
414 0 : node.input_pin_ids.push_back(pin_id);
415 0 : node.input_gains.push_back(gain);
416 : }
417 0 : if (pin_id >= next_id_) next_id_ = pin_id + 1;
418 0 : break;
419 : }
420 : }
421 0 : }
422 :
423 9 : void AudioGraph::set_mixer_input_gain(int node_id, size_t pin_index, float gain) {
424 18 : for (auto &node : nodes_) {
425 18 : if (node.id == node_id && node.routing_type == NodeRoutingType::Mixer) {
426 9 : if (pin_index < node.input_gains.size()) {
427 9 : node.input_gains[pin_index] = std::clamp(gain, 0.0f, 2.0f);
428 3 : }
429 6 : break;
430 : }
431 : }
432 9 : }
433 :
434 0 : bool AudioGraph::add_output_pin(int node_id) {
435 0 : for (auto &node : nodes_) {
436 0 : if (node.id == node_id && node.routing_type == NodeRoutingType::Splitter) {
437 0 : if (node.output_pin_ids.size() < 8) {
438 0 : node.output_pin_ids.push_back(next_id_++);
439 0 : return true;
440 : }
441 0 : return false;
442 : }
443 : }
444 0 : return false;
445 0 : }
446 :
447 0 : bool AudioGraph::remove_output_pin(int node_id, int pin_id) {
448 0 : for (auto &node : nodes_) {
449 0 : if (node.id == node_id && node.routing_type == NodeRoutingType::Splitter) {
450 0 : if (node.output_pin_ids.size() > 2) {
451 0 : int index_to_remove = -1;
452 0 : if (pin_id == -1) {
453 0 : index_to_remove = node.output_pin_ids.size() - 1;
454 0 : } else {
455 0 : for (size_t i = 0; i < node.output_pin_ids.size(); ++i) {
456 0 : if (node.output_pin_ids[i] == pin_id) {
457 0 : index_to_remove = i;
458 0 : break;
459 : }
460 0 : }
461 : }
462 0 : if (index_to_remove != -1) {
463 0 : int pin_to_remove = node.output_pin_ids[index_to_remove];
464 0 : for (const auto &link : links_) {
465 0 : if (link.source_pin_id == pin_to_remove) {
466 0 : return false;
467 : }
468 : }
469 0 : node.output_pin_ids.erase(node.output_pin_ids.begin() + index_to_remove);
470 0 : return true;
471 : }
472 0 : }
473 0 : return false;
474 : }
475 : }
476 0 : return false;
477 0 : }
478 :
479 0 : void AudioGraph::restore_output_pin(int node_id, int pin_id, int index) {
480 0 : for (auto &node : nodes_) {
481 0 : if (node.id == node_id && node.routing_type == NodeRoutingType::Splitter) {
482 0 : if (index >= 0 && index <= node.output_pin_ids.size()) {
483 0 : node.output_pin_ids.insert(node.output_pin_ids.begin() + index, pin_id);
484 0 : } else {
485 0 : node.output_pin_ids.push_back(pin_id);
486 : }
487 0 : if (pin_id >= next_id_) next_id_ = pin_id + 1;
488 0 : break;
489 : }
490 : }
491 0 : }
492 :
493 : } // namespace Amplitron
|