summaryrefslogtreecommitdiff
path: root/sqlite.zig
diff options
context:
space:
mode:
authorGravatar Vincent Rischmann2022-04-17 00:52:31 +0200
committerGravatar Vincent Rischmann2022-04-17 01:21:08 +0200
commit64848442f900f56b06c3953ee5b3cc6cd97b9bc7 (patch)
tree200fb19f3be8c5d259581d1debc27a8550131496 /sqlite.zig
parentdocument CreateFunctionFlag (diff)
downloadzig-sqlite-64848442f900f56b06c3953ee5b3cc6cd97b9bc7.tar.gz
zig-sqlite-64848442f900f56b06c3953ee5b3cc6cd97b9bc7.tar.xz
zig-sqlite-64848442f900f56b06c3953ee5b3cc6cd97b9bc7.zip
work on supporting aggregate SQL functions
Diffstat (limited to '')
-rw-r--r--sqlite.zig337
1 files changed, 259 insertions, 78 deletions
diff --git a/sqlite.zig b/sqlite.zig
index 52a28dd..ad92cb9 100644
--- a/sqlite.zig
+++ b/sqlite.zig
@@ -1,4 +1,5 @@
1const std = @import("std"); 1const std = @import("std");
2const builtin = @import("builtin");
2const build_options = @import("build_options"); 3const build_options = @import("build_options");
3const debug = std.debug; 4const debug = std.debug;
4const io = std.io; 5const io = std.io;
@@ -601,6 +602,89 @@ pub const Db = struct {
601 return Savepoint.init(self, name); 602 return Savepoint.init(self, name);
602 } 603 }
603 604
605 // Helpers functions to implement SQLite functions.
606
607 fn sliceFromValue(sqlite_value: *c.sqlite3_value) []const u8 {
608 const size = @intCast(usize, c.sqlite3_value_bytes(sqlite_value));
609
610 const value = c.sqlite3_value_text(sqlite_value);
611 debug.assert(value != null); // TODO(vincent): how do we handle this properly ?
612
613 return value.?[0..size];
614 }
615
616 /// Sets the result of a function call in the context `ctx`.
617 ///
618 /// Determines at compile time which sqlite3_result_XYZ function to use based on the type of `result`.
619 fn setFunctionResult(ctx: ?*c.sqlite3_context, result: anytype) void {
620 const ResultType = @TypeOf(result);
621
622 switch (ResultType) {
623 Text => c.sqlite3_result_text(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT),
624 Blob => c.sqlite3_result_blob(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT),
625 else => switch (@typeInfo(ResultType)) {
626 .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) {
627 c.sqlite3_result_int(ctx, result);
628 } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) {
629 c.sqlite3_result_int64(ctx, result);
630 } else {
631 @compileError("integer " ++ @typeName(ResultType) ++ " is not representable in sqlite");
632 },
633 .Float => c.sqlite3_result_double(ctx, result),
634 .Bool => c.sqlite3_result_int(ctx, if (result) 1 else 0),
635 .Array => |arr| switch (arr.child) {
636 u8 => c.sqlite3_result_blob(ctx, &result, arr.len, c.SQLITE_TRANSIENT),
637 else => @compileError("cannot use a result of type " ++ @typeName(ResultType)),
638 },
639 .Pointer => |ptr| switch (ptr.size) {
640 .Slice => switch (ptr.child) {
641 u8 => c.sqlite3_result_text(ctx, result.ptr, @intCast(c_int, result.len), c.SQLITE_TRANSIENT),
642 else => @compileError("cannot use a result of type " ++ @typeName(ResultType)),
643 },
644 else => @compileError("cannot use a result of type " ++ @typeName(ResultType)),
645 },
646 else => @compileError("cannot use a result of type " ++ @typeName(ResultType)),
647 },
648 }
649 }
650
651 /// Sets a function argument using the provided value.
652 ///
653 /// Determines at compile time which sqlite3_value_XYZ function to use based on the type `ArgType`.
654 fn setFunctionArgument(comptime ArgType: type, arg: *ArgType, sqlite_value: *c.sqlite3_value) void {
655 switch (ArgType) {
656 Text => arg.*.data = sliceFromValue(sqlite_value),
657 Blob => arg.*.data = sliceFromValue(sqlite_value),
658 else => switch (@typeInfo(ArgType)) {
659 .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) {
660 const value = c.sqlite3_value_int(sqlite_value);
661 arg.* = @intCast(ArgType, value);
662 } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) {
663 const value = c.sqlite3_value_int64(sqlite_value);
664 arg.* = @intCast(ArgType, value);
665 } else {
666 @compileError("integer " ++ @typeName(ArgType) ++ " is not representable in sqlite");
667 },
668 .Float => {
669 const value = c.sqlite3_value_double(sqlite_value);
670 arg.* = @floatCast(ArgType, value);
671 },
672 .Bool => {
673 const value = c.sqlite3_value_int(sqlite_value);
674 arg.* = value > 0;
675 },
676 .Pointer => |ptr| switch (ptr.size) {
677 .Slice => switch (ptr.child) {
678 u8 => arg.* = sliceFromValue(sqlite_value),
679 else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)),
680 },
681 else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)),
682 },
683 else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)),
684 },
685 }
686 }
687
604 /// CreateFunctionFlag controls the flags used when creating a custom SQL function. 688 /// CreateFunctionFlag controls the flags used when creating a custom SQL function.
605 /// See https://sqlite.org/c3ref/c_deterministic.html. 689 /// See https://sqlite.org/c3ref/c_deterministic.html.
606 /// 690 ///
@@ -614,6 +698,117 @@ pub const Db = struct {
614 direct_only: bool = true, 698 direct_only: bool = true,
615 }; 699 };
616 700
701 /// Creates an aggregate SQLite function with the given name.
702 ///
703 /// When the SQLite function is called in a statement, `step_func` will be called for each row with the input arguments.
704 /// Each SQLite argument is converted to a Zig value according to the following rules:
705 /// * TEXT values can be either sqlite.Text or []const u8
706 /// * BLOB values can be either sqlite.Blob or []const u8
707 /// * INTEGER values can be any Zig integer
708 /// * REAL values can be any Zig float
709 ///
710 /// The final result of the SQL function call will be what `finalize_func` returns.
711 ///
712 /// The context `my_ctx` contains the state necessary to perform the aggregation, both `step_func` and `finalize_func` must have at least the first argument of type ContextType.
713 ///
714 pub fn createAggregateFunction(self: *Self, func_name: [:0]const u8, my_ctx: anytype, comptime step_func: anytype, comptime finalize_func: anytype, comptime create_flags: CreateFunctionFlag) Error!void {
715 // Check that the context type is usable
716 const ContextType = @TypeOf(my_ctx);
717 switch (@typeInfo(ContextType)) {
718 .Pointer => |ptr_info| switch (ptr_info.size) {
719 .One => {},
720 else => @compileError("cannot use context of type " ++ @typeName(ContextType) ++ ", must be a single-item pointer"),
721 },
722 else => @compileError("cannot use context of type " ++ @typeName(ContextType) ++ ", must be a single-item pointer"),
723 }
724
725 // Validate the step function
726
727 const StepFuncType = @TypeOf(step_func);
728
729 const step_fn_info = switch (@typeInfo(StepFuncType)) {
730 .Fn => |fn_info| fn_info,
731 else => @compileError("cannot use step_fn, expecting a function"),
732 };
733 if (step_fn_info.is_generic) @compileError("step_fn function can't be generic");
734 if (step_fn_info.is_var_args) @compileError("step_fn function can't be variadic");
735
736 const StepFuncArgTuple = std.meta.ArgsTuple(StepFuncType);
737
738 // subtract one because the user-provided function always takes an additional context
739 const step_func_args_len = step_fn_info.args.len - 1;
740
741 // Validate the finalize function
742
743 const FinalizeFuncType = @TypeOf(finalize_func);
744
745 const finalize_fn_info = switch (@typeInfo(FinalizeFuncType)) {
746 .Fn => |fn_info| fn_info,
747 else => @compileError("cannot use finalize_fn, expecting a function"),
748 };
749 if (finalize_fn_info.args.len != 1) @compileError("finalize_fn must take exactly one argument");
750
751 const FinalizeFuncArgTuple = std.meta.ArgsTuple(FinalizeFuncType);
752
753 //
754
755 var flags: c_int = c.SQLITE_UTF8;
756 if (create_flags.deterministic) {
757 flags |= c.SQLITE_DETERMINISTIC;
758 }
759 if (create_flags.direct_only) {
760 flags |= c.SQLITE_DIRECTONLY;
761 }
762
763 const result = c.sqlite3_create_function_v2(
764 self.db,
765 func_name,
766 step_func_args_len,
767 flags,
768 my_ctx,
769 null, // xFunc
770 struct {
771 fn xStep(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void {
772 debug.assert(argc == step_func_args_len);
773
774 const sqlite_args = argv.?[0..step_func_args_len];
775
776 var fn_args: StepFuncArgTuple = undefined;
777
778 // First argument is always the user-provided context
779 fn_args[0] = @ptrCast(ContextType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx)));
780
781 comptime var i: usize = 0;
782 inline while (i < step_func_args_len) : (i += 1) {
783 // we add 1 because we need to ignore the first argument which is the user-provided context, not a SQLite value.
784 const arg = step_fn_info.args[i + 1];
785 const arg_ptr = &fn_args[i + 1];
786
787 const ArgType = arg.arg_type.?;
788 setFunctionArgument(ArgType, arg_ptr, sqlite_args[i].?);
789 }
790
791 @call(.{}, step_func, fn_args);
792 }
793 }.xStep,
794 struct {
795 fn xFinal(ctx: ?*c.sqlite3_context) callconv(.C) void {
796 var fn_args: FinalizeFuncArgTuple = undefined;
797 // Only one argument, the user-provided context
798 fn_args[0] = @ptrCast(ContextType, @alignCast(@alignOf(ContextType), c.sqlite3_user_data(ctx)));
799
800 const result = @call(.{}, finalize_func, fn_args);
801
802 setFunctionResult(ctx, result);
803 }
804 }.xFinal,
805 null,
806 );
807 if (result != c.SQLITE_OK) {
808 return errors.errorFromResultCode(result);
809 }
810 }
811
617 /// Creates a scalar SQLite function with the given name. 812 /// Creates a scalar SQLite function with the given name.
618 /// 813 ///
619 /// When the SQLite function is called in a statement, `func` will be called with the input arguments. 814 /// When the SQLite function is called in a statement, `func` will be called with the input arguments.
@@ -652,81 +847,6 @@ pub const Db = struct {
652 flags, 847 flags,
653 null, 848 null,
654 struct { 849 struct {
655 fn sliceFromValue(sqlite_value: *c.sqlite3_value) []const u8 {
656 const size = @intCast(usize, c.sqlite3_value_bytes(sqlite_value));
657
658 const value = c.sqlite3_value_text(sqlite_value);
659 debug.assert(value != null); // TODO(vincent): how do we handle this properly ?
660
661 return value.?[0..size];
662 }
663
664 fn bindValue(comptime ArgType: type, arg: *ArgType, sqlite_value: *c.sqlite3_value) void {
665 switch (ArgType) {
666 Text => arg.*.data = sliceFromValue(sqlite_value),
667 Blob => arg.*.data = sliceFromValue(sqlite_value),
668 else => switch (@typeInfo(ArgType)) {
669 .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) {
670 const value = c.sqlite3_value_int(sqlite_value);
671 arg.* = @intCast(ArgType, value);
672 } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) {
673 const value = c.sqlite3_value_int64(sqlite_value);
674 arg.* = @intCast(ArgType, value);
675 } else {
676 @compileError("integer " ++ @typeName(ArgType) ++ " is not representable in sqlite");
677 },
678 .Float => {
679 const value = c.sqlite3_value_double(sqlite_value);
680 arg.* = @floatCast(ArgType, value);
681 },
682 .Bool => {
683 const value = c.sqlite3_value_int(sqlite_value);
684 arg.* = value > 0;
685 },
686 .Pointer => |ptr| switch (ptr.size) {
687 .Slice => switch (ptr.child) {
688 u8 => arg.* = sliceFromValue(sqlite_value),
689 else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)),
690 },
691 else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)),
692 },
693 else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)),
694 },
695 }
696 }
697
698 fn setResult(ctx: ?*c.sqlite3_context, result: anytype) void {
699 const ResultType = @TypeOf(result);
700
701 switch (ResultType) {
702 Text => c.sqlite3_result_text(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT),
703 Blob => c.sqlite3_result_blob(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT),
704 else => switch (@typeInfo(ResultType)) {
705 .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) {
706 c.sqlite3_result_int(ctx, result);
707 } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) {
708 c.sqlite3_result_int64(ctx, result);
709 } else {
710 @compileError("integer " ++ @typeName(ResultType) ++ " is not representable in sqlite");
711 },
712 .Float => c.sqlite3_result_double(ctx, result),
713 .Bool => c.sqlite3_result_int(ctx, if (result) 1 else 0),
714 .Array => |arr| switch (arr.child) {
715 u8 => c.sqlite3_result_blob(ctx, &result, arr.len, c.SQLITE_TRANSIENT),
716 else => @compileError("cannot use a result of type " ++ @typeName(ResultType)),
717 },
718 .Pointer => |ptr| switch (ptr.size) {
719 .Slice => switch (ptr.child) {
720 u8 => c.sqlite3_result_text(ctx, result.ptr, @intCast(c_int, result.len), c.SQLITE_TRANSIENT),
721 else => @compileError("cannot use a result of type " ++ @typeName(ResultType)),
722 },
723 else => @compileError("cannot use a result of type " ++ @typeName(ResultType)),
724 },
725 else => @compileError("cannot use a result of type " ++ @typeName(ResultType)),
726 },
727 }
728 }
729
730 fn xFunc(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { 850 fn xFunc(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void {
731 debug.assert(argc == fn_info.args.len); 851 debug.assert(argc == fn_info.args.len);
732 852
@@ -735,13 +855,12 @@ pub const Db = struct {
735 var fn_args: ArgTuple = undefined; 855 var fn_args: ArgTuple = undefined;
736 inline for (fn_info.args) |arg, i| { 856 inline for (fn_info.args) |arg, i| {
737 const ArgType = arg.arg_type.?; 857 const ArgType = arg.arg_type.?;
738 858 setFunctionArgument(ArgType, &fn_args[i], sqlite_args[i].?);
739 bindValue(ArgType, &fn_args[i], sqlite_args[i].?);
740 } 859 }
741 860
742 const result = @call(.{}, func, fn_args); 861 const result = @call(.{}, func, fn_args);
743 862
744 setResult(ctx, result); 863 setFunctionResult(ctx, result);
745 } 864 }
746 }.xFunc, 865 }.xFunc,
747 null, 866 null,
@@ -3477,6 +3596,68 @@ test "sqlite: create scalar function" {
3477 } 3596 }
3478} 3597}
3479 3598
3599test "sqlite: create aggregate function" {
3600 // TODO(vincent): fix this, panics on incorrect pointer alignment when casting the SQLite user data to the context type
3601 // in the xStep function.
3602 if (builtin.cpu.arch.isAARCH64()) return error.SkipZigTest;
3603
3604 var db = try getTestDb();
3605 defer db.deinit();
3606
3607 var rand = std.rand.DefaultPrng.init(@intCast(u64, std.time.milliTimestamp()));
3608
3609 // Create an aggregate function working with a MyContext
3610
3611 const MyContext = struct {
3612 sum: u32,
3613 };
3614 var my_ctx = MyContext{ .sum = 0 };
3615
3616 try db.createAggregateFunction(
3617 "mySum",
3618 &my_ctx,
3619 struct {
3620 fn step(ctx: *MyContext, input: u32) void {
3621 ctx.sum += input;
3622 }
3623 }.step,
3624 struct {
3625 fn finalize(ctx: *MyContext) u32 {
3626 return ctx.sum;
3627 }
3628 }.finalize,
3629 .{},
3630 );
3631
3632 // Initialize some data
3633
3634 try db.exec("CREATE TABLE view(id integer PRIMARY KEY, nb integer)", .{}, .{});
3635 var i: usize = 0;
3636 var exp: usize = 0;
3637 while (i < 20) : (i += 1) {
3638 const val = rand.random().intRangeAtMost(u32, 0, 5205905);
3639 exp += val;
3640
3641 try db.exec("INSERT INTO view(nb) VALUES(?{u32})", .{}, .{val});
3642 }
3643
3644 // Get the sum and check the result
3645
3646 var diags = Diagnostics{};
3647 const result = db.one(
3648 usize,
3649 "SELECT mySum(nb) FROM view",
3650 .{ .diags = &diags },
3651 .{},
3652 ) catch |err| {
3653 debug.print("err: {}\n", .{diags});
3654 return err;
3655 };
3656
3657 try testing.expect(result != null);
3658 try testing.expectEqual(@as(usize, exp), result.?);
3659}
3660
3480test "sqlite: empty slice" { 3661test "sqlite: empty slice" {
3481 var arena = std.heap.ArenaAllocator.init(testing.allocator); 3662 var arena = std.heap.ArenaAllocator.init(testing.allocator);
3482 defer arena.deinit(); 3663 defer arena.deinit();