Skip to content

Add flake8 comma rule #8535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _call_preprocess(
raise ValueError(
f"You provided an output field {tool_call_output_field_name} to receive the tool calls information, "
"but did not provide any tools as the input. Please provide a list of tools as the input by adding an "
"input field with type `list[dspy.Tool]`."
"input field with type `list[dspy.Tool]`.",
)

if tool_call_output_field_name and litellm.supports_function_calling(model=lm.model):
Expand All @@ -58,7 +58,7 @@ def _call_preprocess(

signature_for_native_function_calling = signature.delete(tool_call_output_field_name)
signature_for_native_function_calling = signature_for_native_function_calling.delete(
tool_call_input_field_name
tool_call_input_field_name,
)

return signature_for_native_function_calling
Expand Down Expand Up @@ -344,15 +344,15 @@ def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]])
{
"role": "user",
"content": self.format_user_message_content(signature, demo, prefix=incomplete_demo_prefix),
}
},
)
messages.append(
{
"role": "assistant",
"content": self.format_assistant_message_content(
signature, demo, missing_field_message="Not supplied for this particular example. "
signature, demo, missing_field_message="Not supplied for this particular example. ",
),
}
},
)

for demo in complete_demos:
Expand All @@ -361,9 +361,9 @@ def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]])
{
"role": "assistant",
"content": self.format_assistant_message_content(
signature, demo, missing_field_message="Not supplied for this conversation history message. "
signature, demo, missing_field_message="Not supplied for this conversation history message. ",
),
}
},
)

return messages
Expand Down Expand Up @@ -419,13 +419,13 @@ def format_conversation_history(
{
"role": "user",
"content": self.format_user_message_content(signature, message),
}
},
)
messages.append(
{
"role": "assistant",
"content": self.format_assistant_message_content(signature, message),
}
},
)

# Remove the history field from the inputs
Expand Down
4 changes: 2 additions & 2 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ def format_finetune_data(
wrapped in a dictionary with a "messages" key.
"""
system_user_messages = self.format( # returns a list of dicts with the keys "role" and "content"
signature=signature, demos=demos, inputs=inputs
signature=signature, demos=demos, inputs=inputs,
)
assistant_message_content = self.format_assistant_message_content( # returns a string, without the role
signature=signature, outputs=outputs
signature=signature, outputs=outputs,
)
assistant_message = {"role": "assistant", "content": assistant_message_content}
messages = system_user_messages + [assistant_message]
Expand Down
6 changes: 3 additions & 3 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __call__(

try:
structured_output_model = _get_structured_outputs_response_format(
signature, self.use_native_function_calling
signature, self.use_native_function_calling,
)
lm_kwargs["response_format"] = structured_output_model
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
Expand Down Expand Up @@ -201,7 +201,7 @@ def format_field_with_value(self, fields_with_values: dict[FieldInfoWithName, An
return json.dumps(serialize_for_json(d), indent=2)

def format_finetune_data(
self, signature: type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]
self, signature: type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any],
) -> dict[str, list[Any]]:
# TODO: implement format_finetune_data method in JSONAdapter
raise NotImplementedError
Expand All @@ -225,7 +225,7 @@ def _get_structured_outputs_response_format(
annotation = field.annotation
if get_origin(annotation) is dict:
raise ValueError(
f"Field '{name}' has an open-ended mapping type which is not supported by Structured Outputs."
f"Field '{name}' has an open-ended mapping type which is not supported by Structured Outputs.",
)

fields = {}
Expand Down
2 changes: 1 addition & 1 deletion dspy/adapters/two_step_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, extraction_model: LM, **kwargs):
self.extraction_model = extraction_model

def format(
self, signature: type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]
self, signature: type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any],
) -> list[dict[str, Any]]:
"""
Format a prompt for the first stage with the main LM.
Expand Down
6 changes: 3 additions & 3 deletions dspy/adapters/types/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def format(self) -> list[dict[str, Any]]:
"type": "input_audio",
"input_audio": {
"data": data,
"format": self.audio_format
}
"format": self.audio_format,
},
}]


Expand Down Expand Up @@ -85,7 +85,7 @@ def from_file(cls, file_path: str) -> "Audio":

