• 欢迎来到THBWiki!如果您是第一次来到这里,请点击右上角注册一个帐户
  • 有任何意见、建议、求助、反馈都可以在 讨论板 提出
  • THBWiki以专业性和准确性为目标,如果你发现了任何确定的错误或疏漏,可在登录后直接进行改正

模块:list

From THBWiki
Jump to navigation Jump to search
Lua-Logo.svg 模块文档[创建]
------------------------------------------------------------
-- 本模块提供 python 风格的列表
-- 使用 list(), list(table) 构造列表
-- 然后可以使用各种方法和运算, 如 l:append(), l1 == l2
-- 我在 lua 里写 python .jpg
--
-- 主要参考了 http://lua-users.org/wiki/PythonLists
-- 但是本模块实现的功能更全面, 性能更好
------------------------------------------------------------

-- 利用局部变量加速访问
local table_insert = table.insert
local table_remove = table.remove
local table_concat = table.concat
local table_sort = table.sort
local string_len = mw.ustring.len
local string_sub = mw.ustring.sub
local type = type
local tostring = tostring
local unpack = unpack
local setmetatable = setmetatable
local rawset = rawset

local list = {}

-- 构造函数. 可以把表, 字符串, 迭代器转化为列表对象
local function construct(_, ...)
    -- 传入迭代器时会有 3 个参数
    local t, s, v = ...
    local l = {}
    if t == nil then
        return setmetatable(l, list)
    end
    local Type = type(t)
    if Type == 'table' then
        for i = 1, #t do
            l[i] = t[i]
        end
    elseif Type == 'string' then
        for i = 1, string_len(t) do
            l[i] = string_sub(t, i, i)
        end
    elseif Type == 'function' then
        v = { t(s, v) }
        local m = #v
        local i = 1
        if m == 1 then
            v = v[1]
            while v ~= nil do
                l[i] = v
                i = i + 1
                v = t(s, v)
            end
        elseif m > 1 then
            local v1 = v[1]
            while v1 ~= nil do
                l[i] = v
                i = i + 1
                v = { t(s, v1) }
                v1 = v[1]
            end
        end
    else
        error('list() expects table, string, or iterator function')
    end
    return setmetatable(l, list)
end

setmetatable(list, {
    __call = construct  -- list(t) -> construct(list, t)
})

------------------------------------------------------------
-- 1. 索引, 列表方法与运算符
-- 基本上可以完全还原 python 的用法
------------------------------------------------------------

-- 访问的键不存在时调用此元方法
function list:__index(i)
    local Type = type(i)
    if Type == 'string' then
        return list[i]  -- 去元表中查找方法
    elseif Type == 'number' and i < 0 then
        local n = #self
        if i >= -n then
            return self[i + n + 1]  -- 支持负索引
        end
    end
    return nil
end

-- 赋值的键不存在时调用此元方法
function list:__newindex(i, v)
    local n = #self
    if type(i) == 'number' and i < 0 and i >= -n then
        self[i + n + 1] = v
        return
    end
    if type(i) == 'number' and i % 1 == 0 then
        error(string.format('list assignment index out of range (index=%d, size=%d)', i, n))
    end
    error('invalid list assignment at index ' .. tostring(i))
end

