1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
|
const std = @import("std");
const builtin = @import("builtin");
const native_endian = builtin.cpu.arch.endian();
/// A Writer that counts how many codepoints has been written to it.
/// Expects valid UTF-8 input, and does not validate the input.
pub fn CodepointCountingWriter(comptime WriterType: type) type {
return struct {
codepoints_written: u64,
child_stream: WriterType,
pub const Error = WriterType.Error || error{Utf8InvalidStartByte};
pub const Writer = std.io.Writer(*Self, Error, write);
const Self = @This();
pub fn write(self: *Self, bytes: []const u8) Error!usize {
const bytes_and_codepoints = try utf8CountCodepointsAllowTruncate(bytes);
// Might not be the full input, so the leftover bytes are written on the next call.
const bytes_to_write = bytes[0..bytes_and_codepoints.bytes];
const amt = try self.child_stream.write(bytes_to_write);
const bytes_written = bytes_to_write[0..amt];
self.codepoints_written += (try utf8CountCodepointsAllowTruncate(bytes_written)).codepoints;
return amt;
}
pub fn writer(self: *Self) Writer {
return .{ .context = self };
}
};
}
// Like `std.unicode.utf8CountCodepoints`, but on truncated input, it returns
// the number of codepoints up to that point.
// Does not validate UTF-8 beyond checking the start byte.
fn utf8CountCodepointsAllowTruncate(s: []const u8) !struct { bytes: usize, codepoints: usize } {
var len: usize = 0;
const N = @sizeOf(usize);
const MASK = 0x80 * (std.math.maxInt(usize) / 0xff);
var i: usize = 0;
while (i < s.len) {
// Fast path for ASCII sequences
while (i + N <= s.len) : (i += N) {
const v = std.mem.readInt(usize, s[i..][0..N], native_endian);
if (v & MASK != 0) break;
len += N;
}
if (i < s.len) {
const n = try std.unicode.utf8ByteSequenceLength(s[i]);
// Truncated input; return the current counts.
if (i + n > s.len) return .{ .bytes = i, .codepoints = len };
i += n;
len += 1;
}
}
return .{ .bytes = i, .codepoints = len };
}
pub fn codepointCountingWriter(child_stream: anytype) CodepointCountingWriter(@TypeOf(child_stream)) {
return .{ .codepoints_written = 0, .child_stream = child_stream };
}
const testing = std.testing;
test CodepointCountingWriter {
var counting_stream = codepointCountingWriter(std.io.null_writer);
const stream = counting_stream.writer();
const utf8_text = "blåhaj" ** 100;
stream.writeAll(utf8_text) catch unreachable;
const expected_count = try std.unicode.utf8CountCodepoints(utf8_text);
try testing.expectEqual(expected_count, counting_stream.codepoints_written);
}
test "handles partial UTF-8 writes" {
var buf: [100]u8 = undefined;
var fbs = std.io.fixedBufferStream(&buf);
var counting_stream = codepointCountingWriter(fbs.writer());
const stream = counting_stream.writer();
const utf8_text = "ååå";
// `å` is represented as `\xC5\xA5`, write 1.5 `å`s.
var wc = try stream.write(utf8_text[0..3]);
// One should have been written fully.
try testing.expectEqual("å".len, wc);
try testing.expectEqual(1, counting_stream.codepoints_written);
// Write the rest, continuing from the reported number of bytes written.
wc = try stream.write(utf8_text[wc..]);
try testing.expectEqual(4, wc);
try testing.expectEqual(3, counting_stream.codepoints_written);
const expected_count = try std.unicode.utf8CountCodepoints(utf8_text);
try testing.expectEqual(expected_count, counting_stream.codepoints_written);
try testing.expectEqualSlices(u8, utf8_text, fbs.getWritten());
}
|