Skip to content

Zachmu/insert into select #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ var queries = []struct {
"SELECT i FROM mytable;",
[]sql.Row{{int64(1)}, {int64(2)}, {int64(3)}},
},
{
"SELECT s,i FROM mytable;",
[]sql.Row{
{"first row", int64(1)},
{"second row", int64(2)},
{"third row", int64(3)}},
},
{
"SELECT s,i FROM (select i,s from mytable) mt;",
[]sql.Row{
{"first row", int64(1)},
{"second row", int64(2)},
{"third row", int64(3)}},
},
{
"SELECT i + 1 FROM mytable;",
[]sql.Row{{int64(2)}, {int64(3)}, {int64(4)}},
Expand Down Expand Up @@ -2113,6 +2127,97 @@ func TestInsertInto(t *testing.T) {
"SELECT * FROM typestable WHERE id = 999;",
[]sql.Row{{int64(999), nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}},
},
{
"INSERT INTO mytable SELECT * from mytable",
[]sql.Row{{int64(3)}},
"SELECT * FROM mytable order by i",
[]sql.Row{
{int64(1), "first row"},
{int64(1), "first row"},
{int64(2), "second row"},
{int64(2), "second row"},
{int64(3), "third row"},
{int64(3), "third row"},
},
},
{
"INSERT INTO mytable(i,s) SELECT * from mytable",
[]sql.Row{{int64(3)}},
"SELECT * FROM mytable order by i",
[]sql.Row{
{int64(1), "first row"},
{int64(1), "first row"},
{int64(2), "second row"},
{int64(2), "second row"},
{int64(3), "third row"},
{int64(3), "third row"},
},
},
{
"INSERT INTO mytable (i,s) SELECT i+10, 'new' from mytable",
[]sql.Row{{int64(3)}},
"SELECT * FROM mytable order by i",
[]sql.Row{
{int64(1), "first row"},
{int64(2), "second row"},
{int64(3), "third row"},
{int64(11), "new"},
{int64(12), "new"},
{int64(13), "new"},
},
},
{
"INSERT INTO mytable SELECT i2, s2 from othertable",
[]sql.Row{{int64(3)}},
"SELECT * FROM mytable order by i,s",
[]sql.Row{
{int64(1), "first row"},
{int64(1), "third"},
{int64(2), "second"},
{int64(2), "second row"},
{int64(3), "first"},
{int64(3), "third row"},
},
},
{
"INSERT INTO mytable (s,i) SELECT * from othertable",
[]sql.Row{{int64(3)}},
"SELECT * FROM mytable order by i,s",
[]sql.Row{
{int64(1), "first row"},
{int64(1), "third"},
{int64(2), "second"},
{int64(2), "second row"},
{int64(3), "first"},
{int64(3), "third row"},
},
},
{
"INSERT INTO mytable (s,i) SELECT concat(m.s, o.s2), m.i from othertable o join mytable m on m.i=o.i2",
[]sql.Row{{int64(3)}},
"SELECT * FROM mytable order by i,s",
[]sql.Row{
{int64(1), "first row"},
{int64(1), "first rowthird"},
{int64(2), "second row"},
{int64(2), "second rowsecond"},
{int64(3), "third row"},
{int64(3), "third rowfirst"},
},
},
{
"INSERT INTO mytable (i,s) SELECT (i + 10.0) / 10.0 + 10, concat(s, ' new') from mytable",
[]sql.Row{{int64(3)}},
"SELECT * FROM mytable order by i, s",
[]sql.Row{
{int64(1), "first row"},
{int64(2), "second row"},
{int64(3), "third row"},
{int64(11), "first row new"},
{int64(11), "second row new"},
{int64(11), "third row new"},
},
},
}

for _, insertion := range insertions {
Expand Down Expand Up @@ -2168,6 +2273,22 @@ func TestInsertIntoErrors(t *testing.T) {
"null given to non-nullable",
"INSERT INTO mytable (i, s) VALUES (null, 'y');",
},
{
"incompatible types",
"INSERT INTO mytable (i, s) select * from othertable",
},
{
"column count mismatch in select",
"INSERT INTO mytable (i) select * from othertable",
},
{
"column count mismatch in select",
"INSERT INTO mytable select s from othertable",
},
{
"column count mismatch in join select",
"INSERT INTO mytable (s,i) SELECT * from othertable o join mytable m on m.i=o.i2",
},
}

