Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- name: prek check
uses: j178/prek-action@v1
with:
extra-args: --all-files --skip ruff --skip ruff-format --skip ty --skip mypy
extra-args: --all-files --skip ruff --skip ruff-format

lint:
runs-on: ubuntu-latest
Expand All @@ -44,7 +44,7 @@ jobs:
--group docs \
--group test
uv pip freeze
- name: Lint with mypy and ruff
- name: Lint with ty and ruff
run: |
uv run make lint
- name: Build documentation
Expand Down
16 changes: 3 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,10 @@ repos:
language: system
files: "(.py$)|(.*.ipynb$)"

- id: mypy
name: mypy
language: python
entry: mypy --install-types --non-interactive
files: ^numpyro/

- repo: https://github.com/astral-sh/ty-pre-commit
rev: v0.0.49
hooks:
- id: ty
name: ty check
language: python
entry: ty
args: [check, --config-file, ty.toml]
pass_filenames: false
always_run: true
additional_dependencies: [ty]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
Expand Down
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ lint: FORCE
ruff check .
ruff format . --check
python scripts/update_headers.py --check
mypy --install-types --non-interactive numpyro
ty check -vvv --config-file ty.toml --exit-zero
ty check

license: FORCE
python scripts/update_headers.py
Expand Down
5 changes: 4 additions & 1 deletion numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def scan_enum(
trace as packed_trace,
)

if substitute_stack is None:
substitute_stack = []

# amount number of steps to unroll
history = min(history, length)
unroll_steps = min(2 * history - 1, length)
Expand Down Expand Up @@ -189,7 +192,7 @@ def body_fn(wrapped_carry, x, prefix=None):
# store shape of new_carry at a global variable
if len(carry_shapes) < (history + 1):
carry_shapes.append(
[jnp.shape(x) for x in jax.tree.flatten(new_carry)[0]]
[jnp.shape(x) for x in jax.tree.flatten(new_carry)[0]] # ty: ignore[invalid-argument-type]
)
# make new_carry have the same shape as carry
# FIXME: is this rigorous?
Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/funsor/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def process_message(self, msg):
plate_to_scale[self.name] = (
self.size / self.subsample_size if self.subsample_size else 1
)
return OrigPlateMessenger.process_message(self, msg)
return OrigPlateMessenger.process_message(self, msg) # ty: ignore[invalid-argument-type]

def postprocess_message(self, msg):
if msg["type"] in ["to_funsor", "to_data"]:
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def plate_to_enum_plate():

"""
try:
numpyro.plate.__new__ = lambda cls, *args, **kwargs: enum_plate(*args, **kwargs)
numpyro.plate.__new__ = lambda cls, *args, **kwargs: enum_plate(*args, **kwargs) # ty: ignore[invalid-assignment]
yield
finally:
numpyro.plate.__new__ = lambda *args, **kwargs: object.__new__(numpyro.plate)
numpyro.plate.__new__ = lambda *args, **kwargs: object.__new__(numpyro.plate) # ty: ignore[invalid-assignment]


def _config_enumerate_fn(site, default):
Expand Down
11 changes: 9 additions & 2 deletions numpyro/contrib/hsgp/spectral_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def diag_spectral_density_squared_exponential(

def _spectral_density(w):
return spectral_density_squared_exponential(
dim=dim, w=w, alpha=alpha, length=length
dim=dim,
w=w,
alpha=alpha,
length=length, # ty: ignore[invalid-argument-type]
)

sqrt_eigenvalues_ = sqrt_eigenvalues(ell=ell, m=m, dim=dim) # dim x m
Expand Down Expand Up @@ -329,7 +332,11 @@ def diag_spectral_density_rational_quadratic(

def _spectral_density(w):
return spectral_density_rational_quadratic(
dim=dim, w=w, alpha=alpha, length=length, scale_mixture=scale_mixture
dim=dim,
w=w,
alpha=alpha,
length=length, # ty: ignore[invalid-argument-type]
scale_mixture=scale_mixture,
)

sqrt_eigenvalues_ = sqrt_eigenvalues(ell=ell, m=m, dim=dim)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _run_inference(
Run MCMC on the model conditioned on the given branching trace.
"""
slp_model = condition(self.model, data=branching_trace)
kernel = self.kernel_cls(slp_model) # type: ignore[call-arg]
kernel = self.kernel_cls(slp_model) # ty: ignore[too-many-positional-arguments]
mcmc = MCMC(kernel, **self.mcmc_kwargs)
mcmc.run(rng_key, *args, **kwargs)