-- 利用 rawset 防止触发 __newindex
function list:append(el)
    rawset(self, #self + 1, el)
end

function list:extend(t)
    local n = #self
    for i = 1, #t do
        rawset(self, n + i, t[i])
    end
end

function list:insert(i, el)
    local n = #self
    if i < 0 then
        i = i + n + 1
    end
    if i < 1 then
        i = 1
    end
    if i > n + 1 then
        i = n + 1
    end
    table_insert(self, i, el)
end

function list:pop(i)
    if not i then
        return table_remove(self, #self)
    end
    if i < 0 then
        i = i + #self + 1
    end
    return table_remove(self, i)
end

function list:remove(el)
    for i = 1, #self do
        if self[i] == el then
            table_remove(self, i)
            return
        end
    end
    error(string.format('list:remove(%s) failed, value not in list', tostring(el)))
end

function list:index(el)
    for i = 1, #self do
        if self[i] == el then
            return i
        end
    end
    return nil
end

function list:count(el)
    local c = 0
    for i = 1, #self do
        if self[i] == el then
            c = c + 1
        end
    end
    return c
end

function list:reverse()
    local i, j = 1, #self
    while i < j do
        self[i], self[j] = self[j], self[i]
        i = i + 1
        j = j - 1
    end
end

function list:sort(args)
    local key, reverse

    if args then
        key = args.key
        reverse = args.reverse
    end

    key = key or function(x) return x end

    local decorated = {}
    for i, v in ipairs(self) do
        decorated[i] = { key(v), i, v }
    end

    table_sort(decorated, function(a, b)
        if a[1] ~= b[1] then
            return a[1] < b[1]
        end
        return a[2] < b[2]  -- 用索引比较, 保证稳定性
    end)

    for i = 1, #self do
        self[i] = decorated[i][3]
    end

    if reverse then
        self:reverse()
    end
end

function list:clear()
    for i = #self, 1, -1 do
        self[i] = nil
    end
end

function list:copy()
    return list(self)
end

-- l1 == l2 调用的元方法
function list.__eq(a, b)
    local n = #a
    if n ~= #b then
        return false
    end
    for i = 1, n do
        if a[i] ~= b[i] then
            return false
        end
    end
    return true
end

-- l1 < l2 调用的元方法
function list.__lt(a, b)
    local n_a, n_b = #a, #b
    local n = n_a < n_b and n_a or n_b
    for i = 1, n do
        if a[i] < b[i] then
            return true
        elseif a[i] > b[i] then
            return false
        end
    end
    return n_a < n_b
end

-- l1 <= l2 调用的元方法
function list.__le(a, b)
    return a < b or a == b
end

-- l1 .. l2 调用的元方法. 因为 lua 有连接运算符所以不用 +
function list.__concat(a, b)
    local l = list(a)
    l:extend(b)
    return l
end

-- lst * num 的元方法, 用来重复列表
function list.__mul(a, b)
    local type_a, type_b = type(a), type(b)
    local lst, num
    if type_a == 'table' and type_b == 'number' then
        lst, num = a, b
    elseif type_b == 'table' and type_a == 'number' then
        lst, num = b, a
    elseif type_a == 'table' and type_b == 'table' then
        error('attempt to multiply two lists')
    else
        local type_else = (type_a == 'table') and type_b or type_a
        error('attempt to multiply list and ' .. type_else)
    end

    local l = {}
    local n = #lst
    for i = 0, num - 1 do
        local m = n * i
        for j = 1, n do
            l[m + j] = lst[j]
        end
    end
    return setmetatable(l, list)
end

-- 以更易阅读的格式输出
function list:__tostring()
    local strs = {}
    for i = 1, #self do
        strs[i] = tostring(self[i])
    end
    return '{' .. table_concat(strs, ', ') .. '}'
end

------------------------------------------------------------
-- 2. 切片语法, del 语句与内置函数
-- 这部分只能作为方法实现, 不能完全还原
------------------------------------------------------------

local function normalize_slice(n, start, stop, step)
    step = step or 1

    if step == 0 then
        error('slice step cannot be zero')
    end

    if not start then
        start = step > 0 and 1 or n
    end
    if not stop then
        stop = step > 0 and n or 1
    end

    if start < 0 then
        start = start + n + 1
    end
    if stop < 0 then
        stop = stop + n + 1
    end

    if step > 0 then
        if start < 1 then start = 1 end
        if stop > n then stop = n end
    else
        if start > n then start = n end
        if stop < 1 then stop = 1 end
    end

    return start, stop, step
end

-- 切片语法 l[a:b:c] 对应 l:slice{a,b,c}
function list:slice(args)
    local start, stop, step = normalize_slice(#self, unpack(args, 1, 3))

    local l = {}
    local j = 1
    for i = start, stop, step do
        l[j] = self[i]
        j = j + 1
    end

    return setmetatable(l, list)
end

-- 切片赋值 l1[a:b:c] = l2 对应 l1:slice_set({a,b,c}, l2)
function list:slice_set(args, t)
    local n = #self
    local start, stop, step = normalize_slice(n, unpack(args, 1, 3))

    local indices = {}
    local j = 1
    for i = start, stop, step do
        indices[j] = i
        j = j + 1
    end

    local n_old, n_new = #indices, #t

    if step ~= 1 then
        if n_old ~= n_new then
            error('attempt to assign sequence of size ' .. n_new .. ' to extended slice of size ' .. n_old)
        end
        for i = 1, n_old do
            self[indices[i]] = t[i]
        end
        return
    end

    local d = n_new - n_old

    if d > 0 then
        for i = n, stop + 1, -1 do
            rawset(self, i + d, self[i])
        end
    elseif d < 0 then
        for i = stop + 1, n do
            self[i + d] = self[i]
        end
        for i = n, n + d + 1, -1 do
            self[i] = nil
        end
    end

    for i = 1, n_new do
        rawset(self, start + i - 1, t[i])
    end
end

-- 切片删除 del l[a:b:c] 对应 l:slice_del{a,b,c}
function list:slice_del(args)
    if (args[3] or 1) == 1 then
        self:slice_set(args, {})
        return
    end

    local n = #self
    local start, stop, step = normalize_slice(n, unpack(args, 1, 3))

    local mark = {}
    for i = start, stop, step do
        mark[i] = true
    end

    local j = step > 0 and start or stop
    for i = j, n do
        if not mark[i] then
            self[j] = self[i]
            j = j + 1
        end
    end

    for i = n, j, -1 do
        self[i] = nil
    end
end

-- 这样就可以使用 l{a,b,c} 来切片
function list:__call(args)
    return self:slice(args)
end

-- del l[i], l[j], ... 对应 l:del(i, j, ...)
function list:del(...)
    local indices = {...}
    for i = 1, #indices do
        local index = indices[i]
        if index < 0 then
            index = index + #self + 1
        end
        table_remove(self, index)
    end
end

function list.range(start, stop, step)
    if step == 0 then
        error('range step cannot be zero')
    end

    if not (stop or step) then
        start, stop = 1, start
    end

    local l = {}
    local j = 1
    for i = start, stop, step or 1 do
        l[j] = i
        j = j + 1
    end

    return setmetatable(l, list)
end

function list:len()
    return #self
end

function list:max(key)
    local max = self[1]
    if max == nil then
        return nil
    end
    if key then
        local maxkey = key(max)
        for i = 2, #self do
            local v = self[i]
            local k = key(v)
            if k > maxkey then
                max = v
                maxkey = k
            end
        end
    else
        for i = 2, #self do
            local v = self[i]
            if v > max then
                max = v
            end
        end
    end
    return max
end

function list:min(key)
    local min = self[1]
    if min == nil then
        return nil
    end
    if key then
        local minkey = key(min)
        for i = 2, #self do
            local v = self[i]
            local k = key(v)
            if k < minkey then
                min = v
                minkey = k
            end
        end
    else
        for i = 2, #self do
            local v = self[i]
            if v < min then
                min = v
            end
        end
    end
    return min
end

function list:sum(start)
    local sum = start or 0
    for i = 1, #self do
        sum = sum + self[i]
    end
    return sum
end

-- python 中可以通过 sum(l, []) 来求和列表的列表, 但是 lua 列表连接用 .. 更好
-- 和字符串列表的 join 语义一致, 所以使用 l:join({}) 得到相同效果
function list:join(sep)
    local Type = type(sep)
    if Type == 'string' then
        return table_concat(self, sep)
    elseif Type == 'table' then
        local l = list(self[1])
        for i = 2, #self do
            l:extend(sep)
            l:extend(self[i])
        end
        return l
    end
    error('seperator should be string or table, not ' .. tostring(sep))
end

-- 借鉴了 js, 可以传入最多 3 个参数的函数 f(value, index, list)
function list:all(f)
    if f then
        for i = 1, #self do
            if not f(self[i], i, self) then
                return false
            end
        end
    else
        for i = 1, #self do
            if not self[i] then
                return false
            end
        end
    end
    return true
end

-- 同上
function list:any(f)
    if f then
        for i = 1, #self do
            if f(self[i], i, self) then
                return true
            end
        end
    else
        for i = 1, #self do
            if self[i] then
                return true
            end
        end
    end
    return false
end

function list:reversed()
    local l = list(self)
    l:reverse()
    return l
end

function list:sorted(args)
    local l = list(self)
    l:sort(args)
    return l
end

-- 去重同时保持顺序. 元素是列表时无法去重
function list:set()
    local seen, l = {}, {}
    local j = 1
    for i = 1, #self do
        local v = self[i]
        if not seen[v] then
            seen[v] = true
            l[j] = v
            j = j + 1
        end
    end
    return setmetatable(l, list)
end

function list:map(f)
    local l = {}
    for i = 1, #self do
        l[i] = f(self[i])
    end
    return setmetatable(l, list)
end

function list:filter(f)
    local l = {}
    local j = 1
    for i = 1, #self do
        local v = self[i]
        if f(v) then
            l[j] = v
            j = j + 1
        end
    end
    return setmetatable(l, list)
end

-- 返回迭代器, 可以写 for a, b in list.zip(l1, l2) do ... 或 l1:zip(l2)
function list.zip(...)
    local lists = {...}
    local m = #lists
    local i = 1
    return function()
        local t ={}
        for j = 1, m do
            t[j] = lists[j][i]
        end
        i = i + 1
        return unpack(t, 1, m)
    end
end

-- 单列表迭代器
function list:iter()
    return self:zip()
end

------------------------------------------------------------
-- 3. 扩展方法
-- 为了方便使用而添加的内容, 有些借鉴了其他语言
------------------------------------------------------------

-- 从一般的表中提取键值列表
function list.fromkeys(t)
    local l = {}
    local i = 1
    for k, _ in pairs(t) do
        l[i] = k
        i = i + 1
    end
    return setmetatable(l, list)
end

function list.fromvalues(t)
    local l = {}
    local i = 1
    for _, v in pairs(t) do
        l[i] = v
        i = i + 1
    end
    return setmetatable(l, list)
end

function list.frompairs(t)
    return list(pairs(t))
end

-- 借鉴了 js, 允许传入最多 3 个参数的函数 f(value, index, list)
function list:mapi(f)
    local l = {}
    for i = 1, #self do
        l[i] = f(self[i], i, self)
    end
    return setmetatable(l, list)
end

function list:filteri(f)
    local l = {}
    local j = 1
    for i = 1, #self do
        local v = self[i]
        if f(v, i, self) then
            l[j] = v
            j = j + 1
        end
    end
    return setmetatable(l, list)
end

-- 著名的高阶函数
-- 同样借鉴了 js, 允许传入最多 4 个参数的函数 f(a, b, index, list)
function list:reduce(f, init)
    local acc = init or self[1]
    local start = init and 1 or 2
    for i = start, #self do
        acc = f(acc, self[i], i, self)
    end
    return acc
end

-- 来自 js
function list:find_index(f)
    for i = 1, #self do
        if f(self[i], i, self) then
            return i
        end
    end
    return nil
end

function list:find(f)
    local i = self:find_index(f)
    if i then
        return self[i]
    end
    return nil
end

return list