import torch
from torch.utils.data import Dataset, DataLoader, BatchSampler, RandomSampler
import numpy as np
import scipy.sparse as sparse
import scipy.sparse.sputils as sputils
import recoder.utils as utils
CSR_MATRIX_INDEX_SIZE_LIMIT = 2000
[docs]class UsersInteractions:
"""
Holds the interactions of a set of users in an interactions sparse matrix
Args:
users (np.array): users being represented.
interactions_matrix (scipy.sparse.csr_matrix): user-item interactions matrix, where ``interactions_matrix[i]``
correspond to the interactions of ``users[i]``.
"""
def __init__(self, users, interactions_matrix):
self.users = users
self.interactions_matrix = interactions_matrix
[docs]class RecommendationDataset(Dataset):
"""
Represents a :class:`torch.utils.data.Dataset` that iterates through the users interactions with items.
Indexing this dataset returns a :class:`UsersInteractions` containing the interactions
of the users in the index.
Args:
interactions_matrix (scipy.sparse.csr_matrix): the user-item interactions matrix.
target_interactions_matrix (scipy.sparse.csr_matrix, optional): the target user-item interactions
matrix. Mainly used for evaluation, representing the items to recommend.
"""
def __init__(self, interactions_matrix, target_interactions_matrix=None):
self.interactions_matrix = interactions_matrix # type: sparse.csr_matrix
self.target_interactions_matrix = target_interactions_matrix # type: sparse.csr_matrix
self.users = np.arange(self.interactions_matrix.shape[0])
self.items = np.arange(self.interactions_matrix.shape[1])
def __len__(self):
return self.interactions_matrix.shape[0]
def __getitem__(self, index):
assert sputils.issequence(index) or sputils.isintlike(index)
users = np.array(index).reshape(-1,)
extracted_sparse_matrix = self._extract(self.interactions_matrix, index)
if self.target_interactions_matrix is None:
return UsersInteractions(users=users, interactions_matrix=extracted_sparse_matrix), None
else:
extracted_target_sparse_matrix = self._extract(self.target_interactions_matrix, index)
return UsersInteractions(users=users, interactions_matrix=extracted_sparse_matrix), \
UsersInteractions(users=users, interactions_matrix=extracted_target_sparse_matrix)
def _extract(self, sparse_matrix, index):
if sputils.issequence(index) and len(index) > CSR_MATRIX_INDEX_SIZE_LIMIT:
# It happens that scipy implements the indexing of a csr_matrix with a list using
# matrix multiplication, which gets to be an issue if the size of the index list is
# large and lead to memory issues
# Reference: https://stackoverflow.com/questions/46034212/sparse-matrix-slicing-memory-error/46040827#46040827
# In order to solve this issue, simply chunk the index into smaller indices of
# size CSR_MATRIX_INDEX_SIZE_LIMIT and then stack the extracted chunks
sparse_matrix_slices = []
for offset in range(0, len(index), CSR_MATRIX_INDEX_SIZE_LIMIT):
sparse_matrix_slices.append(sparse_matrix[index[offset: offset + CSR_MATRIX_INDEX_SIZE_LIMIT]])
extracted_sparse_matrix = sparse.vstack(sparse_matrix_slices)
else:
extracted_sparse_matrix = sparse_matrix[index]
return extracted_sparse_matrix
[docs]class RecommendationDataLoader:
"""
A ``DataLoader`` similar to ``torch.utils.data.DataLoader`` that handles
:class:`RecommendationDataset` and generate batches with negative sampling.
By default, if no ``collate_fn`` is provided, the :func:`BatchCollator.collate` will
be used, and iterating through this dataloader will return a :class:`Batch` at each
iteration.
Args:
dataset (RecommendationDataset): dataset from which to load the data
batch_size (int): number of samples per batch
negative_sampling (bool, optional): whether to apply mini-batch based negative sampling or not.
num_sampling_users (int, optional): number of users to consider for mini-batch based negative
sampling. This is useful for increasing the number of negative samples while keeping the
batch-size small. If 0, then num_sampling_users will be equal to batch_size.
num_workers (int, optional): how many subprocesses to use for data loading.
collate_fn (callable, optional): A function that transforms a :class:`UsersInteractions` into
a mini-batch.
"""
def __init__(self, dataset, batch_size, negative_sampling=False,
num_sampling_users=0, num_workers=0, collate_fn=None):
self.dataset = dataset # type: RecommendationDataset
self.num_sampling_users = num_sampling_users
self.num_workers = num_workers
self.batch_size = batch_size
self.negative_sampling = negative_sampling
if self.num_sampling_users == 0:
self.num_sampling_users = batch_size
assert self.num_sampling_users >= batch_size, 'num_sampling_users should be at least equal to the batch_size'
self.batch_collator = BatchCollator(batch_size=self.batch_size, negative_sampling=self.negative_sampling)
# Wrapping a BatchSampler within a BatchSampler
# in order to fetch the whole mini-batch at once
# from the dataset instead of fetching each sample on its own
batch_sampler = BatchSampler(BatchSampler(RandomSampler(dataset),
batch_size=self.num_sampling_users, drop_last=False),
batch_size=1, drop_last=False)
if collate_fn is None:
self._collate_fn = self.batch_collator.collate
self._use_default_data_generator = True
else:
self._collate_fn = collate_fn
self._use_default_data_generator = False
self._dataloader = DataLoader(dataset, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=self._collate)
def _default_data_generator(self):
for input, target in self._dataloader:
for batch_ind in range(len(input)):
if target is None:
yield input[batch_ind], None
else:
yield input[batch_ind], target[batch_ind]
def _collate(self, batch):
_input_batch, _target_batch = utils.unzip(batch)
# _input_batch is a list of size 1, where the only
# element is the UsersInteractions batch
input = self._collate_fn(_input_batch[0])
if _target_batch[0] is None:
target = None
else:
target = self._collate_fn(_target_batch[0])
return input, target
def __iter__(self):
if self._use_default_data_generator:
return self._default_data_generator()
return self._dataloader.__iter__()
def __len__(self):
return int(np.ceil(len(self.dataset) / self.batch_collator.batch_size))
[docs]class Batch:
"""
Represents a sparse batch of users and items interactions.
Args:
users (torch.LongTensor): users that are in the batch
items (torch.LongTensor): items that are in the batch
indices (torch.LongTensor): the indices of the interactions in the sparse matrix
values (torch.LongTensor): the values of the interactions
size (torch.Size): the size of the sparse interactions matrix
"""
def __init__(self, users, items,
indices, values, size):
self.users = users
self.items = items
self.indices = indices
self.values = values
self.size = size
[docs]class BatchCollator:
"""
Collator of :class:`UsersInteractions`. It collates the users interactions into multiple :class:`Batch`
based on ``batch_size``.
Args:
batch_size (int): number of samples per batch
negative_sampling (bool, optional): whether to apply mini-batch based negative sampling or not.
"""
def __init__(self, batch_size, negative_sampling=False):
self.batch_size = batch_size
self.negative_sampling = negative_sampling
[docs] def collate(self, users_interactions):
"""
Collates :class:`UsersInteractions` into batches of size ``batch_size``.
Args:
users_interactions (UsersInteractions): a :class:`UsersInteractions`.
Returns:
list[Batch]: list of batches.
"""
batch_users = users_interactions.users
users_inds, items_inds = users_interactions.interactions_matrix.nonzero()
if self.negative_sampling:
# The positive item ids in the batch
# This is simply equivalent to only selecting the non-zero columns
# in the sparse matrix
batch_items, items_inds = np.unique(items_inds, return_inverse=True)
vector_dim = len(batch_items)
batch_items = torch.LongTensor(batch_items)
else:
vector_dim = users_interactions.interactions_matrix.shape[1]
batch_items = None
batch_users = torch.LongTensor(batch_users)
slices = []
current_ind = 0
for offset in range(0, users_interactions.interactions_matrix.shape[0], self.batch_size):
slice_sparse_matrix = users_interactions.interactions_matrix[offset: offset + self.batch_size]
slice_batch_users = batch_users[offset: offset + self.batch_size]
slice_users_inds = slice_sparse_matrix.nonzero()[0]
num_nnz = slice_sparse_matrix.getnnz()
slice_items_inds = items_inds[current_ind:current_ind+num_nnz]
current_ind += num_nnz
slice_inter_vals = slice_sparse_matrix.data
indices = torch.LongTensor([slice_users_inds, slice_items_inds])
values = torch.FloatTensor(slice_inter_vals)
slices.append(Batch(items=batch_items, users=slice_batch_users,
indices=indices, values=values,
size=torch.Size([slice_sparse_matrix.shape[0], vector_dim])))
return slices