nazemi commited on
Commit
6a5d9aa
·
verified ·
1 Parent(s): ada8a65

Upload fine_ast.py

Browse files
Files changed (1) hide show
  1. fine_ast.py +83 -0
fine_ast.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from datasets import load_dataset
3
+ dataset = load_dataset("audiofolder", data_dir="data")
4
+ #dataset= dataset["train"].train_test_split(seed=42, shuffle=True, test_size=0.1)
5
+ from transformers import ASTForAudioClassification
6
+ from transformers import ASTFeatureExtractor
7
+ from transformers import TrainingArguments
8
+ import numpy as np
9
+ from transformers import Trainer
10
+ import evaluate
11
+ batch_size = 8
12
+ gradient_accumulation_steps = 1
13
+ num_train_epochs = 10
14
+ labels=["noise","speech"]
15
+ num_labels = 2
16
+ max_duration = 5
17
+ model_id="bookbot/distil-ast-audioset"
18
+ model_name = "speechVSnoise"
19
+
20
+ label2id, id2label = dict(), dict()
21
+ for i, label in enumerate(labels):
22
+ label2id[label] = str(i)
23
+ id2label[str(i)] = label
24
+
25
+ model = ASTForAudioClassification.from_pretrained(
26
+ model_id,
27
+ num_labels=num_labels, label2id=label2id,
28
+ id2label=id2label,
29
+ ignore_mismatched_sizes=True
30
+ )
31
+ feature_extractor = ASTFeatureExtractor.from_pretrained(
32
+ model_id, do_normalize=True, return_attention_mask=False
33
+ )
34
+
35
+
36
+ def preprocess_function(examples):
37
+ audio_arrays = [x["array"] for x in examples["audio"]]
38
+ inputs = feature_extractor(
39
+ audio_arrays,
40
+ sampling_rate=feature_extractor.sampling_rate,
41
+ max_length=int(feature_extractor.sampling_rate * max_duration),
42
+ truncation=True,
43
+
44
+ )
45
+ return inputs
46
+ dataset_encoded = dataset.map(
47
+ preprocess_function,
48
+ batched=True,
49
+ batch_size=1674,
50
+ num_proc=1,
51
+ )
52
+ metric = evaluate.load("accuracy")
53
+ def compute_metrics(eval_pred):
54
+ predictions = np.argmax(eval_pred.predictions, axis=1)
55
+ return metric.compute(predictions=predictions, references=eval_pred.label_ids)
56
+ training_args = TrainingArguments(
57
+ f"{model_name}",
58
+ evaluation_strategy="epoch",
59
+ save_strategy="epoch",
60
+ learning_rate=5e-5,
61
+ per_device_train_batch_size=batch_size,
62
+ gradient_accumulation_steps=gradient_accumulation_steps,
63
+ per_device_eval_batch_size=batch_size,
64
+ num_train_epochs=num_train_epochs,
65
+ warmup_ratio=0.1,
66
+ logging_steps=5,
67
+ load_best_model_at_end=True,
68
+ # metric_for_best_model="accuracy",
69
+ # push_to_hub=True,
70
+ )
71
+
72
+ from transformers import Trainer
73
+
74
+ trainer = Trainer(
75
+ model,
76
+ training_args,
77
+ train_dataset=dataset_encoded["train"],
78
+ eval_dataset=dataset_encoded["train"],
79
+ tokenizer=feature_extractor,
80
+ # compute_metrics=compute_metrics,
81
+ )
82
+ trainer.train()
83
+