Skip to content

Type-checking cleanups #962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dev = [
"psutil>=5.9.3,<6",
"pydocstyle>=6.3.0,<7",
"pydoctor>=24.11.1,<25",
"pyright==1.1.402",
"pyright==1.1.403",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asked similarly at nexus-rpc/sdk-python#13 (comment), should we fixate to a specific Pyright version? Will this ensure we miss new things that may effect our users in later versions? I see we use a fixed mypy version too. I don't necessarily mind it so long as we keep on top of updating regularly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, let's consider changing to something like >= 1.1 in a future PR.

"pytest~=7.4",
"pytest-asyncio>=0.21,<0.22",
"pytest-timeout~=2.2",
Expand All @@ -69,14 +69,16 @@ lint = [
{cmd = "uv run ruff check --select I"},
{cmd = "uv run ruff format --check"},
{ref = "lint-types"},
{cmd = "uv run pyright"},
{ref = "lint-docs"},
]
bridge-lint = { cmd = "cargo clippy -- -D warnings", cwd = "temporalio/bridge" }
# TODO(cretz): Why does pydocstyle complain about @overload missing docs after
# https://github.com/PyCQA/pydocstyle/pull/511?
lint-docs = "uv run pydocstyle --ignore-decorators=overload"
lint-types = "uv run mypy --namespace-packages --check-untyped-defs ."
lint-types = [
{ cmd = "uv run pyright"},
{ cmd = "uv run mypy --namespace-packages --check-untyped-defs ."},
]
run-bench = "uv run python scripts/run_bench.py"
test = "uv run pytest"

