Skip to content

Commit e7a44ae

Browse files
authored
feat(transformers/models): add models of longt5, longformer, etc. (mindspore-lab#1234)
1 parent 90e2bd4 commit e7a44ae

21 files changed

Lines changed: 7110 additions & 1 deletion

mindone/transformers/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,16 @@
655655
LlavaOnevisionProcessor,
656656
LlavaOnevisionVideoProcessor,
657657
)
658+
from .models.longformer import (
659+
LongformerForMaskedLM,
660+
LongformerForMultipleChoice,
661+
LongformerForQuestionAnswering,
662+
LongformerForSequenceClassification,
663+
LongformerForTokenClassification,
664+
LongformerModel,
665+
LongformerPreTrainedModel,
666+
)
667+
from .models.longt5 import LongT5EncoderModel, LongT5ForConditionalGeneration, LongT5Model, LongT5PreTrainedModel
658668
from .models.luke import (
659669
LukeForEntityClassification,
660670
LukeForEntityPairClassification,
@@ -1020,6 +1030,7 @@
10201030
TapasModel,
10211031
TapasPreTrainedModel,
10221032
)
1033+
from .models.timesformer import TimesformerForVideoClassification, TimesformerModel, TimesformerPreTrainedModel
10231034
from .models.trocr import TrOCRForCausalLM, TrOCRPreTrainedModel
10241035
from .models.tvp import TvpForVideoGrounding, TvpModel, TvpPreTrainedModel
10251036
from .models.umt5 import (
@@ -1076,6 +1087,7 @@
10761087
from .models.vitpose import VitPoseForPoseEstimation, VitPosePreTrainedModel
10771088
from .models.vitpose_backbone import VitPoseBackbone, VitPoseBackbonePreTrainedModel
10781089
from .models.vits import VitsModel, VitsPreTrainedModel
1090+
from .models.vivit import VivitForVideoClassification, VivitModel, VivitPreTrainedModel
10791091
from .models.wav2vec2 import (
10801092
Wav2Vec2FeatureExtractor,
10811093
Wav2Vec2ForAudioFrameClassification,

mindone/transformers/activations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,6 @@ def get_activation(activation_string):
225225
return ACT2FN[activation_string]
226226
else:
227227
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
228+
229+
230+
gelu = get_activation("gelu")

mindone/transformers/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@
102102
llava_next,
103103
llava_next_video,
104104
llava_onevision,
105+
longformer,
106+
longt5,
105107
luke,
106108
m2m_100,
107109
mamba,
@@ -169,6 +171,7 @@
169171
switch_transformers,
170172
t5,
171173
tapas,
174+
timesformer,
172175
trocr,
173176
tvp,
174177
umt5,
@@ -187,6 +190,7 @@
187190
vitpose,
188191
vitpose_backbone,
189192
vits,
193+
vivit,
190194
wav2vec2,
191195
x_clip,
192196
xlm_roberta,

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@
126126
("llava_next", "LlavaNextConfig"),
127127
("llava_next_video", "LlavaNextVideoConfig"),
128128
("llava_onevision", "LlavaOnevisionConfig"),
129+
("longformer", "LongformerConfig"),
130+
("longt5", "LongT5Config"),
129131
("luke", "LukeConfig"),
130132
("mamba", "MambaConfig"),
131133
("mamba2", "Mamba2Config"),
@@ -193,6 +195,7 @@
193195
("swin2sr", "Swin2SRConfig"),
194196
("t5", "T5Config"),
195197
("tapas", "TapasConfig"),
198+
("timesformer", "TimesformerConfig"),
196199
("trocr", "TrOCRConfig"),
197200
("tvp", "TvpConfig"),
198201
("umt5", "UMT5Config"),
@@ -209,6 +212,7 @@
209212
("vitdet", "VitDetConfig"),
210213
("vitpose", "VitPoseConfig"),
211214
("vitpose_backbone", "VitPoseBackboneConfig"),
215+
("vivit", "VivitConfig"),
212216
("wav2vec2", "Wav2Vec2Config"),
213217
("mvp", "MvpConfig"),
214218
("whisper", "WhisperConfig"),
@@ -330,6 +334,8 @@
330334
("llava_next", "LLaVA-NeXT"),
331335
("llava_next_video", "LLaVa-NeXT-Video"),
332336
("llava_onevision", "LLaVA-Onevision"),
337+
("longformer", "Longformer"),
338+
("longt5", "LongT5"),
333339
("mimi", "Mimi"),
334340
("mistral", "Mistral"),
335341
("mllama", "Mllama"),
@@ -401,6 +407,7 @@
401407
("t5", "T5"),
402408
("t5v1.1", "T5v1.1"),
403409
("tapas", "TAPAS"),
410+
("timesformer", "TimeSformer"),
404411
("trocr", "TrOCR"),
405412
("tvp", "TVP"),
406413
("umt5", "UMT5"),
@@ -417,6 +424,7 @@
417424
("vitdet", "VitDet"),
418425
("vitpose", "ViTPose"),
419426
("vitpose_backbone", "ViTPoseBackbone"),
427+
("vivit", "ViViT"),
420428
("wav2vec2", "Wav2Vec2"),
421429
("whisper", "Whisper"),
422430
("xclip", "X-CLIP"),

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@
119119
("levit", "LevitModel"),
120120
("lilt", "LiltModel"),
121121
("llama", "LlamaModel"),
122+
("longformer", "LongformerModel"),
123+
("longt5", "LongT5Model"),
122124
("luke", "LukeModel"),
123125
("m2m_100", "M2M100Model"),
124126
("mamba", "MambaModel"),
@@ -177,6 +179,7 @@
177179
("swin2sr", "Swin2SRModel"),
178180
("t5", "T5Model"),
179181
("tapas", "TapasModel"),
182+
("timesformer", "TimesformerModel"),
180183
("tvp", "TvpModel"),
181184
("umt5", "UMT5Model"),
182185
("unispeech", "UniSpeechModel"),
@@ -187,6 +190,7 @@
187190
("vit", "ViTModel"),
188191
("vit_msn", "ViTMSNModel"),
189192
("vitdet", "VitDetModel"),
193+
("vivit", "VivitModel"),
190194
("wav2vec2", "Wav2Vec2Model"),
191195
("whisper", "WhisperModel"),
192196
("xclip", "XCLIPModel"),
@@ -234,7 +238,11 @@
234238
("llava_next", "LlavaNextForConditionalGeneration"),
235239
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
236240
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
241+
("longformer", "LongformerForMaskedLM"),
237242
("luke", "LukeForMaskedLM"),
243+
("mobilebert", "MobileBertForPreTraining"),
244+
("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
245+
("roberta", "RobertaForMaskedLM"),
238246
("megatron-bert", "MegatronBertForPreTraining"),
239247
("mistral3", "Mistral3ForConditionalGeneration"),
240248
("mllama", "MllamaForConditionalGeneration"),
@@ -287,7 +295,10 @@
287295
("gpt2", "GPT2LMHeadModel"),
288296
("ibert", "IBertForMaskedLM"),
289297
("led", "LEDForConditionalGeneration"),
298+
("longformer", "LongformerForMaskedLM"),
299+
("longt5", "LongT5ForConditionalGeneration"),
290300
("luke", "LukeForMaskedLM"),
301+
("camembert", "CamembertForMaskedLM"),
291302
("roberta", "RobertaForMaskedLM"),
292303
("mamba", "MambaForCausalLM"),
293304
("mamba2", "Mamba2ForCausalLM"),
@@ -413,9 +424,11 @@
413424
("segformer", "SegformerModel"),
414425
("siglip_vision_model", "SiglipVisionModel"),
415426
("swin2sr", "Swin2SRModel"),
427+
("timesformer", "TimesformerModel"),
416428
("vit", "ViTModel"),
417429
("vit_msn", "ViTMSNModel"),
418430
("vitdet", "VitDetModel"),
431+
("vivit", "VivitModel"),
419432
("yolos", "YolosModel"),
420433
("zamba2", "Zamba2ForCausalLM"),
421434
]
@@ -503,7 +516,12 @@
503516
]
504517
)
505518