@classmethod
def from_array(
cls, array: Any, sampling_rate: int, format: str = "wav"
cls, array: Any, sampling_rate: int, format: str = "wav",
) -> "Audio":
"""
Process numpy-like array and encode it as base64. Uses sampling rate and audio format for encoding.
Expand Down
4 changes: 2 additions & 2 deletions dspy/adapters/types/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __call__(self, **kwargs):
else:
raise ValueError(
"You are calling `__call__` on an async tool, please use `acall` instead or set "
"`allow_async=True` to run the async tool in sync mode."
"`allow_async=True` to run the async tool in sync mode.",
)
return result

Expand Down Expand Up @@ -306,7 +306,7 @@ def format(self) -> list[dict[str, Any]]:
}
for tool_call in self.tool_calls
],
}
},
]


Expand Down
2 changes: 1 addition & 1 deletion dspy/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def configure_cache(
if enable_disk_cache and enable_litellm_cache:
raise ValueError(
"Cannot enable both LiteLLM and DSPy on-disk cache, please set at most one of `enable_disk_cache` or "
"`enable_litellm_cache` to True."
"`enable_litellm_cache` to True.",
)

if enable_litellm_cache:
Expand Down
46 changes: 23 additions & 23 deletions dspy/clients/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def status(self):
except ImportError:
raise ImportError(
"To use Databricks finetuning, please install the databricks_genai package via "
"`pip install databricks_genai`."
"`pip install databricks_genai`.",
)
run = fm.get(self.finetuning_run)
return run.status
Expand Down Expand Up @@ -84,7 +84,7 @@ def deploy_finetuned_model(
model_name = model.replace(".", "_")

get_endpoint_response = requests.get(
url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}", json={"name": model_name}, headers=headers
url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}", json={"name": model_name}, headers=headers,
)

if get_endpoint_response.status_code == 200:
Expand All @@ -98,8 +98,8 @@ def deploy_finetuned_model(
"entity_version": model_version,
"min_provisioned_throughput": min_provisioned_throughput,
"max_provisioned_throughput": max_provisioned_throughput,
}
]
},
],
}

response = requests.put(
Expand All @@ -120,23 +120,23 @@ def deploy_finetuned_model(
"entity_version": model_version,
"min_provisioned_throughput": min_provisioned_throughput,
"max_provisioned_throughput": max_provisioned_throughput,
}
]
},
],
},
}

response = requests.post(url=f"{databricks_host}/api/2.0/serving-endpoints", json=data, headers=headers)

if response.status_code == 200:
logger.info(
f"Successfully started creating/updating serving endpoint {model_name} on Databricks model serving!"
f"Successfully started creating/updating serving endpoint {model_name} on Databricks model serving!",
)
else:
raise ValueError(f"Failed to create serving endpoint: {response.json()}.")

logger.info(
f"Waiting for serving endpoint {model_name} to be ready, this might take a few minutes... You can check "
f"the status of the endpoint at {databricks_host}/ml/endpoints/{model_name}"
f"the status of the endpoint at {databricks_host}/ml/endpoints/{model_name}",
)
from openai import OpenAI

Expand All @@ -150,7 +150,7 @@ def deploy_finetuned_model(
try:
if data_format == TrainDataFormat.CHAT:
client.chat.completions.create(
messages=[{"role": "user", "content": "hi"}], model=model_name, max_tokens=1
messages=[{"role": "user", "content": "hi"}], model=model_name, max_tokens=1,
)
elif data_format == TrainDataFormat.COMPLETION:
client.completions.create(prompt="hi", model=model_name, max_tokens=1)
Expand All @@ -161,7 +161,7 @@ def deploy_finetuned_model(

raise ValueError(
f"Failed to create serving endpoint {model_name} on Databricks model serving platform within "
f"{deploy_timeout} seconds."
f"{deploy_timeout} seconds.",
)

@staticmethod
Expand All @@ -179,22 +179,22 @@ def finetune(
train_data_format = TrainDataFormat.COMPLETION
else:
raise ValueError(
f"String `train_data_format` must be one of 'chat' or 'completion', but received: {train_data_format}."
f"String `train_data_format` must be one of 'chat' or 'completion', but received: {train_data_format}.",
)

if "train_data_path" not in train_kwargs:
raise ValueError("The `train_data_path` must be provided to finetune on Databricks.")
# Add the file name to the directory path.
train_kwargs["train_data_path"] = DatabricksProvider.upload_data(
train_data, train_kwargs["train_data_path"], train_data_format
train_data, train_kwargs["train_data_path"], train_data_format,
)

try:
from databricks.model_training import foundation_model as fm
except ImportError:
raise ImportError(
"To use Databricks finetuning, please install the databricks_genai package via "
"`pip install databricks_genai`."
"`pip install databricks_genai`.",
)

if "register_to" not in train_kwargs:
Expand Down Expand Up @@ -224,7 +224,7 @@ def finetune(
elif job.run.status.display_name == "Failed":
raise ValueError(
f"Finetuning run failed with status: {job.run.status.display_name}. Please check the Databricks "
f"workspace for more details. Finetuning job's metadata: {job.run}."
f"workspace for more details. Finetuning job's metadata: {job.run}.",
)
else:
time.sleep(60)
Expand All @@ -236,7 +236,7 @@ def finetune(
model_to_deploy = train_kwargs.get("register_to")
job.endpoint_name = model_to_deploy.replace(".", "_")
DatabricksProvider.deploy_finetuned_model(
model_to_deploy, train_data_format, databricks_host, databricks_token, deploy_timeout
model_to_deploy, train_data_format, databricks_host, databricks_token, deploy_timeout,
)
job.launch_completed = True
# The finetuned model name should be in the format: "databricks/<endpoint_name>".
Expand Down Expand Up @@ -266,7 +266,7 @@ def _get_workspace_client() -> "WorkspaceClient":
except ImportError:
raise ImportError(
"To use Databricks finetuning, please install the databricks-sdk package via "
"`pip install databricks-sdk`."
"`pip install databricks-sdk`.",
)
return WorkspaceClient()

Expand All @@ -277,7 +277,7 @@ def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databric
if not match:
raise ValueError(
f"Databricks Unity Catalog path must be in the format '/Volumes/<catalog>/<schema>/<volume>/...', but "
f"received: {databricks_unity_catalog_path}."
f"received: {databricks_unity_catalog_path}.",
)

catalog = match.group("catalog")
Expand All @@ -290,7 +290,7 @@ def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databric
except Exception:
raise ValueError(
f"Databricks Unity Catalog volume does not exist: {volume_path}, please create it on the Databricks "
"workspace."
"workspace.",
)

try:
Expand Down Expand Up @@ -326,32 +326,32 @@ def _validate_chat_data(data: dict[str, Any]):
if "messages" not in data:
raise ValueError(
"Each finetuning data must be a dict with a 'messages' key when `task=CHAT_COMPLETION`, but "
f"received: {data}"
f"received: {data}",
)

if not isinstance(data["messages"], list):
raise ValueError(
"The value of the 'messages' key in each finetuning data must be a list of dicts with keys 'role' and "
f"'content' when `task=CHAT_COMPLETION`, but received: {data['messages']}"
f"'content' when `task=CHAT_COMPLETION`, but received: {data['messages']}",
)

for message in data["messages"]:
if "role" not in message:
raise ValueError(f"Each message in the 'messages' list must contain a 'role' key, but received: {message}.")
if "content" not in message:
raise ValueError(
f"Each message in the 'messages' list must contain a 'content' key, but received: {message}."
f"Each message in the 'messages' list must contain a 'content' key, but received: {message}.",
)


def _validate_completion_data(data: dict[str, Any]):
if "prompt" not in data:
raise ValueError(
"Each finetuning data must be a dict with a 'prompt' key when `task=INSTRUCTION_FINETUNE`, but "
f"received: {data}"
f"received: {data}",
)
if "response" not in data and "completion" not in data:
raise ValueError(
"Each finetuning data must be a dict with a 'response' or 'completion' key when "
f"`task=INSTRUCTION_FINETUNE`, but received: {data}"
f"`task=INSTRUCTION_FINETUNE`, but received: {data}",
)
4 changes: 2 additions & 2 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def forward(self, prompt=None, messages=None, **kwargs):
"You can inspect the latest LM interactions with `dspy.inspect_history()`. "
"To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
" if the reason for truncation is repetition."
" if the reason for truncation is repetition.",
)

if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
Expand Down Expand Up @@ -166,7 +166,7 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
"You can inspect the latest LM interactions with `dspy.inspect_history()`. "
"To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
" if the reason for truncation is repetition."
" if the reason for truncation is repetition.",
)

if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
Expand Down
6 changes: 3 additions & 3 deletions dspy/clients/lm_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def launch(lm: "LM", launch_kwargs: dict[str, Any] | None = None):
except ImportError:
raise ImportError(
"For local model launching, please install sglang."
"Navigate to https://docs.sglang.ai/start/install.html for the latest installation instructions!"
"Navigate to https://docs.sglang.ai/start/install.html for the latest installation instructions!",
)

if hasattr(lm, "process"):
Expand Down Expand Up @@ -195,7 +195,7 @@ def train_sft_locally(model_name, train_data, train_kwargs):
except ImportError:
raise ImportError(
"For local finetuning, please install torch, transformers, and trl "
"by running `pip install -U torch transformers accelerate trl peft`"
"by running `pip install -U torch transformers accelerate trl peft`",
)

device = train_kwargs.get("device", None)
Expand All @@ -220,7 +220,7 @@ def train_sft_locally(model_name, train_data, train_kwargs):
if "max_seq_length" not in train_kwargs:
train_kwargs["max_seq_length"] = 4096
logger.info(
f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}"
f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}",
)

from datasets import Dataset
Expand Down
Loading