qninhdt commited on
Commit
3f792e3
·
1 Parent(s): 8db8077
configs/experiment/miniagent-bert-attn-m8.yaml CHANGED
@@ -15,22 +15,22 @@ model:
15
  inst_proj_model:
16
  _target_: src.models.attn_module.AttnProjection
17
  input_dim: 768
18
- n_heads: 1
19
  output_length: 8
20
 
21
  tool_proj_model:
22
  _target_: src.models.attn_module.AttnProjection
23
  input_dim: 768
24
- n_heads: 1
25
  output_length: 8
26
 
27
  pred_model:
28
  _target_: src.models.attn_module.BiAttnPrediction
29
  input_dim: 768
30
- n_heads: 1
31
 
32
  data:
33
  bert_model: bert-base-uncased
34
  seed: 42
35
- batch_size: 128
36
  tool_capacity: 16
 
15
  inst_proj_model:
16
  _target_: src.models.attn_module.AttnProjection
17
  input_dim: 768
18
+ n_heads: 4
19
  output_length: 8
20
 
21
  tool_proj_model:
22
  _target_: src.models.attn_module.AttnProjection
23
  input_dim: 768
24
+ n_heads: 4
25
  output_length: 8
26
 
27
  pred_model:
28
  _target_: src.models.attn_module.BiAttnPrediction
29
  input_dim: 768
30
+ n_heads: 4
31
 
32
  data:
33
  bert_model: bert-base-uncased
34
  seed: 42
35
+ batch_size: 64
36
  tool_capacity: 16
configs/experiment/{miniagent-bert-attn.yaml → miniagent-bert-attn-v1.yaml} RENAMED
@@ -13,24 +13,24 @@ model:
13
  bert_model: bert-base-uncased
14
 
15
  inst_proj_model:
16
- _target_: src.models.attn_module.AttnProjection
17
  input_dim: 768
18
- n_heads: 1
19
  output_length: 16
20
 
21
  tool_proj_model:
22
- _target_: src.models.attn_module.AttnProjection
23
  input_dim: 768
24
- n_heads: 1
25
  output_length: 16
26
 
27
  pred_model:
28
- _target_: src.models.attn_module.BiAttnPrediction
29
  input_dim: 768
30
- n_heads: 1
31
 
32
  data:
33
  bert_model: bert-base-uncased
34
  seed: 42
35
- batch_size: 128
36
  tool_capacity: 16
 
13
  bert_model: bert-base-uncased
14
 
15
  inst_proj_model:
16
+ _target_: src.models.attn_v1_module.AttnProjection
17
  input_dim: 768
18
+ n_heads: 4
19
  output_length: 16
20
 
21
  tool_proj_model:
22
+ _target_: src.models.attn_v1_module.AttnProjection
23
  input_dim: 768
24
+ n_heads: 4
25
  output_length: 16
26
 
27
  pred_model:
28
+ _target_: src.models.attn_v1_module.BiAttnPrediction
29
  input_dim: 768
30
+ n_heads: 4
31
 
32
  data:
33
  bert_model: bert-base-uncased
34
  seed: 42
35
+ batch_size: 64
36
  tool_capacity: 16
configs/experiment/miniagent-bert-attn-v2.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /data: mixed
5
+ - override /model: miniagent
6
+ - override /callbacks: default
7
+ - override /trainer: gpu
8
+
9
+ seed: 42
10
+
11
+ model:
12
+ lr: 1e-5
13
+ bert_model: bert-base-uncased
14
+
15
+ inst_proj_model:
16
+ _target_: src.models.attn_v2_module.AttnProjection
17
+ input_dim: 768
18
+ n_heads: 4
19
+ output_length: 16
20
+
21
+ tool_proj_model:
22
+ _target_: src.models.attn_v2_module.AttnProjection
23
+ input_dim: 768
24
+ n_heads: 4
25
+ output_length: 16
26
+
27
+ pred_model:
28
+ _target_: src.models.attn_v2_module.BiAttnPrediction
29
+ input_dim: 768
30
+ n_heads: 4
31
+
32
+ data:
33
+ bert_model: bert-base-uncased
34
+ seed: 42
35
+ batch_size: 64
36
+ tool_capacity: 16
configs/experiment/miniagent-bert-mlp-abs_diff-mult.yaml CHANGED
@@ -9,7 +9,7 @@ defaults:
9
  seed: 42
