diff --git a/CHANGELOG.md b/CHANGELOG.md index 2635aff..86b217a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ Types of changes: ### Removed ### Fixed +- Fixed barrier unrolling to preserve multi-qubit barrier statements instead of splitting into individual per-qubit barriers. ([#295](https://github.com/qBraid/pyqasm/pull/295)) - Added support for physical qubit identifiers (`$0`, `$1`, …) in plain QASM 3 programs, including gates, barriers, measurements, and duplicate-qubit detection. ([#291](https://github.com/qBraid/pyqasm/pull/291)) - Updated CI to use `macos-15-intel` image due to deprecation of `macos-13` image. ([#283](https://github.com/qBraid/pyqasm/pull/283)) diff --git a/src/pyqasm/transformer.py b/src/pyqasm/transformer.py index 385c2b6..6f0fe2a 100644 --- a/src/pyqasm/transformer.py +++ b/src/pyqasm/transformer.py @@ -479,35 +479,37 @@ def _get_pyqasm_device_qubit_index( return _offsets[reg] + idx if isinstance(unrolled_stmts, QuantumBarrier): - _qubit_id = cast(Identifier, unrolled_stmts.qubits[0]) # type: ignore[union-attr] - if not isinstance(_qubit_id, IndexedIdentifier): - _start = _get_pyqasm_device_qubit_index( - _qubit_id.name, 0, qubit_register_offsets, global_qreg_size_map - ) - _end = _get_pyqasm_device_qubit_index( - _qubit_id.name, - global_qreg_size_map[_qubit_id.name] - 1, - qubit_register_offsets, - global_qreg_size_map, - ) - if _start == 0: - _qubit_id.name = f"__PYQASM_QUBITS__[:{_end+1}]" - elif _end == device_qubits - 1: - _qubit_id.name = f"__PYQASM_QUBITS__[{_start}:]" + for _qubit_id in unrolled_stmts.qubits: # type: ignore[union-attr] + _qubit_id = cast(Identifier, _qubit_id) + if not isinstance(_qubit_id, IndexedIdentifier): + _start = _get_pyqasm_device_qubit_index( + _qubit_id.name, 0, qubit_register_offsets, global_qreg_size_map + ) + _end = _get_pyqasm_device_qubit_index( + _qubit_id.name, + global_qreg_size_map[_qubit_id.name] - 1, + qubit_register_offsets, + global_qreg_size_map, + ) + if _start == 0: + _qubit_id.name = f"__PYQASM_QUBITS__[:{_end+1}]" + elif _end == device_qubits - 1: + _qubit_id.name = f"__PYQASM_QUBITS__[{_start}:]" + else: + _qubit_id.name = f"__PYQASM_QUBITS__[{_start}:{_end+1}]" else: - _qubit_id.name = f"__PYQASM_QUBITS__[{_start}:{_end+1}]" - else: - _qubit_str = cast(str, unrolled_stmts.qubits[0].name) # type: ignore[union-attr] - _qubit_ind = cast( - list, unrolled_stmts.qubits[0].indices - ) # type: ignore[union-attr] - for multi_ind in _qubit_ind: - for ind in multi_ind: - pyqasm_ind = _get_pyqasm_device_qubit_index( - _qubit_str.name, ind.value, qubit_register_offsets, global_qreg_size_map - ) - ind.value = pyqasm_ind - _qubit_str.name = "__PYQASM_QUBITS__" + _qubit_str = cast(str, _qubit_id.name) # type: ignore[union-attr] + _qubit_ind = cast(list, _qubit_id.indices) # type: ignore[union-attr] + for multi_ind in _qubit_ind: + for ind in multi_ind: + pyqasm_ind = _get_pyqasm_device_qubit_index( + _qubit_str.name, + ind.value, + qubit_register_offsets, + global_qreg_size_map, + ) + ind.value = pyqasm_ind + _qubit_str.name = "__PYQASM_QUBITS__" if isinstance(unrolled_stmts, list): # pylint: disable=too-many-nested-blocks if isinstance(unrolled_stmts[0], QuantumMeasurementStatement): diff --git a/src/pyqasm/visitor.py b/src/pyqasm/visitor.py index c18e49c..e4c5dec 100644 --- a/src/pyqasm/visitor.py +++ b/src/pyqasm/visitor.py @@ -761,6 +761,42 @@ def _visit_reset(self, statement: qasm3_ast.QuantumReset) -> list[qasm3_ast.Quan return unrolled_resets + def _expand_barrier_ranges( + self, + barrier: qasm3_ast.QuantumBarrier, + barrier_qubits: list[qasm3_ast.IndexedIdentifier | qasm3_ast.Identifier], + ) -> list: + """Replace RangeDefinition-containing qubits in a barrier with their + expanded IndexedIdentifier equivalents so that consolidate_qubit_registers + only sees IntegerLiteral indices.""" + consolidated_qubits: list = [] + expanded_idx = 0 + for op_qubit in barrier.qubits: + # Expand this single operand to find how many bits it produces. + temp_barrier = qasm3_ast.QuantumBarrier(qubits=[op_qubit]) + temp_barrier.span = barrier.span + op_expanded = self._get_op_bits(temp_barrier, qubits=True) + num_bits = len(op_expanded) + + if isinstance(op_qubit, qasm3_ast.IndexedIdentifier): + has_range = any( + isinstance(ind, qasm3_ast.RangeDefinition) + for dim in op_qubit.indices + for ind in dim # type: ignore[union-attr] + ) + if has_range: + consolidated_qubits.extend( + barrier_qubits[expanded_idx : expanded_idx + num_bits] + ) + else: + consolidated_qubits.append(op_qubit) + else: + # Bare Identifier — keep as-is for compact slice notation + consolidated_qubits.append(op_qubit) + + expanded_idx += num_bits + return consolidated_qubits + def _visit_barrier( # pylint: disable=too-many-locals, too-many-branches self, barrier: qasm3_ast.QuantumBarrier ) -> list[qasm3_ast.QuantumBarrier]: @@ -834,10 +870,14 @@ def _visit_barrier( # pylint: disable=too-many-locals, too-many-branches if not self._unroll_barriers: if self._consolidate_qubits: + consolidated_qubits = self._expand_barrier_ranges(barrier, barrier_qubits) + expanded = qasm3_ast.QuantumBarrier( + qubits=consolidated_qubits # type: ignore[arg-type] + ) barrier = cast( qasm3_ast.QuantumBarrier, Qasm3Transformer.consolidate_qubit_registers( - barrier, + expanded, self._qubit_register_offsets, self._global_qreg_size_map, self._module._device_qubits, @@ -845,18 +885,23 @@ def _visit_barrier( # pylint: disable=too-many-locals, too-many-branches ) return [barrier] + # Keep barrier as a single multi-qubit statement with expanded qubit + # references (e.g. q -> q[0], q[1], q[2]) instead of splitting into + # individual per-qubit barriers. + expanded_barrier = qasm3_ast.QuantumBarrier(qubits=barrier_qubits) # type: ignore[arg-type] + if self._consolidate_qubits: - unrolled_barriers = cast( - list[qasm3_ast.QuantumBarrier], + expanded_barrier = cast( + qasm3_ast.QuantumBarrier, Qasm3Transformer.consolidate_qubit_registers( - unrolled_barriers, + expanded_barrier, self._qubit_register_offsets, self._global_qreg_size_map, self._module._device_qubits, ), ) - return unrolled_barriers + return [expanded_barrier] def _get_op_parameters(self, operation: qasm3_ast.QuantumGate) -> list[float]: """Get the parameters for the operation. diff --git a/tests/qasm2/test_operations.py b/tests/qasm2/test_operations.py index 441b6bc..8c14ec9 100644 --- a/tests/qasm2/test_operations.py +++ b/tests/qasm2/test_operations.py @@ -51,8 +51,7 @@ def test_whitelisted_ops(): include 'qelib1.inc'; qreg q[2]; creg c[2]; - barrier q[0]; - barrier q[1]; + barrier q[0], q[1]; reset q[0]; reset q[1]; measure q[0] -> c[0]; diff --git a/tests/qasm3/test_barrier.py b/tests/qasm3/test_barrier.py index fa3456c..6798777 100644 --- a/tests/qasm3/test_barrier.py +++ b/tests/qasm3/test_barrier.py @@ -47,23 +47,9 @@ def test_barrier(): qubit[2] q1; qubit[3] q2; qubit[1] q3; - barrier q1[0]; - barrier q1[1]; - barrier q2[0]; - barrier q2[1]; - barrier q2[2]; - barrier q3[0]; - barrier q1[0]; - barrier q1[1]; - barrier q2[0]; - barrier q2[1]; - barrier q2[2]; - barrier q3[0]; - barrier q1[0]; - barrier q1[1]; - barrier q2[0]; - barrier q2[1]; - barrier q3[0]; + barrier q1[0], q1[1], q2[0], q2[1], q2[2], q3[0]; + barrier q1[0], q1[1], q2[0], q2[1], q2[2], q3[0]; + barrier q1[0], q1[1], q2[0], q2[1], q3[0]; """ module = loads(qasm_str) module.unroll() @@ -87,10 +73,7 @@ def my_function(qubit[4] a) { expected_qasm = """OPENQASM 3.0; include "stdgates.inc"; qubit[4] q; - barrier q[0]; - barrier q[1]; - barrier q[2]; - barrier q[3]; + barrier q[0], q[1], q[2], q[3]; """ module = loads(qasm_str) module.unroll() diff --git a/tests/qasm3/test_device_qubits.py b/tests/qasm3/test_device_qubits.py index 9261e00..d3c0eec 100644 --- a/tests/qasm3/test_device_qubits.py +++ b/tests/qasm3/test_device_qubits.py @@ -57,9 +57,7 @@ def test_barrier(): expected_qasm = """OPENQASM 3.0; qubit[5] __PYQASM_QUBITS__; include "stdgates.inc"; - barrier __PYQASM_QUBITS__[2]; - barrier __PYQASM_QUBITS__[3]; - barrier __PYQASM_QUBITS__[4]; + barrier __PYQASM_QUBITS__[2], __PYQASM_QUBITS__[3], __PYQASM_QUBITS__[4]; barrier __PYQASM_QUBITS__[1]; """ result = loads(qasm, device_qubits=5) @@ -91,6 +89,25 @@ def test_unrolled_barrier(): check_unrolled_qasm(dumps(result), expected_qasm) +def test_unrolled_barrier_with_range(): + qasm = """OPENQASM 3.0; + include "stdgates.inc"; + qubit[4] q; + qubit[2] q2; + barrier q[0:2]; + barrier q2[0:2]; + """ + expected_qasm = """OPENQASM 3.0; + qubit[6] __PYQASM_QUBITS__; + include "stdgates.inc"; + barrier __PYQASM_QUBITS__[0], __PYQASM_QUBITS__[1]; + barrier __PYQASM_QUBITS__[4], __PYQASM_QUBITS__[5]; + """ + result = loads(qasm, device_qubits=6) + result.unroll(unroll_barriers=False, consolidate_qubits=True) + check_unrolled_qasm(dumps(result), expected_qasm) + + def test_measurement(): qasm = """OPENQASM 3.0; include "stdgates.inc"; diff --git a/tests/visualization/test_mpl_draw.py b/tests/visualization/test_mpl_draw.py index a76e6d5..bd281f4 100644 --- a/tests/visualization/test_mpl_draw.py +++ b/tests/visualization/test_mpl_draw.py @@ -117,6 +117,32 @@ def test_draw_qasm2_simple(): _check_fig(circ, fig) +def test_draw_barriers(): + """Test drawing barriers with various qubit patterns.""" + qasm = """ + OPENQASM 3.0; + include "stdgates.inc"; + qubit[3] q; + qubit[2] r; + barrier q[0], q[2]; + barrier q[0:2]; + barrier r; + barrier r[0], q[1]; + """ + circ = loads(qasm) + fig = mpl_draw(circ) + _check_fig(circ, fig) + + from matplotlib.patches import Rectangle + + ax = fig.axes[0] + # Barriers are drawn as Rectangle patches (one per qubit line per barrier) + # and dashed vlines (added to collections). 4 barriers touching 2+2+2+2=8 lines total. + barrier_patches = [p for p in ax.patches if isinstance(p, Rectangle)] + assert len(barrier_patches) == 8 + assert len(ax.collections) > 0 + + @pytest.mark.mpl_image_compare(baseline_dir="images", filename="bell.png") def test_draw_bell(): """Test drawing a simple Bell state circuit."""