Browse Source

add colab result

liuyuqi-dellpc 2 years ago
parent
commit
8b839d0777
2 changed files with 766 additions and 362 deletions
  1. 0 362
      1.baseline.ipynb
  2. 766 0
      1_baseline.ipynb

+ 0 - 362
1.baseline.ipynb

@@ -1,362 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## 声音识别项目介绍\n",
-    "\n",
-    "\n",
-    "\n",
-    "## 开发环境\n",
-    "\n",
-    "* TensorFlow的版本:2.0 +\n",
-    "* keras\n",
-    "* sklearn\n",
-    "* librosa\n",
-    "\n",
-    "## 下载数据\n",
-    "\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/train_sample.zip\n",
-    "!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/test_a.zip\n",
-    "\n",
-    "\n",
-    "!unzip -qq train_sample.zip\n",
-    "!\\rm train_sample.zip\n",
-    "\n",
-    "!unzip -qq test_a.zip\n",
-    "!\\rm test_a.zip"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# 安装语音处理依赖\n",
-    "!pip install librosa --user"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# 基本库\n",
-    "\n",
-    "import pandas as pd\n",
-    "import numpy as np\n",
-    "\n",
-    "from sklearn.model_selection import train_test_split\n",
-    "from sklearn.metrics import classification_report\n",
-    "from sklearn.model_selection import GridSearchCV\n",
-    "\n",
-    "from sklearn.preprocessing import MinMaxScaler\n",
-    "\n",
-    "\n",
-    "from tensorflow.keras.models import Sequential\n",
-    "from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPool2D, Dropout\n",
-    "from tensorflow.keras.utils import to_categorical \n",
-    "\n",
-    "from sklearn.ensemble import RandomForestClassifier\n",
-    "from sklearn.svm import SVC\n",
-    "\n",
-    "\n",
-    "import os\n",
-    "import librosa\n",
-    "import librosa.display\n",
-    "import glob "
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## 数据预处理\n",
-    "\n",
-    "特征提取以及数据集的建立\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "feature = []\n",
-    "label = []\n",
-    "# 建立类别标签,不同类别对应不同的数字。\n",
-    "label_dict = {'aloe': 0, 'burger': 1, 'cabbage': 2,'candied_fruits':3, 'carrots': 4, 'chips':5,\n",
-    "                  'chocolate': 6, 'drinks': 7, 'fries': 8, 'grapes': 9, 'gummies': 10, 'ice-cream':11,\n",
-    "                  'jelly': 12, 'noodles': 13, 'pickles': 14, 'pizza': 15, 'ribs': 16, 'salmon':17,\n",
-    "                  'soup': 18, 'wings': 19}\n",
-    "label_dict_inv = {v:k for k,v in label_dict.items()}"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from tqdm import tqdm\n",
-    "def extract_features(parent_dir, sub_dirs, max_file=10, file_ext=\"*.wav\"):\n",
-    "    c = 0\n",
-    "    label, feature = [], []\n",
-    "    for sub_dir in sub_dirs:\n",
-    "        for fn in tqdm(glob.glob(os.path.join(parent_dir, sub_dir, file_ext))[:max_file]): # 遍历数据集的所有文件\n",
-    "            \n",
-    "           # segment_log_specgrams, segment_labels = [], []\n",
-    "            #sound_clip,sr = librosa.load(fn)\n",
-    "            #print(fn)\n",
-    "            label_name = fn.split('/')[-2]\n",
-    "            label.extend([label_dict[label_name]])\n",
-    "            X, sample_rate = librosa.load(fn,res_type='kaiser_fast')\n",
-    "            mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征\n",
-    "            feature.extend([mels])\n",
-    "            \n",
-    "    return [feature, label]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# 自己更改目录\n",
-    "parent_dir = './train_sample/'\n",
-    "save_dir = \"./\"\n",
-    "folds = sub_dirs = np.array(['aloe','burger','cabbage','candied_fruits',\n",
-    "                             'carrots','chips','chocolate','drinks','fries',\n",
-    "                            'grapes','gummies','ice-cream','jelly','noodles','pickles',\n",
-    "                            'pizza','ribs','salmon','soup','wings'])\n",
-    "\n",
-    "# 获取特征feature以及类别的label\n",
-    "temp = extract_features(parent_dir,sub_dirs,max_file=100)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "temp = np.array(temp)\n",
-    "data = temp.transpose()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# 获取特征\n",
-    "X = np.vstack(data[:, 0])\n",
-    "\n",
-    "# 获取标签\n",
-    "Y = np.array(data[:, 1])\n",
-    "print('X的特征尺寸是:',X.shape)\n",
-    "print('Y的特征尺寸是:',Y.shape)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# 在Keras库中:to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示\n",
-    "Y = to_categorical(Y)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "'''最终数据'''\n",
-    "print(X.shape)\n",
-    "print(Y.shape)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state = 1, stratify=Y)\n",
-    "print('训练集的大小',len(X_train))\n",
-    "print('测试集的大小',len(X_test))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "X_train = X_train.reshape(-1, 16, 8, 1)\n",
-    "X_test = X_test.reshape(-1, 16, 8, 1)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## 搭建CNN网络¶\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "model = Sequential()\n",
-    "\n",
-    "# 输入的大小\n",
-    "input_dim = (16, 8, 1)\n",
-    "\n",
-    "model.add(Conv2D(64, (3, 3), padding = \"same\", activation = \"tanh\", input_shape = input_dim))# 卷积层\n",
-    "model.add(MaxPool2D(pool_size=(2, 2)))# 最大池化\n",
-    "model.add(Conv2D(128, (3, 3), padding = \"same\", activation = \"tanh\")) #卷积层\n",
-    "model.add(MaxPool2D(pool_size=(2, 2))) # 最大池化层\n",
-    "model.add(Dropout(0.1))\n",
-    "model.add(Flatten()) # 展开\n",
-    "model.add(Dense(1024, activation = \"tanh\"))\n",
-    "model.add(Dense(20, activation = \"softmax\")) # 输出层:20个units输出20个类的概率\n",
-    "\n",
-    "# 编译模型,设置损失函数,优化方法以及评价标准\n",
-    "model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "model.summary()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# 训练模型\n",
-    "model.fit(X_train, Y_train, epochs = 20, batch_size = 15, validation_data = (X_test, Y_test))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "#预测\n",
-    "\n",
-    "def extract_features(test_dir, file_ext=\"*.wav\"):\n",
-    "    feature = []\n",
-    "    for fn in tqdm(glob.glob(os.path.join(test_dir, file_ext))[:]): # 遍历数据集的所有文件\n",
-    "        X, sample_rate = librosa.load(fn,res_type='kaiser_fast')\n",
-    "        mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征\n",
-    "        feature.extend([mels])\n",
-    "    return feature\n",
-    "    "
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "X_test = extract_features('./test_a/')"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "X_test = np.vstack(X_test)\n",
-    "predictions = model.predict(X_test.reshape(-1, 16, 8, 1))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "preds = np.argmax(predictions, axis = 1)\n",
-    "preds = [label_dict_inv[x] for x in preds]\n",
-    "\n",
-    "path = glob.glob('./test_a/*.wav')\n",
-    "result = pd.DataFrame({'name':path, 'label': preds})\n",
-    "\n",
-    "result['name'] = result['name'].apply(lambda x: x.split('/')[-1])\n",
-    "result.to_csv('submit.csv',index=None)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "!ls ./test_a/*.wav | wc -l"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "!wc -l submit.csv"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "language_info": {
-   "name": "plaintext"
-  },
-  "orig_nbformat": 4
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}

+ 766 - 0
1_baseline.ipynb

@@ -0,0 +1,766 @@
+{
+  "cells": [
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "xQZFoVjZweQH"
+      },
+      "outputs": [],
+      "source": [
+        ""
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "2EIqvi83weQL"
+      },
+      "source": [
+        "## 声音识别项目介绍\n",
+        "\n",
+        "数据集来自 Eating Sound Collection,数据集中包含20种不同食物的咀嚼声音,赛题任务是给这些声音数据建模,准确分类。\n",
+        "\n",
+        "train文件夹:完整的训练集;\n",
+        "train_sample文件夹:部分训练集;\n",
+        "test文件夹:测试集;\n",
+        "\n",
+        "赛题包含的类别:\n",
+        "\n",
+        "\n",
+        "```\n",
+        "aloe\n",
+        "ice-cream\n",
+        "ribs\n",
+        "chocolate\n",
+        "cabbage\n",
+        "candied_fruits\n",
+        "soup\n",
+        "jelly\n",
+        "grapes\n",
+        "pizza\n",
+        "gummies\n",
+        "salmon\n",
+        "wings\n",
+        "burger\n",
+        "pickles\n",
+        "carrots\n",
+        "fries\n",
+        "chips\n",
+        "noodles\n",
+        "drinks\n",
+        "```\n",
+        "\n",
+        "训练数据通过不同目录来区分不同类型的声音:\n",
+        "```\n",
+        "/content/train_sample/aloe\n",
+        "/content/train_sample/ice-cream\n",
+        "/content/train_sample/carrots\n",
+        "\n",
+        "```\n",
+        "\n",
+        "\n",
+        "## 开发环境\n",
+        "\n",
+        "* TensorFlow的版本:2.0 +\n",
+        "* keras\n",
+        "* sklearn\n",
+        "* librosa\n",
+        "\n",
+        "## 下载数据\n",
+        "\n",
+        "\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 1,
+      "metadata": {
+        "id": "jMnYHwWvweQS",
+        "outputId": "6f169e94-edb9-4ddb-f295-c0a5204a5513",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "--2022-01-07 08:58:21--  http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/train_sample.zip\n",
+            "Resolving tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)... 118.31.232.194\n",
+            "Connecting to tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)|118.31.232.194|:80... connected.\n",
+            "HTTP request sent, awaiting response... 200 OK\n",
+            "Length: 540689175 (516M) [application/zip]\n",
+            "Saving to: ‘train_sample.zip’\n",
+            "\n",
+            "train_sample.zip    100%[===================>] 515.64M  4.95MB/s    in 1m 52s  \n",
+            "\n",
+            "2022-01-07 09:00:13 (4.61 MB/s) - ‘train_sample.zip’ saved [540689175/540689175]\n",
+            "\n",
+            "--2022-01-07 09:00:13--  http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/test_a.zip\n",
+            "Resolving tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)... 118.31.232.194\n",
+            "Connecting to tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)|118.31.232.194|:80... connected.\n",
+            "HTTP request sent, awaiting response... 200 OK\n",
+            "Length: 1092637852 (1.0G) [application/zip]\n",
+            "Saving to: ‘test_a.zip’\n",
+            "\n",
+            "test_a.zip          100%[===================>]   1.02G  4.98MB/s    in 3m 38s  \n",
+            "\n",
+            "2022-01-07 09:03:52 (4.78 MB/s) - ‘test_a.zip’ saved [1092637852/1092637852]\n",
+            "\n"
+          ]
+        }
+      ],
+      "source": [
+        "# 这里只下载部分训练集\n",
+        "!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/train_sample.zip\n",
+        "!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/test_a.zip\n",
+        "\n",
+        "\n",
+        "!unzip -qq train_sample.zip\n",
+        "!\\rm train_sample.zip\n",
+        "\n",
+        "!unzip -qq test_a.zip\n",
+        "!\\rm test_a.zip"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 2,
+      "metadata": {
+        "id": "fxVp5dadweQV",
+        "outputId": "14a7ae70-c2f0-4bb7-b45a-42bd215e4df0",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Requirement already satisfied: librosa in /usr/local/lib/python3.7/dist-packages (0.8.1)\n",
+            "Requirement already satisfied: soundfile>=0.10.2 in /usr/local/lib/python3.7/dist-packages (from librosa) (0.10.3.post1)\n",
+            "Requirement already satisfied: decorator>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from librosa) (4.4.2)\n",
+            "Requirement already satisfied: resampy>=0.2.2 in /usr/local/lib/python3.7/dist-packages (from librosa) (0.2.2)\n",
+            "Requirement already satisfied: numba>=0.43.0 in /usr/local/lib/python3.7/dist-packages (from librosa) (0.51.2)\n",
+            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from librosa) (21.3)\n",
+            "Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from librosa) (1.0.1)\n",
+            "Requirement already satisfied: joblib>=0.14 in /usr/local/lib/python3.7/dist-packages (from librosa) (1.1.0)\n",
+            "Requirement already satisfied: scipy>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from librosa) (1.4.1)\n",
+            "Requirement already satisfied: audioread>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from librosa) (2.1.9)\n",
+            "Requirement already satisfied: pooch>=1.0 in /usr/local/lib/python3.7/dist-packages (from librosa) (1.5.2)\n",
+            "Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from librosa) (1.19.5)\n",
+            "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba>=0.43.0->librosa) (57.4.0)\n",
+            "Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba>=0.43.0->librosa) (0.34.0)\n",
+            "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->librosa) (3.0.6)\n",
+            "Requirement already satisfied: appdirs in /usr/local/lib/python3.7/dist-packages (from pooch>=1.0->librosa) (1.4.4)\n",
+            "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from pooch>=1.0->librosa) (2.23.0)\n",
+            "Requirement already satisfied: six>=1.3 in /usr/local/lib/python3.7/dist-packages (from resampy>=0.2.2->librosa) (1.15.0)\n",
+            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa) (3.0.0)\n",
+            "Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.7/dist-packages (from soundfile>=0.10.2->librosa) (1.15.0)\n",
+            "Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.0->soundfile>=0.10.2->librosa) (2.21)\n",
+            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->pooch>=1.0->librosa) (2.10)\n",
+            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->pooch>=1.0->librosa) (3.0.4)\n",
+            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->pooch>=1.0->librosa) (1.24.3)\n",
+            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->pooch>=1.0->librosa) (2021.10.8)\n"
+          ]
+        }
+      ],
+      "source": [
+        "# 安装语音处理依赖\n",
+        "!pip install librosa --user"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 8,
+      "metadata": {
+        "id": "DvN1C1LdweQW"
+      },
+      "outputs": [],
+      "source": [
+        "# 基本库\n",
+        "\n",
+        "import pandas as pd\n",
+        "import numpy as np\n",
+        "\n",
+        "from sklearn.model_selection import train_test_split\n",
+        "from sklearn.metrics import classification_report\n",
+        "from sklearn.model_selection import GridSearchCV\n",
+        "\n",
+        "from sklearn.preprocessing import MinMaxScaler\n",
+        "\n",
+        "\n",
+        "from tensorflow.keras.models import Sequential\n",
+        "from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPool2D, Dropout\n",
+        "from tensorflow.keras.utils import to_categorical \n",
+        "\n",
+        "from sklearn.ensemble import RandomForestClassifier\n",
+        "from sklearn.svm import SVC\n",
+        "\n",
+        "\n",
+        "import os\n",
+        "import librosa\n",
+        "import librosa.display\n",
+        "import glob \n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "0SlgEw5UweQY"
+      },
+      "source": [
+        "## 数据预处理\n",
+        "\n",
+        "特征提取以及数据集的建立\n",
+        "\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 4,
+      "metadata": {
+        "id": "Su3uS8-vweQY"
+      },
+      "outputs": [],
+      "source": [
+        "feature = []\n",
+        "label = []\n",
+        "# 建立类别标签,不同类别对应不同的数字。\n",
+        "label_dict = {'aloe': 0, 'burger': 1, 'cabbage': 2,'candied_fruits':3, 'carrots': 4, 'chips':5,\n",
+        "                  'chocolate': 6, 'drinks': 7, 'fries': 8, 'grapes': 9, 'gummies': 10, 'ice-cream':11,\n",
+        "                  'jelly': 12, 'noodles': 13, 'pickles': 14, 'pizza': 15, 'ribs': 16, 'salmon':17,\n",
+        "                  'soup': 18, 'wings': 19}\n",
+        "# key和value对换\n",
+        "label_dict_inv = {v:k for k,v in label_dict.items()}"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "from tqdm import tqdm\n",
+        "def extract_features(parent_dir, sub_dirs, max_file=10, file_ext=\"*.wav\"):\n",
+        "    c = 0\n",
+        "    label, feature = [], []\n",
+        "    for sub_dir in sub_dirs:\n",
+        "        for fn in tqdm(glob.glob(os.path.join(parent_dir, sub_dir, file_ext))[:max_file]): # 遍历数据集的所有文件\n",
+        "            \n",
+        "           # segment_log_specgrams, segment_labels = [], []\n",
+        "            #sound_clip,sr = librosa.load(fn)\n",
+        "            #print(fn)\n",
+        "            label_name = fn.split('/')[-2]\n",
+        "            label.extend([label_dict[label_name]])\n",
+        "            # librosa读取声音wmv\n",
+        "            X, sample_rate = librosa.load(fn,res_type='kaiser_fast')\n",
+        "            mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征\n",
+        "            feature.extend([mels])\n",
+        "            \n",
+        "    return [feature, label]"
+      ],
+      "metadata": {
+        "id": "VL30pSGX6GxC"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 9,
+      "metadata": {
+        "id": "d6oDYvdEweQb",
+        "outputId": "5559d623-cc15-4d64-c165-476934db52a5",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "100%|██████████| 45/45 [00:03<00:00, 11.35it/s]\n",
+            "100%|██████████| 64/64 [00:04<00:00, 14.91it/s]\n",
+            "100%|██████████| 48/48 [00:05<00:00,  8.80it/s]\n",
+            "100%|██████████| 74/74 [00:08<00:00,  8.59it/s]\n",
+            "100%|██████████| 49/49 [00:05<00:00,  9.47it/s]\n",
+            "100%|██████████| 57/57 [00:05<00:00,  9.97it/s]\n",
+            "100%|██████████| 27/27 [00:02<00:00, 10.60it/s]\n",
+            "100%|██████████| 27/27 [00:02<00:00, 11.07it/s]\n",
+            "100%|██████████| 57/57 [00:05<00:00, 11.06it/s]\n",
+            "100%|██████████| 61/61 [00:06<00:00, 10.02it/s]\n",
+            "100%|██████████| 65/65 [00:06<00:00,  9.68it/s]\n",
+            "100%|██████████| 69/69 [00:07<00:00,  9.22it/s]\n",
+            "100%|██████████| 43/43 [00:04<00:00, 10.22it/s]\n",
+            "100%|██████████| 33/33 [00:02<00:00, 11.10it/s]\n",
+            "100%|██████████| 75/75 [00:07<00:00,  9.45it/s]\n",
+            "100%|██████████| 55/55 [00:06<00:00,  8.70it/s]\n",
+            "100%|██████████| 47/47 [00:04<00:00,  9.46it/s]\n",
+            "100%|██████████| 37/37 [00:04<00:00,  8.60it/s]\n",
+            "100%|██████████| 32/32 [00:02<00:00, 14.34it/s]\n",
+            "100%|██████████| 35/35 [00:03<00:00,  9.46it/s]\n"
+          ]
+        }
+      ],
+      "source": [
+        "# 自己更改目录\n",
+        "parent_dir = './train_sample/'\n",
+        "save_dir = \"./\"\n",
+        "folds = sub_dirs = np.array(['aloe','burger','cabbage','candied_fruits',\n",
+        "                             'carrots','chips','chocolate','drinks','fries',\n",
+        "                            'grapes','gummies','ice-cream','jelly','noodles','pickles',\n",
+        "                            'pizza','ribs','salmon','soup','wings'])\n",
+        "\n",
+        "# 获取特征feature以及类别的label\n",
+        "temp = extract_features(parent_dir,sub_dirs,max_file=100)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 10,
+      "metadata": {
+        "id": "-8P75hTKweQc",
+        "outputId": "3f50b6f2-a590-4c0b-f6fb-61663ee30e2d",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
+            "  \"\"\"Entry point for launching an IPython kernel.\n"
+          ]
+        }
+      ],
+      "source": [
+        "\n",
+        "temp = np.array(temp)  #(2, 1000)\n",
+        "data = temp.transpose()"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 11,
+      "metadata": {
+        "id": "Z7jm020wweQd",
+        "outputId": "3f1f41eb-40c4-426a-ef05-4462929b8a89",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "X的特征尺寸是: (1000, 128)\n",
+            "Y的特征尺寸是: (1000,)\n"
+          ]
+        }
+      ],
+      "source": [
+        "# 获取特征 (1000, 128)\n",
+        "X = np.vstack(data[:, 0])\n",
+        "\n",
+        "# 获取标签 (1000, 20)\n",
+        "Y = np.array(data[:, 1])\n",
+        "print('X的特征尺寸是:',X.shape)\n",
+        "print('Y的特征尺寸是:',Y.shape)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 38,
+      "metadata": {
+        "id": "GQ0d4mhdweQd"
+      },
+      "outputs": [],
+      "source": [
+        "# 在Keras库中:to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示\n",
+        "Y = to_categorical(Y)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 13,
+      "metadata": {
+        "id": "bi7-ecyUweQe",
+        "outputId": "b09bfa64-7f0d-459c-fa9c-f4002f053d36",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "(1000, 128)\n",
+            "(1000, 20)\n"
+          ]
+        }
+      ],
+      "source": [
+        "'''最终数据'''\n",
+        "print(X.shape)\n",
+        "print(Y.shape)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 14,
+      "metadata": {
+        "id": "HE-DBp-1weQf",
+        "outputId": "d1100b6c-9c0e-4691-a3d0-0f0b378f490c",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "训练集的大小 750\n",
+            "测试集的大小 250\n"
+          ]
+        }
+      ],
+      "source": [
+        "# 训练集划分\n",
+        "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state = 1, stratify=Y)\n",
+        "print('训练集的大小',len(X_train))\n",
+        "print('测试集的大小',len(X_test))"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 15,
+      "metadata": {
+        "id": "7YD80EA0weQf"
+      },
+      "outputs": [],
+      "source": [
+        "X_train = X_train.reshape(-1, 16, 8, 1)\n",
+        "X_test = X_test.reshape(-1, 16, 8, 1)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "pSN5FqOZweQf"
+      },
+      "source": [
+        "## 搭建CNN网络¶\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 16,
+      "metadata": {
+        "id": "mK1yzAM8weQg"
+      },
+      "outputs": [],
+      "source": [
+        "model = Sequential()\n",
+        "\n",
+        "# 输入的大小\n",
+        "input_dim = (16, 8, 1)\n",
+        "\n",
+        "model.add(Conv2D(64, (3, 3), padding = \"same\", activation = \"tanh\", input_shape = input_dim))# 卷积层\n",
+        "model.add(MaxPool2D(pool_size=(2, 2)))# 最大池化\n",
+        "model.add(Conv2D(128, (3, 3), padding = \"same\", activation = \"tanh\")) #卷积层\n",
+        "model.add(MaxPool2D(pool_size=(2, 2))) # 最大池化层\n",
+        "model.add(Dropout(0.1))\n",
+        "model.add(Flatten()) # 展开\n",
+        "model.add(Dense(1024, activation = \"tanh\"))\n",
+        "model.add(Dense(20, activation = \"softmax\")) # 输出层:20个units输出20个类的概率\n",
+        "\n",
+        "# 编译模型,设置损失函数,优化方法以及评价标准\n",
+        "model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 17,
+      "metadata": {
+        "id": "npluL5MRweQg",
+        "outputId": "21762dbb-a80d-484b-9e3e-e548569df98e",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Model: \"sequential\"\n",
+            "_________________________________________________________________\n",
+            " Layer (type)                Output Shape              Param #   \n",
+            "=================================================================\n",
+            " conv2d (Conv2D)             (None, 16, 8, 64)         640       \n",
+            "                                                                 \n",
+            " max_pooling2d (MaxPooling2D  (None, 8, 4, 64)         0         \n",
+            " )                                                               \n",
+            "                                                                 \n",
+            " conv2d_1 (Conv2D)           (None, 8, 4, 128)         73856     \n",
+            "                                                                 \n",
+            " max_pooling2d_1 (MaxPooling  (None, 4, 2, 128)        0         \n",
+            " 2D)                                                             \n",
+            "                                                                 \n",
+            " dropout (Dropout)           (None, 4, 2, 128)         0         \n",
+            "                                                                 \n",
+            " flatten (Flatten)           (None, 1024)              0         \n",
+            "                                                                 \n",
+            " dense (Dense)               (None, 1024)              1049600   \n",
+            "                                                                 \n",
+            " dense_1 (Dense)             (None, 20)                20500     \n",
+            "                                                                 \n",
+            "=================================================================\n",
+            "Total params: 1,144,596\n",
+            "Trainable params: 1,144,596\n",
+            "Non-trainable params: 0\n",
+            "_________________________________________________________________\n"
+          ]
+        }
+      ],
+      "source": [
+        "model.summary()"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 18,
+      "metadata": {
+        "id": "aJa_zbQlweQg",
+        "outputId": "ec47225d-e60c-48e6-9583-1a1235d99785",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Epoch 1/20\n",
+            "50/50 [==============================] - 3s 33ms/step - loss: 2.8777 - accuracy: 0.1387 - val_loss: 2.6630 - val_accuracy: 0.1800\n",
+            "Epoch 2/20\n",
+            "50/50 [==============================] - 1s 27ms/step - loss: 2.5199 - accuracy: 0.2360 - val_loss: 2.4908 - val_accuracy: 0.2320\n",
+            "Epoch 3/20\n",
+            "50/50 [==============================] - 1s 26ms/step - loss: 2.2987 - accuracy: 0.2973 - val_loss: 2.4248 - val_accuracy: 0.2880\n",
+            "Epoch 4/20\n",
+            "50/50 [==============================] - 1s 26ms/step - loss: 2.0837 - accuracy: 0.3480 - val_loss: 2.3719 - val_accuracy: 0.3000\n",
+            "Epoch 5/20\n",
+            "50/50 [==============================] - 1s 25ms/step - loss: 1.9110 - accuracy: 0.4187 - val_loss: 2.4793 - val_accuracy: 0.3160\n",
+            "Epoch 6/20\n",
+            "50/50 [==============================] - 1s 25ms/step - loss: 1.7646 - accuracy: 0.4667 - val_loss: 2.3191 - val_accuracy: 0.3520\n",
+            "Epoch 7/20\n",
+            "50/50 [==============================] - 1s 26ms/step - loss: 1.5766 - accuracy: 0.5213 - val_loss: 2.4034 - val_accuracy: 0.3760\n",
+            "Epoch 8/20\n",
+            "50/50 [==============================] - 1s 27ms/step - loss: 1.4424 - accuracy: 0.5573 - val_loss: 2.3988 - val_accuracy: 0.3520\n",
+            "Epoch 9/20\n",
+            "50/50 [==============================] - 1s 26ms/step - loss: 1.3804 - accuracy: 0.5840 - val_loss: 2.4564 - val_accuracy: 0.3480\n",
+            "Epoch 10/20\n",
+            "50/50 [==============================] - 1s 25ms/step - loss: 1.1348 - accuracy: 0.6680 - val_loss: 2.4833 - val_accuracy: 0.3480\n",
+            "Epoch 11/20\n",
+            "50/50 [==============================] - 1s 27ms/step - loss: 1.0722 - accuracy: 0.6773 - val_loss: 2.7686 - val_accuracy: 0.3560\n",
+            "Epoch 12/20\n",
+            "50/50 [==============================] - 1s 25ms/step - loss: 1.0049 - accuracy: 0.7067 - val_loss: 2.7160 - val_accuracy: 0.3800\n",
+            "Epoch 13/20\n",
+            "50/50 [==============================] - 1s 27ms/step - loss: 0.8915 - accuracy: 0.7427 - val_loss: 2.7581 - val_accuracy: 0.4160\n",
+            "Epoch 14/20\n",
+            "50/50 [==============================] - 1s 28ms/step - loss: 0.7965 - accuracy: 0.7627 - val_loss: 3.0566 - val_accuracy: 0.3480\n",
+            "Epoch 15/20\n",
+            "50/50 [==============================] - 1s 26ms/step - loss: 0.7574 - accuracy: 0.7707 - val_loss: 3.2272 - val_accuracy: 0.3800\n",
+            "Epoch 16/20\n",
+            "50/50 [==============================] - 1s 25ms/step - loss: 0.7050 - accuracy: 0.8027 - val_loss: 3.1865 - val_accuracy: 0.3880\n",
+            "Epoch 17/20\n",
+            "50/50 [==============================] - 1s 27ms/step - loss: 0.6664 - accuracy: 0.8307 - val_loss: 3.1371 - val_accuracy: 0.4160\n",
+            "Epoch 18/20\n",
+            "50/50 [==============================] - 1s 27ms/step - loss: 0.5803 - accuracy: 0.8320 - val_loss: 3.2395 - val_accuracy: 0.4000\n",
+            "Epoch 19/20\n",
+            "50/50 [==============================] - 1s 25ms/step - loss: 0.4792 - accuracy: 0.8533 - val_loss: 3.0263 - val_accuracy: 0.4440\n",
+            "Epoch 20/20\n",
+            "50/50 [==============================] - 1s 25ms/step - loss: 0.4119 - accuracy: 0.8947 - val_loss: 3.2198 - val_accuracy: 0.4000\n"
+          ]
+        },
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "<keras.callbacks.History at 0x7f520e3ea590>"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 18
+        }
+      ],
+      "source": [
+        "# 训练模型\n",
+        "model.fit(X_train, Y_train, epochs = 20, batch_size = 15, validation_data = (X_test, Y_test))"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 20,
+      "metadata": {
+        "id": "I3OAlap_weQg"
+      },
+      "outputs": [],
+      "source": [
+        "#预测\n",
+        "\n",
+        "def extract_features(test_dir, file_ext=\"*.wav\"):\n",
+        "    feature = []\n",
+        "    for fn in tqdm(glob.glob(os.path.join(test_dir, file_ext))[:]): # 遍历数据集的所有文件\n",
+        "        X, sample_rate = librosa.load(fn,res_type='kaiser_fast')\n",
+        "        mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征\n",
+        "        feature.extend([mels])\n",
+        "    return feature\n",
+        "    "
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 21,
+      "metadata": {
+        "id": "9pHaRW8UweQh",
+        "outputId": "8e37a519-b182-4d21-fed4-9481ce479812",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "100%|██████████| 2000/2000 [03:14<00:00, 10.26it/s]\n"
+          ]
+        }
+      ],
+      "source": [
+        "X_test = extract_features('./test_a/')"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 22,
+      "metadata": {
+        "id": "8ODoB2LIweQh"
+      },
+      "outputs": [],
+      "source": [
+        "X_test = np.vstack(X_test)\n",
+        "predictions = model.predict(X_test.reshape(-1, 16, 8, 1))"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 23,
+      "metadata": {
+        "id": "aumnZDzIweQi"
+      },
+      "outputs": [],
+      "source": [
+        "preds = np.argmax(predictions, axis = 1)\n",
+        "preds = [label_dict_inv[x] for x in preds]\n",
+        "\n",
+        "path = glob.glob('./test_a/*.wav')\n",
+        "result = pd.DataFrame({'name':path, 'label': preds})\n",
+        "\n",
+        "result['name'] = result['name'].apply(lambda x: x.split('/')[-1])\n",
+        "result.to_csv('submit.csv',index=None)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        ""
+      ],
+      "metadata": {
+        "id": "lLRtsuwE21Aq"
+      }
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 24,
+      "metadata": {
+        "id": "hJjx2q_8weQi",
+        "outputId": "80d85d36-f65a-40bd-fdb4-2f8fafd2cdbc",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "2000\n"
+          ]
+        }
+      ],
+      "source": [
+        "!ls ./test_a/*.wav | wc -l"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 25,
+      "metadata": {
+        "id": "ToTUBB8MweQi",
+        "outputId": "41b370ab-ff95-41bf-96cb-f1c2c7dce3ac",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        }
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "2001 submit.csv\n"
+          ]
+        }
+      ],
+      "source": [
+        "!wc -l submit.csv"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "rUl6daGYweQi"
+      },
+      "outputs": [],
+      "source": [
+        ""
+      ]
+    }
+  ],
+  "metadata": {
+    "language_info": {
+      "name": "python"
+    },
+    "orig_nbformat": 4,
+    "colab": {
+      "name": "1.baseline.ipynb",
+      "provenance": []
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    }
+  },
+  "nbformat": 4,
+  "nbformat_minor": 0
+}