10
 
11
  model:
12
- lr: 0.001
13
  bert_model: bert-base-uncased
14
 
15
  inst_proj_model:
@@ -33,5 +33,5 @@ model:
33
  data:
34
  bert_model: bert-base-uncased
35
  seed: 42
36
- batch_size: 128
37
  tool_capacity: 16
 
9
  seed: 42
10
 
11
  model:
12
+ lr: 0.0001
13
  bert_model: bert-base-uncased
14
 
15
  inst_proj_model:
 
33
  data:
34
  bert_model: bert-base-uncased
35
  seed: 42
36
+ batch_size: 64
37
  tool_capacity: 16
configs/experiment/miniagent-bert-mlp-abs_diff.yaml CHANGED
@@ -9,7 +9,7 @@ defaults:
9
  seed: 42
10
 
11
  model:
12
- lr: 0.001
13
  bert_model: bert-base-uncased
14
 
15
  inst_proj_model:
@@ -33,5 +33,5 @@ model:
33
  data:
34
  bert_model: bert-base-uncased
35
  seed: 42
36
- batch_size: 128
37
  tool_capacity: 16
 
9
  seed: 42
10
 
11
  model:
12
+ lr: 0.0001
13
  bert_model: bert-base-uncased
14
 
15
  inst_proj_model:
 
33
  data:
34
  bert_model: bert-base-uncased
35
  seed: 42
36
+ batch_size: 64
37
  tool_capacity: 16
configs/experiment/miniagent-bert-mlp-mult.yaml CHANGED
@@ -9,7 +9,7 @@ defaults:
9
  seed: 42
10
 
11
  model:
12
- lr: 0.001
13
  bert_model: bert-base-uncased
14
 
15
  inst_proj_model:
@@ -33,5 +33,5 @@ model:
33
  data:
34
  bert_model: bert-base-uncased
35
  seed: 42
36
- batch_size: 128
37
  tool_capacity: 16
 
9
  seed: 42
10
 
11
  model:
12
+ lr: 0.0001
13
  bert_model: bert-base-uncased
14
 
15
  inst_proj_model:
 
33
  data:
34
  bert_model: bert-base-uncased
35
  seed: 42
36
+ batch_size: 64
37
  tool_capacity: 16
configs/experiment/miniagent-bert-mlp.yaml CHANGED
@@ -9,7 +9,7 @@ defaults:
9
  seed: 42
10
 
11
  model:
12
- lr: 0.001
13
  bert_model: bert-base-uncased
14
 
15
  inst_proj_model:
@@ -33,5 +33,5 @@ model:
33
  data:
34
  bert_model: bert-base-uncased
35
  seed: 42
36
- batch_size: 128
37
  tool_capacity: 16
 
9
  seed: 42
10
 
11
  model:
12
+ lr: 0.0001
13
  bert_model: bert-base-uncased
14
 
15
  inst_proj_model:
 
33
  data:
34
  bert_model: bert-base-uncased
35
  seed: 42
36
+ batch_size: 64
37
  tool_capacity: 16
configs/logger/wandb.yaml CHANGED
@@ -13,4 +13,4 @@ wandb:
13
  # entity: "" # set to name of your wandb team
14
  group: ""
15
  tags: []
16
- job_type: ""
 
13
  # entity: "" # set to name of your wandb team
14
  group: ""
15
  tags: []
16
+ job_type: ""
configs/trainer/default.yaml CHANGED
@@ -3,7 +3,7 @@ _target_: lightning.pytorch.trainer.Trainer
3
  default_root_dir: ${paths.output_dir}
