diff options
| author | 2020-10-30 13:49:05 +0100 | |
|---|---|---|
| committer | 2020-11-11 13:31:50 +0100 | |
| commit | be6902d0736178c113d6b11c9b056c4a33a966f3 (patch) | |
| tree | 634b14cf2624f77aa4f79e808b9dae2a85ee8f2c | |
| parent | update requirements (diff) | |
| download | zig-sqlite-be6902d0736178c113d6b11c9b056c4a33a966f3.tar.gz zig-sqlite-be6902d0736178c113d6b11c9b056c4a33a966f3.tar.xz zig-sqlite-be6902d0736178c113d6b11c9b056c4a33a966f3.zip | |
add types to bind markers and check them at comptime
| -rw-r--r-- | query.zig | 190 | ||||
| -rw-r--r-- | sqlite.zig | 76 |
2 files changed, 221 insertions, 45 deletions
diff --git a/query.zig b/query.zig new file mode 100644 index 0000000..a2a5c5c --- /dev/null +++ b/query.zig | |||
| @@ -0,0 +1,190 @@ | |||
| 1 | const builtin = @import("builtin"); | ||
| 2 | const std = @import("std"); | ||
| 3 | const mem = std.mem; | ||
| 4 | const testing = std.testing; | ||
| 5 | |||
| 6 | /// Blob is used to represent a SQLite BLOB value when binding a parameter or reading a column. | ||
| 7 | pub const Blob = struct { data: []const u8 }; | ||
| 8 | |||
| 9 | /// Text is used to represent a SQLite TEXT value when binding a parameter or reading a column. | ||
| 10 | pub const Text = struct { data: []const u8 }; | ||
| 11 | |||
| 12 | const BindMarker = union(enum) { | ||
| 13 | Type: type, | ||
| 14 | None: void, | ||
| 15 | }; | ||
| 16 | |||
| 17 | pub const ParsedQuery = struct { | ||
| 18 | const Self = @This(); | ||
| 19 | |||
| 20 | bind_markers: [128]BindMarker, | ||
| 21 | nb_bind_markers: usize, | ||
| 22 | |||
| 23 | query: [1024]u8, | ||
| 24 | query_size: usize, | ||
| 25 | |||
| 26 | pub fn from(comptime query: []const u8) Self { | ||
| 27 | const State = enum { | ||
| 28 | Start, | ||
| 29 | BindMarker, | ||
| 30 | BindMarkerType, | ||
| 31 | }; | ||
| 32 | |||
| 33 | comptime var buf: [query.len]u8 = undefined; | ||
| 34 | comptime var pos = 0; | ||
| 35 | comptime var state = .Start; | ||
| 36 | |||
| 37 | comptime var current_bind_marker_type: [256]u8 = undefined; | ||
| 38 | comptime var current_bind_marker_type_pos = 0; | ||
| 39 | |||
| 40 | comptime var parsed_query: ParsedQuery = undefined; | ||
| 41 | parsed_query.nb_bind_markers = 0; | ||
| 42 | |||
| 43 | inline for (query) |c, i| { | ||
| 44 | switch (state) { | ||
| 45 | .Start => switch (c) { | ||
| 46 | '?' => { | ||
| 47 | state = .BindMarker; | ||
| 48 | buf[pos] = c; | ||
| 49 | pos += 1; | ||
| 50 | }, | ||
| 51 | else => { | ||
| 52 | buf[pos] = c; | ||
| 53 | pos += 1; | ||
| 54 | }, | ||
| 55 | }, | ||
| 56 | .BindMarker => switch (c) { | ||
| 57 | '{' => { | ||
| 58 | state = .BindMarkerType; | ||
| 59 | current_bind_marker_type_pos = 0; | ||
| 60 | }, | ||
| 61 | else => { | ||
| 62 | @compileError("a bind marker start (the character ?) must be followed by a bind marker type, eg {integer}"); | ||
| 63 | }, | ||
| 64 | }, | ||
| 65 | .BindMarkerType => switch (c) { | ||
| 66 | '}' => { | ||
| 67 | state = .Start; | ||
| 68 | |||
| 69 | const typ = parsed_query.parseType(current_bind_marker_type[0..current_bind_marker_type_pos]); | ||
| 70 | |||
| 71 | parsed_query.bind_markers[parsed_query.nb_bind_markers] = BindMarker{ .Type = typ }; | ||
| 72 | parsed_query.nb_bind_markers += 1; | ||
| 73 | }, | ||
| 74 | else => { | ||
| 75 | current_bind_marker_type[current_bind_marker_type_pos] = c; | ||
| 76 | current_bind_marker_type_pos += 1; | ||
| 77 | }, | ||
| 78 | }, | ||
| 79 | else => { | ||
| 80 | @compileError("invalid state " ++ @tagName(state)); | ||
| 81 | }, | ||
| 82 | } | ||
| 83 | } | ||
| 84 | if (state == .BindMarker) { | ||
| 85 | @compileError("invalid final state " ++ @tagName(state) ++ ", this means you wrote a ? in last position without a bind marker type"); | ||
| 86 | } | ||
| 87 | if (state == .BindMarkerType) { | ||
| 88 | @compileError("invalid final state " ++ @tagName(state) ++ ", this means you wrote an incomplete bind marker type"); | ||
| 89 | } | ||
| 90 | |||
| 91 | mem.copy(u8, &parsed_query.query, &buf); | ||
| 92 | parsed_query.query_size = pos; | ||
| 93 | |||
| 94 | return parsed_query; | ||
| 95 | } | ||
| 96 | |||
| 97 | fn parseType(comptime self: *Self, type_info: []const u8) type { | ||
| 98 | if (type_info.len <= 0) @compileError("invalid type info " ++ type_info); | ||
| 99 | |||
| 100 | // Integer | ||
| 101 | if (mem.eql(u8, "usize", type_info)) return usize; | ||
| 102 | if (mem.eql(u8, "isize", type_info)) return isize; | ||
| 103 | |||
| 104 | if (type_info[0] == 'u' or type_info[0] == 'i') { | ||
| 105 | return @Type(builtin.TypeInfo{ | ||
| 106 | .Int = builtin.TypeInfo.Int{ | ||
| 107 | .is_signed = type_info[0] == 'i', | ||
| 108 | .bits = std.fmt.parseInt(usize, type_info[1..type_info.len], 10) catch { | ||
| 109 | @compileError("invalid type info " ++ type_info); | ||
| 110 | }, | ||
| 111 | }, | ||
| 112 | }); | ||
| 113 | } | ||
| 114 | |||
| 115 | // Float | ||
| 116 | if (mem.eql(u8, "f16", type_info)) return f16; | ||
| 117 | if (mem.eql(u8, "f32", type_info)) return f32; | ||
| 118 | if (mem.eql(u8, "f64", type_info)) return f64; | ||
| 119 | if (mem.eql(u8, "f128", type_info)) return f128; | ||
| 120 | |||
| 121 | // Strings | ||
| 122 | if (mem.eql(u8, "[]const u8", type_info) or mem.eql(u8, "[]u8", type_info)) { | ||
| 123 | return []const u8; | ||
| 124 | } | ||
| 125 | if (mem.eql(u8, "text", type_info)) return Text; | ||
| 126 | if (mem.eql(u8, "blob", type_info)) return Blob; | ||
| 127 | |||
| 128 | @compileError("invalid type info " ++ type_info); | ||
| 129 | } | ||
| 130 | |||
| 131 | pub fn getQuery(comptime self: *const Self) []const u8 { | ||
| 132 | return self.query[0..self.query_size]; | ||
| 133 | } | ||
| 134 | }; | ||
| 135 | |||
| 136 | test "parsed query: query" { | ||
| 137 | const testCase = struct { | ||
| 138 | query: []const u8, | ||
| 139 | expected_query: []const u8, | ||
| 140 | }; | ||
| 141 | |||
| 142 | const testCases = &[_]testCase{ | ||
| 143 | .{ | ||
| 144 | .query = "INSERT INTO user(id, name, age) VALUES(?{usize}, ?{[]const u8}, ?{u32})", | ||
| 145 | .expected_query = "INSERT INTO user(id, name, age) VALUES(?, ?, ?)", | ||
| 146 | }, | ||
| 147 | .{ | ||
| 148 | .query = "SELECT id, name, age FROM user WHER age > ?{u32} AND age < ?{u32}", | ||
| 149 | .expected_query = "SELECT id, name, age FROM user WHER age > ? AND age < ?", | ||
| 150 | }, | ||
| 151 | }; | ||
| 152 | |||
| 153 | inline for (testCases) |tc| { | ||
| 154 | comptime var parsed_query = ParsedQuery.from(tc.query); | ||
| 155 | std.debug.print("parsed query: {}\n", .{parsed_query.getQuery()}); | ||
| 156 | testing.expectEqualStrings(tc.expected_query, parsed_query.getQuery()); | ||
| 157 | } | ||
| 158 | } | ||
| 159 | |||
| 160 | test "parsed query: bind markers types" { | ||
| 161 | const testCase = struct { | ||
| 162 | query: []const u8, | ||
| 163 | expected_marker: BindMarker, | ||
| 164 | }; | ||
| 165 | |||
| 166 | const testCases = &[_]testCase{ | ||
| 167 | .{ | ||
| 168 | .query = "foobar ?{usize}", | ||
| 169 | .expected_marker = .{ .Type = usize }, | ||
| 170 | }, | ||
| 171 | .{ | ||
| 172 | .query = "foobar ?{text}", | ||
| 173 | .expected_marker = .{ .Type = Text }, | ||
| 174 | }, | ||
| 175 | .{ | ||
| 176 | .query = "foobar ?{blob}", | ||
| 177 | .expected_marker = .{ .Type = Blob }, | ||
| 178 | }, | ||
| 179 | }; | ||
| 180 | |||
| 181 | inline for (testCases) |tc| { | ||
| 182 | comptime var parsed_query = ParsedQuery.from(tc.query); | ||
| 183 | std.debug.print("parsed query: {}\n", .{parsed_query.getQuery()}); | ||
| 184 | |||
| 185 | testing.expectEqual(1, parsed_query.nb_bind_markers); | ||
| 186 | |||
| 187 | const bind_marker = parsed_query.bind_markers[0]; | ||
| 188 | testing.expectEqual(tc.expected_marker.Type, bind_marker.Type); | ||
| 189 | } | ||
| 190 | } | ||
| @@ -8,6 +8,8 @@ const c = @cImport({ | |||
| 8 | @cInclude("sqlite3.h"); | 8 | @cInclude("sqlite3.h"); |
| 9 | }); | 9 | }); |
| 10 | 10 | ||
| 11 | usingnamespace @import("query.zig"); | ||
| 12 | |||
| 11 | const logger = std.log.scoped(.sqlite); | 13 | const logger = std.log.scoped(.sqlite); |
| 12 | 14 | ||
| 13 | /// Db is a wrapper around a SQLite database, providing high-level functions for executing queries. | 15 | /// Db is a wrapper around a SQLite database, providing high-level functions for executing queries. |
| @@ -106,8 +108,9 @@ pub const Db = struct { | |||
| 106 | /// The statement returned is only compatible with the number of bind markers in the input query. | 108 | /// The statement returned is only compatible with the number of bind markers in the input query. |
| 107 | /// This is done because we type check the bind parameters when executing the statement later. | 109 | /// This is done because we type check the bind parameters when executing the statement later. |
| 108 | /// | 110 | /// |
| 109 | pub fn prepare(self: *Self, comptime query: []const u8) !Statement(StatementOptions.from(query)) { | 111 | pub fn prepare(self: *Self, comptime query: []const u8) !Statement(.{}, ParsedQuery.from(query)) { |
| 110 | return Statement(comptime StatementOptions.from(query)).prepare(self, 0, query); | 112 | const parsed_query = ParsedQuery.from(query); |
| 113 | return Statement(.{}, comptime parsed_query).prepare(self, 0); | ||
| 111 | } | 114 | } |
| 112 | 115 | ||
| 113 | /// rowsAffected returns the number of rows affected by the last statement executed. | 116 | /// rowsAffected returns the number of rows affected by the last statement executed. |
| @@ -116,28 +119,7 @@ pub const Db = struct { | |||
| 116 | } | 119 | } |
| 117 | }; | 120 | }; |
| 118 | 121 | ||
| 119 | /// Bytes is used to represent a byte slice with its SQLite datatype. | 122 | pub const StatementOptions = struct {}; |
| 120 | /// | ||
| 121 | /// Since Zig doesn't have strings we can't tell if a []u8 must be stored as a SQLite TEXT or BLOB, | ||
| 122 | /// this type can be used to communicate this when executing a statement. | ||
| 123 | /// | ||
| 124 | /// If a []u8 or []const u8 is passed as bind parameter it will be treated as TEXT. | ||
| 125 | pub const Bytes = union(enum) { | ||
| 126 | Blob: []const u8, | ||
| 127 | Text: []const u8, | ||
| 128 | }; | ||
| 129 | |||
| 130 | pub const StatementOptions = struct { | ||
| 131 | const Self = @This(); | ||
| 132 | |||
| 133 | bind_markers: usize, | ||
| 134 | |||
| 135 | fn from(comptime query: []const u8) Self { | ||
| 136 | return Self{ | ||
| 137 | .bind_markers = std.mem.count(u8, query, "?"), | ||
| 138 | }; | ||
| 139 | } | ||
| 140 | }; | ||
| 141 | 123 | ||
| 142 | /// Statement is a wrapper around a SQLite statement, providing high-level functions to execute | 124 | /// Statement is a wrapper around a SQLite statement, providing high-level functions to execute |
| 143 | /// a statement and retrieve rows for SELECT queries. | 125 | /// a statement and retrieve rows for SELECT queries. |
| @@ -172,19 +154,21 @@ pub const StatementOptions = struct { | |||
| 172 | /// | 154 | /// |
| 173 | /// Look at aach function for more complete documentation. | 155 | /// Look at aach function for more complete documentation. |
| 174 | /// | 156 | /// |
| 175 | pub fn Statement(comptime opts: StatementOptions) type { | 157 | pub fn Statement(comptime opts: StatementOptions, comptime query: ParsedQuery) type { |
| 176 | return struct { | 158 | return struct { |
| 177 | const Self = @This(); | 159 | const Self = @This(); |
| 178 | 160 | ||
| 179 | stmt: *c.sqlite3_stmt, | 161 | stmt: *c.sqlite3_stmt, |
| 180 | 162 | ||
| 181 | fn prepare(db: *Db, flags: c_uint, comptime query: []const u8) !Self { | 163 | fn prepare(db: *Db, flags: c_uint) !Self { |
| 182 | var stmt = blk: { | 164 | var stmt = blk: { |
| 165 | const real_query = query.getQuery(); | ||
| 166 | |||
| 183 | var tmp: ?*c.sqlite3_stmt = undefined; | 167 | var tmp: ?*c.sqlite3_stmt = undefined; |
| 184 | const result = c.sqlite3_prepare_v3( | 168 | const result = c.sqlite3_prepare_v3( |
| 185 | db.db, | 169 | db.db, |
| 186 | query.ptr, | 170 | real_query.ptr, |
| 187 | @intCast(c_int, query.len), | 171 | @intCast(c_int, real_query.len), |
| 188 | flags, | 172 | flags, |
| 189 | &tmp, | 173 | &tmp, |
| 190 | null, | 174 | null, |
| @@ -212,11 +196,15 @@ pub fn Statement(comptime opts: StatementOptions) type { | |||
| 212 | const StructType = @TypeOf(values); | 196 | const StructType = @TypeOf(values); |
| 213 | const StructTypeInfo = @typeInfo(StructType).Struct; | 197 | const StructTypeInfo = @typeInfo(StructType).Struct; |
| 214 | 198 | ||
| 215 | if (comptime opts.bind_markers != StructTypeInfo.fields.len) { | 199 | if (comptime query.nb_bind_markers != StructTypeInfo.fields.len) { |
| 216 | @compileError("number of bind markers not equal to number of fields"); | 200 | @compileError("number of bind markers not equal to number of fields"); |
| 217 | } | 201 | } |
| 218 | 202 | ||
| 219 | inline for (StructTypeInfo.fields) |struct_field, _i| { | 203 | inline for (StructTypeInfo.fields) |struct_field, _i| { |
| 204 | if (struct_field.field_type != query.bind_markers[_i].Type) { | ||
| 205 | @compileError("value type " ++ @typeName(struct_field.field_type) ++ " is not the bind marker type " ++ @typeName(query.bind_markers[_i].Type)); | ||
| 206 | } | ||
| 207 | |||
| 220 | const i = @as(usize, _i); | 208 | const i = @as(usize, _i); |
| 221 | const field_type_info = @typeInfo(struct_field.field_type); | 209 | const field_type_info = @typeInfo(struct_field.field_type); |
| 222 | const field_value = @field(values, struct_field.name); | 210 | const field_value = @field(values, struct_field.name); |
| @@ -226,10 +214,8 @@ pub fn Statement(comptime opts: StatementOptions) type { | |||
| 226 | []const u8, []u8 => { | 214 | []const u8, []u8 => { |
| 227 | _ = c.sqlite3_bind_text(self.stmt, column, field_value.ptr, @intCast(c_int, field_value.len), null); | 215 | _ = c.sqlite3_bind_text(self.stmt, column, field_value.ptr, @intCast(c_int, field_value.len), null); |
| 228 | }, | 216 | }, |
| 229 | Bytes => switch (field_value) { | 217 | Text => _ = c.sqlite3_bind_text(self.stmt, column, field_value.data.ptr, @intCast(c_int, field_value.data.len), null), |
| 230 | .Text => |v| _ = c.sqlite3_bind_text(self.stmt, column, v.ptr, @intCast(c_int, v.len), null), | 218 | Blob => _ = c.sqlite3_bind_blob(self.stmt, column, field_value.data.ptr, @intCast(c_int, field_value.data.len), null), |
| 231 | .Blob => |v| _ = c.sqlite3_bind_blob(self.stmt, column, v.ptr, @intCast(c_int, v.len), null), | ||
| 232 | }, | ||
| 233 | else => switch (field_type_info) { | 219 | else => switch (field_type_info) { |
| 234 | .Int, .ComptimeInt => _ = c.sqlite3_bind_int64(self.stmt, column, @intCast(c_longlong, field_value)), | 220 | .Int, .ComptimeInt => _ = c.sqlite3_bind_int64(self.stmt, column, @intCast(c_longlong, field_value)), |
| 235 | .Float, .ComptimeFloat => _ = c.sqlite3_bind_double(self.stmt, column, field_value), | 221 | .Float, .ComptimeFloat => _ = c.sqlite3_bind_double(self.stmt, column, field_value), |
| @@ -490,7 +476,7 @@ test "sqlite: statement exec" { | |||
| 490 | }; | 476 | }; |
| 491 | 477 | ||
| 492 | for (users) |user| { | 478 | for (users) |user| { |
| 493 | try db.exec("INSERT INTO user(id, name, age) VALUES(?, ?, ?)", user); | 479 | try db.exec("INSERT INTO user(id, name, age) VALUES(?{usize}, ?{[]const u8}, ?{usize})", user); |
| 494 | 480 | ||
| 495 | const rows_inserted = db.rowsAffected(); | 481 | const rows_inserted = db.rowsAffected(); |
| 496 | testing.expectEqual(@as(usize, 1), rows_inserted); | 482 | testing.expectEqual(@as(usize, 1), rows_inserted); |
| @@ -499,10 +485,10 @@ test "sqlite: statement exec" { | |||
| 499 | // Read a single user | 485 | // Read a single user |
| 500 | 486 | ||
| 501 | { | 487 | { |
| 502 | var stmt = try db.prepare("SELECT id, name, age FROM user WHERE id = ?"); | 488 | var stmt = try db.prepare("SELECT id, name, age FROM user WHERE id = ?{usize}"); |
| 503 | defer stmt.deinit(); | 489 | defer stmt.deinit(); |
| 504 | 490 | ||
| 505 | var rows = try stmt.all(User, .{ .allocator = allocator }, .{ .id = 20 }); | 491 | var rows = try stmt.all(User, .{ .allocator = allocator }, .{ .id = @as(usize, 20) }); |
| 506 | for (rows) |row| { | 492 | for (rows) |row| { |
| 507 | testing.expectEqual(users[0].id, row.id); | 493 | testing.expectEqual(users[0].id, row.id); |
| 508 | testing.expectEqualStrings(users[0].name, row.name); | 494 | testing.expectEqualStrings(users[0].name, row.name); |
| @@ -529,7 +515,7 @@ test "sqlite: statement exec" { | |||
| 529 | // Test with anonymous structs | 515 | // Test with anonymous structs |
| 530 | 516 | ||
| 531 | { | 517 | { |
| 532 | var stmt = try db.prepare("SELECT id, name, age FROM user WHERE id = ?"); | 518 | var stmt = try db.prepare("SELECT id, name, age FROM user WHERE id = ?{usize}"); |
| 533 | defer stmt.deinit(); | 519 | defer stmt.deinit(); |
| 534 | 520 | ||
| 535 | var row = try stmt.one( | 521 | var row = try stmt.one( |
| @@ -539,7 +525,7 @@ test "sqlite: statement exec" { | |||
| 539 | age: usize, | 525 | age: usize, |
| 540 | }, | 526 | }, |
| 541 | .{ .allocator = allocator }, | 527 | .{ .allocator = allocator }, |
| 542 | .{ .id = 20 }, | 528 | .{ .id = @as(usize, 20) }, |
| 543 | ); | 529 | ); |
| 544 | testing.expect(row != null); | 530 | testing.expect(row != null); |
| 545 | 531 | ||
| @@ -552,12 +538,12 @@ test "sqlite: statement exec" { | |||
| 552 | // Test with a single integer | 538 | // Test with a single integer |
| 553 | 539 | ||
| 554 | { | 540 | { |
| 555 | const query = "SELECT age FROM user WHERE id = ?"; | 541 | const query = "SELECT age FROM user WHERE id = ?{usize}"; |
| 556 | 542 | ||
| 557 | var stmt: Statement(StatementOptions.from(query)) = try db.prepare(query); | 543 | var stmt: Statement(.{}, ParsedQuery.from(query)) = try db.prepare(query); |
| 558 | defer stmt.deinit(); | 544 | defer stmt.deinit(); |
| 559 | 545 | ||
| 560 | var age = try stmt.one(usize, .{}, .{ .id = 20 }); | 546 | var age = try stmt.one(usize, .{}, .{ .id = @as(usize, 20) }); |
| 561 | testing.expect(age != null); | 547 | testing.expect(age != null); |
| 562 | 548 | ||
| 563 | testing.expectEqual(@as(usize, 33), age.?); | 549 | testing.expectEqual(@as(usize, 33), age.?); |
| @@ -566,10 +552,10 @@ test "sqlite: statement exec" { | |||
| 566 | // Test with a Bytes struct | 552 | // Test with a Bytes struct |
| 567 | 553 | ||
| 568 | { | 554 | { |
| 569 | try db.exec("INSERT INTO user(id, name, age) VALUES(?, ?, ?)", .{ | 555 | try db.exec("INSERT INTO user(id, name, age) VALUES(?{usize}, ?{blob}, ?{u32})", .{ |
| 570 | .id = 200, | 556 | .id = @as(usize, 200), |
| 571 | .name = Bytes{ .Text = "hello" }, | 557 | .name = Blob{ .data = "hello" }, |
| 572 | .age = 20, | 558 | .age = @as(u32, 20), |
| 573 | }); | 559 | }); |
| 574 | } | 560 | } |
| 575 | } | 561 | } |