-
Notifications
You must be signed in to change notification settings - Fork 6
Fix several bugs in the build-time AST rewrite #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
76f10fa
d61a9ef
7094325
755e683
d0bba0b
5eb330d
0b0f891
fa123e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,6 @@ | |
|
|
||
| import gast | ||
|
|
||
|
|
||
| ANNOTATIONS = ["original_lineno"] | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,6 @@ | |
|
|
||
| from blqs import statement | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| import blqs # coverage: ignore | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,6 @@ | |
|
|
||
| from blqs import _stack | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| import blqs # coverage: ignore | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -96,9 +96,15 @@ def _build(func: Callable, build_config: Optional[BuildConfig] = None) -> Callab | |||||
|
|
||||||
| This method is not intended to be called directly, use build or build_with_config above. | ||||||
| """ | ||||||
|
|
||||||
| @functools.wraps(func) | ||||||
| def wrapper(*args, **kwargs): | ||||||
| # The AST rewrite depends only on `func` and `build_config`, so build it once | ||||||
| # and cache it. Doing this lazily (not at decoration time) keeps rewrite errors | ||||||
| # at call time, which some callers rely on (e.g. a decorator stacked on top of | ||||||
| # `@blqs.build` raises `ValueError` on call, not at import). | ||||||
| cache: dict = {} | ||||||
|
|
||||||
| def _ensure_built(): | ||||||
| if cache: | ||||||
| return | ||||||
| # Get source. | ||||||
| source_code = textwrap.dedent(inspect.getsource(func)) | ||||||
|
|
||||||
|
|
@@ -132,16 +138,29 @@ def wrapper(*args, **kwargs): | |||||
| # Get the outer function, and call it, returning the inner function. | ||||||
| new_func = getattr(module, outer_fn_name)() # pylint: disable=not-callable | ||||||
| # Set this inner function up with the correct globals and closure. | ||||||
| final_func = types.FunctionType( | ||||||
| cache["final_func"] = types.FunctionType( | ||||||
| code=new_func.__code__, globals=func.__globals__, closure=func.__closure__ | ||||||
| ) | ||||||
| # Defer line-map construction to the exception handler: it's only needed on | ||||||
| # errors, and `construct_line_map` has a known assertion that can fire on | ||||||
| # otherwise well-formed rewrites (matching the original behavior). | ||||||
| cache["transformed_gast"] = transformed_gast | ||||||
| cache["transformed_source_code"] = transformed_source_code | ||||||
| cache["filename"] = filename | ||||||
|
|
||||||
| @functools.wraps(func) | ||||||
| def wrapper(*args, **kwargs): | ||||||
| _ensure_built() | ||||||
| try: | ||||||
| return final_func(*args, **kwargs) # pylint: disable=not-callable | ||||||
| return cache["final_func"](*args, **kwargs) # pylint: disable=not-callable | ||||||
| except Exception as e: | ||||||
| # If there is an exception, chain the exception in such a way as to indicated | ||||||
| # the original file and line number is given. | ||||||
| line_map = _ast.construct_line_map(transformed_gast, transformed_source_code) | ||||||
| exceptions._raise_with_line_mapping(e, func, line_map, filename) | ||||||
| if "line_map" not in cache: | ||||||
| cache["line_map"] = _ast.construct_line_map( | ||||||
| cache["transformed_gast"], cache["transformed_source_code"] | ||||||
| ) | ||||||
| exceptions._raise_with_line_mapping(e, func, cache["line_map"], cache["filename"]) | ||||||
|
|
||||||
| return wrapper | ||||||
|
|
||||||
|
|
@@ -258,12 +277,16 @@ def visit_For(self, node): | |||||
| if not self._build_config.support_for: | ||||||
| return node | ||||||
|
|
||||||
| # Bind the iterable to a local once. The original re-evaluated `iter` in each | ||||||
| # slot (is_iterable / For() / loop_vars() / the loop), double-running side | ||||||
| # effects on generators and the like. | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if I understand correctly, we're binding here to copy the object, perhaps good to mention that. That'd clarify a bit further the mention of double-running side effects.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed i reworded to "Evaluate the iterable once." |
||||||
| template = """ | ||||||
| is_iterable = blqs.is_iterable(iter) | ||||||
| for_statement = blqs.For(iter) if is_iterable else None | ||||||
| loop_vars = blqs.loop_vars(iter) if is_iterable else None | ||||||
| iter_value = iter | ||||||
| is_iterable = blqs.is_iterable(iter_value) | ||||||
| for_statement = blqs.For(iter_value) if is_iterable else None | ||||||
| loop_vars = blqs.loop_vars(iter_value) if is_iterable else None | ||||||
| for target in ([loop_vars if len(loop_vars) > 1 else loop_vars[0]] | ||||||
| if is_iterable else iter): | ||||||
| if is_iterable else iter_value): | ||||||
| with for_statement.loop_block() if for_statement else contextlib.nullcontext(): | ||||||
| loop_body | ||||||
| else: | ||||||
|
|
@@ -272,6 +295,7 @@ def visit_For(self, node): | |||||
| """ | ||||||
| new_nodes = _template.replace( | ||||||
| template, | ||||||
| iter_value=self._namer.new_name("iter_value"), | ||||||
| is_iterable=self._namer.new_name("is_iterable"), | ||||||
| for_statement=self._namer.new_name("for_statement"), | ||||||
| loop_vars=self._namer.new_name("loop_vars"), | ||||||
|
|
@@ -287,22 +311,29 @@ def visit_While(self, node): | |||||
| if not self._build_config.support_while: | ||||||
| return node | ||||||
|
|
||||||
| # Track the exit reason with a flag instead of re-evaluating `test` after the | ||||||
| # loop. The original `if not test or is_readable` re-ran `test`, doubling side | ||||||
| # effects for non-trivial conditions. | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This part of the comment won't make sense to readers of the proposed code because
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i rephrased it so it no longer quotes the removed condition, and added a test that the while test isn't re-evaluated after the loop :) |
||||||
| template = """ | ||||||
| is_readable = blqs.is_readable(test) | ||||||
| while_statement = blqs.While(test) if is_readable else None | ||||||
| loop_exited_normally = False | ||||||
| while test or is_readable: | ||||||
| with while_statement.loop_block() if while_statement else contextlib.nullcontext(): | ||||||
| loop_body | ||||||
| if is_readable: | ||||||
| break | ||||||
| if not test or is_readable: | ||||||
| else: | ||||||
| loop_exited_normally = True | ||||||
| if loop_exited_normally or is_readable: | ||||||
| with while_statement.else_block() if while_statement else contextlib.nullcontext(): | ||||||
| else_body | ||||||
| """ | ||||||
| new_nodes = _template.replace( | ||||||
| template, | ||||||
| is_readable=self._namer.new_name("is_readable"), | ||||||
| while_statement=self._namer.new_name("while_statement"), | ||||||
| loop_exited_normally=self._namer.new_name("loop_exited_normally"), | ||||||
| test=node.test, | ||||||
| loop_body=node.body, | ||||||
| else_body=node.orelse if node.orelse else gast.Pass(), | ||||||
|
|
@@ -314,6 +345,11 @@ def visit_Assign(self, node): | |||||
| if not self._build_config.support_assign: | ||||||
| return node | ||||||
|
|
||||||
| target_names = self._target_names(node.targets) | ||||||
| if target_names is None: | ||||||
| # Non-Name targets (e.g. `obj.x = ...`); leave alone for native semantics. | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Minor:) What does leave alone for native semantics refer to? maybe good to clarify
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i clarified it to "skip the rewrite, run natively" (same for del) and added attribute assign + del tests |
||||||
| return node | ||||||
|
|
||||||
| template = """ | ||||||
| temp_value = value | ||||||
| readable_targets = blqs.readable_targets(temp_value) | ||||||
|
|
@@ -325,53 +361,95 @@ def visit_Assign(self, node): | |||||
| else: | ||||||
| targets = temp_value | ||||||
| """ | ||||||
| assign_names = self._target_names(node.targets) | ||||||
| new_nodes = _template.replace( | ||||||
| template, | ||||||
| temp_value=self._namer.new_name("temp_value"), | ||||||
| value=node.value, | ||||||
| targets=node.targets, | ||||||
| readable_targets=self._namer.new_name("readable_targets"), | ||||||
| assign_names=assign_names, | ||||||
| assign_names=target_names, | ||||||
| ) | ||||||
| return new_nodes | ||||||
|
|
||||||
| def _target_names(self, targets): | ||||||
| """Return a `gast.Tuple` of the target names, or None. | ||||||
|
|
||||||
| None means a target is not a `Name` (or a `Tuple`/`List` of `Name`s) — | ||||||
| e.g. `obj.x` or `arr[0]` — and callers should skip the rewrite. | ||||||
| """ | ||||||
| names = [] | ||||||
| for target in targets: | ||||||
| if isinstance(target, gast.Name): | ||||||
| names.append(gast.Constant(target.id, None)) | ||||||
| elif isinstance(target, gast.Tuple): | ||||||
| names.extend(gast.Constant(t.id, None) for t in target.elts) | ||||||
| elif isinstance(target, gast.List): | ||||||
| names.extend(gast.Constant(t.id, None) for t in target.elts) | ||||||
| elif isinstance(target, (gast.Tuple, gast.List)): | ||||||
| for t in target.elts: | ||||||
| if not isinstance(t, gast.Name): | ||||||
| return None | ||||||
| names.append(gast.Constant(t.id, None)) | ||||||
| else: | ||||||
| raise ValueError("Invalid target type: this should not happen") # coverage: ignore | ||||||
| return None | ||||||
| return gast.Tuple(names, gast.Load()) | ||||||
|
|
||||||
| def _flat_targets(self, targets): | ||||||
| """Flatten Tuple/List targets to a single list of Name nodes. | ||||||
|
|
||||||
| Pre: `_target_names(targets)` returned non-None, so every leaf is a Name. | ||||||
| """ | ||||||
| flat = [] | ||||||
| for target in targets: | ||||||
| if isinstance(target, gast.Name): | ||||||
| flat.append(target) | ||||||
| elif isinstance(target, (gast.Tuple, gast.List)): | ||||||
| flat.extend(target.elts) | ||||||
| return flat | ||||||
|
|
||||||
| def visit_Delete(self, node): | ||||||
| node = self.generic_visit(node) | ||||||
| if not self._build_config.support_delete: | ||||||
| return node | ||||||
|
|
||||||
| target_names = self._target_names(node.targets) | ||||||
| target_tuple = gast.Tuple(node.targets, gast.Load()) | ||||||
| template = """ | ||||||
| temp_value = target_tuple | ||||||
| standard_targets = tuple(val for val in temp_value if not blqs.is_deletable(val)) | ||||||
| if len(standard_targets) > 0: | ||||||
| del standard_targets | ||||||
| deletable_names = tuple(name for val, name in zip(temp_value, target_names) | ||||||
| if target_names is None: | ||||||
| # Non-Name targets (e.g. `del obj.x`); leave alone for native semantics. | ||||||
| return node | ||||||
|
|
||||||
| flat_targets = self._flat_targets(node.targets) | ||||||
| target_tuple = gast.Tuple(flat_targets, gast.Load()) | ||||||
| target_values_name = self._namer.new_name("target_values") | ||||||
|
|
||||||
| # Capture the targets' values first: the per-target `del` below may unbind | ||||||
| # some of them, and the is_deletable filter that builds `Delete(...)` needs | ||||||
| # the values afterward. | ||||||
| new_nodes = list( | ||||||
| _template.replace( | ||||||
| "target_values = target_tuple", | ||||||
| target_values=target_values_name, | ||||||
| target_tuple=target_tuple, | ||||||
| ) | ||||||
| ) | ||||||
|
|
||||||
| # Natively `del` each non-deletable value. The original emitted | ||||||
| # `del standard_targets`, deleting the helper tuple rather than the user's | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reworded along those lines, kept it past tense since it's describing the old behavior :) |
||||||
| # bindings — so `del a` (a plain int) never actually unbound `a`. | ||||||
| for target in flat_targets: | ||||||
| check_template = """ | ||||||
| if not blqs.is_deletable(target): | ||||||
| del target | ||||||
| """ | ||||||
| new_nodes.extend(_template.replace(check_template, target=target)) | ||||||
|
|
||||||
| delete_template = """ | ||||||
| deletable_names = tuple(name for val, name in zip(target_values, target_names) | ||||||
| if blqs.is_deletable(val)) | ||||||
| if len(deletable_names) > 0: | ||||||
| blqs.Delete(deletable_names) | ||||||
| """ | ||||||
| new_nodes = _template.replace( | ||||||
| template, | ||||||
| temp_value=self._namer.new_name("temp_value"), | ||||||
| targets=node.targets, | ||||||
| standard_targets=self._namer.new_name("standard_targets"), | ||||||
| target_names=target_names, | ||||||
| target_tuple=target_tuple, | ||||||
| new_nodes.extend( | ||||||
| _template.replace( | ||||||
| delete_template, | ||||||
| target_values=target_values_name, | ||||||
| deletable_names=self._namer.new_name("deletable_names"), | ||||||
| target_names=target_names, | ||||||
| ) | ||||||
| ) | ||||||
| return new_nodes | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,6 @@ | |
|
|
||
| from blqs import block, protocols, statement | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| import blqs # coverage: ignore | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,6 @@ | |
|
|
||
| from blqs import instruction | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| import blqs # coverage: ignore | ||
|
|
||
|
|
@@ -46,7 +45,7 @@ def __call__(self, *targets) -> blqs.Instruction: | |
| return instruction.Instruction(self, *targets) | ||
|
|
||
| def __eq__(self, other): | ||
| if not isinstance(self, type(other)): | ||
| if type(self) is not type(other): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Minor as unsure how much testing we'd like here): is there a test case for this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good idea! i added a subclass test that isn't equal to a base Op of the same name |
||
| return NotImplemented | ||
| return self._name == other._name | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,17 +95,24 @@ def _build(func: Callable, build_config: Optional[BuildConfig] = None) -> Callab | |
| """ | ||
| build_config = build_config or BuildConfig() | ||
|
|
||
| @functools.wraps(func) | ||
| def wrapper(*args, **kwargs): | ||
| import blqs_cirq as __blqs_cirq | ||
|
|
||
| blqs_build_config = build_config.blqs_build_config or blqs.BuildConfig() | ||
| blqs_build_config.additional_decorator_specs = [ | ||
| # Build the inner blqs config once. `dataclasses.replace` returns a fresh object | ||
| # rather than mutating the user's BuildConfig, and running `build_with_config` | ||
| # here (at decoration time) keeps the AST rewrite off the per-call path. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (minor:) 3 parts to this comment, perhaps good to split it and place each part right before each line of code it belongs to (e.g.,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. definitely i split it into 2, 1 per line! thank you :) |
||
| import blqs_cirq as __blqs_cirq | ||
|
|
||
| user_blqs_config = build_config.blqs_build_config or blqs.BuildConfig() | ||
| blqs_build_config = dataclasses.replace( | ||
| user_blqs_config, | ||
| additional_decorator_specs=[ | ||
| blqs.DecoratorSpec(module=__blqs_cirq, method=build), | ||
| blqs.DecoratorSpec(module=__blqs_cirq, method=build_with_config), | ||
| *blqs_build_config.additional_decorator_specs, | ||
| ] | ||
| blqs_func = blqs.build_with_config(blqs_build_config)(func) | ||
| *user_blqs_config.additional_decorator_specs, | ||
| ], | ||
| ) | ||
| blqs_func = blqs.build_with_config(blqs_build_config)(func) | ||
|
|
||
| @functools.wraps(func) | ||
| def wrapper(*args, **kwargs): | ||
| program = blqs_func(*args, **kwargs) | ||
| return _build_circuit(program, build_config) if build_config.output_circuit else program | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we're caching the line map too - could it happen that we have two distinct exceptions for the same inputs (and so using the same line map through caching outputs invalid line nums for the 2nd exception)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no the line map only maps generated line numbers to original ones, so it's the same for every exception. i added a comment noting that :)