From cdc08a9aaba7eb39d0d5c62005f0469c2ea4f929 Mon Sep 17 00:00:00 2001 From: Darren Shepherd Date: Wed, 14 Aug 2024 09:06:28 -0700 Subject: [PATCH] feat: add load method --- gptscript/__init__.py | 2 +- gptscript/frame.py | 9 ++++++--- gptscript/gptscript.py | 7 ++++++- tests/test_gptscript.py | 7 +++++++ 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/gptscript/__init__.py b/gptscript/__init__.py index ef3adbe..9b5de34 100644 --- a/gptscript/__init__.py +++ b/gptscript/__init__.py @@ -1,6 +1,6 @@ from gptscript.gptscript import GPTScript from gptscript.confirm import AuthResponse -from gptscript.frame import RunFrame, CallFrame, PromptFrame +from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program from gptscript.opts import GlobalOptions from gptscript.prompt import PromptResponse from gptscript.run import Run, RunBasicCommand, Options diff --git a/gptscript/frame.py b/gptscript/frame.py index f42ee08..bdaec4f 100644 --- a/gptscript/frame.py +++ b/gptscript/frame.py @@ -47,9 +47,12 @@ def __init__(self, self.name = name self.entryToolId = entryToolId self.toolSet = toolSet - for tool in toolSet: - if isinstance(self.toolSet[tool], dict): - self.toolSet[tool] = Tool(**self.toolSet[tool]) + if self.toolSet is None: + self.toolSet = {} + else: + for tool in toolSet: + if isinstance(self.toolSet[tool], dict): + self.toolSet[tool] = Tool(**self.toolSet[tool]) class RunFrame: diff --git a/gptscript/gptscript.py b/gptscript/gptscript.py index 8f20a2f..bb3bb57 100644 --- a/gptscript/gptscript.py +++ b/gptscript/gptscript.py @@ -9,7 +9,7 @@ import requests from gptscript.confirm import AuthResponse -from gptscript.frame import RunFrame, CallFrame, PromptFrame +from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program from gptscript.opts import GlobalOptions from gptscript.prompt import PromptResponse from gptscript.run import Run, RunBasicCommand, Options @@ -94,6 +94,11 @@ def run( "" if opts is None else opts.input ) + async def load(self, file_path: str) -> Program: + out = await self._run_basic_command("load", {"file": file_path}) + parsed_nodes = json.loads(out) + return Program(**parsed_nodes.get("program", {})) + async def parse(self, file_path: str, disable_cache: bool = False) -> list[Text | Tool]: out = await self._run_basic_command("parse", {"file": file_path, "disableCache": disable_cache}) parsed_nodes = json.loads(out) diff --git a/tests/test_gptscript.py b/tests/test_gptscript.py index 69dcf6e..b56d7f6 100644 --- a/tests/test_gptscript.py +++ b/tests/test_gptscript.py @@ -248,6 +248,13 @@ async def test_eval_with_context(gptscript): assert "Acorn Labs" == await run.text(), "Unexpected output from eval using context" +@pytest.mark.asyncio +async def test_load_simple_file(gptscript): + wd = os.getcwd() + prg = await gptscript.load(wd + "/tests/fixtures/test.gpt") + assert prg.toolSet[prg.entryToolId].instructions == "Who was the president of the United States in 1986?", \ + "Unexpected output from parsing simple file" + @pytest.mark.asyncio async def test_parse_simple_file(gptscript):