diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 91083f4018c06..e60c2d4cb7f63 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2899,6 +2899,36 @@ def to_sql( ... conn.execute(text("SELECT * FROM users")).fetchall() [(0, 'User 6'), (1, 'User 7')] + Use ``method`` to define a callable insertion method to do nothing + if there's a primary key conflict on a table in a PostgreSQL database. + + >>> from sqlalchemy.dialects.postgresql import insert + >>> def insert_on_conflict_nothing(table, conn, keys, data_iter): + ... # "a" is the primary key in "conflict_table" + ... data = [dict(zip(keys, row)) for row in data_iter] + ... stmt = insert(table.table).values(data).on_conflict_do_nothing(index_elements=["a"]) + ... result = conn.execute(stmt) + ... return result.rowcount + >>> df_conflict.to_sql("conflict_table", conn, if_exists="append", method=insert_on_conflict_nothing) # doctest: +SKIP + 0 + + For MySQL, a callable to update columns ``b`` and ``c`` if there's a conflict + on a primary key. + + >>> from sqlalchemy.dialects.mysql import insert + >>> def insert_on_conflict_update(table, conn, keys, data_iter): + ... # update columns "b" and "c" on primary key conflict + ... data = [dict(zip(keys, row)) for row in data_iter] + ... stmt = ( + ... insert(table.table) + ... .values(data) + ... ) + ... stmt = stmt.on_duplicate_key_update(b=stmt.inserted.b, c=stmt.inserted.c) + ... result = conn.execute(stmt) + ... return result.rowcount + >>> df_conflict.to_sql("conflict_table", conn, if_exists="append", method=insert_on_conflict_update) # doctest: +SKIP + 2 + Specify the dtype (especially useful for integers with missing values). Notice that while pandas is forced to store the data as floating point, the database supports nullable integers. When fetching the data with diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 7a3f7521d4a17..4fa9836f7294b 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -817,6 +817,117 @@ def psql_insert_copy(table, conn, keys, data_iter): tm.assert_frame_equal(result, expected) +@pytest.mark.db +@pytest.mark.parametrize("conn", postgresql_connectable) +def test_insertion_method_on_conflict_do_nothing(conn, request): + # GH 15988: Example in to_sql docstring + conn = request.getfixturevalue(conn) + + from sqlalchemy.dialects.postgresql import insert + from sqlalchemy.engine import Engine + from sqlalchemy.sql import text + + def insert_on_conflict(table, conn, keys, data_iter): + data = [dict(zip(keys, row)) for row in data_iter] + stmt = ( + insert(table.table) + .values(data) + .on_conflict_do_nothing(index_elements=["a"]) + ) + result = conn.execute(stmt) + return result.rowcount + + create_sql = text( + """ + CREATE TABLE test_insert_conflict ( + a integer PRIMARY KEY, + b numeric, + c text + ); + """ + ) + if isinstance(conn, Engine): + with conn.connect() as con: + with con.begin(): + con.execute(create_sql) + else: + with conn.begin(): + conn.execute(create_sql) + + expected = DataFrame([[1, 2.1, "a"]], columns=list("abc")) + expected.to_sql("test_insert_conflict", conn, if_exists="append", index=False) + + df_insert = DataFrame([[1, 3.2, "b"]], columns=list("abc")) + inserted = df_insert.to_sql( + "test_insert_conflict", + conn, + index=False, + if_exists="append", + method=insert_on_conflict, + ) + result = sql.read_sql_table("test_insert_conflict", conn) + tm.assert_frame_equal(result, expected) + assert inserted == 0 + + # Cleanup + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_insert_conflict") + + +@pytest.mark.db +@pytest.mark.parametrize("conn", mysql_connectable) +def test_insertion_method_on_conflict_update(conn, request): + # GH 14553: Example in to_sql docstring + conn = request.getfixturevalue(conn) + + from sqlalchemy.dialects.mysql import insert + from sqlalchemy.engine import Engine + from sqlalchemy.sql import text + + def insert_on_conflict(table, conn, keys, data_iter): + data = [dict(zip(keys, row)) for row in data_iter] + stmt = insert(table.table).values(data) + stmt = stmt.on_duplicate_key_update(b=stmt.inserted.b, c=stmt.inserted.c) + result = conn.execute(stmt) + return result.rowcount + + create_sql = text( + """ + CREATE TABLE test_insert_conflict ( + a INT PRIMARY KEY, + b FLOAT, + c VARCHAR(10) + ); + """ + ) + if isinstance(conn, Engine): + with conn.connect() as con: + with con.begin(): + con.execute(create_sql) + else: + with conn.begin(): + conn.execute(create_sql) + + df = DataFrame([[1, 2.1, "a"]], columns=list("abc")) + df.to_sql("test_insert_conflict", conn, if_exists="append", index=False) + + expected = DataFrame([[1, 3.2, "b"]], columns=list("abc")) + inserted = expected.to_sql( + "test_insert_conflict", + conn, + index=False, + if_exists="append", + method=insert_on_conflict, + ) + result = sql.read_sql_table("test_insert_conflict", conn) + tm.assert_frame_equal(result, expected) + assert inserted == 2 + + # Cleanup + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_insert_conflict") + + def test_execute_typeerror(sqlite_iris_engine): with pytest.raises(TypeError, match="pandas.io.sql.execute requires a connection"): with tm.assert_produces_warning(