Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
4a39381
Update actor registry to track GPU tasks
wk9874 Jun 1, 2026
fe7e19a
Designate actors to all be GPU if GPU is enabled
wk9874 Jun 2, 2026
68a55f0
Make use_gpus static based on if they are available
wk9874 Jun 2, 2026
91b5b0c
Incorporate some of abdullah's fixes, add gpu status to health endpoint
wk9874 Jun 3, 2026
b2a0fe5
Add abdullah's fixes to model
wk9874 Jun 3, 2026
d327d86
Add query param to train and predict for use_gpu
wk9874 Jun 3, 2026
0281622
Add use gpu toggle
wk9874 Jun 3, 2026
6ccd8bd
Fix query param
wk9874 Jun 4, 2026
2684900
Rebuild static
wk9874 Jun 4, 2026
4eed25c
Move detection to main
wk9874 Jun 11, 2026
e6acf2f
Improve handling of cpu and gpu actors
wk9874 Jun 11, 2026
1e3763c
Handle gpu killing better
wk9874 Jun 11, 2026
b3a582b
Add if
wk9874 Jun 11, 2026
1921b1d
Add try except
wk9874 Jun 11, 2026
cc25de0
Log messages
wk9874 Jun 11, 2026
c68b7f7
Allow overriding
wk9874 Jun 11, 2026
a3aa119
Check validated samples and anns before creating model
wk9874 Jun 12, 2026
b4b4e9f
Add use GPU toggle to form
wk9874 Jun 12, 2026
9ceb94e
Add contextual help and docs
wk9874 Jun 12, 2026
6e6091c
Merge branch 'wk9874/models/local_load_support' into wk9874/models/gp…
wk9874 Jun 15, 2026
de5f778
Fix tests
wk9874 Jun 15, 2026
172e451
Move use GPU to schema form, improve typing, add e2e test
wk9874 Jun 15, 2026
53b05bd
Resolve conflict
wk9874 Jun 15, 2026
3ff8ce8
change to use_gpu in update_actors
wk9874 Jun 18, 2026
a4a5592
change to use_gpu in update_actors
wk9874 Jun 18, 2026
c862a68
Fix update model test
wk9874 Jun 18, 2026
0f58352
Fix typing of load_model
wk9874 Jun 18, 2026
2407d33
Address PR comments
wk9874 Jun 18, 2026
7d8d83d
Merge branch 'dev' into wk9874/models/gpu_support
wk9874 Jun 18, 2026
57a8eb4
chore: update build output [skip ci]
Jun 18, 2026
22da441
Fix type
wk9874 Jun 18, 2026
ab0a1e2
Merge branch 'wk9874/models/gpu_support' of github.com:ukaea/toktagge…
wk9874 Jun 18, 2026
de3fd15
Fix max_actors calculation
wk9874 Jun 18, 2026
dc940a6
Reduce num GPU in tests to 1
wk9874 Jun 18, 2026
307bbca
Change Use GPU to Allocate
wk9874 Jun 19, 2026
82eaeb7
Change Use GPU to Allocate
wk9874 Jun 19, 2026
435e0dd
Rebuild
wk9874 Jun 19, 2026
f682b08
chore: update build output [skip ci]
Jun 19, 2026
0424abe
Fix mistake in conftest
wk9874 Jun 19, 2026
34b4ee2
Merge branch 'wk9874/models/gpu_support' of github.com:ukaea/toktagge…
wk9874 Jun 19, 2026
4d0f37b
Try setting env var
wk9874 Jun 22, 2026
1dad8c9
Rebuild
wk9874 Jun 22, 2026
7479b47
Merge remote-tracking branch 'origin/main' into wk9874/models/gpu_sup…
Jun 22, 2026
56e9280
chore: update build output [skip ci]
Jun 22, 2026
bcdfb66
Change env var to str
wk9874 Jun 23, 2026
e95cee8
Merge branch 'wk9874/models/gpu_support' of github.com:ukaea/toktagge…
wk9874 Jun 23, 2026
4f85974
Merge branch 'dev' into wk9874/models/gpu_support
wk9874 Jun 23, 2026
1d74b95
Increase timeout, rebuild
wk9874 Jun 23, 2026
4a6a1fc
chore: update build output [skip ci]
Jun 23, 2026
d6ab68a
Merge branch 'dev' into wk9874/models/gpu_support
wk9874 Jun 23, 2026
08c01af
Merge branch 'wk9874/models/gpu_support' of github.com:ukaea/toktagge…
wk9874 Jun 23, 2026
ae9637f
Fix gpu check in disruption cnn
wk9874 Jun 24, 2026
a79c2fb
Merge branch 'dev' into wk9874/models/gpu_support
wk9874 Jun 24, 2026
8058095
chore: update build output [skip ci]
Jun 24, 2026
3e3f596
Fix config default script for bools, update package.lock, rebuild
wk9874 Jun 25, 2026
e702171
Merge branch 'wk9874/models/gpu_support' of github.com:ukaea/toktagge…
wk9874 Jun 25, 2026
8787a17
chore: update build output [skip ci]
Jun 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/custom_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,52 @@ self.log_progress(
- `"failed"` - Training encountered an error
- `"aborted"` - Training was manually stopped

## GPU Usage
TokTagger can reserve GPU nodes for ML Model tasks if requested by the user via the UI. If you wish to use a GPU to perform training or prediction in your model, then you can check whether a GPU is currently enabled on the model's worker node using this method:
```py
class MyModel(Model):
def train(self, ...):

if self.gpu_available():
...
```

Note that your model should be written in a way which makes it agnostic to the environment in which it is run. To do this, you may wish to include a `device` parameter within your Training and Prediction parameters, eg for a PyTorch model:
```py
class MyTrainParams(pydantic.BaseModel):
device: typing.Literal["cpu", "cuda", "mps", "xpu"] = "cpu"
```
The exact format of the inputs required may depend on the ML framework you are using. You can then use these to correctly setup your model for the given environment, for example:
```py
class MyModel(Model):
def train(
self,
samples: list[Sample],
annotations: list[list[Annotation]],
params: MyTrainParams,
):
if self.gpu_available() and params.device != "cpu":
raise ValueError(
"Only CPU available on current worker node, but non-CPU device requested!"
)
device = torch.device(params.device)
self.model.to(device)
```
Note that a CPU is always available on all worker nodes, even if a GPU is requested. To reduce startup time, if a model exists on a worker node with a GPU but prediction is requested on a CPU node, that task will still be executed on the existing worker node which has GPU access. If a new model has a task scheduled which requires a GPU, and there are no free GPUs available, the least most recently used model will be stopped and its worker node will be made available to the new model.

!!! note
Note that TokTagger uses Ray for task scheduling and management. Ray will detect Nvidia and most AMD GPUs reliably and make them available if requested, but for intel GPUs it may be less reliable.

For Macs with Apple silicon chips (such as the M1, M2, M3 series), the GPU is integrated and appears similarly to a CPU, and Ray will not detect it. Therefore it should be available to every worker node, regardless of the Use GPU setting, and regardless of what `self.gpu_available()` reports.

If Ray detects a GPU is available on your machine, it will automatically start TokTagger with GPU support enabled, and will use all GPU cores it is able to detect. If you wish to limit the number of GPUs which TokTagger should use in parallel, set the `MAX_GPU_ACTORS` environment variable. To disable GPUs, set `MAX_GPU_ACTORS` to 0.

!!! tip
As mentioned above, Ray may fail to detect your GPU, especially on Mac devices. If you know how many GPU cores you can use in parallel on your machine, you can set `MAX_GPU_ACTORS` to that number. You should also set the environment variable `FORCE_NUM_GPUS` to `True`, as otherwise TokTagger will throw an error during start up as it tries to allocate more cores than it thinks are available.

Note that this is just for task scheduling, and has no impact on the underlying resources available to each task. However the `Use GPU` button on the UI will now be enabled, and the `self.gpu_available()` will correctly report whether the user requested GPU usage or not for the given task, meaning that you can correctly assign tasks to GPU or CPU hardware accordingly.


## Complete Example: Random Forest Classification Model

Here's a complete example using a scikit-learn RandomForest for time series classification:
Expand Down
8 changes: 5 additions & 3 deletions tests/api/routers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

@pytest.mark.asyncio
@pytest.mark.models_enabled
async def test_health_models_enabled(api_client, setup_db):
response = await api_client.get("/health")
async def test_health_models_enabled(models_api_client, setup_db):
response = await models_api_client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["name"] == "TokTagger"
assert data.get("version") # Won't check its contents here
assert data.get("db_connected")
assert data.get("models_enabled")
assert data.get("gpu_available") # Forced to be 2 GPUs in conftest setup


@pytest.mark.asyncio
Expand All @@ -22,4 +23,5 @@ async def test_health_models_disabled(api_client, setup_db):
assert data["name"] == "TokTagger"
assert data.get("version") # Won't check its contents here
assert data.get("db_connected")
assert not data.get("models_enabled")
assert data.get("models_enabled") is False
assert data.get("gpu_available") is False
Loading