From 88e7283d2de4cf820b5d698c804c9ba1a6c8a6c4 Mon Sep 17 00:00:00 2001 From: Vincent Rischmann Date: Tue, 22 Mar 2022 00:33:49 +0100 Subject: add createScalarFunction to create a user-defined scalar function --- sqlite.zig | 292 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 292 insertions(+) (limited to 'sqlite.zig') diff --git a/sqlite.zig b/sqlite.zig index f6632b6..e27abb5 100644 --- a/sqlite.zig +++ b/sqlite.zig @@ -600,6 +600,150 @@ pub const Db = struct { pub fn savepoint(self: *Self, name: []const u8) Savepoint.InitError!Savepoint { return Savepoint.init(self, name); } + + pub const CreateFunctionFlag = struct { + deterministic: bool = true, + direct_only: bool = true, + }; + + /// Creates a scalar SQLite function with the given name. + /// + /// When the SQLite function is called in a statement, `func` will be called with the input arguments. + /// Each SQLite argument is converted to a Zig value according to the following rules: + /// * TEXT values can be either sqlite.Text or []const u8 + /// * BLOB values can be either sqlite.Blob or []const u8 + /// * INTEGER values can be any Zig integer + /// * REAL values can be any Zig float + /// + /// The return type of the function is converted to a SQLite value according to the same rules but reversed. + /// + pub fn createScalarFunction(self: *Self, func_name: [:0]const u8, comptime func: anytype, comptime create_flags: CreateFunctionFlag) !void { + const Type = @TypeOf(func); + + const fn_info = switch (@typeInfo(Type)) { + .Fn => |fn_info| fn_info, + else => @compileError("expecting a function"), + }; + if (fn_info.is_generic) @compileError("function can't be generic"); + if (fn_info.is_var_args) @compileError("function can't be variadic"); + + const ArgTuple = std.meta.ArgsTuple(Type); + + var flags: c_int = c.SQLITE_UTF8; + if (create_flags.deterministic) { + flags |= c.SQLITE_DETERMINISTIC; + } + if (create_flags.direct_only) { + flags |= c.SQLITE_DIRECTONLY; + } + + const result = c.sqlite3_create_function_v2( + self.db, + func_name, + fn_info.args.len, + flags, + null, + struct { + fn sliceFromValue(sqlite_value: *c.sqlite3_value) []const u8 { + const size = @intCast(usize, c.sqlite3_value_bytes(sqlite_value)); + + const value = c.sqlite3_value_text(sqlite_value); + debug.assert(value != null); // TODO(vincent): how do we handle this properly ? + + return value.?[0..size]; + } + + fn bindValue(comptime ArgType: type, arg: *ArgType, sqlite_value: *c.sqlite3_value) void { + switch (ArgType) { + Text => arg.*.data = sliceFromValue(sqlite_value), + Blob => arg.*.data = sliceFromValue(sqlite_value), + else => switch (@typeInfo(ArgType)) { + .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) { + const value = c.sqlite3_value_int(sqlite_value); + arg.* = @intCast(ArgType, value); + } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) { + const value = c.sqlite3_value_int64(sqlite_value); + arg.* = @intCast(ArgType, value); + } else { + @compileError("integer " ++ @typeName(ArgType) ++ " is not representable in sqlite"); + }, + .Float => { + const value = c.sqlite3_value_double(sqlite_value); + arg.* = @floatCast(ArgType, value); + }, + .Bool => { + const value = c.sqlite3_value_int(sqlite_value); + arg.* = value > 0; + }, + .Pointer => |ptr| switch (ptr.size) { + .Slice => switch (ptr.child) { + u8 => arg.* = sliceFromValue(sqlite_value), + else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), + }, + else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), + }, + else => @compileError("cannot use an argument of type " ++ @typeName(ArgType)), + }, + } + } + + fn setResult(ctx: ?*c.sqlite3_context, result: anytype) void { + const ResultType = @TypeOf(result); + + switch (ResultType) { + Text => c.sqlite3_result_text(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT), + Blob => c.sqlite3_result_blob(ctx, result.data.ptr, @intCast(c_int, result.data.len), c.SQLITE_TRANSIENT), + else => switch (@typeInfo(ResultType)) { + .Int => |info| if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 32) { + c.sqlite3_result_int(ctx, result); + } else if ((info.bits + if (info.signedness == .unsigned) 1 else 0) <= 64) { + c.sqlite3_result_int64(ctx, result); + } else { + @compileError("integer " ++ @typeName(ResultType) ++ " is not representable in sqlite"); + }, + .Float => c.sqlite3_result_double(ctx, result), + .Bool => c.sqlite3_result_int(ctx, if (result) 1 else 0), + .Array => |arr| switch (arr.child) { + u8 => c.sqlite3_result_blob(ctx, &result, arr.len, c.SQLITE_TRANSIENT), + else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), + }, + .Pointer => |ptr| switch (ptr.size) { + .Slice => switch (ptr.child) { + u8 => c.sqlite3_result_text(ctx, result.ptr, @intCast(c_int, result.len), c.SQLITE_TRANSIENT), + else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), + }, + else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), + }, + else => @compileError("cannot use a result of type " ++ @typeName(ResultType)), + }, + } + } + + fn xFunc(ctx: ?*c.sqlite3_context, argc: c_int, argv: [*c]?*c.sqlite3_value) callconv(.C) void { + debug.assert(argc == fn_info.args.len); + + const sqlite_args = argv.?[0..fn_info.args.len]; + + var fn_args: ArgTuple = undefined; + inline for (fn_info.args) |arg, i| { + const ArgType = arg.arg_type.?; + + bindValue(ArgType, &fn_args[i], sqlite_args[i].?); + } + + const result = @call(.{}, func, fn_args); + + setResult(ctx, result); + } + }.xFunc, + null, + null, + null, + ); + if (result != c.SQLITE_OK) { + return errors.errorFromResultCode(result); + } + } }; /// Savepoint is a helper type for managing savepoints. @@ -3177,6 +3321,154 @@ test "sqlite: one with all named parameters" { try testing.expectEqual(@as(usize, 20), id.?); } +test "sqlite: create scalar function" { + var db = try getTestDb(); + defer db.deinit(); + + { + try db.createScalarFunction( + "myInteger", + struct { + fn run(input: u16) u16 { + return input * 2; + } + }.run, + .{}, + ); + + const result = try db.one(usize, "SELECT myInteger(20)", .{}, .{}); + + try testing.expect(result != null); + try testing.expectEqual(@as(usize, 40), result.?); + } + + { + try db.createScalarFunction( + "myInteger64", + struct { + fn run(input: i64) i64 { + return @intCast(i64, input) * 2; + } + }.run, + .{}, + ); + + const result = try db.one(usize, "SELECT myInteger64(20)", .{}, .{}); + + try testing.expect(result != null); + try testing.expectEqual(@as(usize, 40), result.?); + } + + { + try db.createScalarFunction( + "myMax", + struct { + fn run(a: f64, b: f64) f64 { + return std.math.max(a, b); + } + }.run, + .{}, + ); + + const result = try db.one(f64, "SELECT myMax(2.0, 23.4)", .{}, .{}); + + try testing.expect(result != null); + try testing.expectEqual(@as(f64, 23.4), result.?); + } + + { + try db.createScalarFunction( + "myBool", + struct { + fn run() bool { + return true; + } + }.run, + .{}, + ); + + const result = try db.one(bool, "SELECT myBool()", .{}, .{}); + + try testing.expect(result != null); + try testing.expectEqual(true, result.?); + } + + { + try db.createScalarFunction( + "mySlice", + struct { + fn run() []const u8 { + return "foobar"; + } + }.run, + .{}, + ); + + const result = try db.oneAlloc([]const u8, testing.allocator, "SELECT mySlice()", .{}, .{}); + try testing.expect(result != null); + try testing.expectEqualStrings("foobar", result.?); + testing.allocator.free(result.?); + } + + { + const Blake3 = std.crypto.hash.Blake3; + + var expected_hash: [Blake3.digest_length]u8 = undefined; + Blake3.hash("hello", &expected_hash, .{}); + + try db.createScalarFunction( + "blake3", + struct { + fn run(input: []const u8) [std.crypto.hash.Blake3.digest_length]u8 { + var hash: [Blake3.digest_length]u8 = undefined; + Blake3.hash(input, &hash, .{}); + return hash; + } + }.run, + .{}, + ); + + const hash = try db.one([Blake3.digest_length]u8, "SELECT blake3('hello')", .{}, .{}); + + try testing.expect(hash != null); + try testing.expectEqual(expected_hash, hash.?); + } + + { + try db.createScalarFunction( + "myText", + struct { + fn run() Text { + return Text{ .data = "foobar" }; + } + }.run, + .{}, + ); + + const result = try db.oneAlloc(Text, testing.allocator, "SELECT myText()", .{}, .{}); + try testing.expect(result != null); + try testing.expectEqualStrings("foobar", result.?.data); + testing.allocator.free(result.?.data); + } + + { + try db.createScalarFunction( + "myBlob", + struct { + fn run() Blob { + return Blob{ .data = "barbaz" }; + } + }.run, + .{}, + ); + + const result = try db.oneAlloc(Blob, testing.allocator, "SELECT myBlob()", .{}, .{}); + try testing.expect(result != null); + try testing.expectEqualStrings("barbaz", result.?.data); + testing.allocator.free(result.?.data); + } +} + test "sqlite: empty slice" { var arena = std.heap.ArenaAllocator.init(testing.allocator); defer arena.deinit(); -- cgit v1.2.3