From 9d011583eaf01be6bdc8bd3b777fd7e3e2bbf572 Mon Sep 17 00:00:00 2001 From: Vincent Rischmann Date: Sat, 25 Jun 2022 01:35:34 +0200 Subject: add a way to get the aggregate context with createAggregateFunction The old way of working was that we always passed the user context as first argument to both `step` and `finalize` functions and the caller had no way of getting the aggregate context from SQLite (http://www3.sqlite.org/c3ref/aggregate_context.html). Now both `step` and `finalize` functions must have a first argument of type `FunctionContext`: fn step(fctx: FunctionContext, input: u32) void { var ctx = fctx.aggregateContext(*u32) orelse return; ctx.* += input; } fn finalize(ctx: *u32) u32 { var ctx = fctx.aggregateContext(*u32) orelse return 0; return ctx.sum; } Fixes #89 --- sqlite.zig | 214 ++++++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 161 insertions(+), 53 deletions(-) (limited to 'sqlite.zig') diff --git a/sqlite.zig b/sqlite.zig index ba117c4..e4757a2 100644 --- a/sqlite.zig +++ b/sqlite.zig @@ -716,6 +716,8 @@ pub const Db = struct { /// Creates an aggregate SQLite function with the given name. /// + /// `step_func` and `finalize_func` must be two functions. The first argument of both functions _must_ be of the type FunctionContext. + /// /// When the SQLite function is called in a statement, `step_func` will be called for each row with the input arguments. /// Each SQLite argument is converted to a Zig value according to the following rules: /// * TEXT values can be either sqlite.Text or []const u8 @@ -724,47 +726,33 @@ pub const Db = struct { /// * REAL values can be any Zig float /// /// The final result of the SQL function call will be what `finalize_func` returns. - /// - /// The context `my_ctx` contains the state necessary to perform the aggregation, both `step_func` and `finalize_func` must have at least the first argument of type ContextType. - /// - pub fn createAggregateFunction(self: *Self, func_name: [:0]const u8, my_ctx: anytype, comptime step_func: anytype, comptime finalize_func: anytype, comptime create_flags: CreateFunctionFlag) Error!void { - // Check that the context type is usable - const ContextPtrType = @TypeOf(my_ctx); - const ContextType = switch (@typeInfo(ContextPtrType)) { - .Pointer => |ptr_info| switch (ptr_info.size) { - .One => ptr_info.child, - else => @compileError("cannot use context of type " ++ @typeName(ContextPtrType) ++ ", must be a single-item pointer"), - }, - else => @compileError("cannot use context of type " ++ @typeName(ContextPtrType) ++ ", must be a single-item pointer"), - }; - - // Validate the step function - - const StepFuncType = @TypeOf(step_func); + pub fn createAggregateFunction(self: *Self, comptime name: [:0]const u8, user_ctx: anytype, comptime step_func: anytype, comptime finalize_func: anytype, comptime create_flags: CreateFunctionFlag) Error!void { + // Validate the functions - const step_fn_info = switch (@typeInfo(StepFuncType)) { + const step_fn_info = switch (@typeInfo(@TypeOf(step_func))) { .Fn => |fn_info| fn_info, - else => @compileError("cannot use step_fn, expecting a function"), + else => @compileError("cannot use func, expecting a function"), }; - if (step_fn_info.is_generic) @compileError("step_fn function can't be generic"); - if (step_fn_info.is_var_args) @compileError("step_fn function can't be variadic"); - - const StepFuncArgTuple = std.meta.ArgsTuple(StepFuncType); - - // subtract one because the user-provided function always takes an additional context - const step_func_args_len = step_fn_info.args.len - 1; + if (step_fn_info.is_generic) @compileError("step function can't be generic"); + if (step_fn_info.is_var_args) @compileError("step function can't be variadic"); - // Validate the finalize function - - const FinalizeFuncType = @TypeOf(finalize_func); - - const finalize_fn_info = switch (@typeInfo(FinalizeFuncType)) { + const finalize_fn_info = switch (@typeInfo(@TypeOf(finalize_func))) { .Fn => |fn_info| fn_info, - else => @compileError("cannot use finalize_fn, expecting a function"), + else => @compileError("cannot use func, expecting a function"), }; - if (finalize_fn_info.args.len != 1) @compileError("finalize_fn must take exactly one argument"); + if (finalize_fn_info.args.len != 1) @compileError("finalize function must take exactly one argument"); + if (finalize_fn_info.is_generic) @compileError("finalize function can't be generic"); + if (finalize_fn_info.is_var_args) @compileError("finalize function can't be variadic"); + + if (step_fn_info.args[0].arg_type.? != finalize_fn_info.args[0].arg_type.?) { + @compileError("both step and finalize functions must have the same first argument and it must be a FunctionContext"); + } + if (step_fn_info.args[0].arg_type.? != FunctionContext) { + @compileError("both step and finalize functions must have a first argument of type FunctionContext"); + } - const FinalizeFuncArgTuple = std.meta.ArgsTuple(FinalizeFuncType); + // subtract the context argument + const real_args_len = step_fn_info.args.len - 1; // @@ -772,42 +760,43 @@ pub const Db = struct { const result = c.sqlite3_create_function_v2( self.db, - func_name, - step_func_args_len, + name, + real_args_len, flags, - my_ctx, + user_ctx, null, // xFunc struct { fn xStep(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { - debug.assert(argc == step_func_args_len); + debug.assert(argc == real_args_len); - const sqlite_args = argv.?[0..step_func_args_len]; + const sqlite_args = argv.?[0..real_args_len]; - var fn_args: StepFuncArgTuple = undefined; + var args: std.meta.ArgsTuple(@TypeOf(step_func)) = undefined; - // First argument is always the user-provided context - fn_args[0] = @ptrCast(ContextPtrType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx))); + // Pass the function context + args[0] = FunctionContext{ .ctx = ctx }; comptime var i: usize = 0; - inline while (i < step_func_args_len) : (i += 1) { - // we add 1 because we need to ignore the first argument which is the user-provided context, not a SQLite value. + inline while (i < real_args_len) : (i += 1) { + // Remember the firt argument is always the function context const arg = step_fn_info.args[i + 1]; - const arg_ptr = &fn_args[i + 1]; + const arg_ptr = &args[i + 1]; const ArgType = arg.arg_type.?; setFunctionArgument(ArgType, arg_ptr, sqlite_args[i].?); } - @call(.{}, step_func, fn_args); + @call(.{}, step_func, args); } }.xStep, struct { fn xFinal(ctx: ?*c.sqlite3_context) callconv(.C) void { - var fn_args: FinalizeFuncArgTuple = undefined; - // Only one argument, the user-provided context - fn_args[0] = @ptrCast(ContextPtrType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx))); + var args: std.meta.ArgsTuple(@TypeOf(finalize_func)) = undefined; + + // Pass the function context + args[0] = FunctionContext{ .ctx = ctx }; - const result = @call(.{}, finalize_func, fn_args); + const result = @call(.{}, finalize_func, args); setFunctionResult(ctx, result); } @@ -905,6 +894,62 @@ pub const Db = struct { } }; +/// FunctionContext is the context passed as first parameter in the `step` and `finalize` functions used with `createAggregateFunction`. +/// It provides two functions: +/// * userContext to retrieve the user provided context +/// * aggregateContext to create or retrieve the aggregate context +/// +/// Both functions take a type as parameter and take care of casting so the caller doesn't have to do it. +pub const FunctionContext = struct { + ctx: ?*c.sqlite3_context, + + pub fn userContext(self: FunctionContext, comptime Type: type) ?Type { + const Types = splitPtrTypes(Type); + + if (c.sqlite3_user_data(self.ctx)) |value| { + return @ptrCast( + Types.PointerType, + @alignCast(@alignOf(Types.ValueType), value), + ); + } + return null; + } + + pub fn aggregateContext(self: FunctionContext, comptime Type: type) ?Type { + const Types = splitPtrTypes(Type); + + if (c.sqlite3_aggregate_context(self.ctx, @sizeOf(Types.ValueType))) |value| { + return @ptrCast( + Types.PointerType, + @alignCast(@alignOf(Types.ValueType), value), + ); + } + return null; + } + + const SplitPtrTypes = struct { + ValueType: type, + PointerType: type, + }; + + fn splitPtrTypes(comptime Type: type) SplitPtrTypes { + switch (@typeInfo(Type)) { + .Pointer => |ptr_info| switch (ptr_info.size) { + .One => return SplitPtrTypes{ + .ValueType = ptr_info.child, + .PointerType = Type, + }, + else => @compileError("cannot use type " ++ @typeName(Type) ++ ", must be a single-item pointer"), + }, + .Void => return SplitPtrTypes{ + .ValueType = void, + .PointerType = undefined, + }, + else => @compileError("cannot use type " ++ @typeName(Type) ++ ", must be a single-item pointer"), + } + } +}; + /// Savepoint is a helper type for managing savepoints. /// /// A savepoint creates a transaction like BEGIN/COMMIT but they're named and can be nested. @@ -3758,7 +3803,7 @@ test "sqlite: create scalar function" { } } -test "sqlite: create aggregate function" { +test "sqlite: create aggregate function with no aggregate context" { var db = try getTestDb(); defer db.deinit(); @@ -3775,12 +3820,14 @@ test "sqlite: create aggregate function" { "mySum", &my_ctx, struct { - fn step(ctx: *MyContext, input: u32) void { + fn step(fctx: FunctionContext, input: u32) void { + var ctx = fctx.userContext(*MyContext) orelse return; ctx.sum += input; } }.step, struct { - fn finalize(ctx: *MyContext) u32 { + fn finalize(fctx: FunctionContext) u32 { + var ctx = fctx.userContext(*MyContext) orelse return 0; return ctx.sum; } }.finalize, @@ -3816,6 +3863,67 @@ test "sqlite: create aggregate function" { try testing.expectEqual(@as(usize, exp), result.?); } +test "sqlite: create aggregate function with an aggregate context" { + var db = try getTestDb(); + defer db.deinit(); + + var rand = std.rand.DefaultPrng.init(@intCast(u64, std.time.milliTimestamp())); + + try db.createAggregateFunction( + "mySum", + null, + struct { + fn step(fctx: FunctionContext, input: u32) void { + var ctx = fctx.aggregateContext(*u32) orelse return; + ctx.* += input; + } + }.step, + struct { + fn finalize(fctx: FunctionContext) u32 { + var ctx = fctx.aggregateContext(*u32) orelse return 0; + return ctx.*; + } + }.finalize, + .{}, + ); + + // Initialize some data + + try db.exec("CREATE TABLE view(id integer PRIMARY KEY, a integer, b integer)", .{}, .{}); + var i: usize = 0; + var exp_a: usize = 0; + var exp_b: usize = 0; + while (i < 20) : (i += 1) { + const val1 = rand.random().intRangeAtMost(u32, 0, 5205905); + exp_a += val1; + + const val2 = rand.random().intRangeAtMost(u32, 0, 310455); + exp_b += val2; + + try db.exec("INSERT INTO view(a, b) VALUES(?{u32}, ?{u32})", .{}, .{ val1, val2 }); + } + + // Get the sum and check the result + + var diags = Diagnostics{}; + const result = db.one( + struct { + a_sum: usize, + b_sum: usize, + }, + "SELECT mySum(a), mySum(b) FROM view", + .{ .diags = &diags }, + .{}, + ) catch |err| { + debug.print("err: {}\n", .{diags}); + return err; + }; + + try testing.expect(result != null); + try testing.expectEqual(@as(usize, exp_a), result.?.a_sum); + try testing.expectEqual(@as(usize, exp_b), result.?.b_sum); +} + test "sqlite: empty slice" { var arena = std.heap.ArenaAllocator.init(testing.allocator); defer arena.deinit(); -- cgit v1.2.3