Upload modeling_keep.py
Browse files- modeling_keep.py +6 -5
modeling_keep.py
CHANGED
|
@@ -29,14 +29,15 @@ class KEEPModel(PreTrainedModel):
|
|
| 29 |
super().__init__(config)
|
| 30 |
|
| 31 |
# Vision Encoder (基于 timm 的 ViT)
|
|
|
|
| 32 |
self.visual = timm.create_model(
|
| 33 |
"vit_large_patch16_224",
|
| 34 |
pretrained=False,
|
| 35 |
-
img_size=224,
|
| 36 |
-
patch_size=16,
|
| 37 |
-
init_values=1e-5,
|
| 38 |
-
num_classes=0,
|
| 39 |
-
dynamic_img_size=True,
|
| 40 |
)
|
| 41 |
|
| 42 |
# 线性投影层,将 Vision Encoder 的输出投影到 768 维
|
|
|
|
| 29 |
super().__init__(config)
|
| 30 |
|
| 31 |
# Vision Encoder (基于 timm 的 ViT)
|
| 32 |
+
vision_config = config.vision_config
|
| 33 |
self.visual = timm.create_model(
|
| 34 |
"vit_large_patch16_224",
|
| 35 |
pretrained=False,
|
| 36 |
+
img_size=vision_config.get("img_size", 224),
|
| 37 |
+
patch_size=vision_config.get("patch_size", 16),
|
| 38 |
+
init_values=vision_config.get("init_values", 1e-5),
|
| 39 |
+
num_classes=vision_config.get("num_classes", 0),
|
| 40 |
+
dynamic_img_size=vision_config.get("dynamic_img_size", True),
|
| 41 |
)
|
| 42 |
|
| 43 |
# 线性投影层,将 Vision Encoder 的输出投影到 768 维
|