Expand Down
4 changes: 3 additions & 1 deletion numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from collections import OrderedDict
from itertools import product
from typing import Union
from typing import Union, cast

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -257,6 +257,7 @@ def summary(
samples = {
"Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
}
samples = cast(dict[str, np.ndarray], samples)

summary_dict = {}
for name, value in samples.items():
Expand Down Expand Up @@ -311,6 +312,7 @@ def print_summary(
samples = {
"Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
}
samples = cast(dict[str, np.ndarray], samples)
summary_dict = summary(samples, prob, group_by_chain=True)
if not summary_dict:
return
Expand Down
34 changes: 21 additions & 13 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
]

import math
from typing import Generic, Optional, cast
from typing import ClassVar, Generic, Optional, cast

import numpy as np

Expand Down Expand Up @@ -93,6 +93,9 @@ def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten)

def tree_flatten(self):
raise NotImplementedError

def __call__(self, x: NumLikeT) -> ArrayLike:
raise NotImplementedError

Expand Down Expand Up @@ -144,6 +147,8 @@ class _SingletonConstraint(ParameterFreeConstraint[NumLikeT]):
and unlike constraints.interval.
"""

_instance: ClassVar["_SingletonConstraint"]

def __new__(cls):
if (not hasattr(cls, "_instance")) or (type(cls._instance) is not cls):
# Do not use the singleton instance of a superclass of cls.
Expand Down Expand Up @@ -212,20 +217,23 @@ class _Dependent(Constraint[NumLike]):
"""

def __init__(
self, *, is_discrete: bool = NotImplemented, event_dim: int = NotImplemented
self,
*,
is_discrete: bool = NotImplemented, # ty: ignore[invalid-parameter-default]
event_dim: int = NotImplemented, # ty: ignore[invalid-parameter-default]
):
self._is_discrete = is_discrete
self._event_dim = event_dim
super().__init__()

@property
def is_discrete(self) -> bool: # type: ignore[override]
def is_discrete(self) -> bool:
if self._is_discrete is NotImplemented:
raise NotImplementedError(".is_discrete cannot be determined statically")
return self._is_discrete

