Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,22 @@ For a demo of how to run the system in colab: [![Open In Colab](https://colab.re
For chemical NER and the following for gene NER (`GENE-Y` are normalizable gene names and `GENE-N` are non-normalizable):
```gene_ner = {'acc': 0.9835720382010852, 'token_f1': [0.5503875968992249, 0.8223125230882896, 0.0, 0.0, 0.9949847518170479], 'f1': 0.7477207783371888, 'report': '\n precision recall f1-score support\n\n GENE-N 0.70 0.45 0.55 2355\n GENE-Y 0.76 0.88 0.82 5013\n\n micro avg 0.75 0.75 0.75 7368\n macro avg 0.73 0.67 0.68 7368\nweighted avg 0.74 0.75 0.73 7368\n'}```

### Fine-tuning for concept normalization: End-to-end example
1. Prepare the concept normalization input data (train.tsv, dev.tsv, and test.tsv) use following format (.tsv file).

| text | conceptnorm |
| --- | --- |
| <e> pleural effusion </e> | C0032227 |
| <e> pulmonary consolidation </e> | C0521530 |
| <e> aorta tortuous </e> | CUI-less |

2. Prepare ontology cui text file (.txt), each line of the file will be a CUI.

3. Prepare concept embeddings .npy file, and save it into a folder $concept_norm_path (a giant matrix with each row correspoding to the embeddings of each CUI, the order of the CUI embeddings follows the order of the ontology cui text file).

4. Fine-tune with something like:
```python -m cnlpt.train_system # --do_train --do_eval --task_name conceptnorm --data_dir cnlp_concept_norm/ --encoder_name cambridgeltl/SapBERT-from-PubMedBERT-fulltext-mean-token --output_dir temp/ --concept_norm_path $concept_norm_path --overwrite_output_dir --cache cache --token true --num_train_epochs 5 --learning_rate 3e-5 --per_device_train_batch_size 64 --max_seq_length 16 --layer 12 --seed 24 --evals_per_epoch 1```

### Fine-tuning options
Run ```python -m cnlpt.train_system -h``` to see all the available options. In addition to inherited Huggingface Transformers options, there are options to do the following:
* Run simple baselines (use ``--model cnn --tokenizer_name roberta-base`` -- since there is no HF model then you must specify the tokenizer explicitly)
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ install_requires =

[options.entry_points]
console_scripts =
cnlpt_conceptnorm_rest = cnlpt.api.concept_norm_rest:rest
cnlpt_dtr_rest = cnlpt.api.dtr_rest:rest
cnlpt_event_rest = cnlpt.api.event_rest:rest
cnlpt_negation_rest = cnlpt.api.negation_rest:rest
Expand Down
74 changes: 64 additions & 10 deletions src/cnlpt/CnlpModelForClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
import torch
from torch import nn
import logging
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import CrossEntropyLoss, MSELoss, Parameter
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn.functional import softmax, relu
from torch.nn.functional import softmax, relu, normalize
import math
import random
import numpy as np
import os

logger = logging.getLogger(__name__)

Expand All @@ -33,9 +35,50 @@ def forward(self, features, *kwargs):
x = self.out_proj(x)
return x

class CosineLayer(nn.Module):
def __init__(
self,
config,
concept_dim=88150
):
super(CosineLayer, self).__init__()

self.dropout = nn.Dropout(0.1)
self.cos = nn.CosineSimilarity(dim=-1)
concept_dims = (concept_dim,config.hidden_size)

if config.concept_norm is not None:
weights_matrix = np.load(os.path.join(config.concept_norm,
"concept_embeddings.npy")).astype(np.float32)
self.weight = Parameter(torch.from_numpy(weights_matrix),
requires_grad=True)
threshold_value = np.loadtxt(
os.path.join(config.concept_norm, "cuiless_threshold.txt")).astype(np.float32)
self.threshold = Parameter(torch.tensor(threshold_value),
requires_grad=False)
else:
self.weight = Parameter(torch.rand(concept_dims),
requires_grad=True)
torch.nn.init.xavier_uniform(self.weight)
self.threshold = Parameter(torch.tensor(0.35), requires_grad=True)

def forward(self, features):
batch_size, fea_size = features.shape
features_norm = normalize(features)
weight_norm = normalize(self.weight)
sim_mt = torch.mm(features_norm, weight_norm.transpose(0, 1))
cui_less_score = torch.full((batch_size, 1), 1).to(
features.device) * self.threshold.to(features.device)
similarity_score = torch.cat((sim_mt, cui_less_score), 1)
# if self.config.finetuning_task[task_ind] == "conceptnorm":
# #### TODO add scaling as a hyper-parameter for concept normalization
scaling = 0.03
similarity_score = similarity_score/scaling
return similarity_score


class RepresentationProjectionLayer(nn.Module):
def __init__(self, config, layer=10, tokens=False, tagger=False, relations=False, num_attention_heads=-1, head_size=64):
def __init__(self, config, layer=10, tokens=False, tagger=False, relations=False, skip_projection=False,num_attention_heads=-1, head_size=64):
super().__init__()
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if relations:
Expand All @@ -47,6 +90,7 @@ def __init__(self, config, layer=10, tokens=False, tagger=False, relations=False
self.tokens = tokens
self.tagger = tagger
self.relations = relations
self.skip_projection = skip_projection
self.hidden_size = config.hidden_size

if num_attention_heads <= 0 and relations:
Expand Down Expand Up @@ -95,9 +139,13 @@ def forward(self, features, event_tokens, **kwargs):
# take <s> token (equiv. to [CLS])
x = features[self.layer_to_use][..., 0, :]

x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
# for normal classification we pass through a dense layer, for cosine layer
# classification we just grab the representation directly:
if not self.skip_projection:
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)

return x


Expand Down Expand Up @@ -131,6 +179,7 @@ def __init__(
tagger = [False],
relations = [False],
use_prior_tasks=False,
concept_norm=None,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -146,6 +195,7 @@ def __init__(
self.use_prior_tasks = use_prior_tasks
self.encoder_name = encoder_name
self.encoder_config = AutoConfig.from_pretrained(encoder_name).to_dict()
self.concept_norm = concept_norm
if encoder_name.startswith('distilbert'):
self.hidden_dropout_prob = self.encoder_config['dropout']
self.hidden_size = self.encoder_config['dim']
Expand Down Expand Up @@ -222,13 +272,17 @@ def __init__(self,
self.classifiers = nn.ModuleList()
total_prev_task_labels = 0
for task_ind,task_num_labels in enumerate(self.num_labels):
self.feature_extractors.append(RepresentationProjectionLayer(config, layer=config.layer, tokens=config.tokens, tagger=config.tagger[task_ind], relations=config.relations[task_ind], num_attention_heads=config.num_rel_attention_heads, head_size=config.rel_attention_head_dims))
conceptnorm = config.finetuning_task[task_ind] == "conceptnorm"
self.feature_extractors.append(RepresentationProjectionLayer(config, layer=config.layer, tokens=config.tokens, tagger=config.tagger[task_ind], relations=config.relations[task_ind], skip_projection=conceptnorm, num_attention_heads=config.num_rel_attention_heads, head_size=config.rel_attention_head_dims))
if config.relations[task_ind]:
hidden_size = config.num_rel_attention_heads
if config.use_prior_tasks:
hidden_size += total_prev_task_labels

self.classifiers.append(ClassificationHead(config, task_num_labels, hidden_size=hidden_size))
elif conceptnorm:
self.classifiers.append(CosineLayer(config,
concept_dim=task_num_labels -1))
else:
self.classifiers.append(ClassificationHead(config, task_num_labels))
total_prev_task_labels += task_num_labels
Expand Down Expand Up @@ -424,7 +478,7 @@ def forward(
labels=None,
output_attentions=None,
output_hidden_states=None,
event_tokens=None,
event_mask=None,
):
r"""
Forward method.
Expand All @@ -449,7 +503,7 @@ def forward(
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers.
output_hidden_states: not used.
event_tokens: a mask defining which tokens in the input are to be averaged for input to classifier head; only used when self.tokens==True.
event_mask: a mask defining which tokens in the input are to be averaged for input to classifier head; only used when self.tokens==True.

Returns: (`transformers.SequenceClassifierOutput`) the output of the model
"""
Expand Down Expand Up @@ -480,7 +534,7 @@ def forward(
)

for task_ind,task_num_labels in enumerate(self.num_labels):
features = self.feature_extractors[task_ind](outputs.hidden_states, event_tokens)
features = self.feature_extractors[task_ind](outputs.hidden_states, event_mask)
if self.use_prior_tasks:
# note: this specific way of incorporating previous logits doesn't help in my experiments with thyme/clinical tempeval
if self.relations[task_ind]:
Expand Down
5 changes: 5 additions & 0 deletions src/cnlpt/api/cnlp_rest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

# Core python imports
import os
import os.path

# FastAPI imports
from pydantic import BaseModel
Expand Down Expand Up @@ -70,6 +71,10 @@ def initialize_cnlpt_model(app, model_name, cuda=True, batch_size=8):
AutoModel.register(CnlpConfig, CnlpModelForClassification)

config = AutoConfig.from_pretrained(model_name)

if 'concept_norm' in config.__dict__:
config.concept_norm = os.path.join(model_name, '..')

app.state.tokenizer = AutoTokenizer.from_pretrained(model_name,
config=config)
model = CnlpModelForClassification.from_pretrained(model_name, cache_dir=os.getenv('HF_CACHE'), config=config)
Expand Down
120 changes: 120 additions & 0 deletions src/cnlpt/api/concept_norm_rest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from fastapi import FastAPI
from pydantic import BaseModel

from typing import List, Tuple, Dict

# Modeling imports
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
)

# from .api.cnlp_rest import get_dataset
from datasets import Dataset

from ..CnlpModelForClassification import CnlpModelForClassification, CnlpConfig
from .cnlp_rest import get_dataset, initialize_cnlpt_model
import numpy as np
import torch

import logging, os, json
from time import time

app = FastAPI()
model_name = "/lab-share/CHIP-Savova-e2/Public/resources/cnlpt/concept_norm/share/checkpoint-57456/"
logger = logging.getLogger('Concept_Normalization_REST_Processor')
logger.setLevel(logging.DEBUG)

task = 'conceptnorm'
with open(os.path.join(model_name,"../ontology_cui.txt"), 'r') as outfile:
labels = json.load(outfile)

outfile.close()
labels = labels + ["CUI-less"]

max_length = 32

class Entity(BaseModel):
''' doc_text: The raw text of the document
offset: A list of entities, where each is a tuple of character offsets into doc_text for that entity'''
entity_text: str


class ConceptNormResults(BaseModel):
''' statuses: dictionary from entity id to classification decision about negation; true -> negated, false -> not negated'''
statuses: List[str]

@app.on_event("startup")
async def startup_event():
initialize_cnlpt_model(app, model_name)

@app.post("/concept_norm/process")
async def process(entity: Entity):
text = entity.entity_text
logger.warn('Received entities of len %d to process' % (len(text)))
instances = [text]
start_time = time()

dataset = get_dataset(instances, app.state.tokenizer, [labels,], [task,], max_length)
preproc_end = time()

output = app.state.trainer.predict(test_dataset=dataset)
predictions = output.predictions[0]
predictions = np.argmax(predictions, axis=1)

pred_end = time()

results = []
for ent_ind in range(len(dataset)):
results.append(labels[predictions[ent_ind]])

output = ConceptNormResults(statuses=results)

postproc_end = time()

preproc_time = preproc_end - start_time
pred_time = pred_end - preproc_end
postproc_time = postproc_end - pred_end

logging.warn("Pre-processing time: %f, processing time: %f, post-processing time %f" % (preproc_time, pred_time, postproc_time))

return output

@app.get("/conceptnorm/{test_str}")
async def test(test_str: str):
return {'argument': test_str}


def rest():
import argparse

parser = argparse.ArgumentParser(description='Run the http server for negation')
parser.add_argument('-p', '--port', type=int, help='The port number to run the server on', default=8000)

args = parser.parse_args()

import uvicorn
uvicorn.run("cnlpt.api.concept_norm_rest:app", host='0.0.0.0', port=args.port, reload=True)

if __name__ == '__main__':
rest()
Loading