{ "cells": [ { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.1386, 0.1583, -0.2967, ..., -0.2709, -0.2844, 0.4581],\n", " [ 0.5364, -0.2327, 0.1754, ..., 0.5540, 0.4981, -0.0024],\n", " [ 0.3002, -0.3475, 0.1208, ..., -0.4562, 0.3288, 0.8773],\n", " ...,\n", " [ 0.3799, 0.1203, 0.8283, ..., -0.8624, -0.5957, 0.0471],\n", " [-0.0252, -0.7177, -0.6950, ..., 0.0757, -0.6668, -0.3401],\n", " [ 0.7535, 0.2391, 0.0717, ..., 0.2467, -0.6458, -0.3213]]],\n", " grad_fn=), pooler_output=tensor([[-0.9377, -0.5043, -0.9799, 0.9030, 0.9329, -0.2438, 0.8926, 0.2288,\n", " -0.9531, -1.0000, -0.8862, 0.9906, 0.9855, 0.7155, 0.9455, -0.8645,\n", " -0.6035, -0.6666, 0.3020, -0.1587, 0.7455, 1.0000, -0.4022, 0.4261,\n", " 0.6151, 0.9996, -0.8773, 0.9594, 0.9585, 0.6950, -0.6718, 0.3325,\n", " -0.9954, -0.2268, -0.9658, -0.9951, 0.6127, -0.7670, 0.0873, 0.0824,\n", " -0.9518, 0.4713, 1.0000, 0.3299, 0.7583, -0.2670, -1.0000, 0.3166,\n", " -0.9364, 0.9910, 0.9719, 0.9893, 0.2190, 0.6048, 0.5849, -0.4123,\n", " -0.0063, 0.1719, -0.3988, -0.6190, -0.6603, 0.5069, -0.9757, -0.9039,\n", " 0.9926, 0.9323, -0.3687, -0.4869, -0.3143, 0.0499, 0.9129, 0.3396,\n", " -0.1879, -0.9235, 0.8675, 0.3228, -0.6406, 1.0000, -0.7989, -0.9931,\n", " 0.9629, 0.9124, 0.4827, -0.7276, 0.5996, -1.0000, 0.7548, -0.1600,\n", " -0.9941, 0.3386, 0.8394, -0.4158, 0.2943, 0.6111, -0.5745, -0.7185,\n", " -0.4768, -0.9681, -0.4327, -0.6732, 0.1248, -0.2093, -0.5882, -0.4186,\n", " 0.5447, -0.6125, -0.6138, 0.4712, 0.4779, 0.7633, 0.3974, -0.4148,\n", " 0.7063, -0.9680, 0.7389, -0.4270, -0.9948, -0.6019, -0.9950, 0.7459,\n", " -0.6343, -0.2753, 0.9522, -0.5724, 0.6218, -0.1295, -0.9905, -1.0000,\n", " -0.8710, -0.7506, -0.5008, -0.4827, -0.9872, -0.9847, 0.7214, 0.9694,\n", " 0.3013, 1.0000, -0.4427, 0.9699, -0.5431, -0.8189, 0.9180, -0.5132,\n", " 0.9026, 0.5274, -0.5940, 0.2928, -0.6933, 0.7179, -0.9318, -0.2776,\n", " -0.9160, -0.9457, -0.3287, 0.9556, -0.7927, -0.9860, -0.1904, -0.2760,\n", " -0.6062, 0.9005, 0.9266, 0.4353, -0.6858, 0.4720, 0.2851, 0.7685,\n", " -0.8647, -0.5626, 0.5127, -0.5468, -0.9490, -0.9907, -0.5809, 0.7146,\n", " 0.9948, 0.7981, 0.3463, 0.9349, -0.4238, 0.9333, -0.9754, 0.9936,\n", " -0.2597, 0.4665, -0.5400, 0.4947, -0.8723, 0.0034, 0.8378, -0.9134,\n", " -0.8432, -0.2516, -0.5177, -0.4687, -0.9491, 0.5691, -0.4856, -0.4857,\n", " -0.2245, 0.9609, 0.9823, 0.7496, 0.6256, 0.8552, -0.9073, -0.5802,\n", " 0.2874, 0.3017, 0.3016, 0.9974, -0.8503, -0.2108, -0.9261, -0.9907,\n", " -0.0252, -0.9488, -0.3972, -0.8097, 0.8707, -0.7512, 0.8107, 0.5488,\n", " -0.9830, -0.8569, 0.4852, -0.6156, 0.4846, -0.2893, 0.9647, 0.9858,\n", " -0.7064, 0.7120, 0.9593, -0.9590, -0.8708, 0.7893, -0.3561, 0.8603,\n", " -0.7243, 0.9882, 0.9876, 0.9282, -0.9547, -0.8329, -0.7993, -0.8398,\n", " -0.2333, 0.2315, 0.9712, 0.6055, 0.6388, 0.2429, -0.7884, 0.9981,\n", " -0.9448, -0.9804, -0.8184, -0.3534, -0.9951, 0.9729, 0.4165, 0.8094,\n", " -0.6227, -0.8183, -0.9817, 0.8532, 0.1242, 0.9826, -0.6376, -0.9450,\n", " -0.8094, -0.9748, 0.0412, -0.3097, -0.8153, -0.0306, -0.9255, 0.5677,\n", " 0.6217, 0.6652, -0.9682, 0.9997, 1.0000, 0.9826, 0.9013, 0.8950,\n", " -1.0000, -0.8081, 1.0000, -0.9995, -1.0000, -0.9361, -0.8200, 0.4755,\n", " -1.0000, -0.2698, -0.0111, -0.9297, 0.8492, 0.9879, 0.9950, -1.0000,\n", " 0.8653, 0.9513, -0.5679, 0.9966, -0.6713, 0.9815, 0.6008, 0.7414,\n", " -0.3265, 0.5574, -0.9801, -0.8956, -0.8082, -0.9267, 0.9999, 0.2542,\n", " -0.7970, -0.8854, 0.7831, -0.1391, -0.0060, -0.9786, -0.4503, 0.8895,\n", " 0.9021, 0.3021, 0.2650, -0.5750, 0.5099, 0.1216, 0.1170, 0.6484,\n", " -0.9505, -0.3889, -0.6938, 0.2508, -0.7526, -0.9831, 0.9646, -0.2742,\n", " 0.9865, 1.0000, 0.3756, -0.9045, 0.8847, 0.4860, -0.5515, 1.0000,\n", " 0.9092, -0.9904, -0.4959, 0.7900, -0.7156, -0.8280, 0.9999, -0.4197,\n", " -0.9282, -0.7733, 0.9945, -0.9956, 0.9998, -0.8985, -0.9838, 0.9735,\n", " 0.9655, -0.8103, -0.8325, 0.1020, -0.6722, 0.4561, -0.9412, 0.8396,\n", " 0.6979, -0.1201, 0.9288, -0.8345, -0.6312, 0.4356, -0.8901, -0.4565,\n", " 0.9874, 0.5709, -0.2111, -0.0206, -0.4182, -0.9116, -0.9781, 0.8246,\n", " 1.0000, -0.4229, 0.9489, -0.5226, -0.0986, 0.2202, 0.7459, 0.7152,\n", " -0.3528, -0.8800, 0.9299, -0.9716, -0.9949, 0.7278, 0.2206, -0.4944,\n", " 1.0000, 0.6285, 0.3795, 0.7228, 0.9993, 0.0301, 0.5936, 0.9816,\n", " 0.9914, -0.3465, 0.5882, 0.8365, -0.9824, -0.4488, -0.7612, 0.1331,\n", " -0.9479, -0.0559, -0.9697, 0.9846, 0.9960, 0.5818, 0.3121, 0.8577,\n", " 1.0000, -0.9274, 0.6693, -0.1365, 0.8035, -1.0000, -0.8057, -0.4504,\n", " -0.1711, -0.9512, -0.5899, 0.3991, -0.9754, 0.9563, 0.8806, -0.9937,\n", " -0.9923, -0.4979, 0.8853, 0.1439, -0.9994, -0.8986, -0.6272, 0.8385,\n", " -0.3239, -0.9470, -0.7009, -0.4768, 0.5742, -0.2216, 0.5665, 0.9667,\n", " 0.7935, -0.9401, -0.6746, -0.1753, -0.9163, 0.9409, -0.8701, -0.9894,\n", " -0.2514, 1.0000, -0.4087, 0.9385, 0.6050, 0.8219, -0.2712, 0.3326,\n", " 0.9827, 0.3613, -0.8314, -0.9850, -0.2861, -0.5398, 0.8254, 0.8414,\n", " 0.7590, 0.9412, 0.9627, 0.2765, -0.0737, 0.0399, 0.9998, -0.3095,\n", " -0.1933, -0.4689, -0.2511, -0.4629, -0.2914, 1.0000, 0.3963, 0.7777,\n", " -0.9950, -0.9808, -0.9303, 1.0000, 0.8822, -0.6848, 0.8124, 0.6242,\n", " -0.2551, 0.8266, -0.2791, -0.3167, 0.2294, 0.1682, 0.9627, -0.6738,\n", " -0.9904, -0.7910, 0.7099, -0.9770, 1.0000, -0.7030, -0.3960, -0.5981,\n", " -0.6683, -0.2727, -0.0183, -0.9882, -0.3841, 0.5605, 0.9745, 0.3505,\n", " -0.4898, -0.9298, 0.9578, 0.9533, -0.9859, -0.9597, 0.9777, -0.9784,\n", " 0.7550, 1.0000, 0.3446, 0.6786, 0.3947, -0.5349, 0.5541, -0.6754,\n", " 0.8078, -0.9595, -0.4484, -0.3901, 0.3983, -0.1319, -0.2896, 0.7860,\n", " 0.3500, -0.5530, -0.7294, -0.2361, 0.4663, 0.9332, -0.3048, -0.1916,\n", " 0.2318, -0.3230, -0.9323, -0.4672, -0.6315, -1.0000, 0.8068, -1.0000,\n", " 0.8035, 0.4066, -0.3700, 0.8760, 0.7829, 0.8298, -0.8628, -0.9795,\n", " 0.1322, 0.8529, -0.5029, -0.9057, -0.6918, 0.5017, -0.2052, 0.1564,\n", " -0.7397, 0.8156, -0.3414, 1.0000, 0.2659, -0.8292, -0.9821, 0.2491,\n", " -0.3009, 1.0000, -0.8952, -0.9832, 0.3330, -0.9180, -0.8493, 0.5868,\n", " 0.1653, -0.8522, -0.9961, 0.9220, 0.8661, -0.6477, 0.7927, -0.3991,\n", " -0.7691, 0.1512, 0.9868, 0.9924, 0.7317, 0.9083, -0.1226, -0.5258,\n", " 0.9840, 0.4009, -0.0436, 0.1361, 1.0000, 0.4004, -0.9497, -0.1309,\n", " -0.9788, -0.3522, -0.9551, 0.3755, 0.3099, 0.9195, -0.4460, 0.9738,\n", " -0.9714, 0.1901, -0.8894, -0.7863, 0.4757, -0.9463, -0.9892, -0.9938,\n", " 0.8142, -0.4077, -0.1895, 0.2102, 0.1715, 0.6322, 0.5566, -1.0000,\n", " 0.9642, 0.6150, 0.9768, 0.9768, 0.9115, 0.8108, 0.3251, -0.9920,\n", " -0.9910, -0.5438, -0.3567, 0.7960, 0.7648, 0.8900, 0.6470, -0.4875,\n", " -0.4792, -0.7756, -0.8423, -0.9972, 0.5961, -0.8679, -0.9678, 0.9718,\n", " -0.3461, -0.1534, -0.2139, -0.9586, 0.9321, 0.7627, 0.4636, 0.0862,\n", " 0.5071, 0.9170, 0.9597, 0.9882, -0.9231, 0.8555, -0.9196, 0.6712,\n", " 0.9381, -0.9606, 0.2335, 0.8301, -0.5560, 0.3696, -0.4752, -0.9740,\n", " 0.8174, -0.4268, 0.7773, -0.4798, 0.0639, -0.4718, -0.2607, -0.7624,\n", " -0.8742, 0.6576, 0.6207, 0.9219, 0.9360, -0.0496, -0.8942, -0.3701,\n", " -0.8944, -0.9526, 0.9536, -0.0851, -0.2961, 0.9031, 0.1321, 0.9324,\n", " 0.4289, -0.4989, -0.4174, -0.7639, 0.8887, -0.7894, -0.7639, -0.7093,\n", " 0.8105, 0.3595, 1.0000, -0.9188, -0.9878, -0.8268, -0.6012, 0.4992,\n", " -0.7880, -1.0000, 0.3609, -0.8314, 0.8524, -0.9398, 0.9500, -0.9339,\n", " -0.9851, -0.3495, 0.8436, 0.9375, -0.5159, -0.8989, 0.5196, -0.8797,\n", " 0.9979, 0.8753, -0.8277, -0.0012, 0.6013, -0.9184, -0.7398, 0.9228]],\n", " grad_fn=), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import BertTokenizer, BertModel\n", "\n", "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", "model = BertModel.from_pretrained(\"bert-base-uncased\")\n", "text = \"Replace me by any text you'd like.\"\n", "encoded_input = tokenizer(text, return_tensors=\"pt\")\n", "output = model(**encoded_input)\n", "output" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/qninh/miniconda3/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from src.data.mixed_datamodule import MixedDataModule" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "datamodule = MixedDataModule(dataset_path=\"./datasets/mixed\", batch_size=32, num_workers=4, bert_model=\"bert-base-uncased\", tool_capacity=16)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "datamodule.setup(stage=\"fit\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_dataloader = datamodule.train_dataloader()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'instruction': torch.Size([32, 128]),\n", " 'instruction_mask': torch.Size([32, 128]),\n", " 'tool_desc_emb': torch.Size([32, 128]),\n", " 'tool_desc_mask': torch.Size([32, 128])}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# first sample\n", "batch = next(iter(train_dataloader))\n", "{\n", " key: value.shape\n", " for key, value in batch.items()\n", "}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'instruction': torch.Size([32, 128]),\n", " 'instruction_mask': torch.Size([32, 128]),\n", " 'tool_desc_emb': torch.Size([32, 16, 128]),\n", " 'tool_desc_mask': torch.Size([32, 16, 128])}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# first sample\n", "batch = next(iter(val_dataloader))\n", "{\n", " key: value.shape\n", " for key, value in batch.items()\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "swim", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }