--- List comprehensions implemented in Lua. -- @class module -- @name pl.comprehension -- -- http://lua-users.org/wiki/ListComprehensions -- -- Example: -- local comp = require 'comprehension' . new() -- assert(comp 'sum(x^2 for x)' {2,3,4} == 2^2+3^2+4^2) -- -- (c) 2008 David Manura. Licensed under the same terms as Lua (MIT license). -- local assert = assert local loadstring = loadstring local tonumber = tonumber local math_max = math.max local table_concat = table.concat local getfenv = getfenv local setfenv = setfenv local ipairs = ipairs local setmetatable = setmetatable local _G = _G local lb = require "pl.luabalanced" local utils = require 'pl.utils' -- fold operations -- http://en.wikipedia.org/wiki/Fold_(higher-order_function) local ops = { list = {init=' {} ', accum=' __result[#__result+1] = (%s) '}, table = {init=' {} ', accum=' local __k, __v = %s __result[__k] = __v '}, sum = {init=' 0 ', accum=' __result = __result + (%s) '}, min = {init=' nil ', accum=' local __tmp = %s ' .. ' if __result then if __tmp < __result then ' .. '__result = __tmp end else __result = __tmp end '}, max = {init=' nil ', accum=' local __tmp = %s ' .. ' if __result then if __tmp > __result then ' .. '__result = __tmp end else __result = __tmp end '}, } -- Parses comprehension string . -- Returns output expression list string, array of for types -- ('=', 'in' or nil) , array of input variable name -- strings , array of input variable value strings -- , array of predicate expression strings , -- operation name string , and number of placeholder -- parameters . -- -- The is equivalent to the mathematical set-builder notation: -- -- { | in , } -- -- Examples: -- "x^2 for x" -- array values -- "x^2 for x=1,10,2" -- numeric for -- "k^v for k,v in pairs(_1)" -- iterator for -- "(x+y)^2 for x for y if x > y" -- nested -- local function parse_comprehension(expr) local t = {} local pos = 1 -- extract opname (if exists) local opname local tok, post = expr:match('^%s*([%a_][%w_]*)%s*%(()', pos) local pose = #expr + 1 if tok then local tok2, posb = lb.match_bracketed(expr, post-1) assert(tok2, 'syntax error') if expr:match('^%s*$', posb) then opname = tok pose = posb - 1 pos = post end end opname = opname or "list" -- extract out expression list local out; out, pos = lb.match_explist(expr, pos) assert(out, "syntax error: missing expression list") out = table_concat(out, ', ') -- extract "for" clauses local fortypes = {} local invarlists = {} local invallists = {} while 1 do local post = expr:match('^%s*for%s+()', pos) if not post then break end pos = post -- extract input vars local iv; iv, pos = lb.match_namelist(expr, pos) assert(#iv > 0, 'syntax error: zero variables') for _,ident in ipairs(iv) do assert(not ident:match'^__', "identifier " .. ident .. " may not contain __ prefix") end invarlists[#invarlists+1] = iv -- extract '=' or 'in' (optional) local fortype, post = expr:match('^(=)%s*()', pos) if not fortype then fortype, post = expr:match('^(in)%s+()', pos) end if fortype then pos = post -- extract input value range local il; il, pos = lb.match_explist(expr, pos) assert(#il > 0, 'syntax error: zero expressions') assert(fortype ~= '=' or #il == 2 or #il == 3, 'syntax error: numeric for requires 2 or three expressions') fortypes[#invarlists] = fortype invallists[#invarlists] = il else fortypes[#invarlists] = false invallists[#invarlists] = false end end assert(#invarlists > 0, 'syntax error: missing "for" clause') -- extract "if" clauses local preds = {} while 1 do local post = expr:match('^%s*if%s+()', pos) if not post then break end pos = post local pred; pred, pos = lb.match_expression(expr, pos) assert(pred, 'syntax error: predicated expression not found') preds[#preds+1] = pred end -- extract number of parameter variables (name matching "_%d+") local stmp = ''; lb.gsub(expr, function(u, sin) -- strip comments/strings if u == 'e' then stmp = stmp .. ' ' .. sin .. ' ' end end) local max_param = 0; stmp:gsub('[%a_][%w_]*', function(s) local s = s:match('^_(%d+)$') if s then max_param = math_max(max_param, tonumber(s)) end end) if pos ~= pose then assert(false, "syntax error: unrecognized " .. expr:sub(pos)) end --DEBUG: --print('----\n', string.format("%q", expr), string.format("%q", out), opname) --for k,v in ipairs(invarlists) do print(k,v, invallists[k]) end --for k,v in ipairs(preds) do print(k,v) end return out, fortypes, invarlists, invallists, preds, opname, max_param end -- Create Lua code string representing comprehension. -- Arguments are in the form returned by parse_comprehension. local function code_comprehension( out, fortypes, invarlists, invallists, preds, opname, max_param ) local op = assert(ops[opname]) local code = op.accum:gsub('%%s', out) for i=#preds,1,-1 do local pred = preds[i] code = ' if ' .. pred .. ' then ' .. code .. ' end ' end for i=#invarlists,1,-1 do if not fortypes[i] then local arrayname = '__in' .. i local idx = '__idx' .. i code = ' for ' .. idx .. ' = 1, #' .. arrayname .. ' do ' .. ' local ' .. invarlists[i][1] .. ' = ' .. arrayname .. '['..idx..'] ' .. code .. ' end ' else code = ' for ' .. table_concat(invarlists[i], ', ') .. ' ' .. fortypes[i] .. ' ' .. table_concat(invallists[i], ', ') .. ' do ' .. code .. ' end ' end end code = ' local __result = ( ' .. op.init .. ' ) ' .. code return code end -- Convert code string represented by code_comprehension -- into Lua function. Also must pass ninputs = #invarlists, -- max_param, and invallists (from parse_comprehension). -- Uses environment env. local function wrap_comprehension(code, ninputs, max_param, invallists, env) assert(ninputs > 0) local ts = {} for i=1,max_param do ts[#ts+1] = '_' .. i end for i=1,ninputs do if not invallists[i] then local name = '__in' .. i ts[#ts+1] = name end end if #ts > 0 then code = ' local ' .. table_concat(ts, ', ') .. ' = ... ' .. code end code = code .. ' return __result ' --print('DEBUG:', code) local f, err = loadstring(code) if not f then assert(false, err .. ' with generated code ' .. code) end setfenv(f, env) return f end -- Build Lua function from comprehension string. -- Uses environment env. local function build_comprehension(expr, env) local out, fortypes, invarlists, invallists, preds, opname, max_param = parse_comprehension(expr) local code = code_comprehension( out, fortypes, invarlists, invallists, preds, opname, max_param) local f = wrap_comprehension(code, #invarlists, max_param, invallists, env) return f end -- Creates new comprehension cache. -- Any list comprehension function created are set to the environment -- env (defaults to caller of new). local function new(env) -- Note: using a single global comprehension cache would have had -- security implications (e.g. retrieving cached functions created -- in other environments). -- The cache lookup function could have instead been written to retrieve -- the caller's environment, lookup up the cache private to that -- environment, and then looked up the function in that cache. -- That would avoid the need for this call to -- explicitly manage caches; however, that might also have an undue -- performance penalty. env = env or getfenv(2) local mt = {} local cache = setmetatable({}, mt) -- Index operator builds, caches, and returns Lua function -- corresponding to comprehension expression string. -- -- Example: f = comprehension['x^2 for x'] -- function mt:__index(expr) local f = build_comprehension(expr, env) self[expr] = f -- cache return f end -- Convenience syntax. -- Allows comprehension 'x^2 for x' instead of comprehension['x^2 for x']. mt.__call = mt.__index cache.new = new return cache end local comprehension = {} comprehension.new = new -- a default instance local C = new() utils.add_function_factory(getmetatable "",function(s) return C(s) end) return comprehension