Jianw
2025-05-13 3b39fe3810c3ee2ec9ec97236c1769c5c85e062c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
local unicode = require "luacheck.unicode"
local utils = require "luacheck.utils"
 
local decoder = {}
 
local sbyte = string.byte
local sfind = string.find
local sgsub = string.gsub
local ssub = string.sub
 
-- `LatinChars` and `UnicodeChars` objects represent source strings
-- and provide Unicode-aware access to them with a common interface.
-- Source bytes should not be accessed directly.
-- Provided methods are:
-- `Chars:get_codepoint(index)`: returns codepoint at given index as integer or nil if index is out of range.
-- `Chars:get_substring(from, to)`: returns substring of original bytes corresponding to characters from `from` to `to`.
-- `Chars:get_printable_substring(from. to)`: like get_substring but escapes not printable characters.
-- `Chars:get_length()`: returns total number of characters.
-- `Chars:find(pattern, from)`: `string.find` but `from` is in characters. Return values are still in bytes.
 
-- `LatinChars` is an optimized special case for latin1 strings.
local LatinChars = utils.class()
 
function LatinChars:__init(bytes)
   self._bytes = bytes
end
 
function LatinChars:get_codepoint(index)
   return sbyte(self._bytes, index)
end
 
function LatinChars:get_substring(from, to)
   return ssub(self._bytes, from, to)
end
 
local function hexadecimal_escaper(byte)
   return ("\\x%02X"):format(sbyte(byte))
end
 
function LatinChars:get_printable_substring(from, to)
   return (sgsub(ssub(self._bytes, from, to), "[^\32-\126]", hexadecimal_escaper))
end
 
function LatinChars:get_length()
   return #self._bytes
end
 
function LatinChars:find(pattern, from)
   return sfind(self._bytes, pattern, from)
end
 
-- Decodes `bytes` as UTF8. Returns arrays of codepoints as integers and their byte offsets.
-- Byte offsets have one extra item pointing to one byte past the end of `bytes`.
-- On decoding error returns nothing.
local function get_codepoints_and_byte_offsets(bytes)
   local codepoints = {}
   local byte_offsets = {}
 
   local byte_index = 1
   local codepoint_index = 1
 
   while true do
      byte_offsets[codepoint_index] = byte_index
 
      -- Attempt to decode the next codepoint from UTF8.
      local codepoint = sbyte(bytes, byte_index)
 
      if not codepoint then
         return codepoints, byte_offsets
      end
 
      byte_index = byte_index + 1
 
      if codepoint >= 0x80 then
         -- Not ASCII.
 
         if codepoint < 0xC0 then
            return
         end
 
         local cont = (sbyte(bytes, byte_index) or 0) - 0x80
 
         if cont < 0 or cont >= 0x40 then
            return
         end
 
         byte_index = byte_index + 1
 
         if codepoint < 0xE0 then
            -- Two bytes.
            codepoint = cont + (codepoint - 0xC0) * 0x40
         elseif codepoint < 0xF0 then
            -- Three bytes.
            codepoint = cont + (codepoint - 0xE0) * 0x40
 
            cont = (sbyte(bytes, byte_index) or 0) - 0x80
 
            if cont < 0 or cont >= 0x40 then
               return
            end
 
            byte_index = byte_index + 1
 
            codepoint = cont + codepoint * 0x40
         elseif codepoint < 0xF8 then
            -- Four bytes.
            codepoint = cont + (codepoint - 0xF0) * 0x40
 
            cont = (sbyte(bytes, byte_index) or 0) - 0x80
 
            if cont < 0 or cont >= 0x40 then
               return
            end
 
            byte_index = byte_index + 1
 
            codepoint = cont + codepoint * 0x40
 
            cont = (sbyte(bytes, byte_index) or 0) - 0x80
 
            if cont < 0 or cont >= 0x40 then
               return
            end
 
            byte_index = byte_index + 1
 
            codepoint = cont + codepoint * 0x40
 
            if codepoint > 0x10FFFF then
               return
            end
         else
            return
         end
      end
 
      codepoints[codepoint_index] = codepoint
      codepoint_index = codepoint_index + 1
   end
end
 
-- `UnicodeChars` is the general case for non-latin1 strings.
-- Assumes UTF8, on decoding error falls back to latin1.
local UnicodeChars = utils.class()
 
function UnicodeChars:__init(bytes, codepoints, byte_offsets)
   self._bytes = bytes
   self._codepoints = codepoints
   self._byte_offsets = byte_offsets
end
 
function UnicodeChars:get_codepoint(index)
   return self._codepoints[index]
end
 
function UnicodeChars:get_substring(from, to)
   local byte_offsets = self._byte_offsets
   return ssub(self._bytes, byte_offsets[from], byte_offsets[to + 1] - 1)
end
 
function UnicodeChars:get_printable_substring(from, to)
   -- This is only called on syntax error, it's okay to be slow.
   local parts = {}
 
   for index = from, to do
      local codepoint = self._codepoints[index]
 
      if unicode.is_printable(codepoint) then
         table.insert(parts, self:get_substring(index, index))
      else
         table.insert(parts, (codepoint > 255 and "\\u{%X}" or "\\x%02X"):format(codepoint))
      end
   end
 
   return table.concat(parts)
end
 
function UnicodeChars:get_length()
   return #self._codepoints
end
 
function UnicodeChars:find(pattern, from)
   return sfind(self._bytes, pattern, self._byte_offsets[from])
end
 
function decoder.decode(bytes)
   -- Only use UnicodeChars if necessary. LatinChars isn't much faster but noticeably more memory efficient.
   if sfind(bytes, "[\128-\255]") then
      local codepoints, byte_offsets = get_codepoints_and_byte_offsets(bytes)
 
      if codepoints then
         return UnicodeChars(bytes, codepoints, byte_offsets)
      end
   end
 
   return LatinChars(bytes)
end
 
return decoder