Skip to content

Add GPU Support for ML Models#291

Open
wk9874 wants to merge 57 commits into
devfrom
wk9874/models/gpu_support
Open

Add GPU Support for ML Models#291
wk9874 wants to merge 57 commits into
devfrom
wk9874/models/gpu_support

Conversation

@wk9874

@wk9874 wk9874 commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds ML Model GPU Support. For Train and Predict tasks, the user can now specify whether they want to reserve a GPU worker node for it to run on via a switch on the UI. This is then passed to the backend and into Ray, which allocates Actors to hardware accordingly. Note that it is down to the model implementation as to how / whether it uses the hardware

Changes

  • Adds code into main.py which determines how many CPU and GPU nodes are available (and allows overriding these by user via env vars)
  • ActorRegistry tracks GPU actors separately from CPU ones, and compares these to the number of GPU nodes available before comparing total numbers, removing stale GPU actors first
  • Adds use_gpu query params to train and predict endpoints, and passes these into worker tasks
  • If GPU requested, Worker checks if existing Actors have access to GPU, stopping and restarting on GPU node if not
  • Adds fixes to disruption model so that it handles device correctly
  • Adds 'Use GPU' switch and contextual help to train and predict modals via SchemaForm to set use_gpu query params
  • Adds tests, refactors models tests to make more reliable
  • Adds documentation

Closes #264

Don't merge until #273 is merged

@wk9874 wk9874 changed the title Wk9874/models/gpu support Add GPU Support for ML Models Jun 16, 2026

@samueljackson92 samueljackson92 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

Bugs

1. ActorRegistry.update_actors() always stores self.gpu_enabled, not the actual GPU status of the actortoktagger/api/models/base.py

# new actor path:
self.actors[actor_name] = self.gpu_enabled  # always the server-wide flag

The actors dict is supposed to map actor_name → is_gpu_actor, but update_actors doesn't receive use_gpu. When a CPU-only prediction is dispatched on a GPU-enabled server, it gets stored as True, inflating gpu_count and causing legitimate GPU actors to be incorrectly evicted. The update_actors call sites in the router need to pass use_gpu through.


2. create_sample_predictions ignores the use_gpu query paramtoktagger/api/routers/models.py

task = get_predictions.remote(
    ...
    use_gpu=task_registry.gpu_enabled,  # should be: use_gpu=use_gpu
)

The use_gpu query param is validated at the top of the function (raising 409 if not available), but then the worker is always launched with the server-wide gpu_enabled flag rather than the user's choice.


3. int(cluster_resources.get("CPU")) will raise TypeError if "CPU" is absenttoktagger/api/models/base.py

cpus_available = int(cluster_resources.get("CPU")) or os.cpu_count()

int(None) raises TypeError. Should be int(cluster_resources.get("CPU", 0)).


4. max_gpu_actors calculation with 1 GPU silently disables GPU supporttoktagger/api/models/base.py

max_gpu_actors = int(cluster_resources.get("GPU", 0)) - 1

With 1 GPU: 1 - 1 = 0gpu_enabled = False. If this subtraction is intentional (reserving one GPU for something else), it needs a comment. As written it makes GPU support impossible on single-GPU systems.


5. Wrong model ID checked in test_update_modeltests/api/crud/test_utils.py

await utils.update_model(db_client, model_id=setup_model_db["model_id_3"], updates=...)
# then checks:
model_updated = await db_client.get_document_by_id("models", ObjectId(setup_model_db["model_id_1"]))

The test updates model_id_3 but asserts the result on model_id_1. This is a pre-existing bug carried forward from the old setup_db fixture.


6. load_model return type annotation is wrongtoktagger/api/worker.py

def load_model(...) -> tuple[str, str | None]:
    ...
    return {"project_id": ..., "model_id": ..., "message": ...}

Annotated as returning a tuple, actually returns a dict. The caller in the router uses result.get(...) so it works at runtime, but the annotation is misleading.


7. test_model_load_local_disabled is missing @pytest.mark.models_enabledtests/api/routers/test_models.py

Every other test in that file has this mark, but test_model_load_local_disabled does not. Without it, the check_models_status autouse fixture won't skip it when Ray is absent, and it will fail with an unexpected error.


Code Quality / Medium Issues

8. get_actor() uses exception-as-control-flow and a private Ray APItoktagger/api/worker.py

ray.get(ml_model.__ray_terminate__.remote())
raise ValueError("Actor has no GPU, but GPU has been requested.")
# falls through to:
except ValueError:
    # "Actor not alive, so load from weights"
    ml_model = ray.remote(num_gpus=1 if use_gpu else 0)(model_type)...

Using __ray_terminate__ is fragile (private internal API). Raising ValueError to trigger actor recreation is confusing — the log message "Actor not alive, so load from weights" fires even though the actor was alive. A cleaner approach would be a flag variable or an explicit condition check.


9. get_load_model_status() return type is -> bool but returns mixed typestoktagger/api/routers/models.py

The function returns JSONResponse (202), True (200), or raises. FastAPI will serialize True as true in the 200 response, which is a valid but unusual API shape. The misleading annotation may cause confusion for future callers. The return type should be JSONResponse | bool.


10. VideoCNN is missing @ray.remote decoratortoktagger/api/models/temp.py

All other model classes that serve as Ray actors are decorated with @ray.remote. VideoCNN lacks it and cannot be used as an actor. If this is intentionally incomplete (the filename is temp.py), it should at least have a TODO noting this.


Minor

11. Leftover print() in check_models_status fixturetests/conftest.py

def check_models_status(request):
    print()  # leftover debug call

12. Redundant ternary in gpu_available()toktagger/api/models/base.py

return True if assigned_resources.get("GPU") else False
# cleaner:
return bool(assigned_resources.get("GPU"))

Summary

Severity Count Key items
Bug 7 GPU flag not tracked per-actor (#1), use_gpu param ignored in sample predict (#2), potential TypeError (#3), wrong model checked in test (#5), wrong return type (#6)
Medium 3 Exception-as-control-flow + private Ray API (#8), missing test mark (#7), VideoCNN missing @ray.remote (#10)
Minor 2 Leftover print() (#11), redundant ternary (#12)

The most critical issues are #1 and #2 — the GPU tracking data structure is inconsistent with the actual GPU status of each actor, and the sample prediction endpoint ignores the user's use_gpu choice.

@samueljackson92 samueljackson92 added the enhancement New feature or request label Jun 18, 2026
@wk9874

wk9874 commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

Finally passes the CI

@abdullah-ukaea & @praksharma reminder to give this a quick functionality test when you get a chance pls!

@abdullah-ukaea

Copy link
Copy Markdown
Collaborator

Finally passes the CI

@abdullah-ukaea & @praksharma reminder to give this a quick functionality test when you get a chance pls!

LGTM!, tested on a NVIDIA GeForce RTX 4090. Model training and model predict works as expected.

@wk9874 wk9874 requested a review from samueljackson92 June 26, 2026 08:36
@praksharma

Copy link
Copy Markdown
Member

Tested on M4 Max. A GPU was not discovered as expected.
image

But PyTorch detects the GPU/ MPS for both training and prediction.

(YoloVideoDetectionModel pid=17952) {'model': 'yolo26n.pt', 'epochs': 1, 'batch': 5, 'imgsz': 640, 'workers': 0, 'device': 'mps', 'save': False, 'plots': False, 'val': False, 'close_mosaic': 0}
(YoloVideoDetectionModel pid=17952) Ultralytics 8.4.33 🚀 Python-3.12.2 torch-2.11.0 MPS (Apple M4 Max)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ML Models: GPU support

4 participants