Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions bytetracker/byte_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def re_activate(self, new_track, frame_id, new_id=False):
if new_id:
self.track_id = self.next_id()
self.score = new_track.score
self.cls = new_track.cls
# Only update class if new class is not 5 (unknown/ambiguous)
# This preserves the known class when object becomes temporarily ambiguous
if new_track.cls != 5:
self.cls = new_track.cls

def update(self, new_track, frame_id):
"""
Expand All @@ -118,7 +121,6 @@ def update(self, new_track, frame_id):
"""
self.frame_id = frame_id
self.tracklet_len += 1
# self.cls = cls

new_tlwh = new_track.tlwh
self.mean, self.covariance = self.kalman_filter.update(
Expand All @@ -128,6 +130,10 @@ def update(self, new_track, frame_id):
self.is_activated = True

self.score = new_track.score
# Only update class if new class is not 5 (unknown/ambiguous)
# This preserves the known class when object becomes temporarily ambiguous
if new_track.cls != 5:
self.cls = new_track.cls

@property
# @jit(nopython=True)
Expand Down Expand Up @@ -186,7 +192,7 @@ def __repr__(self):

class BYTETracker(object):
def __init__(
self, track_thresh=0.45, track_buffer=25, match_thresh=0.8, frame_rate=30
self, track_thresh=0.45, track_buffer=25, match_thresh=0.8, frame_rate=30, det_thresh=None, second_match_thresh=None, unconfirmed_match_thresh=None
):
self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
Expand All @@ -198,11 +204,23 @@ def __init__(
self.track_thresh = track_thresh
self.match_thresh = match_thresh
# self.det_thresh = track_thresh
self.det_thresh = track_thresh + 0.1
if det_thresh is None:
self.det_thresh = track_thresh + 0.1
else:
self.det_thresh = det_thresh
# self.det_thresh = track_thresh + 0.1

# Make second association and unconfirmed track thresholds configurable
# Default to 0.5 for backward compatibility, but allow override for sensitive mode
self.second_match_thresh = second_match_thresh if second_match_thresh is not None else 0.5
self.unconfirmed_match_thresh = unconfirmed_match_thresh if unconfirmed_match_thresh is not None else 0.7

self.buffer_size = int(frame_rate / 30.0 * track_buffer)
self.max_time_lost = self.buffer_size
self.kalman_filter = KalmanFilter()

print(f"[Tracker] Initialized with track_thresh={track_thresh}, track_buffer={track_buffer}, match_thresh={match_thresh}, frame_rate={frame_rate}, det_thresh={self.det_thresh}, second_match_thresh={self.second_match_thresh}, unconfirmed_match_thresh={self.unconfirmed_match_thresh}")

def update(self, dets, _ = None):
self.frame_id += 1
activated_starcks = []
Expand Down Expand Up @@ -231,7 +249,7 @@ def update(self, dets, _ = None):
confs = confs

remain_inds = confs > self.track_thresh
inds_low = confs > 0.1
inds_low = confs > self.det_thresh
inds_high = confs < self.track_thresh

inds_second = np.logical_and(inds_low, inds_high)
Expand All @@ -243,7 +261,7 @@ def update(self, dets, _ = None):
scores_second = confs[inds_second]

clss_keep = classes[remain_inds]
clss_second = classes[remain_inds]
clss_second = classes[inds_second]

if len(dets) > 0:
"""Detections"""
Expand Down Expand Up @@ -300,7 +318,7 @@ def update(self, dets, _ = None):
]
dists = matching.iou_distance(r_tracked_stracks, detections_second)
matches, u_track, u_detection_second = matching.linear_assignment(
dists, thresh=0.5
dists, thresh=self.second_match_thresh
)
for itracked, idet in matches:
track = r_tracked_stracks[itracked]
Expand All @@ -324,7 +342,7 @@ def update(self, dets, _ = None):
# if not self.args.mot20:
dists = matching.fuse_score(dists, detections)
matches, u_unconfirmed, u_detection = matching.linear_assignment(
dists, thresh=0.7
dists, thresh=self.unconfirmed_match_thresh
)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
Expand Down Expand Up @@ -366,11 +384,10 @@ def update(self, dets, _ = None):
outputs = []
for t in output_stracks:
output = []
tlwh = t.tlwh
# Use tlbr property which correctly converts tlwh to xyxy format
# tlbr = top-left-bottom-right = [x1, y1, x2, y2]
xyxy = t.tlbr
tid = t.track_id
tlwh = np.expand_dims(tlwh, axis=0)
xyxy = xywh2xyxy(tlwh)
xyxy = np.squeeze(xyxy, axis=0)
output.extend(xyxy)
output.append(tid)
output.append(t.cls)
Expand Down