目录

ScalarFunctionExpr protobuf 序列化/反序列化最小例子

ScalarFunctionExpr protobuf 序列化/反序列化最小例子

什么是 ScalarFunctionExpr

ScalarFunctionExprDatafusion 物理执行层对一次标量函数调用的完整描述。

它描述的是

  • 函数名称是什么
  • 函数参数有哪些
  • 函数返回,包括返回类型是什么,是否可空等等
  • 执行过程中的配置信息

比如 sqrt(a),平方根 它的物理表达式大概长

ScalarFunctionExpr
function: sqrt 的实现
name: "sqrt"
args:
- Column("a", index=0)
return_type: Float64
nullable: false/true

标量函数和聚合函数

标量函数 参数和值一一对应,比如 sqrt、abs、lower、upper、trim、length、concat 等等

聚合函数 多个输入对一一个输出,比如 sum、avg、count、max、min 等等

什么是 protobuf

protobuf是一种格式 是一种面向高效传输和跨语言通信的序列化格式。

JSON:
文本格式
人能直接看懂
调试方便
体积相对大
解析通常慢一点
不强制 schema

protobuf:
二进制格式
人不能直接看懂
体积小
解析快
强依赖 .proto schema
更适合高性能/跨语言 RPC

如何写最小例子

  • 定义表达式
  • 编码
  • 解码
  • 验证

完整例子代码


// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

//! See `main.rs` for how to run it.
//!
//! This example demonstrates the smallest useful round trip for a physical
//! [`ScalarFunctionExpr`]:
//!
//! 1. Build a physical expression for `sqrt(a)`.
//! 2. Serialize it to a protobuf `PhysicalExprNode`.
//! 3. Deserialize it back to a physical expression.
//! 4. Evaluate both expressions against the same batch.

use std::sync::Arc;

use arrow::array::Float64Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::common::config::ConfigOptions;
use datafusion::common::{DataFusionError, Result};
use datafusion::physical_expr::ScalarFunctionExpr;
use datafusion::physical_plan::PhysicalExpr;
use datafusion::physical_plan::expressions::Column;
use datafusion::prelude::SessionContext;
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
use datafusion_proto::protobuf::physical_expr_node::ExprType;

pub async fn scalar_function_expr() -> Result<()> {
    println!("=== ScalarFunctionExpr Proto Round Trip Example ===\n");

    // 定义输入数据长什么样
    // 我们有一张输入表/输入 batch,它只有一列:
    // 列名: a
    // 类型: Float64
    // 是否允许 NULL: false
    // 也就是类似 SQL 里的:
    // CREATE TABLE t (
    //   a DOUBLE NOT NULL
    // );

    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)]));

    // 对输入 batch 里的第 0 列 a 求平方根。
    let expr = Arc::new(ScalarFunctionExpr::try_new(
        datafusion::functions::math::sqrt(),
        vec![Arc::new(Column::new("a", 0))],
        schema.as_ref(),
        Arc::new(ConfigOptions::new()),
    )?) as Arc<dyn PhysicalExpr>;

    println!("Step 1: Built physical expression: {expr}");

    // 这段是在把刚刚造好的物理表达式 sqrt(a@0) 转成 proto 结构,序列化
    let codec = DefaultPhysicalExtensionCodec {};
    let proto = serialize_physical_expr(&expr, &codec)?;
    let Some(ExprType::ScalarUdf(scalar_udf)) = proto.expr_type.as_ref() else {
        return Err(DataFusionError::Execution(
            "Expected ScalarUdf proto node".to_string(),
        ));
    };

    println!(
        "Step 2: Serialized to proto: name={}, args={}, has_fun_definition={}",
        scalar_udf.name,
        scalar_udf.args.len(),
        scalar_udf.fun_definition.is_some()
    );

    // 反序列化
    let ctx = SessionContext::new();
    let decoded_expr = parse_physical_expr(&proto, &ctx.task_ctx(), &schema, &codec)?;

    println!("Step 3: Deserialized expression: {decoded_expr}");

    // 验证反序列化的表达式、执行结果是不是和原来一样
    let batch = RecordBatch::try_new(
        Arc::clone(&schema),
        vec![Arc::new(Float64Array::from(vec![4.0, 9.0, 16.0]))],
    )?;

    let original = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
    let decoded = decoded_expr
        .evaluate(&batch)?
        .into_array(batch.num_rows())?;
    let original = original
        .as_any()
        .downcast_ref::<Float64Array>()
        .ok_or_else(|| {
            DataFusionError::Execution("Expected Float64 result array".to_string())
        })?;
    let decoded = decoded
        .as_any()
        .downcast_ref::<Float64Array>()
        .ok_or_else(|| {
            DataFusionError::Execution("Expected Float64 result array".to_string())
        })?;

    assert_eq!(original, decoded);

    println!("Step 4: Evaluated both expressions successfully");
    println!("  input:  [4.0, 9.0, 16.0]");
    println!("  output: {decoded:?}");

    Ok(())
}

执行流程


ScalarFunctionExpr(sqrt(a))
↓ serialize_physical_expr
PhysicalExprNode::ScalarUdf
↓ parse_physical_expr
ScalarFunctionExpr(sqrt(a))
↓ evaluate
Float64Array [2.0, 3.0, 4.0]

运行命令和输出


/Users/zhengpeng/.cargo/bin/cargo run --color=always --example proto --profile dev --manifest-path /Users/zhengpeng/Source/Code/Rust-Code/Github/datafusion/datafusion-examples/Cargo.toml -- scalar_function_expr
    Finished `dev` profile [unoptimized + debuginfo] target(s) in 0.19s
     Running `target/debug/examples/proto scalar_function_expr`
Usage: cargo run --example proto -- [all|composed_extension_codec|expression_deduplication|scalar_function_expr]
=== ScalarFunctionExpr Proto Round Trip Example ===

Step 1: Built physical expression: sqrt(a@0)
Step 2: Serialized to proto: name=sqrt, args=1, has_fun_definition=false
Step 3: Deserialized expression: sqrt(a@0)
Step 4: Evaluated both expressions successfully
  input:  [4.0, 9.0, 16.0]
  output: PrimitiveArray<Float64>
[
  2.0,
  3.0,
  4.0,
]

Process finished with exit code 0

说明

例子里的 serialize_physical_expr 是把 Rust 里的 PhysicalExpr 转成 protobuf 的 Rust struct,还没有进一步 encode 成二进制 bytes; 如果要网络传输,还需要 prost::Message::encode。