4
 
5
  min_epochs: 1 # prevents early stopping
6
- max_epochs: 50
7
 
8
  accelerator: cpu
9
  devices: 1
@@ -11,11 +11,11 @@ devices: 1
11
  log_every_n_steps: 10
12
 
13
  # mixed precision for extra speed-up
14
- # precision: 16
15
 
16
  # perform a validation loop every N training epochs
17
  check_val_every_n_epoch: 1
18
 
19
  # set True to to ensure deterministic results
20
  # makes training slower but gives more reproducibility than just setting seeds
21
- deterministic: True
 
3
  default_root_dir: ${paths.output_dir}
4
 
5
  min_epochs: 1 # prevents early stopping
6
+ max_epochs: 201
7
 
8
  accelerator: cpu
9
  devices: 1
 
11
  log_every_n_steps: 10
12
 
13
  # mixed precision for extra speed-up
14
+ precision: 16-mixed
15
 
16
  # perform a validation loop every N training epochs
17
  check_val_every_n_epoch: 1
18
 
19
  # set True to to ensure deterministic results
20
  # makes training slower but gives more reproducibility than just setting seeds
21
+ deterministic: False
src/models/{attn_module.py → attn_v1_module.py} RENAMED
@@ -10,41 +10,31 @@ class AttnProjection(nn.Module):
10
 
11
  self.query = nn.Parameter(torch.randn(output_length, input_dim))
12
 
13
- self.attn = nn.MultiheadAttention(input_dim, n_heads, batch_first=True)
 
 
14
  self.norm1 = nn.LayerNorm(input_dim)
15
- self.dropout1 = nn.Dropout(0.5)
16
-
17
- # self.self_attn = nn.MultiheadAttention(input_dim, n_heads, batch_first=True)
18
- # self.norm2 = nn.LayerNorm(input_dim)
19
- # self.dropout2 = nn.Dropout(0.5)
20
 
21
- # self.ff = nn.Sequential(
22
- # nn.Linear(input_dim, input_dim * 4),
23
- # nn.SiLU(),
24
- # nn.Dropout(0.5),
25
- # nn.Linear(input_dim * 4, input_dim),
26
- # )
27
- # self.norm3 = nn.LayerNorm(input_dim)
28
- # self.dropout3 = nn.Dropout(0.5)
29
 
30
- nn.init.xavier_uniform_(self.query)
31
 
32
  def forward(self, x):
33
  B = x.shape[0]
34
 
35
  query = self.query.unsqueeze(0).repeat(B, 1, 1)
36
 
37
- z = self.attn(query, x, x)[0]
38
- z = self.norm1(z)
39
- z = self.dropout1(z)
40
 
41
- # z = self.self_attn(z, z, z)[0] + z
42
- # z = self.norm2(z)
43
- # z = self.dropout2(z)
44
-
45
- # z = self.ff(z) + z
46
- # z = self.norm3(z)
47
- # z = self.dropout3(z)
48
 
49
  z = z.contiguous().view(B, -1)
50
 
@@ -58,43 +48,28 @@ class BiAttnPrediction(nn.Module):
58
 
59
  self.input_dim = input_dim
60
 
61
- self.attn1 = nn.MultiheadAttention(input_dim, n_heads, batch_first=True)
 
 
62
  self.norm1 = nn.LayerNorm(input_dim)
63
- self.dropout1 = nn.Dropout(0.5)
64
 
65
- self.attn2 = nn.MultiheadAttention(input_dim, n_heads, batch_first=True)
 
 
66
  self.norm2 = nn.LayerNorm(input_dim)
