xiyuanz commited on
Commit
64a50f1
·
verified ·
1 Parent(s): 4e9b425

add demo and blog

Browse files
Files changed (1) hide show
  1. README.md +58 -1
README.md CHANGED
@@ -11,6 +11,63 @@ Mitra regressor is a tabular foundation model that is pre-trained on purely synt
11
 
12
  Mitra is based on a 12-layer Transformer of 72 M parameters, pre-trained by incorporating an in-context learning paradigm.
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ## License
15
 
16
- This project is licensed under the Apache-2.0 License.
 
 
 
 
 
11
 
12
  Mitra is based on a 12-layer Transformer of 72 M parameters, pre-trained by incorporating an in-context learning paradigm.
13
 
14
+ ## Usage
15
+
16
+ To use Mitra regressor, install AutoGluon by running:
17
+
18
+ ```sh
19
+ pip install uv
20
+ uv pip install autogluon.tabular[mitra]
21
+ ```
22
+
23
+ A minimal example showing how to perform inference using the Mitra regressor:
24
+
25
+ ```python
26
+ import pandas as pd
27
+ from autogluon.tabular import TabularDataset, TabularPredictor
28
+ from sklearn.model_selection import train_test_split
29
+ from sklearn.datasets import fetch_california_housing
30
+
31
+ # Load datasets
32
+ housing_data = fetch_california_housing()
33
+ housing_df = pd.DataFrame(housing_data.data, columns=housing_data.feature_names)
34
+ housing_df['target'] = housing_data.target
35
+
36
+ print("Dataset shapes:")
37
+ print(f"California Housing: {housing_df.shape}")
38
+
39
+ # Create train/test splits (80/20)
40
+ housing_train, housing_test = train_test_split(housing_df, test_size=0.2, random_state=42)
41
+
42
+ print("Training set sizes:")
43
+ print(f"Housing: {len(housing_train)} samples")
44
+
45
+ # Convert to TabularDataset
46
+ housing_train_data = TabularDataset(housing_train)
47
+ housing_test_data = TabularDataset(housing_test)
48
+
49
+ # Create predictor with Mitra for regression
50
+ print("Training Mitra regressor on California Housing dataset...")
51
+ mitra_reg_predictor = TabularPredictor(
52
+ label='target',
53
+ path='./mitra_regressor_model',
54
+ problem_type='regression'
55
+ )
56
+ mitra_reg_predictor.fit(
57
+ housing_train_data.sample(1000), # sample 1000 rows
58
+ hyperparameters={
59
+ 'MITRA': {'fine_tune': False}
60
+ },
61
+ )
62
+
63
+ # Evaluate regression performance
64
+ mitra_reg_predictor.leaderboard(housing_test_data)
65
+ ```
66
+
67
  ## License
68
 
69
+ This project is licensed under the Apache-2.0 License.
70
+
71
+ ## Reference
72
+
73
+ Amazon Science blog: [Mitra: Mixed synthetic priors for enhancing tabular foundation models](https://www.amazon.science/blog/mitra-mixed-synthetic-priors-for-enhancing-tabular-foundation-models?utm_campaign=mitra-mixed-synthetic-priors-for-enhancing-tabular-foundation-models&utm_medium=organic-asw&utm_source=linkedin&utm_content=2025-7-22-mitra-mixed-synthetic-priors-for-enhancing-tabular-foundation-models&utm_term=2025-july)