diff --git a/.gitignore b/.gitignore index ea6f4be..64d75d7 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,6 @@ research_dir/* state_saves/* __pycache__/* Figure*.png -testrun.py \ No newline at end of file +testrun.py +data/* +projects/* diff --git a/ai_lab_repo.py b/ai_lab_repo.py index dbe9541..5f1e5e0 100755 --- a/ai_lab_repo.py +++ b/ai_lab_repo.py @@ -355,7 +355,7 @@ def data_preparation(self): if self.verbose: print("#"*40, f"\nThe following is dialogue produced by the SW Engineer: {dialogue}", "\n", "#"*40) if "```SUBMIT_CODE" in resp: final_code = extract_prompt(resp, "SUBMIT_CODE") - code_resp = execute_code(final_code, timeout=60) + code_resp = execute_code(final_code) if self.verbose: print("!"*100, "\n", f"CODE RESPONSE: {code_resp}") swe_feedback += f"\nCode Response: {code_resp}\n" if "[CODE EXECUTION ERROR]" in code_resp: @@ -389,7 +389,7 @@ def data_preparation(self): if "```python" in resp: code = extract_prompt(resp, "python") code = self.ml_engineer.dataset_code + "\n" + code - code_resp = execute_code(code, timeout=120) + code_resp = execute_code(code) ml_command = f"Code produced by the ML agent:\n{code}" ml_feedback += f"\nCode Response: {code_resp}\n" if self.verbose: print("!"*100, "\n", f"CODE RESPONSE: {code_resp}") @@ -446,10 +446,13 @@ def literature_review(self): @return: (bool) whether to repeat the phase """ arx_eng = ArxivSearch() - max_tries = self.max_steps * 5 # lit review often requires extra steps + max_tries = self.max_steps * 5 # lit review often requires extra steps + # get initial response from PhD agent resp = self.phd.inference(self.research_topic, "literature review", step=0, temp=0.8) - if self.verbose: print(resp, "\n~~~~~~~~~~~") + if self.verbose: + print(resp, "\n~~~~~~~~~~~") + # iterate until max num tries to complete task is exhausted for _i in range(max_tries): feedback = str() @@ -463,14 +466,30 @@ def literature_review(self): # grab full text from arxiv ID elif "```FULL_TEXT" in resp: query = extract_prompt(resp, "FULL_TEXT") - # expiration timer so that paper does not remain in context too long - arxiv_paper = f"```EXPIRATION {self.arxiv_paper_exp_time}\n" + arx_eng.retrieve_full_paper_text(query) + "```" + try: + # expiration timer so that paper does not remain in context too long + full_text_content = arx_eng.retrieve_full_paper_text(query) + except Exception as e: + # Catch any unexpected errors from arxiv.Client() + err_msg = f"[ERROR] Could not retrieve paper. Possibly invalid arXiv ID. Error: {e}" + full_text_content = err_msg + + # In case retrieve_full_paper_text returns an error string + # or if we want to unify it, e.g. "[ERROR] ...something" + if full_text_content.startswith("[ERROR]"): + # We won't crash; just pass that as feedback + arxiv_paper = f"```EXPIRATION {self.arxiv_paper_exp_time}\n{full_text_content}```" + else: + # normal successful retrieval + arxiv_paper = f"```EXPIRATION {self.arxiv_paper_exp_time}\n{full_text_content}```" + feedback = arxiv_paper # if add paper, extract and add to lit review, provide feedback elif "```ADD_PAPER" in resp: query = extract_prompt(resp, "ADD_PAPER") feedback, text = self.phd.add_review(query, arx_eng) + # If we want to store reference text for later usage if len(self.reference_papers) < self.num_ref_papers: self.reference_papers.append(text) @@ -478,6 +497,7 @@ def literature_review(self): if len(self.phd.lit_review) >= self.num_papers_lit_review: # generate formal review lit_review_sum = self.phd.format_review() + # if human in loop -> check if human is happy with the produced review if self.human_in_loop_flag["literature review"]: retry = self.human_in_loop("literature review", lit_review_sum) @@ -485,18 +505,32 @@ def literature_review(self): if retry: self.phd.lit_review = [] return retry + # otherwise, return lit review and move on to next stage - if self.verbose: print(self.phd.lit_review_sum) + if self.verbose: + print(self.phd.lit_review_sum) # set agent self.set_agent_attr("lit_review_sum", lit_review_sum) # reset agent state self.reset_agents() self.statistics_per_phase["literature review"]["steps"] = _i return False - resp = self.phd.inference(self.research_topic, "literature review", feedback=feedback, step=_i + 1, temp=0.8) - if self.verbose: print(resp, "\n~~~~~~~~~~~") + + # Move on to the next iteration with new feedback + resp = self.phd.inference( + self.research_topic, + "literature review", + feedback=feedback, + step=_i + 1, + temp=0.8 + ) + if self.verbose: + print(resp, "\n~~~~~~~~~~~") + + # If we exceed max_tries: raise Exception("Max tries during phase: Literature Review") + def human_in_loop(self, phase, phase_prod): """ Get human feedback for phase output @@ -611,6 +645,7 @@ def parse_arguments(): help='Total number of paper-solver steps' ) + parser.add_argument('--file-path', type=str, default=None) return parser.parse_args() @@ -622,6 +657,8 @@ def parse_arguments(): human_mode = args.copilot_mode.lower() == "true" compile_pdf = args.compile_latex.lower() == "true" load_existing = args.load_existing.lower() == "true" + file_path = args.file_path + try: num_papers_lit_review = int(args.num_papers_lit_review.lower()) except Exception: @@ -654,6 +691,13 @@ def parse_arguments(): else: research_topic = args.research_topic + if file_path and "{FILE}" in research_topic: + with open(file_path, 'r', encoding='utf-8') as f: + file_content = f.read() + # Replace the placeholder with the entire file text + research_topic = research_topic.replace("{FILE}", file_content) + + task_notes_LLM = [ {"phases": ["plan formulation"], "note": f"You should come up with a plan for TWO experiments."}, diff --git a/inference.py b/inference.py index 74b6d0f..a98ce7a 100755 --- a/inference.py +++ b/inference.py @@ -2,6 +2,7 @@ from openai import OpenAI import openai import os, anthropic, json +from utils import clip_tokens TOKENS_IN = dict() TOKENS_OUT = dict() @@ -29,7 +30,7 @@ def curr_cost_est(): } return sum([costmap_in[_]*TOKENS_IN[_] for _ in TOKENS_IN]) + sum([costmap_out[_]*TOKENS_OUT[_] for _ in TOKENS_OUT]) -def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic_api_key=None, tries=5, timeout=5.0, temp=None, print_cost=True, version="1.5"): +def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic_api_key=None, tries=5, timeout=5.0, temp=None, print_cost=True, version="1.5", max_context_tokens=128000): preloaded_api = os.getenv('OPENAI_API_KEY') if openai_api_key is None and preloaded_api is not None: openai_api_key = preloaded_api @@ -47,6 +48,8 @@ def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}] + + messages = clip_tokens(messages, model=model_str, max_tokens=max_context_tokens) if version == "0.28": if temp is None: completion = openai.ChatCompletion.create( diff --git a/requirements.txt b/requirements.txt index e08992b..eed95c4 100755 --- a/requirements.txt +++ b/requirements.txt @@ -10,32 +10,41 @@ arxiv==2.1.3 astunparse==1.6.3 async-timeout==5.0.1 attrs==24.2.0 +beautifulsoup4==4.12.3 blis==1.0.1 catalogue==2.0.10 certifi==2024.8.30 charset-normalizer==3.4.0 click==8.1.7 cloudpathlib==0.20.0 +cloudpickle==3.1.1 confection==0.1.5 contourpy==1.3.0 cycler==0.12.1 cymem==2.0.10 datasets==3.1.0 diffusers==0.31.0 -dill==0.3.8 +dill==0.3.9 distro==1.9.0 +EMD-signal @ git+https://github.com/laszukdawid/PyEMD.git@4fc40017c1db8f1fceda4370a12314e1dedf8dde exceptiongroup==1.2.2 +Farama-Notifications==0.0.4 feedparser==6.0.11 filelock==3.16.1 flatbuffers==24.3.25 fonttools==4.55.0 +frozendict==2.4.6 frozenlist==1.5.0 fsspec==2024.9.0 gast==0.6.0 google-pasta==0.2.0 grpcio==1.68.0 +gym==0.26.2 +gym-notices==0.0.8 +gymnasium==1.0.0 h11==0.14.0 h5py==3.12.1 +html5lib==1.1 httpcore==1.0.7 httpx==0.27.2 huggingface-hub==0.26.2 @@ -52,6 +61,7 @@ langcodes==3.5.0 language_data==1.3.0 lazy_loader==0.4 libclang==18.1.1 +lxml==5.3.0 marisa-trie==1.2.1 Markdown==3.7 markdown-it-py==3.0.0 @@ -61,21 +71,27 @@ mdurl==0.1.2 ml-dtypes==0.4.1 mpmath==1.3.0 multidict==6.1.0 -multiprocess==0.70.16 +multiprocess==0.70.17 +multitasking==0.0.11 murmurhash==1.0.11 namex==0.0.8 nest-asyncio==1.6.0 networkx==3.2.1 nltk==3.9.1 -numpy==2.0.2 +numpy==1.26.4 openai==1.55.1 opt_einsum==3.4.0 optree==0.13.1 packaging==24.2 pandas==2.2.3 +pathos==0.3.3 patsy==1.0.1 +peewee==3.17.8 pillow==11.0.0 +platformdirs==4.3.6 plotly==5.24.1 +pox==0.3.5 +ppft==1.7.6.9 preshed==3.0.9 propcache==0.2.0 protobuf==5.28.3 @@ -83,11 +99,13 @@ psutil==6.1.0 pyarrow==18.1.0 pydantic==2.10.2 pydantic_core==2.27.1 +pyemd==1.0.0 Pygments==2.18.0 pyparsing==3.2.0 pypdf==5.1.0 python-dateutil==2.9.0.post0 pytz==2024.2 +PyWavelets==1.8.0 PyYAML==6.0.2 regex==2024.11.6 requests==2.32.3 @@ -104,10 +122,12 @@ shellingham==1.5.4 six==1.16.0 smart-open==7.0.5 sniffio==1.3.1 +soupsieve==2.6 spacy==3.8.2 spacy-legacy==3.0.12 spacy-loggers==1.0.5 srsly==2.4.8 +stable_baselines3==2.4.1 statsmodels==0.14.4 sympy==1.13.1 tenacity==9.0.0 @@ -130,8 +150,10 @@ tzdata==2024.2 urllib3==2.2.3 wasabi==1.1.3 weasel==0.4.1 +webencodings==0.5.1 Werkzeug==3.1.3 wrapt==1.17.0 xxhash==3.5.0 yarl==1.18.0 +yfinance==0.2.51 zipp==3.21.0 diff --git a/tools.py b/tools.py index 5d0d4a9..0b592b8 100755 --- a/tools.py +++ b/tools.py @@ -16,14 +16,15 @@ import traceback import concurrent.futures - +import psutil +import subprocess class HFDataSearch: def __init__(self, like_thr=3, dwn_thr=50) -> None: """ Class for finding relevant huggingface datasets - :param like_thr: - :param dwn_thr: + :param like_thr: threshold of 'likes' + :param dwn_thr: threshold of 'downloads' """ self.dwn_thr = dwn_thr self.like_thr = like_thr @@ -103,21 +104,24 @@ def retrieve_ds(self, query, N=10, sim_w=1.0, like_w=0.0, dwn_w=0.0): cosine_similarities = linear_kernel(query_vector, self.description_vectors).flatten() # Normalize cosine similarities cosine_similarities_norm = self._normalize(cosine_similarities) + # Compute final scores final_scores = ( - sim_w * cosine_similarities_norm + - like_w * self.likes_norm + - dwn_w * self.downloads_norm + sim_w * cosine_similarities_norm + + like_w * self.likes_norm + + dwn_w * self.downloads_norm ) + # Get top N indices top_indices = final_scores.argsort()[-N:][::-1] # Convert indices to Python ints top_indices = [int(i) for i in top_indices] top_datasets = [self.ds[i] for i in top_indices] - # check if dataset has a test & train set - has_test_set = list() - has_train_set = list() - ds_size_info = list() + + # Check if dataset has a test & train set; gather size info + has_test_set = [] + has_train_set = [] + ds_size_info = [] for i in top_indices: try: dbuilder = load_dataset_builder(self.ds[i]["id"], trust_remote_code=True).info @@ -132,10 +136,11 @@ def retrieve_ds(self, query, N=10, sim_w=1.0, like_w=0.0, dwn_w=0.0): has_train_set.append(False) ds_size_info.append((None, None, None, None)) continue - # Print number of examples for + has_test, has_train = "test" in dbuilder.splits, "train" in dbuilder.splits has_test_set.append(has_test) has_train_set.append(has_train) + test_dwn_size, test_elem_size = None, None train_dwn_size, train_elem_size = None, None if has_test: @@ -144,7 +149,10 @@ def retrieve_ds(self, query, N=10, sim_w=1.0, like_w=0.0, dwn_w=0.0): if has_train: train_dwn_size = bytes2human(dbuilder.splits["train"].num_bytes) train_elem_size = dbuilder.splits["train"].num_examples + ds_size_info.append((test_dwn_size, test_elem_size, train_dwn_size, train_elem_size)) + + # Attach metadata to the top_datasets for _i in range(len(top_datasets)): top_datasets[_i]["has_test_set"] = has_test_set[_i] top_datasets[_i]["has_train_set"] = has_train_set[_i] @@ -152,6 +160,7 @@ def retrieve_ds(self, query, N=10, sim_w=1.0, like_w=0.0, dwn_w=0.0): top_datasets[_i]["test_element_size"] = ds_size_info[_i][1] top_datasets[_i]["train_download_size"] = ds_size_info[_i][2] top_datasets[_i]["train_element_size"] = ds_size_info[_i][3] + return top_datasets def results_str(self, results): @@ -160,7 +169,7 @@ def results_str(self, results): :param results: (list(dict)) list of results from search :return: (list(str)) list of results in human-readable format """ - result_strs = list() + result_strs = [] for result in results: res_str = f"Dataset ID: {result['id']}\n" res_str += f"Description: {result['description']}\n" @@ -175,71 +184,89 @@ def results_str(self, results): result_strs.append(res_str) return result_strs - class SemanticScholarSearch: def __init__(self): self.sch_engine = SemanticScholar(retry=False) def find_papers_by_str(self, query, N=10): - paper_sums = list() - results = self.sch_engine.search_paper(query, limit=N, min_citation_count=3, open_access_pdf=True) + """ + Finds top-N papers from semantic scholar + :param query: str + :param N: number of results + :return: list of string summaries + """ + paper_sums = [] + results = self.sch_engine.search_paper( + query, + limit=N, + min_citation_count=3, + open_access_pdf=True + ) for _i in range(len(results)): - paper_sum = f'Title: {results[_i].title}\n' - paper_sum += f'Abstract: {results[_i].abstract}\n' - paper_sum += f'Citations: {results[_i].citationCount}\n' - paper_sum += f'Release Date: year {results[_i].publicationDate.year}, month {results[_i].publicationDate.month}, day {results[_i].publicationDate.day}\n' - paper_sum += f'Venue: {results[_i].venue}\n' - paper_sum += f'Paper ID: {results[_i].externalIds["DOI"]}\n' + paper_sum = f"Title: {results[_i].title}\n" + paper_sum += f"Abstract: {results[_i].abstract}\n" + paper_sum += f"Citations: {results[_i].citationCount}\n" + paper_sum += ( + f"Release Date: year {results[_i].publicationDate.year}, " + f"month {results[_i].publicationDate.month}, " + f"day {results[_i].publicationDate.day}\n" + ) + paper_sum += f"Venue: {results[_i].venue}\n" + paper_sum += f"Paper ID: {results[_i].externalIds['DOI']}\n" paper_sums.append(paper_sum) return paper_sums def retrieve_full_paper_text(self, query): + """ + NOTE: Not implemented in this example + """ pass - class ArxivSearch: def __init__(self): # Construct the default API client. self.sch_engine = arxiv.Client() def _process_query(self, query: str) -> str: - """Process query string to fit within MAX_QUERY_LENGTH while preserving as much information as possible""" + """ + Process query string to fit within MAX_QUERY_LENGTH + while preserving as much info as possible + """ MAX_QUERY_LENGTH = 300 - if len(query) <= MAX_QUERY_LENGTH: return query - + # Split into words words = query.split() processed_query = [] current_length = 0 - + # Add words while staying under the limit - # Account for spaces between words for word in words: - # +1 for the space that will be added between words if current_length + len(word) + 1 <= MAX_QUERY_LENGTH: processed_query.append(word) current_length += len(word) + 1 else: break - return ' '.join(processed_query) def find_papers_by_str(self, query, N=20): + """ + Finds top-N relevant arXiv papers + """ processed_query = self._process_query(query) max_retries = 3 retry_count = 0 - + while retry_count < max_retries: try: search = arxiv.Search( query="abs:" + processed_query, max_results=N, - sort_by=arxiv.SortCriterion.Relevance) + sort_by=arxiv.SortCriterion.Relevance + ) - paper_sums = list() - # `results` is a generator; you can iterate over its elements one by one... + paper_sums = [] for r in self.sch_engine.results(search): paperid = r.pdf_url.split("/")[-1] pubdate = str(r.published).split(" ")[0] @@ -251,148 +278,179 @@ def find_papers_by_str(self, query, N=20): paper_sums.append(paper_sum) time.sleep(2.0) return "\n".join(paper_sums) - + except Exception as e: retry_count += 1 if retry_count < max_retries: - # 递增延时 + # Exponential-ish back-off time.sleep(2 * retry_count) continue - + + # If unsuccessful return None def retrieve_full_paper_text(self, query): - pdf_text = str() + """ + Download and extract full text from arXiv PDF + """ + pdf_text = "" + # Attempt to get single result with the provided paper ID paper = next(arxiv.Client().results(arxiv.Search(id_list=[query]))) - # Download the PDF to the PWD with a custom filename. + + # Download the PDF to a local file paper.download_pdf(filename="downloaded-paper.pdf") - # creating a pdf reader object - reader = PdfReader('downloaded-paper.pdf') - # Iterate over all the pages + + # Create a pdf reader object + reader = PdfReader("downloaded-paper.pdf") + + # Iterate over pages for page_number, page in enumerate(reader.pages, start=1): - # Extract text from the page try: text = page.extract_text() - except Exception as e: + except Exception: os.remove("downloaded-paper.pdf") time.sleep(2.0) return "EXTRACTION FAILED" - # Do something with the text (e.g., print it) pdf_text += f"--- Page {page_number} ---" pdf_text += text pdf_text += "\n" + os.remove("downloaded-paper.pdf") time.sleep(2.0) return pdf_text -""" -import multiprocessing -import sys -import io -import traceback -def execute_code(code_str, timeout=180): +def execute_code( + code_str, + max_total_time=7200, # Hard limit on total runtime (seconds) + max_idle_time=60, # If no CPU usage / no prints for this many secs, kill + max_stdout_len=2000 +): + """ + Execute code in a separate subprocess with: + 1) absolute time limit (max_total_time) + 2) idle time limit (max_idle_time) based on CPU usage + 3) capture stdout (limited by max_stdout_len) + 4) gracefully handle zombie processes or processes that vanish + """ + import matplotlib + matplotlib.use('Agg') # Use a non-interactive backend + import matplotlib.pyplot as plt + + # Basic checks if "load_dataset('pubmed" in code_str: - return "pubmed Download took way too long. Program terminated" + return "[CODE EXECUTION ERROR] pubmed Download took way too long. Program terminated" + if "exit(" in code_str: + return "[CODE EXECUTION ERROR] The exit() command is not allowed; please remove it." + + temp_filename = "temp_script.py" + with open(temp_filename, "w", encoding="utf-8") as f: + f.write(code_str) + + start_time = time.time() + output_capture = io.StringIO() + + process = subprocess.Popen( + [sys.executable, temp_filename], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True + ) + + proc_psutil = psutil.Process(process.pid) + last_output_time = time.time() + last_cpu_check_time = time.time() + kill_reason = None - def run_code(queue): - # Redirect stdout to capture print outputs - output_capture = io.StringIO() - sys.stdout = output_capture + # For versions of psutil which do have STATUS_ZOMBIE: + # We'll prefer that, else fallback to the literal "zombie". + ZOMBIE_STRING = getattr(psutil, "STATUS_ZOMBIE", "zombie").lower() + while True: + # 1) If the process ended, break the loop + if process.poll() is not None: + break + + # 2) Read any line from stdout try: - exec_globals = {} - exec(code_str, exec_globals) - except Exception as e: - output_capture.write(f"[CODE EXECUTION ERROR]: {str(e)}\n") - traceback.print_exc(file=output_capture) - finally: - # Put the output in the queue - queue.put(output_capture.getvalue()) - # Restore stdout - sys.stdout = sys.__stdout__ - - # Create a multiprocessing Queue to capture the output - queue = multiprocessing.Queue() - # Create a new Process - process = multiprocessing.Process(target=run_code, args=(queue,)) - process.start() - # Wait for the process to finish or timeout - process.join(timeout) - - if process.is_alive(): - process.terminate() - process.join() - return f"[CODE EXECUTION ERROR]: Code execution exceeded the timeout limit of {timeout} seconds. You must reduce the time complexity of your code." - else: - # Retrieve the output from the queue - output = queue.get() - return output - -""" - -import io -import sys -import traceback -import concurrent.futures + line = process.stdout.readline() + except Exception: + line = None + if line: + output_capture.write(line) + last_output_time = time.time() + # 3) CPU usage check every 2 seconds + if (time.time() - last_cpu_check_time) > 2.0: + last_cpu_check_time = time.time() -import multiprocessing -import io -import sys -import traceback -import multiprocessing -import io -import sys -import traceback + # (a) If process is not running => presumably done or zombie + if not proc_psutil.is_running(): + kill_reason = "Process ended or unknown state" + break + # (b) Attempt to get the string status + try: + # e.g. "running", "sleeping", "zombie", etc. + status_str = proc_psutil.status().lower() + except psutil.Error: + # If we can't retrieve status, assume it's gone + kill_reason = "Could not retrieve process status" + process.kill() + break -def execute_code(code_str, timeout=60, MAX_LEN=1000): - #print(code_str) + # If status is "zombie", kill it + if status_str == ZOMBIE_STRING: + kill_reason = "Process is zombie" + process.kill() + break - # prevent plotting errors - import matplotlib - matplotlib.use('Agg') # Use the non-interactive Agg backend - import matplotlib.pyplot as plt + # (c) Try CPU usage + try: + cpu_usage = proc_psutil.cpu_percent(interval=0.1) + except (psutil.NoSuchProcess, psutil.AccessDenied): + kill_reason = "Process is gone or cannot be accessed" + process.kill() + break - # Preventing execution of certain resource-intensive datasets - if "load_dataset('pubmed" in code_str: - return "[CODE EXECUTION ERROR] pubmed Download took way too long. Program terminated" - if "exit(" in code_str: - return "[CODE EXECUTION ERROR] The exit() command is not allowed you must remove this." - #print(code_str) - # Capturing the output - output_capture = io.StringIO() - sys.stdout = output_capture + # If usage > 1%, consider it active + if cpu_usage > 1.0: + last_output_time = time.time() - # Create a new global context for exec - exec_globals = globals() + # 4) Idle check + if (time.time() - last_output_time) > max_idle_time: + kill_reason = f"Idle for {max_idle_time} seconds." + process.kill() + break - def run_code(): - try: - # Executing the code in the global namespace - exec(code_str, exec_globals) - except Exception as e: - output_capture.write(f"[CODE EXECUTION ERROR]: {str(e)}\n") - traceback.print_exc(file=output_capture) - - try: - # Running code in a separate thread with a timeout - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_code) - future.result(timeout=timeout) - except concurrent.futures.TimeoutError: - return f"[CODE EXECUTION ERROR]: Code execution exceeded the timeout limit of {timeout} seconds. You must reduce the time complexity of your code." - except Exception as e: - return f"[CODE EXECUTION ERROR]: {str(e)}" - finally: - # Restoring standard output - sys.stdout = sys.__stdout__ - - # Returning the captured output - return output_capture.getvalue()[:MAX_LEN] + # 5) Total time check + if (time.time() - start_time) > max_total_time: + kill_reason = f"Exceeded total runtime of {max_total_time} seconds." + process.kill() + break + + time.sleep(0.05) # Short pause + + # 6) Drain leftover stdout + if process.poll() is not None: + leftover = process.stdout.read() + if leftover: + output_capture.write(leftover) + + process.stdout.close() + process.wait() + + # 7) Remove temp file if desired + if os.path.exists(temp_filename): + os.remove(temp_filename) + + # 8) Attach kill_reason to logs if any + output = output_capture.getvalue() + if kill_reason: + output += f"\n[CODE EXECUTION STOPPED]: {kill_reason}\n" + return output[:max_stdout_len] diff --git a/utils.py b/utils.py index a163273..47c4faa 100755 --- a/utils.py +++ b/utils.py @@ -74,7 +74,7 @@ def save_to_file(location, filename, data): print(f"Error saving file {filename}: {e}") -def clip_tokens(messages, model="gpt-4", max_tokens=100000): +def clip_tokens(messages, model="o1-mini", max_tokens=128000): enc = tiktoken.encoding_for_model(model) total_tokens = sum([len(enc.encode(message["content"])) for message in messages])