From 64848442f900f56b06c3953ee5b3cc6cd97b9bc7 Mon Sep 17 00:00:00 2001 From: Vincent Rischmann Date: Sun, 17 Apr 2022 00:52:31 +0200 Subject: work on supporting aggregate SQL functions --- sqlite.zig | 337 +++++++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 259 insertions(+), 78 deletions(-) diff --git a/sqlite.zig b/sqlite.zig index 52a28dd..ad92cb9 100644 --- a/sqlite.zig +++ b/sqlite.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const builtin = @import("builtin"); const build_options = @import("build_options"); const debug = std.debug; const io = std.io; @@ -601,6 +602,89 @@ pub const Db = struct { return Savepoint.init(self, name); } + // Helpers functions to implement SQLite functions. + + fn sliceFromValue(sqlite_value: *c.sqlite3_value) []const u8 { + const size = @intCast(usize, c.sqlite3_value_bytes(sqlite_value)); + + const value = c.sqlite3_value_text(sqlite_value); + debug.assert(value != null); // TODO(vincent): how do we handle this properly ? + + return value.?[0..size]; + } + + /// Sets the result of a function call in the context `ctx`. + /// + /// Determines at compile time which sqlite3_result_XYZ function to use based on the type of `result`. + fn setFunctionResult(ctx: ?*c.sqlite3_context, result: anytype) void { + const ResultType = @TypeOf(result); + + switch (ResultType) { + Text => c.sqlite3_result_text(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT), + Blob => c.sqlite3_result_blob(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT), + else => switch (@typeInfo(ResultType)) { + .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) { + c.sqlite3_result_int(ctx, result); + } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) { + c.sqlite3_result_int64(ctx, result); + } else { + @compileError("integer " ++ @typeName(ResultType) ++ " is not representable in sqlite"); + }, + .Float => c.sqlite3_result_double(ctx, result), + .Bool => c.sqlite3_result_int(ctx, if (result) 1 else 0), + .Array => |arr| switch (arr.child) { + u8 => c.sqlite3_result_blob(ctx, &result, arr.len, c.SQLITE_TRANSIENT), + else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), + }, + .Pointer => |ptr| switch (ptr.size) { + .Slice => switch (ptr.child) { + u8 => c.sqlite3_result_text(ctx, result.ptr, @intCast(c_int, result.len), c.SQLITE_TRANSIENT), + else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), + }, + else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), + }, + else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), + }, + } + } + + /// Sets a function argument using the provided value. + /// + /// Determines at compile time which sqlite3_value_XYZ function to use based on the type `ArgType`. + fn setFunctionArgument(comptime ArgType: type, arg: *ArgType, sqlite_value: *c.sqlite3_value) void { + switch (ArgType) { + Text => arg.*.data = sliceFromValue(sqlite_value), + Blob => arg.*.data = sliceFromValue(sqlite_value), + else => switch (@typeInfo(ArgType)) { + .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) { + const value = c.sqlite3_value_int(sqlite_value); + arg.* = @intCast(ArgType, value); + } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) { + const value = c.sqlite3_value_int64(sqlite_value); + arg.* = @intCast(ArgType, value); + } else { + @compileError("integer " ++ @typeName(ArgType) ++ " is not representable in sqlite"); + }, + .Float => { + const value = c.sqlite3_value_double(sqlite_value); + arg.* = @floatCast(ArgType, value); + }, + .Bool => { + const value = c.sqlite3_value_int(sqlite_value); + arg.* = value > 0; + }, + .Pointer => |ptr| switch (ptr.size) { + .Slice => switch (ptr.child) { + u8 => arg.* = sliceFromValue(sqlite_value), + else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), + }, + else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), + }, + else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), + }, + } + } + /// CreateFunctionFlag controls the flags used when creating a custom SQL function. /// See https://sqlite.org/c3ref/c_deterministic.html. /// @@ -614,6 +698,117 @@ pub const Db = struct { direct_only: bool = true, }; + /// Creates an aggregate SQLite function with the given name. + /// + /// When the SQLite function is called in a statement, `step_func` will be called for each row with the input arguments. + /// Each SQLite argument is converted to a Zig value according to the following rules: + /// * TEXT values can be either sqlite.Text or []const u8 + /// * BLOB values can be either sqlite.Blob or []const u8 + /// * INTEGER values can be any Zig integer + /// * REAL values can be any Zig float + /// + /// The final result of the SQL function call will be what `finalize_func` returns. + /// + /// The context `my_ctx` contains the state necessary to perform the aggregation, both `step_func` and `finalize_func` must have at least the first argument of type ContextType. + /// + pub fn createAggregateFunction(self: *Self, func_name: [:0]const u8, my_ctx: anytype, comptime step_func: anytype, comptime finalize_func: anytype, comptime create_flags: CreateFunctionFlag) Error!void { + // Check that the context type is usable + const ContextType = @TypeOf(my_ctx); + switch (@typeInfo(ContextType)) { + .Pointer => |ptr_info| switch (ptr_info.size) { + .One => {}, + else => @compileError("cannot use context of type " ++ @typeName(ContextType) ++ ", must be a single-item pointer"), + }, + else => @compileError("cannot use context of type " ++ @typeName(ContextType) ++ ", must be a single-item pointer"), + } + + // Validate the step function + + const StepFuncType = @TypeOf(step_func); + + const step_fn_info = switch (@typeInfo(StepFuncType)) { + .Fn => |fn_info| fn_info, + else => @compileError("cannot use step_fn, expecting a function"), + }; + if (step_fn_info.is_generic) @compileError("step_fn function can't be generic"); + if (step_fn_info.is_var_args) @compileError("step_fn function can't be variadic"); + + const StepFuncArgTuple = std.meta.ArgsTuple(StepFuncType); + + // subtract one because the user-provided function always takes an additional context + const step_func_args_len = step_fn_info.args.len - 1; + + // Validate the finalize function + + const FinalizeFuncType = @TypeOf(finalize_func); + + const finalize_fn_info = switch (@typeInfo(FinalizeFuncType)) { + .Fn => |fn_info| fn_info, + else => @compileError("cannot use finalize_fn, expecting a function"), + }; + if (finalize_fn_info.args.len != 1) @compileError("finalize_fn must take exactly one argument"); + + const FinalizeFuncArgTuple = std.meta.ArgsTuple(FinalizeFuncType); + + // + + var flags: c_int = c.SQLITE_UTF8; + if (create_flags.deterministic) { + flags |= c.SQLITE_DETERMINISTIC; + } + if (create_flags.direct_only) { + flags |= c.SQLITE_DIRECTONLY; + } + + const result = c.sqlite3_create_function_v2( + self.db, + func_name, + step_func_args_len, + flags, + my_ctx, + null, // xFunc + struct { + fn xStep(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { + debug.assert(argc == step_func_args_len); + + const sqlite_args = argv.?[0..step_func_args_len]; + + var fn_args: StepFuncArgTuple = undefined; + + // First argument is always the user-provided context + fn_args[0] = @ptrCast(ContextType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx))); + + comptime var i: usize = 0; + inline while (i < step_func_args_len) : (i += 1) { + // we add 1 because we need to ignore the first argument which is the user-provided context, not a SQLite value. + const arg = step_fn_info.args[i + 1]; + const arg_ptr = &fn_args[i + 1]; + + const ArgType = arg.arg_type.?; + setFunctionArgument(ArgType, arg_ptr, sqlite_args[i].?); + } + + @call(.{}, step_func, fn_args); + } + }.xStep, + struct { + fn xFinal(ctx: ?*c.sqlite3_context) callconv(.C) void { + var fn_args: FinalizeFuncArgTuple = undefined; + // Only one argument, the user-provided context + fn_args[0] = @ptrCast(ContextType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx))); + + const result = @call(.{}, finalize_func, fn_args); + + setFunctionResult(ctx, result); + } + }.xFinal, + null, + ); + if (result != c.SQLITE_OK) { + return errors.errorFromResultCode(result); + } + } + /// Creates a scalar SQLite function with the given name. /// /// When the SQLite function is called in a statement, `func` will be called with the input arguments. @@ -652,81 +847,6 @@ pub const Db = struct { flags, null, struct { - fn sliceFromValue(sqlite_value: *c.sqlite3_value) []const u8 { - const size = @intCast(usize, c.sqlite3_value_bytes(sqlite_value)); - - const value = c.sqlite3_value_text(sqlite_value); - debug.assert(value != null); // TODO(vincent): how do we handle this properly ? - - return value.?[0..size]; - } - - fn bindValue(comptime ArgType: type, arg: *ArgType, sqlite_value: *c.sqlite3_value) void { - switch (ArgType) { - Text => arg.*.data = sliceFromValue(sqlite_value), - Blob => arg.*.data = sliceFromValue(sqlite_value), - else => switch (@typeInfo(ArgType)) { - .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) { - const value = c.sqlite3_value_int(sqlite_value); - arg.* = @intCast(ArgType, value); - } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) { - const value = c.sqlite3_value_int64(sqlite_value); - arg.* = @intCast(ArgType, value); - } else { - @compileError("integer " ++ @typeName(ArgType) ++ " is not representable in sqlite"); - }, - .Float => { - const value = c.sqlite3_value_double(sqlite_value); - arg.* = @floatCast(ArgType, value); - }, - .Bool => { - const value = c.sqlite3_value_int(sqlite_value); - arg.* = value > 0; - }, - .Pointer => |ptr| switch (ptr.size) { - .Slice => switch (ptr.child) { - u8 => arg.* = sliceFromValue(sqlite_value), - else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), - }, - else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), - }, - else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), - }, - } - } - - fn setResult(ctx: ?*c.sqlite3_context, result: anytype) void { - const ResultType = @TypeOf(result); - - switch (ResultType) { - Text => c.sqlite3_result_text(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT), - Blob => c.sqlite3_result_blob(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT), - else => switch (@typeInfo(ResultType)) { - .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) { - c.sqlite3_result_int(ctx, result); - } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) { - c.sqlite3_result_int64(ctx, result); - } else { - @compileError("integer " ++ @typeName(ResultType) ++ " is not representable in sqlite"); - }, - .Float => c.sqlite3_result_double(ctx, result), - .Bool => c.sqlite3_result_int(ctx, if (result) 1 else 0), - .Array => |arr| switch (arr.child) { - u8 => c.sqlite3_result_blob(ctx, &result, arr.len, c.SQLITE_TRANSIENT), - else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), - }, - .Pointer => |ptr| switch (ptr.size) { - .Slice => switch (ptr.child) { - u8 => c.sqlite3_result_text(ctx, result.ptr, @intCast(c_int, result.len), c.SQLITE_TRANSIENT), - else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), - }, - else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), - }, - else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), - }, - } - } - fn xFunc(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { debug.assert(argc == fn_info.args.len); @@ -735,13 +855,12 @@ pub const Db = struct { var fn_args: ArgTuple = undefined; inline for (fn_info.args) |arg, i| { const ArgType = arg.arg_type.?; - - bindValue(ArgType, &fn_args[i], sqlite_args[i].?); + setFunctionArgument(ArgType, &fn_args[i], sqlite_args[i].?); } const result = @call(.{}, func, fn_args); - setResult(ctx, result); + setFunctionResult(ctx, result); } }.xFunc, null, @@ -3477,6 +3596,68 @@ test "sqlite: create scalar function" { } } +test "sqlite: create aggregate function" { + // TODO(vincent): fix this, panics on incorrect pointer alignment when casting the SQLite user data to the context type + // in the xStep function. + if (builtin.cpu.arch.isAARCH64()) return error.SkipZigTest; + + var db = try getTestDb(); + defer db.deinit(); + + var rand = std.rand.DefaultPrng.init(@intCast(u64, std.time.milliTimestamp())); + + // Create an aggregate function working with a MyContext + + const MyContext = struct { + sum: u32, + }; + var my_ctx = MyContext{ .sum = 0 }; + + try db.createAggregateFunction( + "mySum", + &my_ctx, + struct { + fn step(ctx: *MyContext, input: u32) void { + ctx.sum += input; + } + }.step, + struct { + fn finalize(ctx: *MyContext) u32 { + return ctx.sum; + } + }.finalize, + .{}, + ); + + // Initialize some data + + try db.exec("CREATE TABLE view(id integer PRIMARY KEY, nb integer)", .{}, .{}); + var i: usize = 0; + var exp: usize = 0; + while (i < 20) : (i += 1) { + const val = rand.random().intRangeAtMost(u32, 0, 5205905); + exp += val; + + try db.exec("INSERT INTO view(nb) VALUES(?{u32})", .{}, .{val}); + } + + // Get the sum and check the result + + var diags = Diagnostics{}; + const result = db.one( + usize, + "SELECT mySum(nb) FROM view", + .{ .diags = &diags }, + .{}, + ) catch |err| { + debug.print("err: {}\n", .{diags}); + return err; + }; + + try testing.expect(result != null); + try testing.expectEqual(@as(usize, exp), result.?); +} + test "sqlite: empty slice" { var arena = std.heap.ArenaAllocator.init(testing.allocator); defer arena.deinit(); -- cgit v1.2.3