Skip to content

Commit

Permalink
Merge pull request #253 from ASUS-AICS/preprocessor-loading-dataframe
Browse files Browse the repository at this point in the history
Linear preprocessor loading dataframe.
  • Loading branch information
Gordon119 authored Feb 16, 2023
2 parents a9e2cae + 285f143 commit 5d47eeb
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions libmultilabel/linear/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,24 @@ def __init__(self, data_format: str) -> None:
"""Initializes the preprocessor.
Args:
data_format (str): The data format used. 'svm' for LibSVM format and 'txt' for LibMultiLabel format.
data_format (str): The data format used. 'svm' for LibSVM format, 'txt' for LibMultiLabel format in file and 'dataframe' for LibMultiLabel format in dataframe .
"""
if not data_format in {'txt', 'svm'}:
if not data_format in {'txt', 'svm', 'dataframe'}:
raise ValueError(f'unsupported data format {data_format}')

self.data_format = data_format

def load_data(self, training_file: str = None,
test_file: str = None,
def load_data(self, training_data: Union[str, pd.DataFrame] = None,
test_data: Union[str, pd.DataFrame] = None,
eval: bool = False,
label_file: str = None,
include_test_labels: bool = False,
remove_no_label_data: bool = False) -> 'dict[str, dict]':
"""Loads and preprocesses data.
Args:
training_file (str): Training data file. Ignored if eval is True. Defaults to None.
test_file (str): Test data file. Ignored if test_file doesn't exist. Defaults to None.
training_data (Union[str, pd.DataFrame]): Training data file or dataframe in LibMultiLabel format. Ignored if eval is True. Defaults to None.
test_data (Union[str, pd.DataFrame]): Test data file or dataframe in LibMultiLabel format. Ignored if test_data doesn't exist. Defaults to None.
eval (bool): If True, ignores training data and uses previously loaded state to preprocess test data.
label_file (str, optional): Path to a file holding all labels.
include_test_labels (bool, optional): Whether to include labels in the test dataset. Defaults to False.
Expand All @@ -58,41 +58,47 @@ def load_data(self, training_file: str = None,
with open(label_file, 'r') as fp:
self.classes = sorted([s.strip() for s in fp.readlines()])
else:
if test_file is None and include_test_labels:
if test_data is None and include_test_labels:
raise ValueError(
f'Specified the inclusion of test labels but test file does not exist')
self.classes = None
self.include_test_labels = include_test_labels

if self.data_format == 'txt':
data = self._load_txt(training_file, test_file, eval)
if self.data_format == 'txt' or 'dataframe':
data = self._load_libmultilabel(training_data, test_data, eval)
elif self.data_format == 'svm':
data = self._load_svm(training_file, test_file, eval)
data = self._load_svm(training_data, test_data, eval)

if 'train' in data:
num_labels = data['train']['y'].getnnz(axis=1)
num_no_label_data = np.count_nonzero(num_labels == 0)
if num_no_label_data > 0:
if remove_no_label_data:
logging.info(
f'Remove {num_no_label_data} instances that have no labels from {training_file}.',
f'Remove {num_no_label_data} instances that have no labels from {training_data}.',
extra={'collect': True})
data['train']['x'] = data['train']['x'][num_labels > 0]
data['train']['y'] = data['train']['y'][num_labels > 0]
else:
logging.info(
f'Keep {num_no_label_data} instances that have no labels from {training_file}.',
f'Keep {num_no_label_data} instances that have no labels from {training_data}.',
extra={'collect': True})

return data

def _load_txt(self, training_file, test_file, eval) -> 'dict[str, dict]':
def _load_libmultilabel(self, training_data, test_data, eval) -> 'dict[str, dict]':
datasets = defaultdict(dict)
if test_file is not None:
test = read_libmultilabel_format(test_file)
if test_data is not None:
if self.data_format == 'txt':
test_data = pd.read_csv(test_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
test = read_libmultilabel_format(test_data)

if not eval:
train = read_libmultilabel_format(training_file)
if self.data_format == 'txt':
training_data = pd.read_csv(training_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
train = read_libmultilabel_format(training_data)
self._generate_tfidf(train['text'])

if self.classes or not self.include_test_labels:
Expand All @@ -103,28 +109,28 @@ def _load_txt(self, training_file, test_file, eval) -> 'dict[str, dict]':
datasets['train']['y'] = self.binarizer.transform(
train['label']).astype('d')

if test_file is not None:
if test_data is not None:
datasets['test']['x'] = self.vectorizer.transform(test['text'])
datasets['test']['y'] = self.binarizer.transform(
test['label']).astype('d')

return dict(datasets)

def _load_svm(self, training_file, test_file, eval) -> 'dict[str, dict]':
def _load_svm(self, training_data, test_data, eval) -> 'dict[str, dict]':
datasets = defaultdict(dict)
if test_file is not None:
ty, tx = read_libsvm_format(test_file)
if test_data is not None:
ty, tx = read_libsvm_format(test_data)

if not eval:
y, x = read_libsvm_format(training_file)
y, x = read_libsvm_format(training_data)
if self.classes or not self.include_test_labels:
self._generate_label_mapping(y, self.classes)
else:
self._generate_label_mapping(y + ty)
datasets['train']['x'] = x
datasets['train']['y'] = self.binarizer.transform(y).astype('d')

if test_file is not None:
if test_data is not None:
datasets['test']['x'] = tx
datasets['test']['y'] = self.binarizer.transform(ty).astype('d')
return dict(datasets)
Expand All @@ -139,10 +145,8 @@ def _generate_label_mapping(self, labels, classes=None):
self.binarizer.fit(labels)


def read_libmultilabel_format(path: str) -> 'dict[str,list[str]]':
data = pd.read_csv(path, sep='\t', header=None,
dtype=str,
on_bad_lines='skip', quoting=csv.QUOTE_NONE).fillna('')
def read_libmultilabel_format(data: pd.DataFrame) -> 'dict[str,list[str]]':
data = data.astype(str)
if data.shape[1] == 2:
data.columns = ['label', 'text']
data = data.reset_index()
Expand All @@ -153,7 +157,6 @@ def read_libmultilabel_format(path: str) -> 'dict[str,list[str]]':
data['label'] = data['label'].map(lambda s: s.split())
return data.to_dict('list')


def read_libsvm_format(file_path: str) -> 'tuple[list[list[int]], sparse.csr_matrix]':
"""Read multi-label LIBSVM-format data.
Expand Down

0 comments on commit 5d47eeb

Please sign in to comment.