Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .clabot
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"contributors": [ "dabacon", "mjk", "splch", "guenp" ],
"contributors": [ "dabacon", "mjk", "splch", "antalszava" ],
"message": "We require contributors to sign our Contributor License Agreement, and we don't have one on file for your GitHub username. In order for us to review and merge your code, please contact @mjk or opensource@ionq.com to sign the CLA."
}
1 change: 0 additions & 1 deletion blqs/blqs/_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import gast


ANNOTATIONS = ["original_lineno"]


Expand Down
1 change: 0 additions & 1 deletion blqs/blqs/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from blqs import statement


if TYPE_CHECKING:
import blqs # coverage: ignore

Expand Down
1 change: 0 additions & 1 deletion blqs/blqs/block_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from blqs import _stack


if TYPE_CHECKING:
import blqs # coverage: ignore

Expand Down
137 changes: 103 additions & 34 deletions blqs/blqs/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,13 @@ def _build(func: Callable, build_config: Optional[BuildConfig] = None) -> Callab

This method is not intended to be called directly, use build or build_with_config above.
"""
# Build the rewrite once and cache it; it depends only on `func` and
# `build_config`. Done lazily so errors surface on call, not at import.
cache: dict = {}

@functools.wraps(func)
def wrapper(*args, **kwargs):
def _ensure_built():
if cache:
return
# Get source.
source_code = textwrap.dedent(inspect.getsource(func))

Expand Down Expand Up @@ -132,16 +136,27 @@ def wrapper(*args, **kwargs):
# Get the outer function, and call it, returning the inner function.
new_func = getattr(module, outer_fn_name)() # pylint: disable=not-callable
# Set this inner function up with the correct globals and closure.
final_func = types.FunctionType(
cache["final_func"] = types.FunctionType(
code=new_func.__code__, globals=func.__globals__, closure=func.__closure__
)
# Stash what the line map needs; build it lazily (only on error).
cache["transformed_gast"] = transformed_gast
cache["transformed_source_code"] = transformed_source_code
cache["filename"] = filename

@functools.wraps(func)
def wrapper(*args, **kwargs):
_ensure_built()
try:
return final_func(*args, **kwargs) # pylint: disable=not-callable
return cache["final_func"](*args, **kwargs) # pylint: disable=not-callable
except Exception as e:
# If there is an exception, chain the exception in such a way as to indicated
# the original file and line number is given.
line_map = _ast.construct_line_map(transformed_gast, transformed_source_code)
exceptions._raise_with_line_mapping(e, func, line_map, filename)
# Re-raise pointing at the original source. The line map is the same
# for every exception, so build it once.
if "line_map" not in cache:
cache["line_map"] = _ast.construct_line_map(
cache["transformed_gast"], cache["transformed_source_code"]
)
Comment on lines +155 to +158

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're caching the line map too - could it happen that we have two distinct exceptions for the same inputs (and so using the same line map through caching outputs invalid line nums for the 2nd exception)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no the line map only maps generated line numbers to original ones, so it's the same for every exception. i added a comment noting that :)

exceptions._raise_with_line_mapping(e, func, cache["line_map"], cache["filename"])

return wrapper

Expand Down Expand Up @@ -258,12 +273,15 @@ def visit_For(self, node):
if not self._build_config.support_for:
return node

# Evaluate the iterable once. The original re-ran `iter` in each slot,
# doubling side effects (e.g. on generators).
template = """
is_iterable = blqs.is_iterable(iter)
for_statement = blqs.For(iter) if is_iterable else None
loop_vars = blqs.loop_vars(iter) if is_iterable else None
iter_value = iter
is_iterable = blqs.is_iterable(iter_value)
for_statement = blqs.For(iter_value) if is_iterable else None
loop_vars = blqs.loop_vars(iter_value) if is_iterable else None
for target in ([loop_vars if len(loop_vars) > 1 else loop_vars[0]]
if is_iterable else iter):
if is_iterable else iter_value):
with for_statement.loop_block() if for_statement else contextlib.nullcontext():
loop_body
else:
Expand All @@ -272,6 +290,7 @@ def visit_For(self, node):
"""
new_nodes = _template.replace(
template,
iter_value=self._namer.new_name("iter_value"),
is_iterable=self._namer.new_name("is_iterable"),
for_statement=self._namer.new_name("for_statement"),
loop_vars=self._namer.new_name("loop_vars"),
Expand All @@ -287,22 +306,28 @@ def visit_While(self, node):
if not self._build_config.support_while:
return node

# Flag the loop's exit reason so the else block doesn't re-run `test`,
# which would double its side effects.
template = """
is_readable = blqs.is_readable(test)
while_statement = blqs.While(test) if is_readable else None
loop_exited_normally = False
while test or is_readable:
with while_statement.loop_block() if while_statement else contextlib.nullcontext():
loop_body
if is_readable:
break
if not test or is_readable:
else:
loop_exited_normally = True
if loop_exited_normally or is_readable:
with while_statement.else_block() if while_statement else contextlib.nullcontext():
else_body
"""
new_nodes = _template.replace(
template,
is_readable=self._namer.new_name("is_readable"),
while_statement=self._namer.new_name("while_statement"),
loop_exited_normally=self._namer.new_name("loop_exited_normally"),
test=node.test,
loop_body=node.body,
else_body=node.orelse if node.orelse else gast.Pass(),
Expand All @@ -314,6 +339,11 @@ def visit_Assign(self, node):
if not self._build_config.support_assign:
return node

target_names = self._target_names(node.targets)
if target_names is None:
# Non-Name target (e.g. `obj.x`): skip the rewrite, run natively.
return node

template = """
temp_value = value
readable_targets = blqs.readable_targets(temp_value)
Expand All @@ -325,53 +355,92 @@ def visit_Assign(self, node):
else:
targets = temp_value
"""
assign_names = self._target_names(node.targets)
new_nodes = _template.replace(
template,
temp_value=self._namer.new_name("temp_value"),
value=node.value,
targets=node.targets,
readable_targets=self._namer.new_name("readable_targets"),
assign_names=assign_names,
assign_names=target_names,
)
return new_nodes

def _target_names(self, targets):
"""Return a `gast.Tuple` of the target names, or None.

