diff --git a/.bazelrc b/.bazelrc index 21a4ea2a..bd2a357a 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,7 +1,11 @@ # Needed to work with ZetaSQL dependency. +# Zetasql is removed. +# This is a candidate for removal build --cxxopt="-std=c++17" # Needed to avoid zetasql proto error. +# Zetasql is removed. +# This is a candidate for removal build --protocopt=--experimental_allow_proto3_optional # icu@: In create_linking_context: in call to create_linking_context(), diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index af6ea0d1..b8a65fd3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,7 +29,9 @@ jobs: - name: Install built wheel shell: bash run: | - pip install dist/*.whl['test'] + PYTHON_VERSION_TAG="cp$(echo ${{ matrix.python-version }} | sed 's/\.//')" + WHEEL_FILE=$(ls dist/*${PYTHON_VERSION_TAG}*.whl) + pip install "${WHEEL_FILE}[test]" - name: Run Test run: | diff --git a/WORKSPACE b/WORKSPACE index 71db771c..0f237431 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -71,18 +71,6 @@ http_archive( ], ) -# Needed by abseil-py by zetasql. -http_archive( - name = "six_archive", - build_file = "//third_party:six.BUILD", - sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", - strip_prefix = "six-1.10.0", - urls = [ - "http://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", - "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", - ], -) - load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") protobuf_deps() @@ -112,6 +100,16 @@ http_archive( url = "https://github.com/abseil/abseil-cpp/archive/%s.tar.gz" % COM_GOOGLE_ABSL_COMMIT, ) + +# re2 required for google tests +http_archive( + name = "com_googlesource_code_re2", + # build_file = "//third_party:re2.BUILD", + sha256 = "b90430b2a9240df4459108b3e291be80ae92c68a47bc06ef2dc419c5724de061", + strip_prefix = "re2-a276a8c738735a0fe45a6ee590fe2df69bcf4502", + urls = ["https://github.com/google/re2/archive/a276a8c738735a0fe45a6ee590fe2df69bcf4502.tar.gz"], +) + # Will be loaded by workspace.bzl from head # TFMD_COMMIT = "404805761e614561cceedc429e67c357c62be26d" # 1.17.1 @@ -218,46 +216,6 @@ load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies") #, "go_repository") gazelle_dependencies() -################################################################################ -# ZetaSQL # -################################################################################ - -ZETASQL_COMMIT = "a516c6b26d183efc4f56293256bba92e243b7a61" # 11/01/2024 - -http_archive( - name = "com_google_zetasql", - patch_args = ["-p1"], - patches = ["//third_party:zetasql.patch"], - sha256 = "1afc2210d4aad371eff0a6bfdd8417ba99e02183a35dff167af2fa6097643f26", - strip_prefix = "zetasql-%s" % ZETASQL_COMMIT, - urls = ["https://github.com/google/zetasql/archive/%s.tar.gz" % ZETASQL_COMMIT], -) - -load("@com_google_zetasql//bazel:zetasql_deps_step_1.bzl", "zetasql_deps_step_1") - -zetasql_deps_step_1() - -load("@com_google_zetasql//bazel:zetasql_deps_step_2.bzl", "zetasql_deps_step_2") - -zetasql_deps_step_2( - analyzer_deps = True, - evaluator_deps = True, - java_deps = False, - testing_deps = False, - tools_deps = False, -) - -# No need to run zetasql_deps_step_3 and zetasql_deps_step_4 since all necessary dependencies are -# already installed. - -# load("@com_google_zetasql//bazel:zetasql_deps_step_3.bzl", "zetasql_deps_step_3") - -# zetasql_deps_step_3() - -# load("@com_google_zetasql//bazel:zetasql_deps_step_4.bzl", "zetasql_deps_step_4") - -# zetasql_deps_step_4() - _PLATFORMS_VERSION = "0.0.6" http_archive( diff --git a/tensorflow_data_validation/BUILD b/tensorflow_data_validation/BUILD index 7fc5d094..198c42b7 100644 --- a/tensorflow_data_validation/BUILD +++ b/tensorflow_data_validation/BUILD @@ -31,7 +31,6 @@ sh_binary( srcs = ["move_generated_files.sh"], data = select({ "//conditions:default": [ - "//tensorflow_data_validation/anomalies/proto:custom_validation_config_proto_py_pb2", "//tensorflow_data_validation/anomalies/proto:validation_config_proto_py_pb2", "//tensorflow_data_validation/anomalies/proto:validation_metadata_proto_py_pb2", "//tensorflow_data_validation/pywrap:tensorflow_data_validation_extension.so", diff --git a/tensorflow_data_validation/anomalies/BUILD b/tensorflow_data_validation/anomalies/BUILD index a26012bc..7d97f36e 100644 --- a/tensorflow_data_validation/anomalies/BUILD +++ b/tensorflow_data_validation/anomalies/BUILD @@ -425,38 +425,6 @@ cc_test( ], ) -cc_library( - name = "custom_validation", - srcs = ["custom_validation.cc"], - hdrs = ["custom_validation.h"], - deps = [ - ":path", - ":schema", - ":status_util", - "//tensorflow_data_validation/anomalies/proto:custom_validation_config_proto_cc_pb2", - "@com_github_tensorflow_metadata//tensorflow_metadata/proto/v0:metadata_v0_proto_cc_pb2", - "@com_github_tfx_bsl//tfx_bsl/cc/statistics:sql_util", - "@com_google_absl//absl/base:log_severity", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - ], -) - -cc_test( - name = "custom_validation_test", - srcs = ["custom_validation_test.cc"], - deps = [ - ":custom_validation", - ":test_util", - "@com_github_tensorflow_metadata//tensorflow_metadata/proto/v0:metadata_v0_proto_cc_pb2", - "@com_google_absl//absl/types:optional", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "telemetry", srcs = ["telemetry.cc"], diff --git a/tensorflow_data_validation/anomalies/custom_validation.cc b/tensorflow_data_validation/anomalies/custom_validation.cc deleted file mode 100644 index bd35cdb5..00000000 --- a/tensorflow_data_validation/anomalies/custom_validation.cc +++ /dev/null @@ -1,306 +0,0 @@ -/* Copyright 2022 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow_data_validation/anomalies/custom_validation.h" - -#include "absl/base/log_severity.h" -#include "absl/container/flat_hash_map.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "tensorflow_data_validation/anomalies/path.h" -#include "tensorflow_data_validation/anomalies/schema_util.h" -#include "tensorflow_data_validation/anomalies/status_util.h" -#include "tfx_bsl/cc/statistics/sql_util.h" -#include "tensorflow_metadata/proto/v0/anomalies.pb.h" -#include "tensorflow_metadata/proto/v0/path.pb.h" -#include "tensorflow_metadata/proto/v0/statistics.pb.h" - -namespace tensorflow { -namespace data_validation { - -namespace { - -using ::tensorflow::metadata::v0::Anomalies; -using ::tensorflow::metadata::v0::AnomalyInfo; -using ::tensorflow::metadata::v0::DatasetFeatureStatisticsList; -using ::tensorflow::metadata::v0::FeatureNameStatistics; - -constexpr char kDefaultSlice[] = "All Examples"; - -// TODO(b/208881543): Update this type alias if representation of slice keys -// changes. -using SliceKey = std::string; - -absl::flat_hash_map> -BuildNamedStatisticsMap(const DatasetFeatureStatisticsList& statistics) { - absl::flat_hash_map> - named_statistics; - for (const auto& dataset : statistics.datasets()) { - for (const auto& feature : dataset.features()) { - const metadata::v0::Path& feature_path = feature.path(); - const std::string serialized_feature_path = - Path(feature_path).Serialize(); - named_statistics[dataset.name()][serialized_feature_path] = feature; - } - } - return named_statistics; -} - -absl::Status GetFeatureStatistics( - const absl::flat_hash_map< - SliceKey, absl::flat_hash_map>& - named_statistics, - const std::string& dataset_name, const metadata::v0::Path& feature_path, - FeatureNameStatistics* statistics) { - auto named_feature_statistics = named_statistics.find(dataset_name); - if (named_feature_statistics == named_statistics.end()) { - if (dataset_name.empty()) { - // If no matching stats are found and no dataset name is specified, use - // the default slice. - named_feature_statistics = named_statistics.find(kDefaultSlice); - } - if (named_feature_statistics == named_statistics.end()) { - return absl::InvalidArgumentError(absl::StrCat( - "Dataset ", dataset_name, - " specified in validation config not found in statistics.")); - } - } - const std::string serialized_feature_path = Path(feature_path).Serialize(); - const auto& feature_statistics = - named_feature_statistics->second.find(serialized_feature_path); - if (feature_statistics == named_feature_statistics->second.end()) { - return absl::InvalidArgumentError(absl::StrCat( - "Feature ", serialized_feature_path, - " specified in validation config not found in statistics.")); - } - *statistics = feature_statistics->second; - return absl::OkStatus(); -} - -absl::Status MergeAnomalyInfos(const AnomalyInfo& anomaly_info, - const std::string& key, - AnomalyInfo* existing_anomaly_info) { - if (Path(anomaly_info.path()).Compare(Path(existing_anomaly_info->path())) != - 0) { - return absl::AlreadyExistsError( - absl::StrCat("Anomaly info map includes entries for ", key, - " which do not have the same path.")); - } - if (anomaly_info.severity() != existing_anomaly_info->severity()) { - existing_anomaly_info->set_severity(MaxSeverity( - anomaly_info.severity(), existing_anomaly_info->severity())); - LOG(WARNING) - << "Anomaly entry for " << key - << " has conflicting severities. The higher severity will be used."; - } - for (const auto& reason : anomaly_info.reason()) { - AnomalyInfo::Reason* new_reason = existing_anomaly_info->add_reason(); - new_reason->CopyFrom(reason); - } - return absl::OkStatus(); -} - -// TODO(b/239095455): Populate top-level descriptions if needed for -// visualization. -absl::Status UpdateAnomalyResults( - const metadata::v0::Path& path, const std::string& test_dataset, - const absl::optional base_dataset, - const absl::optional base_path, - const Validation& validation, Anomalies* results) { - AnomalyInfo anomaly_info; - AnomalyInfo::Reason reason; - reason.set_type(AnomalyInfo::CUSTOM_VALIDATION); - reason.set_short_description(validation.description()); - std::string anomaly_source_description = - absl::StrCat("Query: ", validation.sql_expression(), " Test dataset: "); - if (test_dataset.empty()) { - absl::StrAppend(&anomaly_source_description, "default slice"); - } else { - absl::StrAppend(&anomaly_source_description, test_dataset); - } - if (base_dataset.has_value()) { - absl::StrAppend(&anomaly_source_description, - " Base dataset: ", base_dataset.value(), " "); - } - if (base_path.has_value()) { - absl::StrAppend(&anomaly_source_description, - "Base path: ", Path(base_path.value()).Serialize()); - } - reason.set_description(absl::StrCat("Custom validation triggered anomaly. ", - anomaly_source_description)); - anomaly_info.mutable_path()->CopyFrom(path); - anomaly_info.set_severity(validation.severity()); - anomaly_info.add_reason()->CopyFrom(reason); - const std::string& feature_name = Path(path).Serialize(); - const auto& insert_result = - results->mutable_anomaly_info()->insert({feature_name, anomaly_info}); - // feature_name already existed in anomaly_info. - if (insert_result.second == false) { - AnomalyInfo existing_anomaly_info = - results->anomaly_info().at(feature_name); - TFDV_RETURN_IF_ERROR( - MergeAnomalyInfos(anomaly_info, feature_name, &existing_anomaly_info)); - results->mutable_anomaly_info() - ->at(feature_name) - .CopyFrom(existing_anomaly_info); - } - return absl::OkStatus(); -} - -bool InCurrentEnvironment(Validation validation, - const absl::optional& environment) { - if (validation.in_environment_size() == 0) { - return true; - } - if (environment.has_value()) { - const std::string& environment_value = environment.value(); - for (const auto& each : validation.in_environment()) { - if (each == environment_value) { - return true; - } - } - } - return false; -} - -} // namespace - -absl::Status CustomValidateStatistics( - const metadata::v0::DatasetFeatureStatisticsList& test_statistics, - const metadata::v0::DatasetFeatureStatisticsList* base_statistics, - const CustomValidationConfig& validations, - const absl::optional environment, - metadata::v0::Anomalies* result) { - absl::flat_hash_map> - named_test_statistics = BuildNamedStatisticsMap(test_statistics); - for (const auto& feature_validation : validations.feature_validations()) { - FeatureNameStatistics test_statistics; - TFDV_RETURN_IF_ERROR(GetFeatureStatistics( - named_test_statistics, feature_validation.dataset_name(), - feature_validation.feature_path(), &test_statistics)); - for (const auto& validation : feature_validation.validations()) { - if (InCurrentEnvironment(validation, environment)) { - absl::StatusOr query_result = - tfx_bsl::statistics::EvaluatePredicate(test_statistics, - validation.sql_expression()); - if (!query_result.ok()) { - return absl::InternalError(absl::StrCat( - "Attempt to run query '", validation.sql_expression(), - "' failed with error: ", query_result.status().ToString())); - } else if (!query_result.value()) { - // If the sql_expression evaluates to False, there is an anomaly. - TFDV_RETURN_IF_ERROR(UpdateAnomalyResults( - feature_validation.feature_path(), - feature_validation.dataset_name(), absl::nullopt, absl::nullopt, - validation, result)); - } - } - } - } - if (validations.feature_pair_validations_size() > 0) { - if (base_statistics == nullptr) { - return absl::InvalidArgumentError( - "Feature pair validations are included in the CustomValidationConfig " - "but base_statistics have not been specified."); - } - absl::flat_hash_map> - named_base_statistics = BuildNamedStatisticsMap(*base_statistics); - for (const auto& feature_pair_validation : - validations.feature_pair_validations()) { - FeatureNameStatistics test_statistics; - FeatureNameStatistics base_statistics; - TFDV_RETURN_IF_ERROR(GetFeatureStatistics( - named_test_statistics, feature_pair_validation.dataset_name(), - feature_pair_validation.feature_test_path(), &test_statistics)); - TFDV_RETURN_IF_ERROR(GetFeatureStatistics( - named_base_statistics, feature_pair_validation.base_dataset_name(), - feature_pair_validation.feature_base_path(), &base_statistics)); - for (const auto& validation : feature_pair_validation.validations()) { - if (InCurrentEnvironment(validation, environment)) { - absl::StatusOr query_result = - tfx_bsl::statistics::EvaluatePredicate( - base_statistics, test_statistics, - validation.sql_expression()); - if (!query_result.ok()) { - return absl::InternalError(absl::StrCat( - "Attempt to run query: ", validation.sql_expression(), - " failed with the following error: ", - query_result.status().ToString())); - } else if (!query_result.value()) { - // If the sql_expression evaluates to False, there is an anomaly. - TFDV_RETURN_IF_ERROR(UpdateAnomalyResults( - feature_pair_validation.feature_test_path(), - feature_pair_validation.dataset_name(), - feature_pair_validation.base_dataset_name(), - feature_pair_validation.feature_base_path(), validation, - result)); - } - } - } - } - } - return absl::OkStatus(); -} - -absl::Status CustomValidateStatisticsWithSerializedInputs( - const std::string& serialized_test_statistics, - const std::string& serialized_base_statistics, - const std::string& serialized_validations, - const std::string& serialized_environment, - std::string* serialized_anomalies_proto) { - metadata::v0::DatasetFeatureStatisticsList test_statistics; - metadata::v0::DatasetFeatureStatisticsList base_statistics; - metadata::v0::DatasetFeatureStatisticsList* base_statistics_ptr = nullptr; - if (!test_statistics.ParseFromString(serialized_test_statistics)) { - return absl::InvalidArgumentError( - "Failed to parse DatasetFeatureStatistics proto."); - } - if (!serialized_base_statistics.empty()) { - if (!base_statistics.ParseFromString(serialized_base_statistics)) { - return absl::InvalidArgumentError( - "Failed to parse DatasetFeatureStatistics proto."); - } - base_statistics_ptr = &base_statistics; - } - CustomValidationConfig validations; - if (!validations.ParseFromString(serialized_validations)) { - return absl::InvalidArgumentError( - "Failed to parse CustomValidationConfig proto."); - } - absl::optional environment = absl::nullopt; - if (!serialized_environment.empty()) { - environment = serialized_environment; - } - metadata::v0::Anomalies anomalies; - const absl::Status status = - CustomValidateStatistics(test_statistics, base_statistics_ptr, - validations, environment, &anomalies); - if (!status.ok()) { - return absl::InternalError( - absl::StrCat("Failed to run custom validations: ", status.message())); - } - if (!anomalies.SerializeToString(serialized_anomalies_proto)) { - return absl::InternalError( - "Failed to serialize Anomalies output proto to string."); - } - return absl::OkStatus(); -} - -} // namespace data_validation -} // namespace tensorflow diff --git a/tensorflow_data_validation/anomalies/custom_validation.h b/tensorflow_data_validation/anomalies/custom_validation.h deleted file mode 100644 index 574e9dcc..00000000 --- a/tensorflow_data_validation/anomalies/custom_validation.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2022 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef THIRD_PARTY_PY_TENSORFLOW_DATA_VALIDATION_ANOMALIES_CUSTOM_VALIDATION_H_ -#define THIRD_PARTY_PY_TENSORFLOW_DATA_VALIDATION_ANOMALIES_CUSTOM_VALIDATION_H_ - -#include "absl/status/status.h" -#include "absl/types/optional.h" -#include "tensorflow_data_validation/anomalies/proto/custom_validation_config.pb.h" -#include "tensorflow_metadata/proto/v0/anomalies.pb.h" -#include "tensorflow_metadata/proto/v0/statistics.pb.h" - -namespace tensorflow { -namespace data_validation { - -// Validates `test_statistics` (either alone or by comparing it to -// `base_statistics`) by running the SQL queries specified in `validations`. If -// a validation query returns False, a corresponding anomaly is added to -// `result`. -absl::Status CustomValidateStatistics( - const metadata::v0::DatasetFeatureStatisticsList& test_statistics, - const metadata::v0::DatasetFeatureStatisticsList* base_statistics, - const CustomValidationConfig& validations, - const absl::optional environment, - metadata::v0::Anomalies* result); - -// Like CustomValidateStatistics but with serialized inputs. Used for doing -// custom validation in Python. -absl::Status CustomValidateStatisticsWithSerializedInputs( - const std::string& serialized_test_statistics, - const std::string& serialized_base_statistics, - const std::string& serialized_validations, - const std::string& serialized_environment, - std::string* serialized_anomalies_proto); - -} // namespace data_validation -} // namespace tensorflow - -#endif // THIRD_PARTY_PY_TENSORFLOW_DATA_VALIDATION_ANOMALIES_CUSTOM_VALIDATION_H_ diff --git a/tensorflow_data_validation/anomalies/custom_validation_test.cc b/tensorflow_data_validation/anomalies/custom_validation_test.cc deleted file mode 100644 index d5e5b504..00000000 --- a/tensorflow_data_validation/anomalies/custom_validation_test.cc +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright 2022 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow_data_validation/anomalies/custom_validation.h" - -#include -#include "absl/types/optional.h" -#include "tensorflow_data_validation/anomalies/test_util.h" -#include "tensorflow_metadata/proto/v0/statistics.pb.h" - -namespace tensorflow { -namespace data_validation { - -namespace { - -using testing::EqualsProto; -using testing::ParseTextProtoOrDie; - -TEST(CustomValidationTest, TestSingleStatisticsDefaultSliceValidation) { - metadata::v0::DatasetFeatureStatisticsList test_statistics = - ParseTextProtoOrDie( - R"pb(datasets { - name: "All Examples" - num_examples: 10 - features { - path { step: 'test_feature' } - type: INT - num_stats { num_zeros: 5 max: 25 } - } - })pb"); - CustomValidationConfig validations = - ParseTextProtoOrDie( - R"pb(feature_validations { - feature_path { step: 'test_feature' } - validations { - sql_expression: 'feature.num_stats.num_zeros < 3' - severity: ERROR - description: 'Feature has too many zeros.' - } - validations { - sql_expression: 'feature.num_stats.max > 10' - severity: ERROR - description: 'Maximum value is too low.' - } - })pb"); - metadata::v0::Anomalies expected_anomalies = ParseTextProtoOrDie< - metadata::v0::Anomalies>( - R"pb(anomaly_info { - key: 'test_feature' - value: { - path { step: 'test_feature' } - severity: ERROR - reason { - type: CUSTOM_VALIDATION - short_description: 'Feature has too many zeros.' - description: 'Custom validation triggered anomaly. Query: feature.num_stats.num_zeros < 3 Test dataset: default slice' - } - } - })pb"); - metadata::v0::Anomalies result; - CHECK_OK(CustomValidateStatistics(test_statistics, - /*base_statistics=*/nullptr, validations, - /*environment=*/absl::nullopt, &result)); - EXPECT_THAT(result, EqualsProto(expected_anomalies)); -} - -TEST(CustomValidationTest, TestSingleStatisticsSpecifiedSliceValidation) { - metadata::v0::DatasetFeatureStatisticsList test_statistics = - ParseTextProtoOrDie( - R"pb(datasets { - name: "All Examples" - num_examples: 10 - features { - path { step: 'test_feature' } - type: INT - num_stats { num_zeros: 5 max: 25 } - } - } - datasets { - name: "some_slice" - num_examples: 5 - features { - path { step: 'test_feature' } - type: INT - num_stats { num_zeros: 0 max: 5 } - } - })pb"); - CustomValidationConfig validations = - ParseTextProtoOrDie( - R"pb(feature_validations { - dataset_name: 'some_slice' - feature_path { step: 'test_feature' } - validations { - sql_expression: 'feature.num_stats.num_zeros < 3' - severity: ERROR - description: 'Feature has too many zeros.' - } - validations { - sql_expression: 'feature.num_stats.max > 10' - severity: ERROR - description: 'Maximum value is too low.' - } - })pb"); - metadata::v0::Anomalies expected_anomalies = ParseTextProtoOrDie< - metadata::v0::Anomalies>( - R"pb(anomaly_info { - key: 'test_feature' - value: { - path { step: 'test_feature' } - severity: ERROR - reason { - type: CUSTOM_VALIDATION - short_description: 'Maximum value is too low.' - description: 'Custom validation triggered anomaly. Query: feature.num_stats.max > 10 Test dataset: some_slice' - } - } - })pb"); - metadata::v0::Anomalies result; - CHECK_OK(CustomValidateStatistics(test_statistics, - /*base_statistics=*/nullptr, validations, - /*environment=*/absl::nullopt, &result)); - EXPECT_THAT(result, EqualsProto(expected_anomalies)); -} - -TEST(CustomValidationTest, TestPairValidation) { - metadata::v0::DatasetFeatureStatisticsList test_statistics = - ParseTextProtoOrDie( - R"pb(datasets { - name: "slice_1" - num_examples: 10 - features { - path { step: 'test_feature' } - type: INT - num_stats { num_zeros: 5 max: 25 } - } - })pb"); - metadata::v0::DatasetFeatureStatisticsList base_statistics = - ParseTextProtoOrDie( - R"pb(datasets { - name: "slice_2" - num_examples: 10 - features { - path { step: 'test_feature' } - type: INT - num_stats { num_zeros: 1 max: 1 } - } - })pb"); - CustomValidationConfig validations = ParseTextProtoOrDie< - CustomValidationConfig>( - R"pb(feature_pair_validations { - dataset_name: 'slice_1' - feature_test_path { step: 'test_feature' } - base_dataset_name: 'slice_2' - feature_base_path { step: 'test_feature' } - validations { - sql_expression: 'feature_test.num_stats.num_zeros < feature_base.num_stats.num_zeros' - severity: ERROR - description: 'Test feature has too many zeros.' - } - validations { - sql_expression: 'feature_test.num_stats.num_zeros > feature_base.num_stats.num_zeros' - severity: ERROR - description: 'Base feature has too few zeros.' - } - })pb"); - metadata::v0::Anomalies expected_anomalies = ParseTextProtoOrDie< - metadata::v0::Anomalies>( - R"pb(anomaly_info { - key: 'test_feature' - value: { - path { step: 'test_feature' } - severity: ERROR - reason { - type: CUSTOM_VALIDATION - short_description: 'Test feature has too many zeros.' - description: 'Custom validation triggered anomaly. Query: feature_test.num_stats.num_zeros < feature_base.num_stats.num_zeros Test dataset: slice_1 Base dataset: slice_2 Base path: test_feature' - } - } - })pb"); - metadata::v0::Anomalies result; - CHECK_OK(CustomValidateStatistics(test_statistics, &base_statistics, - validations, - /*environment=*/absl::nullopt, &result)); - EXPECT_THAT(result, EqualsProto(expected_anomalies)); -} - -TEST(CustomValidationTest, TestSpecifiedFeatureNotFound) { - metadata::v0::DatasetFeatureStatisticsList test_statistics = - ParseTextProtoOrDie( - R"pb(datasets { - name: "All Examples" - num_examples: 10 - features { - path { step: 'test_feature' } - type: INT - num_stats { num_zeros: 5 max: 25 } - } - })pb"); - CustomValidationConfig validations = - ParseTextProtoOrDie( - R"pb(feature_validations { - feature_path { step: 'other_feature' } - validations { - sql_expression: 'feature.num_stats.num_zeros < 3' - severity: ERROR - description: 'Feature has too many zeros.' - } - })pb"); - metadata::v0::Anomalies result; - auto error = - CustomValidateStatistics(test_statistics, - /*base_statistics=*/nullptr, validations, - /*environment=*/absl::nullopt, &result); - EXPECT_TRUE(absl::IsInvalidArgument(error)); -} - -TEST(CustomValidationTest, TestSpecifiedDatasetNotFound) { - metadata::v0::DatasetFeatureStatisticsList test_statistics = - ParseTextProtoOrDie( - R"pb(datasets { - name: "some_slice" - num_examples: 10 - features { - path { step: 'test_feature' } - type: INT - num_stats { num_zeros: 5 max: 25 } - } - })pb"); - CustomValidationConfig validations = - ParseTextProtoOrDie( - R"pb(feature_validations { - dataset_name: 'other_slice' - feature_path { step: 'other_feature' } - validations { - sql_expression: 'feature.num_stats.num_zeros < 3' - severity: ERROR - description: 'Feature has too many zeros.' - } - })pb"); - metadata::v0::Anomalies result; - auto error = - CustomValidateStatistics(test_statistics, - /*base_statistics=*/nullptr, validations, - /*environment=*/absl::nullopt, &result); - EXPECT_TRUE(absl::IsInvalidArgument(error)); -} - -TEST(CustomValidationTest, TestMultipleAnomalyReasons) { - metadata::v0::DatasetFeatureStatisticsList test_statistics = - ParseTextProtoOrDie( - R"pb(datasets { - name: "All Examples" - num_examples: 10 - features { - path { step: 'test_feature' } - type: INT - num_stats { num_zeros: 5 max: 25 } - } - })pb"); - CustomValidationConfig validations = - ParseTextProtoOrDie( - R"pb(feature_validations { - feature_path { step: 'test_feature' } - validations { - sql_expression: 'feature.num_stats.num_zeros < 3' - severity: WARNING - description: 'Feature has too many zeros.' - } - validations { - sql_expression: 'feature.num_stats.max > 100' - severity: ERROR - description: 'Maximum value is too low.' - } - })pb"); - metadata::v0::Anomalies expected_anomalies = ParseTextProtoOrDie< - metadata::v0::Anomalies>( - R"pb(anomaly_info { - key: 'test_feature' - value: { - path { step: 'test_feature' } - severity: ERROR - reason { - type: CUSTOM_VALIDATION - short_description: 'Feature has too many zeros.' - description: 'Custom validation triggered anomaly. Query: feature.num_stats.num_zeros < 3 Test dataset: default slice' - } - reason { - type: CUSTOM_VALIDATION - short_description: 'Maximum value is too low.' - description: 'Custom validation triggered anomaly. Query: feature.num_stats.max > 100 Test dataset: default slice' - } - } - })pb"); - metadata::v0::Anomalies result; - CHECK_OK(CustomValidateStatistics(test_statistics, - /*base_statistics=*/nullptr, validations, - /*environment=*/absl::nullopt, &result)); - EXPECT_THAT(result, EqualsProto(expected_anomalies)); -} - -TEST(CustomValidationTest, TestPairValidationsConfiguredButNoBaselineStats) { - metadata::v0::DatasetFeatureStatisticsList test_statistics = - ParseTextProtoOrDie( - R"pb(datasets { - name: "some_slice" - num_examples: 10 - features { - path { step: 'test_feature' } - type: INT - num_stats { num_zeros: 5 max: 25 } - } - })pb"); - CustomValidationConfig validations = ParseTextProtoOrDie< - CustomValidationConfig>( - R"pb(feature_pair_validations { - dataset_name: 'some_slice' - feature_test_path { step: 'test_feature' } - base_dataset_name: 'slice_2' - feature_base_path { step: 'test_feature' } - validations { - sql_expression: 'feature_test.num_stats.num_zeros < feature_base.num_stats.num_zeros' - severity: ERROR - description: 'Test feature has too many zeros.' - } - })pb"); - metadata::v0::Anomalies result; - auto error = - CustomValidateStatistics(test_statistics, - /*base_statistics=*/nullptr, validations, - /*environment=*/absl::nullopt, &result); - EXPECT_TRUE(absl::IsInvalidArgument(error)); -} - -TEST(CustomValidationTest, TestEnvironmentFiltering) { - metadata::v0::DatasetFeatureStatisticsList test_statistics = - ParseTextProtoOrDie( - R"pb(datasets { - name: "All Examples" - num_examples: 10 - features { - path { step: 'test_feature' } - type: INT - num_stats { num_zeros: 5 max: 1 } - } - })pb"); - CustomValidationConfig validations = - ParseTextProtoOrDie( - R"pb(feature_validations { - feature_path { step: 'test_feature' } - validations { - sql_expression: 'feature.num_stats.num_zeros < 3' - severity: ERROR - description: 'Feature has too many zeros.' - } - } - feature_validations { - feature_path { step: 'test_feature' } - validations { - sql_expression: 'feature.num_stats.max > 2' - severity: ERROR - description: 'Maximum value is wrong.' - in_environment: 'not_this_environment' - } - } - feature_validations { - feature_path { step: 'test_feature' } - validations { - sql_expression: 'feature.num_stats.max > 10' - severity: ERROR - description: 'Maximum value is too low.' - in_environment: 'some_environment' - } - })pb"); - metadata::v0::Anomalies expected_anomalies = ParseTextProtoOrDie< - metadata::v0::Anomalies>( - R"pb(anomaly_info { - key: 'test_feature' - value: { - path { step: 'test_feature' } - severity: ERROR - reason { - type: CUSTOM_VALIDATION - short_description: 'Feature has too many zeros.' - description: 'Custom validation triggered anomaly. Query: feature.num_stats.num_zeros < 3 Test dataset: default slice' - } - reason { - type: CUSTOM_VALIDATION - short_description: 'Maximum value is too low.' - description: 'Custom validation triggered anomaly. Query: feature.num_stats.max > 10 Test dataset: default slice' - } - } - })pb"); - metadata::v0::Anomalies result; - CHECK_OK(CustomValidateStatistics(test_statistics, - /*base_statistics=*/nullptr, validations, - "some_environment", &result)); - EXPECT_THAT(result, EqualsProto(expected_anomalies)); -} - -} // namespace -} // namespace data_validation -} // namespace tensorflow diff --git a/tensorflow_data_validation/anomalies/proto/BUILD b/tensorflow_data_validation/anomalies/proto/BUILD index 3a4634df..616f2344 100644 --- a/tensorflow_data_validation/anomalies/proto/BUILD +++ b/tensorflow_data_validation/anomalies/proto/BUILD @@ -36,14 +36,3 @@ tfdv_proto_library_py( name = "validation_metadata_proto_py_pb2", deps = [":validation_metadata_proto"], ) - -tfdv_proto_library( - name = "custom_validation_config_proto", - srcs = ["custom_validation_config.proto"], - deps = ["@com_github_tensorflow_metadata//tensorflow_metadata/proto/v0:metadata_v0_proto"], -) - -tfdv_proto_library_py( - name = "custom_validation_config_proto_py_pb2", - deps = [":custom_validation_config_proto"], -) diff --git a/tensorflow_data_validation/anomalies/proto/custom_validation_config.proto b/tensorflow_data_validation/anomalies/proto/custom_validation_config.proto deleted file mode 100644 index 3b19832b..00000000 --- a/tensorflow_data_validation/anomalies/proto/custom_validation_config.proto +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -syntax = "proto2"; - -package tensorflow.data_validation; - -import "tensorflow_metadata/proto/v0/anomalies.proto"; -import "tensorflow_metadata/proto/v0/path.proto"; - -// Use this proto to configure custom validations in TFDV. -// Example usages follow. -// ----------------------------------------------------------------------------- -// Example Single-Feature Validation -// Statistics - // datasets { - // name: "All Examples" - // num_examples: 10 - // features { - // path { step: 'test_feature' } - // type: INT - // num_stats { num_zeros: 5 max: 25 } - // } - // } -// CustomValidationConfig - // feature_validations { - // feature_path { step: 'test_feature' } - // validations { - // sql_expression: 'feature.num_stats.num_zeros < 3' - // severity: ERROR - // description: 'Feature has too many zeros.' - // } - // validations { - // sql_expression: 'feature.num_stats.max > 10' - // severity: ERROR - // description: 'Maximum value is too low.' - // } - // } -// Anomalies - // anomaly_info { - // key: 'test_feature' - // value: { - // path { step: 'test_feature' } - // severity: ERROR - // reason { - // type: CUSTOM_VALIDATION - // short_description: 'Feature has too many zeros.' - // description: 'Custom validation triggered anomaly. Query: feature.num_stats.num_zeros < 3 Test dataset: default slice' - // } - // } - // } -// ----------------------------------------------------------------------------- -// Example Feature Pair Validation -// Statistics -// Test statistics - // datasets { - // name: "slice_1" - // num_examples: 10 - // features { - // path { step: 'test_feature' } - // type: INT - // num_stats { num_zeros: 5 max: 25 } - // } - // } -// Base statistics - // datasets { - // name: "slice_2" - // num_examples: 10 - // features { - // path { step: 'test_feature' } - // type: INT - // num_stats { num_zeros: 1 max: 1 } - // } - // } -// CustomValidationConfig - // feature_pair_validations { - // dataset_name: 'slice_1' - // feature_test_path { step: 'test_feature' } - // base_dataset_name: 'slice_2' - // feature_base_path { step: 'test_feature' } - // validations { - // sql_expression: 'feature_test.num_stats.num_zeros < feature_base.num_stats.num_zeros' - // severity: ERROR - // description: 'Test feature has too many zeros.' - // } - // } -// Anomalies - // anomaly_info { - // key: 'test_feature' - // value: { - // path { step: 'test_feature' } - // severity: ERROR - // reason { - // type: CUSTOM_VALIDATION - // short_description: 'Test feature has too many zeros.' - // description: 'Custom validation triggered anomaly. Query: feature_test.num_stats.num_zeros < feature_base.num_stats.num_zeros Test dataset: slice_1 Base dataset: slice_2 Base path: test_feature' - // } - // } - // } -// ============================================================================= - -message Validation { - // Expression to evaluate. If the expression returns false, the anomaly is - // returned. - // For single feature validations, the feature statistics are bound to - // `feature`. For feature pair validations, the test feature statistics are - // bound to `feature_test` and the base feature statistics are bound to - // `feature_base`. - optional string sql_expression = 1; - optional tensorflow.metadata.v0.AnomalyInfo.Severity severity = 2; - optional string description = 3; - // Use this to limit the data on which the validation runs to only the - // specified environments. If this field is not specified, the validation - // will always run. - repeated string in_environment = 4; -} - -message FeatureValidation { - // The name of the dataset (i.e., slice) to validate. You do not need to - // specify this if using the default slice, provided there is no empty-named - // slice in the input statistics. - optional string dataset_name = 1; - optional tensorflow.metadata.v0.Path feature_path = 2; - repeated Validation validations = 3; -} - -message FeaturePairValidation { - // The name of the dataset (i.e., slice) to validate. You do not need to - // specify this if using the default slice, provided there is no empty-named - // slice in the input statistics. - optional string dataset_name = 1; - optional string base_dataset_name = 2; - optional tensorflow.metadata.v0.Path feature_test_path = 3; - optional tensorflow.metadata.v0.Path feature_base_path = 4; - repeated Validation validations = 5; -} - -message CustomValidationConfig { - repeated FeatureValidation feature_validations = 1; - repeated FeaturePairValidation feature_pair_validations = 2; -} diff --git a/tensorflow_data_validation/api/validation_api.py b/tensorflow_data_validation/api/validation_api.py index 39e19656..030ec222 100644 --- a/tensorflow_data_validation/api/validation_api.py +++ b/tensorflow_data_validation/api/validation_api.py @@ -25,7 +25,6 @@ from tensorflow_data_validation import constants, types from tensorflow_data_validation.anomalies.proto import ( - custom_validation_config_pb2, validation_config_pb2, validation_metadata_pb2, ) @@ -222,45 +221,12 @@ def _merge_descriptions( return " ".join(descriptions) -def _merge_custom_anomalies( - anomalies: anomalies_pb2.Anomalies, custom_anomalies: anomalies_pb2.Anomalies -) -> anomalies_pb2.Anomalies: - """Merges custom_anomalies with anomalies.""" - for key, custom_anomaly_info in custom_anomalies.anomaly_info.items(): - if key in anomalies.anomaly_info: - # If the key is found in in both inputs, we know it has multiple errors. - anomalies.anomaly_info[key].short_description = _MULTIPLE_ERRORS - anomalies.anomaly_info[key].description = _merge_descriptions( - anomalies.anomaly_info[key], custom_anomaly_info - ) - anomalies.anomaly_info[key].severity = max( - anomalies.anomaly_info[key].severity, custom_anomaly_info.severity - ) - anomalies.anomaly_info[key].reason.extend(custom_anomaly_info.reason) - else: - anomalies.anomaly_info[key].CopyFrom(custom_anomaly_info) - # Also populate top-level descriptions. - anomalies.anomaly_info[key].description = _merge_descriptions( - custom_anomaly_info, None - ) - if len(anomalies.anomaly_info[key].reason) > 1: - anomalies.anomaly_info[key].short_description = _MULTIPLE_ERRORS - else: - anomalies.anomaly_info[ - key - ].short_description = custom_anomaly_info.reason[0].short_description - return anomalies - - def validate_statistics( statistics: statistics_pb2.DatasetFeatureStatisticsList, schema: schema_pb2.Schema, environment: Optional[str] = None, previous_statistics: Optional[statistics_pb2.DatasetFeatureStatisticsList] = None, serving_statistics: Optional[statistics_pb2.DatasetFeatureStatisticsList] = None, - custom_validation_config: Optional[ - custom_validation_config_pb2.CustomValidationConfig - ] = None, ) -> anomalies_pb2.Anomalies: """Validates the input statistics against the provided input schema. @@ -313,14 +279,6 @@ def validate_statistics( distribution skew between current data and serving data. Configuration for skew detection can be done by specifying a `skew_comparator` in the schema. - custom_validation_config: An optional config that can be used to specify - custom validations to perform. If doing single-feature validations, - the test feature will come from `statistics` and will be mapped to - `feature` in the SQL query. If doing feature pair validations, the test - feature will come from `statistics` and will be mapped to `feature_test` - in the SQL query, and the base feature will come from - `previous_statistics` and will be mapped to `feature_base` in the SQL - query. Returns: ------- @@ -354,7 +312,6 @@ def validate_statistics( None, None, False, - custom_validation_config, ) @@ -371,9 +328,6 @@ def validate_statistics_internal( ] = None, validation_options: Optional[vo.ValidationOptions] = None, enable_diff_regions: bool = False, - custom_validation_config: Optional[ - custom_validation_config_pb2.CustomValidationConfig - ] = None, ) -> anomalies_pb2.Anomalies: """Validates the input statistics against the provided input schema. @@ -431,14 +385,6 @@ def validate_statistics_internal( enable_diff_regions: Specifies whether to include a comparison between the existing schema and the fixed schema in the Anomalies protocol buffer output. - custom_validation_config: An optional config that can be used to specify - custom validations to perform. If doing single-feature validations, - the test feature will come from `statistics` and will be mapped to - `feature` in the SQL query. If doing feature pair validations, the test - feature will come from `statistics` and will be mapped to `feature_test` - in the SQL query, and the base feature will come from - `previous_statistics` and will be mapped to `feature_base` in the SQL - query. Returns: ------- @@ -570,78 +516,6 @@ def validate_statistics_internal( result = anomalies_pb2.Anomalies() result.ParseFromString(anomalies_proto_string) - if custom_validation_config is not None: - serialized_previous_statistics = ( - previous_span_statistics.SerializeToString() - if previous_span_statistics is not None - else "" - ) - custom_anomalies_string = ( - pywrap_tensorflow_data_validation.CustomValidateStatistics( - tf.compat.as_bytes(statistics.SerializeToString()), - tf.compat.as_bytes(serialized_previous_statistics), - tf.compat.as_bytes(custom_validation_config.SerializeToString()), - tf.compat.as_bytes(environment), - ) - ) - custom_anomalies = anomalies_pb2.Anomalies() - custom_anomalies.ParseFromString(custom_anomalies_string) - result = _merge_custom_anomalies(result, custom_anomalies) - - return result - - -def custom_validate_statistics( - statistics: statistics_pb2.DatasetFeatureStatisticsList, - validations: custom_validation_config_pb2.CustomValidationConfig, - baseline_statistics: Optional[statistics_pb2.DatasetFeatureStatisticsList] = None, - environment: Optional[str] = None, -) -> anomalies_pb2.Anomalies: - """Validates the input statistics with the user-supplied SQL queries. - - If the SQL query from a user-supplied validation returns False, TFDV will - return an anomaly for that validation. In single feature valdiations, the test - feature will be mapped to `feature` in the SQL query. In two feature - validations, the test feature will be mapped to `feature_test` in the SQL - query, and the base feature will be mapped to `feature_base`. - - If an optional `environment` is supplied, TFDV will run validations with - that environment specified and validations with no environment specified. - - Args: - ---- - statistics: A DatasetFeatureStatisticsList protocol buffer that holds the - statistics to validate. - validations: Configuration that specifies the dataset(s) and feature(s) to - validate and the SQL query to use for the validation. The SQL query must - return a boolean value. - baseline_statistics: An optional DatasetFeatureStatisticsList protocol - buffer that holds the baseline statistics used when validating feature - pairs. - environment: If supplied, TFDV will run validations with that - environment specified and validations with no environment specified. If - not supplied, TFDV will run all validations. - - Returns: - ------- - An Anomalies protocol buffer. - """ - serialized_statistics = statistics.SerializeToString() - serialized_baseline_statistics = ( - baseline_statistics.SerializeToString() - if baseline_statistics is not None - else "" - ) - serialized_validations = validations.SerializeToString() - environment = "" if environment is None else environment - serialized_anomalies = pywrap_tensorflow_data_validation.CustomValidateStatistics( - tf.compat.as_bytes(serialized_statistics), - tf.compat.as_bytes(serialized_baseline_statistics), - tf.compat.as_bytes(serialized_validations), - tf.compat.as_bytes(environment), - ) - result = anomalies_pb2.Anomalies() - result.ParseFromString(serialized_anomalies) return result diff --git a/tensorflow_data_validation/api/validation_api_test.py b/tensorflow_data_validation/api/validation_api_test.py index b2f781c2..158cb68b 100644 --- a/tensorflow_data_validation/api/validation_api_test.py +++ b/tensorflow_data_validation/api/validation_api_test.py @@ -34,7 +34,6 @@ import tensorflow_data_validation as tfdv from tensorflow_data_validation import types -from tensorflow_data_validation.anomalies.proto import custom_validation_config_pb2 from tensorflow_data_validation.api import validation_api, validation_options from tensorflow_data_validation.skew.protos import feature_skew_results_pb2 from tensorflow_data_validation.statistics import stats_options @@ -2232,86 +2231,6 @@ def test_validate_stats_invalid_previous_version_stats_multiple_datasets(self): previous_version_statistics=previous_version_stats, ) - def test_validate_stats_with_custom_validations(self): - statistics = text_format.Parse( - """ - datasets{ - num_examples: 10 - features { - path { step: 'annotated_enum' } - type: STRING - string_stats { - common_stats { - num_missing: 3 - num_non_missing: 7 - min_num_values: 1 - max_num_values: 1 - } - unique: 3 - rank_histogram { - buckets { - label: "D" - sample_count: 1 - } - } - } - } - } - """, - statistics_pb2.DatasetFeatureStatisticsList(), - ) - schema = text_format.Parse( - """ - feature { - name: 'annotated_enum' - type: BYTES - unique_constraints { - min: 4 - max: 4 - } - } - """, - schema_pb2.Schema(), - ) - validation_config = text_format.Parse( - """ - feature_validations { - feature_path { step: 'annotated_enum' } - validations { - sql_expression: 'feature.string_stats.common_stats.num_missing < 3' - severity: WARNING - description: 'Feature has too many missing.' - } - } - """, - custom_validation_config_pb2.CustomValidationConfig(), - ) - expected_anomalies = { - "annotated_enum": text_format.Parse( - """ - path { step: 'annotated_enum' } - short_description: 'Multiple errors' - description: 'Expected at least 4 unique values but found only 3. Custom validation triggered anomaly. Query: feature.string_stats.common_stats.num_missing < 3 Test dataset: default slice' - severity: ERROR - reason { - type: FEATURE_TYPE_LOW_UNIQUE - short_description: 'Low number of unique values' - description: 'Expected at least 4 unique values but found only 3.' - } - reason { - type: CUSTOM_VALIDATION - short_description: 'Feature has too many missing.' - description: 'Custom validation triggered anomaly. Query: feature.string_stats.common_stats.num_missing < 3 Test dataset: default slice' - } - """, - anomalies_pb2.AnomalyInfo(), - ) - } - anomalies = validation_api.validate_statistics( - statistics, schema, None, None, None, validation_config - ) - self._assert_equal_anomalies(anomalies, expected_anomalies) - def test_validate_stats_internal_with_previous_version_stats(self): statistics = text_format.Parse( """ @@ -2591,219 +2510,6 @@ def test_validate_stats_internal_with_validation_options_set(self): # pylint: enable=line-too-long - def test_custom_validate_statistics_single_feature(self): - statistics = text_format.Parse( - """ - datasets{ - num_examples: 10 - features { - path { step: 'annotated_enum' } - type: STRING - string_stats { - common_stats { - num_missing: 3 - num_non_missing: 7 - min_num_values: 1 - max_num_values: 1 - } - unique: 3 - rank_histogram { - buckets { - label: "D" - sample_count: 1 - } - } - } - } - } - """, - statistics_pb2.DatasetFeatureStatisticsList(), - ) - config = text_format.Parse( - """ - feature_validations { - feature_path { step: 'annotated_enum' } - validations { - sql_expression: 'feature.string_stats.common_stats.num_missing < 3' - severity: ERROR - description: 'Feature has too many missing.' - } - } - """, - custom_validation_config_pb2.CustomValidationConfig(), - ) - expected_anomalies = { - "annotated_enum": text_format.Parse( - """ - path { step: 'annotated_enum' } - severity: ERROR - reason { - type: CUSTOM_VALIDATION - short_description: 'Feature has too many missing.' - description: 'Custom validation triggered anomaly. Query: feature.string_stats.common_stats.num_missing < 3 Test dataset: default slice' - } - """, - anomalies_pb2.AnomalyInfo(), - ) - } - anomalies = validation_api.custom_validate_statistics(statistics, config) - self._assert_equal_anomalies(anomalies, expected_anomalies) - - def test_custom_validate_statistics_two_features(self): - test_statistics = text_format.Parse( - """ - datasets{ - num_examples: 10 - features { - path { step: 'annotated_enum' } - type: STRING - string_stats { - common_stats { - num_missing: 3 - num_non_missing: 7 - min_num_values: 1 - max_num_values: 1 - } - unique: 10 - rank_histogram { - buckets { - label: "D" - sample_count: 1 - } - } - } - } - } - """, - statistics_pb2.DatasetFeatureStatisticsList(), - ) - base_statistics = text_format.Parse( - """ - datasets{ - num_examples: 10 - features { - path { step: 'annotated_enum' } - type: STRING - string_stats { - common_stats { - num_missing: 3 - num_non_missing: 7 - min_num_values: 1 - max_num_values: 1 - } - unique: 5 - rank_histogram { - buckets { - label: "D" - sample_count: 1 - } - } - } - } - } - """, - statistics_pb2.DatasetFeatureStatisticsList(), - ) - config = text_format.Parse( - """ - feature_pair_validations { - feature_test_path { step: 'annotated_enum' } - feature_base_path { step: 'annotated_enum' } - validations { - sql_expression: 'feature_test.string_stats.unique = feature_base.string_stats.unique' - severity: ERROR - description: 'Test and base do not have same number of uniques.' - } - } - """, - custom_validation_config_pb2.CustomValidationConfig(), - ) - expected_anomalies = { - "annotated_enum": text_format.Parse( - """ - path { step: 'annotated_enum' } - severity: ERROR - reason { - type: CUSTOM_VALIDATION - short_description: 'Test and base do not have same number of uniques.' - description: 'Custom validation triggered anomaly. Query: feature_test.string_stats.unique = feature_base.string_stats.unique Test dataset: default slice Base dataset: Base path: annotated_enum' - } - """, - anomalies_pb2.AnomalyInfo(), - ) - } - anomalies = validation_api.custom_validate_statistics( - test_statistics, config, base_statistics - ) - self._assert_equal_anomalies(anomalies, expected_anomalies) - - def test_custom_validate_statistics_environment(self): - statistics = text_format.Parse( - """ - datasets{ - num_examples: 10 - features { - path { step: 'some_feature' } - type: STRING - string_stats { - common_stats { - num_missing: 3 - num_non_missing: 7 - min_num_values: 1 - max_num_values: 1 - } - unique: 10 - rank_histogram { - buckets { - label: "D" - sample_count: 1 - } - } - } - } - } - """, - statistics_pb2.DatasetFeatureStatisticsList(), - ) - config = text_format.Parse( - """ - feature_validations { - feature_path { step: 'some_feature' } - validations { - sql_expression: 'feature.string_stats.common_stats.num_missing < 1' - severity: ERROR - description: 'Too many missing' - in_environment: 'TRAINING' - } - validations { - sql_expression: 'feature.string_stats.common_stats.num_missing > 5' - severity: ERROR - description: 'Too few missing' - in_environment: 'SERVING' - } - } - """, - custom_validation_config_pb2.CustomValidationConfig(), - ) - expected_anomalies = { - "some_feature": text_format.Parse( - """ - path { step: 'some_feature' } - severity: ERROR - reason { - type: CUSTOM_VALIDATION - short_description: 'Too many missing' - description: 'Custom validation triggered anomaly. Query: feature.string_stats.common_stats.num_missing < 1 Test dataset: default slice' - } - """, - anomalies_pb2.AnomalyInfo(), - ) - } - anomalies = validation_api.custom_validate_statistics( - statistics, config, None, "TRAINING" - ) - self._assert_equal_anomalies(anomalies, expected_anomalies) - def test_validate_instance(self): instance = pa.RecordBatch.from_arrays([pa.array([["D"]])], ["annotated_enum"]) schema = text_format.Parse( diff --git a/tensorflow_data_validation/constants.py b/tensorflow_data_validation/constants.py index 28bef2e6..aa966b2f 100644 --- a/tensorflow_data_validation/constants.py +++ b/tensorflow_data_validation/constants.py @@ -17,9 +17,7 @@ from tfx_bsl.telemetry import util # Name of the default slice containing all examples. -# LINT.IfChange DEFAULT_SLICE_KEY = "All Examples" -# LINT.ThenChange(../anomalies/custom_validation.cc) # Name of the invalid slice containing all examples in the RecordBatch. INVALID_SLICE_KEY = "Invalid Slice" diff --git a/tensorflow_data_validation/move_generated_files.sh b/tensorflow_data_validation/move_generated_files.sh index 99629c07..08ce5abe 100755 --- a/tensorflow_data_validation/move_generated_files.sh +++ b/tensorflow_data_validation/move_generated_files.sh @@ -25,8 +25,6 @@ function tfdv::move_generated_files() { RUNFILES_DIR=$(pwd) cp -f ${RUNFILES_DIR}/tensorflow_data_validation/skew/protos/feature_skew_results_pb2.py \ ${BUILD_WORKSPACE_DIRECTORY}/tensorflow_data_validation/skew/protos - cp -f ${RUNFILES_DIR}/tensorflow_data_validation/anomalies/proto/custom_validation_config_pb2.py \ - ${BUILD_WORKSPACE_DIRECTORY}/tensorflow_data_validation/anomalies/proto cp -f ${RUNFILES_DIR}/tensorflow_data_validation/anomalies/proto/validation_config_pb2.py \ ${BUILD_WORKSPACE_DIRECTORY}/tensorflow_data_validation/anomalies/proto cp -f ${RUNFILES_DIR}/tensorflow_data_validation/anomalies/proto/validation_metadata_pb2.py \ diff --git a/tensorflow_data_validation/pywrap/BUILD b/tensorflow_data_validation/pywrap/BUILD index 77ae1277..95ef6ea8 100644 --- a/tensorflow_data_validation/pywrap/BUILD +++ b/tensorflow_data_validation/pywrap/BUILD @@ -41,7 +41,6 @@ cc_library( ], features = ["-use_header_modules"], deps = [ - "//tensorflow_data_validation/anomalies:custom_validation", "//tensorflow_data_validation/anomalies:feature_statistics_validator", "@pybind11", ], diff --git a/tensorflow_data_validation/pywrap/validation_submodule.cc b/tensorflow_data_validation/pywrap/validation_submodule.cc index 67770709..3eb71050 100644 --- a/tensorflow_data_validation/pywrap/validation_submodule.cc +++ b/tensorflow_data_validation/pywrap/validation_submodule.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "tensorflow_data_validation/pywrap/validation_submodule.h" -#include "tensorflow_data_validation/anomalies/custom_validation.h" #include "tensorflow_data_validation/anomalies/feature_statistics_validator.h" #include "include/pybind11/pybind11.h" @@ -77,22 +76,6 @@ void DefineValidationSubmodule(py::module main_module) { } return py::bytes(anomalies_proto_string); }); - m.def("CustomValidateStatistics", - [](const std::string& test_statistics_string, - const std::string& base_statistics_string, - const std::string& validations_string, - const std::string& environment_string) -> py::object { - std::string anomalies_proto_string; - const absl::Status status = - CustomValidateStatisticsWithSerializedInputs( - test_statistics_string, base_statistics_string, - validations_string, environment_string, - &anomalies_proto_string); - if (!status.ok()) { - throw std::runtime_error(status.ToString()); - } - return py::bytes(anomalies_proto_string); - }); } } // namespace data_validation diff --git a/tensorflow_data_validation/statistics/stats_impl.py b/tensorflow_data_validation/statistics/stats_impl.py index 7754eb0e..a5a2d1ff 100644 --- a/tensorflow_data_validation/statistics/stats_impl.py +++ b/tensorflow_data_validation/statistics/stats_impl.py @@ -98,16 +98,11 @@ def expand( ) if self._options.slicing_config: - slice_fns, slice_sqls = ( - slicing_util.convert_slicing_config_to_slice_functions_and_sqls( - self._options.slicing_config - ) + slice_fns = slicing_util.convert_slicing_config_to_slice_functions( + self._options.slicing_config ) else: - slice_fns, slice_sqls = ( - self._options.experimental_slice_functions, - self._options.experimental_slice_sqls, - ) + slice_fns = self._options.experimental_slice_functions if slice_fns: # Add default slicing function. @@ -116,10 +111,6 @@ def expand( dataset = dataset | "GenerateSliceKeys" >> beam.FlatMap( slicing_util.generate_slices, slice_functions=slice_functions ) - elif slice_sqls: - dataset = dataset | "GenerateSlicesSql" >> beam.ParDo( - slicing_util.GenerateSlicesSqlDoFn(slice_sqls=slice_sqls) - ) else: dataset = dataset | "KeyWithVoid" >> beam.Map(lambda v: (None, v)) _ = dataset | "TrackDistinctSliceKeys" >> _TrackDistinctSliceKeys() # pylint: disable=no-value-for-parameter @@ -234,15 +225,14 @@ def __init__( ---- options: `tfdv.StatsOptions` for generating data statistics. is_slicing_enabled: Whether to include slice keys in the resulting proto, - even if slice functions or slicing SQL queries are not provided in - `options`. If slice functions or slicing SQL queries are provided in + even if slice functions are not provided in + `options`. If slice functions are provided in `options`, slice keys are included regardless of this value. """ self._options = options self._is_slicing_enabled = ( is_slicing_enabled or bool(self._options.experimental_slice_functions) - or bool(self._options.experimental_slice_sqls) or bool(self._options.slicing_config) ) diff --git a/tensorflow_data_validation/statistics/stats_impl_test.py b/tensorflow_data_validation/statistics/stats_impl_test.py index e994e7ed..dbd9b38c 100644 --- a/tensorflow_data_validation/statistics/stats_impl_test.py +++ b/tensorflow_data_validation/statistics/stats_impl_test.py @@ -2057,56 +2057,6 @@ def extract_output(self, accumulator): ] -_SLICING_SQL_TESTS = [ - { - "testcase_name": "feature_value_slicing_slice_sqls", - "record_batches": [ - pa.RecordBatch.from_arrays( - [ - pa.array([[1.0, 2.0]], type=pa.list_(pa.float32())), - pa.array([[b"a"]], type=pa.list_(pa.binary())), - pa.array([np.linspace(1, 500, 500, dtype=np.int64)]), - ], - ["a", "b", "c"], - ), - pa.RecordBatch.from_arrays( - [ - pa.array([[3.0, 4.0, np.nan, 5.0]], type=pa.list_(pa.float32())), - pa.array([[b"a", b"b"]], type=pa.list_(pa.binary())), - pa.array([np.linspace(501, 1250, 750, dtype=np.int64)]), - ], - ["a", "b", "c"], - ), - pa.RecordBatch.from_arrays( - [ - pa.array([[1.0]], type=pa.list_(pa.float32())), - pa.array([[b"b"]], type=pa.list_(pa.binary())), - pa.array([np.linspace(1251, 3000, 1750, dtype=np.int64)]), - ], - ["a", "b", "c"], - ), - ], - "options": stats_options.StatsOptions( - experimental_slice_sqls=[ - """ - SELECT - STRUCT(b) - FROM - example.b - """ - ], - num_top_values=2, - num_rank_histogram_buckets=2, - num_values_histogram_buckets=2, - num_histogram_buckets=2, - num_quantiles_histogram_buckets=2, - enable_semantic_domain_stats=True, - ), - "expected_result_proto_text": _SLICED_STATS_TEST_RESULT, - }, -] - - def _get_singleton_dataset( statistics: statistics_pb2.DatasetFeatureStatisticsList, ) -> statistics_pb2.DatasetFeatureStatistics: @@ -2209,126 +2159,6 @@ def test_stats_impl( ), ) - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") - def test_stats_impl_slicing_sql(self): - record_batches = [ - pa.RecordBatch.from_arrays( - [ - pa.array([[1.0, 2.0]], type=pa.list_(pa.float32())), - pa.array([[b"a"]], type=pa.list_(pa.binary())), - pa.array([np.linspace(1, 500, 500, dtype=np.int64)]), - ], - ["a", "b", "c"], - ), - pa.RecordBatch.from_arrays( - [ - pa.array([[3.0, 4.0, np.nan, 5.0]], type=pa.list_(pa.float32())), - pa.array([[b"a", b"b"]], type=pa.list_(pa.binary())), - pa.array([np.linspace(501, 1250, 750, dtype=np.int64)]), - ], - ["a", "b", "c"], - ), - pa.RecordBatch.from_arrays( - [ - pa.array([[1.0]], type=pa.list_(pa.float32())), - pa.array([[b"b"]], type=pa.list_(pa.binary())), - pa.array([np.linspace(1251, 3000, 1750, dtype=np.int64)]), - ], - ["a", "b", "c"], - ), - ] - options = stats_options.StatsOptions( - experimental_slice_sqls=[ - """ - SELECT - STRUCT(b) - FROM - example.b - """ - ], - num_top_values=2, - num_rank_histogram_buckets=2, - num_values_histogram_buckets=2, - num_histogram_buckets=2, - num_quantiles_histogram_buckets=2, - enable_semantic_domain_stats=True, - ) - expected_result = text_format.Parse( - _SLICED_STATS_TEST_RESULT, statistics_pb2.DatasetFeatureStatisticsList() - ) - with beam.Pipeline() as p: - result = ( - p - | beam.Create(record_batches, reshuffle=False) - | stats_impl.GenerateStatisticsImpl(options) - ) - util.assert_that( - result, - test_util.make_dataset_feature_stats_list_proto_equal_fn( - self, expected_result, check_histograms=False - ), - ) - - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") - def test_stats_impl_slicing_sql_in_config(self): - record_batches = [ - pa.RecordBatch.from_arrays( - [ - pa.array([[1.0, 2.0]], type=pa.list_(pa.float32())), - pa.array([[b"a"]], type=pa.list_(pa.binary())), - pa.array([np.linspace(1, 500, 500, dtype=np.int64)]), - ], - ["a", "b", "c"], - ), - pa.RecordBatch.from_arrays( - [ - pa.array([[3.0, 4.0, np.nan, 5.0]], type=pa.list_(pa.float32())), - pa.array([[b"a", b"b"]], type=pa.list_(pa.binary())), - pa.array([np.linspace(501, 1250, 750, dtype=np.int64)]), - ], - ["a", "b", "c"], - ), - pa.RecordBatch.from_arrays( - [ - pa.array([[1.0]], type=pa.list_(pa.float32())), - pa.array([[b"b"]], type=pa.list_(pa.binary())), - pa.array([np.linspace(1251, 3000, 1750, dtype=np.int64)]), - ], - ["a", "b", "c"], - ), - ] - options = stats_options.StatsOptions( - slicing_config=text_format.Parse( - """ - slicing_specs { - slice_keys_sql: "SELECT STRUCT(b) FROM example.b" - } - """, - slicing_spec_pb2.SlicingConfig(), - ), - num_top_values=2, - num_rank_histogram_buckets=2, - num_values_histogram_buckets=2, - num_histogram_buckets=2, - num_quantiles_histogram_buckets=2, - enable_semantic_domain_stats=True, - ) - expected_result = text_format.Parse( - _SLICED_STATS_TEST_RESULT, statistics_pb2.DatasetFeatureStatisticsList() - ) - with beam.Pipeline() as p: - result = ( - p - | beam.Create(record_batches, reshuffle=False) - | stats_impl.GenerateStatisticsImpl(options) - ) - util.assert_that( - result, - test_util.make_dataset_feature_stats_list_proto_equal_fn( - self, expected_result, check_histograms=False - ), - ) - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") def test_nld_features(self): record_batches = [pa.RecordBatch.from_arrays([pa.array([[1]])], ["f1"])] diff --git a/tensorflow_data_validation/statistics/stats_options.py b/tensorflow_data_validation/statistics/stats_options.py index 7f597b8f..55426efa 100644 --- a/tensorflow_data_validation/statistics/stats_options.py +++ b/tensorflow_data_validation/statistics/stats_options.py @@ -16,14 +16,11 @@ import copy import json -import logging import types as python_types from typing import Dict, List, Optional, Union from google.protobuf import json_format from tensorflow_metadata.proto.v0 import schema_pb2 -from tfx_bsl.arrow import sql_util -from tfx_bsl.coders import example_coder from tfx_bsl.public.proto import slicing_spec_pb2 from tensorflow_data_validation import types @@ -31,7 +28,6 @@ from tensorflow_data_validation.utils import ( example_weight_map, schema_util, - slicing_util, ) _SCHEMA_JSON_KEY = "schema_json" @@ -81,7 +77,6 @@ def __init__( experimental_use_sketch_based_topk_uniques: Optional[bool] = None, use_sketch_based_topk_uniques: Optional[bool] = None, experimental_slice_functions: Optional[List[types.SliceFunction]] = None, - experimental_slice_sqls: Optional[List[str]] = None, experimental_result_partitions: int = 1, experimental_num_feature_partitions: int = 1, slicing_config: Optional[slicing_spec_pb2.SlicingConfig] = None, @@ -166,25 +161,8 @@ def __init__( pyarrow.RecordBatch as input and return an Iterable[Tuple[Text, pyarrow.RecordBatch]]. Each tuple contains the slice key and the corresponding sliced RecordBatch. Only one of - experimental_slice_functions or experimental_slice_sqls must be + experimental_slice_functions must be specified. - experimental_slice_sqls: List of slicing SQL queries. The query must have - the following pattern: "SELECT STRUCT({feature_name} [AS {slice_key}]) - [FROM example.feature_name [, example.feature_name, ... ] [WHERE ... ]]" - The “example.feature_name” inside the FROM statement is used to flatten - the repeated fields. For non-repeated fields, you can directly write the - query as follows: “SELECT STRUCT(non_repeated_feature_a, - non_repeated_feature_b)” In the query, the “example” is a key word that - binds to each input "row". The semantics of this variable will depend on - the decoding of the input data to the Arrow representation (e.g., for - tf.Example, each key is decoded to a separate column). Thus, structured - data can be readily accessed by iterating/unnesting the fields of the - "example" variable. Example 1: Slice on each value of a feature "SELECT - STRUCT(gender) FROM example.gender" Example 2: Slice on each value of - one feature and a specified value of another. "SELECT STRUCT(gender, - country) FROM example.gender, example.country WHERE country = 'USA'" - Only one of experimental_slice_functions or experimental_slice_sqls must - be specified. experimental_result_partitions: The number of feature partitions to combine output DatasetFeatureStatisticsLists into. If set to 1 (default) output is globally combined. If set to value greater than one, up to @@ -195,7 +173,7 @@ def __init__( number of features in a dataset, and never more than the available beam parallelism. slicing_config: an optional SlicingConfig. SlicingConfig includes - slicing_specs specified with feature keys, feature values or slicing SQL + slicing_specs specified with feature keys or feature values queries. experimental_filter_read_paths: If provided, tries to push down either paths passed via feature_allowlist or via the schema (in that priority) @@ -246,7 +224,6 @@ def __init__( self.use_sketch_based_topk_uniques = True else: self.use_sketch_based_topk_uniques = False - self.experimental_slice_sqls = experimental_slice_sqls self.experimental_num_feature_partitions = experimental_num_feature_partitions self.experimental_result_partitions = experimental_result_partitions self.slicing_config = slicing_config @@ -424,8 +401,6 @@ def experimental_slice_functions(self) -> Optional[List[types.SliceFunction]]: def experimental_slice_functions( self, slice_functions: Optional[List[types.SliceFunction]] ) -> None: - if hasattr(self, "experimental_slice_sqls"): - _validate_slicing_options(slice_functions, self.experimental_slice_sqls) if slice_functions is not None: if not isinstance(slice_functions, list): raise TypeError( @@ -439,19 +414,6 @@ def experimental_slice_functions( ) self._slice_functions = slice_functions - @property - def experimental_slice_sqls(self) -> Optional[List[str]]: - return self._slice_sqls - - @experimental_slice_sqls.setter - def experimental_slice_sqls(self, slice_sqls: Optional[List[str]]) -> None: - if hasattr(self, "experimental_slice_functions"): - _validate_slicing_options(self.experimental_slice_functions, slice_sqls) - if slice_sqls and self.schema: - for slice_sql in slice_sqls: - _validate_sql(slice_sql, self.schema) - self._slice_sqls = slice_sqls - @property def slicing_config(self) -> Optional[slicing_spec_pb2.SlicingConfig]: return self._slicing_config @@ -460,18 +422,11 @@ def slicing_config(self) -> Optional[slicing_spec_pb2.SlicingConfig]: def slicing_config( self, slicing_config: Optional[slicing_spec_pb2.SlicingConfig] ) -> None: - _validate_slicing_config(slicing_config) - if slicing_config and self.experimental_slice_functions: raise ValueError( "Specify only one of slicing_config or experimental_slice_functions." ) - if slicing_config and self.experimental_slice_sqls: - raise ValueError( - "Specify only one of slicing_config or experimental_slice_sqls." - ) - self._slicing_config = slicing_config @property @@ -638,67 +593,3 @@ def per_feature_stats_config( self, features_config: types.PerFeatureStatsConfig ) -> None: self._per_feature_stats_config = features_config - - -def _validate_sql(sql_query: str, schema: schema_pb2.Schema): - arrow_schema = example_coder.ExamplesToRecordBatchDecoder( - schema.SerializeToString() - ).ArrowSchema() - formatted_query = slicing_util.format_slice_sql_query(sql_query) - try: - sql_util.RecordBatchSQLSliceQuery(formatted_query, arrow_schema) - except Exception as e: # pylint: disable=broad-except - # The schema passed to TFDV initially may be incomplete, so we can't crash - # on what may be an error caused by missing features. - logging.error( - "One of the slice SQL query %s raised an exception: %s.", sql_query, repr(e) - ) - - -def _validate_slicing_options( - slice_fns: Optional[List[types.SliceFunction]] = None, - slice_sqls: Optional[List[str]] = None, -): - if slice_fns and slice_sqls: - raise ValueError( - "Only one of experimental_slice_functions or " - "experimental_slice_sqls must be specified." - ) - - -def _validate_slicing_config(slicing_config: Optional[slicing_spec_pb2.SlicingConfig]): - """Validates slicing config. - - Args: - ---- - slicing_config: an optional list of slicing specifications. Slicing - specifications can be provided by feature keys, feature values or slicing - SQL queries. - - Returns: - ------- - None if slicing_config is None. - - Raises: - ------ - ValueError: If both slicing functions and slicing sql queries are specified - in the slicing config. - """ - if slicing_config is None: - return - - has_slice_fns, has_slice_sqls = False, False - - for slicing_spec in slicing_config.slicing_specs: - if (not has_slice_fns) and ( - slicing_spec.feature_keys or slicing_spec.feature_values - ): - has_slice_fns = True - if (not has_slice_sqls) and slicing_spec.slice_keys_sql: - has_slice_sqls = True - - if has_slice_fns and has_slice_sqls: - raise ValueError( - "Only one of slicing features or slicing sql queries can be " - "specified in the slicing config." - ) diff --git a/tensorflow_data_validation/statistics/stats_options_test.py b/tensorflow_data_validation/statistics/stats_options_test.py index f9c601e7..5ea5d735 100644 --- a/tensorflow_data_validation/statistics/stats_options_test.py +++ b/tensorflow_data_validation/statistics/stats_options_test.py @@ -172,15 +172,6 @@ "use_sketch_based_topk_uniques" ), }, - { - "testcase_name": "both_slice_fns_and_slice_sqls_specified", - "stats_options_kwargs": { - "experimental_slice_functions": [lambda x: (None, x)], - "experimental_slice_sqls": [""], - }, - "exception_type": ValueError, - "error_message": "Only one of experimental_slice_functions or", - }, { "testcase_name": "both_slicing_config_and_slice_fns_specified", "stats_options_kwargs": { @@ -197,41 +188,6 @@ "exception_type": ValueError, "error_message": "Specify only one of slicing_config or experimental_slice_functions.", }, - { - "testcase_name": "both_slicing_config_and_slice_sqls_specified", - "stats_options_kwargs": { - "experimental_slice_sqls": [""], - "slicing_config": text_format.Parse( - """ - slicing_specs { - feature_keys: ["country", "city"] - } - """, - slicing_spec_pb2.SlicingConfig(), - ), - }, - "exception_type": ValueError, - "error_message": "Specify only one of slicing_config or experimental_slice_sqls.", - }, - { - "testcase_name": "both_functions_and_sqls_in_slicing_config", - "stats_options_kwargs": { - "slicing_config": text_format.Parse( - """ - slicing_specs { - feature_keys: ["country", "city"] - } - slicing_specs { - slice_keys_sql: "SELECT STRUCT(education) FROM example.education" - } - """, - slicing_spec_pb2.SlicingConfig(), - ), - }, - "exception_type": ValueError, - "error_message": "Only one of slicing features or slicing sql queries can be " - "specified in the slicing config.", - }, ] @@ -241,30 +197,6 @@ def test_stats_options(self, stats_options_kwargs, exception_type, error_message with self.assertRaisesRegex(exception_type, error_message): stats_options.StatsOptions(**stats_options_kwargs) - def test_stats_options_invalid_slicing_sql_query(self): - schema = schema_pb2.Schema( - feature=[ - schema_pb2.Feature(name="feat1", type=schema_pb2.BYTES), - schema_pb2.Feature(name="feat3", type=schema_pb2.INT), - ], - ) - experimental_slice_sqls = [ - """ - SELECT - STRUCT(feat1, feat2) - FROM - example.feat1, example.feat2 - """ - ] - with self.assertLogs(level="ERROR") as log_output: - stats_options.StatsOptions( - experimental_slice_sqls=experimental_slice_sqls, schema=schema - ) - self.assertLen(log_output.records, 1) - self.assertRegex( - log_output.records[0].message, "One of the slice SQL query .*" - ) - def test_valid_stats_options_json_round_trip(self): feature_allowlist = ["a"] schema = schema_pb2.Schema(feature=[schema_pb2.Feature(name="f")]) @@ -426,7 +358,6 @@ def test_stats_options_from_json( "_per_feature_weight_override": null, "_add_default_generators": true, "_use_sketch_based_topk_uniques": false, - "_slice_sqls": null, "_experimental_result_partitions": 1, "_experimental_num_feature_partitions": 1, "_slicing_config": null, diff --git a/tensorflow_data_validation/utils/slicing_util.py b/tensorflow_data_validation/utils/slicing_util.py index b50ee8af..eae11d3c 100644 --- a/tensorflow_data_validation/utils/slicing_util.py +++ b/tensorflow_data_validation/utils/slicing_util.py @@ -13,13 +13,11 @@ # limitations under the License. """Utility function for generating slicing functions.""" -import collections import functools import logging from collections import abc -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Union -import apache_beam as beam import numpy as np import pandas as pd @@ -29,7 +27,7 @@ import pyarrow as pa import six from tensorflow_metadata.proto.v0 import statistics_pb2 -from tfx_bsl.arrow import array_util, sql_util, table_util +from tfx_bsl.arrow import array_util, table_util from tfx_bsl.public.proto import slicing_spec_pb2 from tensorflow_data_validation import constants, types @@ -155,41 +153,41 @@ def feature_value_slicer( "the provided slice values are not valid integers." ) from e - flattened, value_parent_indices = array_util.flatten_nested( - feature_array, True - ) - non_missing_values = np.asarray(flattened) - # Create dataframe with feature value and parent index. - df = pd.DataFrame( - { - feature_name: non_missing_values, - _PARENT_INDEX_COLUMN: value_parent_indices, - } - ) - df = df.drop_duplicates() - # Filter based on slice values - if values is not None: - df = df.loc[df[feature_name].isin(values)] - per_feature_parent_indices.append(df) - # If there are no features to slice on, yield no output. - # TODO(b/200081813): Produce output with an appropriate placeholder key. - if not per_feature_parent_indices: - return - # Join dataframes based on parent indices. - # Note that we want the parent indices per slice key to be sorted in the - # merged dataframe. The individual dataframes have the parent indices in - # sorted order. We use "inner" join type to preserve the order of the left - # keys (also note that same parent index rows would be consecutive). Hence - # we expect the merged dataframe to have sorted parent indices per - # slice key. - merged_df = functools.reduce( - lambda base, update: base.merge( - update, - how="inner", # pylint: disable=g-long-lambda - on=_PARENT_INDEX_COLUMN, - ), - per_feature_parent_indices, - ) + flattened, value_parent_indices = array_util.flatten_nested( + feature_array, True + ) + non_missing_values = np.asarray(flattened) + # Create dataframe with feature value and parent index. + df = pd.DataFrame( + { + feature_name: non_missing_values, + _PARENT_INDEX_COLUMN: value_parent_indices, + } + ) + df = df.drop_duplicates() + # Filter based on slice values + if values is not None: + df = df.loc[df[feature_name].isin(values)] + per_feature_parent_indices.append(df) + # If there are no features to slice on, yield no output. + # TODO(b/200081813): Produce output with an appropriate placeholder key. + if not per_feature_parent_indices: + return + # Join dataframes based on parent indices. + # Note that we want the parent indices per slice key to be sorted in the + # merged dataframe. The individual dataframes have the parent indices in + # sorted order. We use "inner" join type to preserve the order of the left + # keys (also note that same parent index rows would be consecutive). Hence + # we expect the merged dataframe to have sorted parent indices per + # slice key. + merged_df = functools.reduce( + lambda base, update: base.merge( + update, + how="inner", # pylint: disable=g-long-lambda + on=_PARENT_INDEX_COLUMN, + ), + per_feature_parent_indices, + ) # Construct a new column in the merged dataframe with the slice keys. merged_df[_SLICE_KEY_COLUMN] = "" @@ -267,42 +265,26 @@ def generate_slices( ) -def format_slice_sql_query(slice_sql_query: str) -> str: - return f""" - SELECT - ARRAY( - {slice_sql_query} - ) as slice_key - FROM Examples as example;""" - - -def convert_slicing_config_to_slice_functions_and_sqls( +def convert_slicing_config_to_slice_functions( slicing_config: Optional[slicing_spec_pb2.SlicingConfig], -) -> Tuple[List[types.SliceFunction], List[str]]: - """Convert slicing config to a tuple of slice functions and sql queries. +) -> List[types.SliceFunction]: + """Convert slicing config to a list of slice functions. Args: ---- slicing_config: an optional list of slicing specifications. Slicing - specifications can be provided by feature keys, feature values or slicing - SQL queries. + specifications can be provided by feature keys, or feature values Returns: ------- - A tuple consisting of a list of slice functions and a list of slice sql - queries. + A list of slice functions. """ if not slicing_config: - return [], [] + return [] slice_function_list = [] - slice_keys_sql_list = [] for slicing_spec in slicing_config.slicing_specs: # checking overall slice - if ( - not slicing_spec.feature_keys - and not slicing_spec.feature_values - and not slicing_spec.slice_keys_sql - ): + if not slicing_spec.feature_keys and not slicing_spec.feature_values: logging.info("The entire dataset is already included as a slice.") continue @@ -315,83 +297,4 @@ def convert_slicing_config_to_slice_functions_and_sqls( if slice_spec_dict: slice_function_list.append(get_feature_value_slicer(slice_spec_dict)) - if slicing_spec.slice_keys_sql: - slice_keys_sql_list.append(slicing_spec.slice_keys_sql) - - return slice_function_list, slice_keys_sql_list - - -class GenerateSlicesSqlDoFn(beam.DoFn): - """A DoFn that extracts slice keys in batch based on input SQL.""" - - def __init__(self, slice_sqls: List[str]): - self._sqls = [format_slice_sql_query(slice_sql) for slice_sql in slice_sqls] - self._sql_slicer_schema_cache_hits = beam.metrics.Metrics.distribution( - constants.METRICS_NAMESPACE, "sql_slicer_schema_cache_hits" - ) - self._sql_slicer_schema_cache_misses = beam.metrics.Metrics.distribution( - constants.METRICS_NAMESPACE, "sql_slicer_schema_cache_misses" - ) - - def setup(self): - def _generate_queries( - schema: pa.Schema, - ) -> List[sql_util.RecordBatchSQLSliceQuery]: - queries = [] - for sql in self._sqls: - try: - queries.append(sql_util.RecordBatchSQLSliceQuery(sql, schema)) - except RuntimeError as error: - # We can't crash on errors caused by missing features/values. - # Instead failed slicing sqls will create a Invalid Slice. - logging.warning("Failed to parse SQL query %r: %r", sql, error) - queries.append(None) - return queries - - # A cache for compiled sql queries, keyed by record batch schemas. - # This way we can work with record batches of different schemas. - self._get_queries_for_schema = functools.lru_cache(maxsize=3)(_generate_queries) - - def process( - self, record_batch: pa.RecordBatch - ) -> Iterable[types.SlicedRecordBatch]: - # Keep track of row indices per slice key. - per_slice_indices = collections.defaultdict(set) - if record_batch.schema.metadata is not None: - # record_batch may have unhashable schema metadata if derived features are - # being used, so we construct a new schema that strips that information. - cache_schema = pa.schema( - zip(record_batch.schema.names, record_batch.schema.types) - ) - else: - cache_schema = record_batch.schema - for query in self._get_queries_for_schema(cache_schema): - # Example of result with batch size = 3: - # result = [[[('feature', 'value_1')]], - # [[('feature', 'value_2')]], - # [] - # ] - if query is None: - yield (constants.INVALID_SLICE_KEY, record_batch) - continue - - result = query.Execute(record_batch) - for i, per_row_slices in enumerate(result): - for slice_tuples in per_row_slices: - slice_key = "_".join(map("_".join, slice_tuples)) - per_slice_indices[slice_key].add(i) - - yield (constants.DEFAULT_SLICE_KEY, record_batch) - for slice_key, row_indices in per_slice_indices.items(): - yield ( - slice_key, - table_util.RecordBatchTake(record_batch, pa.array(row_indices)), - ) - - def teardown(self): - self._sql_slicer_schema_cache_hits.update( - self._get_queries_for_schema.cache_info().hits - ) - self._sql_slicer_schema_cache_misses.update( - self._get_queries_for_schema.cache_info().misses - ) + return slice_function_list diff --git a/tensorflow_data_validation/utils/slicing_util_test.py b/tensorflow_data_validation/utils/slicing_util_test.py index 7675b0d5..7995bf9e 100644 --- a/tensorflow_data_validation/utils/slicing_util_test.py +++ b/tensorflow_data_validation/utils/slicing_util_test.py @@ -13,15 +13,11 @@ # limitations under the License. """Tests for the slicing utilities.""" -import apache_beam as beam import pyarrow as pa -import pytest from absl.testing import absltest -from apache_beam.testing import util from google.protobuf import text_format from tfx_bsl.public.proto import slicing_spec_pb2 -from tensorflow_data_validation import constants from tensorflow_data_validation.utils import slicing_util @@ -233,26 +229,7 @@ def test_get_feature_value_slicer_non_utf8_slice_key(self): slicing_util.get_feature_value_slicer(features)(input_record_batch) ) - def test_convert_slicing_config_to_fns_and_sqls(self): - slicing_config = text_format.Parse( - """ - slicing_specs { - slice_keys_sql: "SELECT STRUCT(education) FROM example.education" - } - """, - slicing_spec_pb2.SlicingConfig(), - ) - - slicing_fns, slicing_sqls = ( - slicing_util.convert_slicing_config_to_slice_functions_and_sqls( - slicing_config - ) - ) - self.assertEqual(slicing_fns, []) - self.assertEqual( - slicing_sqls, ["SELECT STRUCT(education) FROM example.education"] - ) - + def test_convert_slicing_config_to_fns(self): slicing_config = text_format.Parse( """ slicing_specs {} @@ -267,13 +244,10 @@ def test_convert_slicing_config_to_fns_and_sqls(self): slicing_spec_pb2.SlicingConfig(), ) - slicing_fns, slicing_sqls = ( - slicing_util.convert_slicing_config_to_slice_functions_and_sqls( - slicing_config - ) + slicing_fns = slicing_util.convert_slicing_config_to_slice_functions( + slicing_config ) self.assertLen(slicing_fns, 2) - self.assertEqual(slicing_sqls, []) slicing_config = text_format.Parse( """ @@ -298,14 +272,12 @@ def test_convert_slicing_config_to_fns_and_sqls(self): ), ), ] - slicing_fns, slicing_sqls = ( - slicing_util.convert_slicing_config_to_slice_functions_and_sqls( - slicing_config - ) + slicing_fns = slicing_util.convert_slicing_config_to_slice_functions( + slicing_config ) self._check_results(slicing_fns[0](input_record_batch), expected_result) - def test_convert_slicing_config_to_fns_and_sqls_on_int_field(self): + def test_convert_slicing_config_to_fns_on_int_field(self): slicing_config = text_format.Parse( """ slicing_specs { @@ -329,14 +301,12 @@ def test_convert_slicing_config_to_fns_and_sqls_on_int_field(self): ), ), ] - slicing_fns, _ = ( - slicing_util.convert_slicing_config_to_slice_functions_and_sqls( - slicing_config - ) + slicing_fns = slicing_util.convert_slicing_config_to_slice_functions( + slicing_config ) self._check_results(slicing_fns[0](input_record_batch), expected_result) - def test_convert_slicing_config_to_fns_and_sqls_on_int_invalid(self): + def test_convert_slicing_config_to_fns_on_int_invalid(self): slicing_config = text_format.Parse( """ slicing_specs { @@ -361,10 +331,8 @@ def test_convert_slicing_config_to_fns_and_sqls_on_int_invalid(self): ), ), ] - slicing_fns, _ = ( - slicing_util.convert_slicing_config_to_slice_functions_and_sqls( - slicing_config - ) + slicing_fns = slicing_util.convert_slicing_config_to_slice_functions( + slicing_config ) with self.assertRaisesRegex( @@ -372,291 +340,6 @@ def test_convert_slicing_config_to_fns_and_sqls_on_int_invalid(self): ): self._check_results(slicing_fns[0](input_record_batch), expected_result) - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") - def test_generate_slices_sql(self): - input_record_batches = [ - pa.RecordBatch.from_arrays( - [ - pa.array([[1], [2, 1], [3], [2, 1, 1], [3]]), - pa.array([["dog"], ["cat"], ["wolf"], ["dog", "wolf"], ["wolf"]]), - ], - ["a", "b"], - ), - pa.RecordBatch.from_arrays( - [pa.array([[1]]), pa.array([["dog"]]), pa.array([[1]])], ["a", "b", "c"] - ), - pa.RecordBatch.from_arrays( - [pa.array([[1]]), pa.array([["cat"]]), pa.array([[1]])], ["a", "b", "d"] - ), - pa.RecordBatch.from_arrays( - [pa.array([[1]]), pa.array([["cat"]]), pa.array([[1]])], ["a", "b", "e"] - ), - pa.RecordBatch.from_arrays( - [pa.array([[1]]), pa.array([["cat"]]), pa.array([[1]])], ["a", "b", "f"] - ), - ] - record_batch_with_metadata = pa.RecordBatch.from_arrays( - [pa.array([[1]]), pa.array([["cat"]])], ["a", "b"] - ) - record_batch_with_metadata = pa.RecordBatch.from_arrays( - arrays=record_batch_with_metadata.columns, - schema=record_batch_with_metadata.schema.with_metadata({b"foo": "bar"}), - ) - input_record_batches.append(record_batch_with_metadata) - slice_sql = """ - SELECT - STRUCT(a, b) - FROM - example.a, example.b - """ - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | "Create" >> beam.Create(input_record_batches, reshuffle=False) - | "GenerateSlicesSql" - >> beam.ParDo( - slicing_util.GenerateSlicesSqlDoFn(slice_sqls=[slice_sql]) - ) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self.assertLen(got, 18) - expected_slice_keys = [ - "a_1_b_dog", - "a_1_b_cat", - "a_2_b_cat", - "a_2_b_dog", - "a_1_b_wolf", - "a_2_b_wolf", - "a_3_b_wolf", - "a_1_b_dog", - "a_1_b_cat", - "a_1_b_cat", - "a_1_b_cat", - "a_1_b_cat", - ] + [constants.DEFAULT_SLICE_KEY] * 6 - actual_slice_keys = [slice_key for (slice_key, _) in got] - self.assertCountEqual(expected_slice_keys, actual_slice_keys) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") - def test_generate_slices_sql_assert_record_batches(self): - input_record_batches = [ - pa.RecordBatch.from_arrays( - [ - pa.array([[1], [2, 1], [3], [2, 1, 1], [3]]), - pa.array([["dog"], ["cat"], ["wolf"], ["dog", "wolf"], ["wolf"]]), - ], - ["a", "b"], - ), - ] - slice_sql = """ - SELECT - STRUCT(a, b) - FROM - example.a, example.b - """ - expected_result = [ - ( - "a_1_b_dog", - pa.RecordBatch.from_arrays( - [pa.array([[1], [2, 1, 1]]), pa.array([["dog"], ["dog", "wolf"]])], - ["a", "b"], - ), - ), - ( - "a_1_b_cat", - pa.RecordBatch.from_arrays( - [pa.array([[2, 1]]), pa.array([["cat"]])], ["a", "b"] - ), - ), - ( - "a_2_b_cat", - pa.RecordBatch.from_arrays( - [pa.array([[2, 1]]), pa.array([["cat"]])], ["a", "b"] - ), - ), - ( - "a_2_b_dog", - pa.RecordBatch.from_arrays( - [pa.array([[2, 1, 1]]), pa.array([["dog", "wolf"]])], ["a", "b"] - ), - ), - ( - "a_1_b_wolf", - pa.RecordBatch.from_arrays( - [pa.array([[2, 1, 1]]), pa.array([["dog", "wolf"]])], ["a", "b"] - ), - ), - ( - "a_2_b_wolf", - pa.RecordBatch.from_arrays( - [pa.array([[2, 1, 1]]), pa.array([["dog", "wolf"]])], ["a", "b"] - ), - ), - ( - "a_3_b_wolf", - pa.RecordBatch.from_arrays( - [pa.array([[3], [3]]), pa.array([["wolf"], ["wolf"]])], ["a", "b"] - ), - ), - (constants.DEFAULT_SLICE_KEY, input_record_batches[0]), - ] - - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | "Create" >> beam.Create(input_record_batches, reshuffle=False) - | "GenerateSlicesSql" - >> beam.ParDo( - slicing_util.GenerateSlicesSqlDoFn(slice_sqls=[slice_sql]) - ) - ) - - # pylint: enable=no-value-for-parameter - - def check_result(got): - try: - self._check_results(got, expected_result) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") - def test_generate_slices_sql_invalid_slice(self): - input_record_batches = [ - pa.RecordBatch.from_arrays( - [ - pa.array([[1], [2, 1], [3], [2, 1, 1], [3]]), - pa.array([[], [], [], [], []]), - ], - ["a", "b"], - ), - ] - slice_sql1 = """ - SELECT - STRUCT(a, b) - FROM - example.a, example.b - """ - - expected_result = [ - (constants.INVALID_SLICE_KEY, input_record_batches[0]), - (constants.DEFAULT_SLICE_KEY, input_record_batches[0]), - ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | "Create" >> beam.Create(input_record_batches, reshuffle=False) - | "GenerateSlicesSql" - >> beam.ParDo( - slicing_util.GenerateSlicesSqlDoFn(slice_sqls=[slice_sql1]) - ) - ) - - def check_result(got): - try: - self._check_results(got, expected_result) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - - @pytest.mark.xfail(run=False, reason="This test fails and needs to be fixed.") - def test_generate_slices_sql_multiple_queries(self): - input_record_batches = [ - pa.RecordBatch.from_arrays( - [ - pa.array([[1], [2, 1], [3], [2, 1, 1], [3]]), - pa.array([[], [], [], [], []]), - ], - ["a", "b"], - ), - ] - slice_sql1 = """ - SELECT - STRUCT(c) - FROM - example.a, example.b - """ - - slice_sql2 = """ - SELECT - STRUCT(a) - FROM - example.a - """ - - expected_result = [ - ( - "a_1", - pa.RecordBatch.from_arrays( - [ - pa.array([[1], [2, 1], [2, 1, 1]]), - pa.array([[], [], []]), - ], - ["a", "b"], - ), - ), - ( - "a_2", - pa.RecordBatch.from_arrays( - [ - pa.array([[2, 1], [2, 1, 1]]), - pa.array([[], []]), - ], - ["a", "b"], - ), - ), - ( - "a_3", - pa.RecordBatch.from_arrays( - [ - pa.array([[3], [3]]), - pa.array([[], []]), - ], - ["a", "b"], - ), - ), - (constants.INVALID_SLICE_KEY, input_record_batches[0]), - (constants.DEFAULT_SLICE_KEY, input_record_batches[0]), - ] - with beam.Pipeline() as pipeline: - # pylint: disable=no-value-for-parameter - result = ( - pipeline - | "Create" >> beam.Create(input_record_batches, reshuffle=False) - | "GenerateSlicesSql" - >> beam.ParDo( - slicing_util.GenerateSlicesSqlDoFn( - slice_sqls=[slice_sql1, slice_sql2] - ) - ) - ) - - def check_result(got): - try: - self._check_results(got, expected_result) - - except AssertionError as err: - raise util.BeamAssertException(err) - - util.assert_that(result, check_result) - if __name__ == "__main__": absltest.main() diff --git a/third_party/farmhash.BUILD b/third_party/farmhash.BUILD index a8005735..fa912eb2 100644 --- a/third_party/farmhash.BUILD +++ b/third_party/farmhash.BUILD @@ -19,6 +19,8 @@ cc_library( "//conditions:default": [], }), # Required by ZetaSQL. + # ZetaSQL is removed + # This is a candidate for deletion defines = ["NAMESPACE_FOR_HASH_FUNCTIONS=farmhash"], includes = ["src/."], visibility = ["//visibility:public"], diff --git a/third_party/zetasql.patch b/third_party/zetasql.patch deleted file mode 100644 index c90d3b21..00000000 --- a/third_party/zetasql.patch +++ /dev/null @@ -1,674 +0,0 @@ -diff --git a/zetasql/analyzer/BUILD b/zetasql/analyzer/BUILD -index 590f1be1..3ca15df4 100644 ---- a/zetasql/analyzer/BUILD -+++ b/zetasql/analyzer/BUILD -@@ -18,7 +18,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") - load(":builddefs.bzl", "gen_analyzer_test") - - package( -- default_visibility = ["//zetasql/base:zetasql_implementation"], -+ default_visibility = ["//visibility:public"], - ) - - filegroup( -diff --git a/zetasql/analyzer/expr_resolver_helper.cc b/zetasql/analyzer/expr_resolver_helper.cc -index 93c3654d..8fb2256e 100644 ---- a/zetasql/analyzer/expr_resolver_helper.cc -+++ b/zetasql/analyzer/expr_resolver_helper.cc -@@ -357,7 +357,8 @@ ExprResolutionInfo::ExprResolutionInfo( - : ExprResolutionInfo( - query_resolution_info_in, name_scope_in, aggregate_name_scope_in, - analytic_name_scope_in, -- {.allows_aggregation = allows_aggregation_in, -+ ExprResolutionInfoOptions{ -+ .allows_aggregation = allows_aggregation_in, - .allows_analytic = allows_analytic_in, - .use_post_grouping_columns = use_post_grouping_columns_in, - .clause_name = clause_name_in, -diff --git a/zetasql/analyzer/name_scope.cc b/zetasql/analyzer/name_scope.cc -index b9a3176f..c1cf274a 100644 ---- a/zetasql/analyzer/name_scope.cc -+++ b/zetasql/analyzer/name_scope.cc -@@ -1549,7 +1549,7 @@ NameList::AddRangeVariableInWrappingNameList( - // variables, including for value tables, so we use `flatten_to_table` - // which drops range variables. - ZETASQL_RETURN_IF_ERROR(range_variable_name_list->MergeFrom( -- *original_name_list, ast_location, {.flatten_to_table = true})); -+ *original_name_list, ast_location, MergeOptions{.flatten_to_table = true})); - - auto wrapper_name_list = std::make_shared(); - ZETASQL_RETURN_IF_ERROR( -# diff --git a/bazel/zetasql_deps_step_2.bzl b/bazel/zetasql_deps_step_2.bzl -# index 6873dbe9..872ffd5e 100644 -# --- a/bazel/zetasql_deps_step_2.bzl -# +++ b/bazel/zetasql_deps_step_2.bzl -# @@ -477,7 +477,6 @@ alias( -# flex_register_toolchains(version = "2.6.4") -# bison_register_toolchains(version = "3.3.2") -# go_rules_dependencies() -# - go_register_toolchains(version = "1.21.6") -# gazelle_dependencies() -# textmapper_dependencies() - -diff --git a/zetasql/analyzer/resolver_expr.cc b/zetasql/analyzer/resolver_expr.cc -index 6116b4f7..70e8c9fd 100644 ---- a/zetasql/analyzer/resolver_expr.cc -+++ b/zetasql/analyzer/resolver_expr.cc -@@ -5586,7 +5586,8 @@ absl::Status Resolver::ResolveAnalyticFunctionCall( - { - ExprResolutionInfo analytic_arg_resolution_info( - expr_resolution_info, -- {.name_scope = expr_resolution_info->analytic_name_scope, -+ ExprResolutionInfoOptions{ -+ .name_scope = expr_resolution_info->analytic_name_scope, - .allows_analytic = expr_resolution_info->allows_analytic, - .clause_name = expr_resolution_info->clause_name}); - ZETASQL_RETURN_IF_ERROR(ResolveExpressionArguments( - -diff --git a/zetasql/base/BUILD b/zetasql/base/BUILD -index aa1f00da..7d4c3b3a 100644 ---- a/zetasql/base/BUILD -+++ b/zetasql/base/BUILD -@@ -15,7 +15,7 @@ - - licenses(["notice"]) - --package(default_visibility = [":zetasql_implementation"]) -+package(default_visibility = ["//visibility:public"]) - - package_group( - name = "zetasql_implementation", -diff --git a/zetasql/base/testing/BUILD b/zetasql/base/testing/BUILD -index 10596497..239c670f 100644 ---- a/zetasql/base/testing/BUILD -+++ b/zetasql/base/testing/BUILD -@@ -16,7 +16,7 @@ - - licenses(["notice"]) - --package(default_visibility = ["//zetasql/base:zetasql_implementation"]) -+package(default_visibility = ["//visibility:public"]) - - # A drop in replacement for gtest_main that parsers absl flags - cc_library( -diff --git a/zetasql/common/BUILD b/zetasql/common/BUILD -index cdafb15e..761e13cd 100644 ---- a/zetasql/common/BUILD -+++ b/zetasql/common/BUILD -@@ -14,7 +14,7 @@ - # limitations under the License. - - package( -- default_visibility = ["//zetasql/base:zetasql_implementation"], -+ default_visibility = ["//visibility:public"], - features = ["parse_headers"], - ) - -diff --git a/zetasql/common/internal_value.h b/zetasql/common/internal_value.h -index 770333d2..617ef628 100644 ---- a/zetasql/common/internal_value.h -+++ b/zetasql/common/internal_value.h -@@ -116,7 +116,7 @@ class InternalValue { - static std::string FormatInternal(const Value& x, - bool include_array_ordereness - ) { -- return x.FormatInternal({ -+ return x.FormatInternal(Type::FormatValueContentOptions{ - .force_type_at_top_level = true, - .include_array_ordereness = include_array_ordereness, - .indent = 0, -diff --git a/zetasql/parser/BUILD b/zetasql/parser/BUILD -index 433cf157..4fa4417c 100644 ---- a/zetasql/parser/BUILD -+++ b/zetasql/parser/BUILD -@@ -26,7 +26,7 @@ load("//bazel:textmapper.bzl", "tm_syntax") - load(":builddefs.bzl", "gen_parser_test") - - package( -- default_visibility = ["//zetasql/base:zetasql_implementation"], -+ default_visibility = ["//visibility:public"], - ) - - genrule( -diff --git a/zetasql/public/types/BUILD b/zetasql/public/types/BUILD -index 2b42fdcb..19ff2a4e 100644 ---- a/zetasql/public/types/BUILD -+++ b/zetasql/public/types/BUILD -@@ -14,7 +14,7 @@ - # limitations under the License. - # - --package(default_visibility = ["//zetasql/base:zetasql_implementation"]) -+package(default_visibility = ["//visibility:public"]) - - cc_library( - name = "types", - -diff --git a/zetasql/public/value.cc b/zetasql/public/value.cc -index 7aeffb01..c9f9f9dc 100644 ---- a/zetasql/public/value.cc -+++ b/zetasql/public/value.cc -@@ -1067,7 +1067,7 @@ std::string Value::DebugString(bool verbose) const { - - // Format will wrap arrays and structs. - std::string Value::Format(bool print_top_level_type) const { -- return FormatInternal( -+ return FormatInternal(Type::FormatValueContentOptions - {.force_type_at_top_level = print_top_level_type, .indent = 0}); - } - -@@ -1335,7 +1335,7 @@ std::string Value::FormatInternal( - std::vector element_strings(elements().size()); - for (int i = 0; i < elements().size(); ++i) { - element_strings[i] = -- elements()[i].FormatInternal(options.IncreaseIndent()); -+ elements()[i].FormatInternal(Type::FormatValueContentOptions{options.IncreaseIndent()}); - } - // Sanitize any '$' characters before creating substitution template. "$$" - // is replaced by "$" in the output from absl::Substitute. -@@ -1377,7 +1377,7 @@ std::string Value::FormatInternal( - const StructType* struct_type = type()->AsStruct(); - std::vector field_strings(struct_type->num_fields()); - for (int i = 0; i < struct_type->num_fields(); i++) { -- field_strings[i] = fields()[i].FormatInternal(options.IncreaseIndent()); -+ field_strings[i] = fields()[i].FormatInternal(Type::FormatValueContentOptions{options.IncreaseIndent()}); - } - // Sanitize any '$' characters before creating substitution template. "$$" - // is replaced by "$" in the output from absl::Substitute. -@@ -1423,9 +1423,9 @@ std::string Value::FormatInternal( - } - std::vector boundaries_strings; - boundaries_strings.push_back( -- start().FormatInternal(options.IncreaseIndent())); -+ start().FormatInternal(Type::FormatValueContentOptions{options.IncreaseIndent()})); - boundaries_strings.push_back( -- end().FormatInternal(options.IncreaseIndent())); -+ end().FormatInternal(Type::FormatValueContentOptions{options.IncreaseIndent()})); - // Sanitize any '$' characters before creating substitution template. "$$" - // is replaced by "$" in the output from absl::Substitute. - std::string templ = -diff --git a/zetasql/reference_impl/algebrizer.cc b/zetasql/reference_impl/algebrizer.cc -index 2e1258ab..48a3d7f4 100644 ---- a/zetasql/reference_impl/algebrizer.cc -+++ b/zetasql/reference_impl/algebrizer.cc -@@ -6738,7 +6738,7 @@ absl::StatusOr> Algebrizer::AlgebrizeTvfScan( - ZETASQL_RET_CHECK(tvf_scan->signature()->argument(i).is_scalar()); - ZETASQL_ASSIGN_OR_RETURN(auto expr_argument, - AlgebrizeExpression(argument->expr())); -- arguments.push_back({.value = std::move(expr_argument)}); -+ arguments.push_back(TVFOp::TVFOpArgument{.value = std::move(expr_argument)}); - continue; - } - -@@ -6767,14 +6767,14 @@ absl::StatusOr> Algebrizer::AlgebrizeTvfScan( - columns.push_back({relation_signature_column.name, - argument_column.type(), input_variable}); - } -- arguments.push_back({.relation = TVFOp::TvfInputRelation{ -+ arguments.push_back(TVFOp::TVFOpArgument{.relation = TVFOp::TvfInputRelation{ - std::move(relation), std::move(columns)}}); - continue; - } - - if (argument->model() != nullptr) { - ZETASQL_RET_CHECK(tvf_scan->signature()->argument(i).is_model()); -- arguments.push_back({.model = argument->model()->model()}); -+ arguments.push_back(TVFOp::TVFOpArgument{.model = argument->model()->model()}); - continue; - } - -diff --git a/zetasql/reference_impl/relational_op.cc b/zetasql/reference_impl/relational_op.cc -index 1619590a..a18a733f 100644 ---- a/zetasql/reference_impl/relational_op.cc -+++ b/zetasql/reference_impl/relational_op.cc -@@ -835,11 +835,11 @@ absl::StatusOr> TVFOp::CreateIterator( - } - ZETASQL_RET_CHECK_EQ(columns.size(), tuple_indexes.size()); - input_arguments.push_back( -- {.relation = {std::make_unique( -+ TableValuedFunction::TvfEvaluatorArg{.relation = {std::make_unique( - std::move(columns), std::move(tuple_indexes), context, - std::move(tuple_iterator))}}); - } else if (argument.model) { -- input_arguments.push_back({.model = argument.model}); -+ input_arguments.push_back(TableValuedFunction::TvfEvaluatorArg{.model = argument.model}); - } else { - ZETASQL_RET_CHECK_FAIL() << "Unexpected TVFOpArgument"; - } - -diff --git a/bazel/zetasql_deps_step_2.bzl b/bazel/zetasql_deps_step_2.bzl -index 6873dbe9..223f8dbd 100644 ---- a/bazel/zetasql_deps_step_2.bzl -+++ b/bazel/zetasql_deps_step_2.bzl -@@ -19,7 +19,6 @@ - load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") - load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") --load("@llvm_toolchain//:toolchains.bzl", "llvm_register_toolchains") - load("@rules_bison//bison:bison.bzl", "bison_register_toolchains") - load("@rules_flex//flex:flex.bzl", "flex_register_toolchains") - load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies") -@@ -29,7 +28,6 @@ load("@rules_proto//proto:setup.bzl", "rules_proto_setup") - load("@rules_proto//proto:toolchains.bzl", "rules_proto_toolchains") - - def _load_deps_from_step_1(): -- llvm_register_toolchains() - rules_foreign_cc_dependencies() - - def textmapper_dependencies(): -@@ -49,21 +47,29 @@ def textmapper_dependencies(): - go_repository( - name = "dev_lsp_go_jsonrpc2", - importpath = "go.lsp.dev/jsonrpc2", -+ remote = "https://github.com/go-language-server/jsonrpc2", -+ vcs = "git", - commit = "8c68d4fd37cd4bd06b62b3243f0d2292c681d164", - ) - go_repository( - name = "dev_lsp_go_protocol", - importpath = "go.lsp.dev/protocol", -+ remote = "https://github.com/go-language-server/protocol", -+ vcs = "git", - commit = "da30f9ae0326cc45b76adc5cd8920ac1ffa14a15", - ) - go_repository( - name = "dev_lsp_go_uri", - importpath = "go.lsp.dev/uri", -+ remote = "https://github.com/go-language-server/uri", -+ vcs = "git", - commit = "63eaac75cc850f596be19073ff6d4ec198603779", - ) - go_repository( - name = "dev_lsp_go_pkg", - importpath = "go.lsp.dev/pkg", -+ remote = "https://github.com/go-language-server/pkg", -+ vcs = "git", - commit = "384b27a52fb2b5d74d78cfe89c7738e9a3e216a5", - ) - go_repository( -@@ -477,7 +483,6 @@ alias( - flex_register_toolchains(version = "2.6.4") - bison_register_toolchains(version = "3.3.2") - go_rules_dependencies() -- go_register_toolchains(version = "1.21.6") - gazelle_dependencies() - textmapper_dependencies() - - -diff --git a/bazel/zetasql_deps_step_1.bzl b/bazel/zetasql_deps_step_1.bzl -index 825bf8ea..7edd1352 100644 ---- a/bazel/zetasql_deps_step_1.bzl -+++ b/bazel/zetasql_deps_step_1.bzl -@@ -22,25 +22,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - # but depend on them being something different. So we have to override them both - # by defining the repo first. - load("@com_google_zetasql//bazel:zetasql_bazel_version.bzl", "zetasql_bazel_version") --load("@toolchains_llvm//toolchain:deps.bzl", "bazel_toolchain_dependencies") --load("@toolchains_llvm//toolchain:rules.bzl", "llvm_toolchain") - - def zetasql_deps_step_1(add_bazel_version = True): - if add_bazel_version: - zetasql_bazel_version() - -- bazel_toolchain_dependencies() -- llvm_toolchain( -- name = "llvm_toolchain", -- llvm_versions = { -- "": "16.0.0", -- # The LLVM repo stops providing pre-built binaries for the MacOS x86_64 -- # architecture for versions >= 16.0.0: https://github.com/llvm/llvm-project/releases, -- # but our Kokoro MacOS tests are still using x86_64 (ventura). -- # TODO: Upgrade the MacOS version to sonoma-slcn. -- "darwin-x86_64": "15.0.7", -- }, -- ) - - http_archive( - name = "io_bazel_rules_go", - -diff --git a/bazel/grpc_extra_deps.patch b/bazel/grpc_extra_deps.patch -index 771761b3..9c1b1cee 100644 ---- a/bazel/grpc_extra_deps.patch -+++ b/bazel/grpc_extra_deps.patch -@@ -13,3 +13,41 @@ index 4d8afa3131..b42224501f 100644 - # Pull-in the go 3rd party dependencies for protoc_gen_validate, which is - # needed for building C++ xDS protos - go_third_party() -+ -+ diff --git a/BUILD b/BUILD -+ index 3b5d7e5e3c..c5d61e6e4c 100644 -+ --- a/BUILD -+ +++ b/BUILD -+ @@ -544,6 +544,7 @@ grpc_cc_library( -+ defines = ["GRPC_NO_XDS"], -+ external_deps = [ -+ "absl/base:core_headers", -+ + "absl/status", -+ + "absl/strings", -+ ], -+ language = "c++", -+ public_hdrs = GRPC_PUBLIC_HDRS, -+ -+ diff --git a/include/grpcpp/impl/status.h b/include/grpcpp/impl/status.h -+ index 95436ab8fb..fe9f44adf0 100644 -+ --- a/include/grpcpp/impl/status.h -+ +++ b/include/grpcpp/impl/status.h -+ @@ -23,6 +23,7 @@ -+ -+ #include -+ -+ +#include "absl/status/status.h" -+ #include -+ #include -+ #include -+ @@ -99,6 +100,10 @@ class GRPC_MUST_USE_RESULT_WHEN_USE_STRICT_WARNING Status { -+ Status(StatusCode code, const std::string& error_message) -+ : code_(code), error_message_(error_message) {} -+ -+ + operator absl::Status() const& { -+ + return absl::Status(static_cast(code_), error_message_); -+ + } -+ + -+ /// Construct an instance with \a code, \a error_message and -+ /// \a error_details. It is an error to construct an OK status with non-empty -+ /// \a error_message and/or \a error_details. -+ - -diff --git a/bazel/icu4c-64_2.patch b/bazel/icu4c-64_2.patch -index 69d12b63..a23bdcaf 100644 ---- a/bazel/icu4c-64_2.patch -+++ b/bazel/icu4c-64_2.patch -@@ -5,7 +5,7 @@ - CXX = @CXX@ - AR = @AR@ - -ARFLAGS = @ARFLAGS@ r --+ARFLAGS = @ARFLAGS@ -crs -++ARFLAGS = @ARFLAGS@ - RANLIB = @RANLIB@ - COMPILE_LINK_ENVVAR = @COMPILE_LINK_ENVVAR@ - UCLN_NO_AUTO_CLEANUP = @UCLN_NO_AUTO_CLEANUP@ - - diff --git a/bazel/icu.BUILD b/bazel/icu.BUILD -index be36d7de..f61d8f3c 100644 ---- a/bazel/icu.BUILD -+++ b/bazel/icu.BUILD -@@ -35,20 +35,17 @@ filegroup( - configure_make( - name = "icu", - configure_command = "source/configure", -- args = select({ -- # AR is overridden to be libtool by rules_foreign_cc. It does not support ar style arguments -- # like "r". We need to prevent the icu make rules from adding unsupported parameters by -- # forcing ARFLAGS to keep the rules_foreign_cc value in this parameter -- "@platforms//os:macos": [ -- "ARFLAGS=\"-static -o\"", -- "MAKE=gnumake", -- ], -- "//conditions:default": [], -- }), -- env = { -- "CXXFLAGS": "-fPIC", # For JNI -- "CFLAGS": "-fPIC", # For JNI -- }, -+ env = select({ -+ "@platforms//os:macos": { -+ "AR": "", -+ "CXXFLAGS": "-fPIC", # For JNI -+ "CFLAGS": "-fPIC", # For JNI -+ }, -+ "//conditions:default": { -+ "CXXFLAGS": "-fPIC", # For JNI -+ "CFLAGS": "-fPIC", # For JNI -+ }, -+ }), - configure_options = [ - "--enable-option-checking", - "--enable-static", - - -diff --git a/zetasql/public/constant.h b/zetasql/public/constant.h -index 946183b0..03ac17e0 100644 ---- a/zetasql/public/constant.h -+++ b/zetasql/public/constant.h -@@ -80,7 +80,7 @@ class Constant { - const std::vector& name_path() const { return name_path_; } - - // Returns the type of this Constant. -- virtual const Type* type() const = 0; -+ virtual const zetasql::Type* type() const = 0; - - // Returns whether or not this Constant is a specific constant interface or - // implementation. - -diff --git a/zetasql/public/property_graph.h b/zetasql/public/property_graph.h -index 53ccca23..0eefe780 100644 ---- a/zetasql/public/property_graph.h -+++ b/zetasql/public/property_graph.h -@@ -348,7 +348,7 @@ class GraphPropertyDeclaration { - return ::zetasql::FullName(PropertyGraphNamePath(), Name()); - } - -- virtual const Type* Type() const = 0; -+ virtual const zetasql::Type* Type() const = 0; - - // Returns whether or not this GraphPropertyDeclaration is a specific - // interface or implementation. - -diff --git a/zetasql/analyzer/resolver_expr.cc b/zetasql/analyzer/resolver_expr.cc -index 51d095ab..8ba1eefc 100644 ---- a/zetasql/analyzer/resolver_expr.cc -+++ b/zetasql/analyzer/resolver_expr.cc -@@ -2996,7 +2996,7 @@ class SystemVariableConstant final : public Constant { - const Type* type) - : Constant(name_path), type_(type) {} - -- const Type* type() const override { return type_; } -+ const zetasql::Type* type() const override { return type_; } - std::string DebugString() const override { return FullName(); } - std::string ConstantValueDebugString() const override { return ""; } - - -diff --git a/zetasql/public/coercer.cc b/zetasql/public/coercer.cc -index dc4961dd..80d26183 100644 ---- a/zetasql/public/coercer.cc -+++ b/zetasql/public/coercer.cc -@@ -154,7 +154,7 @@ class TypeSuperTypes { - return false; - } - -- const Type* type() const { return type_; } -+ const zetasql::Type* type() const { return type_; } - TypeListView supertypes() const { return supertypes_; } - - std::vector ToVector() const { - -diff --git a/zetasql/public/function_signature.h b/zetasql/public/function_signature.h -index 29886cc2..5436071c 100644 ---- a/zetasql/public/function_signature.h -+++ b/zetasql/public/function_signature.h -@@ -702,7 +702,7 @@ class FunctionArgumentType { - // Returns NULL if kind_ is not ARG_TYPE_FIXED or ARG_TYPE_LAMBDA. If kind_ is - // ARG_TYPE_LAMBDA, returns the type of lambda body type, which could be NULL - // if the body type is templated. -- const Type* type() const { return type_; } -+ const zetasql::Type* type() const { return type_; } - - SignatureArgumentKind kind() const { return kind_; } - -diff --git a/zetasql/public/input_argument_type.h b/zetasql/public/input_argument_type.h -index f2098787..55b416e3 100644 ---- a/zetasql/public/input_argument_type.h -+++ b/zetasql/public/input_argument_type.h -@@ -81,7 +81,7 @@ class InputArgumentType { - ~InputArgumentType() {} - - // This may return nullptr (such as for lambda). -- const Type* type() const { return type_; } -+ const zetasql::Type* type() const { return type_; } - - const std::vector& field_types() const { - return field_types_; - -diff --git a/zetasql/public/simple_catalog.h b/zetasql/public/simple_catalog.h -index 76a94d43..a0d81b9d 100644 ---- a/zetasql/public/simple_catalog.h -+++ b/zetasql/public/simple_catalog.h -@@ -1202,7 +1202,7 @@ class SimpleConstant : public Constant { - const SimpleConstantProto& simple_constant_proto, - const TypeDeserializer& type_deserializer); - -- const Type* type() const override { return value_.type(); } -+ const zetasql::Type* type() const override { return value_.type(); } - - const Value& value() const { return value_; } - - -diff --git a/zetasql/public/sql_constant.h b/zetasql/public/sql_constant.h -index fa88344f..69defd3b 100644 ---- a/zetasql/public/sql_constant.h -+++ b/zetasql/public/sql_constant.h -@@ -60,7 +60,7 @@ class SQLConstant : public Constant { - - // Returns the Type of the resolved Constant based on its resolved - // expression type. -- const Type* type() const override { -+ const zetasql::Type* type() const override { - return constant_expression()->type(); - } - - -diff --git a/zetasql/public/value.h b/zetasql/public/value.h -index 49b60aec..86688538 100644 ---- a/zetasql/public/value.h -+++ b/zetasql/public/value.h -@@ -122,7 +122,7 @@ class Value { - ~Value(); - - // Returns the type of the value. -- const Type* type() const; -+ const zetasql::Type* type() const; - - // Returns the type kind of the value. Same as type()->type_kind() but in some - // cases can be a bit more efficient. -@@ -1152,7 +1152,7 @@ class Value { - - // Returns a pointer to Value's Type. Requires is_valid(). If TypeKind is - // stored in the Metadata, Type pointer is obtained from static TypeFactory. -- const Type* type() const; -+ const zetasql::Type* type() const; - - // Returns true, if instance stores pointer to a Type and false if type's - // kind. - -diff --git a/zetasql/public/value_inl.h b/zetasql/public/value_inl.h -index e917a97a..f324276f 100644 ---- a/zetasql/public/value_inl.h -+++ b/zetasql/public/value_inl.h -@@ -1077,7 +1077,7 @@ class Value::Metadata::ContentLayout<4> { - has_type_(false), - value_extended_content_(value_extended_content) {} - -- const Type* type() const { return type_; } -+ const zetasql::Type* type() const { return type_; } - int32_t value_extended_content() const { return value_extended_content_; } - bool is_null() const { return is_null_; } - bool preserves_order() const { return preserves_order_; } -@@ -1157,7 +1157,7 @@ class Value::Metadata::ContentLayout<8> { - // TODO: wait for fixed clang-format - // clang-format on - -- const Type* type() const { -+ const zetasql::Type* type() const { - return reinterpret_cast(type_ & kTypeMask); - } - int32_t value_extended_content() const { return value_extended_content_; } - -diff --git a/zetasql/reference_impl/operator.h b/zetasql/reference_impl/operator.h -index 24f0ddac..7adb701d 100644 ---- a/zetasql/reference_impl/operator.h -+++ b/zetasql/reference_impl/operator.h -@@ -240,7 +240,7 @@ class ExprArg : public AlgebraArg { - - ~ExprArg() override = default; - -- const Type* type() const { return type_; } -+ const zetasql::Type* type() const { return type_; } - - private: - const Type* type_; - -diff --git a/zetasql/resolved_ast/resolved_column.h b/zetasql/resolved_ast/resolved_column.h -index 912b3ca4..2e613f2a 100644 ---- a/zetasql/resolved_ast/resolved_column.h -+++ b/zetasql/resolved_ast/resolved_column.h -@@ -119,7 +119,7 @@ class ResolvedColumn { - IdString table_name_id() const { return table_name_; } - IdString name_id() const { return name_; } - -- const Type* type() const { return annotated_type_.type; } -+ const zetasql::Type* type() const { return annotated_type_.type; } - - const AnnotationMap* type_annotation_map() const { - return annotated_type_.annotation_map; - -diff --git a/zetasql/testing/test_value.h b/zetasql/testing/test_value.h -index 0412873e..d2d8c3e8 100644 ---- a/zetasql/testing/test_value.h -+++ b/zetasql/testing/test_value.h -@@ -106,7 +106,7 @@ class ValueConstructor { - : v_(v) {} - - const Value& get() const { return v_; } -- const Type* type() const { return v_.type(); } -+ const zetasql::Type* type() const { return v_.type(); } - - static std::vector ToValues(absl::Span slice) { - std::vector values; - - -diff --git a/zetasql/base/logging.h b/zetasql/base/logging.h -index 730ccdcb..46fe06b0 100644 ---- a/zetasql/base/logging.h -+++ b/zetasql/base/logging.h -@@ -59,6 +59,17 @@ inline void ZetaSqlMakeCheckOpValueString(std::ostream *os, const T &v) { - (*os) << v; - } - -+// This overloading is implemented to address the compilation issue when trying to log unique_ptr types -+// At the moment, we are not providing any specific implementation for handling unique_ptr types. -+template -+inline void ZetaSqlMakeCheckOpValueString(std::ostream* os, const std::unique_ptr& v) { -+ if (v == nullptr) { -+ (*os) << "nullptr"; -+ } else { -+ (*os) << v.get(); -+ } -+} -+ - // Overrides for char types provide readable values for unprintable - // characters. - template <> - - - -diff --git a/zetasql/base/testing/BUILD b/zetasql/base/testing/BUILD -index 10596497..a9b69be7 100644 ---- a/zetasql/base/testing/BUILD -+++ b/zetasql/base/testing/BUILD -@@ -55,6 +55,7 @@ cc_library( - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", -+ "@com_github_grpc_grpc//:grpc++", - ], - ) - -@@ -69,6 +70,7 @@ cc_test( - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", -+ "@com_github_grpc_grpc//:grpc++", - ], - ) - -