--[[ Copyright (c) 2016-present, Facebook, Inc. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. An additional grant of patent rights can be found in the PATENTS file in the same directory. ]]-- local tnt = require 'torchnet.env' local argcheck = require 'argcheck' local IndexedDataset, Dataset = torch.class('tnt.IndexedDataset', 'tnt.Dataset', tnt) local IndexedDatasetReader = torch.class('tnt.IndexedDatasetReader', tnt) local IndexedDatasetWriter = torch.class('tnt.IndexedDatasetWriter', tnt) -- dataset built on index reader IndexedDataset.__init = argcheck{ doc = [[ #### tnt.IndexedDataset(@ARGP) @ARGT A `tnt.IndexedDataset()` is a data structure built upon (possibly several) data archives containing a bunch of tensors of the same type. See [tnt.IndexedDatasetWriter](#IndexedDatasetWriter) and [tnt.IndexedDatasetReader](#IndexedDatasetReader) to see how to create and read a single archive. Purpose: large datasets (containing a lot of files) are often not very well handled by filesystems (especially over network). `tnt.IndexedDataset` provides a convenient and efficient way to bundle them into a single archive file, associated with an indexed file. If `path` is provided, then `fields` must be a Lua array (keys being numbers), where values are string representing a filename prefix to a (index,archive) pair. In other word `path/field.{idx,bin}` must exist. The i-th sample returned by this dataset will be a table containing each field as key, and a tensor found at the corresponding archive at index i. If `path` is not provided, then `fields` must be a Lua hash. Each key represents sample fields and the corresponding value must be a table containing the keys `idx` (for the index filename path) and `bin` (for the archive filename path). If provided (and positive), `maxload` limits the dataset size to the specified size. Archives and/or indexes can also be memory mapped with the `mmap` and `mmapidx` flags. If `standalone` is true, the constructor expects only one field to be provided. The i-th sample returned by the dataset will be the item found at the archive at index i. This is particularly useful with `table` archives. ]], noordered = true, {name='self', type='tnt.IndexedDataset'}, {name='fields', type='table'}, {name='path', type='string', opt=true}, {name='maxload', type='number', opt=true}, {name='mmap', type='boolean', default=false}, {name='mmapidx', type='boolean', default=false}, {name='standalone', type='boolean', default=false}, call = function(self, fields, path, maxload, mmap, mmapidx, standalone) self.__fields = {} if path then for _, fieldname in ipairs(fields) do assert(type(fieldname) == 'string', 'fields should be a list of strings (fieldnames)') table.insert( self.__fields, { name = fieldname, idx = string.format('%s/%s.idx', path, fieldname), bin = string.format('%s/%s.bin', path, fieldname)}) end else for fieldname, paths in pairs(fields) do assert( type(fieldname) == 'string' and type(paths.bin) == 'string' and type(paths.idx) == 'string', 'fields should be a hash of string (fieldname) -> table (idx=string, bin=string)') table.insert(self.__fields, { name = fieldname, idx = paths.idx, bin = paths.bin}) end end assert(#self.__fields > 0, 'fields should not be empty') assert(not standalone or #self.__fields == 1, 'only one field expected with standalone flag at true') local size for _, field in ipairs(self.__fields) do field.data = tnt.IndexedDatasetReader(field.idx, field.bin, mmap, mmapidx) if size then assert(field.data:size() == size, 'inconsistent index data size') else size = field.data:size() end end if maxload and maxload > 0 and maxload < size then size = maxload end self.__standalone = standalone self.__size = size print(string.format("| IndexedDataset: loaded %s with %d examples", path or '', size)) end } IndexedDataset.size = argcheck{ {name='self', type='tnt.IndexedDataset'}, call = function(self) return self.__size end } IndexedDataset.get = argcheck{ {name='self', type='tnt.IndexedDataset'}, {name='idx', type='number'}, call = function(self, idx) assert(idx >= 1 and idx <= self.__size, 'index out of bound') if self.__standalone then return self.__fields[1].data:get(idx) else local sample = {} for _, field in ipairs(self.__fields) do sample[field.name] = field.data:get(idx) end return sample end end } -- supported tensor types local IndexedDatasetIndexTypes = {} for code, type in ipairs{'byte', 'char', 'short', 'int', 'long', 'float', 'double', 'table'} do local Type = (type == 'table') and 'Char' or type:sub(1,1):upper() .. type:sub(2) IndexedDatasetIndexTypes[type] = { code = code, name = string.format('torch.%sTensor', Type), read = string.format('read%s', Type), write = string.format('write%s', Type), size = torch[string.format('%sStorage', Type)].elementSize(), storage = torch[string.format('%sStorage', Type)], tensor = torch[string.format('%sTensor', Type)] } end -- index reader helper function local function readindex(self, indexfilename) local f = torch.DiskFile(indexfilename):binary() assert(f:readLong() == 0x584449544E54, "unrecognized index format") assert(f:readLong() == 1, "unsupported format version") local code = f:readLong() for typename, type in pairs(IndexedDatasetIndexTypes) do if type.code == code then self.type = type self.typename = typename end end assert(self.type, "unrecognized type") assert(f:readLong() == self.type.size, "type size do not match") self.N = f:readLong() self.S = f:readLong() self.dimoffsets = torch.LongTensor(f:readLong(self.N+1)) self.datoffsets = torch.LongTensor(f:readLong(self.N+1)) self.sizes = torch.LongTensor(f:readLong(self.S)) f:close() end -- index writer IndexedDatasetWriter.__init = argcheck{ doc = [[ ##### tnt.IndexedDatasetWriter(@ARGP) @ARGT Creates a (archive,index) file pair. The archive will contain tensors of the same specified `type`. `type` must be a string chosen in {`byte`, `char`, `short`, `int`, `long`, `float`, `double` or `table`}. `indexfilename` is the full path to the index file to be created. `datafilename` is the full path to the data archive file to be created. Tensors are added to the archive with [add()](#IndexedDataset.add). Note that you *must* call [close()](#IndexedDataset.close) to ensure all data is written on disk and to create the index file. The type `table` is special: data will be stored into a CharTensor, serialized from a Lua table object. IndexedDatasetReader will then deserialize the CharTensor into a table at read time. This allows storing heterogenous data easily into an IndexedDataset. ]], {name='self', type='tnt.IndexedDatasetWriter'}, {name='indexfilename', type='string'}, {name='datafilename', type='string'}, {name='type', type='string'}, call = function(self, indexfilename, datafilename, type) self.BLOCKSZ = 1024 self.indexfilename = indexfilename self.datafilename = datafilename assert(IndexedDatasetIndexTypes[type], 'invalid type (byte, char, short, int, long, float, double or table expected)') self.dimoffsets = torch.LongTensor(self.BLOCKSZ) self.datoffsets = torch.LongTensor(self.BLOCKSZ) self.sizes = torch.LongTensor(self.BLOCKSZ) self.N = 0 self.S = 0 self.dimoffsets[1] = 0 self.datoffsets[1] = 0 self.type = IndexedDatasetIndexTypes[type] self.datafile = torch.DiskFile(datafilename, 'w'):binary() end } -- append mode IndexedDatasetWriter.__init = argcheck{ doc = [[ ##### tnt.IndexedDatasetWriter(@ARGP) @ARGT Opens an existing (archive,index) file pair for appending. The tensor type is inferred from the provided index file. `indexfilename` is the full path to the index file to be opened. `datafilename` is the full path to the data archive file to be opened. ]], {name='self', type='tnt.IndexedDatasetWriter'}, {name='indexfilename', type='string'}, {name='datafilename', type='string'}, overload = IndexedDatasetWriter.__init, call = function(self, indexfilename, datafilename) self.BLOCKSZ = 1024 self.indexfilename = indexfilename self.datafilename = datafilename readindex(self, indexfilename) self.datafile = torch.DiskFile(datafilename, 'rw'):binary() self.datafile:seekEnd() end } IndexedDatasetWriter.add = argcheck{ doc = [[ ###### tnt.IndexedDatasetWriter.add(@ARGP) @ARGT Add a tensor to the archive and record its index position. The tensor type must of the same type than the one specified at the creation of the `tnt.IndexedDatasetWriter`. ]], {name='self', type='tnt.IndexedDatasetWriter'}, {name='tensor', type='torch.*Tensor'}, call = function(self, tensor) assert(torch.typename(tensor) == self.type.name, 'invalid tensor type') local size = tensor:size() local dim = size:size() local N = self.N + 1 local S = self.S + dim if self.dimoffsets:size(1) < N+1 then -- +1 for the first 0 value self.dimoffsets:resize(N+self.BLOCKSZ) self.datoffsets:resize(N+self.BLOCKSZ) end if self.sizes:size(1) < S then self.sizes:resize(S+self.BLOCKSZ) end self.dimoffsets[N+1] = self.dimoffsets[N] + dim self.datoffsets[N+1] = self.datoffsets[N] + tensor:nElement() if dim > 0 then self.sizes:narrow(1, self.S+1, dim):copy(torch.LongTensor(size)) end self.N = N self.S = S if tensor:nElement() > 0 then self.datafile[self.type.write](self.datafile, tensor:clone():storage()) end end } IndexedDatasetWriter.add = argcheck{ doc = [[ ###### tnt.IndexedDatasetWriter.add(@ARGP) @ARGT Convenience method which given a `filename` will open the corresponding file in `binary` mode, and reads all data in there as if it was of the type specified at the `tnt.IndexedDatasetWriter` construction. A corresponding tensor is then added to the archive/index pair. ]], {name='self', type='tnt.IndexedDatasetWriter'}, {name='filename', type='string'}, overload = IndexedDatasetWriter.add, call = function(self, filename) local f = torch.DiskFile(filename):binary() f:seekEnd() local sz = f:position()-1 f:seek(1) local storage = f[self.type.read](f, sz/self.type.size) f:close() self:add(self.type.tensor(storage)) end } IndexedDatasetWriter.add = argcheck{ doc = [[ ###### tnt.IndexedDatasetWriter.add(@ARGP) @ARGT Convenience method only available for `table` type IndexedDataset. The table will be serialized into a CharTensor. ]], {name='self', type='tnt.IndexedDatasetWriter'}, {name='table', type='table'}, nonamed = true, -- ambiguity possible with table arg overload = IndexedDatasetWriter.add, call = function(self, tbl) assert( self.type == IndexedDatasetIndexTypes.table or self.type == IndexedDatasetIndexTypes.char, 'table convenience method is only available for "table" or "char"-based datasets') tbl = torch.CharTensor(torch.serializeToStorage(tbl)) self:add(tbl) end } IndexedDatasetWriter.close = argcheck{ doc = [[ ###### tnt.IndexedDatasetWriter.add(@ARGP) @ARGT Finalize the index, and Close the archive/index filename pair. This method must be called to ensure the index is written and all the archive data is flushed on disk. ]], {name='self', type='tnt.IndexedDatasetWriter'}, call = function(self) local f = torch.DiskFile(self.indexfilename, 'w'):binary() f:writeLong(0x584449544E54) -- magic number f:writeLong(1) -- version f:writeLong(self.type.code) -- type code f:writeLong(self.type.size) f:writeLong(self.N) f:writeLong(self.S) -- resize properly underlying storages self.dimoffsets = torch.LongTensor( self.dimoffsets:storage():resize(self.N+1) ) self.datoffsets = torch.LongTensor( self.datoffsets:storage():resize(self.N+1) ) self.sizes = torch.LongTensor( self.sizes:storage():resize(self.S) ) -- write index on disk f:writeLong(self.dimoffsets:storage()) f:writeLong(self.datoffsets:storage()) f:writeLong(self.sizes:storage()) f:close() self.datafile:close() end } -- helper function that updates the meta table when data is on file: local function updatemetatable(data, datafilename) local data_mt = {} local f = torch.DiskFile(datafilename):binary():noBuffer() function data_mt:narrow(dim, offset, size) f:seek((offset - 1) * self.type.size + 1) return self.type.tensor(f[self.type.read](f, size)) end setmetatable(data, {__index = data_mt}) return data end -- index reader IndexedDatasetReader.__init = argcheck{ doc = [[ ##### tnt.IndexedDatasetReader(@ARGP) @ARGT Reads an archive/index pair previously created by [tnt.IndexedDatasetWriter](#IndexedDatasetWriter). `indexfilename` is the full path to the index file. `datafilename` is the full path to the archive file. Memory mapping can be specified for both the archive and index through the optional `mmap` and `mmapidx` flags. ]], {name='self', type='tnt.IndexedDatasetReader'}, {name='indexfilename', type='string'}, {name='datafilename', type='string'}, {name='mmap', type='boolean', default=false}, {name='mmapidx', type='boolean', default=false}, call = function(self, indexfilename, datafilename, mmap, mmapidx) self.indexfilename = indexfilename self.datafilename = datafilename if mmapidx then -- memory mapped index local idx = torch.LongStorage(indexfilename) local offset = 1 assert(idx[offset] == 0x584449544E54, "unrecognized index format") offset = offset + 1 assert(idx[offset] == 1, "unsupported format version") offset = offset + 1 local code = idx[offset] offset = offset + 1 for typename, type in pairs(IndexedDatasetIndexTypes) do if type.code == code then self.type = type self.typename = typename end end assert(self.type, "unrecognized type") assert(idx[offset] == self.type.size, "type size do not match") offset = offset + 1 self.N = idx[offset] offset = offset + 1 self.S = idx[offset] offset = offset + 1 self.dimoffsets = torch.LongTensor(idx, offset, self.N+1) offset = offset + self.N+1 self.datoffsets = torch.LongTensor(idx, offset, self.N+1) offset = offset + self.N+1 self.sizes = torch.LongTensor(idx, offset, self.S) else -- index on file readindex(self, indexfilename) end if mmap then -- memory mapped data self.data = self.type.tensor(self.type.storage(datafilename)) else -- data on file local data = {type=self.type} self.data = updatemetatable(data, datafilename) end end } function IndexedDatasetReader:__write(file) local obj = {} for k,v in pairs(self) do obj[k] = v end obj.type = nil if type(self.data) == 'table' then obj.data.type = nil end file:writeObject(obj) if type(self.data) == 'table' then obj.data.type = self.type end end function IndexedDatasetReader:__read(file) for k,v in pairs(file:readObject()) do self[k] = v end self.type = IndexedDatasetIndexTypes[self.typename] if type(self.data) == 'table' then self.data.type = self.type updatemetatable(self.data, self.datafilename) end end IndexedDatasetReader.size = argcheck{ doc = [[ ###### tnt.IndexedDatasetReader.size(@ARGP) Returns the number of tensors present in the archive. ]], {name='self', type='tnt.IndexedDatasetReader'}, call = function(self) return self.N end } IndexedDatasetReader.get = argcheck{ doc = [[ ###### tnt.IndexedDatasetReader.get(@ARGP) Returns the tensor at the specified `index` in the archive. ]], {name='self', type='tnt.IndexedDatasetReader'}, {name='index', type='number'}, call = function(self, index) assert(index > 0 and index <= self.N, 'index out of range') local ndim = self.dimoffsets[index+1]-self.dimoffsets[index] if ndim == 0 then return self.type.tensor() end local size = self.sizes:narrow( 1, self.dimoffsets[index]+1, ndim ) size = size:clone():storage() local data = self.data:narrow( 1, self.datoffsets[index]+1, self.datoffsets[index+1]-self.datoffsets[index] ):view(size) if self.type == IndexedDatasetIndexTypes.table then return torch.deserializeFromStorage(data) else return data:clone() end end }