-- "Bit array" module based on LuaJIT's BitOp. local bit = require "bit" local math = require "math" local ffi = require "ffi" local assert = assert local error = error local type = type local tostring = tostring module(...) local bitar_ct = ffi.typeof("struct { const double maxbidx, maxidx; const intptr_t arptr; }") local ptr_to_int = ffi.typeof("int32_t *") local anchor = {} -- population count of a nibble local nibpop = ffi.new("double [?]", 16, { 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4 }) -- ...and of a byte local bytepop = ffi.new("double [?]", 256) for i=0,255 do bytepop[i] = nibpop[bit.band(i, 15)] + nibpop[bit.rshift(i, 4)] end nibpop = nil local function bitar_from_intar(maxbidx, maxidx, ar) -- We need to have the int32_t[?] array be reachable so that it will not be -- garbage collected local ar_intptr = ffi.cast("intptr_t", ar) anchor[tostring(ar_intptr)] = ar -- Leaving the (potential) high trailing bits at 0 lets us not worry -- about them in the population count calculation (__len metamethod). -- Also, this is correct for maxbidx%32 == 0, since BitOp's shifts -- mask the 5 lower bits of the counts. local numremain = bit.band(maxbidx+1, 31) ar[maxidx] = bit.band(ar[maxidx], bit.rshift(-1, 32-numremain)) return bitar_ct(maxbidx, maxidx, ar_intptr) end local function setop_common_rel(s1, s2) if (s1.maxbidx ~= s2.maxbidx) then error("bad arguments to bit array set op: must be of same length", 4) end local ar1 = ffi.cast(ptr_to_int, s1.arptr) local ar2 = ffi.cast(ptr_to_int, s2.arptr) return ar1, ar2 end local function setop_common(s1, s2) if (not ffi.istype(bitar_ct, s1) or not ffi.istype(bitar_ct, s2)) then error("bad arguments to bit array set op: both must be 'bitar' types", 3) end local ar1, ar2 = setop_common_rel(s1, s2) local ar = ffi.new("int32_t [?]", s1.maxidx+1) return ar, ar1, ar2 end local mt = { --- Operational methods __add = function(s1, s2) -- set union local ar, ar1, ar2 = setop_common(s1, s2) for i=0,s1.maxidx do ar[i] = bit.bor(ar1[i], ar2[i]) end return bitar_from_intar(s1.maxbidx, s1.maxidx, ar) end, __mul = function(s1, s2) -- set intersection local ar, ar1, ar2 = setop_common(s1, s2) for i=0,s1.maxidx do ar[i] = bit.band(ar1[i], ar2[i]) end return bitar_from_intar(s1.maxbidx, s1.maxidx, ar) end, __sub = function(s1, s2) -- set difference local ar, ar1, ar2 = setop_common(s1, s2) for i=0,s1.maxidx do ar[i] = bit.band(ar1[i], bit.bnot(ar2[i])) end return bitar_from_intar(s1.maxbidx, s1.maxidx, ar) end, __unm = function(s) -- bitwise NOT local newar = ffi.new("int32_t [?]", s.maxidx+1) local oldar = ffi.cast(ptr_to_int, s.arptr) for i=0,s.maxidx do newar[i] = bit.bnot(oldar[i]) end return bitar_from_intar(s.maxbidx, s.maxidx, newar) end, --- Additional operations __index = { -- Is bit i set? isset = function(s, i) if (not (i >= 0 and i<=s.maxbidx)) then error("bad bit index for isset: must be in [0.."..s.maxbidx.."]", 2) end s = ffi.cast(ptr_to_int, s.arptr) return (bit.band(s[bit.rshift(i, 5)], bit.lshift(1, i)) ~= 0) end, -- Clear bit i. set0 = function(s, i) if (not (i >= 0 and i<=s.maxbidx)) then error("bad bit index for set0: must be in [0.."..s.maxbidx.."]", 2) end s = ffi.cast(ptr_to_int, s.arptr) local jx = bit.rshift(i, 5) s[jx] = bit.band(s[jx], bit.rol(0xfffffffe, i)) end, -- Set bit i. set1 = function(s, i) if (not (i >= 0 and i<=s.maxbidx)) then error("bad bit index for set1: must be in [0.."..s.maxbidx.."]", 2) end s = ffi.cast(ptr_to_int, s.arptr) local jx = bit.rshift(i, 5) s[jx] = bit.bor(s[jx], bit.rol(0x00000001, i)) end }, --- Relational methods __eq = function(s1, s2) -- set identity local ar1, ar2 = setop_common_rel(s1, s2) for i=0,s1.maxidx do if (bit.bxor(ar1[i], ar2[i]) ~= 0) then return false end end return true end, __le = function(s1, s2) local ar1, ar2 = setop_common_rel(s1, s2) for i=0,s1.maxidx do if (bit.band(ar1[i], bit.bnot(ar2[i])) ~= 0) then return false end end return true end, __lt = function(s1, s2) return s1 <= s2 and not (s2 == s1) end, -- The length operator gets the population count of the bit array, i.e. the -- number of set bits. __len = function(s) local ar = ffi.cast(ptr_to_int, s.arptr) local popcnt = 0 for i=0,s.maxidx do popcnt = popcnt + bytepop[bit.band(ar[i], 255)] + bytepop[bit.band(bit.rshift(ar[i], 8), 255)] + bytepop[bit.band(bit.rshift(ar[i], 16), 255)] + bytepop[bit.rshift(ar[i], 24)] end return popcnt end, -- serialization __tostring = function(s) local size=s.maxidx+1 local ar = ffi.cast(ptr_to_int, s.arptr) local hdr = "bitar.new("..s.maxbidx..", '" local ofs = #hdr local totalstrlen = ofs+8*size+2 local str = ffi.new("char [?]", totalstrlen) ffi.copy(str, hdr, ofs) for i=0,s.maxidx do -- 'a' is ASCII 97 for nib=0,7 do str[ofs + 8*i + nib] = 97 + bit.band(bit.rshift(ar[i], 4*nib), 0x0000000f) end end ffi.copy(str+totalstrlen-2, "')", 2) return ffi.string(str, totalstrlen) end, -- On garbage collection of the bitar, clear the array's anchor so that it -- can be collected too. __gc = function(s) anchor[tostring(s.arptr)] = nil end, } local bitar = ffi.metatype(bitar_ct, mt) -- Create new bit array. function new(maxbidx, initval) if (type(maxbidx) ~= "number" or not (maxbidx >= 0 and maxbidx <= (2^31)-2)) then error("bad argument #1 to bitar.new (must be a number in [0..(2^31)-2])", 2) end if (math.floor(maxbidx) ~= maxbidx) then error("bad argument #1 to bitar.new (must be an integral number)") end if (type(initval)=="string") then -- string containing hex digits (a..p) given, for INTERNAL use local lstr = initval local numnibs = #lstr assert(numnibs%8 == 0) local size = numnibs/8 local maxidx = size-1 local ar = ffi.new("int32_t [?]", size) local str = ffi.new("char [?]", numnibs) ffi.copy(str, lstr, numnibs) for i=0,maxidx do ar[i] = 0 for nib=0,7 do local hexdig = str[8*i + nib] assert(hexdig >= 97 and hexdig < 97+16) ar[i] = bit.bor(ar[i], bit.lshift(hexdig-97, 4*nib)) end end return bitar_from_intar(maxbidx, maxidx, ar) else -- User-requested bitar creation. if (initval ~= 0 and initval ~= 1) then error("bad argument #2 to bitar.new (must be either 0 or 1)", 2) end local maxidx = math.floor(maxbidx/32) local size = maxidx+1 local ar = ffi.new("int32_t [?]", size) if (initval==1) then ffi.fill(ar, size*4, -1) end return bitar_from_intar(maxbidx, maxidx, ar) end end