summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Vincent Rischmann2022-06-25 01:35:34 +0200
committerGravatar Vincent Rischmann2022-07-14 17:04:24 +0200
commit9d011583eaf01be6bdc8bd3b777fd7e3e2bbf572 (patch)
tree3c19103447f17f87f16d56ecb79c32de836e7891
parentreadme: fix allocator usage (diff)
downloadzig-sqlite-9d011583eaf01be6bdc8bd3b777fd7e3e2bbf572.tar.gz
zig-sqlite-9d011583eaf01be6bdc8bd3b777fd7e3e2bbf572.tar.xz
zig-sqlite-9d011583eaf01be6bdc8bd3b777fd7e3e2bbf572.zip
add a way to get the aggregate context with createAggregateFunction
The old way of working was that we always passed the user context as first argument to both `step` and `finalize` functions and the caller had no way of getting the aggregate context from SQLite (http://www3.sqlite.org/c3ref/aggregate_context.html). Now both `step` and `finalize` functions must have a first argument of type `FunctionContext`: fn step(fctx: FunctionContext, input: u32) void { var ctx = fctx.aggregateContext(*u32) orelse return; ctx.* += input; } fn finalize(ctx: *u32) u32 { var ctx = fctx.aggregateContext(*u32) orelse return 0; return ctx.sum; } Fixes #89
Diffstat (limited to '')
-rw-r--r--sqlite.zig214
1 files changed, 161 insertions, 53 deletions
diff --git a/sqlite.zig b/sqlite.zig
index ba117c4..e4757a2 100644
--- a/sqlite.zig
+++ b/sqlite.zig
@@ -716,6 +716,8 @@ pub const Db = struct {
716 716
717 /// Creates an aggregate SQLite function with the given name. 717 /// Creates an aggregate SQLite function with the given name.
718 /// 718 ///
719 /// `step_func` and `finalize_func` must be two functions. The first argument of both functions _must_ be of the type FunctionContext.
720 ///
719 /// When the SQLite function is called in a statement, `step_func` will be called for each row with the input arguments. 721 /// When the SQLite function is called in a statement, `step_func` will be called for each row with the input arguments.
720 /// Each SQLite argument is converted to a Zig value according to the following rules: 722 /// Each SQLite argument is converted to a Zig value according to the following rules:
721 /// * TEXT values can be either sqlite.Text or []const u8 723 /// * TEXT values can be either sqlite.Text or []const u8
@@ -724,47 +726,33 @@ pub const Db = struct {
724 /// * REAL values can be any Zig float 726 /// * REAL values can be any Zig float
725 /// 727 ///
726 /// The final result of the SQL function call will be what `finalize_func` returns. 728 /// The final result of the SQL function call will be what `finalize_func` returns.
727 /// 729 pub fn createAggregateFunction(self: *Self, comptime name: [:0]const u8, user_ctx: anytype, comptime step_func: anytype, comptime finalize_func: anytype, comptime create_flags: CreateFunctionFlag) Error!void {
728 /// The context `my_ctx` contains the state necessary to perform the aggregation, both `step_func` and `finalize_func` must have at least the first argument of type ContextType. 730 // Validate the functions
729 ///
730 pub fn createAggregateFunction(self: *Self, func_name: [:0]const u8, my_ctx: anytype, comptime step_func: anytype, comptime finalize_func: anytype, comptime create_flags: CreateFunctionFlag) Error!void {
731 // Check that the context type is usable
732 const ContextPtrType = @TypeOf(my_ctx);
733 const ContextType = switch (@typeInfo(ContextPtrType)) {
734 .Pointer => |ptr_info| switch (ptr_info.size) {
735 .One => ptr_info.child,
736 else => @compileError("cannot use context of type " ++ @typeName(ContextPtrType) ++ ", must be a single-item pointer"),
737 },
738 else => @compileError("cannot use context of type " ++ @typeName(ContextPtrType) ++ ", must be a single-item pointer"),
739 };
740
741 // Validate the step function
742
743 const StepFuncType = @TypeOf(step_func);
744 731
745 const step_fn_info = switch (@typeInfo(StepFuncType)) { 732 const step_fn_info = switch (@typeInfo(@TypeOf(step_func))) {
746 .Fn => |fn_info| fn_info, 733 .Fn => |fn_info| fn_info,
747 else => @compileError("cannot use step_fn, expecting a function"), 734 else => @compileError("cannot use func, expecting a function"),
748 }; 735 };
749 if (step_fn_info.is_generic) @compileError("step_fn function can't be generic"); 736 if (step_fn_info.is_generic) @compileError("step function can't be generic");
750 if (step_fn_info.is_var_args) @compileError("step_fn function can't be variadic"); 737 if (step_fn_info.is_var_args) @compileError("step function can't be variadic");
751
752 const StepFuncArgTuple = std.meta.ArgsTuple(StepFuncType);
753
754 // subtract one because the user-provided function always takes an additional context
755 const step_func_args_len = step_fn_info.args.len - 1;
756 738
757 // Validate the finalize function 739 const finalize_fn_info = switch (@typeInfo(@TypeOf(finalize_func))) {
758
759 const FinalizeFuncType = @TypeOf(finalize_func);
760
761 const finalize_fn_info = switch (@typeInfo(FinalizeFuncType)) {
762 .Fn => |fn_info| fn_info, 740 .Fn => |fn_info| fn_info,
763 else => @compileError("cannot use finalize_fn, expecting a function"), 741 else => @compileError("cannot use func, expecting a function"),
764 }; 742 };
765 if (finalize_fn_info.args.len != 1) @compileError("finalize_fn must take exactly one argument"); 743 if (finalize_fn_info.args.len != 1) @compileError("finalize function must take exactly one argument");
744 if (finalize_fn_info.is_generic) @compileError("finalize function can't be generic");
745 if (finalize_fn_info.is_var_args) @compileError("finalize function can't be variadic");
746
747 if (step_fn_info.args[0].arg_type.? != finalize_fn_info.args[0].arg_type.?) {
748 @compileError("both step and finalize functions must have the same first argument and it must be a FunctionContext");
749 }
750 if (step_fn_info.args[0].arg_type.? != FunctionContext) {
751 @compileError("both step and finalize functions must have a first argument of type FunctionContext");
752 }
766 753
767 const FinalizeFuncArgTuple = std.meta.ArgsTuple(FinalizeFuncType); 754 // subtract the context argument
755 const real_args_len = step_fn_info.args.len - 1;
768 756
769 // 757 //
770 758
@@ -772,42 +760,43 @@ pub const Db = struct {
772 760
773 const result = c.sqlite3_create_function_v2( 761 const result = c.sqlite3_create_function_v2(
774 self.db, 762 self.db,
775 func_name, 763 name,
776 step_func_args_len, 764 real_args_len,
777 flags, 765 flags,
778 my_ctx, 766 user_ctx,
779 null, // xFunc 767 null, // xFunc
780 struct { 768 struct {
781 fn xStep(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { 769 fn xStep(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void {
782 debug.assert(argc == step_func_args_len); 770 debug.assert(argc == real_args_len);
783 771
784 const sqlite_args = argv.?[0..step_func_args_len]; 772 const sqlite_args = argv.?[0..real_args_len];
785 773
786 var fn_args: StepFuncArgTuple = undefined; 774 var args: std.meta.ArgsTuple(@TypeOf(step_func)) = undefined;
787 775
788 // First argument is always the user-provided context 776 // Pass the function context
789 fn_args[0] = @ptrCast(ContextPtrType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx))); 777 args[0] = FunctionContext{ .ctx = ctx };
790 778
791 comptime var i: usize = 0; 779 comptime var i: usize = 0;
792 inline while (i < step_func_args_len) : (i += 1) { 780 inline while (i < real_args_len) : (i += 1) {
793 // we add 1 because we need to ignore the first argument which is the user-provided context, not a SQLite value. 781 // Remember the firt argument is always the function context
794 const arg = step_fn_info.args[i + 1]; 782 const arg = step_fn_info.args[i + 1];
795 const arg_ptr = &fn_args[i + 1]; 783 const arg_ptr = &args[i + 1];
796 784
797 const ArgType = arg.arg_type.?; 785 const ArgType = arg.arg_type.?;
798 setFunctionArgument(ArgType, arg_ptr, sqlite_args[i].?); 786 setFunctionArgument(ArgType, arg_ptr, sqlite_args[i].?);
799 } 787 }
800 788
801 @call(.{}, step_func, fn_args); 789 @call(.{}, step_func, args);
802 } 790 }
803 }.xStep, 791 }.xStep,
804 struct { 792 struct {
805 fn xFinal(ctx: ?*c.sqlite3_context) callconv(.C) void { 793 fn xFinal(ctx: ?*c.sqlite3_context) callconv(.C) void {
806 var fn_args: FinalizeFuncArgTuple = undefined; 794 var args: std.meta.ArgsTuple(@TypeOf(finalize_func)) = undefined;
807 // Only one argument, the user-provided context 795
808 fn_args[0] = @ptrCast(ContextPtrType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx))); 796 // Pass the function context
797 args[0] = FunctionContext{ .ctx = ctx };
809 798
810 const result = @call(.{}, finalize_func, fn_args); 799 const result = @call(.{}, finalize_func, args);
811 800
812 setFunctionResult(ctx, result); 801 setFunctionResult(ctx, result);
813 } 802 }
@@ -905,6 +894,62 @@ pub const Db = struct {
905 } 894 }
906}; 895};
907 896
897/// FunctionContext is the context passed as first parameter in the `step` and `finalize` functions used with `createAggregateFunction`.
898/// It provides two functions:
899/// * userContext to retrieve the user provided context
900/// * aggregateContext to create or retrieve the aggregate context
901///
902/// Both functions take a type as parameter and take care of casting so the caller doesn't have to do it.
903pub const FunctionContext = struct {
904 ctx: ?*c.sqlite3_context,
905
906 pub fn userContext(self: FunctionContext, comptime Type: type) ?Type {
907 const Types = splitPtrTypes(Type);
908
909 if (c.sqlite3_user_data(self.ctx)) |value| {
910 return @ptrCast(
911 Types.PointerType,
912 @alignCast(@alignOf(Types.ValueType), value),
913 );
914 }
915 return null;
916 }
917
918 pub fn aggregateContext(self: FunctionContext, comptime Type: type) ?Type {
919 const Types = splitPtrTypes(Type);
920
921 if (c.sqlite3_aggregate_context(self.ctx, @sizeOf(Types.ValueType))) |value| {
922 return @ptrCast(
923 Types.PointerType,
924 @alignCast(@alignOf(Types.ValueType), value),
925 );
926 }
927 return null;
928 }
929
930 const SplitPtrTypes = struct {
931 ValueType: type,
932 PointerType: type,
933 };
934
935 fn splitPtrTypes(comptime Type: type) SplitPtrTypes {
936 switch (@typeInfo(Type)) {
937 .Pointer => |ptr_info| switch (ptr_info.size) {
938 .One => return SplitPtrTypes{
939 .ValueType = ptr_info.child,
940 .PointerType = Type,
941 },
942 else => @compileError("cannot use type " ++ @typeName(Type) ++ ", must be a single-item pointer"),
943 },
944 .Void => return SplitPtrTypes{
945 .ValueType = void,
946 .PointerType = undefined,
947 },
948 else => @compileError("cannot use type " ++ @typeName(Type) ++ ", must be a single-item pointer"),
949 }
950 }
951};
952
908/// Savepoint is a helper type for managing savepoints. 953/// Savepoint is a helper type for managing savepoints.
909/// 954///
910/// A savepoint creates a transaction like BEGIN/COMMIT but they're named and can be nested. 955/// A savepoint creates a transaction like BEGIN/COMMIT but they're named and can be nested.
@@ -3758,7 +3803,7 @@ test "sqlite: create scalar function" {
3758 } 3803 }
3759} 3804}
3760 3805
3761test "sqlite: create aggregate function" { 3806test "sqlite: create aggregate function with no aggregate context" {
3762 var db = try getTestDb(); 3807 var db = try getTestDb();
3763 defer db.deinit(); 3808 defer db.deinit();
3764 3809
@@ -3775,12 +3820,14 @@ test "sqlite: create aggregate function" {
3775 "mySum", 3820 "mySum",
3776 &my_ctx, 3821 &my_ctx,
3777 struct { 3822 struct {
3778 fn step(ctx: *MyContext, input: u32) void { 3823 fn step(fctx: FunctionContext, input: u32) void {
3824 var ctx = fctx.userContext(*MyContext) orelse return;
3779 ctx.sum += input; 3825 ctx.sum += input;
3780 } 3826 }
3781 }.step, 3827 }.step,
3782 struct { 3828 struct {
3783 fn finalize(ctx: *MyContext) u32 { 3829 fn finalize(fctx: FunctionContext) u32 {
3830 var ctx = fctx.userContext(*MyContext) orelse return 0;
3784 return ctx.sum; 3831 return ctx.sum;
3785 } 3832 }
3786 }.finalize, 3833 }.finalize,
@@ -3816,6 +3863,67 @@ test "sqlite: create aggregate function" {
3816 try testing.expectEqual(@as(usize, exp), result.?); 3863 try testing.expectEqual(@as(usize, exp), result.?);
3817} 3864}
3818 3865
3866test "sqlite: create aggregate function with an aggregate context" {
3867 var db = try getTestDb();
3868 defer db.deinit();
3869
3870 var rand = std.rand.DefaultPrng.init(@intCast(u64, std.time.milliTimestamp()));
3871
3872 try db.createAggregateFunction(
3873 "mySum",
3874 null,
3875 struct {
3876 fn step(fctx: FunctionContext, input: u32) void {
3877 var ctx = fctx.aggregateContext(*u32) orelse return;
3878 ctx.* += input;
3879 }
3880 }.step,
3881 struct {
3882 fn finalize(fctx: FunctionContext) u32 {
3883 var ctx = fctx.aggregateContext(*u32) orelse return 0;
3884 return ctx.*;
3885 }
3886 }.finalize,
3887 .{},
3888 );
3889
3890 // Initialize some data
3891
3892 try db.exec("CREATE TABLE view(id integer PRIMARY KEY, a integer, b integer)", .{}, .{});
3893 var i: usize = 0;
3894 var exp_a: usize = 0;
3895 var exp_b: usize = 0;
3896 while (i < 20) : (i += 1) {
3897 const val1 = rand.random().intRangeAtMost(u32, 0, 5205905);
3898 exp_a += val1;
3899
3900 const val2 = rand.random().intRangeAtMost(u32, 0, 310455);
3901 exp_b += val2;
3902
3903 try db.exec("INSERT INTO view(a, b) VALUES(?{u32}, ?{u32})", .{}, .{ val1, val2 });
3904 }
3905
3906 // Get the sum and check the result
3907
3908 var diags = Diagnostics{};
3909 const result = db.one(
3910 struct {
3911 a_sum: usize,
3912 b_sum: usize,
3913 },
3914 "SELECT mySum(a), mySum(b) FROM view",
3915 .{ .diags = &diags },
3916 .{},
3917 ) catch |err| {
3918 debug.print("err: {}\n", .{diags});
3919 return err;
3920 };
3921
3922 try testing.expect(result != null);
3923 try testing.expectEqual(@as(usize, exp_a), result.?.a_sum);
3924 try testing.expectEqual(@as(usize, exp_b), result.?.b_sum);
3925}
3926
3819test "sqlite: empty slice" { 3927test "sqlite: empty slice" {
3820 var arena = std.heap.ArenaAllocator.init(testing.allocator); 3928 var arena = std.heap.ArenaAllocator.init(testing.allocator);
3821 defer arena.deinit(); 3929 defer arena.deinit();