--[[
Copyright (c) 2016-present, Facebook, Inc.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
]]--
local tnt = require 'torchnet.env'
local argcheck = require 'argcheck'
local ShuffleDataset, ResampleDataset =
torch.class('tnt.ShuffleDataset', 'tnt.ResampleDataset', tnt)
ShuffleDataset.__init = argcheck{
doc = [[
#### tnt.ShuffleDataset(@ARGP)
@ARGT
`tnt.ShuffleDataset` is a sub-class of
[tnt.ResampleDataset](#ResampleDataset) provided for convenience.
It samples uniformly from the given `dataset` with, or without
`replacement`. The chosen partition can be redrawn by calling
[resample()](#ShuffleDataset.resample).
If `replacement` is `true`, then the specified `size` may be larger than
the underlying `dataset`.
If `size` is not provided, then the new dataset size will be equal to the
underlying `dataset` size.
Purpose: the easiest way to shuffle a dataset!
]],
{name='self', type='tnt.ShuffleDataset'},
{name='dataset', type='tnt.Dataset'},
{name='size', type='number', opt=true},
{name='replacement', type='boolean', default=false},
call =
function(self, dataset, size, replacement)
if size and not replacement and size > dataset:size() then
error('size cannot be larger than underlying dataset size when sampling without replacement')
end
self.__replacement = replacement
local function sampler(dataset, idx)
return self.__perm[idx]
end
ResampleDataset.__init(self, {
dataset = dataset,
sampler = sampler,
size = size})
self:resample()
end
}
ShuffleDataset.resample = argcheck{
doc = [[
##### tnt.ShuffleDataset.resample(@ARGP)
The permutation associated to `tnt.ShuffleDataset` is fixed, such that two
calls to the same index will return the same sample from the underlying
dataset.
Call `resample()` to draw randomly a new permutation.
]],
{name='self', type='tnt.ShuffleDataset'},
call =
function(self)
self.__perm = self.__replacement
and torch.LongTensor(self:size()):random(self.__dataset:size())
or torch.randperm(self.__dataset:size()):long():narrow(1, 1, self:size())
end
}