Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(udf): simplify WASM UDF implementation by moving some logic to arrow-udf-wasm #20239

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
399 changes: 213 additions & 186 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ development = ["expect-test", "pretty_assertions"]

[workspace.dependencies]
foyer = { version = "0.14.0", features = ["tracing", "nightly"] }
mixtrics ={ version = "0.0.2", features = ["prometheus"] }
mixtrics = { version = "0.0.2", features = ["prometheus"] }
apache-avro = { git = "https://github.com/risingwavelabs/avro", rev = "25113ba88234a9ae23296e981d8302c290fdaa4b", features = [
"snappy",
"zstandard",
Expand Down Expand Up @@ -165,7 +165,11 @@ opendal = "0.49"
# used only by arrow-udf-flight
arrow-flight = "53"
arrow-udf-js = "0.5"
arrow-udf-wasm = { version = "0.4", features = ["build"] }
# TODO(): will change back to released version before merging
# arrow-udf-wasm = { version = "0.4", features = ["build"] }
arrow-udf-wasm = { git = "https://github.com/arrow-udf/arrow-udf", branch = "rc/wasm-find-function", features = [
"build",
] }
arrow-udf-python = "0.4"
arrow-udf-flight = "0.4"
clap = { version = "4", features = ["cargo", "derive", "env"] }
Expand Down
3 changes: 2 additions & 1 deletion proto/catalog.proto
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@
data.DataType return_type = 6;
string language = 7;
optional string link = 8;
optional string identifier = 10;
// The function name in the runtime / on the remote side that is bound to the UDF created in RisingWave.
optional string name_in_runtime = 10;

Check failure on line 328 in proto/catalog.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "10" with name "name_in_runtime" on message "Function" changed option "json_name" from "identifier" to "nameInRuntime".

Check failure on line 328 in proto/catalog.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "10" on message "Function" changed name from "identifier" to "name_in_runtime".
// The source code of the function.
optional string body = 14;
// The zstd-compressed binary of the function.
Expand Down
26 changes: 21 additions & 5 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,20 @@ message WindowFunction {
// Note: due to historic reasons, UserDefinedFunction is a oneof variant parallel to FunctionCall,
// while UserDefinedFunctionMetadata is embedded as a field in TableFunction and AggCall.

enum UdfExprVersion {
// Versions before introducing this enum.
UDF_EXPR_VERSION_UNSPECIFIED = 0;
// Begin from this version, we re-interpret `identifier` as `name_in_runtime`.
UDF_EXPR_VERSION_NAME_IN_RUNTIME = 1;

// IMPORTANT:
// Don't forget to change `UdfExprVersion::LATEST` in `prost/src/lib.rs` to the latest version
// when adding new versions to this enum.

// Only used for tests.
UDF_EXPR_VERSION_MAX = 2147483647;
}

message UserDefinedFunction {
repeated ExprNode children = 1;
string name = 2;
Expand All @@ -605,12 +619,9 @@ message UserDefinedFunction {
string language = 4;
// The link to the external function service.
optional string link = 5;
// An unique identifier to the function.
// - If `link` is not empty, the name of the function in the external function service.
// - If `language` is `rust` or `wasm`, the name of the function in the wasm binary file.
// - If `language` is `javascript`, the name of the function.
// This is re-interpreted as `name_in_runtime`.
optional string identifier = 6;
// - If `language` is `javascript`, the source code of the function.
// - If `language` is `javascript` or `python`, the source code of the function.
optional string body = 7;
// - If `language` is `rust` or `wasm`, the zstd-compressed wasm binary.
optional bytes compressed_binary = 10;
Expand All @@ -619,6 +630,8 @@ message UserDefinedFunction {
optional string runtime = 11;
reserved 12;
reserved "function_type";

UdfExprVersion version = 1000;
}

// Additional information for user defined table/aggregate functions.
Expand All @@ -628,10 +641,13 @@ message UserDefinedFunctionMetadata {
data.DataType return_type = 13;
string language = 4;
optional string link = 5;
// This is re-interpreted as `name_in_runtime`.
optional string identifier = 6;
optional string body = 7;
optional bytes compressed_binary = 10;
optional string runtime = 11;
reserved 12;
reserved "function_type";

UdfExprVersion version = 1000;
}
31 changes: 24 additions & 7 deletions src/expr/core/src/aggregate/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ use risingwave_common::array::arrow::arrow_schema_udf::{Field, Fields, Schema, S
use risingwave_common::array::arrow::{UdfArrowConvert, UdfFromArrow, UdfToArrow};
use risingwave_common::array::Op;
use risingwave_common::bitmap::Bitmap;
use risingwave_pb::expr::PbUserDefinedFunctionMetadata;
use risingwave_pb::expr::{PbUdfExprVersion, PbUserDefinedFunctionMetadata};

use super::*;
use crate::sig::{UdfImpl, UdfKind, UdfOptions};
use crate::sig::{BuildOptions, UdfImpl, UdfKind};

#[derive(Debug)]
pub struct UserDefinedAggregateFunction {
Expand Down Expand Up @@ -123,19 +123,36 @@ pub fn new_user_defined(
return_type: &DataType,
udf: &PbUserDefinedFunctionMetadata,
) -> Result<BoxedAggregateFunction> {
let identifier = udf.get_identifier()?;
let arg_types = udf.arg_types.iter().map(|t| t.into()).collect::<Vec<_>>();
let language = udf.language.as_str();
let runtime = udf.runtime.as_deref();
let link = udf.link.as_deref();

// `identifier` field is re-interpreted as `name_in_runtime`.
if udf.version() < PbUdfExprVersion::NameInRuntime {
assert_ne!(
language, "rust",
"Rust UDAF was not supported yet before this version"
);
assert_ne!(
language, "wasm",
"WASM UDAF was not supported yet before this version"
);
}
let name_in_runtime = udf
.identifier
.as_ref()
.expect("SQL UDF won't get here, other UDFs must have `name_in_runtime`");

let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
let runtime = build_fn(UdfOptions {
let runtime = build_fn(BuildOptions {
kind: UdfKind::Aggregate,
body: udf.body.as_deref(),
compressed_binary: udf.compressed_binary.as_deref(),
link: udf.link.as_deref(),
identifier,
name_in_runtime,
arg_names: &udf.arg_names,
arg_types: &arg_types,
return_type,
always_retry_on_network_error: false,
})
Expand All @@ -145,9 +162,9 @@ pub fn new_user_defined(
// so we can assume that the runtime is not legacy
let arrow_convert = UdfArrowConvert::default();
let arg_schema = Arc::new(Schema::new(
udf.arg_types
arg_types
.iter()
.map(|t| arrow_convert.to_arrow_field("", &DataType::from(t)))
.map(|t| arrow_convert.to_arrow_field("", t))
.try_collect::<_, Fields, _>()?,
));

Expand Down
40 changes: 29 additions & 11 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ use risingwave_common::monitor::GLOBAL_METRICS_REGISTRY;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
use risingwave_expr::expr_context::FRAGMENT_ID;
use risingwave_pb::expr::ExprNode;
use risingwave_pb::expr::{ExprNode, PbUdfExprVersion};

use super::{BoxedExpression, Build};
use crate::expr::Expression;
use crate::sig::{UdfImpl, UdfKind, UdfOptions};
use crate::sig::{BuildOptions, UdfImpl, UdfKind};
use crate::{bail, ExprError, Result};

#[derive(Debug)]
Expand Down Expand Up @@ -168,21 +168,39 @@ impl Build for UserDefinedFunction {
) -> Result<Self> {
let return_type = DataType::from(prost.get_return_type().unwrap());
let udf = prost.get_rex_node().unwrap().as_udf().unwrap();
let identifier = udf.get_identifier()?;
let name = udf.get_name();
let arg_types = udf.arg_types.iter().map(|t| t.into()).collect::<Vec<_>>();

let language = udf.language.as_str();
let runtime = udf.runtime.as_deref();
let link = udf.link.as_deref();

let name_in_runtime = if udf.version() < PbUdfExprVersion::NameInRuntime {
if language == "rust" || language == "wasm" {
// The `identifier` value of Rust and WASM UDF before `NameInRuntime`
// is not used any more. The real bound function name should be the same
// as `name`.
Some(name)
} else {
// `identifier`s of other UDFs already mean `name_in_runtime` before `NameInRuntime`.
udf.identifier.as_ref()
}
} else {
// after `PbUdfExprVersion::NameInRuntime`, `identifier` means `name_in_runtime`
udf.identifier.as_ref()
}
.expect("SQL UDF won't get here, other UDFs must have `name_in_runtime`");

// lookup UDF builder
let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
let runtime = build_fn(UdfOptions {
let runtime = build_fn(BuildOptions {
kind: UdfKind::Scalar,
body: udf.body.as_deref(),
compressed_binary: udf.compressed_binary.as_deref(),
link: udf.link.as_deref(),
identifier,
name_in_runtime,
arg_names: &udf.arg_names,
arg_types: &arg_types,
return_type: &return_type,
always_retry_on_network_error: udf.always_retry_on_network_error,
})
Expand All @@ -202,7 +220,7 @@ impl Build for UserDefinedFunction {
let metrics = GLOBAL_METRICS.with_label_values(
link.unwrap_or(""),
language,
identifier,
name,
// batch query does not have a fragment_id
&FRAGMENT_ID::try_with(ToOwned::to_owned)
.unwrap_or(0)
Expand All @@ -211,12 +229,12 @@ impl Build for UserDefinedFunction {

Ok(Self {
children: udf.children.iter().map(build_child).try_collect()?,
arg_types: udf.arg_types.iter().map(|t| t.into()).collect(),
arg_types,
return_type,
arg_schema,
runtime,
arrow_convert,
span: format!("udf_call({})", identifier).into(),
span: format!("udf_call({})", name).into(),
metrics,
})
}
Expand Down Expand Up @@ -348,15 +366,15 @@ impl MetricsVec {
&self,
link: &str,
language: &str,
identifier: &str,
name: &str,
fragment_id: &str,
) -> Metrics {
// generate an unique id for each instance
static NEXT_INSTANCE_ID: AtomicU64 = AtomicU64::new(0);
let instance_id = NEXT_INSTANCE_ID.fetch_add(1, Ordering::Relaxed).to_string();

let labels = &[link, language, identifier, fragment_id];
let labels5 = &[link, language, identifier, fragment_id, &instance_id];
let labels = &[link, language, name, fragment_id];
let labels5 = &[link, language, name, fragment_id, &instance_id];

Metrics {
success_count: self.success_count.with_guarded_label_values(labels),
Expand Down
20 changes: 14 additions & 6 deletions src/expr/core/src/sig/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! See expr/impl/src/udf for the implementations.

use anyhow::{bail, Context, Result};
use educe::Educe;
use enum_as_inner::EnumAsInner;
use futures::stream::BoxStream;
use risingwave_common::array::arrow::arrow_array_udf::{ArrayRef, BooleanArray, RecordBatch};
Expand Down Expand Up @@ -65,44 +66,51 @@ pub struct UdfImplDescriptor {
/// Creates a function from options.
///
/// This function will be called when `create function` statement is executed on the frontend.
pub create_fn: fn(opts: CreateFunctionOptions<'_>) -> Result<CreateFunctionOutput>,
pub create_fn: fn(opts: CreateOptions<'_>) -> Result<CreateFunctionOutput>,

/// Builds UDF runtime from verified options.
///
/// This function will be called before the UDF is executed on the backend.
pub build_fn: fn(opts: UdfOptions<'_>) -> Result<Box<dyn UdfImpl>>,
pub build_fn: fn(opts: BuildOptions<'_>) -> Result<Box<dyn UdfImpl>>,
}

/// Options for creating a function.
///
/// These information are parsed from `CREATE FUNCTION` statement.
/// Implementations should verify the options and return a `CreateFunctionOutput` in `create_fn`.
pub struct CreateFunctionOptions<'a> {
pub struct CreateOptions<'a> {
pub kind: UdfKind,
/// The function name registered in RisingWave.
pub name: &'a str,
pub arg_names: &'a [String],
pub arg_types: &'a [DataType],
pub return_type: &'a DataType,
/// The function name on the remote side / in the source code, currently only used for external UDF.
pub as_: Option<&'a str>,
pub using_link: Option<&'a str>,
pub using_base64_decoded: Option<&'a [u8]>,
}

/// Output of creating a function.
pub struct CreateFunctionOutput {
pub identifier: String,
/// The name for identifying the function in the UDF runtime.
pub name_in_runtime: String,
pub body: Option<String>,
pub compressed_binary: Option<Vec<u8>>,
}

/// Options for building a UDF runtime.
pub struct UdfOptions<'a> {
#[derive(Educe)]
#[educe(Debug)]
pub struct BuildOptions<'a> {
pub kind: UdfKind,
pub body: Option<&'a str>,
#[educe(Debug(ignore))]
pub compressed_binary: Option<&'a [u8]>,
pub link: Option<&'a str>,
pub identifier: &'a str,
pub name_in_runtime: &'a str,
pub arg_names: &'a [String],
pub arg_types: &'a [DataType],
pub return_type: &'a DataType,
pub always_retry_on_network_error: bool,
}
Expand Down
40 changes: 34 additions & 6 deletions src/expr/core/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ use risingwave_common::array::arrow::arrow_schema_udf::{Fields, Schema, SchemaRe
use risingwave_common::array::arrow::{UdfArrowConvert, UdfFromArrow, UdfToArrow};
use risingwave_common::array::I32Array;
use risingwave_common::bail;
use risingwave_pb::expr::PbUdfExprVersion;

use super::*;
use crate::sig::{UdfImpl, UdfKind, UdfOptions};
use crate::sig::{BuildOptions, UdfImpl, UdfKind};

#[derive(Debug)]
pub struct UserDefinedTableFunction {
Expand Down Expand Up @@ -123,21 +124,48 @@ impl UserDefinedTableFunction {
pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<BoxedTableFunction> {
let udf = prost.get_udf()?;

let identifier = udf.get_identifier()?;
let arg_types = udf.arg_types.iter().map(|t| t.into()).collect::<Vec<_>>();
let return_type = DataType::from(prost.get_return_type()?);

let language = udf.language.as_str();
let runtime = udf.runtime.as_deref();
let link = udf.link.as_deref();

let name_in_runtime = if udf.version() < PbUdfExprVersion::NameInRuntime {
if language == "rust" || language == "wasm" {
// The `identifier` value of Rust and WASM UDF before `NameInRuntime`
// is not used any more. And unfortunately, we don't have the original name
// in `PbUserDefinedFunctionMetadata`, so we need to extract the name from
// the old `identifier` value (e.g. `foo()->int32`).
let old_identifier = udf
.identifier
.as_ref()
.expect("Rust/WASM UDF must have identifier");
Some(
old_identifier
.split_once("(")
.expect("the old identifier must contain `(`")
.0,
)
} else {
// `identifier`s of other UDFs already mean `name_in_runtime` before `NameInRuntime`.
udf.identifier.as_deref()
}
} else {
// after `PbUdfExprVersion::NameInRuntime`, `identifier` means `name_in_runtime`
udf.identifier.as_deref()
}
.expect("SQL UDF won't get here, other UDFs must have `name_in_runtime`");

let build_fn = crate::sig::find_udf_impl(language, runtime, link)?.build_fn;
let runtime = build_fn(UdfOptions {
let runtime = build_fn(BuildOptions {
kind: UdfKind::Table,
body: udf.body.as_deref(),
compressed_binary: udf.compressed_binary.as_deref(),
link: udf.link.as_deref(),
identifier,
name_in_runtime,
arg_names: &udf.arg_names,
arg_types: &arg_types,
return_type: &return_type,
always_retry_on_network_error: false,
})
Expand All @@ -147,9 +175,9 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
legacy: runtime.is_legacy(),
};
let arg_schema = Arc::new(Schema::new(
udf.arg_types
arg_types
.iter()
.map(|t| arrow_convert.to_arrow_field("", &DataType::from(t)))
.map(|t| arrow_convert.to_arrow_field("", t))
.try_collect::<Fields>()?,
));

Expand Down
Loading
Loading