Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,19 @@ def __post_init__(self):
self.stop_sequence = self.stop_sequence if self.stop_sequence is not None else ()
self.full_name = f"{self.name}|{self.num_fewshots}" # todo clefourrier: this is likely incorrect

@staticmethod
def _repr_metric_value(v):
if isinstance(v, functools.partial):
func_name = getattr(v.func, "__name__", str(v.func))
return f"partial({func_name}, ...)"
if isinstance(v, dict):
return repr({key: LightevalTaskConfig._repr_metric_value(val) for key, val in v.items()})
if isinstance(v, Callable):
return getattr(v, "__name__", repr(v))
if isinstance(v, Metric.get_allowed_types_for_metrics()):
return str(v)
return repr(v)

def __str__(self, lite: bool = False): # noqa: C901
md_writer = MarkdownTableWriter()
md_writer.headers = ["Key", "Value"]
Expand All @@ -174,16 +187,7 @@ def __str__(self, lite: bool = False): # noqa: C901
if k == "metrics":
for ix, metrics in enumerate(v):
for metric_k, metric_v in metrics.items():
if isinstance(metric_v, functools.partial):
func_name = getattr(metric_v.func, "__name__", str(metric_v.func))
repr_v = f"partial({func_name}, ...)"
elif isinstance(metric_v, Callable):
repr_v = getattr(metric_v, "__name__", repr(metric_v))
elif isinstance(metric_v, Metric.get_allowed_types_for_metrics()):
repr_v = str(metric_v)
else:
repr_v = repr(metric_v)
values.append([f"{k} {ix}: {metric_k}", repr_v])
values.append([f"{k} {ix}: {metric_k}", self._repr_metric_value(metric_v)])

else:
if isinstance(v, functools.partial):
Expand Down