Update README.md
Browse files
README.md
CHANGED
@@ -1,9 +1,94 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adversarial-MidiBERT
|
2 |
+
|
3 |
+
The description is generated by Grok3.
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
## Model Details
|
8 |
+
|
9 |
+
- **Model Name**: Adversarial-MidiBERT
|
10 |
+
|
11 |
+
- **Model Type**: Transformer-based model for symbolic music understanding
|
12 |
+
|
13 |
+
- **Version**: 1.0
|
14 |
+
|
15 |
+
- **Release Date**: August 2025
|
16 |
+
|
17 |
+
- **Developers**: Zijian Zhao
|
18 |
+
|
19 |
+
- **Organization**: SYSU
|
20 |
+
|
21 |
+
- **License**: Apache License 2.0
|
22 |
+
|
23 |
+
- **Paper**: [Let Network Decide What to Learn: Symbolic Music Understanding Model Based on Large-scale Adversarial Pre-training](https://dl.acm.org/doi/abs/10.1145/3731715.3733483), ACM ICMR 2025
|
24 |
+
|
25 |
+
- Citation:
|
26 |
+
|
27 |
+
```
|
28 |
+
@inproceedings{zhao2025let,
|
29 |
+
title={Let Network Decide What to Learn: Symbolic Music Understanding Model Based on Large-scale Adversarial Pre-training},
|
30 |
+
author={Zhao, Zijian},
|
31 |
+
booktitle={Proceedings of the 2025 International Conference on Multimedia Retrieval},
|
32 |
+
pages={2128--2132},
|
33 |
+
year={2025}
|
34 |
+
}
|
35 |
+
```
|
36 |
+
|
37 |
+
- **Contact**: [email protected]
|
38 |
+
|
39 |
+
- **Repository**: https://github.com/RS2002/Adversarial-MidiBERT
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
## Model Description
|
44 |
+
|
45 |
+
Adversarial-MidiBERT is a transformer-based model designed for symbolic music understanding, leveraging large-scale adversarial pre-training. It builds upon the [MidiBERT-Piano](https://github.com/wazenmai/MIDI-BERT) framework and extends it with adversarial pre-training techniques to enhance performance on music-related tasks. The model processes symbolic music data in an octuple format and can be fine-tuned for various downstream tasks such as music generation, classification, and analysis.
|
46 |
+
|
47 |
+
- **Architecture**: Transformer-based (based on MidiBERT)
|
48 |
+
- **Input Format**: Octuple representation of symbolic music (batch_size, sequence_length, 8)
|
49 |
+
- **Output Format**: Hidden states of dimension [batch_size, sequence_length, 768]
|
50 |
+
- **Hidden Size**: 768
|
51 |
+
- **Training Objective**: Adversarial pre-training followed by task-specific fine-tuning
|
52 |
+
- **Tasks Supported**: Symbolic music understanding tasks
|
53 |
+
|
54 |
+
## Training Data
|
55 |
+
|
56 |
+
The model was pre-trained and fine-tuned on the following datasets:
|
57 |
+
|
58 |
+
- **POP1K7**: A dataset of popular music MIDI files.
|
59 |
+
- **POP909**: A dataset of 909 pop songs in MIDI format.
|
60 |
+
- **Pinaist8**: A dataset of piano performances.
|
61 |
+
- **EMOPIA**: A dataset for emotion-based music analysis.
|
62 |
+
- **GiantMIDI**: A large-scale MIDI dataset.
|
63 |
+
|
64 |
+
For details on dataset preprocessing and dictionary files, refer to the [PianoBART repository](https://github.com/RS2002/PianoBart). Pre-training data should be placed in `./Data/output_pretrain`.
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
## Usage
|
69 |
+
|
70 |
+
### Installation
|
71 |
+
|
72 |
+
```shell
|
73 |
+
git clone https://huggingface.co/RS2002/Adversarial-MidiBERT
|
74 |
+
```
|
75 |
+
|
76 |
+
Please ensure that the `model.py` and `Octuple.pkl` files are located in the same folder.
|
77 |
+
|
78 |
+
### Example Code
|
79 |
+
|
80 |
+
```python
|
81 |
+
import torch
|
82 |
+
from model import Adversarial_MidiBERT
|
83 |
+
|
84 |
+
# Load the model
|
85 |
+
model = Adversarial_MidiBERT.from_pretrained("RS2002/Adversarial-MidiBERT")
|
86 |
+
|
87 |
+
# Example input
|
88 |
+
input_ids = torch.randint(0, 10, (2, 1024, 8))
|
89 |
+
attention_mask = torch.zeros((2, 1024))
|
90 |
+
|
91 |
+
# Forward pass
|
92 |
+
y = model(input_ids, attention_mask)
|
93 |
+
print(y.last_hidden_state.shape) # Output: [2, 1024, 768]
|
94 |
+
```
|