ScalarFunctionExpr protobuf 序列化/反序列化最小例子
目录
ScalarFunctionExpr protobuf 序列化/反序列化最小例子
什么是 ScalarFunctionExpr
ScalarFunctionExpr 是 Datafusion 物理执行层对一次标量函数调用的完整描述。
它描述的是
- 函数名称是什么
- 函数参数有哪些
- 函数返回,包括返回类型是什么,是否可空等等
- 执行过程中的配置信息
比如 sqrt(a),平方根 它的物理表达式大概长
ScalarFunctionExpr
function: sqrt 的实现
name: "sqrt"
args:
- Column("a", index=0)
return_type: Float64
nullable: false/true
标量函数和聚合函数
标量函数 参数和值一一对应,比如 sqrt、abs、lower、upper、trim、length、concat 等等
聚合函数 多个输入对一一个输出,比如 sum、avg、count、max、min 等等
什么是 protobuf
protobuf是一种格式 是一种面向高效传输和跨语言通信的序列化格式。
JSON:
文本格式
人能直接看懂
调试方便
体积相对大
解析通常慢一点
不强制 schema
protobuf:
二进制格式
人不能直接看懂
体积小
解析快
强依赖 .proto schema
更适合高性能/跨语言 RPC
如何写最小例子
- 定义表达式
- 编码
- 解码
- 验证
完整例子代码
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! See `main.rs` for how to run it.
//!
//! This example demonstrates the smallest useful round trip for a physical
//! [`ScalarFunctionExpr`]:
//!
//! 1. Build a physical expression for `sqrt(a)`.
//! 2. Serialize it to a protobuf `PhysicalExprNode`.
//! 3. Deserialize it back to a physical expression.
//! 4. Evaluate both expressions against the same batch.
use std::sync::Arc;
use arrow::array::Float64Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::common::config::ConfigOptions;
use datafusion::common::{DataFusionError, Result};
use datafusion::physical_expr::ScalarFunctionExpr;
use datafusion::physical_plan::PhysicalExpr;
use datafusion::physical_plan::expressions::Column;
use datafusion::prelude::SessionContext;
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
use datafusion_proto::protobuf::physical_expr_node::ExprType;
pub async fn scalar_function_expr() -> Result<()> {
println!("=== ScalarFunctionExpr Proto Round Trip Example ===\n");
// 定义输入数据长什么样
// 我们有一张输入表/输入 batch,它只有一列:
// 列名: a
// 类型: Float64
// 是否允许 NULL: false
// 也就是类似 SQL 里的:
// CREATE TABLE t (
// a DOUBLE NOT NULL
// );
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)]));
// 对输入 batch 里的第 0 列 a 求平方根。
let expr = Arc::new(ScalarFunctionExpr::try_new(
datafusion::functions::math::sqrt(),
vec![Arc::new(Column::new("a", 0))],
schema.as_ref(),
Arc::new(ConfigOptions::new()),
)?) as Arc<dyn PhysicalExpr>;
println!("Step 1: Built physical expression: {expr}");
// 这段是在把刚刚造好的物理表达式 sqrt(a@0) 转成 proto 结构,序列化
let codec = DefaultPhysicalExtensionCodec {};
let proto = serialize_physical_expr(&expr, &codec)?;
let Some(ExprType::ScalarUdf(scalar_udf)) = proto.expr_type.as_ref() else {
return Err(DataFusionError::Execution(
"Expected ScalarUdf proto node".to_string(),
));
};
println!(
"Step 2: Serialized to proto: name={}, args={}, has_fun_definition={}",
scalar_udf.name,
scalar_udf.args.len(),
scalar_udf.fun_definition.is_some()
);
// 反序列化
let ctx = SessionContext::new();
let decoded_expr = parse_physical_expr(&proto, &ctx.task_ctx(), &schema, &codec)?;
println!("Step 3: Deserialized expression: {decoded_expr}");
// 验证反序列化的表达式、执行结果是不是和原来一样
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Float64Array::from(vec![4.0, 9.0, 16.0]))],
)?;
let original = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let decoded = decoded_expr
.evaluate(&batch)?
.into_array(batch.num_rows())?;
let original = original
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
DataFusionError::Execution("Expected Float64 result array".to_string())
})?;
let decoded = decoded
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
DataFusionError::Execution("Expected Float64 result array".to_string())
})?;
assert_eq!(original, decoded);
println!("Step 4: Evaluated both expressions successfully");
println!(" input: [4.0, 9.0, 16.0]");
println!(" output: {decoded:?}");
Ok(())
}
执行流程
ScalarFunctionExpr(sqrt(a))
↓ serialize_physical_expr
PhysicalExprNode::ScalarUdf
↓ parse_physical_expr
ScalarFunctionExpr(sqrt(a))
↓ evaluate
Float64Array [2.0, 3.0, 4.0]
运行命令和输出
/Users/zhengpeng/.cargo/bin/cargo run --color=always --example proto --profile dev --manifest-path /Users/zhengpeng/Source/Code/Rust-Code/Github/datafusion/datafusion-examples/Cargo.toml -- scalar_function_expr
Finished `dev` profile [unoptimized + debuginfo] target(s) in 0.19s
Running `target/debug/examples/proto scalar_function_expr`
Usage: cargo run --example proto -- [all|composed_extension_codec|expression_deduplication|scalar_function_expr]
=== ScalarFunctionExpr Proto Round Trip Example ===
Step 1: Built physical expression: sqrt(a@0)
Step 2: Serialized to proto: name=sqrt, args=1, has_fun_definition=false
Step 3: Deserialized expression: sqrt(a@0)
Step 4: Evaluated both expressions successfully
input: [4.0, 9.0, 16.0]
output: PrimitiveArray<Float64>
[
2.0,
3.0,
4.0,
]
Process finished with exit code 0
说明
例子里的 serialize_physical_expr 是把 Rust 里的 PhysicalExpr 转成 protobuf 的 Rust struct,还没有进一步 encode 成二进制 bytes; 如果要网络传输,还需要 prost::Message::encode。