Skip to content

Commit ba77958

Browse files
redgoldlaceelprans
authored andcommitted
Add support for WHERE clause in copy_to methods
1 parent bf74e88 commit ba77958

File tree

3 files changed

+64
-12
lines changed

3 files changed

+64
-12
lines changed

asyncpg/connection.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ async def copy_to_table(self, table_name, *, source,
826826
delimiter=None, null=None, header=None,
827827
quote=None, escape=None, force_quote=None,
828828
force_not_null=None, force_null=None,
829-
encoding=None):
829+
encoding=None, where=None):
830830
"""Copy data to the specified table.
831831
832832
:param str table_name:
@@ -845,6 +845,15 @@ async def copy_to_table(self, table_name, *, source,
845845
:param str schema_name:
846846
An optional schema name to qualify the table.
847847
848+
:param str where:
849+
An optional condition used to filter rows when copying.
850+
851+
.. note::
852+
853+
Usage of this parameter requires support for the
854+
``COPY FROM ... WHERE`` syntax, introduced in
855+
PostgreSQL version 12.
856+
848857
:param float timeout:
849858
Optional timeout value in seconds.
850859
@@ -872,6 +881,9 @@ async def copy_to_table(self, table_name, *, source,
872881
https://www.postgresql.org/docs/current/static/sql-copy.html
873882
874883
.. versionadded:: 0.11.0
884+
885+
.. versionadded:: 0.27.0
886+
Added ``where`` parameter.
875887
"""
876888
tabname = utils._quote_ident(table_name)
877889
if schema_name:
@@ -883,21 +895,22 @@ async def copy_to_table(self, table_name, *, source,
883895
else:
884896
cols = ''
885897

898+
cond = self._format_copy_where(where)
886899
opts = self._format_copy_opts(
887900
format=format, oids=oids, freeze=freeze, delimiter=delimiter,
888901
null=null, header=header, quote=quote, escape=escape,
889902
force_not_null=force_not_null, force_null=force_null,
890903
encoding=encoding
891904
)
892905

893-
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
894-
tab=tabname, cols=cols, opts=opts)
906+
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
907+
tab=tabname, cols=cols, opts=opts, cond=cond)
895908

896909
return await self._copy_in(copy_stmt, source, timeout)
897910

898911
async def copy_records_to_table(self, table_name, *, records,
899912
columns=None, schema_name=None,
900-
timeout=None):
913+
timeout=None, where=None):
901914
"""Copy a list of records to the specified table using binary COPY.
902915
903916
:param str table_name:
@@ -914,6 +927,16 @@ async def copy_records_to_table(self, table_name, *, records,
914927
:param str schema_name:
915928
An optional schema name to qualify the table.
916929
930+
:param str where:
931+
An optional condition used to filter rows when copying.
932+
933+
.. note::
934+
935+
Usage of this parameter requires support for the
936+
``COPY FROM ... WHERE`` syntax, introduced in
937+
PostgreSQL version 12.
938+
939+
917940
:param float timeout:
918941
Optional timeout value in seconds.
919942
@@ -958,6 +981,9 @@ async def copy_records_to_table(self, table_name, *, records,
958981
959982
.. versionchanged:: 0.24.0
960983
The ``records`` argument may be an asynchronous iterable.
984+
985+
.. versionadded:: 0.27.0
986+
Added ``where`` parameter.
961987
"""
962988
tabname = utils._quote_ident(table_name)
963989
if schema_name:
@@ -975,14 +1001,27 @@ async def copy_records_to_table(self, table_name, *, records,
9751001

9761002
intro_ps = await self._prepare(intro_query, use_cache=True)
9771003

1004+
cond = self._format_copy_where(where)
9781005
opts = '(FORMAT binary)'
9791006

980-
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
981-
tab=tabname, cols=cols, opts=opts)
1007+
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
1008+
tab=tabname, cols=cols, opts=opts, cond=cond)
9821009

9831010
return await self._protocol.copy_in(
9841011
copy_stmt, None, None, records, intro_ps._state, timeout)
9851012

