Upload model, config, and documentation
Browse files- README.md +77 -0
- best_model.pt +3 -0
- config.yaml +107 -0
README.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- pattern-classification
|
| 4 |
+
- multi-label-classification
|
| 5 |
+
datasets:
|
| 6 |
+
- maximuspowers/muat-mean-std
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Pattern Classifier
|
| 10 |
+
|
| 11 |
+
This model was trained to classify which patterns a subject model was trained on, based on neuron activation signatures.
|
| 12 |
+
|
| 13 |
+
## Dataset
|
| 14 |
+
|
| 15 |
+
- **Training Dataset**: [maximuspowers/muat-mean-std](https://huggingface.co/datasets/maximuspowers/muat-mean-std)
|
| 16 |
+
- **Input Mode**: signature
|
| 17 |
+
- **Number of Patterns**: 14
|
| 18 |
+
|
| 19 |
+
## Patterns
|
| 20 |
+
|
| 21 |
+
The model predicts which of the following 14 patterns the subject model was trained to classify as positive:
|
| 22 |
+
|
| 23 |
+
1. `palindrome`
|
| 24 |
+
2. `sorted_ascending`
|
| 25 |
+
3. `sorted_descending`
|
| 26 |
+
4. `alternating`
|
| 27 |
+
5. `contains_abc`
|
| 28 |
+
6. `starts_with`
|
| 29 |
+
7. `ends_with`
|
| 30 |
+
8. `no_repeats`
|
| 31 |
+
9. `has_majority`
|
| 32 |
+
10. `increasing_pairs`
|
| 33 |
+
11. `decreasing_pairs`
|
| 34 |
+
12. `vowel_consonant`
|
| 35 |
+
13. `first_last_match`
|
| 36 |
+
14. `mountain_pattern`
|
| 37 |
+
|
| 38 |
+
## Model Architecture
|
| 39 |
+
|
| 40 |
+
- **Signature Encoder**: [512, 256, 256, 128]
|
| 41 |
+
- **Activation**: relu
|
| 42 |
+
- **Dropout**: 0.2
|
| 43 |
+
- **Batch Normalization**: True
|
| 44 |
+
|
| 45 |
+
## Training Configuration
|
| 46 |
+
|
| 47 |
+
- **Optimizer**: adam
|
| 48 |
+
- **Learning Rate**: 0.001
|
| 49 |
+
- **Batch Size**: 16
|
| 50 |
+
- **Loss Function**: BCE with Logits (with pos_weight for training, unweighted for validation)
|
| 51 |
+
|
| 52 |
+
## Test Set Performance
|
| 53 |
+
|
| 54 |
+
- **F1 Macro**: 0.2832
|
| 55 |
+
- **F1 Micro**: 0.2498
|
| 56 |
+
- **Hamming Accuracy**: 0.7306
|
| 57 |
+
- **Exact Match Accuracy**: 0.0140
|
| 58 |
+
- **BCE Loss**: 0.4844
|
| 59 |
+
|
| 60 |
+
### Per-Pattern Performance (Test Set)
|
| 61 |
+
|
| 62 |
+
| Pattern | Precision | Recall | F1 Score |
|
| 63 |
+
|---------|-----------|--------|----------|
|
| 64 |
+
| palindrome | 11.1% | 89.8% | 19.8% |
|
| 65 |
+
| sorted_ascending | 59.7% | 56.6% | 58.1% |
|
| 66 |
+
| sorted_descending | 15.8% | 66.2% | 25.5% |
|
| 67 |
+
| alternating | 19.8% | 72.4% | 31.1% |
|
| 68 |
+
| contains_abc | 30.8% | 57.6% | 40.1% |
|
| 69 |
+
| starts_with | 9.1% | 59.4% | 15.8% |
|
| 70 |
+
| ends_with | 10.3% | 73.8% | 18.1% |
|
| 71 |
+
| no_repeats | 17.8% | 32.1% | 22.9% |
|
| 72 |
+
| has_majority | 33.3% | 60.5% | 43.0% |
|
| 73 |
+
| increasing_pairs | 23.3% | 35.3% | 28.1% |
|
| 74 |
+
| decreasing_pairs | 19.4% | 60.9% | 29.5% |
|
| 75 |
+
| vowel_consonant | 9.8% | 76.9% | 17.4% |
|
| 76 |
+
| first_last_match | 15.3% | 96.6% | 26.5% |
|
| 77 |
+
| mountain_pattern | 15.3% | 31.7% | 20.6% |
|
best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64c68dc903e28b122d405966fd3d9c471a17b9da24d2d87b46959bbb797c5911
|
| 3 |
+
size 4439480
|
config.yaml
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataloader:
|
| 2 |
+
num_workers: 0
|
| 3 |
+
pin_memory: true
|
| 4 |
+
dataset:
|
| 5 |
+
cache_dir: .cache/classifier_data
|
| 6 |
+
hf_dataset: maximuspowers/muat-mean-std
|
| 7 |
+
input_mode: signature
|
| 8 |
+
max_dimensions:
|
| 9 |
+
max_layers: 13
|
| 10 |
+
max_neurons_per_layer: 8
|
| 11 |
+
max_sequence_length: 5
|
| 12 |
+
neuron_profile:
|
| 13 |
+
methods:
|
| 14 |
+
mean: {}
|
| 15 |
+
std: {}
|
| 16 |
+
patterns:
|
| 17 |
+
- palindrome
|
| 18 |
+
- sorted_ascending
|
| 19 |
+
- sorted_descending
|
| 20 |
+
- alternating
|
| 21 |
+
- contains_abc
|
| 22 |
+
- starts_with
|
| 23 |
+
- ends_with
|
| 24 |
+
- no_repeats
|
| 25 |
+
- has_majority
|
| 26 |
+
- increasing_pairs
|
| 27 |
+
- decreasing_pairs
|
| 28 |
+
- vowel_consonant
|
| 29 |
+
- first_last_match
|
| 30 |
+
- mountain_pattern
|
| 31 |
+
random_seed: 42
|
| 32 |
+
test_split: 0.1
|
| 33 |
+
train_split: 0.8
|
| 34 |
+
val_split: 0.1
|
| 35 |
+
device:
|
| 36 |
+
type: auto
|
| 37 |
+
evaluation:
|
| 38 |
+
decision_threshold: 0.5
|
| 39 |
+
metrics:
|
| 40 |
+
- accuracy_exact_match
|
| 41 |
+
- accuracy_hamming
|
| 42 |
+
- precision_macro
|
| 43 |
+
- recall_macro
|
| 44 |
+
- f1_macro
|
| 45 |
+
- f1_micro
|
| 46 |
+
per_pattern_metrics: true
|
| 47 |
+
hub:
|
| 48 |
+
enabled: true
|
| 49 |
+
private: false
|
| 50 |
+
push_frequency: epoch
|
| 51 |
+
push_logs: true
|
| 52 |
+
push_metrics: true
|
| 53 |
+
push_model: true
|
| 54 |
+
repo_id: maximuspowers/muat-mean-std-classifier
|
| 55 |
+
token: <REDACTED>
|
| 56 |
+
logging:
|
| 57 |
+
checkpoint:
|
| 58 |
+
enabled: true
|
| 59 |
+
mode: max
|
| 60 |
+
monitor: val_f1_macro
|
| 61 |
+
save_best_only: true
|
| 62 |
+
save_dir: ./checkpoints/classifier_mean_std
|
| 63 |
+
tensorboard:
|
| 64 |
+
enabled: true
|
| 65 |
+
log_dir: ./runs/classifier_mean_std
|
| 66 |
+
log_interval: 10
|
| 67 |
+
verbose: true
|
| 68 |
+
model:
|
| 69 |
+
fusion:
|
| 70 |
+
activation: relu
|
| 71 |
+
dropout: 0.2
|
| 72 |
+
hidden_dims:
|
| 73 |
+
- 128
|
| 74 |
+
- 64
|
| 75 |
+
output:
|
| 76 |
+
num_patterns: 14
|
| 77 |
+
signature_encoder:
|
| 78 |
+
activation: relu
|
| 79 |
+
dropout: 0.2
|
| 80 |
+
hidden_dims:
|
| 81 |
+
- 512
|
| 82 |
+
- 256
|
| 83 |
+
- 256
|
| 84 |
+
- 128
|
| 85 |
+
use_batch_norm: true
|
| 86 |
+
weight_encoder:
|
| 87 |
+
activation: relu
|
| 88 |
+
dropout: 0.2
|
| 89 |
+
training:
|
| 90 |
+
batch_size: 16
|
| 91 |
+
early_stopping:
|
| 92 |
+
enabled: true
|
| 93 |
+
mode: min
|
| 94 |
+
monitor: val_loss
|
| 95 |
+
patience: 50
|
| 96 |
+
epochs: 1000
|
| 97 |
+
learning_rate: 0.001
|
| 98 |
+
loss: bce_with_logits
|
| 99 |
+
lr_scheduler:
|
| 100 |
+
enabled: true
|
| 101 |
+
factor: 0.5
|
| 102 |
+
min_lr: 1.0e-05
|
| 103 |
+
patience: 20
|
| 104 |
+
type: reduce_on_plateau
|
| 105 |
+
optimizer: adam
|
| 106 |
+
pos_weight: null
|
| 107 |
+
weight_decay: 0.0001
|