--[[ Copyright (c) 2016-present, Facebook, Inc. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. An additional grant of patent rights can be found in the PATENTS file in the same directory. ]]-- local tnt = require 'torchnet.env' local argcheck = require 'argcheck' local SplitDataset = torch.class('tnt.SplitDataset', 'tnt.Dataset', tnt) SplitDataset.__init = argcheck{ doc = [[ #### tnt.SplitDataset(@ARGP) @ARGT Partition a given `dataset`, according to the specified `partitions`. Use the method [select()](#SplitDataset.select) to select the current partition in use. The Lua hash table `partitions` is of the form (key, value) where key is a user-chosen string naming the partition, and value is a number representing the weight (as a number between 0 and 1) or the size (in number of samples) of the corresponding partition. Partioning is achieved linearly (no shuffling). See [tnt.ShuffleDataset](#ShuffleDataset) if you want to shuffle the dataset before partitioning. The optional variable `initialpartition` specifies the partition that is loaded initially. Purpose: useful in machine learning to perform validation procedures. ]], {name='self', type='tnt.SplitDataset'}, {name='dataset', type='tnt.Dataset'}, {name='partitions', type='table'}, {name='initialpartition', type='string', opt=true}, call = function(self, dataset, partitions, initialpartition) -- created sorted list of partition names (for determinism): local partitionnames = {} for name,_ in pairs(partitions) do table.insert(partitionnames, name) end table.sort(partitionnames) -- create partition size tensor and table with partition names: self.__dataset = dataset self.__partitionsizes = torch.DoubleTensor(#partitionnames) self.__names = {} for n, key in ipairs(partitionnames) do local val = partitions[key] self.__partitionsizes[n] = val self.__names[key] = n end -- assertions: assert( self.__partitionsizes:nElement() >= 2, 'SplitDataset should have at least two partitions' ) assert( self.__partitionsizes:min() >= 0, 'some partition sizes are negative' ) assert( self.__partitionsizes:max() > 0, 'all partitions are empty' ) -- if partition sizes are fractions, convert to sizes: if self.__partitionsizes:sum() <= 1 then self.__partitionsizes = self.__partitionsizes:double() self.__partitionsizes:mul(self.__dataset:size()):floor() else assert(torch.eq(self.__partitionsizes, torch.floor(self.__partitionsizes)):all(), 'partition sizes should be integer numbers, or sum up to <= 1 ' ) end self.__partitionsizes = self.__partitionsizes:long() assert( self.__partitionsizes:sum() <= self.__dataset:size(), 'split cannot involve more samples than dataset size' ) -- select first partition: if initialpartition then self:select(initialpartition) end end } SplitDataset.select = argcheck{ doc = [[ ##### tnt.SplitDataset.select(@ARGP) @ARGT Switch the current partition in use to the one specified by `partition`, which must be a string corresponding to one of the names provided at construction. The current dataset size changes accordingly, as well as the samples returned by the `get()` method. ]], {name='self', type='tnt.SplitDataset'}, {name='partition', type='string'}, call = function(self, partition) local id = self.__names[partition] if not id then error('partition not found') end if self.__partition then self.__partition[1] = id else self.__partition = torch.LongTensor{id} end end } SplitDataset.size = argcheck{ {name='self', type='tnt.SplitDataset'}, call = function(self) assert(self.__partition, 'select a partition before accessing data') return self.__partitionsizes[self.__partition[1]] end } SplitDataset.get = argcheck{ {name='self', type='tnt.SplitDataset'}, {name='idx', type='number'}, call = function(self, idx) assert(self.__partition, 'select a partition before accessing data') assert(idx >= 1 and idx <= self:size(), 'index out of bounds') local offset = (self.__partition[1] == 1) and 0 or self.__partitionsizes:narrow(1, 1, self.__partition[1] - 1):sum() return self.__dataset:get(offset + idx) end }