Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions src/cnlpt/CnlpModelForClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def __init__(
self.encoder = encoder_model.from_pretrained(config.encoder_name)
# part of the motivation for leaving this
# logic alone for character level models is that
# at the time of writing, CANINE and Flair are the only game in town.
# at the time of writing, CANINE and Flair are the only game in town.
# CANINE's hashable embeddings for unicode codepoints allows for
# additional parameterization, which rn doesn't seem so relevant
if not config.character_level:
Expand Down Expand Up @@ -329,12 +329,12 @@ def __init__(
head_size=config.rel_attention_head_dims,
)
if config.relations[task_name]:
hidden_size = config.num_rel_attention_heads
if config.use_prior_tasks:
hidden_size += total_prev_task_labels
# hidden_size = config.num_rel_attention_heads
# if config.use_prior_tasks:
# hidden_size += total_prev_task_labels

self.classifiers[task_name] = ClassificationHead(
config, task_num_labels, hidden_size=hidden_size
config, task_num_labels,
)
else:
self.classifiers[task_name] = ClassificationHead(
Expand Down Expand Up @@ -491,6 +491,30 @@ def compute_loss(
)
state["loss"] += task_weight * task_loss

def remove_task_classifiers(self, tasks: list[str] = None):
if tasks is None:
self.classifiers = nn.ModuleDict()
self.tasks = []
self.class_weights = {}
else:
for task in tasks:
self.classifiers.pop(task)
self.tasks.remove(task)
self.class_weights.pop(task)

def add_task_classifier(self, task_name: str, label_dictionary: dict[str, list]):
self.tasks.append(task_name)
self.classifiers[task_name] = ClassificationHead(
self.config, len(label_dictionary)
)
self.label_dictionary[task_name] = label_dictionary

def set_class_weights(self, class_weights: Union[list[float], None] = None):
if class_weights is None:
self.class_weights = {x: None for x in self.label_dictionary.keys()}
else:
self.class_weights = class_weights

def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -531,7 +555,6 @@ def forward(

Returns: (`transformers.SequenceClassifierOutput`) the output of the model
"""

kwargs = generalize_encoder_forward_kwargs(
self.encoder,
attention_mask=attention_mask,
Expand Down
36 changes: 18 additions & 18 deletions src/cnlpt/cnlp_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,18 @@ class DaptArguments:
"help": "Pretrained tokenizer name or path if not the same as model_name"
},
)
output_dir: Union[str, None] = field(
default=None, metadata={"help": "Directory path to write trained model to."}
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
# output_dir: Union[str, None] = field(
# default=None, metadata={"help": "Directory path to write trained model to."}
# )
# overwrite_output_dir: bool = field(
# default=False,
# metadata={
# "help": (
# "Overwrite the content of the output directory. "
# "Use this to continue training if output_dir points to a checkpoint directory."
# )
# },
# )
data_dir: Union[str, None] = field(
default=None, metadata={"help": "The data dir for domain-adaptive pretraining."}
)
Expand All @@ -333,12 +333,12 @@ class DaptArguments:
default=0.2,
metadata={"help": "The test split proportion for domain-adaptive pretraining."},
)
seed: int = field(
default=42,
metadata={
"help": "The random seed to use for a train/test split for domain-adaptive pretraining (requires --dapt-encoder)."
},
)
# seed: int = field(
# default=42,
# metadata={
# "help": "The random seed to use for a train/test split for domain-adaptive pretraining (requires --dapt-encoder)."
# },
# )
no_eval: bool = field(
default=False,
metadata={"help": "Don't split into train and test; just pretrain."},
Expand Down
7 changes: 3 additions & 4 deletions src/cnlpt/cnlp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,10 +1155,9 @@ def __init__(
batched=True,
remove_columns=list(remove_columns),
)
dataset = dataset.map(
functools.partial(group_texts, self.args.chunk_size),
batched=True,
)

dataset = dataset.remove_columns("word_ids")


if isinstance(dataset, (DatasetDict, IterableDatasetDict)) or args.no_eval:
self.dataset = dataset
Expand Down
2 changes: 1 addition & 1 deletion src/cnlpt/cnlp_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(self, data_dir: str, tasks: set[str] = None, max_train_items=-1):
else:
sep = "\t"

self.dataset = load_dataset("csv", sep=sep, data_files=data_files)
self.dataset = load_dataset("csv", sep=sep, data_files=data_files, keep_default_na=False)

## find out what tasks are available to this dataset, and see the overlap with what the
## user specified at the cli, remove those tasks so we don't also get them from other datasets
Expand Down
87 changes: 74 additions & 13 deletions src/cnlpt/dapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,83 @@
Domain-adaptive pretraining (see DAPT.md for details)
"""

import logging
import os
import sys
from typing import Any, Union

from transformers import (
AutoConfig,
AutoModelForMaskedLM,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,

Check failure on line 16 in src/cnlpt/dapt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/cnlpt/dapt.py:16:5: F401 `transformers.TrainingArguments` imported but unused
set_seed,
)

from .cnlp_args import DaptArguments
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import MaskedLMOutput
from transformers.modeling_utils import PreTrainedModel

from .CnlpModelForClassification import CnlpConfig, freeze_encoder_weights, generalize_encoder_forward_kwargs
from .cnlp_args import DaptArguments, CnlpTrainingArguments
from .cnlp_data import DaptDataset

logger = logging.getLogger(__name__)

Check failure on line 28 in src/cnlpt/dapt.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

src/cnlpt/dapt.py:5:1: I001 Import block is un-sorted or un-formatted


class DaptModel(PreTrainedModel):
base_model_prefix = "cnlpt"
config_class = CnlpConfig

def __init__(
self,
config: config_class,
freeze: float = -1.0,
):
super().__init__(config)
encoder_config = AutoConfig.from_pretrained(config._name_or_path)
encoder_config.vocab_size = config.vocab_size
config.encoder_config = encoder_config.to_dict()
model = AutoModelForMaskedLM.from_config(encoder_config)
self.encoder = model.from_pretrained(config._name_or_path)
# if not config.character_level:
self.encoder.resize_token_embeddings(encoder_config.vocab_size)

if freeze > 0:
freeze_encoder_weights(self.encoder.bert.encoder, freeze)

def forward(
self,
input_ids,
token_type_ids,
attention_mask,
labels,
):
kwargs = generalize_encoder_forward_kwargs(
self.encoder,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_hidden_states=True,
return_dict=True,
)

outputs = self.encoder(input_ids, **kwargs)
logits = outputs.logits

if labels is not None:
loss_fn = CrossEntropyLoss()
loss = loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))

return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


def main(
json_file: Union[str, None] = None, json_obj: Union[dict[str, Any], None] = None
):
Expand All @@ -39,30 +96,31 @@
:rtype: typing.Dict[str, typing.Dict[str, typing.Any]]
:return: the evaluation results (will be empty if ``--do_eval`` not passed)
"""
parser = HfArgumentParser((DaptArguments,))
parser = HfArgumentParser((DaptArguments, CnlpTrainingArguments))
dapt_args: DaptArguments
training_args: CnlpTrainingArguments

if json_file is not None and json_obj is not None:
raise ValueError("cannot specify json_file and json_obj")

if json_file is not None:
(dapt_args,) = parser.parse_json_file(json_file=json_file)
(dapt_args, training_args) = parser.parse_json_file(json_file=json_file)
elif json_obj is not None:
(dapt_args,) = parser.parse_dict(json_obj)
(dapt_args, training_args) = parser.parse_dict(json_obj)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
(dapt_args,) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
(dapt_args, training_args) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
(dapt_args,) = parser.parse_args_into_dataclasses()
(dapt_args, training_args) = parser.parse_args_into_dataclasses()

if (
os.path.exists(dapt_args.output_dir)
and os.listdir(dapt_args.output_dir)
and not dapt_args.overwrite_output_dir
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({dapt_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)

# Setup logging
Expand All @@ -85,9 +143,10 @@
# logger.info("Model parameters %s" % model_args)

logger.info(f"Domain adaptation parameters {dapt_args}")
logger.info(f"Training arguments {training_args}")

# Set seed
set_seed(dapt_args.seed)
set_seed(training_args.seed)

# Load tokenizer: Need this first for loading the datasets
tokenizer = AutoTokenizer.from_pretrained(
Expand All @@ -101,13 +160,15 @@
# additional_special_tokens=['<e>', '</e>', '<a1>', '</a1>', '<a2>', '</a2>', '<cr>', '<neg>']
)

model = AutoModelForMaskedLM.from_pretrained(dapt_args.encoder_name)
# model = AutoModelForMaskedLM.from_pretrained(dapt_args.encoder_name)
config = AutoConfig.from_pretrained(dapt_args.encoder_name)
model = DaptModel(config, freeze=training_args.freeze)

dataset = DaptDataset(dapt_args, tokenizer=tokenizer)

trainer = Trainer(
model=model,
args=TrainingArguments(output_dir=dapt_args.output_dir),
args=training_args,
train_dataset=dataset.train,
eval_dataset=dataset.test if not dapt_args.no_eval else None,
data_collator=dataset.data_collator,
Expand Down
Loading
Loading