Lunatic: rewrite bitar to use arrays of int32, more convenience ops.

git-svn-id: https://svn.eduke32.com/eduke32@2872 1a8010ca-5511-0410-912e-c29ae57300e0
This commit is contained in:
helixhorned 2012-08-10 19:11:43 +00:00
parent 761c2d1c84
commit 26df580dac
3 changed files with 274 additions and 101 deletions

View file

@ -1,83 +1,203 @@
-- "Bit array" module based on LuaJIT's BitOp.
local ffi = require "ffi"
local bit = require "bit"
local math = require "math"
local ffi = require "ffi"
local error = error
local assert = assert
local error = error
local type = type
local setmetatable=setmetatable
local tostring = tostring
module(...)
-- Is bit i set in bit array ar?
function isset(ar, i)
return bit.band(ar[bit.rshift(i, 5)], bit.lshift(1, i)) ~= 0
ffi.cdef[[
struct bitar { const double maxbidx, maxidx; const intptr_t arptr; }
]]
local bitar_ct = ffi.typeof("struct bitar")
local ptr_to_int = ffi.typeof("int32_t *")
local anchor = {}
-- population count of a nibble
local nibpop = ffi.new("double [?]", 16,
{ 0, 1, 1, 2, 1, 2, 2, 3,
1, 2, 2, 3, 2, 3, 3, 4 })
-- ...and of a byte
local bytepop = ffi.new("double [?]", 256)
for i=0,255 do
bytepop[i] = nibpop[bit.band(i, 15)] + nibpop[bit.rshift(i, 4)]
end
nibpop = nil
local function bitar_from_intar(maxbidx, maxidx, ar)
-- We need to have the int32_t[?] array be reachable so that it will not be
-- garbage collected
local ar_intptr = ffi.cast("intptr_t", ar)
anchor[tostring(ar_intptr)] = ar
-- Leaving the (potential) high trailing bits at 0 lets us not worry
-- about them in the population count calculation (__len metamethod).
-- Also, this is correct for maxbidx%32 == 0, since BitOp's shifts
-- mask the 5 lower bits of the counts.
local numremain = bit.band(maxbidx+1, 31)
ar[maxidx] = bit.band(ar[maxidx], bit.rshift(-1, 32-numremain))
return bitar_ct(maxbidx, maxidx, ar_intptr)
end
-- Set bit j in bit array ar.
function set0(ar, j)
local jx = bit.rshift(j, 5)
ar[jx] = bit.band(ar[jx], bit.rol(0xfffffffe, j))
end
-- Clear bit j in bit array ar.
function set1(ar, j)
local jx = bit.rshift(j, 5)
ar[jx] = bit.bor(ar[jx], bit.rol(0x00000001, j))
end
local ops = { isset=isset, set0=set0, set1=set1 }
local mt
mt = {
__index=ops,
-- set ops disguised as arithmetic ones...
__mul = function(ar1, ar2) -- set intersection
assert(#ar1 == #ar2)
local p = {}
for i=0,#ar1 do
p[i] = bit.band(ar1[i], ar2[i])
local function setop_common_rel(s1, s2)
if (s1.maxbidx ~= s2.maxbidx) then
error("bad arguments to bit array set op: must be of same length", 4)
end
return setmetatable(p, mt)
local ar1 = ffi.cast(ptr_to_int, s1.arptr)
local ar2 = ffi.cast(ptr_to_int, s2.arptr)
return ar1, ar2
end
local function setop_common(s1, s2)
if (not ffi.istype(bitar_ct, s1) or not ffi.istype(bitar_ct, s2)) then
error("bad arguments to bit array set op: both must be 'bitar' types", 3)
end
local ar1, ar2 = setop_common_rel(s1, s2)
local ar = ffi.new("int32_t [?]", s1.maxidx+1)
return ar, ar1, ar2
end
local mt = {
--- Operational methods
__add = function(s1, s2) -- set union
local ar, ar1, ar2 = setop_common(s1, s2)
for i=0,s1.maxidx do
ar[i] = bit.bor(ar1[i], ar2[i])
end
return bitar_from_intar(s1.maxbidx, s1.maxidx, ar)
end,
__add = function(ar1, ar2) -- set union
assert(#ar1 == #ar2)
local p = {}
for i=0,#ar1 do
p[i] = bit.bor(ar1[i], ar2[i])
__mul = function(s1, s2) -- set intersection
local ar, ar1, ar2 = setop_common(s1, s2)
for i=0,s1.maxidx do
ar[i] = bit.band(ar1[i], ar2[i])
end
return setmetatable(p, mt)
return bitar_from_intar(s1.maxbidx, s1.maxidx, ar)
end,
__sub = function(ar1, ar2) -- set difference
assert(#ar1 == #ar2)
local p = {}
for i=0,#ar1 do
p[i] = bit.band(ar1[i], bit.bnot(ar2[i]))
__sub = function(s1, s2) -- set difference
local ar, ar1, ar2 = setop_common(s1, s2)
for i=0,s1.maxidx do
ar[i] = bit.band(ar1[i], bit.bnot(ar2[i]))
end
return setmetatable(p, mt)
return bitar_from_intar(s1.maxbidx, s1.maxidx, ar)
end,
__unm = function(s) -- bitwise NOT
local newar = ffi.new("int32_t [?]", s.maxidx+1)
local oldar = ffi.cast(ptr_to_int, s.arptr)
for i=0,s.maxidx do
newar[i] = bit.bnot(oldar[i])
end
return bitar_from_intar(s.maxbidx, s.maxidx, newar)
end,
--- Additional operations
__index = {
-- Is bit i set?
isset = function(s, i)
if (not (i >= 0 and i<=s.maxbidx)) then
error("bad bit index for isset: must be in [0.."..s.maxbidx.."]", 2)
end
s = ffi.cast(ptr_to_int, s.arptr)
return (bit.band(s[bit.rshift(i, 5)], bit.lshift(1, i)) ~= 0)
end,
-- Clear bit i.
set0 = function(s, i)
if (not (i >= 0 and i<=s.maxbidx)) then
error("bad bit index for set0: must be in [0.."..s.maxbidx.."]", 2)
end
s = ffi.cast(ptr_to_int, s.arptr)
local jx = bit.rshift(i, 5)
s[jx] = bit.band(s[jx], bit.rol(0xfffffffe, i))
end,
-- Set bit i.
set1 = function(s, i)
if (not (i >= 0 and i<=s.maxbidx)) then
error("bad bit index for set1: must be in [0.."..s.maxbidx.."]", 2)
end
s = ffi.cast(ptr_to_int, s.arptr)
local jx = bit.rshift(i, 5)
s[jx] = bit.bor(s[jx], bit.rol(0x00000001, i))
end
},
--- Relational methods
__eq = function(s1, s2) -- set identity
local ar1, ar2 = setop_common_rel(s1, s2)
for i=0,s1.maxidx do
if (bit.bxor(ar1[i], ar2[i]) ~= 0) then
return false
end
end
return true
end,
__le = function(s1, s2)
local ar1, ar2 = setop_common_rel(s1, s2)
for i=0,s1.maxidx do
if (bit.band(ar1[i], bit.bnot(ar2[i])) ~= 0) then
return false
end
end
return true
end,
__lt = function(s1, s2)
return s1 <= s2 and not (s2 == s1)
end,
-- The length operator gets the population count of the bit array, i.e. the
-- number of set bits.
__len = function(s)
local ar = ffi.cast(ptr_to_int, s.arptr)
local popcnt = 0
for i=0,s.maxidx do
popcnt = popcnt + bytepop[bit.band(ar[i], 255)] +
bytepop[bit.band(bit.rshift(ar[i], 8), 255)] +
bytepop[bit.band(bit.rshift(ar[i], 16), 255)] +
bytepop[bit.rshift(ar[i], 24)]
end
return popcnt
end,
-- serialization
__tostring = function(ar)
local maxidx=#ar
local size=maxidx+1
__tostring = function(s)
local size=s.maxidx+1
local ar = ffi.cast(ptr_to_int, s.arptr)
local hdr = "bitar.new('"
local hdr = "bitar.new("..s.maxbidx..", '"
local ofs = #hdr
local totalstrlen = ofs+8*size+2
local str = ffi.new("char [?]", totalstrlen)
ffi.copy(str, hdr, ofs)
for i=0,maxidx do
for i=0,s.maxidx do
-- 'a' is ASCII 97
for nib=0,7 do
str[ofs + 8*i + nib] = 97 + bit.band(bit.rshift(ar[i], 4*nib), 0x0000000f)
@ -88,51 +208,69 @@ mt = {
return ffi.string(str, totalstrlen)
end,
-- On garbage collection of the bitar, clear the array's anchor so that it
-- can be collected too.
__gc = function(s)
anchor[tostring(s.arptr)] = nil
end,
}
-- Create new bit array.
-- Returns a table p in which entries p[0] through p[floor((maxbidx+31)/32)]
-- are set to an initialization value: 0 if 0 has been passed, -1 if 1
-- has been passed.
-- Storage: 4 bits/bit + O(1)? (per 32 bits: 64 bits key, 64 bits value)
function new(maxbidx, initval)
local p = {}
local bitar = ffi.metatype("struct bitar", mt)
if (type(maxbidx)=="string") then
-- string containing hex digits (a..p) given, internal
local lstr = maxbidx
-- Create new bit array.
function new(maxbidx, initval)
if (type(maxbidx) ~= "number" or not (maxbidx >= 0 and maxbidx <= (2^31)-2)) then
error("bad argument #1 to bitar.new (must be a number in [0..(2^31)-2])", 2)
end
if (math.floor(maxbidx) ~= maxbidx) then
error("bad argument #1 to bitar.new (must be an integral number)")
end
if (type(initval)=="string") then
-- string containing hex digits (a..p) given, for INTERNAL use
local lstr = initval
local numnibs = #lstr
assert(numnibs%8 == 0)
local size = numnibs/8
local maxidx = size-1
local ar = ffi.new("int32_t [?]", size)
local str = ffi.new("char [?]", numnibs)
ffi.copy(str, lstr, numnibs)
for i=0,maxidx do
p[i] = 0
ar[i] = 0
for nib=0,7 do
local hexdig = str[8*i + nib]
assert(hexdig >= 97 and hexdig < 97+16)
p[i] = bit.bor(p[i], bit.lshift(hexdig-97, 4*nib))
ar[i] = bit.bor(ar[i], bit.lshift(hexdig-97, 4*nib))
end
end
return bitar_from_intar(maxbidx, maxidx, ar)
else
if (type(maxbidx) ~= "number" or not (maxbidx >= 0)) then
error("bad argument #1 to newarray (must be a nonnegative number)", 2)
end
-- User-requested bitar creation.
if (initval ~= 0 and initval ~= 1) then
error("bad argument #2 to newarray (must be either 0 or 1)", 2)
error("bad argument #2 to bitar.new (must be either 0 or 1)", 2)
end
for i=0,maxbidx/32 do
p[i] = -initval
end
local maxidx = math.floor(maxbidx/32)
local size = maxidx+1
local ar = ffi.new("int32_t [?]", size)
if (initval==1) then
ffi.fill(ar, size*4, -1)
end
return setmetatable(p, mt)
return bitar_from_intar(maxbidx, maxidx, ar)
end
end

View file

@ -3,8 +3,8 @@
-- Usage: luajit bittest.lua <number or "x"> [-ffi] [-bchk]
local string = require "string"
local math = require "math"
local bit = require("bit")
local bitar = require "bitar"
local print = print
@ -28,6 +28,7 @@ end
-- based on example from http://bitop.luajit.org/api.html
local m = string.dump and tonumber(arg[1]) or 1e7
local maxidx = math.floor(m/32)
local ffiar_p, boundchk_p = false, false
@ -46,24 +47,23 @@ function sieve()
local p = {}
if (ffiar_p) then
-- stand-alone using unchecked int32_t array instead of table:
-- on x86_64 approx. 100 vs. 160 ms for m = 1e7
-- (enabling bound checking makes it be around 170 ms)
-- stand-alone using unchecked int32_t array: on x86_64 approx. 80 ms
-- for m = 1e7 (enabling bound checking makes it be around 100 ms)
local ffi = require "ffi"
local pp = ffi.new("int32_t [?]", (m+31)/32 + 1)
local pp = ffi.new("int32_t [?]", maxidx + 1)
p = pp
if (boundchk_p) then
local mt = {
__index = function(tab,idx)
if (idx >= 0 and idx <= (m+31)/32) then
if (idx >= 0 and idx <= maxidx) then
return pp[idx]
end
end,
__newindex = function(tab,idx,val)
if (idx >= 0 and idx <= (m+31)/32) then
if (idx >= 0 and idx <= maxidx) then
pp[idx] = val
end
end,
@ -72,43 +72,73 @@ function sieve()
p = setmetatable({}, mt)
end
for i=0,(m+31)/32 do p[i] = -1; end
for i=0,maxidx do p[i] = -1; end
else
p = bitar.new(m, 1)
end
local t = getticks()
if (ffiar_p) then
local bit = require "bit"
for i=2,m do
if (bit.band(p[bit.rshift(i, 5)], bit.lshift(1, i)) ~= 0) then
count = count + 1
for j=i+i,m,i do
local jx = bit.rshift(j, 5)
p[jx] = bit.band(p[jx], bit.rol(0xfffffffe, j));
end
end
end
else
for i=2,m do
if (p:isset(i)) then
count = count + 1
for j=i+i,m,i do p:set0(j); end
end
end
end
-- When using bitar module: x86_64: approx. 110 ms
print(string.format("[%s] Found %d primes up to %d (%.02f ms)",
ffiar_p and "ffi-ar"..(boundchk_p and ", bchk" or "") or "tab-ar",
count, m, getticks()-t))
return p
return p, count
end
if (string.dump) then
local p = sieve()
local function printf(fmt, ...) print(string.format(fmt, ...)) end
local p, count = sieve()
local t = getticks()
if (ffiar_p) then
return
end
-- test serialization
local p2 = bitar.new(string.match(tostring(p), "'(.*)'"))
print(getticks()-t)
local ser = tostring(p)
local maxbidx_str = string.match(ser, '%(([0-9]+),')
local p2 = bitar.new(tonumber(maxbidx_str), string.match(ser, "'(.*)'"))
printf("serialization + new: %.02f ms", tostring(getticks()-t))
for i=0,#p do
assert(p[i]==p2[i])
assert(p==p2)
if (m >= 2) then
assert(#p == count+2) -- +2 is because 0 and 1 are set even though they're not primes
end
for i = 3,#p do
p[i] = nil
if (not ffiar_p) then
math.randomseed(os.time())
local maxbidx = math.random(0, 65536)
local p3 = bitar.new(maxbidx, 1)
assert(#p3 == maxbidx+1) -- bits 0 to maxbidx inclusive are set
end
--[[
print(p)
print(p-p) -- test set difference
print(-p)
--]]
assert(p-p == p*(-p))
end

View file

@ -592,12 +592,17 @@ local function check_literal_am(am)
end
end
local actor_ptr_ct = ffi.typeof("actor_u_t *") -- an unrestricted actor_t pointer
local con_action_ct = ffi.typeof("con_action_t")
local con_move_ct = ffi.typeof("con_move_t")
local con_ai_ct = ffi.typeof("con_ai_t")
local actor_mt = {
__index = {
-- action
set_action = function(a, act)
a = ffi.cast("actor_u_t *", a)
if (ffi.istype("con_action_t", act)) then
a = ffi.cast(actor_ptr_ct, a)
if (ffi.istype(con_action_ct, act)) then
a.t_data[4] = act.id
a.ac = act.ac
else
@ -611,8 +616,8 @@ local actor_mt = {
end,
has_action = function(a, act)
a = ffi.cast("actor_u_t *", a)
if (ffi.istype("con_action_t", act)) then
a = ffi.cast(actor_ptr_ct, a)
if (ffi.istype(con_action_ct, act)) then
return (a.t_data[4]==act.id)
else
check_literal_am(act)
@ -622,26 +627,26 @@ local actor_mt = {
-- count
set_count = function(a)
ffi.cast("actor_u_t *", a).t_data[0] = 0
ffi.cast(actor_ptr_ct, a).t_data[0] = 0
end,
get_count = function(a)
return ffi.cast("actor_u_t *", a).t_data[0]
return ffi.cast(actor_ptr_ct, a).t_data[0]
end,
-- action count
reset_acount = function(a)
ffi.cast("actor_u_t *", a).t_data[2] = 0
ffi.cast(actor_ptr_ct, a).t_data[2] = 0
end,
get_acount = function(a)
return ffi.cast("actor_u_t *", a).t_data[2]
return ffi.cast(actor_ptr_ct, a).t_data[2]
end,
-- move
set_move = function(a, mov, movflags)
a = ffi.cast("actor_u_t *", a)
if (ffi.istype("con_move_t", mov)) then
a = ffi.cast(actor_ptr_ct, a)
if (ffi.istype(con_move_ct, mov)) then
a.t_data[1] = mov.id
a.mv = mov.mv
else
@ -651,15 +656,15 @@ local actor_mt = {
end
a.t_data[0] = 0
local i = a-ffi.cast("actor_u_t *", ffiC.actor[0])
local i = a-ffi.cast(actor_ptr_ct, ffiC.actor[0])
ffiC.sprite[i].hitag = movflags or 0
-- TODO: random angle moveflag
end,
has_move = function(a, mov)
a = ffi.cast("actor_u_t *", a)
if (ffi.istype("con_move_t", mov)) then
a = ffi.cast(actor_ptr_ct, a)
if (ffi.istype(con_move_ct, mov)) then
return (a.t_data[1]==mov.id)
else
check_literal_am(mov)
@ -670,10 +675,10 @@ local actor_mt = {
-- ai
set_ai = function(a, ai)
local oa = a
a = ffi.cast("actor_u_t *", a)
a = ffi.cast(actor_ptr_ct, a)
-- TODO: literal number AIs?
assert(ffi.istype("con_ai_t", ai))
assert(ffi.istype(con_ai_ct, ai))
-- NOTE: compare with gameexec.c
a.t_data[5] = ai.id
@ -685,9 +690,9 @@ local actor_mt = {
end,
has_ai = function(a, ai)
a = ffi.cast("actor_u_t *", a)
a = ffi.cast(actor_ptr_ct, a)
if (ffi.istype("con_ai_t", ai)) then
if (ffi.istype(con_ai_ct, ai)) then
return (a.t_data[5]==ai.id)
else
check_literal_am(ai)