raze/polymer/eduke32/source/lunatic/bitar.lua
2012-08-10 19:11:43 +00:00

276 lines
7.7 KiB
Lua

-- "Bit array" module based on LuaJIT's BitOp.
local bit = require "bit"
local math = require "math"
local ffi = require "ffi"
local assert = assert
local error = error
local type = type
local tostring = tostring
module(...)
ffi.cdef[[
struct bitar { const double maxbidx, maxidx; const intptr_t arptr; }
]]
local bitar_ct = ffi.typeof("struct bitar")
local ptr_to_int = ffi.typeof("int32_t *")
local anchor = {}
-- population count of a nibble
local nibpop = ffi.new("double [?]", 16,
{ 0, 1, 1, 2, 1, 2, 2, 3,
1, 2, 2, 3, 2, 3, 3, 4 })
-- ...and of a byte
local bytepop = ffi.new("double [?]", 256)
for i=0,255 do
bytepop[i] = nibpop[bit.band(i, 15)] + nibpop[bit.rshift(i, 4)]
end
nibpop = nil
local function bitar_from_intar(maxbidx, maxidx, ar)
-- We need to have the int32_t[?] array be reachable so that it will not be
-- garbage collected
local ar_intptr = ffi.cast("intptr_t", ar)
anchor[tostring(ar_intptr)] = ar
-- Leaving the (potential) high trailing bits at 0 lets us not worry
-- about them in the population count calculation (__len metamethod).
-- Also, this is correct for maxbidx%32 == 0, since BitOp's shifts
-- mask the 5 lower bits of the counts.
local numremain = bit.band(maxbidx+1, 31)
ar[maxidx] = bit.band(ar[maxidx], bit.rshift(-1, 32-numremain))
return bitar_ct(maxbidx, maxidx, ar_intptr)
end
local function setop_common_rel(s1, s2)
if (s1.maxbidx ~= s2.maxbidx) then
error("bad arguments to bit array set op: must be of same length", 4)
end
local ar1 = ffi.cast(ptr_to_int, s1.arptr)
local ar2 = ffi.cast(ptr_to_int, s2.arptr)
return ar1, ar2
end
local function setop_common(s1, s2)
if (not ffi.istype(bitar_ct, s1) or not ffi.istype(bitar_ct, s2)) then
error("bad arguments to bit array set op: both must be 'bitar' types", 3)
end
local ar1, ar2 = setop_common_rel(s1, s2)
local ar = ffi.new("int32_t [?]", s1.maxidx+1)
return ar, ar1, ar2
end
local mt = {
--- Operational methods
__add = function(s1, s2) -- set union
local ar, ar1, ar2 = setop_common(s1, s2)
for i=0,s1.maxidx do
ar[i] = bit.bor(ar1[i], ar2[i])
end
return bitar_from_intar(s1.maxbidx, s1.maxidx, ar)
end,
__mul = function(s1, s2) -- set intersection
local ar, ar1, ar2 = setop_common(s1, s2)
for i=0,s1.maxidx do
ar[i] = bit.band(ar1[i], ar2[i])
end
return bitar_from_intar(s1.maxbidx, s1.maxidx, ar)
end,
__sub = function(s1, s2) -- set difference
local ar, ar1, ar2 = setop_common(s1, s2)
for i=0,s1.maxidx do
ar[i] = bit.band(ar1[i], bit.bnot(ar2[i]))
end
return bitar_from_intar(s1.maxbidx, s1.maxidx, ar)
end,
__unm = function(s) -- bitwise NOT
local newar = ffi.new("int32_t [?]", s.maxidx+1)
local oldar = ffi.cast(ptr_to_int, s.arptr)
for i=0,s.maxidx do
newar[i] = bit.bnot(oldar[i])
end
return bitar_from_intar(s.maxbidx, s.maxidx, newar)
end,
--- Additional operations
__index = {
-- Is bit i set?
isset = function(s, i)
if (not (i >= 0 and i<=s.maxbidx)) then
error("bad bit index for isset: must be in [0.."..s.maxbidx.."]", 2)
end
s = ffi.cast(ptr_to_int, s.arptr)
return (bit.band(s[bit.rshift(i, 5)], bit.lshift(1, i)) ~= 0)
end,
-- Clear bit i.
set0 = function(s, i)
if (not (i >= 0 and i<=s.maxbidx)) then
error("bad bit index for set0: must be in [0.."..s.maxbidx.."]", 2)
end
s = ffi.cast(ptr_to_int, s.arptr)
local jx = bit.rshift(i, 5)
s[jx] = bit.band(s[jx], bit.rol(0xfffffffe, i))
end,
-- Set bit i.
set1 = function(s, i)
if (not (i >= 0 and i<=s.maxbidx)) then
error("bad bit index for set1: must be in [0.."..s.maxbidx.."]", 2)
end
s = ffi.cast(ptr_to_int, s.arptr)
local jx = bit.rshift(i, 5)
s[jx] = bit.bor(s[jx], bit.rol(0x00000001, i))
end
},
--- Relational methods
__eq = function(s1, s2) -- set identity
local ar1, ar2 = setop_common_rel(s1, s2)
for i=0,s1.maxidx do
if (bit.bxor(ar1[i], ar2[i]) ~= 0) then
return false
end
end
return true
end,
__le = function(s1, s2)
local ar1, ar2 = setop_common_rel(s1, s2)
for i=0,s1.maxidx do
if (bit.band(ar1[i], bit.bnot(ar2[i])) ~= 0) then
return false
end
end
return true
end,
__lt = function(s1, s2)
return s1 <= s2 and not (s2 == s1)
end,
-- The length operator gets the population count of the bit array, i.e. the
-- number of set bits.
__len = function(s)
local ar = ffi.cast(ptr_to_int, s.arptr)
local popcnt = 0
for i=0,s.maxidx do
popcnt = popcnt + bytepop[bit.band(ar[i], 255)] +
bytepop[bit.band(bit.rshift(ar[i], 8), 255)] +
bytepop[bit.band(bit.rshift(ar[i], 16), 255)] +
bytepop[bit.rshift(ar[i], 24)]
end
return popcnt
end,
-- serialization
__tostring = function(s)
local size=s.maxidx+1
local ar = ffi.cast(ptr_to_int, s.arptr)
local hdr = "bitar.new("..s.maxbidx..", '"
local ofs = #hdr
local totalstrlen = ofs+8*size+2
local str = ffi.new("char [?]", totalstrlen)
ffi.copy(str, hdr, ofs)
for i=0,s.maxidx do
-- 'a' is ASCII 97
for nib=0,7 do
str[ofs + 8*i + nib] = 97 + bit.band(bit.rshift(ar[i], 4*nib), 0x0000000f)
end
end
ffi.copy(str+totalstrlen-2, "')", 2)
return ffi.string(str, totalstrlen)
end,
-- On garbage collection of the bitar, clear the array's anchor so that it
-- can be collected too.
__gc = function(s)
anchor[tostring(s.arptr)] = nil
end,
}
local bitar = ffi.metatype("struct bitar", mt)
-- Create new bit array.
function new(maxbidx, initval)
if (type(maxbidx) ~= "number" or not (maxbidx >= 0 and maxbidx <= (2^31)-2)) then
error("bad argument #1 to bitar.new (must be a number in [0..(2^31)-2])", 2)
end
if (math.floor(maxbidx) ~= maxbidx) then
error("bad argument #1 to bitar.new (must be an integral number)")
end
if (type(initval)=="string") then
-- string containing hex digits (a..p) given, for INTERNAL use
local lstr = initval
local numnibs = #lstr
assert(numnibs%8 == 0)
local size = numnibs/8
local maxidx = size-1
local ar = ffi.new("int32_t [?]", size)
local str = ffi.new("char [?]", numnibs)
ffi.copy(str, lstr, numnibs)
for i=0,maxidx do
ar[i] = 0
for nib=0,7 do
local hexdig = str[8*i + nib]
assert(hexdig >= 97 and hexdig < 97+16)
ar[i] = bit.bor(ar[i], bit.lshift(hexdig-97, 4*nib))
end
end
return bitar_from_intar(maxbidx, maxidx, ar)
else
-- User-requested bitar creation.
if (initval ~= 0 and initval ~= 1) then
error("bad argument #2 to bitar.new (must be either 0 or 1)", 2)
end
local maxidx = math.floor(maxbidx/32)
local size = maxidx+1
local ar = ffi.new("int32_t [?]", size)
if (initval==1) then
ffi.fill(ar, size*4, -1)
end
return bitar_from_intar(maxbidx, maxidx, ar)
end
end