From 70bc70375f8e82843830d93631ab005302057a15 Mon Sep 17 00:00:00 2001 From: owl Date: Fri, 30 Aug 2024 12:23:33 +0200 Subject: count codepoints instead of bytes, to determine width A complete solution would be to count grapheme clusters, but that would require adding a dependency on something like zg. Counting codepoints will ensure that typical non-ASCII text is supported, but you can still throw it off with more complex Unicode constructions, which might not be so useful in help text. Fixes #75 --- clap/codepoint_counting_writer.zig | 102 +++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 clap/codepoint_counting_writer.zig (limited to 'clap/codepoint_counting_writer.zig') diff --git a/clap/codepoint_counting_writer.zig b/clap/codepoint_counting_writer.zig new file mode 100644 index 0000000..e6b9d1c --- /dev/null +++ b/clap/codepoint_counting_writer.zig @@ -0,0 +1,102 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const native_endian = builtin.cpu.arch.endian(); + +/// A Writer that counts how many codepoints has been written to it. +/// Expects valid UTF-8 input, and does not validate the input. +pub fn CodepointCountingWriter(comptime WriterType: type) type { + return struct { + codepoints_written: u64, + child_stream: WriterType, + + pub const Error = WriterType.Error || error{Utf8InvalidStartByte}; + pub const Writer = std.io.Writer(*Self, Error, write); + + const Self = @This(); + + pub fn write(self: *Self, bytes: []const u8) Error!usize { + const bytes_and_codepoints = try utf8CountCodepointsAllowTruncate(bytes); + // Might not be the full input, so the leftover bytes are written on the next call. + const bytes_to_write = bytes[0..bytes_and_codepoints.bytes]; + const amt = try self.child_stream.write(bytes_to_write); + const bytes_written = bytes_to_write[0..amt]; + self.codepoints_written += (try utf8CountCodepointsAllowTruncate(bytes_written)).codepoints; + return amt; + } + + pub fn writer(self: *Self) Writer { + return .{ .context = self }; + } + }; +} + +// Like `std.unicode.utf8CountCodepoints`, but on truncated input, it returns +// the number of codepoints up to that point. +// Does not validate UTF-8 beyond checking the start byte. +fn utf8CountCodepointsAllowTruncate(s: []const u8) !struct { bytes: usize, codepoints: usize } { + var len: usize = 0; + + const N = @sizeOf(usize); + const MASK = 0x80 * (std.math.maxInt(usize) / 0xff); + + var i: usize = 0; + while (i < s.len) { + // Fast path for ASCII sequences + while (i + N <= s.len) : (i += N) { + const v = std.mem.readInt(usize, s[i..][0..N], native_endian); + if (v & MASK != 0) break; + len += N; + } + + if (i < s.len) { + const n = try std.unicode.utf8ByteSequenceLength(s[i]); + // Truncated input; return the current counts. + if (i + n > s.len) return .{ .bytes = i, .codepoints = len }; + + i += n; + len += 1; + } + } + + return .{ .bytes = i, .codepoints = len }; +} + +pub fn codepointCountingWriter(child_stream: anytype) CodepointCountingWriter(@TypeOf(child_stream)) { + return .{ .codepoints_written = 0, .child_stream = child_stream }; +} + +const testing = std.testing; + +test CodepointCountingWriter { + var counting_stream = codepointCountingWriter(std.io.null_writer); + const stream = counting_stream.writer(); + + const utf8_text = "blåhaj" ** 100; + stream.writeAll(utf8_text) catch unreachable; + const expected_count = try std.unicode.utf8CountCodepoints(utf8_text); + try testing.expectEqual(expected_count, counting_stream.codepoints_written); +} + +test "handles partial UTF-8 writes" { + var buf: [100]u8 = undefined; + var fbs = std.io.fixedBufferStream(&buf); + var counting_stream = codepointCountingWriter(fbs.writer()); + const stream = counting_stream.writer(); + + const utf8_text = "ååå"; + // `å` is represented as `\xC5\xA5`, write 1.5 `å`s. + var wc = try stream.write(utf8_text[0..3]); + // One should have been written fully. + try testing.expectEqual("å".len, wc); + try testing.expectEqual(1, counting_stream.codepoints_written); + + // Write the rest, continuing from the reported number of bytes written. + wc = try stream.write(utf8_text[wc..]); + try testing.expectEqual(4, wc); + try testing.expectEqual(3, counting_stream.codepoints_written); + + const expected_count = try std.unicode.utf8CountCodepoints(utf8_text); + try testing.expectEqual(expected_count, counting_stream.codepoints_written); + + try testing.expectEqualSlices(u8, utf8_text, fbs.getWritten()); +} -- cgit v1.2.3