--[[ - @author Bruno Massa - @file database.lua - Ajato has a entire layer between SQL Databases - and modules. It helps them to add fieldial and complex - instructions easely. ]] --[[ - @defgroup database Database abstraction layer - @{ - Allow the use of different database servers using the same code base. - - Ajato provides a slim database abstraction layer to provide developers with - the ability to support multiple database servers easily. The intent of this - layer is to preserve the syntax and power of SQL as much as possible, while - letting Ajato control the pieces of queries that need to be written - differently for different servers and provide basic security checks. - - Most Ajato database queries are performed by a call to db_query() or - db_query_range(). Module authors should also consider using pager_query() for - queries that return results that need to be presented on multiple pages, and - tablesort_sql() for generating appropriate queries for sortable tables. - - For example, one might wish to return a list of the most recent 10 nodes - authored by a given user. Instead of directly issuing the SQL query - @code - SELECT n.title, n.body, n.created FROM node n WHERE n.uid = uid LIMIT 0, 10; - @endcode - one would instead call the Ajato functions: - @code - result = db_query_range('SELECT n.title, n.body, n.created - FROM {node} n WHERE n.uid = %d', 0, 10, uid); - for node = db_fetch_object(result) do - -- Perform operations on node['body'], etc. here. - end - @endcode - Curly braces are used around "node" to provide table prefixing. The explicit - use of a user ID is pulled out into an argument passed to db_query() so - that SQL injection attacks from user input can be caught and nilified. - The LIMIT syntax varies between database servers, so that is abstracted - into db_query_range() arguments. Finally, note the common pattern of - iterating over the result set. ]] --[[ - Connect to a SQL server. Ajato needs a SQL connection to a - SQL server (MySQL, PostgreSQL or any other supported - by Kepler/LuaSQL) - - @param db_type - String, the SQL server (mysql, postgres, odbc, oci8) - @param db_name - Any type, a series of argumments that will be substituted - @param db_username - Number, how much items per page. Default is 10 - @param db_password - String, the SQL instruction destinated to count the total items. Default is nil - @param db_host - String, the URL where the database server is located. Default is nil - @param db_port - Number, the port where the database server can be accessed. Default is nil ]] function db_init() -- Check if required variables are set if not sys['db_server'] or not sys['db_name'] then return false end -- Chekc if the database is already running if sys['db_conn'] then return false end -- Load the LuaSQL module require('luasql.'.. sys['db_server']) local server, err = luasql.mysql() if not server then ajato_add_message(t('This database is not supported or it is not installed.'), 'error') return false end -- Start the database connection sys['db_conn'], err = server:connect(sys['db_name'], sys['db_user'], sys['db_password'], sys['db_host'], tonumber(sys['db_port'])) if not sys['db_conn'] then ajato_add_message(t('No connection to the database.'), 'error') return false end -- Everything is ok return true end --[[ - Returns the last insert id. - - On each DBMS, the implementation of this feature is - different. So, we need to implement a customized command - to each of them. - - @param db_table - String, The name of the table you inserted into. - @param field - String, The name of the autoincrement field. - @return - Number, the last ID created - @note ensure all Big 6 DBMS are covered ]] function db_last_insert_id(db_table, field) if sys['db_server'] == 'mysql' then return db_result(db_query('SELECT LAST_INSERT_ID()')) elseif sys['db_server'] == 'postgres' then return db_result(db_query("SELECT currval('%s_seq')", db_prefix_tables('{'.. db_table ..'}') ..'_'.. field)) end end --[[ - Execute a SQL instruction as pagination - - Sometime users want to create a 'pager effect': limit the results - and show a '< previous 1 2 3 next >' to navigate. - - @param sql - String, the SQL instruction - @param variables - Any type, a series of argumments that will be substituted - @param limit - Number, how much items per page. Default is 10 - @param skip - Number, how many items will be skiped - @param sql_count - String, the SQL instruction destinated to count the total items. Default is nil - @param name - String, a indentifier - @return ]] function db_pager(sql, limit, skip, sql_count, name, ...) limit = limit or 1 skip = skip or 0 return db_query(sql ..' LIMIT '.. limit ..' OFFSET '.. skip, ...) end --[[ - Execute a SQL instruction. To avoid SQL injection, Ajato - deal with string substitution on SQL instructions. The - resulting SQL instruction is then executed. - - @param sql - String, the SQL instruction - @param ... - Table, a series of argumments that will be substituted ]] function db_query(sql, ...) if type(sql) ~= 'string' then return nil end -- Ajato uses prefixes to coordinate different installations -- into the same database. Marking all table names like {table_name} -- will ensure the correct prefix will be added. sql = string.gsub(sql, '{(.-)}', (sys['db_prefix'] or '') ..'%1') -- Insert all variables into the SQL instruction. -- It avoids SQL injection problems if ... then local arguments = {} for _, value in pairs({...}) do if type(value) == 'table' then for _, value2 in pairs(value) do if type(value2) == 'string' then arguments[#arguments + 1] = string.gsub(value2, '%%', '%%%%') else arguments[#arguments + 1] = value2 end end elseif type(value) == 'string' then arguments[#arguments + 1] = string.gsub(value, '%%', '%%%%') else arguments[#arguments + 1] = value end end sql = string.format(sql, unpack(arguments)) end -- Execute the SQL instruction local sql_result, err = sys.db_conn:execute(sql, unpack({...})) -- Return the error, if any if err then ajato_add_message(err .. string.format(': %q', sql), 'error') end return sql_result end --[[ - Execute a SQL instruction, but only a limited number of results - are wanted. - - @param sql - String, the SQL instruction - @param rows_from - String, starting from this row - @param rows_qty - Number, how many results are needed - @param ... - Table, a series of argumments that will be substituted ]] function db_query_range(sql, rows_from, rows_qty, ...) return db_query(sql ..' LIMIT '.. rows_from ..','.. rows_qty, ...) end --[[ - Rewrites node, taxonomy and comment queries. Use it for listing queries. Do not - use FROM table1, table2 syntax, use JOIN instead. - - @param query - String, Query to be rewritten. - @param primary_table - String, Name or alias of the table which has the primary key field for this query. - Possible values are: comments, forum, node, menu, term_data, vocabulary. - @param primary_field - String, Name of the primary field. - @param args - Table, Further arguments, passed to the implementations of hook_db_rewrite_sql. - @return - String, The original query with JOIN and WHERE statements inserted from hook_db_rewrite_sql - implementations. nid is rewritten if needed. ]] function db_rewrite_sql(sql, primary_table, primary_field, args) primary_table = primary_table or 'n' primary_field = primary_field or 'nid' args = args or {} return sql end --[[ - Get a single value from the query - - @param sql_object - Table, the SQL table with the results - @return - Number or String, the qyery result ]] function db_result(sql_object) if type(sql_object) == 'userdata' then return sql_object:fetch() end end --[[ - Get a list of values, according to the SQL instruction. They should - provide a valid SQL resulting table, from db_query, db_query_range - or db_pager. - - @param sql_object - Table, the SQL table with the results - @return - Function, that gives the next result ]] function db_results(sql_object) if type(sql_object) == 'userdata' then return sql_object:fetch({}, 'a') end end --[[ - When the SQL instruction has many results, this function - will help developers to walk them one by one. They should - provide a valid SQL resulting table, from db_query, db_query_range - or db_pager. - - @param sql_object - Table, the SQL table with the results - @return - Function, that gives the next result ]] function db_rows(sql_object) if type(sql_object) == 'userdata' then return function () return sql_object:fetch({}, 'a') end end end --[[ - Generate a SQL instruction to create a new table - from a SQL API table. Developers dont need to know - SQL to deal with the database. It prevents some errors - too. - - @param name - String, the name of the table to create. - @param sql_table - Table, original SQL API table - @return - Table, SQL statements to create the table ]] function db_sqlapi_create_table(name, sql_table) local sql_fields = {} local sql = {} -- Add the SQL statement for each field. for field_name, field in pairs(sql_table['fields']) do sql_fields[#sql_fields + 1] = field_name ..' '.. db_sqlapi_create_table_field(field) end -- Process keys & indexes. if sql_table['primary key'] then sql_fields[#sql_fields + 1] = 'PRIMARY KEY ('.. table.concat(sql_table['primary key'], ',') ..')' end -- Add unique keys if sql_table['unique keys'] then for key_name, key in pairs(sql_table['unique keys']) do sql_fields[#sql_fields + 1] = 'CONSTRAINT {'.. name ..'}_'.. key_name ..'_key UNIQUE ('.. table.concat(key, ',') ..')' end end sql[1] = 'CREATE TABLE {'.. name ..'} ('.. table.concat(sql_fields, ',') .. ')' if sys['db_server'] == 'mysql' then sql[1] = sql[1] .. ' /*!40100 DEFAULT CHARACTER SET UTF8 */ ' end -- Add indexes if sql_table['indexes'] then for key_name, key in pairs(sql_table['indexes']) do sql[#sql + 1] = 'CREATE INDEX {'.. name ..'}_'.. key_name .. '_idx ON {'.. name ..'} ('.. db_sqlapi_create_table_key(key) ..')' end end return sql end --[[ - Generate single line with the field definition. - - @param field - Table, the field data - @return - String, the field definition in SQL instruction ]] function db_sqlapi_create_table_field(field) local sql sql = field['type'] -- The fields with type SERIAL has some special procedures if field['type'] == 'serial' and sys['db_server'] == 'mysql' then sql = ' INT UNSIGNED AUTO_INCREMENT' end -- Limit the data lenght if field['length'] then sql = sql .. '('.. field['length'] ..')' elseif field['precision'] and field['scale'] then sql = sql .. '('.. field['scale'] ..', '.. field['precision'] ..')' end -- Add the Unsigned (only positive numbers) flag if field['unsigned'] and field['type'] ~= 'serial' then sql = sql .. ' UNSIGNED' end -- Add the NOT NULL flag if field['not null'] then sql = sql .. ' NOT NULL' end -- Add the DEFAULT value if field['default'] then if type(field['default']) == 'string' then field['default'] = string.format('%q', field['default']) end sql = sql .. ' DEFAULT '.. field['default'] end return sql end --[[ - List all fields used into this key definition - - @param field - Table, the field data: name and size - @return - String, the list with the all fields used by this a table key ]] function db_sqlapi_create_table_key(fields) local ret = {} for _, field in pairs(fields) do if type(field) == 'table' then if sys['db_server'] == 'postgres' then ret[#ret + 1] = 'substr('.. field[1] ..', 1, '.. field[2] ..')' else ret[#ret + 1] = field[1] ..'('.. field[2] ..')' end else ret[#ret + 1] = field end end return table.concat(ret, ',') end --[[ - Install tables using SQL API tables - - @param module - String, the module schema ]] function db_sqlapi_install(module) local tables = _G[module ..'_schema']() for table_name, sql_table in pairs(tables) do local instructions = db_sqlapi_create_table(table_name, sql_table) for _, instruction in pairs(instructions) do db_query(instruction) end end end --[[ - Uninstall tables using SQL API tables - - @param module - String, the module schema ]] function db_sqlapi_uninstall(module) local tables = _G[module]['schema']() if tables then for table_name in pairs(tables) do db_query('DROP TABLE {'.. table_name ..'}') end end end --[[ - @} End of "defgroup database". ]]