Skip to content
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ research_dir/*
state_saves/*
__pycache__/*
Figure*.png
testrun.py
testrun.py
data/*
projects/*
58 changes: 51 additions & 7 deletions ai_lab_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,13 @@ def literature_review(self):
@return: (bool) whether to repeat the phase
"""
arx_eng = ArxivSearch()
max_tries = self.max_steps * 5 # lit review often requires extra steps
max_tries = self.max_steps * 5 # lit review often requires extra steps

# get initial response from PhD agent
resp = self.phd.inference(self.research_topic, "literature review", step=0, temp=0.8)
if self.verbose: print(resp, "\n~~~~~~~~~~~")
if self.verbose:
print(resp, "\n~~~~~~~~~~~")

# iterate until max num tries to complete task is exhausted
for _i in range(max_tries):
feedback = str()
Expand All @@ -463,40 +466,71 @@ def literature_review(self):
# grab full text from arxiv ID
elif "```FULL_TEXT" in resp:
query = extract_prompt(resp, "FULL_TEXT")
# expiration timer so that paper does not remain in context too long
arxiv_paper = f"```EXPIRATION {self.arxiv_paper_exp_time}\n" + arx_eng.retrieve_full_paper_text(query) + "```"
try:
# expiration timer so that paper does not remain in context too long
full_text_content = arx_eng.retrieve_full_paper_text(query)
except Exception as e:
# Catch any unexpected errors from arxiv.Client()
err_msg = f"[ERROR] Could not retrieve paper. Possibly invalid arXiv ID. Error: {e}"
full_text_content = err_msg

# In case retrieve_full_paper_text returns an error string
# or if we want to unify it, e.g. "[ERROR] ...something"
if full_text_content.startswith("[ERROR]"):
# We won't crash; just pass that as feedback
arxiv_paper = f"```EXPIRATION {self.arxiv_paper_exp_time}\n{full_text_content}```"
else:
# normal successful retrieval
arxiv_paper = f"```EXPIRATION {self.arxiv_paper_exp_time}\n{full_text_content}```"

feedback = arxiv_paper

# if add paper, extract and add to lit review, provide feedback
elif "```ADD_PAPER" in resp:
query = extract_prompt(resp, "ADD_PAPER")
feedback, text = self.phd.add_review(query, arx_eng)
# If we want to store reference text for later usage
if len(self.reference_papers) < self.num_ref_papers:
self.reference_papers.append(text)

# completion condition
if len(self.phd.lit_review) >= self.num_papers_lit_review:
# generate formal review
lit_review_sum = self.phd.format_review()

# if human in loop -> check if human is happy with the produced review
if self.human_in_loop_flag["literature review"]:
retry = self.human_in_loop("literature review", lit_review_sum)
# if not happy, repeat the process with human feedback
if retry:
self.phd.lit_review = []
return retry

# otherwise, return lit review and move on to next stage
if self.verbose: print(self.phd.lit_review_sum)
if self.verbose:
print(self.phd.lit_review_sum)
# set agent
self.set_agent_attr("lit_review_sum", lit_review_sum)
# reset agent state
self.reset_agents()
self.statistics_per_phase["literature review"]["steps"] = _i
return False
resp = self.phd.inference(self.research_topic, "literature review", feedback=feedback, step=_i + 1, temp=0.8)
if self.verbose: print(resp, "\n~~~~~~~~~~~")

# Move on to the next iteration with new feedback
resp = self.phd.inference(
self.research_topic,
"literature review",
feedback=feedback,
step=_i + 1,
temp=0.8
)
if self.verbose:
print(resp, "\n~~~~~~~~~~~")

# If we exceed max_tries:
raise Exception("Max tries during phase: Literature Review")


def human_in_loop(self, phase, phase_prod):
"""
Get human feedback for phase output
Expand Down Expand Up @@ -611,6 +645,7 @@ def parse_arguments():
help='Total number of paper-solver steps'
)

parser.add_argument('--file-path', type=str, default=None)

return parser.parse_args()

Expand All @@ -622,6 +657,8 @@ def parse_arguments():
human_mode = args.copilot_mode.lower() == "true"
compile_pdf = args.compile_latex.lower() == "true"
load_existing = args.load_existing.lower() == "true"
file_path = args.file_path

try:
num_papers_lit_review = int(args.num_papers_lit_review.lower())
except Exception:
Expand Down Expand Up @@ -654,6 +691,13 @@ def parse_arguments():
else:
research_topic = args.research_topic

if file_path and "{FILE}" in research_topic:
with open(file_path, 'r', encoding='utf-8') as f:
file_content = f.read()
# Replace the placeholder with the entire file text
research_topic = research_topic.replace("{FILE}", file_content)


task_notes_LLM = [
{"phases": ["plan formulation"],
"note": f"You should come up with a plan for TWO experiments."},
Expand Down
5 changes: 4 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from openai import OpenAI
import openai
import os, anthropic, json
from utils import clip_tokens

TOKENS_IN = dict()
TOKENS_OUT = dict()
Expand Down Expand Up @@ -29,7 +30,7 @@ def curr_cost_est():
}
return sum([costmap_in[_]*TOKENS_IN[_] for _ in TOKENS_IN]) + sum([costmap_out[_]*TOKENS_OUT[_] for _ in TOKENS_OUT])

def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic_api_key=None, tries=5, timeout=5.0, temp=None, print_cost=True, version="1.5"):
def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic_api_key=None, tries=5, timeout=5.0, temp=None, print_cost=True, version="1.5", max_context_tokens=128000):
preloaded_api = os.getenv('OPENAI_API_KEY')
if openai_api_key is None and preloaded_api is not None:
openai_api_key = preloaded_api
Expand All @@ -47,6 +48,8 @@ def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}]

messages = clip_tokens(messages, model=model_str, max_tokens=max_context_tokens)
if version == "0.28":
if temp is None:
completion = openai.ChatCompletion.create(
Expand Down
4 changes: 2 additions & 2 deletions tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def retrieve_full_paper_text(self, query):
import io
import traceback

def execute_code(code_str, timeout=180):
def execute_code(code_str, timeout=600):
if "load_dataset('pubmed" in code_str:
return "pubmed Download took way too long. Program terminated"

Expand Down Expand Up @@ -349,7 +349,7 @@ def run_code(queue):
import traceback


def execute_code(code_str, timeout=60, MAX_LEN=1000):
def execute_code(code_str, timeout=600, MAX_LEN=1000):
#print(code_str)

# prevent plotting errors
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def save_to_file(location, filename, data):
print(f"Error saving file {filename}: {e}")


def clip_tokens(messages, model="gpt-4", max_tokens=100000):
def clip_tokens(messages, model="o1-mini", max_tokens=128000):
enc = tiktoken.encoding_for_model(model)
total_tokens = sum([len(enc.encode(message["content"])) for message in messages])

Expand Down