506-
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict()
519+
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
520+
[
521+
("timesformer", "TimesformerForVideoClassification"),
522+
("vivit", "VivitForVideoClassification"),
523+
]
524+
)
507525

508526
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
509527
[
@@ -584,6 +602,7 @@
584602
("electra", "ElectraForMaskedLM"),
585603
("funnel", "FunnelForMaskedLM"),
586604
("ibert", "IBertForMaskedLM"),
605+
("longformer", "LongformerForMaskedLM"),
587606
("luke", "LukeForMaskedLM"),
588607
("mobilebert", "MobileBertForMaskedLM"),
589608
("mpnet", "MPNetForMaskedLM"),
@@ -640,6 +659,7 @@
640659
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
641660
("fsmt", "FSMTForConditionalGeneration"),
642661
("led", "LEDForConditionalGeneration"),
662+
("longt5", "LongT5ForConditionalGeneration"),
643663
("m2m_100", "M2M100ForConditionalGeneration"),
644664
("mvp", "MvpForConditionalGeneration"),
645665
("nllb-moe", "NllbMoeForConditionalGeneration"),
@@ -698,6 +718,7 @@
698718
("canine", "CanineForSequenceClassification"),
699719
("lilt", "LiltForSequenceClassification"),
700720
("llama", "LlamaForSequenceClassification"),
721+
("longformer", "LongformerForSequenceClassification"),
701722
("opt", "OPTForSequenceClassification"),
702723
("persimmon", "PersimmonForSequenceClassification"),
703724
("mbart", "MBartForSequenceClassification"),
@@ -752,6 +773,7 @@
752773
("luke", "LukeForQuestionAnswering"),
753774
("convbert", "ConvBertForQuestionAnswering"),
754775
("llama", "LlamaForQuestionAnswering"),
776+
("longformer", "LongformerForQuestionAnswering"),
755777
("mistral", "MistralForQuestionAnswering"),
756778
("mobilebert", "MobileBertForQuestionAnswering"),
757779
("mpnet", "MPNetForQuestionAnswering"),
@@ -812,6 +834,7 @@
812834
("helium", "HeliumForTokenClassification"),
813835
("ibert", "IBertForTokenClassification"),
814836
("lilt", "LiltForTokenClassification"),
837+
("longformer", "LongformerForTokenClassification"),
815838
("luke", "LukeForTokenClassification"),
816839
("mistral", "MistralForTokenClassification"),
817840
("mobilebert", "MobileBertForTokenClassification"),
@@ -852,6 +875,7 @@
852875
("distilbert", "DistilBertForMultipleChoice"),
853876
("funnel", "FunnelForMultipleChoice"),
854877
("ibert", "IBertForMultipleChoice"),
878+
("longformer", "LongformerForMultipleChoice"),
855879
("luke", "LukeForMultipleChoice"),
856880
("megatron-bert", "MegatronBertForMultipleChoice"),
857881
("mobilebert", "MobileBertForMultipleChoice"),
@@ -971,6 +995,7 @@
971995
("distilbert", "DistilBertModel"),
972996
("emu3", "Emu3TextModel"),
973997
("ibert", "IBertModel"),
998+
("longformer", "LongformerModel"),
974999
("mllama", "MllamaTextModel"),
9751000
("mobilebert", "MobileBertModel"),
9761001
("mt5", "MT5EncoderModel"),
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# This code is adapted from https://github.com/huggingface/transformers
4+
# with modifications to run transformers on mindspore.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
from .modeling_longformer import *

0 commit comments

Comments
 (0)