67
- self.dropout2 = nn.Dropout(0.5)
68
-
69
- # self.ff1 = nn.Sequential(
70
- # nn.Linear(input_dim, input_dim * 4),
71
- # nn.SiLU(),
72
- # nn.Dropout(0.5),
73
- # nn.Linear(input_dim * 4, input_dim),
74
- # )
75
- # self.norm_ff1 = nn.LayerNorm(input_dim)
76
- # self.dropout_ff1 = nn.Dropout(0.5)
77
-
78
- # self.ff2 = nn.Sequential(
79
- # nn.Linear(input_dim, input_dim * 4),
80
- # nn.SiLU(),
81
- # nn.Dropout(0.5),
82
- # nn.Linear(input_dim * 4, input_dim),
83
- # )
84
-
85
- # self.norm_ff2 = nn.LayerNorm(input_dim)
86
- # self.dropout_ff2 = nn.Dropout(0.5)
87
 
88
  self.mlp = nn.Sequential(
89
- nn.Linear(input_dim * 2, 1024),
90
  nn.SiLU(),
91
- nn.Dropout(0.5),
92
  nn.Linear(1024, 512),
93
  nn.SiLU(),
94
- nn.Dropout(0.5),
95
  nn.Linear(512, 256),
96
  nn.SiLU(),
97
- nn.Dropout(0.5),
98
  nn.Linear(256, 1),
99
  )
100
 
@@ -103,28 +78,19 @@ class BiAttnPrediction(nn.Module):
103
  x1 = x1.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
104
  x2 = x2.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
105
 
106
- z1 = self.attn1(x2, x1, x1)[0] + x1
107
- z1 = self.norm1(z1)
108
- z1 = self.dropout1(z1)
109
-
110
- z2 = self.attn2(x1, x2, x2)[0] + x2
111
- z2 = self.norm2(z2)
112
- z2 = self.dropout2(z2)
113
 
114
- # z1 = self.ff1(z1) + z1
115
- # z1 = self.norm_ff1(z1)
116
- # z1 = self.dropout_ff1(z1)
117
 
118
- # z2 = self.ff2(z2) + z2
119
- # z2 = self.norm_ff2(z2)
120
- # z2 = self.dropout_ff2(z2)
121
 
122
- # z1 = torch.cat([z1.mean(dim=1), z1.max(dim=1).values], dim=1) # [B, D * 2]
123
- # z2 = torch.cat([z2.mean(dim=1), z2.max(dim=1).values], dim=1) # [B, D * 2]
124
  z1 = z1.mean(dim=1)
125
  z2 = z2.mean(dim=1)
126
 
127
- z = torch.cat([z1, z2], dim=1) # [B, D * 4]
128
 
129
  z = self.mlp(z)
130
 
 
10
 
11
  self.query = nn.Parameter(torch.randn(output_length, input_dim))
12
 
13
+ self.attn = nn.MultiheadAttention(
14
+ input_dim, n_heads, dropout=0.2, batch_first=True
15
+ )
16
  self.norm1 = nn.LayerNorm(input_dim)
 
 
 
 
 
17
 
18
+ self.self_attn = nn.MultiheadAttention(
19
+ input_dim, n_heads, dropout=0.2, batch_first=True
20
+ )
21
+ self.norm2 = nn.LayerNorm(input_dim)
22
+ self.dropout = nn.Dropout(0.2)
 
 
 
23
 
24
+ nn.init.xavier_normal_(self.query)
25
 
26
  def forward(self, x):
27
  B = x.shape[0]
28
 
29
  query = self.query.unsqueeze(0).repeat(B, 1, 1)
30
 
31
+ z = self.norm1(x)
32
+ z_attn = self.attn(query, z, z)[0]
33
+ z = z_attn
34
 
35
+ z = self.norm2(z)
36
+ z_attn = self.self_attn(z, z, z)[0]
37
+ z = z + self.dropout(z_attn)
 
 
 
 
38
 
39
  z = z.contiguous().view(B, -1)
40
 
 
48
 
49
  self.input_dim = input_dim
50
 
51
+ self.attn1 = nn.MultiheadAttention(
52
+ input_dim, n_heads, dropout=0.2, batch_first=True
53
+ )
54
  self.norm1 = nn.LayerNorm(input_dim)
