1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| from transformers import BertForSequenceClassification,BertModel
from torch.nn import CosineSimilarity,CosineEmbeddingLoss from typing import Optional import torch class SentenceEncoderModel(BertForSequenceClassification): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.bert = BertModel(config) self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict senA_input_ids, senB_input_ids = input_ids[:, 0], input_ids[:, 1] senA_attention_mask, senB_attention_mask = attention_mask[:, 0], attention_mask[:, 1] senA_token_type_ids, senB_token_type_ids = token_type_ids[:, 0], token_type_ids[:, 1] senA_outputs = self.bert( senA_input_ids, attention_mask=senA_attention_mask, token_type_ids=senA_token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) senA_pooled_output = senA_outputs[1] senB_outputs = self.bert( senB_input_ids, attention_mask=senB_attention_mask, token_type_ids=senB_token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) senB_pooled_output = senB_outputs[1] cos = CosineSimilarity()(senA_pooled_output, senB_pooled_output) loss = None if labels is not None: loss_fct = CosineEmbeddingLoss(0.3) loss = loss_fct(senA_pooled_output, senB_pooled_output, labels) output = (cos,) return ((loss,) + output) if loss is not None else output
|