Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lighteval/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def rec(nest: dict, prefix: str, into: dict):
) # Need this for markdown
into[prefix + k + sep + str(i)] = vv.tolist() if isinstance(vv, np.ndarray) else vv
elif isinstance(v, np.ndarray):
into[prefix + k + sep + str(i)] = v.tolist()
into[prefix + k] = v.tolist()
else:
v = clean_markdown(v)
into[prefix + k] = v
Expand Down
25 changes: 24 additions & 1 deletion tests/unit/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

import unittest

from lighteval.utils.utils import remove_reasoning_tags
import numpy as np

from lighteval.utils.utils import flatten_dict, remove_reasoning_tags


class TestRemoveReasoningTags(unittest.TestCase):
Expand Down Expand Up @@ -61,3 +63,24 @@ def test_no_closing_tag(self):
tag_pairs = [("<think>", "</think>")]
result = remove_reasoning_tags(text, tag_pairs)
self.assertEqual(result, "<think> Reasoning section. Answer section")


class TestFlattenDict(unittest.TestCase):
def test_bare_ndarray_value(self):
# A standalone ndarray previously hit the list-loop variable `i`, which
# is unbound here -> UnboundLocalError.
result = flatten_dict({"a": np.array([1, 2, 3])})
self.assertEqual(result, {"a": [1, 2, 3]})

def test_ndarray_after_list_key(self):
# A preceding list key leaks a stale `i`, which produced a bogus indexed
# key (e.g. "arr/2") for the ndarray.
result = flatten_dict({"lst": [10, 20, 30], "arr": np.array([7, 8, 9])})
self.assertEqual(
result,
{"lst/0": 10, "lst/1": 20, "lst/2": 30, "arr": [7, 8, 9]},
)

def test_list_of_ndarrays_still_indexed(self):
result = flatten_dict({"m": [np.array([1, 2]), np.array([3, 4])]})
self.assertEqual(result, {"m/0": [1, 2], "m/1": [3, 4]})