for _, expectedFailure := range expectedFailures {
Expand Down
34 changes: 24 additions & 10 deletions sql/analyzer/prune_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ func pruneColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
return n, nil
}

// Skip pruning columns for insert statements. For inserts involving a select (INSERT INTO table1 SELECT a,b FROM
// table2), all columns from the select are used for the insert, and error checking for schema compatibility
// happens at execution time. Otherwise the logic below will convert a Project to a ResolvedTable for the selected
// table, which can alter the column order of the select.
if _, ok := n.(*plan.InsertInto); ok {
return n, nil
}

if describe, ok := n.(*plan.DescribeQuery); ok {
pruned, err := pruneColumns(ctx, a, describe.Child)
if err != nil {
Expand All @@ -25,16 +33,7 @@ func pruneColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
return plan.NewDescribeQuery(describe.Format, pruned), nil
}

columns := make(usedColumns)

// All the columns required for the output of the query must be mark as
// used, otherwise the schema would change.
for _, col := range n.Schema() {
if _, ok := columns[col.Source]; !ok {
columns[col.Source] = make(map[string]struct{})
}
columns[col.Source][col.Name] = struct{}{}
}
columns := findRequiredColumns(n)

findUsedColumns(columns, n)

Expand All @@ -51,6 +50,21 @@ func pruneColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
return fixRemainingFieldsIndexes(n)
}

func findRequiredColumns(n sql.Node) usedColumns {
columns := make(usedColumns)

// All the columns required for the output of the query must be mark as
// used, otherwise the schema would change.
for _, col := range n.Schema() {
if _, ok := columns[col.Source]; !ok {
columns[col.Source] = make(map[string]struct{})
}
columns[col.Source][col.Name] = struct{}{}
}

return columns
}

func pruneSubqueryColumns(
ctx *sql.Context,
a *Analyzer,
Expand Down
21 changes: 17 additions & 4 deletions sql/plan/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
p.Columns[i] = f.Name
}
} else {
err = p.validateColumns(ctx, dstSchema)
err = p.validateColumns(dstSchema)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -139,7 +139,7 @@ func (p *InsertInto) Execute(ctx *sql.Context) (int, error) {
return i, err
}

err = p.validateNullability(ctx, dstSchema, row)
err = p.validateNullability(dstSchema, row)
if err != nil {
_ = iter.Close()
return i, err
Expand Down Expand Up @@ -223,13 +223,19 @@ func (p *InsertInto) validateValueCount(ctx *sql.Context) error {
return ErrInsertIntoMismatchValueCount.New()
}
}
case *ResolvedTable:
return p.assertSchemasMatch(node.Schema())
case *Project:
return p.assertSchemasMatch(node.Schema())
case *InnerJoin:
return p.assertSchemasMatch(node.Schema())
default:
return ErrInsertIntoUnsupportedValues.New(node)
}
return nil
}

func (p *InsertInto) validateColumns(ctx *sql.Context, dstSchema sql.Schema) error {
func (p *InsertInto) validateColumns(dstSchema sql.Schema) error {
dstColNames := make(map[string]struct{})
for _, dstCol := range dstSchema {
dstColNames[dstCol.Name] = struct{}{}
Expand All @@ -248,11 +254,18 @@ func (p *InsertInto) validateColumns(ctx *sql.Context, dstSchema sql.Schema) err
return nil
}

func (p *InsertInto) validateNullability(ctx *sql.Context, dstSchema sql.Schema, row sql.Row) error {
func (p *InsertInto) validateNullability(dstSchema sql.Schema, row sql.Row) error {
for i, col := range dstSchema {
if !col.Nullable && row[i] == nil {
return ErrInsertIntoNonNullableProvidedNull.New(col.Name)
}
}
return nil
}

func (p *InsertInto) assertSchemasMatch(schema sql.Schema) error {
if len(p.Columns) != len(schema) {
return ErrInsertIntoMismatchValueCount.New()
}
return nil
}