1013+
def _format_copy_where(self, where):
1014+
if where and not self._server_caps.sql_copy_from_where:
1015+
raise exceptions.UnsupportedServerFeatureError(
1016+
'the `where` parameter requires PostgreSQL 12 or later')
1017+
1018+
if where:
1019+
where_clause = 'WHERE ' + where
1020+
else:
1021+
where_clause = ''
1022+
1023+
return where_clause
1024+
9861025
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
9871026
delimiter=None, null=None, header=None, quote=None,
9881027
escape=None, force_quote=None, force_not_null=None,
@@ -2326,7 +2365,7 @@ class _ConnectionProxy:
23262365
ServerCapabilities = collections.namedtuple(
23272366
'ServerCapabilities',
23282367
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
2329-
'sql_close_all'])
2368+
'sql_close_all', 'sql_copy_from_where'])
23302369
ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.'
23312370

23322371

@@ -2338,27 +2377,31 @@ def _detect_server_capabilities(server_version, connection_settings):
23382377
plpgsql = False
23392378
sql_reset = True
23402379
sql_close_all = False
2380+
sql_copy_from_where = False
23412381
elif hasattr(connection_settings, 'crdb_version'):
23422382
# CockroachDB detected.
23432383
advisory_locks = False
23442384
notifications = False
23452385
plpgsql = False
23462386
sql_reset = False
23472387
sql_close_all = False
2388+
sql_copy_from_where = False
23482389
elif hasattr(connection_settings, 'crate_version'):
23492390
# CrateDB detected.
23502391
advisory_locks = False
23512392
notifications = False
23522393
plpgsql = False
23532394
sql_reset = False
23542395
sql_close_all = False
2396+
sql_copy_from_where = False
23552397
else:
23562398
# Standard PostgreSQL server assumed.
23572399
advisory_locks = True
23582400
notifications = True
23592401
plpgsql = True
23602402
sql_reset = True
23612403
sql_close_all = True
2404+
sql_copy_from_where = server_version.major >= 12
23622405

23632406
return ServerCapabilities(
23642407
advisory_locks=advisory_locks,

asyncpg/exceptions/_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
1414
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
1515
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
16-
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched')
16+
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
17+
'UnsupportedServerFeatureError')
1718

1819

1920
def _is_asyncpg_class(cls):
@@ -228,6 +229,10 @@ class UnsupportedClientFeatureError(InterfaceError):
228229
"""Requested feature is unsupported by asyncpg."""
229230

230231

232+
class UnsupportedServerFeatureError(InterfaceError):
233+
"""Requested feature is unsupported by PostgreSQL server."""
234+
235+
231236
class InterfaceWarning(InterfaceMessage, UserWarning):
232237
"""A warning caused by an improper use of asyncpg API."""
233238

asyncpg/pool.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,8 @@ async def copy_to_table(
739739
force_quote=None,
740740
force_not_null=None,
741741
force_null=None,
742-
encoding=None
742+
encoding=None,
743+
where=None
743744
):
744745
"""Copy data to the specified table.
745746
@@ -768,7 +769,8 @@ async def copy_to_table(
768769
force_quote=force_quote,
769770
force_not_null=force_not_null,
770771
force_null=force_null,
771-
encoding=encoding
772+
encoding=encoding,
773+
where=where
772774
)
773775

774776
async def copy_records_to_table(
@@ -778,7 +780,8 @@ async def copy_records_to_table(
778780
records,
779781
columns=None,
780782
schema_name=None,
781-
timeout=None
783+
timeout=None,
784+
where=None
782785
):
783786
"""Copy a list of records to the specified table using binary COPY.
784787
@@ -795,7 +798,8 @@ async def copy_records_to_table(
795798
records=records,
796799
columns=columns,
797800
schema_name=schema_name,
798-
timeout=timeout
801+
timeout=timeout,
802+
where=where
799803
)
800804

801805
def acquire(self, *, timeout=None):

0 commit comments

Comments
 (0)