--[[ Copyright (c) 2016-present, Facebook, Inc. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. An additional grant of patent rights can be found in the PATENTS file in the same directory. ]]-- local tnt = require 'torchnet.env' local Threads = require 'threads' local argcheck = require 'argcheck' local doc = require 'argcheck.doc' local ParallelDatasetIterator = torch.class('tnt.ParallelDatasetIterator', 'tnt.DatasetIterator', tnt) ParallelDatasetIterator.__init = argcheck{ doc = [[ #### tnt.ParallelDatasetIterator(@ARGP) @ARGT Allows to iterate over a dataset in a thread manner. `tnt.ParallelDatasetIterator:run()` guarantees that all samples will be seen, but does not guarantee the order unless `ordered` is set to true. The purpose of this class is to have a zero pre-processing cost. When reading datasets on the fly from disk (not loading them fully in memory), or performing complex pre-processing this can be of interest. The number of threads used to parallelize is specified by `nthread`. `init(threadid)` (where threadid=1..nthread) is a closure which may initialize the specified thread as needed, if needed. It is doing nothing by default. `closure(threadid)` will be called on each thread and must return a `tnt.Dataset` instance. `perm(idx)` is a permutation used to shuffle the examples. If shuffling is needed, one can use this closure, or (better) use [tnt.ShuffleDataset](#ShuffleDataset) on the underlying dataset (returned by `closure()`). `filter(sample)` is a closure which returns `true` if the given sample should be considered or `false` if not. Note that filter is called _after_ fetching the data in a threaded manner. `transform(sample)` is a function which maps the given sample to a new value. This transformation occurs before filtering. When `ordered` is set to true the ordering of samples returned by the iterator is guaranteed. This option is particularly useful for repeatable experiments. By default `ordered` is false, which means that order is not guaranteed by `run()` (though often the ordering is similar in practice). A common error raised by this dataset is when `closure()` is not serializable. Make sure that all [upvalues](http://www.lua.org/pil/27.3.3.html) of `closure()` are serializable. It is recommended to avoid [upvalues](http://www.lua.org/pil/27.3.3.html) at all cost, and to make sure you require all the appropriate torch packages needed to (de-)serialize `closure()` in the `init()` function. For more information, check out the [threads package](https://github.com/torch/threads), on which `tnt.ParallelDatasetIterator` relies. ]], {name='self', type='tnt.ParallelDatasetIterator'}, {name='init', type='function', default=function(idx) end}, {name='closure', type='function'}, {name='nthread', type='number'}, {name='perm', type='function', default=function(idx) return idx end}, {name='filter', type='function', default=function(sample) return true end}, {name='transform', type='function', default=function(sample) return sample end}, {name='ordered', type='boolean', default=false}, call = function(self, init, closure, nthread, perm, filter, transform, ordered) local function main(idx) gdataset = closure(idx) assert(torch.isTypeOf(gdataset, 'tnt.Dataset'), 'closure should return a Dataset class') return gdataset:size() end Threads.serialization('threads.sharedserialize') local threads, sizes = Threads(nthread, init, main) self.__threads = threads self.__nthread = nthread local size = sizes[1][1] local sample -- beware: do not put this line in loop() local sampleOrigIdx function self.run() -- make sure we are not in the middle of something threads:synchronize() -- loading size of the dataset each time run() is called threads:addjob( function() local size = gdataset:size() return size end, function(_size_) size = _size_ end ) threads:dojob() local idx = 1 local function enqueue() while idx <= size and threads:acceptsjob() do threads:addjob( function(origIdx, idx) local sample = gdataset:get(idx) collectgarbage() collectgarbage() return sample, origIdx end, function(_sample_, _origIdx_) sample, sampleOrigIdx = _sample_, _origIdx_ end, idx, perm(idx) ) idx = idx + 1 end end enqueue() local iterFunction if ordered then local curSampleIdx = 1 local storedSamples = {} -- `samplePlaceholder` stands in for samples which have been -- filtered out by the `filter` function local samplePlaceholder = {} -- Move past placeholders (filtered out samples) in -- `storedSamples` local function advancePastPlaceholders() while storedSamples[curSampleIdx] == samplePlaceholder do storedSamples[curSampleIdx] = nil curSampleIdx = curSampleIdx + 1 end end iterFunction = function() advancePastPlaceholders() -- Load into storedSamples until we find the next sample in -- the sequence or we run out of samples while storedSamples[curSampleIdx] == nil and threads:hasjob() do enqueue() threads:dojob() if threads:haserror() then threads:synchronize() end enqueue() sample = transform(sample) if filter(sample) then -- Store sample storedSamples[sampleOrigIdx] = sample else -- Mark sample as "filtered out" storedSamples[sampleOrigIdx] = samplePlaceholder end advancePastPlaceholders() end enqueue() local curSample = storedSamples[curSampleIdx] storedSamples[curSampleIdx] = nil curSampleIdx = curSampleIdx + 1 return curSample end else iterFunction = function() while threads:hasjob() do enqueue() threads:dojob() if threads:haserror() then threads:synchronize() end enqueue() sample = transform(sample) if filter(sample) then return sample end end end end return iterFunction end end } doc[[ #### tnt.ParallelDatasetIterator.execSingle(tnt.DatasetIterator, name, ...) Execute the given method `name` on the dataset corresponding to the first available thread, passing it the subsequent arguments, and returns what the `name` method returns. For example: ```lua local iterator = tnt.ParallelDatasetIterator{...} print(iterator:execSingle("size")) ``` will print the size of the dataset loaded in the first available thread. ]] ParallelDatasetIterator.execSingle = function(self, name, ...) assert(not self.__threads:hasjob(), 'cannot execSingle during loop') local args = {...} local res self.__threads:addjob( function() return gdataset:exec(name, table.unpack(args)) end, function(...) res = {...} end) self.__threads:synchronize() return table.unpack(res) end doc[[ #### tnt.ParallelDatasetIterator.exec(tnt.DatasetIterator, name, ...) Execute the given method `name` on the underlying datasets in each thread, passing to each of them the subsequent arguments, and returns a table of what the `name` method returns for each thread. For example: ```lua local iterator = tnt.ParallelDatasetIterator{...} for _, v in pairs(iterator:exec("size")) do print(v) end ``` will print the size of the datasets loaded in each thread. ]] ParallelDatasetIterator.exec = function(self, name, ...) assert(not self.__threads:hasjob(), 'cannot exec during loop') local args = {...} local res = {} self.__threads:specific(true) for i=1,self.__nthread do self.__threads:addjob(i, function() return gdataset:exec(name, table.unpack(args)) end, function(...) local r = {...} res[i] = #r > 1 and r or r[1] end) end self.__threads:specific(false) return res end