Skip to content

Commit a7e7f09

Browse files
committed
add flake8 comma rule
1 parent 2a4e7a0 commit a7e7f09

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+311
-310
lines changed

dspy/adapters/base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _call_preprocess(
4343
raise ValueError(
4444
f"You provided an output field {tool_call_output_field_name} to receive the tool calls information, "
4545
"but did not provide any tools as the input. Please provide a list of tools as the input by adding an "
46-
"input field with type `list[dspy.Tool]`."
46+
"input field with type `list[dspy.Tool]`.",
4747
)
4848

4949
if tool_call_output_field_name and litellm.supports_function_calling(model=lm.model):
@@ -58,7 +58,7 @@ def _call_preprocess(
5858

5959
signature_for_native_function_calling = signature.delete(tool_call_output_field_name)
6060
signature_for_native_function_calling = signature_for_native_function_calling.delete(
61-
tool_call_input_field_name
61+
tool_call_input_field_name,
6262
)
6363

6464
return signature_for_native_function_calling
@@ -344,15 +344,15 @@ def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]])
344344
{
345345
"role": "user",
346346
"content": self.format_user_message_content(signature, demo, prefix=incomplete_demo_prefix),
347-
}
347+
},
348348
)
349349
messages.append(
350350
{
351351
"role": "assistant",
352352
"content": self.format_assistant_message_content(
353-
signature, demo, missing_field_message="Not supplied for this particular example. "
353+
signature, demo, missing_field_message="Not supplied for this particular example. ",
354354
),
355-
}
355+
},
356356
)
357357

358358
for demo in complete_demos:
@@ -361,9 +361,9 @@ def format_demos(self, signature: type[Signature], demos: list[dict[str, Any]])
361361
{
362362
"role": "assistant",
363363
"content": self.format_assistant_message_content(
364-
signature, demo, missing_field_message="Not supplied for this conversation history message. "
364+
signature, demo, missing_field_message="Not supplied for this conversation history message. ",
365365
),
366-
}
366+
},
367367
)
368368

369369
return messages
@@ -419,13 +419,13 @@ def format_conversation_history(
419419
{
420420
"role": "user",
421421
"content": self.format_user_message_content(signature, message),
422-
}
422+
},
423423
)
424424
messages.append(
425425
{
426426
"role": "assistant",
427427
"content": self.format_assistant_message_content(signature, message),
428-
}
428+
},
429429
)
430430

431431
# Remove the history field from the inputs

dspy/adapters/chat_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,10 @@ def format_finetune_data(
238238
wrapped in a dictionary with a "messages" key.
239239
"""
240240
system_user_messages = self.format( # returns a list of dicts with the keys "role" and "content"
241-
signature=signature, demos=demos, inputs=inputs
241+
signature=signature, demos=demos, inputs=inputs,
242242
)
243243
assistant_message_content = self.format_assistant_message_content( # returns a string, without the role
244-
signature=signature, outputs=outputs
244+
signature=signature, outputs=outputs,
245245
)
246246
assistant_message = {"role": "assistant", "content": assistant_message_content}
247247
messages = system_user_messages + [assistant_message]

dspy/adapters/json_adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __call__(
7272

7373
try:
7474
structured_output_model = _get_structured_outputs_response_format(
75-
signature, self.use_native_function_calling
75+
signature, self.use_native_function_calling,
7676
)
7777
lm_kwargs["response_format"] = structured_output_model
7878
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
@@ -201,7 +201,7 @@ def format_field_with_value(self, fields_with_values: dict[FieldInfoWithName, An
201201
return json.dumps(serialize_for_json(d), indent=2)
202202

203203
def format_finetune_data(
204-
self, signature: type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]
204+
self, signature: type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any],
205205
) -> dict[str, list[Any]]:
206206
# TODO: implement format_finetune_data method in JSONAdapter
207207
raise NotImplementedError
@@ -225,7 +225,7 @@ def _get_structured_outputs_response_format(
225225
annotation = field.annotation
226226
if get_origin(annotation) is dict:
227227
raise ValueError(
228-
f"Field '{name}' has an open-ended mapping type which is not supported by Structured Outputs."
228+
f"Field '{name}' has an open-ended mapping type which is not supported by Structured Outputs.",
229229
)
230230

231231
fields = {}

dspy/adapters/two_step_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, extraction_model: LM, **kwargs):
4646
self.extraction_model = extraction_model
4747

4848
def format(
49-
self, signature: type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]
49+
self, signature: type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any],
5050
) -> list[dict[str, Any]]:
5151
"""
5252
Format a prompt for the first stage with the main LM.

dspy/adapters/types/audio.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def format(self) -> list[dict[str, Any]]:
3535
"type": "input_audio",
3636
"input_audio": {
3737
"data": data,
38-
"format": self.audio_format
39-
}
38+
"format": self.audio_format,
39+
},
4040
}]
4141

