qninhdt
commited on
Commit
·
3f792e3
1
Parent(s):
8db8077
cc
Browse files- configs/experiment/miniagent-bert-attn-m8.yaml +4 -4
- configs/experiment/{miniagent-bert-attn.yaml → miniagent-bert-attn-v1.yaml} +7 -7
- configs/experiment/miniagent-bert-attn-v2.yaml +36 -0
- configs/experiment/miniagent-bert-mlp-abs_diff-mult.yaml +2 -2
- configs/experiment/miniagent-bert-mlp-abs_diff.yaml +2 -2
- configs/experiment/miniagent-bert-mlp-mult.yaml +2 -2
- configs/experiment/miniagent-bert-mlp.yaml +2 -2
- configs/logger/wandb.yaml +1 -1
- configs/trainer/default.yaml +3 -3
- src/models/{attn_module.py → attn_v1_module.py} +34 -68
- src/models/attn_v2_module.py +124 -0
- src/models/miniagent_module.py +1 -0
- src/train.py +2 -0
configs/experiment/miniagent-bert-attn-m8.yaml
CHANGED
|
@@ -15,22 +15,22 @@ model:
|
|
| 15 |
inst_proj_model:
|
| 16 |
_target_: src.models.attn_module.AttnProjection
|
| 17 |
input_dim: 768
|
| 18 |
-
n_heads:
|
| 19 |
output_length: 8
|
| 20 |
|
| 21 |
tool_proj_model:
|
| 22 |
_target_: src.models.attn_module.AttnProjection
|
| 23 |
input_dim: 768
|
| 24 |
-
n_heads:
|
| 25 |
output_length: 8
|
| 26 |
|
| 27 |
pred_model:
|
| 28 |
_target_: src.models.attn_module.BiAttnPrediction
|
| 29 |
input_dim: 768
|
| 30 |
-
n_heads:
|
| 31 |
|
| 32 |
data:
|
| 33 |
bert_model: bert-base-uncased
|
| 34 |
seed: 42
|
| 35 |
-
batch_size:
|
| 36 |
tool_capacity: 16
|
|
|
|
| 15 |
inst_proj_model:
|
| 16 |
_target_: src.models.attn_module.AttnProjection
|
| 17 |
input_dim: 768
|
| 18 |
+
n_heads: 4
|
| 19 |
output_length: 8
|
| 20 |
|
| 21 |
tool_proj_model:
|
| 22 |
_target_: src.models.attn_module.AttnProjection
|
| 23 |
input_dim: 768
|
| 24 |
+
n_heads: 4
|
| 25 |
output_length: 8
|
| 26 |
|
| 27 |
pred_model:
|
| 28 |
_target_: src.models.attn_module.BiAttnPrediction
|
| 29 |
input_dim: 768
|
| 30 |
+
n_heads: 4
|
| 31 |
|
| 32 |
data:
|
| 33 |
bert_model: bert-base-uncased
|
| 34 |
seed: 42
|
| 35 |
+
batch_size: 64
|
| 36 |
tool_capacity: 16
|
configs/experiment/{miniagent-bert-attn.yaml → miniagent-bert-attn-v1.yaml}
RENAMED
|
@@ -13,24 +13,24 @@ model:
|
|
| 13 |
bert_model: bert-base-uncased
|
| 14 |
|
| 15 |
inst_proj_model:
|
| 16 |
-
_target_: src.models.
|
| 17 |
input_dim: 768
|
| 18 |
-
n_heads:
|
| 19 |
output_length: 16
|
| 20 |
|
| 21 |
tool_proj_model:
|
| 22 |
-
_target_: src.models.
|
| 23 |
input_dim: 768
|
| 24 |
-
n_heads:
|
| 25 |
output_length: 16
|
| 26 |
|
| 27 |
pred_model:
|
| 28 |
-
_target_: src.models.
|
| 29 |
input_dim: 768
|
| 30 |
-
n_heads:
|
| 31 |
|
| 32 |
data:
|
| 33 |
bert_model: bert-base-uncased
|
| 34 |
seed: 42
|
| 35 |
-
batch_size:
|
| 36 |
tool_capacity: 16
|
|
|
|
| 13 |
bert_model: bert-base-uncased
|
| 14 |
|
| 15 |
inst_proj_model:
|
| 16 |
+
_target_: src.models.attn_v1_module.AttnProjection
|
| 17 |
input_dim: 768
|
| 18 |
+
n_heads: 4
|
| 19 |
output_length: 16
|
| 20 |
|
| 21 |
tool_proj_model:
|
| 22 |
+
_target_: src.models.attn_v1_module.AttnProjection
|
| 23 |
input_dim: 768
|
| 24 |
+
n_heads: 4
|
| 25 |
output_length: 16
|
| 26 |
|
| 27 |
pred_model:
|
| 28 |
+
_target_: src.models.attn_v1_module.BiAttnPrediction
|
| 29 |
input_dim: 768
|
| 30 |
+
n_heads: 4
|
| 31 |
|
| 32 |
data:
|
| 33 |
bert_model: bert-base-uncased
|
| 34 |
seed: 42
|
| 35 |
+
batch_size: 64
|
| 36 |
tool_capacity: 16
|
configs/experiment/miniagent-bert-attn-v2.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /data: mixed
|
| 5 |
+
- override /model: miniagent
|
| 6 |
+
- override /callbacks: default
|
| 7 |
+
- override /trainer: gpu
|
| 8 |
+
|
| 9 |
+
seed: 42
|
| 10 |
+
|
| 11 |
+
model:
|
| 12 |
+
lr: 1e-5
|
| 13 |
+
bert_model: bert-base-uncased
|
| 14 |
+
|
| 15 |
+
inst_proj_model:
|
| 16 |
+
_target_: src.models.attn_v2_module.AttnProjection
|
| 17 |
+
input_dim: 768
|
| 18 |
+
n_heads: 4
|
| 19 |
+
output_length: 16
|
| 20 |
+
|
| 21 |
+
tool_proj_model:
|
| 22 |
+
_target_: src.models.attn_v2_module.AttnProjection
|
| 23 |
+
input_dim: 768
|
| 24 |
+
n_heads: 4
|
| 25 |
+
output_length: 16
|
| 26 |
+
|
| 27 |
+
pred_model:
|
| 28 |
+
_target_: src.models.attn_v2_module.BiAttnPrediction
|
| 29 |
+
input_dim: 768
|
| 30 |
+
n_heads: 4
|
| 31 |
+
|
| 32 |
+
data:
|
| 33 |
+
bert_model: bert-base-uncased
|
| 34 |
+
seed: 42
|
| 35 |
+
batch_size: 64
|
| 36 |
+
tool_capacity: 16
|
configs/experiment/miniagent-bert-mlp-abs_diff-mult.yaml
CHANGED
|
@@ -9,7 +9,7 @@ defaults:
|
|
| 9 |
seed: 42
|
| 10 |
|
| 11 |
model:
|
| 12 |
-
lr: 0.
|
| 13 |
bert_model: bert-base-uncased
|
| 14 |
|
| 15 |
inst_proj_model:
|
|
@@ -33,5 +33,5 @@ model:
|
|
| 33 |
data:
|
| 34 |
bert_model: bert-base-uncased
|
| 35 |
seed: 42
|
| 36 |
-
batch_size:
|
| 37 |
tool_capacity: 16
|
|
|
|
| 9 |
seed: 42
|
| 10 |
|
| 11 |
model:
|
| 12 |
+
lr: 0.0001
|
| 13 |
bert_model: bert-base-uncased
|
| 14 |
|
| 15 |
inst_proj_model:
|
|
|
|
| 33 |
data:
|
| 34 |
bert_model: bert-base-uncased
|
| 35 |
seed: 42
|
| 36 |
+
batch_size: 64
|
| 37 |
tool_capacity: 16
|
configs/experiment/miniagent-bert-mlp-abs_diff.yaml
CHANGED
|
@@ -9,7 +9,7 @@ defaults:
|
|
| 9 |
seed: 42
|
| 10 |
|
| 11 |
model:
|
| 12 |
-
lr: 0.
|
| 13 |
bert_model: bert-base-uncased
|
| 14 |
|
| 15 |
inst_proj_model:
|
|
@@ -33,5 +33,5 @@ model:
|
|
| 33 |
data:
|
| 34 |
bert_model: bert-base-uncased
|
| 35 |
seed: 42
|
| 36 |
-
batch_size:
|
| 37 |
tool_capacity: 16
|
|
|
|
| 9 |
seed: 42
|
| 10 |
|
| 11 |
model:
|
| 12 |
+
lr: 0.0001
|
| 13 |
bert_model: bert-base-uncased
|
| 14 |
|
| 15 |
inst_proj_model:
|
|
|
|
| 33 |
data:
|
| 34 |
bert_model: bert-base-uncased
|
| 35 |
seed: 42
|
| 36 |
+
batch_size: 64
|
| 37 |
tool_capacity: 16
|
configs/experiment/miniagent-bert-mlp-mult.yaml
CHANGED
|
@@ -9,7 +9,7 @@ defaults:
|
|
| 9 |
seed: 42
|
| 10 |
|
| 11 |
model:
|
| 12 |
-
lr: 0.
|
| 13 |
bert_model: bert-base-uncased
|
| 14 |
|
| 15 |
inst_proj_model:
|
|
@@ -33,5 +33,5 @@ model:
|
|
| 33 |
data:
|
| 34 |
bert_model: bert-base-uncased
|
| 35 |
seed: 42
|
| 36 |
-
batch_size:
|
| 37 |
tool_capacity: 16
|
|
|
|
| 9 |
seed: 42
|
| 10 |
|
| 11 |
model:
|
| 12 |
+
lr: 0.0001
|
| 13 |
bert_model: bert-base-uncased
|
| 14 |
|
| 15 |
inst_proj_model:
|
|
|
|
| 33 |
data:
|
| 34 |
bert_model: bert-base-uncased
|
| 35 |
seed: 42
|
| 36 |
+
batch_size: 64
|
| 37 |
tool_capacity: 16
|
configs/experiment/miniagent-bert-mlp.yaml
CHANGED
|
@@ -9,7 +9,7 @@ defaults:
|
|
| 9 |
seed: 42
|
| 10 |
|
| 11 |
model:
|
| 12 |
-
lr: 0.
|
| 13 |
bert_model: bert-base-uncased
|
| 14 |
|
| 15 |
inst_proj_model:
|
|
@@ -33,5 +33,5 @@ model:
|
|
| 33 |
data:
|
| 34 |
bert_model: bert-base-uncased
|
| 35 |
seed: 42
|
| 36 |
-
batch_size:
|
| 37 |
tool_capacity: 16
|
|
|
|
| 9 |
seed: 42
|
| 10 |
|
| 11 |
model:
|
| 12 |
+
lr: 0.0001
|
| 13 |
bert_model: bert-base-uncased
|
| 14 |
|
| 15 |
inst_proj_model:
|
|
|
|
| 33 |
data:
|
| 34 |
bert_model: bert-base-uncased
|
| 35 |
seed: 42
|
| 36 |
+
batch_size: 64
|
| 37 |
tool_capacity: 16
|
configs/logger/wandb.yaml
CHANGED
|
@@ -13,4 +13,4 @@ wandb:
|
|
| 13 |
# entity: "" # set to name of your wandb team
|
| 14 |
group: ""
|
| 15 |
tags: []
|
| 16 |
-
job_type: ""
|
|
|
|
| 13 |
# entity: "" # set to name of your wandb team
|
| 14 |
group: ""
|
| 15 |
tags: []
|
| 16 |
+
job_type: ""
|
configs/trainer/default.yaml
CHANGED
|
@@ -3,7 +3,7 @@ _target_: lightning.pytorch.trainer.Trainer
|
|
| 3 |
default_root_dir: ${paths.output_dir}
|
| 4 |
|
| 5 |
min_epochs: 1 # prevents early stopping
|
| 6 |
-
max_epochs:
|
| 7 |
|
| 8 |
accelerator: cpu
|
| 9 |
devices: 1
|
|
@@ -11,11 +11,11 @@ devices: 1
|
|
| 11 |
log_every_n_steps: 10
|
| 12 |
|
| 13 |
# mixed precision for extra speed-up
|
| 14 |
-
|
| 15 |
|
| 16 |
# perform a validation loop every N training epochs
|
| 17 |
check_val_every_n_epoch: 1
|
| 18 |
|
| 19 |
# set True to to ensure deterministic results
|
| 20 |
# makes training slower but gives more reproducibility than just setting seeds
|
| 21 |
-
deterministic:
|
|
|
|
| 3 |
default_root_dir: ${paths.output_dir}
|
| 4 |
|
| 5 |
min_epochs: 1 # prevents early stopping
|
| 6 |
+
max_epochs: 201
|
| 7 |
|
| 8 |
accelerator: cpu
|
| 9 |
devices: 1
|
|
|
|
| 11 |
log_every_n_steps: 10
|
| 12 |
|
| 13 |
# mixed precision for extra speed-up
|
| 14 |
+
precision: 16-mixed
|
| 15 |
|
| 16 |
# perform a validation loop every N training epochs
|
| 17 |
check_val_every_n_epoch: 1
|
| 18 |
|
| 19 |
# set True to to ensure deterministic results
|
| 20 |
# makes training slower but gives more reproducibility than just setting seeds
|
| 21 |
+
deterministic: False
|
src/models/{attn_module.py → attn_v1_module.py}
RENAMED
|
@@ -10,41 +10,31 @@ class AttnProjection(nn.Module):
|
|
| 10 |
|
| 11 |
self.query = nn.Parameter(torch.randn(output_length, input_dim))
|
| 12 |
|
| 13 |
-
self.attn = nn.MultiheadAttention(
|
|
|
|
|
|
|
| 14 |
self.norm1 = nn.LayerNorm(input_dim)
|
| 15 |
-
self.dropout1 = nn.Dropout(0.5)
|
| 16 |
-
|
| 17 |
-
# self.self_attn = nn.MultiheadAttention(input_dim, n_heads, batch_first=True)
|
| 18 |
-
# self.norm2 = nn.LayerNorm(input_dim)
|
| 19 |
-
# self.dropout2 = nn.Dropout(0.5)
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# )
|
| 27 |
-
# self.norm3 = nn.LayerNorm(input_dim)
|
| 28 |
-
# self.dropout3 = nn.Dropout(0.5)
|
| 29 |
|
| 30 |
-
nn.init.
|
| 31 |
|
| 32 |
def forward(self, x):
|
| 33 |
B = x.shape[0]
|
| 34 |
|
| 35 |
query = self.query.unsqueeze(0).repeat(B, 1, 1)
|
| 36 |
|
| 37 |
-
z = self.
|
| 38 |
-
|
| 39 |
-
z =
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
# z = self.ff(z) + z
|
| 46 |
-
# z = self.norm3(z)
|
| 47 |
-
# z = self.dropout3(z)
|
| 48 |
|
| 49 |
z = z.contiguous().view(B, -1)
|
| 50 |
|
|
@@ -58,43 +48,28 @@ class BiAttnPrediction(nn.Module):
|
|
| 58 |
|
| 59 |
self.input_dim = input_dim
|
| 60 |
|
| 61 |
-
self.attn1 = nn.MultiheadAttention(
|
|
|
|
|
|
|
| 62 |
self.norm1 = nn.LayerNorm(input_dim)
|
| 63 |
-
self.dropout1 = nn.Dropout(0.
|
| 64 |
|
| 65 |
-
self.attn2 = nn.MultiheadAttention(
|
|
|
|
|
|
|
| 66 |
self.norm2 = nn.LayerNorm(input_dim)
|
| 67 |
-
self.dropout2 = nn.Dropout(0.
|
| 68 |
-
|
| 69 |
-
# self.ff1 = nn.Sequential(
|
| 70 |
-
# nn.Linear(input_dim, input_dim * 4),
|
| 71 |
-
# nn.SiLU(),
|
| 72 |
-
# nn.Dropout(0.5),
|
| 73 |
-
# nn.Linear(input_dim * 4, input_dim),
|
| 74 |
-
# )
|
| 75 |
-
# self.norm_ff1 = nn.LayerNorm(input_dim)
|
| 76 |
-
# self.dropout_ff1 = nn.Dropout(0.5)
|
| 77 |
-
|
| 78 |
-
# self.ff2 = nn.Sequential(
|
| 79 |
-
# nn.Linear(input_dim, input_dim * 4),
|
| 80 |
-
# nn.SiLU(),
|
| 81 |
-
# nn.Dropout(0.5),
|
| 82 |
-
# nn.Linear(input_dim * 4, input_dim),
|
| 83 |
-
# )
|
| 84 |
-
|
| 85 |
-
# self.norm_ff2 = nn.LayerNorm(input_dim)
|
| 86 |
-
# self.dropout_ff2 = nn.Dropout(0.5)
|
| 87 |
|
| 88 |
self.mlp = nn.Sequential(
|
| 89 |
-
nn.Linear(input_dim *
|
| 90 |
nn.SiLU(),
|
| 91 |
-
nn.Dropout(0.
|
| 92 |
nn.Linear(1024, 512),
|
| 93 |
nn.SiLU(),
|
| 94 |
-
nn.Dropout(0.
|
| 95 |
nn.Linear(512, 256),
|
| 96 |
nn.SiLU(),
|
| 97 |
-
nn.Dropout(0.
|
| 98 |
nn.Linear(256, 1),
|
| 99 |
)
|
| 100 |
|
|
@@ -103,28 +78,19 @@ class BiAttnPrediction(nn.Module):
|
|
| 103 |
x1 = x1.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
|
| 104 |
x2 = x2.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
z1 = self.dropout1(z1)
|
| 109 |
-
|
| 110 |
-
z2 = self.attn2(x1, x2, x2)[0] + x2
|
| 111 |
-
z2 = self.norm2(z2)
|
| 112 |
-
z2 = self.dropout2(z2)
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
# z1 = self.dropout_ff1(z1)
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
# z2 = self.dropout_ff2(z2)
|
| 121 |
|
| 122 |
-
# z1 = torch.cat([z1.mean(dim=1), z1.max(dim=1).values], dim=1) # [B, D * 2]
|
| 123 |
-
# z2 = torch.cat([z2.mean(dim=1), z2.max(dim=1).values], dim=1) # [B, D * 2]
|
| 124 |
z1 = z1.mean(dim=1)
|
| 125 |
z2 = z2.mean(dim=1)
|
| 126 |
|
| 127 |
-
z = torch.cat([z1, z2], dim=1) # [B, D * 4]
|
| 128 |
|
| 129 |
z = self.mlp(z)
|
| 130 |
|
|
|
|
| 10 |
|
| 11 |
self.query = nn.Parameter(torch.randn(output_length, input_dim))
|
| 12 |
|
| 13 |
+
self.attn = nn.MultiheadAttention(
|
| 14 |
+
input_dim, n_heads, dropout=0.2, batch_first=True
|
| 15 |
+
)
|
| 16 |
self.norm1 = nn.LayerNorm(input_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
self.self_attn = nn.MultiheadAttention(
|
| 19 |
+
input_dim, n_heads, dropout=0.2, batch_first=True
|
| 20 |
+
)
|
| 21 |
+
self.norm2 = nn.LayerNorm(input_dim)
|
| 22 |
+
self.dropout = nn.Dropout(0.2)
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
nn.init.xavier_normal_(self.query)
|
| 25 |
|
| 26 |
def forward(self, x):
|
| 27 |
B = x.shape[0]
|
| 28 |
|
| 29 |
query = self.query.unsqueeze(0).repeat(B, 1, 1)
|
| 30 |
|
| 31 |
+
z = self.norm1(x)
|
| 32 |
+
z_attn = self.attn(query, z, z)[0]
|
| 33 |
+
z = z_attn
|
| 34 |
|
| 35 |
+
z = self.norm2(z)
|
| 36 |
+
z_attn = self.self_attn(z, z, z)[0]
|
| 37 |
+
z = z + self.dropout(z_attn)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
z = z.contiguous().view(B, -1)
|
| 40 |
|
|
|
|
| 48 |
|
| 49 |
self.input_dim = input_dim
|
| 50 |
|
| 51 |
+
self.attn1 = nn.MultiheadAttention(
|
| 52 |
+
input_dim, n_heads, dropout=0.2, batch_first=True
|
| 53 |
+
)
|
| 54 |
self.norm1 = nn.LayerNorm(input_dim)
|
| 55 |
+
self.dropout1 = nn.Dropout(0.2)
|
| 56 |
|
| 57 |
+
self.attn2 = nn.MultiheadAttention(
|
| 58 |
+
input_dim, n_heads, dropout=0.2, batch_first=True
|
| 59 |
+
)
|
| 60 |
self.norm2 = nn.LayerNorm(input_dim)
|
| 61 |
+
self.dropout2 = nn.Dropout(0.2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
self.mlp = nn.Sequential(
|
| 64 |
+
nn.Linear(input_dim * 3, 1024),
|
| 65 |
nn.SiLU(),
|
| 66 |
+
nn.Dropout(0.2),
|
| 67 |
nn.Linear(1024, 512),
|
| 68 |
nn.SiLU(),
|
| 69 |
+
nn.Dropout(0.2),
|
| 70 |
nn.Linear(512, 256),
|
| 71 |
nn.SiLU(),
|
| 72 |
+
nn.Dropout(0.2),
|
| 73 |
nn.Linear(256, 1),
|
| 74 |
)
|
| 75 |
|
|
|
|
| 78 |
x1 = x1.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
|
| 79 |
x2 = x2.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
|
| 80 |
|
| 81 |
+
x1 = self.norm1(x1)
|
| 82 |
+
x2 = self.norm2(x2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
z1_attn = self.attn1(x2, x1, x1)[0]
|
| 85 |
+
z1 = x1 + self.dropout1(z1_attn)
|
|
|
|
| 86 |
|
| 87 |
+
z2_attn = self.attn2(x1, x2, x2)[0]
|
| 88 |
+
z2 = x2 + self.dropout2(z2_attn)
|
|
|
|
| 89 |
|
|
|
|
|
|
|
| 90 |
z1 = z1.mean(dim=1)
|
| 91 |
z2 = z2.mean(dim=1)
|
| 92 |
|
| 93 |
+
z = torch.cat([z1, z2, torch.abs(z1 - z2)], dim=1) # [B, D * 4]
|
| 94 |
|
| 95 |
z = self.mlp(z)
|
| 96 |
|
src/models/attn_v2_module.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AttnProjection(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, input_dim, n_heads, output_length):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
self.query = nn.Parameter(torch.randn(output_length, input_dim))
|
| 12 |
+
|
| 13 |
+
self.attn = nn.MultiheadAttention(
|
| 14 |
+
input_dim, n_heads, dropout=0.2, batch_first=True
|
| 15 |
+
)
|
| 16 |
+
self.norm1 = nn.LayerNorm(input_dim)
|
| 17 |
+
self.dropout1 = nn.Dropout(0.2)
|
| 18 |
+
|
| 19 |
+
self.self_attn = nn.MultiheadAttention(
|
| 20 |
+
input_dim, n_heads, dropout=0.2, batch_first=True
|
| 21 |
+
)
|
| 22 |
+
self.norm2 = nn.LayerNorm(input_dim)
|
| 23 |
+
self.dropout2 = nn.Dropout(0.2)
|
| 24 |
+
|
| 25 |
+
self.cls_mlp = nn.Sequential(
|
| 26 |
+
nn.Linear(input_dim, input_dim), nn.SiLU(), nn.Dropout(0.2)
|
| 27 |
+
)
|
| 28 |
+
self.norm3 = nn.LayerNorm(input_dim)
|
| 29 |
+
|
| 30 |
+
nn.init.xavier_normal_(self.query)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
B = x.shape[0]
|
| 34 |
+
|
| 35 |
+
query = self.query.unsqueeze(0).repeat(B, 1, 1)
|
| 36 |
+
|
| 37 |
+
x_cls = x[:, 0, :]
|
| 38 |
+
x_other = x[:, 1:, :]
|
| 39 |
+
|
| 40 |
+
z_other = self.norm1(x_other)
|
| 41 |
+
z_attn = self.attn(query, z_other, z_other)[0]
|
| 42 |
+
z_other = self.dropout1(z_attn)
|
| 43 |
+
|
| 44 |
+
z_other = self.norm2(z_other)
|
| 45 |
+
z_attn = self.self_attn(z_other, z_other, z_other)[0]
|
| 46 |
+
z_other = z_other + self.dropout1(z_attn)
|
| 47 |
+
|
| 48 |
+
z_cls = x_cls + self.cls_mlp(self.norm3(x_cls))
|
| 49 |
+
|
| 50 |
+
z = torch.cat([z_cls.unsqueeze(1), z_other], dim=1)
|
| 51 |
+
|
| 52 |
+
z = z.contiguous().view(B, -1)
|
| 53 |
+
|
| 54 |
+
return z
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class BiAttnPrediction(nn.Module):
|
| 58 |
+
|
| 59 |
+
def __init__(self, input_dim, n_heads):
|
| 60 |
+
super().__init__()
|
| 61 |
+
|
| 62 |
+
self.input_dim = input_dim
|
| 63 |
+
|
| 64 |
+
self.attn1 = nn.MultiheadAttention(
|
| 65 |
+
input_dim, n_heads, dropout=0.2, batch_first=True
|
| 66 |
+
)
|
| 67 |
+
self.norm1 = nn.LayerNorm(input_dim)
|
| 68 |
+
self.dropout1 = nn.Dropout(0.2)
|
| 69 |
+
|
| 70 |
+
self.attn2 = nn.MultiheadAttention(
|
| 71 |
+
input_dim, n_heads, dropout=0.2, batch_first=True
|
| 72 |
+
)
|
| 73 |
+
self.norm2 = nn.LayerNorm(input_dim)
|
| 74 |
+
self.dropout2 = nn.Dropout(0.2)
|
| 75 |
+
|
| 76 |
+
self.mlp = nn.Sequential(
|
| 77 |
+
nn.Linear(input_dim * 6, 1024),
|
| 78 |
+
nn.SiLU(),
|
| 79 |
+
nn.Dropout(0.2),
|
| 80 |
+
nn.Linear(1024, 512),
|
| 81 |
+
nn.SiLU(),
|
| 82 |
+
nn.Dropout(0.2),
|
| 83 |
+
nn.Linear(512, 256),
|
| 84 |
+
nn.SiLU(),
|
| 85 |
+
nn.Dropout(0.2),
|
| 86 |
+
nn.Linear(256, 1),
|
| 87 |
+
)
|
| 88 |
+
self.norm3 = nn.LayerNorm(input_dim)
|
| 89 |
+
|
| 90 |
+
def forward(self, x1, x2):
|
| 91 |
+
B = x1.shape[0]
|
| 92 |
+
x1 = x1.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
|
| 93 |
+
x2 = x2.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
|
| 94 |
+
|
| 95 |
+
z1_cls = x1[:, 0, :]
|
| 96 |
+
z2_cls = x2[:, 0, :]
|
| 97 |
+
|
| 98 |
+
x1_other = self.norm1(x1[:, 1:, :])
|
| 99 |
+
x2_other = self.norm2(x2[:, 1:, :])
|
| 100 |
+
|
| 101 |
+
z1_attn = self.attn1(x2_other, x1_other, x1_other)[0]
|
| 102 |
+
z1_other = x1_other + self.dropout1(z1_attn)
|
| 103 |
+
|
| 104 |
+
z2_attn = self.attn2(x1_other, x2_other, x2_other)[0]
|
| 105 |
+
z2_other = x2_other + self.dropout2(z2_attn)
|
| 106 |
+
|
| 107 |
+
z1_other = z1_other.mean(dim=1)
|
| 108 |
+
z2_other = z2_other.mean(dim=1)
|
| 109 |
+
|
| 110 |
+
z = torch.cat(
|
| 111 |
+
[
|
| 112 |
+
z1_cls,
|
| 113 |
+
z1_other,
|
| 114 |
+
z2_cls,
|
| 115 |
+
z2_other,
|
| 116 |
+
torch.abs(z1_cls - z2_cls),
|
| 117 |
+
torch.abs(z1_other - z2_other),
|
| 118 |
+
],
|
| 119 |
+
dim=1,
|
| 120 |
+
) # [B, D * 4]
|
| 121 |
+
|
| 122 |
+
z = self.mlp(z)
|
| 123 |
+
|
| 124 |
+
return z
|
src/models/miniagent_module.py
CHANGED
|
@@ -76,6 +76,7 @@ class MiniAgentModule(LightningModule):
|
|
| 76 |
target = torch.eye(B, device=pred.device).float()
|
| 77 |
|
| 78 |
pos_weight = torch.tensor([B - 1], device=pred.device)
|
|
|
|
| 79 |
loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
|
| 80 |
|
| 81 |
self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
|
|
|
|
| 76 |
target = torch.eye(B, device=pred.device).float()
|
| 77 |
|
| 78 |
pos_weight = torch.tensor([B - 1], device=pred.device)
|
| 79 |
+
# pos_weight = torch.tensor([1], device=pred.device)
|
| 80 |
loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
|
| 81 |
|
| 82 |
self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
|
src/train.py
CHANGED
|
@@ -73,6 +73,8 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
| 73 |
cfg.trainer, callbacks=callbacks, logger=logger
|
| 74 |
)
|
| 75 |
|
|
|
|
|
|
|
| 76 |
object_dict = {
|
| 77 |
"cfg": cfg,
|
| 78 |
"datamodule": datamodule,
|
|
|
|
| 73 |
cfg.trainer, callbacks=callbacks, logger=logger
|
| 74 |
)
|
| 75 |
|
| 76 |
+
trainer.fit_loop.max_epochs = 150
|
| 77 |
+
|
| 78 |
object_dict = {
|
| 79 |
"cfg": cfg,
|
| 80 |
"datamodule": datamodule,
|