YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Adversarial-MidiBERT

The description is generated by Grok3.

Model Details

Model Description

Adversarial-MidiBERT is a transformer-based model designed for symbolic music understanding, leveraging large-scale adversarial pre-training. It builds upon the MidiBERT-Piano framework and extends it with adversarial pre-training techniques to enhance performance on music-related tasks. The model processes symbolic music data in an octuple format and can be fine-tuned for various downstream tasks such as music generation, classification, and analysis.

  • Architecture: Transformer-based (based on MidiBERT)
  • Input Format: Octuple representation of symbolic music (batch_size, sequence_length, 8)
  • Output Format: Hidden states of dimension [batch_size, sequence_length, 768]
  • Hidden Size: 768
  • Training Objective: Adversarial pre-training followed by task-specific fine-tuning
  • Tasks Supported: Symbolic music understanding tasks

Training Data

The model was pre-trained and fine-tuned on the following datasets:

  • POP1K7: A dataset of popular music MIDI files.
  • POP909: A dataset of 909 pop songs in MIDI format.
  • Pinaist8: A dataset of piano performances.
  • EMOPIA: A dataset for emotion-based music analysis.
  • GiantMIDI: A large-scale MIDI dataset.

For details on dataset preprocessing and dictionary files, refer to the PianoBART repository. Pre-training data should be placed in ./Data/output_pretrain.

Usage

Installation

git clone https://huggingface.co/RS2002/Adversarial-MidiBERT

Please ensure that the model.py and Octuple.pkl files are located in the same folder.

Example Code

import torch
from model import Adversarial_MidiBERT

# Load the model
model = Adversarial_MidiBERT.from_pretrained("RS2002/Adversarial-MidiBERT")

# Example input
input_ids = torch.randint(0, 10, (2, 1024, 8))
attention_mask = torch.zeros((2, 1024))

# Forward pass
y = model(input_ids, attention_mask)
print(y.last_hidden_state.shape)  # Output: [2, 1024, 768]
Downloads last month
18
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support