diff options
| author | 2022-04-17 00:52:31 +0200 | |
|---|---|---|
| committer | 2022-04-17 01:21:08 +0200 | |
| commit | 64848442f900f56b06c3953ee5b3cc6cd97b9bc7 (patch) | |
| tree | 200fb19f3be8c5d259581d1debc27a8550131496 | |
| parent | document CreateFunctionFlag (diff) | |
| download | zig-sqlite-64848442f900f56b06c3953ee5b3cc6cd97b9bc7.tar.gz zig-sqlite-64848442f900f56b06c3953ee5b3cc6cd97b9bc7.tar.xz zig-sqlite-64848442f900f56b06c3953ee5b3cc6cd97b9bc7.zip | |
work on supporting aggregate SQL functions
Diffstat (limited to '')
| -rw-r--r-- | sqlite.zig | 337 |
1 files changed, 259 insertions, 78 deletions
| @@ -1,4 +1,5 @@ | |||
| 1 | const std = @import("std"); | 1 | const std = @import("std"); |
| 2 | const builtin = @import("builtin"); | ||
| 2 | const build_options = @import("build_options"); | 3 | const build_options = @import("build_options"); |
| 3 | const debug = std.debug; | 4 | const debug = std.debug; |
| 4 | const io = std.io; | 5 | const io = std.io; |
| @@ -601,6 +602,89 @@ pub const Db = struct { | |||
| 601 | return Savepoint.init(self, name); | 602 | return Savepoint.init(self, name); |
| 602 | } | 603 | } |
| 603 | 604 | ||
| 605 | // Helpers functions to implement SQLite functions. | ||
| 606 | |||
| 607 | fn sliceFromValue(sqlite_value: *c.sqlite3_value) []const u8 { | ||
| 608 | const size = @intCast(usize, c.sqlite3_value_bytes(sqlite_value)); | ||
| 609 | |||
| 610 | const value = c.sqlite3_value_text(sqlite_value); | ||
| 611 | debug.assert(value != null); // TODO(vincent): how do we handle this properly ? | ||
| 612 | |||
| 613 | return value.?[0..size]; | ||
| 614 | } | ||
| 615 | |||
| 616 | /// Sets the result of a function call in the context `ctx`. | ||
| 617 | /// | ||
| 618 | /// Determines at compile time which sqlite3_result_XYZ function to use based on the type of `result`. | ||
| 619 | fn setFunctionResult(ctx: ?*c.sqlite3_context, result: anytype) void { | ||
| 620 | const ResultType = @TypeOf(result); | ||
| 621 | |||
| 622 | switch (ResultType) { | ||
| 623 | Text => c.sqlite3_result_text(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT), | ||
| 624 | Blob => c.sqlite3_result_blob(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT), | ||
| 625 | else => switch (@typeInfo(ResultType)) { | ||
| 626 | .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) { | ||
| 627 | c.sqlite3_result_int(ctx, result); | ||
| 628 | } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) { | ||
| 629 | c.sqlite3_result_int64(ctx, result); | ||
| 630 | } else { | ||
| 631 | @compileError("integer " ++ @typeName(ResultType) ++ " is not representable in sqlite"); | ||
| 632 | }, | ||
| 633 | .Float => c.sqlite3_result_double(ctx, result), | ||
| 634 | .Bool => c.sqlite3_result_int(ctx, if (result) 1 else 0), | ||
| 635 | .Array => |arr| switch (arr.child) { | ||
| 636 | u8 => c.sqlite3_result_blob(ctx, &result, arr.len, c.SQLITE_TRANSIENT), | ||
| 637 | else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), | ||
| 638 | }, | ||
| 639 | .Pointer => |ptr| switch (ptr.size) { | ||
| 640 | .Slice => switch (ptr.child) { | ||
| 641 | u8 => c.sqlite3_result_text(ctx, result.ptr, @intCast(c_int, result.len), c.SQLITE_TRANSIENT), | ||
| 642 | else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), | ||
| 643 | }, | ||
| 644 | else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), | ||
| 645 | }, | ||
| 646 | else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), | ||
| 647 | }, | ||
| 648 | } | ||
| 649 | } | ||
| 650 | |||
| 651 | /// Sets a function argument using the provided value. | ||
| 652 | /// | ||
| 653 | /// Determines at compile time which sqlite3_value_XYZ function to use based on the type `ArgType`. | ||
| 654 | fn setFunctionArgument(comptime ArgType: type, arg: *ArgType, sqlite_value: *c.sqlite3_value) void { | ||
| 655 | switch (ArgType) { | ||
| 656 | Text => arg.*.data = sliceFromValue(sqlite_value), | ||
| 657 | Blob => arg.*.data = sliceFromValue(sqlite_value), | ||
| 658 | else => switch (@typeInfo(ArgType)) { | ||
| 659 | .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) { | ||
| 660 | const value = c.sqlite3_value_int(sqlite_value); | ||
| 661 | arg.* = @intCast(ArgType, value); | ||
| 662 | } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) { | ||
| 663 | const value = c.sqlite3_value_int64(sqlite_value); | ||
| 664 | arg.* = @intCast(ArgType, value); | ||
| 665 | } else { | ||
| 666 | @compileError("integer " ++ @typeName(ArgType) ++ " is not representable in sqlite"); | ||
| 667 | }, | ||
| 668 | .Float => { | ||
| 669 | const value = c.sqlite3_value_double(sqlite_value); | ||
| 670 | arg.* = @floatCast(ArgType, value); | ||
| 671 | }, | ||
| 672 | .Bool => { | ||
| 673 | const value = c.sqlite3_value_int(sqlite_value); | ||
| 674 | arg.* = value > 0; | ||
| 675 | }, | ||
| 676 | .Pointer => |ptr| switch (ptr.size) { | ||
| 677 | .Slice => switch (ptr.child) { | ||
| 678 | u8 => arg.* = sliceFromValue(sqlite_value), | ||
| 679 | else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), | ||
| 680 | }, | ||
| 681 | else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), | ||
| 682 | }, | ||
| 683 | else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), | ||
| 684 | }, | ||
| 685 | } | ||
| 686 | } | ||
| 687 | |||
| 604 | /// CreateFunctionFlag controls the flags used when creating a custom SQL function. | 688 | /// CreateFunctionFlag controls the flags used when creating a custom SQL function. |
| 605 | /// See https://sqlite.org/c3ref/c_deterministic.html. | 689 | /// See https://sqlite.org/c3ref/c_deterministic.html. |
| 606 | /// | 690 | /// |
| @@ -614,6 +698,117 @@ pub const Db = struct { | |||
| 614 | direct_only: bool = true, | 698 | direct_only: bool = true, |
| 615 | }; | 699 | }; |
| 616 | 700 | ||
| 701 | /// Creates an aggregate SQLite function with the given name. | ||
| 702 | /// | ||
| 703 | /// When the SQLite function is called in a statement, `step_func` will be called for each row with the input arguments. | ||
| 704 | /// Each SQLite argument is converted to a Zig value according to the following rules: | ||
| 705 | /// * TEXT values can be either sqlite.Text or []const u8 | ||
| 706 | /// * BLOB values can be either sqlite.Blob or []const u8 | ||
| 707 | /// * INTEGER values can be any Zig integer | ||
| 708 | /// * REAL values can be any Zig float | ||
| 709 | /// | ||
| 710 | /// The final result of the SQL function call will be what `finalize_func` returns. | ||
| 711 | /// | ||
| 712 | /// The context `my_ctx` contains the state necessary to perform the aggregation, both `step_func` and `finalize_func` must have at least the first argument of type ContextType. | ||
| 713 | /// | ||
| 714 | pub fn createAggregateFunction(self: *Self, func_name: [:0]const u8, my_ctx: anytype, comptime step_func: anytype, comptime finalize_func: anytype, comptime create_flags: CreateFunctionFlag) Error!void { | ||
| 715 | // Check that the context type is usable | ||
| 716 | const ContextType = @TypeOf(my_ctx); | ||
| 717 | switch (@typeInfo(ContextType)) { | ||
| 718 | .Pointer => |ptr_info| switch (ptr_info.size) { | ||
| 719 | .One => {}, | ||
| 720 | else => @compileError("cannot use context of type " ++ @typeName(ContextType) ++ ", must be a single-item pointer"), | ||
| 721 | }, | ||
| 722 | else => @compileError("cannot use context of type " ++ @typeName(ContextType) ++ ", must be a single-item pointer"), | ||
| 723 | } | ||
| 724 | |||
| 725 | // Validate the step function | ||
| 726 | |||
| 727 | const StepFuncType = @TypeOf(step_func); | ||
| 728 | |||
| 729 | const step_fn_info = switch (@typeInfo(StepFuncType)) { | ||
| 730 | .Fn => |fn_info| fn_info, | ||
| 731 | else => @compileError("cannot use step_fn, expecting a function"), | ||
| 732 | }; | ||
| 733 | if (step_fn_info.is_generic) @compileError("step_fn function can't be generic"); | ||
| 734 | if (step_fn_info.is_var_args) @compileError("step_fn function can't be variadic"); | ||
| 735 | |||
| 736 | const StepFuncArgTuple = std.meta.ArgsTuple(StepFuncType); | ||
| 737 | |||
| 738 | // subtract one because the user-provided function always takes an additional context | ||
| 739 | const step_func_args_len = step_fn_info.args.len - 1; | ||
| 740 | |||
| 741 | // Validate the finalize function | ||
| 742 | |||
| 743 | const FinalizeFuncType = @TypeOf(finalize_func); | ||
| 744 | |||
| 745 | const finalize_fn_info = switch (@typeInfo(FinalizeFuncType)) { | ||
| 746 | .Fn => |fn_info| fn_info, | ||
| 747 | else => @compileError("cannot use finalize_fn, expecting a function"), | ||
| 748 | }; | ||
| 749 | if (finalize_fn_info.args.len != 1) @compileError("finalize_fn must take exactly one argument"); | ||
| 750 | |||
| 751 | const FinalizeFuncArgTuple = std.meta.ArgsTuple(FinalizeFuncType); | ||
| 752 | |||
| 753 | // | ||
| 754 | |||
| 755 | var flags: c_int = c.SQLITE_UTF8; | ||
| 756 | if (create_flags.deterministic) { | ||
| 757 | flags |= c.SQLITE_DETERMINISTIC; | ||
| 758 | } | ||
| 759 | if (create_flags.direct_only) { | ||
| 760 | flags |= c.SQLITE_DIRECTONLY; | ||
| 761 | } | ||
| 762 | |||
| 763 | const result = c.sqlite3_create_function_v2( | ||
| 764 | self.db, | ||
| 765 | func_name, | ||
| 766 | step_func_args_len, | ||
| 767 | flags, | ||
| 768 | my_ctx, | ||
| 769 | null, // xFunc | ||
| 770 | struct { | ||
| 771 | fn xStep(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { | ||
| 772 | debug.assert(argc == step_func_args_len); | ||
| 773 | |||
| 774 | const sqlite_args = argv.?[0..step_func_args_len]; | ||
| 775 | |||
| 776 | var fn_args: StepFuncArgTuple = undefined; | ||
| 777 | |||
| 778 | // First argument is always the user-provided context | ||
| 779 | fn_args[0] = @ptrCast(ContextType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx))); | ||
| 780 | |||
| 781 | comptime var i: usize = 0; | ||
| 782 | inline while (i < step_func_args_len) : (i += 1) { | ||
| 783 | // we add 1 because we need to ignore the first argument which is the user-provided context, not a SQLite value. | ||
| 784 | const arg = step_fn_info.args[i + 1]; | ||
| 785 | const arg_ptr = &fn_args[i + 1]; | ||
| 786 | |||
| 787 | const ArgType = arg.arg_type.?; | ||
| 788 | setFunctionArgument(ArgType, arg_ptr, sqlite_args[i].?); | ||
| 789 | } | ||
| 790 | |||
| 791 | @call(.{}, step_func, fn_args); | ||
| 792 | } | ||
| 793 | }.xStep, | ||
| 794 | struct { | ||
| 795 | fn xFinal(ctx: ?*c.sqlite3_context) callconv(.C) void { | ||
| 796 | var fn_args: FinalizeFuncArgTuple = undefined; | ||
| 797 | // Only one argument, the user-provided context | ||
| 798 | fn_args[0] = @ptrCast(ContextType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx))); | ||
| 799 | |||
| 800 | const result = @call(.{}, finalize_func, fn_args); | ||
| 801 | |||
| 802 | setFunctionResult(ctx, result); | ||
| 803 | } | ||
| 804 | }.xFinal, | ||
| 805 | null, | ||
| 806 | ); | ||
| 807 | if (result != c.SQLITE_OK) { | ||
| 808 | return errors.errorFromResultCode(result); | ||
| 809 | } | ||
| 810 | } | ||
| 811 | |||
| 617 | /// Creates a scalar SQLite function with the given name. | 812 | /// Creates a scalar SQLite function with the given name. |
| 618 | /// | 813 | /// |
| 619 | /// When the SQLite function is called in a statement, `func` will be called with the input arguments. | 814 | /// When the SQLite function is called in a statement, `func` will be called with the input arguments. |
| @@ -652,81 +847,6 @@ pub const Db = struct { | |||
| 652 | flags, | 847 | flags, |
| 653 | null, | 848 | null, |
| 654 | struct { | 849 | struct { |
| 655 | fn sliceFromValue(sqlite_value: *c.sqlite3_value) []const u8 { | ||
| 656 | const size = @intCast(usize, c.sqlite3_value_bytes(sqlite_value)); | ||
| 657 | |||
| 658 | const value = c.sqlite3_value_text(sqlite_value); | ||
| 659 | debug.assert(value != null); // TODO(vincent): how do we handle this properly ? | ||
| 660 | |||
| 661 | return value.?[0..size]; | ||
| 662 | } | ||
| 663 | |||
| 664 | fn bindValue(comptime ArgType: type, arg: *ArgType, sqlite_value: *c.sqlite3_value) void { | ||
| 665 | switch (ArgType) { | ||
| 666 | Text => arg.*.data = sliceFromValue(sqlite_value), | ||
| 667 | Blob => arg.*.data = sliceFromValue(sqlite_value), | ||
| 668 | else => switch (@typeInfo(ArgType)) { | ||
| 669 | .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) { | ||
| 670 | const value = c.sqlite3_value_int(sqlite_value); | ||
| 671 | arg.* = @intCast(ArgType, value); | ||
| 672 | } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) { | ||
| 673 | const value = c.sqlite3_value_int64(sqlite_value); | ||
| 674 | arg.* = @intCast(ArgType, value); | ||
| 675 | } else { | ||
| 676 | @compileError("integer " ++ @typeName(ArgType) ++ " is not representable in sqlite"); | ||
| 677 | }, | ||
| 678 | .Float => { | ||
| 679 | const value = c.sqlite3_value_double(sqlite_value); | ||
| 680 | arg.* = @floatCast(ArgType, value); | ||
| 681 | }, | ||
| 682 | .Bool => { | ||
| 683 | const value = c.sqlite3_value_int(sqlite_value); | ||
| 684 | arg.* = value > 0; | ||
| 685 | }, | ||
| 686 | .Pointer => |ptr| switch (ptr.size) { | ||
| 687 | .Slice => switch (ptr.child) { | ||
| 688 | u8 => arg.* = sliceFromValue(sqlite_value), | ||
| 689 | else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), | ||
| 690 | }, | ||
| 691 | else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), | ||
| 692 | }, | ||
| 693 | else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), | ||
| 694 | }, | ||
| 695 | } | ||
| 696 | } | ||
| 697 | |||
| 698 | fn setResult(ctx: ?*c.sqlite3_context, result: anytype) void { | ||
| 699 | const ResultType = @TypeOf(result); | ||
| 700 | |||
| 701 | switch (ResultType) { | ||
| 702 | Text => c.sqlite3_result_text(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT), | ||
| 703 | Blob => c.sqlite3_result_blob(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT), | ||
| 704 | else => switch (@typeInfo(ResultType)) { | ||
| 705 | .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) { | ||
| 706 | c.sqlite3_result_int(ctx, result); | ||
| 707 | } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) { | ||
| 708 | c.sqlite3_result_int64(ctx, result); | ||
| 709 | } else { | ||
| 710 | @compileError("integer " ++ @typeName(ResultType) ++ " is not representable in sqlite"); | ||
| 711 | }, | ||
| 712 | .Float => c.sqlite3_result_double(ctx, result), | ||
| 713 | .Bool => c.sqlite3_result_int(ctx, if (result) 1 else 0), | ||
| 714 | .Array => |arr| switch (arr.child) { | ||
| 715 | u8 => c.sqlite3_result_blob(ctx, &result, arr.len, c.SQLITE_TRANSIENT), | ||
| 716 | else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), | ||
| 717 | }, | ||
| 718 | .Pointer => |ptr| switch (ptr.size) { | ||
| 719 | .Slice => switch (ptr.child) { | ||
| 720 | u8 => c.sqlite3_result_text(ctx, result.ptr, @intCast(c_int, result.len), c.SQLITE_TRANSIENT), | ||
| 721 | else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), | ||
| 722 | }, | ||
| 723 | else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), | ||
| 724 | }, | ||
| 725 | else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), | ||
| 726 | }, | ||
| 727 | } | ||
| 728 | } | ||
| 729 | |||
| 730 | fn xFunc(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { | 850 | fn xFunc(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { |
| 731 | debug.assert(argc == fn_info.args.len); | 851 | debug.assert(argc == fn_info.args.len); |
| 732 | 852 | ||
| @@ -735,13 +855,12 @@ pub const Db = struct { | |||
| 735 | var fn_args: ArgTuple = undefined; | 855 | var fn_args: ArgTuple = undefined; |
| 736 | inline for (fn_info.args) |arg, i| { | 856 | inline for (fn_info.args) |arg, i| { |
| 737 | const ArgType = arg.arg_type.?; | 857 | const ArgType = arg.arg_type.?; |
| 738 | 858 | setFunctionArgument(ArgType, &fn_args[i], sqlite_args[i].?); | |
| 739 | bindValue(ArgType, &fn_args[i], sqlite_args[i].?); | ||
| 740 | } | 859 | } |
| 741 | 860 | ||
| 742 | const result = @call(.{}, func, fn_args); | 861 | const result = @call(.{}, func, fn_args); |
| 743 | 862 | ||
| 744 | setResult(ctx, result); | 863 | setFunctionResult(ctx, result); |
| 745 | } | 864 | } |
| 746 | }.xFunc, | 865 | }.xFunc, |
| 747 | null, | 866 | null, |
| @@ -3477,6 +3596,68 @@ test "sqlite: create scalar function" { | |||
| 3477 | } | 3596 | } |
| 3478 | } | 3597 | } |
| 3479 | 3598 | ||
| 3599 | test "sqlite: create aggregate function" { | ||
| 3600 | // TODO(vincent): fix this, panics on incorrect pointer alignment when casting the SQLite user data to the context type | ||
| 3601 | // in the xStep function. | ||
| 3602 | if (builtin.cpu.arch.isAARCH64()) return error.SkipZigTest; | ||
| 3603 | |||
| 3604 | var db = try getTestDb(); | ||
| 3605 | defer db.deinit(); | ||
| 3606 | |||
| 3607 | var rand = std.rand.DefaultPrng.init(@intCast(u64, std.time.milliTimestamp())); | ||
| 3608 | |||
| 3609 | // Create an aggregate function working with a MyContext | ||
| 3610 | |||
| 3611 | const MyContext = struct { | ||
| 3612 | sum: u32, | ||
| 3613 | }; | ||
| 3614 | var my_ctx = MyContext{ .sum = 0 }; | ||
| 3615 | |||
| 3616 | try db.createAggregateFunction( | ||
| 3617 | "mySum", | ||
| 3618 | &my_ctx, | ||
| 3619 | struct { | ||
| 3620 | fn step(ctx: *MyContext, input: u32) void { | ||
| 3621 | ctx.sum += input; | ||
| 3622 | } | ||
| 3623 | }.step, | ||
| 3624 | struct { | ||
| 3625 | fn finalize(ctx: *MyContext) u32 { | ||
| 3626 | return ctx.sum; | ||
| 3627 | } | ||
| 3628 | }.finalize, | ||
| 3629 | .{}, | ||
| 3630 | ); | ||
| 3631 | |||
| 3632 | // Initialize some data | ||
| 3633 | |||
| 3634 | try db.exec("CREATE TABLE view(id integer PRIMARY KEY, nb integer)", .{}, .{}); | ||
| 3635 | var i: usize = 0; | ||
| 3636 | var exp: usize = 0; | ||
| 3637 | while (i < 20) : (i += 1) { | ||
| 3638 | const val = rand.random().intRangeAtMost(u32, 0, 5205905); | ||
| 3639 | exp += val; | ||
| 3640 | |||
| 3641 | try db.exec("INSERT INTO view(nb) VALUES(?{u32})", .{}, .{val}); | ||
| 3642 | } | ||
| 3643 | |||
| 3644 | // Get the sum and check the result | ||
| 3645 | |||
| 3646 | var diags = Diagnostics{}; | ||
| 3647 | const result = db.one( | ||
| 3648 | usize, | ||
| 3649 | "SELECT mySum(nb) FROM view", | ||
| 3650 | .{ .diags = &diags }, | ||
| 3651 | .{}, | ||
| 3652 | ) catch |err| { | ||
| 3653 | debug.print("err: {}\n", .{diags}); | ||
| 3654 | return err; | ||
| 3655 | }; | ||
| 3656 | |||
| 3657 | try testing.expect(result != null); | ||
| 3658 | try testing.expectEqual(@as(usize, exp), result.?); | ||
| 3659 | } | ||
| 3660 | |||
| 3480 | test "sqlite: empty slice" { | 3661 | test "sqlite: empty slice" { |
| 3481 | var arena = std.heap.ArenaAllocator.init(testing.allocator); | 3662 | var arena = std.heap.ArenaAllocator.init(testing.allocator); |
| 3482 | defer arena.deinit(); | 3663 | defer arena.deinit(); |