|
from transformers.models.bert import BertConfig |
|
from transformers.models.xlm_roberta import XLMRobertaConfig |
|
|
|
|
|
def _init_function( |
|
self, |
|
entity_vocab_size: int | None = 10000, |
|
entity_embedding_size: int = 768, |
|
entity_fusion_method: str = "multihead_attention", |
|
use_entity_position_embeddings: bool = True, |
|
entity_fusion_activation: str = "softmax", |
|
num_entity_fusion_attention_heads: int = 12, |
|
similarity_function: str = "dot", |
|
similarity_temperature: float = 1.0, |
|
*args, |
|
**kwargs, |
|
): |
|
self.entity_vocab_size = entity_vocab_size |
|
self.entity_embedding_size = entity_embedding_size |
|
self.entity_fusion_method = entity_fusion_method |
|
self.use_entity_position_embeddings = use_entity_position_embeddings |
|
self.entity_fusion_activation = entity_fusion_activation |
|
self.num_entity_fusion_attention_heads = num_entity_fusion_attention_heads |
|
self.similarity_function = similarity_function |
|
self.similarity_temperature = similarity_temperature |
|
|
|
super(self.__class__, self).__init__(*args, **kwargs) |
|
|
|
|
|
class KPRConfigForBert(BertConfig): |
|
__init__ = _init_function |
|
model_type = "kpr-bert" |
|
|
|
|
|
class KPRConfigForXLMRoberta(XLMRobertaConfig): |
|
__init__ = _init_function |
|
model_type = "kpr-xlm-roberta" |
|
|