@property
def event_dim(self) -> int: # type: ignore[override]
def event_dim(self) -> int:
if self._event_dim is NotImplemented:
raise NotImplementedError(".event_dim cannot be determined statically")
return self._event_dim
Expand All @@ -234,8 +242,8 @@ def __call__(
self,
x: Optional[NumLike] = None,
*,
is_discrete: bool = NotImplemented,
event_dim: int = NotImplemented,
is_discrete: bool = NotImplemented, # ty: ignore[invalid-parameter-default]
event_dim: int = NotImplemented, # ty: ignore[invalid-parameter-default]
):
if x is not None:
raise ValueError("Cannot determine validity of dependent constraint")
Expand Down Expand Up @@ -273,7 +281,7 @@ def __init__(
self._is_discrete = is_discrete
self._event_dim = event_dim

def __call__(self, x):
def __call__(self, x): # ty: ignore[invalid-method-override]
if not callable(x):
return super().__call__(x)

Expand Down Expand Up @@ -371,21 +379,21 @@ def __init__(
super().__init__()

@property
def is_discrete(self) -> bool: # type: ignore[override]
def is_discrete(self) -> bool:
return self.base_constraint.is_discrete

@property
def event_dim(self) -> int: # type: ignore[override]
def event_dim(self) -> int:
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims

def __call__(self, value: NumLikeT) -> ArrayLike:
result = self.base_constraint(value)
def __call__(self, x: NumLikeT) -> ArrayLike:
result = self.base_constraint(x)
if self.reinterpreted_batch_ndims == 0:
return result
elif jnp.ndim(result) < self.reinterpreted_batch_ndims:
expected = self.event_dim
raise ValueError(
f"Expected value.dim() >= {expected} but got {jnp.ndim(value)}"
f"Expected value.dim() >= {expected} but got {jnp.ndim(x)}"
)
result = jnp.reshape(
result,
Expand Down Expand Up @@ -837,7 +845,7 @@ def __init__(self, event_dim: int = 1) -> None:
super().__init__()

@property
def event_dim(self) -> int: # type: ignore[override]
def event_dim(self) -> int:
return self._event_dim

def __call__(self, x: NonScalarArray) -> ArrayLike:
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def infer_shapes(
if (
cls.support is not None
and hasattr(cls.support, "event_dim")
and cls.support.event_dim > 0
and cast(int, cls.support.event_dim) > 0
):
raise NotImplementedError

Expand Down
7 changes: 5 additions & 2 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten)

def tree_flatten(self):
raise NotImplementedError

@property
def inv(self) -> "Transform":
inv = None
Expand Down Expand Up @@ -582,8 +585,8 @@ class CorrMatrixCholeskyTransform(CholeskyTransform):
correlation matrix.
"""

domain = constraints.corr_matrix # type: ignore[assignment]
codomain = constraints.corr_cholesky # type: ignore[assignment]
domain = constraints.corr_matrix
codomain = constraints.corr_cholesky

def log_abs_det_jacobian(
self,
Expand Down
9 changes: 6 additions & 3 deletions numpyro/examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,11 @@ def _download_with_retries(url: str, out_path: str) -> None:
type(detached_exc).__name__, delay
)
)
assert delay is not None
time.sleep(delay)
if isinstance(last_exc, HTTPError):
last_exc.close()
assert last_exc is not None
raise last_exc


Expand Down Expand Up @@ -231,6 +233,7 @@ def _download(dset: dset) -> None:
else:
if isinstance(last_exc, HTTPError):
last_exc.close()
assert last_exc is not None
raise last_exc


Expand Down Expand Up @@ -441,7 +444,7 @@ def _load_jsb_chorales() -> dict:
return processed_dataset


def _load_higgs(num_datapoints: int) -> dict:
def _load_higgs(num_datapoints: int | None) -> dict:
warnings.warn(
"Higgs is a 2.6 GB dataset",
stacklevel=find_stack_level(),
Expand Down Expand Up @@ -513,7 +516,7 @@ def _load_mortality() -> dict:
}


def _load(dset: dset, num_datapoints: int = -1) -> dict:
def _load(dset: dset, num_datapoints: int | None = -1) -> dict:
if dset == BASEBALL:
return _load_baseball()
elif dset == BOSTON_HOUSING:
Expand Down Expand Up @@ -586,7 +589,7 @@ def get_batch(i=0, idxs=idxs):
return tuple(
np.take(a, ret_idx, axis=0)
if isinstance(a, list)
else lax.index_take(a, (ret_idx,), axes=(0,))
else lax.index_take(a, (ret_idx,), axes=(0,)) # ty: ignore[invalid-argument-type]
for a in arrays
)

Expand Down
18 changes: 14 additions & 4 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def seeded_model(data):

from collections import OrderedDict
from types import TracebackType
from typing import Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Union, cast
import warnings

import numpy as np
Expand All @@ -116,6 +116,9 @@ def seeded_model(data):
)
from numpyro.util import find_stack_level, is_prng_key, not_jax_tracer

if TYPE_CHECKING:
from numpyro.infer.reparam import Reparam

__all__ = [
"block",
"collapse",
Expand Down Expand Up @@ -674,10 +677,14 @@ class reparam(Messenger):
def __init__(
self,
fn: Optional[Callable] = None,
config: Optional[Union[dict, Callable]] = None,
config: Optional[
Union[dict[str, "Reparam"], Callable[[Message], Optional["Reparam"]]]
] = None,
) -> None:
assert isinstance(config, dict) or callable(config)
self.config = config
self.config: Optional[
Union[dict[str, "Reparam"], Callable[[Message], Optional["Reparam"]]]
] = config
super().__init__(fn)

def process_message(self, msg: Message) -> None:
Expand All @@ -686,11 +693,14 @@ def process_message(self, msg: Message) -> None:

if isinstance(self.config, dict):
reparam = self.config.get(msg["name"])
else:
elif self.config is not None:
reparam = self.config(msg)
else:
return
if reparam is None:
return

reparam = cast("Reparam", reparam)
new_fn, value = reparam(msg["name"], msg["fn"], msg["value"])

if value is not None:
Expand Down
Loading
Loading