diff options
| author | 2022-06-25 01:35:34 +0200 | |
|---|---|---|
| committer | 2022-07-14 17:04:24 +0200 | |
| commit | 9d011583eaf01be6bdc8bd3b777fd7e3e2bbf572 (patch) | |
| tree | 3c19103447f17f87f16d56ecb79c32de836e7891 /sqlite.zig | |
| parent | readme: fix allocator usage (diff) | |
| download | zig-sqlite-9d011583eaf01be6bdc8bd3b777fd7e3e2bbf572.tar.gz zig-sqlite-9d011583eaf01be6bdc8bd3b777fd7e3e2bbf572.tar.xz zig-sqlite-9d011583eaf01be6bdc8bd3b777fd7e3e2bbf572.zip | |
add a way to get the aggregate context with createAggregateFunction
The old way of working was that we always passed the user context as
first argument to both `step` and `finalize` functions and the caller
had no way of getting the aggregate context from
SQLite (http://www3.sqlite.org/c3ref/aggregate_context.html).
Now both `step` and `finalize` functions must have a first argument of type `FunctionContext`:
fn step(fctx: FunctionContext, input: u32) void {
var ctx = fctx.aggregateContext(*u32) orelse return;
ctx.* += input;
}
fn finalize(ctx: *u32) u32 {
var ctx = fctx.aggregateContext(*u32) orelse return 0;
return ctx.sum;
}
Fixes #89
Diffstat (limited to 'sqlite.zig')
| -rw-r--r-- | sqlite.zig | 214 |
1 files changed, 161 insertions, 53 deletions
| @@ -716,6 +716,8 @@ pub const Db = struct { | |||
| 716 | 716 | ||
| 717 | /// Creates an aggregate SQLite function with the given name. | 717 | /// Creates an aggregate SQLite function with the given name. |
| 718 | /// | 718 | /// |
| 719 | /// `step_func` and `finalize_func` must be two functions. The first argument of both functions _must_ be of the type FunctionContext. | ||
| 720 | /// | ||
| 719 | /// When the SQLite function is called in a statement, `step_func` will be called for each row with the input arguments. | 721 | /// When the SQLite function is called in a statement, `step_func` will be called for each row with the input arguments. |
| 720 | /// Each SQLite argument is converted to a Zig value according to the following rules: | 722 | /// Each SQLite argument is converted to a Zig value according to the following rules: |
| 721 | /// * TEXT values can be either sqlite.Text or []const u8 | 723 | /// * TEXT values can be either sqlite.Text or []const u8 |
| @@ -724,47 +726,33 @@ pub const Db = struct { | |||
| 724 | /// * REAL values can be any Zig float | 726 | /// * REAL values can be any Zig float |
| 725 | /// | 727 | /// |
| 726 | /// The final result of the SQL function call will be what `finalize_func` returns. | 728 | /// The final result of the SQL function call will be what `finalize_func` returns. |
| 727 | /// | 729 | pub fn createAggregateFunction(self: *Self, comptime name: [:0]const u8, user_ctx: anytype, comptime step_func: anytype, comptime finalize_func: anytype, comptime create_flags: CreateFunctionFlag) Error!void { |
| 728 | /// The context `my_ctx` contains the state necessary to perform the aggregation, both `step_func` and `finalize_func` must have at least the first argument of type ContextType. | 730 | // Validate the functions |
| 729 | /// | ||
| 730 | pub fn createAggregateFunction(self: *Self, func_name: [:0]const u8, my_ctx: anytype, comptime step_func: anytype, comptime finalize_func: anytype, comptime create_flags: CreateFunctionFlag) Error!void { | ||
| 731 | // Check that the context type is usable | ||
| 732 | const ContextPtrType = @TypeOf(my_ctx); | ||
| 733 | const ContextType = switch (@typeInfo(ContextPtrType)) { | ||
| 734 | .Pointer => |ptr_info| switch (ptr_info.size) { | ||
| 735 | .One => ptr_info.child, | ||
| 736 | else => @compileError("cannot use context of type " ++ @typeName(ContextPtrType) ++ ", must be a single-item pointer"), | ||
| 737 | }, | ||
| 738 | else => @compileError("cannot use context of type " ++ @typeName(ContextPtrType) ++ ", must be a single-item pointer"), | ||
| 739 | }; | ||
| 740 | |||
| 741 | // Validate the step function | ||
| 742 | |||
| 743 | const StepFuncType = @TypeOf(step_func); | ||
| 744 | 731 | ||
| 745 | const step_fn_info = switch (@typeInfo(StepFuncType)) { | 732 | const step_fn_info = switch (@typeInfo(@TypeOf(step_func))) { |
| 746 | .Fn => |fn_info| fn_info, | 733 | .Fn => |fn_info| fn_info, |
| 747 | else => @compileError("cannot use step_fn, expecting a function"), | 734 | else => @compileError("cannot use func, expecting a function"), |
| 748 | }; | 735 | }; |
| 749 | if (step_fn_info.is_generic) @compileError("step_fn function can't be generic"); | 736 | if (step_fn_info.is_generic) @compileError("step function can't be generic"); |
| 750 | if (step_fn_info.is_var_args) @compileError("step_fn function can't be variadic"); | 737 | if (step_fn_info.is_var_args) @compileError("step function can't be variadic"); |
| 751 | |||
| 752 | const StepFuncArgTuple = std.meta.ArgsTuple(StepFuncType); | ||
| 753 | |||
| 754 | // subtract one because the user-provided function always takes an additional context | ||
| 755 | const step_func_args_len = step_fn_info.args.len - 1; | ||
| 756 | 738 | ||
| 757 | // Validate the finalize function | 739 | const finalize_fn_info = switch (@typeInfo(@TypeOf(finalize_func))) { |
| 758 | |||
| 759 | const FinalizeFuncType = @TypeOf(finalize_func); | ||
| 760 | |||
| 761 | const finalize_fn_info = switch (@typeInfo(FinalizeFuncType)) { | ||
| 762 | .Fn => |fn_info| fn_info, | 740 | .Fn => |fn_info| fn_info, |
| 763 | else => @compileError("cannot use finalize_fn, expecting a function"), | 741 | else => @compileError("cannot use func, expecting a function"), |
| 764 | }; | 742 | }; |
| 765 | if (finalize_fn_info.args.len != 1) @compileError("finalize_fn must take exactly one argument"); | 743 | if (finalize_fn_info.args.len != 1) @compileError("finalize function must take exactly one argument"); |
| 744 | if (finalize_fn_info.is_generic) @compileError("finalize function can't be generic"); | ||
| 745 | if (finalize_fn_info.is_var_args) @compileError("finalize function can't be variadic"); | ||
| 746 | |||
| 747 | if (step_fn_info.args[0].arg_type.? != finalize_fn_info.args[0].arg_type.?) { | ||
| 748 | @compileError("both step and finalize functions must have the same first argument and it must be a FunctionContext"); | ||
| 749 | } | ||
| 750 | if (step_fn_info.args[0].arg_type.? != FunctionContext) { | ||
| 751 | @compileError("both step and finalize functions must have a first argument of type FunctionContext"); | ||
| 752 | } | ||
| 766 | 753 | ||
| 767 | const FinalizeFuncArgTuple = std.meta.ArgsTuple(FinalizeFuncType); | 754 | // subtract the context argument |
| 755 | const real_args_len = step_fn_info.args.len - 1; | ||
| 768 | 756 | ||
| 769 | // | 757 | // |
| 770 | 758 | ||
| @@ -772,42 +760,43 @@ pub const Db = struct { | |||
| 772 | 760 | ||
| 773 | const result = c.sqlite3_create_function_v2( | 761 | const result = c.sqlite3_create_function_v2( |
| 774 | self.db, | 762 | self.db, |
| 775 | func_name, | 763 | name, |
| 776 | step_func_args_len, | 764 | real_args_len, |
| 777 | flags, | 765 | flags, |
| 778 | my_ctx, | 766 | user_ctx, |
| 779 | null, // xFunc | 767 | null, // xFunc |
| 780 | struct { | 768 | struct { |
| 781 | fn xStep(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { | 769 | fn xStep(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { |
| 782 | debug.assert(argc == step_func_args_len); | 770 | debug.assert(argc == real_args_len); |
| 783 | 771 | ||
| 784 | const sqlite_args = argv.?[0..step_func_args_len]; | 772 | const sqlite_args = argv.?[0..real_args_len]; |
| 785 | 773 | ||
| 786 | var fn_args: StepFuncArgTuple = undefined; | 774 | var args: std.meta.ArgsTuple(@TypeOf(step_func)) = undefined; |
| 787 | 775 | ||
| 788 | // First argument is always the user-provided context | 776 | // Pass the function context |
| 789 | fn_args[0] = @ptrCast(ContextPtrType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx))); | 777 | args[0] = FunctionContext{ .ctx = ctx }; |
| 790 | 778 | ||
| 791 | comptime var i: usize = 0; | 779 | comptime var i: usize = 0; |
| 792 | inline while (i < step_func_args_len) : (i += 1) { | 780 | inline while (i < real_args_len) : (i += 1) { |
| 793 | // we add 1 because we need to ignore the first argument which is the user-provided context, not a SQLite value. | 781 | // Remember the firt argument is always the function context |
| 794 | const arg = step_fn_info.args[i + 1]; | 782 | const arg = step_fn_info.args[i + 1]; |
| 795 | const arg_ptr = &fn_args[i + 1]; | 783 | const arg_ptr = &args[i + 1]; |
| 796 | 784 | ||
| 797 | const ArgType = arg.arg_type.?; | 785 | const ArgType = arg.arg_type.?; |
| 798 | setFunctionArgument(ArgType, arg_ptr, sqlite_args[i].?); | 786 | setFunctionArgument(ArgType, arg_ptr, sqlite_args[i].?); |
| 799 | } | 787 | } |
| 800 | 788 | ||
| 801 | @call(.{}, step_func, fn_args); | 789 | @call(.{}, step_func, args); |
| 802 | } | 790 | } |
| 803 | }.xStep, | 791 | }.xStep, |
| 804 | struct { | 792 | struct { |
| 805 | fn xFinal(ctx: ?*c.sqlite3_context) callconv(.C) void { | 793 | fn xFinal(ctx: ?*c.sqlite3_context) callconv(.C) void { |
| 806 | var fn_args: FinalizeFuncArgTuple = undefined; | 794 | var args: std.meta.ArgsTuple(@TypeOf(finalize_func)) = undefined; |
| 807 | // Only one argument, the user-provided context | 795 | |
| 808 | fn_args[0] = @ptrCast(ContextPtrType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx))); | 796 | // Pass the function context |
| 797 | args[0] = FunctionContext{ .ctx = ctx }; | ||
| 809 | 798 | ||
| 810 | const result = @call(.{}, finalize_func, fn_args); | 799 | const result = @call(.{}, finalize_func, args); |
| 811 | 800 | ||
| 812 | setFunctionResult(ctx, result); | 801 | setFunctionResult(ctx, result); |
| 813 | } | 802 | } |
| @@ -905,6 +894,62 @@ pub const Db = struct { | |||
| 905 | } | 894 | } |
| 906 | }; | 895 | }; |
| 907 | 896 | ||
| 897 | /// FunctionContext is the context passed as first parameter in the `step` and `finalize` functions used with `createAggregateFunction`. | ||
| 898 | /// It provides two functions: | ||
| 899 | /// * userContext to retrieve the user provided context | ||
| 900 | /// * aggregateContext to create or retrieve the aggregate context | ||
| 901 | /// | ||
| 902 | /// Both functions take a type as parameter and take care of casting so the caller doesn't have to do it. | ||
| 903 | pub const FunctionContext = struct { | ||
| 904 | ctx: ?*c.sqlite3_context, | ||
| 905 | |||
| 906 | pub fn userContext(self: FunctionContext, comptime Type: type) ?Type { | ||
| 907 | const Types = splitPtrTypes(Type); | ||
| 908 | |||
| 909 | if (c.sqlite3_user_data(self.ctx)) |value| { | ||
| 910 | return @ptrCast( | ||
| 911 | Types.PointerType, | ||
| 912 | @alignCast(@alignOf(Types.ValueType), value), | ||
| 913 | ); | ||
| 914 | } | ||
| 915 | return null; | ||
| 916 | } | ||
| 917 | |||
| 918 | pub fn aggregateContext(self: FunctionContext, comptime Type: type) ?Type { | ||
| 919 | const Types = splitPtrTypes(Type); | ||
| 920 | |||
| 921 | if (c.sqlite3_aggregate_context(self.ctx, @sizeOf(Types.ValueType))) |value| { | ||
| 922 | return @ptrCast( | ||
| 923 | Types.PointerType, | ||
| 924 | @alignCast(@alignOf(Types.ValueType), value), | ||
| 925 | ); | ||
| 926 | } | ||
| 927 | return null; | ||
| 928 | } | ||
| 929 | |||
| 930 | const SplitPtrTypes = struct { | ||
| 931 | ValueType: type, | ||
| 932 | PointerType: type, | ||
| 933 | }; | ||
| 934 | |||
| 935 | fn splitPtrTypes(comptime Type: type) SplitPtrTypes { | ||
| 936 | switch (@typeInfo(Type)) { | ||
| 937 | .Pointer => |ptr_info| switch (ptr_info.size) { | ||
| 938 | .One => return SplitPtrTypes{ | ||
| 939 | .ValueType = ptr_info.child, | ||
| 940 | .PointerType = Type, | ||
| 941 | }, | ||
| 942 | else => @compileError("cannot use type " ++ @typeName(Type) ++ ", must be a single-item pointer"), | ||
| 943 | }, | ||
| 944 | .Void => return SplitPtrTypes{ | ||
| 945 | .ValueType = void, | ||
| 946 | .PointerType = undefined, | ||
| 947 | }, | ||
| 948 | else => @compileError("cannot use type " ++ @typeName(Type) ++ ", must be a single-item pointer"), | ||
| 949 | } | ||
| 950 | } | ||
| 951 | }; | ||
| 952 | |||
| 908 | /// Savepoint is a helper type for managing savepoints. | 953 | /// Savepoint is a helper type for managing savepoints. |
| 909 | /// | 954 | /// |
| 910 | /// A savepoint creates a transaction like BEGIN/COMMIT but they're named and can be nested. | 955 | /// A savepoint creates a transaction like BEGIN/COMMIT but they're named and can be nested. |
| @@ -3758,7 +3803,7 @@ test "sqlite: create scalar function" { | |||
| 3758 | } | 3803 | } |
| 3759 | } | 3804 | } |
| 3760 | 3805 | ||
| 3761 | test "sqlite: create aggregate function" { | 3806 | test "sqlite: create aggregate function with no aggregate context" { |
| 3762 | var db = try getTestDb(); | 3807 | var db = try getTestDb(); |
| 3763 | defer db.deinit(); | 3808 | defer db.deinit(); |
| 3764 | 3809 | ||
| @@ -3775,12 +3820,14 @@ test "sqlite: create aggregate function" { | |||
| 3775 | "mySum", | 3820 | "mySum", |
| 3776 | &my_ctx, | 3821 | &my_ctx, |
| 3777 | struct { | 3822 | struct { |
| 3778 | fn step(ctx: *MyContext, input: u32) void { | 3823 | fn step(fctx: FunctionContext, input: u32) void { |
| 3824 | var ctx = fctx.userContext(*MyContext) orelse return; | ||
| 3779 | ctx.sum += input; | 3825 | ctx.sum += input; |
| 3780 | } | 3826 | } |
| 3781 | }.step, | 3827 | }.step, |
| 3782 | struct { | 3828 | struct { |
| 3783 | fn finalize(ctx: *MyContext) u32 { | 3829 | fn finalize(fctx: FunctionContext) u32 { |
| 3830 | var ctx = fctx.userContext(*MyContext) orelse return 0; | ||
| 3784 | return ctx.sum; | 3831 | return ctx.sum; |
| 3785 | } | 3832 | } |
| 3786 | }.finalize, | 3833 | }.finalize, |
| @@ -3816,6 +3863,67 @@ test "sqlite: create aggregate function" { | |||
| 3816 | try testing.expectEqual(@as(usize, exp), result.?); | 3863 | try testing.expectEqual(@as(usize, exp), result.?); |
| 3817 | } | 3864 | } |
| 3818 | 3865 | ||
| 3866 | test "sqlite: create aggregate function with an aggregate context" { | ||
| 3867 | var db = try getTestDb(); | ||
| 3868 | defer db.deinit(); | ||
| 3869 | |||
| 3870 | var rand = std.rand.DefaultPrng.init(@intCast(u64, std.time.milliTimestamp())); | ||
| 3871 | |||
| 3872 | try db.createAggregateFunction( | ||
| 3873 | "mySum", | ||
| 3874 | null, | ||
| 3875 | struct { | ||
| 3876 | fn step(fctx: FunctionContext, input: u32) void { | ||
| 3877 | var ctx = fctx.aggregateContext(*u32) orelse return; | ||
| 3878 | ctx.* += input; | ||
| 3879 | } | ||
| 3880 | }.step, | ||
| 3881 | struct { | ||
| 3882 | fn finalize(fctx: FunctionContext) u32 { | ||
| 3883 | var ctx = fctx.aggregateContext(*u32) orelse return 0; | ||
| 3884 | return ctx.*; | ||
| 3885 | } | ||
| 3886 | }.finalize, | ||
| 3887 | .{}, | ||
| 3888 | ); | ||
| 3889 | |||
| 3890 | // Initialize some data | ||
| 3891 | |||
| 3892 | try db.exec("CREATE TABLE view(id integer PRIMARY KEY, a integer, b integer)", .{}, .{}); | ||
| 3893 | var i: usize = 0; | ||
| 3894 | var exp_a: usize = 0; | ||
| 3895 | var exp_b: usize = 0; | ||
| 3896 | while (i < 20) : (i += 1) { | ||
| 3897 | const val1 = rand.random().intRangeAtMost(u32, 0, 5205905); | ||
| 3898 | exp_a += val1; | ||
| 3899 | |||
| 3900 | const val2 = rand.random().intRangeAtMost(u32, 0, 310455); | ||
| 3901 | exp_b += val2; | ||
| 3902 | |||
| 3903 | try db.exec("INSERT INTO view(a, b) VALUES(?{u32}, ?{u32})", .{}, .{ val1, val2 }); | ||
| 3904 | } | ||
| 3905 | |||
| 3906 | // Get the sum and check the result | ||
| 3907 | |||
| 3908 | var diags = Diagnostics{}; | ||
| 3909 | const result = db.one( | ||
| 3910 | struct { | ||
| 3911 | a_sum: usize, | ||
| 3912 | b_sum: usize, | ||
| 3913 | }, | ||
| 3914 | "SELECT mySum(a), mySum(b) FROM view", | ||
| 3915 | .{ .diags = &diags }, | ||
| 3916 | .{}, | ||
| 3917 | ) catch |err| { | ||
| 3918 | debug.print("err: {}\n", .{diags}); | ||
| 3919 | return err; | ||
| 3920 | }; | ||
| 3921 | |||
| 3922 | try testing.expect(result != null); | ||
| 3923 | try testing.expectEqual(@as(usize, exp_a), result.?.a_sum); | ||
| 3924 | try testing.expectEqual(@as(usize, exp_b), result.?.b_sum); | ||
| 3925 | } | ||
| 3926 | |||
| 3819 | test "sqlite: empty slice" { | 3927 | test "sqlite: empty slice" { |
| 3820 | var arena = std.heap.ArenaAllocator.init(testing.allocator); | 3928 | var arena = std.heap.ArenaAllocator.init(testing.allocator); |
| 3821 | defer arena.deinit(); | 3929 | defer arena.deinit(); |