None means a target is not a `Name` (or a `Tuple`/`List` of `Name`s) —
e.g. `obj.x` or `arr[0]` — and callers should skip the rewrite.
"""
names = []
for target in targets:
if isinstance(target, gast.Name):
names.append(gast.Constant(target.id, None))
elif isinstance(target, gast.Tuple):
names.extend(gast.Constant(t.id, None) for t in target.elts)
elif isinstance(target, gast.List):
names.extend(gast.Constant(t.id, None) for t in target.elts)
elif isinstance(target, (gast.Tuple, gast.List)):
for t in target.elts:
if not isinstance(t, gast.Name):
return None
names.append(gast.Constant(t.id, None))
else:
raise ValueError("Invalid target type: this should not happen") # coverage: ignore
return None
return gast.Tuple(names, gast.Load())

def _flat_targets(self, targets):
"""Flatten Tuple/List targets to a single list of Name nodes.

Pre: `_target_names(targets)` returned non-None, so every leaf is a Name.
"""
flat = []
for target in targets:
if isinstance(target, gast.Name):
flat.append(target)
elif isinstance(target, (gast.Tuple, gast.List)):
flat.extend(target.elts)
return flat

def visit_Delete(self, node):
node = self.generic_visit(node)
if not self._build_config.support_delete:
return node

target_names = self._target_names(node.targets)
target_tuple = gast.Tuple(node.targets, gast.Load())
template = """
temp_value = target_tuple
standard_targets = tuple(val for val in temp_value if not blqs.is_deletable(val))
if len(standard_targets) > 0:
del standard_targets
deletable_names = tuple(name for val, name in zip(temp_value, target_names)
if target_names is None:
# Non-Name target (e.g. `del obj.x`): skip the rewrite, run natively.
return node

flat_targets = self._flat_targets(node.targets)
target_tuple = gast.Tuple(flat_targets, gast.Load())
target_values_name = self._namer.new_name("target_values")

# Capture values first: the `del`s below unbind names the filter needs.
new_nodes = list(
_template.replace(
"target_values = target_tuple",
target_values=target_values_name,
target_tuple=target_tuple,
)
)

# Natively `del` non-deletable values. The original deleted the helper
# tuple, so `del a` (a plain int) never unbound `a`.
for target in flat_targets:
check_template = """
if not blqs.is_deletable(target):
del target
"""
new_nodes.extend(_template.replace(check_template, target=target))

delete_template = """
deletable_names = tuple(name for val, name in zip(target_values, target_names)
if blqs.is_deletable(val))
if len(deletable_names) > 0:
blqs.Delete(deletable_names)
"""
new_nodes = _template.replace(
template,
temp_value=self._namer.new_name("temp_value"),
targets=node.targets,
standard_targets=self._namer.new_name("standard_targets"),
target_names=target_names,
target_tuple=target_tuple,
new_nodes.extend(
_template.replace(
delete_template,
target_values=target_values_name,
deletable_names=self._namer.new_name("deletable_names"),
target_names=target_names,
)
)
return new_nodes
75 changes: 75 additions & 0 deletions blqs/blqs/build_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,81 @@ def fn():
assert transformed_fn() == blqs.Program.of()


def test_build_delete_native_actually_unbinds():
def fn():
a = 1
del a
return a

transformed_fn = blqs.build(fn)
with pytest.raises((NameError, UnboundLocalError)):
transformed_fn()


def test_build_for_evaluates_iterable_once():
calls = []

def make_iter():
calls.append(1)
return [0]

def fn():
for _ in make_iter():
blqs.Op("H")(0)

blqs.build(fn)()
assert calls == [1]


def test_build_while_does_not_reevaluate_test_after_loop():
state = {"i": 0}
reads = []

def cond():
reads.append(state["i"])
return state["i"] < 2

def fn():
while cond():
blqs.Op("H")(0)
state["i"] += 1

blqs.build(fn)()
# `is_readable(cond())` checks once up front (i=0), then the loop tests at
# i=0,1,2. The fixed code does not evaluate `cond()` again after the loop; the
# old code re-ran it once more, appending a trailing 2.
assert reads == [0, 0, 1, 2]


def test_build_assign_to_attribute_is_native():
class C:
pass

c = C()

def fn():
c.x = 5

# Attribute targets aren't rewritten; the assignment runs as plain Python.
blqs.build(fn)()
assert c.x == 5


def test_build_delete_attribute_is_native():
class C:
pass

c = C()
c.x = 5

def fn():
del c.x

# Attribute targets aren't rewritten; the `del` runs as plain Python.
blqs.build(fn)()
assert not hasattr(c, "x")


def test_build_with_config_support_if():
def if_fn():
if blqs.Register("a"):
Expand Down
1 change: 0 additions & 1 deletion blqs/blqs/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from blqs import block, protocols, statement


if TYPE_CHECKING:
import blqs # coverage: ignore

Expand Down
5 changes: 2 additions & 3 deletions blqs/blqs/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from blqs import block, protocols, statement


if TYPE_CHECKING:
import blqs # coverage: ignore

Expand Down Expand Up @@ -84,8 +83,8 @@ def else_block(self) -> blqs.Block:
return self._else_block

def __str__(self):
loop_str = f"while {self._condition}:\n{self._loop_block}\n"
else_str = f"else:\n{self._else_block}"
loop_str = f"while {self._condition}:\n{self._loop_block}"
else_str = f"\nelse:\n{self._else_block}"
return loop_str + else_str if self._else_block else loop_str

def __eq__(self, other):
Expand Down
2 changes: 1 addition & 1 deletion blqs/blqs/loops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_while_str():
with loop.loop_block():
op = blqs.Op("MOV")
op(0, 1)
assert str(loop) == "while R(a):\n MOV 0, 1\n"
assert str(loop) == "while R(a):\n MOV 0, 1"
with loop.else_block():
op = blqs.Op("H")
op(0)
Expand Down
3 changes: 1 addition & 2 deletions blqs/blqs/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from blqs import instruction


if TYPE_CHECKING:
import blqs # coverage: ignore

Expand Down Expand Up @@ -46,7 +45,7 @@ def __call__(self, *targets) -> blqs.Instruction:
return instruction.Instruction(self, *targets)

def __eq__(self, other):
if not isinstance(self, type(other)):
if type(self) is not type(other):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Minor as unsure how much testing we'd like here): is there a test case for this?

@splch splch Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea! i added a subclass test that isn't equal to a base Op of the same name

return NotImplemented
return self._name == other._name

Expand Down
Loading
Loading