55
+ self.dropout1 = nn.Dropout(0.2)
56
 
57
+ self.attn2 = nn.MultiheadAttention(
58
+ input_dim, n_heads, dropout=0.2, batch_first=True
59
+ )
60
  self.norm2 = nn.LayerNorm(input_dim)
61
+ self.dropout2 = nn.Dropout(0.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  self.mlp = nn.Sequential(
64
+ nn.Linear(input_dim * 3, 1024),
65
  nn.SiLU(),
66
+ nn.Dropout(0.2),
67
  nn.Linear(1024, 512),
68
  nn.SiLU(),
69
+ nn.Dropout(0.2),
70
  nn.Linear(512, 256),
71
  nn.SiLU(),
72
+ nn.Dropout(0.2),
73
  nn.Linear(256, 1),
74
  )
75
 
 
78
  x1 = x1.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
79
  x2 = x2.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
80
 
81
+ x1 = self.norm1(x1)
82
+ x2 = self.norm2(x2)
 
 
 
 
 
83
 
84
+ z1_attn = self.attn1(x2, x1, x1)[0]
85
+ z1 = x1 + self.dropout1(z1_attn)
 
86
 
87
+ z2_attn = self.attn2(x1, x2, x2)[0]
88
+ z2 = x2 + self.dropout2(z2_attn)
 
89
 
 
 
90
  z1 = z1.mean(dim=1)
91
  z2 = z2.mean(dim=1)
92
 
93
+ z = torch.cat([z1, z2, torch.abs(z1 - z2)], dim=1) # [B, D * 4]
94
 
95
  z = self.mlp(z)
96
 
src/models/attn_v2_module.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class AttnProjection(nn.Module):
7
+
8
+ def __init__(self, input_dim, n_heads, output_length):
9
+ super().__init__()
10
+
11
+ self.query = nn.Parameter(torch.randn(output_length, input_dim))
12
+
13
+ self.attn = nn.MultiheadAttention(
14
+ input_dim, n_heads, dropout=0.2, batch_first=True
15
+ )
16
+ self.norm1 = nn.LayerNorm(input_dim)
17
+ self.dropout1 = nn.Dropout(0.2)
18
+
19
+ self.self_attn = nn.MultiheadAttention(
20
+ input_dim, n_heads, dropout=0.2, batch_first=True
21
+ )
22
+ self.norm2 = nn.LayerNorm(input_dim)
23
+ self.dropout2 = nn.Dropout(0.2)
24
+
25
+ self.cls_mlp = nn.Sequential(
26
+ nn.Linear(input_dim, input_dim), nn.SiLU(), nn.Dropout(0.2)
27
+ )
28
+ self.norm3 = nn.LayerNorm(input_dim)
29
+
30
+ nn.init.xavier_normal_(self.query)
31
+
32
+ def forward(self, x):
33
+ B = x.shape[0]
34
+
35
+ query = self.query.unsqueeze(0).repeat(B, 1, 1)
36
+
37
+ x_cls = x[:, 0, :]
38
+ x_other = x[:, 1:, :]
39
+
40
+ z_other = self.norm1(x_other)
41
+ z_attn = self.attn(query, z_other, z_other)[0]
42
+ z_other = self.dropout1(z_attn)
43
+
44
+ z_other = self.norm2(z_other)
45
+ z_attn = self.self_attn(z_other, z_other, z_other)[0]
46
+ z_other = z_other + self.dropout1(z_attn)
47
+
48
+ z_cls = x_cls + self.cls_mlp(self.norm3(x_cls))
49
+
50
+ z = torch.cat([z_cls.unsqueeze(1), z_other], dim=1)
51
+
52
+ z = z.contiguous().view(B, -1)
53
+
54
+ return z
55
+
56
+
57
+ class BiAttnPrediction(nn.Module):
58
+
59
+ def __init__(self, input_dim, n_heads):
60
+ super().__init__()
61
+
62
+ self.input_dim = input_dim
63
+
64
+ self.attn1 = nn.MultiheadAttention(
65
+ input_dim, n_heads, dropout=0.2, batch_first=True
66
+ )
67
+ self.norm1 = nn.LayerNorm(input_dim)
68
+ self.dropout1 = nn.Dropout(0.2)
69
+
70
+ self.attn2 = nn.MultiheadAttention(
71
+ input_dim, n_heads, dropout=0.2, batch_first=True
72
+ )
73
+ self.norm2 = nn.LayerNorm(input_dim)
74
+ self.dropout2 = nn.Dropout(0.2)
75
+
76
+ self.mlp = nn.Sequential(
77
+ nn.Linear(input_dim * 6, 1024),
78
+ nn.SiLU(),
79
+ nn.Dropout(0.2),
80
+ nn.Linear(1024, 512),
81
+ nn.SiLU(),
82
+ nn.Dropout(0.2),
83
+ nn.Linear(512, 256),
84
+ nn.SiLU(),
85
+ nn.Dropout(0.2),
86
+ nn.Linear(256, 1),
87
+ )
88
+ self.norm3 = nn.LayerNorm(input_dim)
89
+
90
+ def forward(self, x1, x2):
91
+ B = x1.shape[0]
92
+ x1 = x1.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
93
+ x2 = x2.view(B, -1, self.input_dim) # [B, M x D] -> [B, M, D]
94
+
95
+ z1_cls = x1[:, 0, :]
96
+ z2_cls = x2[:, 0, :]
97
+
98
+ x1_other = self.norm1(x1[:, 1:, :])
99
+ x2_other = self.norm2(x2[:, 1:, :])
100
+
101
+ z1_attn = self.attn1(x2_other, x1_other, x1_other)[0]
102
+ z1_other = x1_other + self.dropout1(z1_attn)
103
+
104
+ z2_attn = self.attn2(x1_other, x2_other, x2_other)[0]
105
+ z2_other = x2_other + self.dropout2(z2_attn)
106
+
107
+ z1_other = z1_other.mean(dim=1)
108
+ z2_other = z2_other.mean(dim=1)
109
+
110
+ z = torch.cat(
111
+ [
112
+ z1_cls,
113
+ z1_other,
114
+ z2_cls,
115
+ z2_other,
116
+ torch.abs(z1_cls - z2_cls),
117
+ torch.abs(z1_other - z2_other),
118
+ ],
119
+ dim=1,
120
+ ) # [B, D * 4]
121
+
122
+ z = self.mlp(z)
123
+
124
+ return z
src/models/miniagent_module.py CHANGED
@@ -76,6 +76,7 @@ class MiniAgentModule(LightningModule):
76
  target = torch.eye(B, device=pred.device).float()
77
 
78
  pos_weight = torch.tensor([B - 1], device=pred.device)
 
79
  loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
80
 
81
  self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
 
76
  target = torch.eye(B, device=pred.device).float()
77
 
78
  pos_weight = torch.tensor([B - 1], device=pred.device)
79
+ # pos_weight = torch.tensor([1], device=pred.device)
80
  loss = F.binary_cross_entropy_with_logits(pred, target, pos_weight=pos_weight)
81
 
82
  self.log("train/loss", loss, on_step=True, sync_dist=True, prog_bar=True)
src/train.py CHANGED
@@ -73,6 +73,8 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
73
  cfg.trainer, callbacks=callbacks, logger=logger
74
  )
75
 
 
 
76
  object_dict = {
77
  "cfg": cfg,
78
  "datamodule": datamodule,
 
73
  cfg.trainer, callbacks=callbacks, logger=logger
74
  )
75
 
76
+ trainer.fit_loop.max_epochs = 150
77
+
78
  object_dict = {
79
  "cfg": cfg,
80
  "datamodule": datamodule,