Skip to content

Commit

Permalink
Merge pull request #170 from GoogleCloudPlatform/5789B3FF9E750C15D6B7…
Browse files Browse the repository at this point in the history
…4524C048BD98

Project import generated by Copybara.
  • Loading branch information
olavloite authored May 10, 2024
2 parents ea3a979 + a8d4028 commit d1ec59f
Show file tree
Hide file tree
Showing 40 changed files with 2,149 additions and 129 deletions.
19 changes: 9 additions & 10 deletions backend/actions/change_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,18 @@ std::vector<const Column*> GetColumnsForDataChangeRecord(

absl::StatusOr<
std::tuple<std::vector<zetasql::Value>, std::vector<zetasql::Value>>>
GetValuesForDataChangeRecord(std::string value_capture_type,
absl::string_view mod_type,
const Table* tracked_table,
std::vector<const Column*> populated_columns,
std::vector<zetasql::Value> values,
const Key& key, ReadOnlyStore* store,
std::vector<const Column*> tracked_columns) {
GetValuesForDataChangeRecord(
absl::string_view value_capture_type, absl::string_view mod_type,
const Table* tracked_table,
const std::vector<const Column*>& populated_columns,
const std::vector<zetasql::Value>& values, const Key& key,
ReadOnlyStore* store, std::vector<const Column*> tracked_columns) {
// Get old_values
std::tuple<std::vector<zetasql::Value>, std::vector<zetasql::Value>>
new_values_and_old_values;
std::vector<std::string> tracked_columns_str;
for (const Column* col : tracked_columns) {
if (!IsPrimaryKey(tracked_table, col)) {
if (!IsPrimaryKey(tracked_table, col) && !col->is_generated()) {
tracked_columns_str.push_back(col->Name());
}
}
Expand Down Expand Up @@ -298,7 +297,7 @@ GetValuesForDataChangeRecord(std::string value_capture_type,
// Accumulate tracked column types and values for same DataChangeRecord
absl::Status LogTableMod(
const Key& key, std::vector<const Column*> columns,
std::vector<zetasql::Value> values, const Table* tracked_table,
const std::vector<zetasql::Value>& values, const Table* tracked_table,
const ChangeStream* change_stream, absl::string_view mod_type,
zetasql::Value partition_token,
absl::flat_hash_map<const ChangeStream*, std::vector<DataChangeRecord>>*
Expand Down Expand Up @@ -717,7 +716,7 @@ absl::StatusOr<std::vector<WriteOp>> BuildChangeStreamWriteOps(
// Map for change streams and their DataChangeRecords
absl::flat_hash_map<const ChangeStream*, std::vector<DataChangeRecord>>
data_change_records_in_transaction_by_change_stream;
// Map for chagne streams and their ModGroups
// Map for change streams and their ModGroups
absl::flat_hash_map<const ChangeStream*, ModGroup>
last_mod_group_by_change_stream;
for (const auto& write_op : buffered_write_ops) {
Expand Down
1 change: 1 addition & 0 deletions backend/query/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ cc_library(
":analyzer_options",
":query_context",
":query_engine_options",
":queryable_table",
"//backend/query/feature_filter:gsql_supported_functions",
"//backend/query/feature_filter:sql_feature_filter",
"//backend/query/feature_filter:sql_features_view",
Expand Down
215 changes: 215 additions & 0 deletions backend/query/ml/ml_predict_table_valued_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,194 @@ std::vector<std::string> FunctionName(bool safe) {
return {std::string(kMlFunctionNamespace), std::string(kFunctionName)};
}

class MlPredictTableValuedFunctionEvaluator
: public zetasql::EvaluatorTableIterator {
public:
MlPredictTableValuedFunctionEvaluator(
const zetasql::Model* model,
std::unique_ptr<EvaluatorTableIterator> input,
zetasql::Value parameters,
const std::vector<zetasql::TVFSchemaColumn>& output_columns)
: model_(model),
input_(std::move(input)),
parameters_(std::move(parameters)),
output_columns_(output_columns) {}

// Validates inputs and initializes evaluator's state.
absl::Status Init();

int NumColumns() const override {
return static_cast<int>(output_columns_.size());
}

std::string GetColumnName(int i) const override {
DCHECK_GE(i, 0);
DCHECK_LT(i, output_columns_.size());
return output_columns_[i].name;
}

const zetasql::Type* GetColumnType(int i) const override {
DCHECK_GE(i, 0);
DCHECK_LT(i, output_columns_.size());
return output_columns_[i].type;
}

const zetasql::Value& GetValue(int i) const override {
DCHECK_GE(i, 0);
DCHECK_LT(i, output_columns_.size());
return output_values_[i];
}

absl::Status Status() const override { return status_; }

absl::Status Cancel() override { return input_->Cancel(); }

bool NextRow() override {
// Advance input iterator, stop if there is an error.
if (!input_->NextRow()) {
status_ = input_->Status();
return false;
}

// Get all the input values and populate pass-through columns.
for (auto& input_column : input_columns_) {
*input_column.value = input_->GetValue(input_column.input_index);
}

// Invoke model evaluator to populate output values.
status_ = ModelEvaluator::Predict(model_, model_inputs_, model_outputs_);
return status_.ok();
}

private:
// The model argument of ML.PREDICT function.
const zetasql::Model* const model_;
// The relation argument of ML.PREDICT function.
std::unique_ptr<EvaluatorTableIterator> input_;
// The parameters argument of ML.PREDICT function.
const zetasql::Value parameters_;
// Selected output columns: model outputs and pass-through columns.
const std::vector<zetasql::TVFSchemaColumn> output_columns_;
// Maps input iterator column index to either input_values_ for model inputs
// or output_values_ for pass-through columns.
struct InputColumn {
// Index of the input column value to be read.
int64_t input_index;
// Pointer to the value to be set.
zetasql::Value* value;
};
std::vector<InputColumn> input_columns_;
// Model input columns sent as arguments to ModelEvaluator.
CaseInsensitiveStringMap<const ModelEvaluator::ModelColumn> model_inputs_;
// Model output columns values of which are set by ModelEvaluator.
CaseInsensitiveStringMap<ModelEvaluator::ModelColumn> model_outputs_;
// Vector of values referenced by model_inputs_.
std::vector<zetasql::Value> input_values_;
// Vector of values accessible through GetValue().
std::vector<zetasql::Value> output_values_;
// Status of the iterator.
absl::Status status_;
};

absl::Status MlPredictTableValuedFunctionEvaluator::Init() {
// Create index of input columns.
CaseInsensitiveStringMap<std::vector<int64_t>> input_columns_by_name;
for (int i = 0; i < input_->NumColumns(); ++i) {
input_columns_by_name[input_->GetColumnName(i)].emplace_back(i);
}

// Validate that model inputs are satisfied and build model_inputs_.
input_values_.resize(model_->NumInputs());
for (int i = 0; i < model_->NumInputs(); ++i) {
const QueryableModelColumn* model_column =
model_->GetInput(i)->GetAs<QueryableModelColumn>();
ZETASQL_RET_CHECK(model_column);

// Find matching input column by name.
auto input_column = input_columns_by_name.find(model_column->Name());
if (input_column == input_columns_by_name.end()) {
// If column is required, fail the query.
if (model_column->required()) {
return error::MlInputColumnMissing(
model_column->Name(),
model_column->GetType()->TypeName(zetasql::PRODUCT_EXTERNAL,
/*use_external_float32=*/true));
}
// Ignore missing optional columns.
continue;
}

// If there is more than one matching input column, raise ambiguous error.
if (input_column->second.size() > 1) {
return error::MlInputColumnAmbiguous(model_column->Name());
}

ZETASQL_RET_CHECK_EQ(input_column->second.size(), 1);
int64_t input_column_index = input_column->second.front();

const zetasql::Type* input_column_type =
input_->GetColumnType(input_column_index);
if (!input_column_type->Equals(model_column->GetType())) {
return error::MlInputColumnTypeMismatch(
model_column->Name(),
input_column_type->TypeName(zetasql::PRODUCT_EXTERNAL,
/*use_external_float32=*/true),
model_column->GetType()->TypeName(zetasql::PRODUCT_EXTERNAL,
/*use_external_float32=*/true));
}

input_columns_.push_back(InputColumn{.input_index = input_column_index,
.value = &input_values_[i]});

model_inputs_.insert(
{model_column->Name(),
ModelEvaluator::ModelColumn{.model_column = model_column,
.value = &input_values_[i]}});
}

// Map output columns to model outputs or passthrough columns.
output_values_.resize(output_columns_.size());
for (int i = 0; i < output_columns_.size(); ++i) {
const std::string& column_name = output_columns_[i].name;
const zetasql::Type* column_type = output_columns_[i].type;

// Output of the model, not a pass through column.
const zetasql::Column* model_column =
model_->FindOutputByName(column_name);
if (model_column != nullptr) {
ZETASQL_RET_CHECK(model_column->Is<QueryableModelColumn>());
ZETASQL_RET_CHECK(model_column->GetType()->Equals(column_type));
model_outputs_.insert(
{model_column->Name(),
ModelEvaluator::ModelColumn{
.model_column = model_column->GetAs<QueryableModelColumn>(),
.value = &output_values_[i]}});
continue;
}

// If the output column matches an input column, it's a pass-through column.
auto input_column = input_columns_by_name.find(column_name);
if (input_column != input_columns_by_name.end()) {
if (input_column->second.size() > 1) {
return error::MlPassThroughColumnAmbiguous(column_name);
}
ZETASQL_RET_CHECK_EQ(input_column->second.size(), 1);
int64_t input_column_index = input_column->second.front();
const zetasql::Type* input_column_type =
input_->GetColumnType(input_column_index);
ZETASQL_RET_CHECK(column_type->Equals(input_column_type));
input_columns_.push_back(InputColumn{.input_index = input_column_index,
.value = &output_values_[i]});
continue;
}

ZETASQL_RET_CHECK_FAIL() << "Could not match ML TVF Scan column " << column_name
<< ". Matches should be ensured when resolving the TVF";
}

return absl::OkStatus();
}

} // namespace

MlPredictTableValuedFunction::MlPredictTableValuedFunction(bool safe)
Expand Down Expand Up @@ -127,4 +315,31 @@ absl::Status MlPredictTableValuedFunction::Resolve(
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<zetasql::EvaluatorTableIterator>>
MlPredictTableValuedFunction::CreateEvaluator(
std::vector<TvfEvaluatorArg> input_arguments,
const std::vector<zetasql::TVFSchemaColumn>& output_columns,
const zetasql::FunctionSignature* function_call_signature) const {
ZETASQL_RET_CHECK_GE(input_arguments.size(), 2);
ZETASQL_RET_CHECK_LE(input_arguments.size(), 3);

ZETASQL_RET_CHECK(input_arguments[0].model);
const zetasql::Model* model = input_arguments[0].model;

ZETASQL_RET_CHECK(input_arguments[1].relation);
std::unique_ptr<zetasql::EvaluatorTableIterator> input =
std::move(input_arguments[1].relation);

zetasql::Value parameters;
if (input_arguments.size() >= 3) {
ZETASQL_RET_CHECK(input_arguments[2].value);
parameters = *input_arguments[2].value;
}

auto evaluator = std::make_unique<MlPredictTableValuedFunctionEvaluator>(
model, std::move(input), parameters, std::move(output_columns));
ZETASQL_RETURN_IF_ERROR(evaluator->Init());
return std::move(evaluator);
}

} // namespace google::spanner::emulator::backend
7 changes: 7 additions & 0 deletions backend/query/ml/ml_predict_table_valued_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ class MlPredictTableValuedFunction : public zetasql::TableValuedFunction {
std::shared_ptr<zetasql::TVFSignature>* output_tvf_signature)
const override;

// Creates evaluator for this function.
absl::StatusOr<std::unique_ptr<zetasql::EvaluatorTableIterator>>
CreateEvaluator(std::vector<TvfEvaluatorArg> input_arguments,
const std::vector<zetasql::TVFSchemaColumn>& output_columns,
const zetasql::FunctionSignature* function_call_signature)
const override;

private:
const bool safe_ = false;
};
Expand Down
33 changes: 21 additions & 12 deletions backend/query/query_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1314,20 +1314,29 @@ TEST_P(QueryEngineTest, TestCannotQueryChangeStreamDataTableExternally) {
}

TEST_P(QueryEngineTest, TestMlQuery) {
GTEST_SKIP();

if (GetParam() == database_api::DatabaseDialect::POSTGRESQL) {
return;
ZETASQL_ASSERT_OK_AND_ASSIGN(
QueryResult result,
query_engine().ExecuteSql(Query{R"sql(
SELECT spanner.ml_predict_row(
'test'::text,
'{"instances" : [{"string_col":"foo"}]}'::jsonb))sql"},
QueryContext{schema(), reader()}));
ASSERT_NE(result.rows, nullptr);
EXPECT_EQ(
ToString(result),
R"(ml_predict_row(PG.JSONB) : {"predictions": [{"Outcome": false}]},)");
} else {
ZETASQL_ASSERT_OK_AND_ASSIGN(
QueryResult result,
query_engine().ExecuteSql(Query{R"sql(
SELECT int64_col, Outcome
FROM ML.PREDICT(MODEL test_model, TABLE test_table))sql"},
QueryContext{model_schema(), reader()}));
ASSERT_NE(result.rows, nullptr);
EXPECT_EQ(ToString(result),
R"(int64_col,Outcome(INT64,BOOL) : 1,false,2,false,4,true,)");
}

EXPECT_THAT(
query_engine().ExecuteSql(
Query{"SELECT int64_col, Outcome "
"FROM ML.PREDICT(MODEL test_model, TABLE test_table)"},
QueryContext{model_schema(), reader()}),
StatusIs(
absl::StatusCode::kUnimplemented,
HasSubstr("TVF ML.PREDICT does not support the API in evaluator.h")));
}

INSTANTIATE_TEST_SUITE_P(
Expand Down
15 changes: 13 additions & 2 deletions backend/query/query_validator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "backend/query/feature_filter/sql_feature_filter.h"
#include "backend/query/query_context.h"
#include "backend/query/query_engine_options.h"
#include "backend/query/queryable_table.h"
#include "backend/schema/catalog/column.h"
#include "backend/schema/catalog/index.h"
#include "backend/schema/catalog/sequence.h"
Expand Down Expand Up @@ -558,8 +559,18 @@ absl::Status QueryValidator::CheckPendingCommitTimestampReads(
return absl::OkStatus();
}

std::string table_name = table_scan->table()->Name();
const Table* table = schema()->FindTable(table_name);
// Any table in the user schema will be a QueryableTable. We use this property
// to skip table scans against system tables (e.g. information_schema.tables)
// since these tables do not have corresponding backend schema nodes.
//
// Skipping these tables is safe because they are not writable and do not
// contain commit timestamps (pending or otherwise).
if (!table_scan->table()->Is<QueryableTable>()) {
return absl::OkStatus();
}

const Table* table =
table_scan->table()->GetAs<QueryableTable>()->wrapped_table();
ZETASQL_RET_CHECK(table != nullptr);
std::vector<const Column*> columns;
for (int i = 0; i < table_scan->column_index_list_size(); ++i) {
Expand Down
Loading

0 comments on commit d1ec59f

Please sign in to comment.