diff --git a/CHANGELOG.md b/CHANGELOG.md index a71620c1..cfa1a2e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +### Added + +- Add `--create-missing-collections` option to load_queryables to automatically create collections that don't exist + ## [v0.9.6] ### Added diff --git a/src/pypgstac/src/pypgstac/pypgstac.py b/src/pypgstac/src/pypgstac/pypgstac.py index e0720850..7d073843 100644 --- a/src/pypgstac/src/pypgstac/pypgstac.py +++ b/src/pypgstac/src/pypgstac/pypgstac.py @@ -126,6 +126,7 @@ def load_queryables( collection_ids: Optional[list[str]] = None, delete_missing: Optional[bool] = False, index_fields: Optional[list[str]] = None, + create_missing_collections: Optional[bool] = False, ) -> None: """Load queryables from a JSON file. @@ -139,6 +140,9 @@ def load_queryables( index_fields: List of field names to create indexes for. If not provided, no indexes will be created. Creating too many indexes can negatively impact performance. + create_missing_collections: If True and collection_ids is specified, + automatically create empty collections for any + collection IDs that don't exist. """ # Read the queryables JSON file @@ -147,6 +151,59 @@ def load_queryables( queryables_data = item break # We only need the first item + # Create missing collections if requested + if create_missing_collections and collection_ids: + conn = self._db.connect() + with conn.cursor() as cur: + # Get list of existing collections + cur.execute( + "SELECT id FROM collections WHERE id = ANY(%s);", + [collection_ids], + ) + existing_collections = {r[0] for r in cur.fetchall()} + + # Create empty collections for any that don't exist + missing_collections = [ + cid for cid in collection_ids if cid not in existing_collections + ] + if missing_collections: + with conn.transaction(): + # Create a temporary table for bulk insert + cur.execute( + """ + DROP TABLE IF EXISTS tmp_collections; + CREATE TEMP TABLE tmp_collections + (content jsonb) ON COMMIT DROP; + """, + ) + # Insert collection records into temp table + with cur.copy( + "COPY tmp_collections (content) FROM stdin;", + ) as copy: + for cid in missing_collections: + empty_collection = { + "id": cid, + "stac_version": "1.0.0", + "description": "Automatically created collection" + + f" for {cid}", + "license": "proprietary", + "extent": { + "spatial": {"bbox": [[-180, -90, 180, 90]]}, + "temporal": {"interval": [[None, None]]}, + }, + } + copy.write_row( + (orjson.dumps(empty_collection).decode(),), + ) + + # Insert from temp table to collections + cur.execute( + """ + INSERT INTO collections (content) + SELECT content FROM tmp_collections; + """, + ) + if not queryables_data: raise ValueError(f"No valid JSON data found in {file}") diff --git a/src/pypgstac/tests/test_queryables.py b/src/pypgstac/tests/test_queryables.py index 31241818..48a12e73 100644 --- a/src/pypgstac/tests/test_queryables.py +++ b/src/pypgstac/tests/test_queryables.py @@ -514,6 +514,187 @@ def test_load_queryables_delete_missing_with_collections( partial_props_file.unlink() +def test_load_queryables_create_missing_collections(db: PgstacDB) -> None: + """Test loading queryables with create_missing_collections flag.""" + # Create a CLI instance + cli = PgstacCLI(dsn=db.dsn) + + # Try to load queryables for non-existent collections without the flag + non_existent_collections = ["test_collection_1", "test_collection_2"] + with pytest.raises(Exception) as exc_info: + cli.load_queryables( + str(TEST_QUERYABLES_JSON), + collection_ids=non_existent_collections, + ) + assert "do not exist" in str(exc_info.value) + + # Load queryables with create_missing_collections flag + cli.load_queryables( + str(TEST_QUERYABLES_JSON), + collection_ids=non_existent_collections, + create_missing_collections=True, + ) + + # Verify that the collections were created + result = db.query( + """ + SELECT id, content + FROM collections + WHERE id = ANY(%s) + ORDER BY id; + """, + [non_existent_collections], + ) + + # Convert result to a list of dictionaries + collections = [{"id": row[0], "content": row[1]} for row in result] + + # Check that both collections were created + assert len(collections) == 2 + for collection in collections: + assert collection["id"] in non_existent_collections + content = collection["content"] + # Verify required STAC fields + assert content["stac_version"] == "1.0.0" + assert "description" in content + assert content["license"] == "proprietary" + assert "extent" in content + assert "spatial" in content["extent"] + assert "temporal" in content["extent"] + assert content["extent"]["spatial"]["bbox"] == [[-180, -90, 180, 90]] + assert content["extent"]["temporal"]["interval"] == [[None, None]] + + # Verify that queryables were loaded for these collections + result = db.query( + """ + SELECT name, collection_ids + FROM queryables + WHERE name LIKE 'test:%%' + AND collection_ids = %s::text[] + ORDER BY name; + """, + [non_existent_collections], + ) + + # Convert result to a list of dictionaries + queryables = [{"name": row[0], "collection_ids": row[1]} for row in result] + + # Check that queryables were created and associated with the collections + assert len(queryables) == 5 # All test properties + for queryable in queryables: + assert set(queryable["collection_ids"]) == set(non_existent_collections) + +def test_load_queryables_with_multiple_hyphenated_collections(db: PgstacDB) -> None: + """Test loading queryables for multiple collections with hyphenated names.""" + # Create a CLI instance + cli = PgstacCLI(dsn=db.dsn) + + # Create collections with hyphenated names + hyphenated_collections = [ + "test-collection-1", + "my-hyphenated-collection-2", + "another-test-collection-3", + ] + cli.load_queryables( + str(TEST_QUERYABLES_JSON), + collection_ids=hyphenated_collections, + create_missing_collections=True, + index_fields=["test:string_prop", "test:number_prop"], + ) + + # Verify that all collections were created + result = db.query( + """ + SELECT id FROM collections WHERE id = ANY(%s); + """, + [hyphenated_collections], + ) + collections = [row[0] for row in result] + assert len(collections) == len(hyphenated_collections) + assert set(collections) == set(hyphenated_collections) + + # Verify that queryables were loaded for all collections + result = db.query( + """ + SELECT name, collection_ids, property_index_type + FROM queryables + WHERE name LIKE 'test:%%' + AND collection_ids @> %s + ORDER BY name; + """, + [hyphenated_collections], + ) + + # Convert result to a list of dictionaries + queryables = [ + {"name": row[0], "collection_ids": row[1], "property_index_type": row[2]} + for row in result + ] + + # Check that all queryables were created and associated with all collections + assert len(queryables) == 5 # All test properties should be present + for queryable in queryables: + # Verify all collections are associated with each queryable + assert set(hyphenated_collections).issubset(set(queryable["collection_ids"])) + # Check that only specified properties have indexes + if queryable["name"] in ["test:string_prop", "test:number_prop"]: + assert queryable["property_index_type"] == "BTREE" + else: + assert queryable["property_index_type"] is None + +def test_load_queryables_with_hyphenated_collection(db: PgstacDB) -> None: + """Test loading queryables for a collection with a hyphenated name.""" + # Create a CLI instance + cli = PgstacCLI(dsn=db.dsn) + + # Create a collection with a hyphenated name + hyphenated_collection = "test-collection-with-hyphens" + cli.load_queryables( + str(TEST_QUERYABLES_JSON), + collection_ids=[hyphenated_collection], + create_missing_collections=True, + index_fields=["test:string_prop"], + ) + + # Verify that the collection was created + result = db.query( + """ + SELECT id FROM collections WHERE id = %s; + """, + [hyphenated_collection], + ) + collections = [row[0] for row in result] + assert len(collections) == 1 + assert collections[0] == hyphenated_collection + + # Verify that queryables were loaded for this collection + result = db.query( + """ + SELECT name, collection_ids, property_index_type + FROM queryables + WHERE name LIKE 'test:%%' + AND %s = ANY(collection_ids) + ORDER BY name; + """, + [hyphenated_collection], + ) + + # Convert result to a list of dictionaries + queryables = [ + {"name": row[0], "collection_ids": row[1], "property_index_type": row[2]} + for row in result + ] + + # Check that all queryables were created and associated with the collection + assert len(queryables) == 5 # All test properties should be present + for queryable in queryables: + assert hyphenated_collection in queryable["collection_ids"] + # Check that only test:string_prop has an index + if queryable["name"] == "test:string_prop": + assert queryable["property_index_type"] == "BTREE" + else: + assert queryable["property_index_type"] is None + def test_load_queryables_no_properties(db: PgstacDB) -> None: """Test loading queryables with no properties.""" # Create a CLI instance