1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
|
/// A Writer that counts how many codepoints has been written to it.
/// Expects valid UTF-8 input, and does not validate the input.
pub const CodepointCountingWriter = struct {
codepoints_written: u64 = 0,
child_stream: *std.Io.Writer,
interface: std.Io.Writer = .{
.buffer = &.{},
.vtable = &.{ .drain = drain },
},
const Self = @This();
pub fn init(child_stream: *std.Io.Writer) Self {
return .{
.child_stream = child_stream,
};
}
fn drain(w: *std.Io.Writer, data: []const []const u8, splat: usize) std.Io.Writer.Error!usize {
const self: *Self = @alignCast(@fieldParentPtr("interface", w));
var n_bytes_written: usize = 0;
var i: usize = 0;
while (i < data.len + splat - 1) : (i += 1) {
const chunk = data[@min(i, data.len)];
const bytes_and_codepoints = utf8CountCodepointsAllowTruncate(chunk) catch return std.Io.Writer.Error.WriteFailed;
// Might not be the full input, so the leftover bytes are written on the next call.
const bytes_to_write = chunk[0..bytes_and_codepoints.bytes];
const amt = try self.child_stream.write(bytes_to_write);
n_bytes_written += amt;
const bytes_written = bytes_to_write[0..amt];
self.codepoints_written += (utf8CountCodepointsAllowTruncate(bytes_written) catch return std.Io.Writer.Error.WriteFailed).codepoints;
}
return n_bytes_written;
}
};
// Like `std.unicode.utf8CountCodepoints`, but on truncated input, it returns
// the number of codepoints up to that point.
// Does not validate UTF-8 beyond checking the start byte.
fn utf8CountCodepointsAllowTruncate(s: []const u8) !struct { bytes: usize, codepoints: usize } {
const native_endian = @import("builtin").cpu.arch.endian();
var len: usize = 0;
const N = @sizeOf(usize);
const MASK = 0x80 * (std.math.maxInt(usize) / 0xff);
var i: usize = 0;
while (i < s.len) {
// Fast path for ASCII sequences
while (i + N <= s.len) : (i += N) {
const v = std.mem.readInt(usize, s[i..][0..N], native_endian);
if (v & MASK != 0) break;
len += N;
}
if (i < s.len) {
const n = try std.unicode.utf8ByteSequenceLength(s[i]);
// Truncated input; return the current counts.
if (i + n > s.len) return .{ .bytes = i, .codepoints = len };
i += n;
len += 1;
}
}
return .{ .bytes = i, .codepoints = len };
}
const testing = std.testing;
test CodepointCountingWriter {
var discarding = std.Io.Writer.Discarding.init(&.{});
var counting_stream = CodepointCountingWriter.init(&discarding.writer);
const utf8_text = "blåhaj" ** 100;
counting_stream.interface.writeAll(utf8_text) catch unreachable;
const expected_count = try std.unicode.utf8CountCodepoints(utf8_text);
try testing.expectEqual(expected_count, counting_stream.codepoints_written);
}
test "handles partial UTF-8 writes" {
var buf: [100]u8 = undefined;
var fbs = std.Io.Writer.fixed(&buf);
var counting_stream = CodepointCountingWriter.init(&fbs);
const utf8_text = "ååå";
// `å` is represented as `\xC5\xA5`, write 1.5 `å`s.
var wc = try counting_stream.interface.write(utf8_text[0..3]);
// One should have been written fully.
try testing.expectEqual("å".len, wc);
try testing.expectEqual(1, counting_stream.codepoints_written);
// Write the rest, continuing from the reported number of bytes written.
wc = try counting_stream.interface.write(utf8_text[wc..]);
try testing.expectEqual(4, wc);
try testing.expectEqual(3, counting_stream.codepoints_written);
const expected_count = try std.unicode.utf8CountCodepoints(utf8_text);
try testing.expectEqual(expected_count, counting_stream.codepoints_written);
try testing.expectEqualSlices(u8, utf8_text, fbs.buffered());
}
const std = @import("std");
|