@@ -826,7 +826,7 @@ async def copy_to_table(self, table_name, *, source,
826
826
delimiter = None , null = None , header = None ,
827
827
quote = None , escape = None , force_quote = None ,
828
828
force_not_null = None , force_null = None ,
829
- encoding = None ):
829
+ encoding = None , where = None ):
830
830
"""Copy data to the specified table.
831
831
832
832
:param str table_name:
@@ -845,6 +845,15 @@ async def copy_to_table(self, table_name, *, source,
845
845
:param str schema_name:
846
846
An optional schema name to qualify the table.
847
847
848
+ :param str where:
849
+ An optional condition used to filter rows when copying.
850
+
851
+ .. note::
852
+
853
+ Usage of this parameter requires support for the
854
+ ``COPY FROM ... WHERE`` syntax, introduced in
855
+ PostgreSQL version 12.
856
+
848
857
:param float timeout:
849
858
Optional timeout value in seconds.
850
859
@@ -872,6 +881,9 @@ async def copy_to_table(self, table_name, *, source,
872
881
https://www.postgresql.org/docs/current/static/sql-copy.html
873
882
874
883
.. versionadded:: 0.11.0
884
+
885
+ .. versionadded:: 0.27.0
886
+ Added ``where`` parameter.
875
887
"""
876
888
tabname = utils ._quote_ident (table_name )
877
889
if schema_name :
@@ -883,21 +895,22 @@ async def copy_to_table(self, table_name, *, source,
883
895
else :
884
896
cols = ''
885
897
898
+ cond = self ._format_copy_where (where )
886
899
opts = self ._format_copy_opts (
887
900
format = format , oids = oids , freeze = freeze , delimiter = delimiter ,
888
901
null = null , header = header , quote = quote , escape = escape ,
889
902
force_not_null = force_not_null , force_null = force_null ,
890
903
encoding = encoding
891
904
)
892
905
893
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
894
- tab = tabname , cols = cols , opts = opts )
906
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
907
+ tab = tabname , cols = cols , opts = opts , cond = cond )
895
908
896
909
return await self ._copy_in (copy_stmt , source , timeout )
897
910
898
911
async def copy_records_to_table (self , table_name , * , records ,
899
912
columns = None , schema_name = None ,
900
- timeout = None ):
913
+ timeout = None , where = None ):
901
914
"""Copy a list of records to the specified table using binary COPY.
902
915
903
916
:param str table_name:
@@ -914,6 +927,16 @@ async def copy_records_to_table(self, table_name, *, records,
914
927
:param str schema_name:
915
928
An optional schema name to qualify the table.
916
929
930
+ :param str where:
931
+ An optional condition used to filter rows when copying.
932
+
933
+ .. note::
934
+
935
+ Usage of this parameter requires support for the
936
+ ``COPY FROM ... WHERE`` syntax, introduced in
937
+ PostgreSQL version 12.
938
+
939
+
917
940
:param float timeout:
918
941
Optional timeout value in seconds.
919
942
@@ -958,6 +981,9 @@ async def copy_records_to_table(self, table_name, *, records,
958
981
959
982
.. versionchanged:: 0.24.0
960
983
The ``records`` argument may be an asynchronous iterable.
984
+
985
+ .. versionadded:: 0.27.0
986
+ Added ``where`` parameter.
961
987
"""
962
988
tabname = utils ._quote_ident (table_name )
963
989
if schema_name :
@@ -975,14 +1001,27 @@ async def copy_records_to_table(self, table_name, *, records,
975
1001
976
1002
intro_ps = await self ._prepare (intro_query , use_cache = True )
977
1003
1004
+ cond = self ._format_copy_where (where )
978
1005
opts = '(FORMAT binary)'
979
1006
980
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
981
- tab = tabname , cols = cols , opts = opts )
1007
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
1008
+ tab = tabname , cols = cols , opts = opts , cond = cond )
982
1009
983
1010
return await self ._protocol .copy_in (
984
1011
copy_stmt , None , None , records , intro_ps ._state , timeout )
985
1012
1013
+ def _format_copy_where (self , where ):
1014
+ if where and not self ._server_caps .sql_copy_from_where :
1015
+ raise exceptions .UnsupportedServerFeatureError (
1016
+ 'the `where` parameter requires PostgreSQL 12 or later' )
1017
+
1018
+ if where :
1019
+ where_clause = 'WHERE ' + where
1020
+ else :
1021
+ where_clause = ''
1022
+
1023
+ return where_clause
1024
+
986
1025
def _format_copy_opts (self , * , format = None , oids = None , freeze = None ,
987
1026
delimiter = None , null = None , header = None , quote = None ,
988
1027
escape = None , force_quote = None , force_not_null = None ,
@@ -2326,7 +2365,7 @@ class _ConnectionProxy:
2326
2365
ServerCapabilities = collections .namedtuple (
2327
2366
'ServerCapabilities' ,
2328
2367
['advisory_locks' , 'notifications' , 'plpgsql' , 'sql_reset' ,
2329
- 'sql_close_all' ])
2368
+ 'sql_close_all' , 'sql_copy_from_where' ])
2330
2369
ServerCapabilities .__doc__ = 'PostgreSQL server capabilities.'
2331
2370
2332
2371
@@ -2338,27 +2377,31 @@ def _detect_server_capabilities(server_version, connection_settings):
2338
2377
plpgsql = False
2339
2378
sql_reset = True
2340
2379
sql_close_all = False
2380
+ sql_copy_from_where = False
2341
2381
elif hasattr (connection_settings , 'crdb_version' ):
2342
2382
# CockroachDB detected.
2343
2383
advisory_locks = False
2344
2384
notifications = False
2345
2385
plpgsql = False
2346
2386
sql_reset = False
2347
2387
sql_close_all = False
2388
+ sql_copy_from_where = False
2348
2389
elif hasattr (connection_settings , 'crate_version' ):
2349
2390
# CrateDB detected.
2350
2391
advisory_locks = False
2351
2392
notifications = False
2352
2393
plpgsql = False
2353
2394
sql_reset = False
2354
2395
sql_close_all = False
2396
+ sql_copy_from_where = False
2355
2397
else :
2356
2398
# Standard PostgreSQL server assumed.
2357
2399
advisory_locks = True
2358
2400
notifications = True
2359
2401
plpgsql = True
2360
2402
sql_reset = True
2361
2403
sql_close_all = True
2404
+ sql_copy_from_where = server_version .major >= 12
2362
2405
2363
2406
return ServerCapabilities (
2364
2407
advisory_locks = advisory_locks ,
0 commit comments