From e5185f65051f881bf61e88542a1acd4957f8383b Mon Sep 17 00:00:00 2001 From: Uko Kokņevičs Date: Sun, 3 Aug 2025 12:54:12 +0300 Subject: Move bot configuration to SQL land --- src/Bot.zig | 9 ++- src/Config.zig | 1 + src/DB.zig | 100 ++++++++++++++++++++++++++++++++ src/inline_bots.zig | 60 +++++++++---------- src/main.zig | 74 ++++++++++++++++++++--- src/types.zig | 1 + src/types/AnswerCallbackQueryParams.zig | 5 ++ src/types/InlineKeyboardMarkup.zig | 2 +- src/types/SendMessageParams.zig | 4 +- 9 files changed, 211 insertions(+), 45 deletions(-) create mode 100644 src/DB.zig create mode 100644 src/types/AnswerCallbackQueryParams.zig (limited to 'src') diff --git a/src/Bot.zig b/src/Bot.zig index b0eb972..fb91b3f 100644 --- a/src/Bot.zig +++ b/src/Bot.zig @@ -6,6 +6,7 @@ const Allocator = std.mem.Allocator; const ArrayList = std.ArrayList; const Bot = @This(); const Config = @import("Config.zig"); +const DB = @import("DB.zig"); const HttpClient = std.http.Client; const HttpMethod = std.http.Method; const Parsed = std.json.Parsed; @@ -14,6 +15,7 @@ const Uri = std.Uri; allocator: Allocator, http_client: HttpClient, config: Config, +db: *DB, base_uri: Uri = Uri.parse("https://api.telegram.org/") catch unreachable, uri_path_data: ArrayList(u8), poweron: bool = true, @@ -21,7 +23,7 @@ server_header_buffer: [4096]u8 = undefined, username: ?[]const u8 = null, id: ?i64 = null, -pub fn init(allocator: Allocator, config: Config) !Bot { +pub fn init(allocator: Allocator, config: Config, db: *DB) !Bot { var uri_path_data = try ArrayList(u8).initCapacity(allocator, 5 + config.bot_token.len); errdefer uri_path_data.deinit(); @@ -35,6 +37,7 @@ pub fn init(allocator: Allocator, config: Config) !Bot { .allocator = allocator, }, .config = config, + .db = db, .uri_path_data = uri_path_data, }; } @@ -47,6 +50,10 @@ pub fn deinit(self: *Bot) void { self.* = undefined; } +pub inline fn answerCallbackQuery(self: *Bot, args: types.AnswerCallbackQueryParams) !void { + (try self.post(bool, "answerCallbackQuery", args)).deinit(); +} + pub inline fn deleteMessage(self: *Bot, args: types.DeleteMessageParams) !void { (try self.post(bool, "deleteMessage", args)).deinit(); } diff --git a/src/Config.zig b/src/Config.zig index f9d6dab..2732df3 100644 --- a/src/Config.zig +++ b/src/Config.zig @@ -41,6 +41,7 @@ pub const Wrapper = struct { }; bot_token: []const u8, +db_path: [:0]const u8, dev_group: i64, owner: i64, diff --git a/src/DB.zig b/src/DB.zig new file mode 100644 index 0000000..6510e7c --- /dev/null +++ b/src/DB.zig @@ -0,0 +1,100 @@ +const sqlite = @import("sqlite"); +const std = @import("std"); + +const DB = @This(); + +const target_version = 1; + +sql: sqlite.Db, + +pub const InlineBotType = enum(u32) { + blacklisted = 0, + whitelisted = 1, +}; + +pub fn init(db_path: [:0]const u8) !DB { + const sql = try sqlite.Db.init(.{ + .mode = .{ .File = db_path }, + .open_flags = .{ + .write = true, + .create = true, + }, + .threading_mode = .MultiThread, + }); + + return DB{ + .sql = sql, + }; +} + +pub fn deinit(self: *DB) void { + self.sql.deinit(); +} + +pub fn getInlineBotType(self: *DB, id: i64) !?InlineBotType { + const row = try self.sql.one(u32, "SELECT type FROM inline_bots WHERE id = ?", .{}, .{ .id = id }); + if (row) |r| { + return @enumFromInt(r); + } + return null; +} + +pub fn setInlineBotType(self: *DB, id: i64, ty: InlineBotType) !void { + try self.sql.exec("INSERT OR REPLACE INTO inline_bots (id, type) VALUES (?, ?)", .{}, .{ id, @intFromEnum(ty) }); +} + +pub fn upgrade(self: *DB) !void { + try self.sql.exec("CREATE TABLE IF NOT EXISTS version(id INTEGER PRIMARY KEY, version INTEGER)", .{}, .{}); + const row = try self.sql.one(struct { version: u32 }, "SELECT version FROM version WHERE id = 0", .{}, .{}); + var current_ver: u32 = if (row) |r| r.version else 0; + + if (current_ver == target_version) { + std.log.info("Database is up to date", .{}); + return; + } else if (current_ver > target_version) { + std.log.err("Database has a higher version than supported?", .{}); + return error.CorruptedDatabase; + } + + std.log.info("Updating database from version {} to {}", .{ current_ver, target_version }); + + var setVerStmt = try self.sql.prepare("INSERT OR REPLACE INTO version(id, version) VALUES (0, ?)"); + defer setVerStmt.deinit(); + + while (current_ver < target_version) : (current_ver += 1) { + std.log.info("Updating database step from {}", .{current_ver}); + try self.upgradeStep(current_ver + 1); + setVerStmt.reset(); + try setVerStmt.exec(.{}, .{ current_ver + 1 }); + } +} + +fn upgradeStep(self: *DB, new_version: u32) !void { + switch (new_version) { + 1 => { + try self.sql.exec("DROP TABLE IF EXISTS inline_bots_enum", .{}, .{}); + try self.sql.exec( + \\CREATE TABLE inline_bots_enum ( + \\ id INTEGER PRIMARY KEY, + \\ value TEXT UNIQUE + \\) + , .{}, .{}); + try self.sql.exec( + \\INSERT INTO inline_bots_enum(id, value) + \\VALUES (?, 'blacklisted'), (?, 'whitelisted') + , .{}, .{ + .blacklisted = @intFromEnum(InlineBotType.blacklisted), + .whitelisted = @intFromEnum(InlineBotType.whitelisted), + }); + + try self.sql.exec("DROP TABLE IF EXISTS inline_bots", .{}, .{}); + try self.sql.exec( + \\CREATE TABLE inline_bots ( + \\ id INTEGER PRIMARY KEY, + \\ type INTEGER REFERENCES inline_bots_enum(id) + \\) + , .{}, .{}); + }, + else => unreachable, + } +} diff --git a/src/inline_bots.zig b/src/inline_bots.zig index c6fa2b7..29824eb 100644 --- a/src/inline_bots.zig +++ b/src/inline_bots.zig @@ -4,44 +4,18 @@ const utils = @import("utils.zig"); const Bot = @import("Bot.zig"); -const whitelist = [_]i64{ - 90832338, // @vid - 109158646, // @bing - 114528005, // @pic - 136269978, // @ImageFetcherBot - 140267078, // @gif - 154595593, // @wiki - 184730458, // @UnitConversionBot - 223493268, // @minroobot - 296635833, // @lastfmrobot - 473587803, // @LyBot - 595898211, // @DeezerMusicBot - 733460033, // @crabravebot - 870410041, // @HowGayBot - 7904498194, // @tanstiktokbot -}; - -const blacklist = [_]i64{ - 6465471545, // @DickGrowerBot - 7759097490, // @CookieGrowerBot -}; - -comptime { - std.testing.expect(utils.isSorted(i64, &whitelist)) catch unreachable; - std.testing.expect(utils.isSorted(i64, &blacklist)) catch unreachable; +pub inline fn blacklistBot(bot: *Bot, inline_bot_id: i64) !void { + return bot.db.setInlineBotType(inline_bot_id, .blacklisted); } -inline fn isWhitelisted(bot: types.User) bool { - return utils.isIn(i64, bot.id, &whitelist); -} - -inline fn isBlacklisted(bot: types.User) bool { - return utils.isIn(i64, bot.id, &blacklist); +pub inline fn whitelistBot(bot: *Bot, inline_bot_id: i64) !void { + return bot.db.setInlineBotType(inline_bot_id, .whitelisted); } // Returns true if processing of message should continue pub fn onInlineBot(bot: *Bot, msg: types.Message, via: types.User) !bool { - if (isWhitelisted(via)) { + const ty = try bot.db.getInlineBotType(via.id); + if (ty == .whitelisted) { return true; } @@ -51,7 +25,7 @@ pub fn onInlineBot(bot: *Bot, msg: types.Message, via: types.User) !bool { .message_id = msg.message_id, }); - if (!isBlacklisted(via)) { + if (ty != .blacklisted) { // Not explicitly blacklisted, notify dev group const text = try std.fmt.allocPrint( bot.allocator, @@ -60,10 +34,30 @@ pub fn onInlineBot(bot: *Bot, msg: types.Message, via: types.User) !bool { ); defer bot.allocator.free(text); + const whitelist_cb = try std.fmt.allocPrint( + bot.allocator, + "bwl:{}", + .{ via.id }, + ); + defer bot.allocator.free(whitelist_cb); + + const blacklist_cb = try std.fmt.allocPrint( + bot.allocator, + "bbl:{}", + .{ via.id }, + ); + defer bot.allocator.free(blacklist_cb); + try bot.sendMessage_(.{ .chat_id = bot.config.dev_group, .text = text, .parse_mode = .html, + .reply_markup = .{ + .inline_keyboard = &.{&.{ + .{ .text = "Whitelist", .callback_data = whitelist_cb }, + .{ .text = "Blacklist", .callback_data = blacklist_cb }, + }}, + }, }); } diff --git a/src/main.zig b/src/main.zig index 942fd90..5931250 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,3 +1,4 @@ +const inline_bots = @import("inline_bots.zig"); const std = @import("std"); const types = @import("types.zig"); const utils = @import("utils.zig"); @@ -6,10 +7,9 @@ const Allocator = std.mem.Allocator; const ArrayList = std.ArrayList; const Bot = @import("Bot.zig"); const Config = @import("Config.zig"); +const DB = @import("DB.zig"); const GPA = std.heap.GeneralPurposeAllocator(.{}); -const onInlineBot = @import("inline_bots.zig").onInlineBot; - pub fn main() !void { defer std.log.info("We're done", .{}); @@ -22,7 +22,11 @@ pub fn main() !void { defer config.deinit(); try config.merge("config.json"); - var bot = try Bot.init(allocator, config.config); + var db = try DB.init(config.config.db_path); + defer db.deinit(); + try db.upgrade(); + + var bot = try Bot.init(allocator, config.config, &db); defer bot.deinit(); // TODO: Catch fatal errors, report them @@ -48,15 +52,15 @@ fn loadConfig(allocator: Allocator, filename: []const u8) !std.json.Parsed(Confi ); } -fn reportError(bot: *Bot, msg: types.Message, err: anyerror) !void { - std.log.err("While handling {}: {}", .{ msg, err }); - const msgStr = try std.json.stringifyAlloc(bot.allocator, msg, .{ +fn reportError(bot: *Bot, evt: anytype, err: anyerror) !void { + std.log.err("While handling {}: {}", .{ evt, err }); + const evtStr = try std.json.stringifyAlloc(bot.allocator, evt, .{ .whitespace = .indent_2, .emit_null_optional_fields = false, }); - defer bot.allocator.free(msgStr); + defer bot.allocator.free(evtStr); - const devMsg = try std.fmt.allocPrint(bot.allocator, "{} while handling\n
{s}
", .{ err, msgStr }); + const devMsg = try std.fmt.allocPrint(bot.allocator, "{} while handling\n
{s}
", .{ err, evtStr }); defer bot.allocator.free(devMsg); bot.sendMessage_(.{ @@ -90,6 +94,12 @@ fn wrappedMain(bot: *Bot) !void { try reportError(bot, message, err); }; } + + if (update.callback_query) |cb| { + onCallbackQuery(bot, cb) catch |err| { + try reportError(bot, cb, err); + }; + } } } @@ -104,9 +114,55 @@ fn wrappedMain(bot: *Bot) !void { }); } +fn onCallbackQuery(bot: *Bot, cb: types.CallbackQuery) !void { + if (cb.data) |cb_data| blk: { + if (std.mem.startsWith(u8, cb_data, "bbl:")) { + if (cb.from.id != bot.config.owner) { + break :blk; + } + + const inline_bot_id = try std.fmt.parseInt(i64, cb_data[4..], 10); + try inline_bots.blacklistBot(bot, inline_bot_id); + if (cb.message) |msg| { + try bot.deleteMessage(.{ + .chat_id = msg.chat.id, + .message_id = msg.message_id, + }); + } + } else if (std.mem.startsWith(u8, cb_data, "bwl:")) { + if (cb.from.id != bot.config.owner) { + break :blk; + } + + const inline_bot_id = try std.fmt.parseInt(i64, cb_data[4..], 10); + try inline_bots.whitelistBot(bot, inline_bot_id); + if (cb.message) |msg| { + try bot.deleteMessage(.{ + .chat_id = msg.chat.id, + .message_id = msg.message_id, + }); + } + } else { + break :blk; + } + + return bot.answerCallbackQuery(.{ + .callback_query_id = cb.id, + .text = "OK", + }); + } + + std.log.info("Unrecognised callback query data: {?s}", .{ cb.data }); + return bot.answerCallbackQuery(.{ + .callback_query_id = cb.id, + .text = "Unallowed callback query, don't press the button again", + .show_alert = true, + }); +} + fn onMessage(bot: *Bot, msg: types.Message) !void { if (msg.via_bot) |via| { - if (!try onInlineBot(bot, msg, via)) { + if (!try inline_bots.onInlineBot(bot, msg, via)) { return; } } diff --git a/src/types.zig b/src/types.zig index b99d24e..d203652 100644 --- a/src/types.zig +++ b/src/types.zig @@ -1,4 +1,5 @@ pub const Animation = @import("types/Animation.zig"); +pub const AnswerCallbackQueryParams = @import("types/AnswerCallbackQueryParams.zig"); pub const Audio = @import("types/Audio.zig"); pub const BackgroundFill = @import("types/background_fill.zig").BackgroundFill; pub const BackgroundType = @import("types/background_type.zig").BackgroundType; diff --git a/src/types/AnswerCallbackQueryParams.zig b/src/types/AnswerCallbackQueryParams.zig new file mode 100644 index 0000000..875cec1 --- /dev/null +++ b/src/types/AnswerCallbackQueryParams.zig @@ -0,0 +1,5 @@ +callback_query_id: []const u8, +text: ?[]const u8 = null, +show_alert: bool = false, +url: ?[]const u8 = null, +cache_time: u64 = 0, diff --git a/src/types/InlineKeyboardMarkup.zig b/src/types/InlineKeyboardMarkup.zig index 388d4fc..a246851 100644 --- a/src/types/InlineKeyboardMarkup.zig +++ b/src/types/InlineKeyboardMarkup.zig @@ -1,3 +1,3 @@ const InlineKeyboardButton = @import("InlineKeyboardButton.zig"); -inline_keyboard: [][]InlineKeyboardButton, +inline_keyboard: []const []const InlineKeyboardButton, diff --git a/src/types/SendMessageParams.zig b/src/types/SendMessageParams.zig index 8c84940..587055c 100644 --- a/src/types/SendMessageParams.zig +++ b/src/types/SendMessageParams.zig @@ -1,3 +1,4 @@ +const InlineKeyboardMarkup = @import("InlineKeyboardMarkup.zig"); const LinkPreviewOptions = @import("LinkPreviewOptions.zig"); const MessageEntity = @import("MessageEntity.zig"); const ParseMode = @import("parse_mode.zig").ParseMode; @@ -15,4 +16,5 @@ disable_notification: ?bool = null, protect_content: ?bool = null, message_effect_id: ?[]const u8 = null, reply_parameters: ?ReplyParameters = null, -// TODO: reply_markup: InlineKeyboardMarkup OR ReplyKeyboardMarkup OR ReplyKeyboardRemove OR ForceReply +// TODO: InlineKeyboardMarkup OR ReplyKeyboardMarkup OR ReplyKeyboardRemove OR ForceReply +reply_markup: ?InlineKeyboardMarkup = null, -- cgit v1.2.3