4242

@@ -85,7 +85,7 @@ def from_file(cls, file_path: str) -> "Audio":
8585

8686
@classmethod
8787
def from_array(
88-
cls, array: Any, sampling_rate: int, format: str = "wav"
88+
cls, array: Any, sampling_rate: int, format: str = "wav",
8989
) -> "Audio":
9090
"""
9191
Process numpy-like array and encode it as base64. Uses sampling rate and audio format for encoding.

dspy/adapters/types/tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __call__(self, **kwargs):
179179
else:
180180
raise ValueError(
181181
"You are calling `__call__` on an async tool, please use `acall` instead or set "
182-
"`allow_async=True` to run the async tool in sync mode."
182+
"`allow_async=True` to run the async tool in sync mode.",
183183
)
184184
return result
185185

@@ -306,7 +306,7 @@ def format(self) -> list[dict[str, Any]]:
306306
}
307307
for tool_call in self.tool_calls
308308
],
309-
}
309+
},
310310
]
311311

312312

dspy/clients/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def configure_cache(
4747
if enable_disk_cache and enable_litellm_cache:
4848
raise ValueError(
4949
"Cannot enable both LiteLLM and DSPy on-disk cache, please set at most one of `enable_disk_cache` or "
50-
"`enable_litellm_cache` to True."
50+
"`enable_litellm_cache` to True.",
5151
)
5252

5353
if enable_litellm_cache:

dspy/clients/databricks.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def status(self):
3232
except ImportError:
3333
raise ImportError(
3434
"To use Databricks finetuning, please install the databricks_genai package via "
35-
"`pip install databricks_genai`."
35+
"`pip install databricks_genai`.",
3636
)
3737
run = fm.get(self.finetuning_run)
3838
return run.status
@@ -84,7 +84,7 @@ def deploy_finetuned_model(
8484
model_name = model.replace(".", "_")
8585

8686
get_endpoint_response = requests.get(
87-
url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}", json={"name": model_name}, headers=headers
87+
url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}", json={"name": model_name}, headers=headers,
8888
)
8989

9090
if get_endpoint_response.status_code == 200:
@@ -98,8 +98,8 @@ def deploy_finetuned_model(
9898
"entity_version": model_version,
9999
"min_provisioned_throughput": min_provisioned_throughput,
100100
"max_provisioned_throughput": max_provisioned_throughput,
101-
}
102-
]
101+
},
102+
],
103103
}
104104

105105
response = requests.put(
@@ -120,23 +120,23 @@ def deploy_finetuned_model(
120120
"entity_version": model_version,
121121
"min_provisioned_throughput": min_provisioned_throughput,
122122
"max_provisioned_throughput": max_provisioned_throughput,
123-
}
124-
]
123+
},
124+
],
125125
},
126126
}
127127

128128
response = requests.post(url=f"{databricks_host}/api/2.0/serving-endpoints", json=data, headers=headers)
129129

130130
if response.status_code == 200:
131131
logger.info(
132-
f"Successfully started creating/updating serving endpoint {model_name} on Databricks model serving!"
132+
f"Successfully started creating/updating serving endpoint {model_name} on Databricks model serving!",
133133
)
134134
else:
135135
raise ValueError(f"Failed to create serving endpoint: {response.json()}.")
136136

137137
logger.info(
138138
f"Waiting for serving endpoint {model_name} to be ready, this might take a few minutes... You can check "
139-
f"the status of the endpoint at {databricks_host}/ml/endpoints/{model_name}"
139+
f"the status of the endpoint at {databricks_host}/ml/endpoints/{model_name}",
140140
)
141141
from openai import OpenAI
142142

@@ -150,7 +150,7 @@ def deploy_finetuned_model(
150150
try:
151151
if data_format == TrainDataFormat.CHAT:
152152
client.chat.completions.create(
153-
messages=[{"role": "user", "content": "hi"}], model=model_name, max_tokens=1
153+
messages=[{"role": "user", "content": "hi"}], model=model_name, max_tokens=1,
154154
)
155155
elif data_format == TrainDataFormat.COMPLETION:
156156
client.completions.create(prompt="hi", model=model_name, max_tokens=1)
@@ -161,7 +161,7 @@ def deploy_finetuned_model(
161161

162162
raise ValueError(
163163
f"Failed to create serving endpoint {model_name} on Databricks model serving platform within "
164-
f"{deploy_timeout} seconds."
164+
f"{deploy_timeout} seconds.",
165165
)
166166

167167
@staticmethod
@@ -179,22 +179,22 @@ def finetune(
179179
train_data_format = TrainDataFormat.COMPLETION
180180
else:
181181
raise ValueError(
182-
f"String `train_data_format` must be one of 'chat' or 'completion', but received: {train_data_format}."
182+
f"String `train_data_format` must be one of 'chat' or 'completion', but received: {train_data_format}.",
183183
)
184184

185185
if "train_data_path" not in train_kwargs:
186186
raise ValueError("The `train_data_path` must be provided to finetune on Databricks.")
187187
# Add the file name to the directory path.
188188
train_kwargs["train_data_path"] = DatabricksProvider.upload_data(
189-
train_data, train_kwargs["train_data_path"], train_data_format
189+
train_data, train_kwargs["train_data_path"], train_data_format,
190190
)
191191

192192
try:
193193
from databricks.model_training import foundation_model as fm
194194
except ImportError:
195195
raise ImportError(
196196
"To use Databricks finetuning, please install the databricks_genai package via "
197-
"`pip install databricks_genai`."
197+
"`pip install databricks_genai`.",
198198
)
199199

200200
if "register_to" not in train_kwargs:
@@ -224,7 +224,7 @@ def finetune(
224224
elif job.run.status.display_name == "Failed":
225225
raise ValueError(
226226
f"Finetuning run failed with status: {job.run.status.display_name}. Please check the Databricks "
227-
f"workspace for more details. Finetuning job's metadata: {job.run}."
227+
f"workspace for more details. Finetuning job's metadata: {job.run}.",
228228
)
229229
else:
230230
time.sleep(60)
@@ -236,7 +236,7 @@ def finetune(
236236
model_to_deploy = train_kwargs.get("register_to")
237237
job.endpoint_name = model_to_deploy.replace(".", "_")
238238
DatabricksProvider.deploy_finetuned_model(
239-
model_to_deploy, train_data_format, databricks_host, databricks_token, deploy_timeout
239+
model_to_deploy, train_data_format, databricks_host, databricks_token, deploy_timeout,
240240
)
241241
job.launch_completed = True
242242
# The finetuned model name should be in the format: "databricks/<endpoint_name>".
@@ -266,7 +266,7 @@ def _get_workspace_client() -> "WorkspaceClient":
266266
except ImportError:
267267
raise ImportError(
268268
"To use Databricks finetuning, please install the databricks-sdk package via "
269-
"`pip install databricks-sdk`."
269+
"`pip install databricks-sdk`.",
270270
)
271271
return WorkspaceClient()
272272

@@ -277,7 +277,7 @@ def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databric
277277
if not match:
278278
raise ValueError(
279279
f"Databricks Unity Catalog path must be in the format '/Volumes/<catalog>/<schema>/<volume>/...', but "
280-
f"received: {databricks_unity_catalog_path}."
280+
f"received: {databricks_unity_catalog_path}.",
281281
)
282282

283283
catalog = match.group("catalog")
@@ -290,7 +290,7 @@ def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databric
290290
except Exception:
291291
raise ValueError(
292292
f"Databricks Unity Catalog volume does not exist: {volume_path}, please create it on the Databricks "
293-
"workspace."
293+
"workspace.",
294294
)
295295

296296
try:
@@ -326,32 +326,32 @@ def _validate_chat_data(data: dict[str, Any]):
326326
if "messages" not in data:
327327
raise ValueError(
328328
"Each finetuning data must be a dict with a 'messages' key when `task=CHAT_COMPLETION`, but "
329-
f"received: {data}"
329+
f"received: {data}",
330330
)
331331

332332
if not isinstance(data["messages"], list):
333333
raise ValueError(
334334
"The value of the 'messages' key in each finetuning data must be a list of dicts with keys 'role' and "
335-
f"'content' when `task=CHAT_COMPLETION`, but received: {data['messages']}"
335+
f"'content' when `task=CHAT_COMPLETION`, but received: {data['messages']}",
336336
)
337337

338338
for message in data["messages"]:
339339
if "role" not in message:
340340
raise ValueError(f"Each message in the 'messages' list must contain a 'role' key, but received: {message}.")
341341
if "content" not in message:
342342
raise ValueError(
343-
f"Each message in the 'messages' list must contain a 'content' key, but received: {message}."
343+
f"Each message in the 'messages' list must contain a 'content' key, but received: {message}.",
344344
)
345345

346346

347347
def _validate_completion_data(data: dict[str, Any]):
348348
if "prompt" not in data:
349349
raise ValueError(
350350
"Each finetuning data must be a dict with a 'prompt' key when `task=INSTRUCTION_FINETUNE`, but "
351-
f"received: {data}"
351+
f"received: {data}",
352352
)
353353
if "response" not in data and "completion" not in data:
354354
raise ValueError(
355355
"Each finetuning data must be a dict with a 'response' or 'completion' key when "
356-
f"`task=INSTRUCTION_FINETUNE`, but received: {data}"
356+
f"`task=INSTRUCTION_FINETUNE`, but received: {data}",
357357
)

dspy/clients/lm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def forward(self, prompt=None, messages=None, **kwargs):
136136
"You can inspect the latest LM interactions with `dspy.inspect_history()`. "
137137
"To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
138138
f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
139-
" if the reason for truncation is repetition."
139+
" if the reason for truncation is repetition.",
140140
)
141141

142142
if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
@@ -166,7 +166,7 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
166166
"You can inspect the latest LM interactions with `dspy.inspect_history()`. "
167167
"To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
168168
f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
169-
" if the reason for truncation is repetition."
169+
" if the reason for truncation is repetition.",
170170
)
171171

172172
if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):

dspy/clients/lm_local.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def launch(lm: "LM", launch_kwargs: dict[str, Any] | None = None):
3232
except ImportError:
3333
raise ImportError(
3434
"For local model launching, please install sglang."
35-
"Navigate to https://docs.sglang.ai/start/install.html for the latest installation instructions!"
35+
"Navigate to https://docs.sglang.ai/start/install.html for the latest installation instructions!",
3636
)
3737

3838
if hasattr(lm, "process"):
@@ -195,7 +195,7 @@ def train_sft_locally(model_name, train_data, train_kwargs):
195195
except ImportError:
196196
raise ImportError(
197197
"For local finetuning, please install torch, transformers, and trl "
198-
"by running `pip install -U torch transformers accelerate trl peft`"
198+
"by running `pip install -U torch transformers accelerate trl peft`",
199199
)
200200

201201
device = train_kwargs.get("device", None)
@@ -220,7 +220,7 @@ def train_sft_locally(model_name, train_data, train_kwargs):
220220
if "max_seq_length" not in train_kwargs:
221221
train_kwargs["max_seq_length"] = 4096
222222
logger.info(
223-
f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}"
223+
f"The 'train_kwargs' parameter didn't include a 'max_seq_length', defaulting to {train_kwargs['max_seq_length']}",
224224
)
225225

226226
from datasets import Dataset

0 commit comments

Comments
 (0)