--[[ Copyright (c) 2016-present, Facebook, Inc. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. An additional grant of patent rights can be found in the PATENTS file in the same directory. ]]-- local tnt = require 'torchnet.env' local argcheck = require 'argcheck' local utils = require 'torchnet.utils' local TransformDataset = torch.class('tnt.TransformDataset', 'tnt.Dataset', tnt) TransformDataset.__init = argcheck{ doc = [[ #### tnt.TransformDataset(@ARGP) @ARGT Given a closure `transform()`, and a `dataset`, `tnt.TransformDataset` applies the closure in an on-the-fly manner when querying a sample with `tnt.Dataset:get()`. If key is provided, the closure is applied to the sample field specified by `key` (only). The closure must return the new corresponding field value. If key is not provided, the closure is applied on the full sample. The closure must return the new sample table. The size of the new dataset is equal to the size of the underlying `dataset`. Purpose: when performing pre-processing operations, it is convenient to be able to perform on-the-fly transformations to a dataset. ]], {name='self', type='tnt.TransformDataset'}, {name='dataset', type='tnt.Dataset'}, {name='transform', type='function'}, {name='key', type='string', opt=true}, call = function(self, dataset, transform, key) self.dataset = dataset if key then function self.__transform(z, idx) assert(z[key], 'missing key in sample') z[key] = transform(z[key], idx) return z end else function self.__transform(z, idx) return transform(z, idx) end end end } TransformDataset.__init = argcheck{ doc = [[ #### tnt.TransformDataset(@ARGP) @ARGT Given a set of closures and a `dataset`, `tnt.TransformDataset` applies these closures in an on-the-fly manner when querying a sample with `tnt.Dataset:get()`. Closures are provided in `transforms`, a Lua table, where a (key,value) pair represents a (sample field name, corresponding closure to be applied to the field name). Each closure must return the new value of the corresponding field. ]], {name='self', type='tnt.TransformDataset'}, {name='dataset', type='tnt.Dataset'}, {name='transforms', type='table'}, overload = TransformDataset.__init, call = function(self, dataset, transforms) for k,v in pairs(transforms) do assert(type(v) == 'function', 'key/function table expected for transforms') end self.dataset = dataset transforms = utils.table.copy(transforms) function self.__transform(z) for key,transform in pairs(transforms) do assert(z[key], 'missing key in sample') z[key] = transform(z[key]) end return z end end } TransformDataset.size = argcheck{ {name='self', type='tnt.TransformDataset'}, call = function(self) return self.dataset:size() end } TransformDataset.get = argcheck{ {name='self', type='tnt.TransformDataset'}, {name='idx', type='number'}, call = function(self, idx) return self.__transform( self.dataset:get(idx), idx ) end }