Expand All @@ -100,7 +102,7 @@ filterwarnings = [
[tool.cibuildwheel]
before-all = "pip install protoc-wheel-0"
build = "cp39-win_amd64 cp39-manylinux_x86_64 cp39-manylinux_aarch64 cp39-macosx_x86_64 cp39-macosx_arm64"
build-verbosity = "1"
build-verbosity = 1

[tool.cibuildwheel.macos]
environment = { MACOSX_DEPLOYMENT_TARGET = "10.12" }
Expand Down Expand Up @@ -158,16 +160,24 @@ project-name = "Temporal Python"
sidebar-expand-depth = 2

[tool.pyright]
reportUnknownVariableType = "none"
reportUnknownParameterType = "none"
reportUnusedCallResult = "none"
reportImplicitStringConcatenation = "none"
reportPrivateUsage = "none"
enableTypeIgnoreComments = true
reportAny = "none"
reportCallInDefaultInitializer = "none"
reportExplicitAny = "none"
reportIgnoreCommentWithoutRule = "none"
reportImplicitOverride = "none"
reportImplicitStringConcatenation = "none"
reportImportCycles = "none"
reportMissingTypeArgument = "none"
reportAny = "none"
enableTypeIgnoreComments = true

reportPrivateUsage = "none"
reportUnannotatedClassAttribute = "none"
reportUnknownArgumentType = "none"
reportUnknownMemberType = "none"
reportUnknownParameterType = "none"
reportUnknownVariableType = "none"
reportUnnecessaryIsInstance = "none"
reportUnnecessaryTypeIgnoreComment = "none"
reportUnusedCallResult = "none"
include = ["temporalio", "tests"]
exclude = [
"temporalio/api",
Expand Down
2 changes: 1 addition & 1 deletion temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5772,7 +5772,7 @@ async def get_worker_task_reachability(


class _ClientImpl(OutboundInterceptor):
def __init__(self, client: Client) -> None:
def __init__(self, client: Client) -> None: # type: ignore
# We are intentionally not calling the base class's __init__ here
self._client = client

Expand Down
4 changes: 2 additions & 2 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from agents.models.multi_provider import MultiProvider
from typing_extensions import Required, TypedDict

from temporalio import activity, workflow
from temporalio import activity
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater


Expand Down Expand Up @@ -106,7 +106,7 @@ class ActivityModelInput(TypedDict, total=False):

model_name: Optional[str]
system_instructions: Optional[str]
input: Required[Union[str, list[TResponseInputItem]]] # type: ignore
input: Required[Union[str, list[TResponseInputItem]]]
model_settings: Required[ModelSettings]
tools: list[ToolInput]
output_schema: Optional[AgentOutputSchemaInput]
Expand Down
30 changes: 13 additions & 17 deletions temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from __future__ import annotations

import logging
from datetime import timedelta
from typing import Optional

from temporalio import workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.workflow import ActivityCancellationType, VersioningIntent

logger = logging.getLogger(__name__)

from typing import Any, AsyncIterator, Optional, Sequence, Union, cast
from typing import Any, AsyncIterator, Sequence, Union, cast

from agents import (
AgentOutputSchema,
Expand Down Expand Up @@ -57,7 +54,7 @@ def __init__(
async def get_response(
self,
system_instructions: Optional[str],
input: Union[str, list[TResponseInputItem]],
input: Union[str, list[TResponseInputItem], dict[str, str]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Optional[AgentOutputSchemaBase],
Expand All @@ -67,7 +64,9 @@ async def get_response(
previous_response_id: Optional[str],
prompt: Optional[ResponsePromptParam],
) -> ModelResponse:
def get_summary(input: Union[str, list[TResponseInputItem]]) -> str:
def get_summary(
input: Union[str, list[TResponseInputItem], dict[str, str]],
) -> str:
### Activity summary shown in the UI
try:
max_size = 100
Expand All @@ -88,21 +87,18 @@ def get_summary(input: Union[str, list[TResponseInputItem]]) -> str:
return ""

def make_tool_info(tool: Tool) -> ToolInput:
if isinstance(tool, FileSearchTool):
return cast(FileSearchTool, tool)
elif isinstance(tool, WebSearchTool):
return cast(WebSearchTool, tool)
if isinstance(tool, (FileSearchTool, WebSearchTool)):
return tool
elif isinstance(tool, ComputerTool):
raise NotImplementedError(
"Computer search preview is not supported in Temporal model"
)
elif isinstance(tool, FunctionTool):
t = cast(FunctionToolInput, tool)
return FunctionToolInput(
name=t.name,
description=t.description,
params_json_schema=t.params_json_schema,
strict_json_schema=t.strict_json_schema,
name=tool.name,
description=tool.description,
params_json_schema=tool.params_json_schema,
strict_json_schema=tool.strict_json_schema,
)
else:
raise ValueError(f"Unknown tool type: {tool.name}")
Expand Down Expand Up @@ -141,7 +137,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
activity_input = ActivityModelInput(
model_name=self.model_name,
system_instructions=system_instructions,
input=input,
input=cast(Union[str, list[TResponseInputItem]], input),
model_settings=model_settings,
tools=tool_infos,
output_schema=output_schema_input,
Expand Down Expand Up @@ -169,7 +165,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
def stream_response(
self,
system_instructions: Optional[str],
input: Union[str, list][TResponseInputItem], # type: ignore
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Optional[AgentOutputSchemaBase],
Expand Down
15 changes: 7 additions & 8 deletions temporalio/contrib/openai_agents/_trace_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import Any, Mapping, Protocol, Type, cast
from typing import Any, Mapping, Protocol, Type

from agents import CustomSpanData, custom_span, get_current_span, trace
from agents import custom_span, get_current_span, trace
from agents.tracing import (
get_trace_provider,
)
from agents.tracing.provider import DefaultTraceProvider
from agents.tracing.spans import NoOpSpan, SpanImpl
from agents.tracing.spans import NoOpSpan

import temporalio.activity
import temporalio.api.common.v1
Expand Down Expand Up @@ -116,7 +115,7 @@ class OpenAIAgentsTracingInterceptor(
worker = Worker(client, task_queue="my-task-queue", interceptors=[interceptor])
"""

def __init__(
def __init__( # type: ignore[reportMissingSuperCall]
self,
payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter,
) -> None:
Expand Down Expand Up @@ -189,7 +188,7 @@ async def start_workflow(
**({"temporal:workflowId": input.id} if input.id else {}),
}
data = {"workflowId": input.id} if input.id else None
span_name = f"temporal:startWorkflow"
span_name = "temporal:startWorkflow"
if get_trace_provider().get_current_trace() is None:
with trace(
span_name + ":" + input.workflow, metadata=metadata, group_id=input.id
Expand All @@ -208,7 +207,7 @@ async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> A
**({"temporal:workflowId": input.id} if input.id else {}),
}
data = {"workflowId": input.id, "query": input.query}
span_name = f"temporal:queryWorkflow"
span_name = "temporal:queryWorkflow"
if get_trace_provider().get_current_trace() is None:
with trace(span_name, metadata=metadata, group_id=input.id):
with custom_span(name=span_name, data=data):
Expand All @@ -227,7 +226,7 @@ async def signal_workflow(
**({"temporal:workflowId": input.id} if input.id else {}),
}
data = {"workflowId": input.id, "signal": input.signal}
span_name = f"temporal:signalWorkflow"
span_name = "temporal:signalWorkflow"
if get_trace_provider().get_current_trace() is None:
with trace(span_name, metadata=metadata, group_id=input.id):
with custom_span(name=span_name, data=data):
Expand Down
10 changes: 3 additions & 7 deletions temporalio/nexus/_operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@

import dataclasses
import logging
from collections.abc import Awaitable, Mapping, MutableMapping, Sequence
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import timedelta
from typing import (
Any,
Awaitable,
Callable,
Mapping,
MutableMapping,
Optional,
Sequence,
Type,
Union,
overload,
)
Expand Down Expand Up @@ -305,7 +301,7 @@ async def start_workflow(
args: Sequence[Any] = [],
id: str,
task_queue: Optional[str] = None,
result_type: Optional[Type[ReturnType]] = None,
result_type: Optional[type[ReturnType]] = None,
execution_timeout: Optional[timedelta] = None,
run_timeout: Optional[timedelta] = None,
task_timeout: Optional[timedelta] = None,
Expand Down Expand Up @@ -340,7 +336,7 @@ async def start_workflow(
args: Sequence[Any] = [],
id: str,
task_queue: Optional[str] = None,
result_type: Optional[Type] = None,
result_type: Optional[type] = None,
execution_timeout: Optional[timedelta] = None,
run_timeout: Optional[timedelta] = None,
task_timeout: Optional[timedelta] = None,
Expand Down
2 changes: 1 addition & 1 deletion temporalio/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _on_logs(
# We can't access logging module's start time and it's not worth
# doing difference math to get relative time right here, so
# we'll make time relative to _our_ module's start time
self.relativeCreated = (record.created - _module_start_time) * 1000
self.relativeCreated = (record.created - _module_start_time) * 1000 # type: ignore[reportUninitializedInstanceVariable]
# Log the record
self.logger.handle(record)

Expand Down
14 changes: 3 additions & 11 deletions temporalio/worker/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,9 @@

from typing_extensions import TypeAlias, TypedDict

import temporalio.activity
import temporalio.api.common.v1
import temporalio.bridge.client
import temporalio.bridge.proto
import temporalio.bridge.proto.activity_result
import temporalio.bridge.proto.activity_task
import temporalio.bridge.proto.common
import temporalio.bridge.worker
import temporalio.client
import temporalio.converter
import temporalio.exceptions
import temporalio.common
import temporalio.runtime
import temporalio.service
from temporalio.common import (
Expand Down Expand Up @@ -578,8 +570,8 @@ def config(self) -> WorkerConfig:
Configuration, shallow-copied.
"""
config = self._config.copy()
config["activities"] = list(config["activities"])
config["workflows"] = list(config["workflows"])
config["activities"] = list(config.get("activities", []))
config["workflows"] = list(config.get("workflows", []))
return config

@property
Expand Down
6 changes: 3 additions & 3 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2514,15 +2514,15 @@ def get_debug(self) -> bool:


class _WorkflowInboundImpl(WorkflowInboundInterceptor):
def __init__(
def __init__( # type: ignore
self,
instance: _WorkflowInstanceImpl,
) -> None:
# We are intentionally not calling the base class's __init__ here
self._instance = instance

def init(self, outbound: WorkflowOutboundInterceptor) -> None:
self._outbound = outbound
self._outbound = outbound # type: ignore

async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any:
args = [self._instance._object] + list(input.args)
Expand Down Expand Up @@ -2572,7 +2572,7 @@ async def handle_update_handler(self, input: HandleUpdateInput) -> Any:


class _WorkflowOutboundImpl(WorkflowOutboundInterceptor):
def __init__(self, instance: _WorkflowInstanceImpl) -> None:
def __init__(self, instance: _WorkflowInstanceImpl) -> None: # type: ignore
# We are intentionally not calling the base class's __init__ here
self._instance = instance

Expand Down
Loading