RS2002 commited on
Commit
5ce3e16
·
verified ·
1 Parent(s): 4921b3e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +94 -9
README.md CHANGED
@@ -1,9 +1,94 @@
1
- ---
2
- tags:
3
- - model_hub_mixin
4
- - pytorch_model_hub_mixin
5
- ---
6
-
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Library: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adversarial-MidiBERT
2
+
3
+ The description is generated by Grok3.
4
+
5
+
6
+
7
+ ## Model Details
8
+
9
+ - **Model Name**: Adversarial-MidiBERT
10
+
11
+ - **Model Type**: Transformer-based model for symbolic music understanding
12
+
13
+ - **Version**: 1.0
14
+
15
+ - **Release Date**: August 2025
16
+
17
+ - **Developers**: Zijian Zhao
18
+
19
+ - **Organization**: SYSU
20
+
21
+ - **License**: Apache License 2.0
22
+
23
+ - **Paper**: [Let Network Decide What to Learn: Symbolic Music Understanding Model Based on Large-scale Adversarial Pre-training](https://dl.acm.org/doi/abs/10.1145/3731715.3733483), ACM ICMR 2025
24
+
25
+ - Citation:
26
+
27
+ ```
28
+ @inproceedings{zhao2025let,
29
+ title={Let Network Decide What to Learn: Symbolic Music Understanding Model Based on Large-scale Adversarial Pre-training},
30
+ author={Zhao, Zijian},
31
+ booktitle={Proceedings of the 2025 International Conference on Multimedia Retrieval},
32
+ pages={2128--2132},
33
+ year={2025}
34
+ }
35
+ ```
36
+
37
+ - **Contact**: [email protected]
38
+
39
+ - **Repository**: https://github.com/RS2002/Adversarial-MidiBERT
40
+
41
+
42
+
43
+ ## Model Description
44
+
45
+ Adversarial-MidiBERT is a transformer-based model designed for symbolic music understanding, leveraging large-scale adversarial pre-training. It builds upon the [MidiBERT-Piano](https://github.com/wazenmai/MIDI-BERT) framework and extends it with adversarial pre-training techniques to enhance performance on music-related tasks. The model processes symbolic music data in an octuple format and can be fine-tuned for various downstream tasks such as music generation, classification, and analysis.
46
+
47
+ - **Architecture**: Transformer-based (based on MidiBERT)
48
+ - **Input Format**: Octuple representation of symbolic music (batch_size, sequence_length, 8)
49
+ - **Output Format**: Hidden states of dimension [batch_size, sequence_length, 768]
50
+ - **Hidden Size**: 768
51
+ - **Training Objective**: Adversarial pre-training followed by task-specific fine-tuning
52
+ - **Tasks Supported**: Symbolic music understanding tasks
53
+
54
+ ## Training Data
55
+
56
+ The model was pre-trained and fine-tuned on the following datasets:
57
+
58
+ - **POP1K7**: A dataset of popular music MIDI files.
59
+ - **POP909**: A dataset of 909 pop songs in MIDI format.
60
+ - **Pinaist8**: A dataset of piano performances.
61
+ - **EMOPIA**: A dataset for emotion-based music analysis.
62
+ - **GiantMIDI**: A large-scale MIDI dataset.
63
+
64
+ For details on dataset preprocessing and dictionary files, refer to the [PianoBART repository](https://github.com/RS2002/PianoBart). Pre-training data should be placed in `./Data/output_pretrain`.
65
+
66
+
67
+
68
+ ## Usage
69
+
70
+ ### Installation
71
+
72
+ ```shell
73
+ git clone https://huggingface.co/RS2002/Adversarial-MidiBERT
74
+ ```
75
+
76
+ Please ensure that the `model.py` and `Octuple.pkl` files are located in the same folder.
77
+
78
+ ### Example Code
79
+
80
+ ```python
81
+ import torch
82
+ from model import Adversarial_MidiBERT
83
+
84
+ # Load the model
85
+ model = Adversarial_MidiBERT.from_pretrained("RS2002/Adversarial-MidiBERT")
86
+
87
+ # Example input
88
+ input_ids = torch.randint(0, 10, (2, 1024, 8))
89
+ attention_mask = torch.zeros((2, 1024))
90
+
91
+ # Forward pass
92
+ y = model(input_ids, attention_mask)
93
+ print(y.last_hidden_state.shape) # Output: [2, 1024, 768]
94
+ ```