当前位置: 首页 > article >正文

【深度学习】用LSTM写诗,生成式的方式写诗系列之一

Epoch 4: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=18.5, loss=5.8]
[5] loss: 5.828, accuracy: 18.389 , lr:0.001000
Epoch 5: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=19.2, loss=5.68]
[6] loss: 5.739, accuracy: 18.732 , lr:0.001000
Epoch 6: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=19.6, loss=5.57]
[7] loss: 5.629, accuracy: 19.197 , lr:0.001000
Epoch 7: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=19.9, loss=5.45]
[8] loss: 5.517, accuracy: 19.745 , lr:0.001000
Epoch 8: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=20.4, loss=5.38]
[9] loss: 5.402, accuracy: 20.316 , lr:0.001000
Epoch 9: 100%|████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=21, loss=5.27]
[10] loss: 5.299, accuracy: 20.887 , lr:0.001000
Epoch 10: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=21.8, loss=5.16]
[11] loss: 5.210, accuracy: 21.427 , lr:0.001000
Epoch 11: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=22.2, loss=5.1]
[12] loss: 5.136, accuracy: 21.923 , lr:0.001000
Epoch 12: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=22.5, loss=5.06]
[13] loss: 5.071, accuracy: 22.379 , lr:0.001000
Epoch 13: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=22.9, loss=5.02]
[14] loss: 5.011, accuracy: 22.819 , lr:0.001000
Epoch 14: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=23.3, loss=4.94]
[15] loss: 4.959, accuracy: 23.212 , lr:0.001000
Epoch 15: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=23.7, loss=4.91]
[16] loss: 4.910, accuracy: 23.564 , lr:0.001000
Epoch 16: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=24.3, loss=4.82]
[17] loss: 4.862, accuracy: 23.914 , lr:0.001000
Epoch 17: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=24, loss=4.83]
[18] loss: 4.818, accuracy: 24.228 , lr:0.001000
Epoch 18: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=24.7, loss=4.77]
[19] loss: 4.775, accuracy: 24.523 , lr:0.001000
Epoch 19: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=24.6, loss=4.73]
[20] loss: 4.734, accuracy: 24.808 , lr:0.001000
Epoch 20: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=25, loss=4.69]
[21] loss: 4.694, accuracy: 25.090 , lr:0.001000
Epoch 21: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=25, loss=4.71]
[22] loss: 4.657, accuracy: 25.346 , lr:0.001000
Epoch 22: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=25.8, loss=4.62]
[23] loss: 4.619, accuracy: 25.587 , lr:0.001000
Epoch 23: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=25.9, loss=4.59]
[24] loss: 4.584, accuracy: 25.825 , lr:0.001000
Epoch 24: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=26.3, loss=4.52]
[25] loss: 4.549, accuracy: 26.078 , lr:0.001000
Epoch 25: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=26.3, loss=4.53]
[26] loss: 4.516, accuracy: 26.280 , lr:0.001000
Epoch 26: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=26.6, loss=4.49]
[27] loss: 4.483, accuracy: 26.517 , lr:0.001000
Epoch 27: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=26.8, loss=4.46]
[28] loss: 4.451, accuracy: 26.746 , lr:0.001000
Epoch 28: 100%|███████████████████████████████████████████████████████| 63/63 [1:00:07<00:00, 57.26s/batch, acc=27.1, loss=4.41]
[29] loss: 4.422, accuracy: 26.937 , lr:0.001000
Epoch 29: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=27.2, loss=4.38]
[30] loss: 4.389, accuracy: 27.182 , lr:0.001000
Epoch 30: 100%|████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=27, loss=4.4]
[31] loss: 4.361, accuracy: 27.371 , lr:0.001000
Epoch 31: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=27.5, loss=4.34]
[32] loss: 4.332, accuracy: 27.589 , lr:0.001000
Epoch 32: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=27.6, loss=4.31]
[33] loss: 4.304, accuracy: 27.791 , lr:0.001000
Epoch 33: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.89batch/s, acc=27.9, loss=4.28]
[34] loss: 4.277, accuracy: 28.014 , lr:0.001000
Epoch 34: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=28, loss=4.26]
[35] loss: 4.248, accuracy: 28.200 , lr:0.001000
Epoch 35: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=28.6, loss=4.22]
[36] loss: 4.222, accuracy: 28.433 , lr:0.001000
Epoch 36: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=28.3, loss=4.21]
[37] loss: 4.196, accuracy: 28.625 , lr:0.001000
Epoch 37: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=29.1, loss=4.16]
[38] loss: 4.169, accuracy: 28.858 , lr:0.001000
Epoch 38: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=29.2, loss=4.13]
[39] loss: 4.142, accuracy: 29.056 , lr:0.001000
Epoch 39: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=29, loss=4.13]
[40] loss: 4.116, accuracy: 29.282 , lr:0.001000
Epoch 40: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=29.5, loss=4.12]
[41] loss: 4.092, accuracy: 29.477 , lr:0.001000
Epoch 41: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=29.7, loss=4.08]
[42] loss: 4.066, accuracy: 29.716 , lr:0.001000
Epoch 42: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=29.8, loss=4.06]
[43] loss: 4.042, accuracy: 29.918 , lr:0.001000
Epoch 43: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=30.5, loss=3.99]
[44] loss: 4.016, accuracy: 30.146 , lr:0.001000
Epoch 44: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=30.2, loss=4.01]
[45] loss: 3.990, accuracy: 30.398 , lr:0.001000
Epoch 45: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=30.6, loss=3.96]
[46] loss: 3.968, accuracy: 30.607 , lr:0.001000
Epoch 46: 100%|█████████████████████████████████████████████████████████| 63/63 [40:05<00:00, 38.19s/batch, acc=30.6, loss=3.96]
[47] loss: 3.945, accuracy: 30.814 , lr:0.001000
Epoch 47: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=30.9, loss=3.94]
[48] loss: 3.918, accuracy: 31.073 , lr:0.001000
Epoch 48: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=31.1, loss=3.91]
[49] loss: 3.893, accuracy: 31.322 , lr:0.001000
Epoch 49: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=32, loss=3.86]
[50] loss: 3.869, accuracy: 31.574 , lr:0.001000
Epoch 50: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=31.2, loss=3.9]
[51] loss: 3.846, accuracy: 31.811 , lr:0.001000
Epoch 51: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=31.7, loss=3.85]
[52] loss: 3.823, accuracy: 32.042 , lr:0.001000
Epoch 52: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=32.4, loss=3.8]
[53] loss: 3.798, accuracy: 32.325 , lr:0.001000
Epoch 53: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=32.2, loss=3.8]
[54] loss: 3.776, accuracy: 32.552 , lr:0.001000
Epoch 54: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=32.4, loss=3.79]
[55] loss: 3.755, accuracy: 32.794 , lr:0.001000
Epoch 55: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=32.8, loss=3.75]
[56] loss: 3.729, accuracy: 33.081 , lr:0.001000
Epoch 56: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=32.8, loss=3.74]
[57] loss: 3.708, accuracy: 33.301 , lr:0.001000
Epoch 57: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=33.8, loss=3.68]
[58] loss: 3.683, accuracy: 33.597 , lr:0.001000
Epoch 58: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=33.5, loss=3.67]
[59] loss: 3.661, accuracy: 33.838 , lr:0.001000
Epoch 59: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=34, loss=3.65]
[60] loss: 3.639, accuracy: 34.106 , lr:0.001000
Epoch 60: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=34, loss=3.65]
[61] loss: 3.619, accuracy: 34.350 , lr:0.001000
Epoch 61: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=34.1, loss=3.64]
[62] loss: 3.595, accuracy: 34.632 , lr:0.001000
Epoch 62: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=34.6, loss=3.57]
[63] loss: 3.573, accuracy: 34.872 , lr:0.001000
Epoch 63: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=34.8, loss=3.58]
[64] loss: 3.553, accuracy: 35.140 , lr:0.001000
Epoch 64: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=35.1, loss=3.53]
[65] loss: 3.531, accuracy: 35.394 , lr:0.001000
Epoch 65: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.81batch/s, acc=34.8, loss=3.56]
[66] loss: 3.512, accuracy: 35.636 , lr:0.001000
Epoch 66: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=35.1, loss=3.55]
[67] loss: 3.490, accuracy: 35.896 , lr:0.001000
Epoch 67: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=36.1, loss=3.49]
[68] loss: 3.471, accuracy: 36.147 , lr:0.001000
Epoch 68: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=36, loss=3.48]
[69] loss: 3.451, accuracy: 36.413 , lr:0.001000
Epoch 69: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=36.5, loss=3.44]
[70] loss: 3.436, accuracy: 36.595 , lr:0.001000
Epoch 70: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=36.5, loss=3.45]
[71] loss: 3.412, accuracy: 36.873 , lr:0.001000
Epoch 71: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=36.2, loss=3.44]
[72] loss: 3.393, accuracy: 37.130 , lr:0.001000
Epoch 72: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=36.4, loss=3.44]
[73] loss: 3.375, accuracy: 37.342 , lr:0.001000
Epoch 73: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=37.1, loss=3.4]
[74] loss: 3.355, accuracy: 37.608 , lr:0.001000
Epoch 74: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=37.2, loss=3.37]
[75] loss: 3.337, accuracy: 37.853 , lr:0.001000
Epoch 75: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.81batch/s, acc=37.9, loss=3.35]
[76] loss: 3.318, accuracy: 38.105 , lr:0.001000
Epoch 76: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=37.5, loss=3.35]
[77] loss: 3.303, accuracy: 38.282 , lr:0.001000
Epoch 77: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=37.9, loss=3.31]
[78] loss: 3.285, accuracy: 38.523 , lr:0.001000
Epoch 78: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=38.1, loss=3.3]
[79] loss: 3.267, accuracy: 38.738 , lr:0.001000
Epoch 79: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=38.9, loss=3.28]
[80] loss: 3.250, accuracy: 38.972 , lr:0.001000
Epoch 80: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=38.6, loss=3.27]
[81] loss: 3.230, accuracy: 39.248 , lr:0.001000
Epoch 81: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=39.1, loss=3.22]
[82] loss: 3.216, accuracy: 39.435 , lr:0.001000
Epoch 82: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=38.8, loss=3.25]
[83] loss: 3.197, accuracy: 39.675 , lr:0.001000
Epoch 83: 100%|██████████████████████████████████████████████████████████| 63/63 [00:38<00:00,  1.62batch/s, acc=39.7, loss=3.2]
[84] loss: 3.180, accuracy: 39.914 , lr:0.001000
Epoch 84: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=39.4, loss=3.2]
[85] loss: 3.165, accuracy: 40.108 , lr:0.001000
Epoch 85: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=40.1, loss=3.17]
[86] loss: 3.152, accuracy: 40.277 , lr:0.001000
Epoch 86: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=39.9, loss=3.18]
[87] loss: 3.135, accuracy: 40.508 , lr:0.001000
Epoch 87: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=40.4, loss=3.14]
[88] loss: 3.118, accuracy: 40.736 , lr:0.001000
Epoch 88: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=40.5, loss=3.14]
[89] loss: 3.104, accuracy: 40.918 , lr:0.001000
Epoch 89: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=40.7, loss=3.11]
[90] loss: 3.093, accuracy: 41.061 , lr:0.001000
Epoch 90: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=40.8, loss=3.1]
[91] loss: 3.074, accuracy: 41.315 , lr:0.001000
Epoch 91: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=41.5, loss=3.06]
[92] loss: 3.057, accuracy: 41.559 , lr:0.001000
Epoch 92: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=41.1, loss=3.09]
[93] loss: 3.043, accuracy: 41.745 , lr:0.001000
Epoch 93: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=41.5, loss=3.06]
[94] loss: 3.029, accuracy: 41.924 , lr:0.001000
Epoch 94: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=42, loss=3.03]
[95] loss: 3.015, accuracy: 42.133 , lr:0.001000
Epoch 95: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=41.6, loss=3.04]
[96] loss: 3.001, accuracy: 42.302 , lr:0.001000
Epoch 96: 100%|██████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=42, loss=3]
[97] loss: 2.988, accuracy: 42.483 , lr:0.001000
Epoch 97: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=42.9, loss=2.96]
[98] loss: 2.972, accuracy: 42.694 , lr:0.001000
Epoch 98: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=42.1, loss=3.01]
[99] loss: 2.964, accuracy: 42.804 , lr:0.001000
Epoch 99: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=42.2, loss=3.01]
[100] loss: 2.953, accuracy: 42.973 , lr:0.001000
Finished Training using %.3f seconds 6896.93013882637

先训练了100轮次,后面应该还能增长,但是不等了

数据初探:

class DictObj(object):
    def __init__(self, map):
        self.map = map
        
    def __getattr__(self, attr):
        if attr in self.map:
            return self.map[attr]
        else:
            raise AttributeError("No such attribute: " + attr)
        
Config = DictObj({
    'poem_path':os.path.join(base_dir, "tang.npz"),
    "tensorboard_path":os.path.join(base_dir, "tensorboard"),
    "model_save_path":os.path.join(base_dir,"modelDict"),
    "embedding_dim":100,
    "hidden_dim":1024,
    "lr":0.001,
    "LSTM_layers":2,
    'batch_size':512,
    'epochs':500,
    'dropout':0.2,
    'ealier_stop':10,
    'device':torch.device('cuda' if torch.cuda.is_available() else 'cpu')
})
def view_data(poem_path):
    datas = np.load(poem_path, allow_pickle=True)
    data = datas['data'] #(57580,125)
    ix2word = datas['ix2word'].item() # datas['word2ix'].item() 8293
    word2ix = datas['word2ix'].item() # datas['word2ix'].item() 8293
    word_data = np.zeros((1,data.shape[1]), dtype = str) # 将所有的0 转化成 ''
    # 看一下其中一行的数据是什么?
    row = np.random.randint(0, data.shape[0]) # 随机选一行,左闭右开没问题
    print(data[row])
    for i in range(data.shape[1]):
        word_data[0][i] = ix2word[data[row][i]]
    print(word_data)

view_data(Config.poem_path)

数据处理:

class PoemDataset(Dataset):
    def __init__(self, poem_path, seq_len):
        super().__init__()
        # np 文件的地址
        self.poem_path = poem_path
        # 序列长度,48 是认为规定的,也可以是其它值,因为大部分是5言或者7言,加上表达就是 6,或8, 取48确保是整句话
        self.seq_len = seq_len
        self.poem_data, self.ix2word, self.word2ix = self.get_raw_data()
        self.no_space_data = self.filter_space()
        print("no_space_data len:", self.no_space_data[0:200])
    
    def __len__(self):
        return len(self.no_space_data)//(self.seq_len)

    def __getitem__(self, idx):
        txt = self.no_space_data[idx*self.seq_len:(idx+1)*self.seq_len]
        label = self.no_space_data[idx*self.seq_len+1:(idx+1)*self.seq_len+1]
        return torch.LongTensor(txt), torch.LongTensor(label)
    
    def filter_space(self):
        # 7197500 个文本
        tensor_data = torch.from_numpy(self.poem_data).view(-1)
        no_space_data = []  
        for i in range(tensor_data.shape[0]):
            word_idx = tensor_data[i].item()
            if word_idx!= 8292:
                no_space_data.append(word_idx)
        return no_space_data

        
    def get_raw_data(self):
        datas = np.load(self.poem_path, allow_pickle=True)
        data = datas['data']
        ix2word = datas['ix2word'].item()
        word2ix = datas['word2ix'].item()
        return data, ix2word, word2ix
poem_dataset = PoemDataset(Config.poem_path, 96)
[8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8291 5428 6933 3469 7066 3465 6407
 8248 7009   82 7435  925 3469 3576  232  786 5272 2296 7066 4807 6103
 6663 2958 2003 2173   28 7066 1987 8061 4299  848 4874 7435 8290]
[['<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'
  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'
  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'
  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'
  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'
  '<' '<' '冬' '月' '内' ',' '无' '叶' '艾' '枝' '枯' '。' '草' '内' '急' '寻' '蛇' '床'
  '子' ',' '烧' '烟' '入' '中' '自' '消' '除' ',' '速' '救' '免' '灾' '虞' '。' '<']]
no_space_data len: [8291, 6731, 4770, 1787, 8118, 7577, 7066, 4817, 648, 7121, 1542, 6483, 7435, 7686, 2889, 1671, 5862, 1949, 7066, 2596, 4785, 3629, 1379, 2703, 7435, 6064, 6041, 4666, 4038, 4881, 7066, 4747, 1534, 70, 3788, 3823, 7435, 4907, 5567, 201, 2834, 1519, 7066, 782, 782, 2063, 2031, 846, 7435, 8290, 8291, 2309, 2596, 6483, 2260, 7316, 7066, 6332, 5274, 2125, 5029, 7792, 7435, 4186, 8087, 7047, 6622, 6933, 7066, 6134, 3564, 3766, 6920, 6157, 7435, 7086, 4770, 5849, 4776, 4981, 7066, 4857, 2649, 3020, 332, 1727, 7435, 7458, 7294, 3465, 5149, 1671, 7066, 2834, 6000, 3942, 3534, 1534, 7435, 4102, 7460, 758, 3961, 3374, 7066, 7904, 6811, 4449, 2121, 6802, 7435, 6182, 27, 7912, 1756, 7440, 7066, 201, 7909, 8118, 201, 4662, 7435, 7824, 1508, 3154, 152, 5862, 7066, 7976, 6043, 258, 47, 7878, 7435, 8290, 8291, 3495, 70, 7113, 4839, 5237, 7066, 65, 3941, 2031, 2260, 5418, 7435, 411, 6773, 2878, 4686, 482, 7066, 1989, 5617, 4992, 8245, 676, 7435, 4236, 1418, 4915, 7686, 7363, 7066, 5708, 7541, 7440, 5237, 2192, 7435, 3114, 5913, 7989, 3069, 1845, 7066, 7047, 3534, 4921, 6622, 6933, 7435, 1664, 2260, 2003, 4816, 7151, 7066, 5036, 2219, 5849, 4898, 174, 7435, 201, 7228, 222]

因为有空格,啥的,要先吧空格之类的去掉。

def show_dataset():
    idx,label = poem_dataset[0]
    for id in idx:
        print(poem_dataset.ix2word[id.item()], end=' ')
    print("\n")
    for la in label:
        print(poem_dataset.ix2word[la.item()], end=' ')
    '''
    <START> 度 门 能 不 访 , 冒 雪 屡 西 东 。 已 想 人 如 玉 , 遥 怜 马 似 骢 。 乍 迷 金 谷 路 , 稍 变 上 阳 宫 。 还 比 相 思 意 , 纷 纷 正 满 空 。 <EOP> <START> 逍 遥 东 城 隅 , 双 树 寒 葱 蒨 。 广 庭 流 华 月 , 高 阁 凝 余 霰 。 杜 门 非 养 素 , 抱 疾 阻 良 䜩 。 孰 谓 无 他 人 , 思 君 岁 

    度 门 能 不 访 , 冒 雪 屡 西 东 。 已 想 人 如 玉 , 遥 怜 马 似 骢 。 乍 迷 金 谷 路 , 稍 变 上 阳 宫 。 还 比 相 思 意 , 纷 纷 正 满 空 。 <EOP> <START> 逍 遥 东 城 隅 , 双 树 寒 葱 蒨 。 广 庭 流 华 月 , 高 阁 凝 余 霰 。 杜 门 非 养 素 , 抱 疾 阻 良 䜩 。 孰 谓 无 他 人 , 思 君 岁 云
    '''
# 构建模型
class PoemModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input, hidden=None):
        embeds = self.embedding(input)
        batch_size,seq_len,embedding_dim = embeds.shape
        if hidden is None:
            h0 = torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim).to(Config.device)
            c0 = torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim).to(Config.device)
        else:
            h0,c0 = hidden
        output, hidden = self.lstm(embeds, (h0, c0))
        # output = torch.tanh(self.dropout(self.fc1(output)))
        output = self.fc(output)
        return output, hidden

vocab_size = len(poem_dataset.word2ix)
model = PoemModel(vocab_size, Config.embedding_dim, Config.hidden_dim, Config.LSTM_layers, Config.dropout).to(Config.device)
input_data, label_data = next(iter(dataloader))
print(input_data.shape, label_data.shape)
output, hidden = model(input_data.to(Config.device))
# output.shape torch.Size([1024, 96, 8293]) hidden[0].shape torch.Size([3, 1024, 1024]) hidden[1].shape torch.Size([3, 1024, 1024])  label_data.shape torch.Size([1024, 96])
a = 1

def accuracy(output, label_data):
    pred = output.argmax(dim=2)
    correct = (pred == label_data).sum().item()
    total = label_data.numel()
    return correct / total * 100
# 训练模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=Config.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=1)
def train(model, dataloader, criterion, optimizer, scheduler, epochs):
    if not os.path.exists(Config.model_save_path):
        os.makedirs(Config.model_save_path)
    best_acc = 0.0
    early_stop = 0
    start_time = time.time()
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        last_acc = 0.0
        with tqdm(dataloader, unit="batch") as tepoch:
            for input_data, label_data in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                input_data, label_data = input_data.to(Config.device), label_data.to(Config.device)
                optimizer.zero_grad()
                output, hidden = model(input_data)
                current_acc = accuracy(output, label_data)
                running_acc += current_acc
                loss = criterion(output.view(-1, vocab_size), label_data.view(-1))
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss.item(), acc=current_acc)
        scheduler.step()
        last_acc = running_acc / len(dataloader)
        if last_acc > best_acc:
            best_acc = last_acc
            torch.save(model.state_dict(), os.path.join(Config.model_save_path, "best_model.pth"))
        else:
            early_stop += 1
        torch.save(model.state_dict(), os.path.join(Config.model_save_path, "last_model.pth"))
        print('[%d] loss: %.3f, accuracy: %.3f , lr:%.6f'  % (epoch + 1, running_loss / len(dataloader), last_acc,scheduler.get_last_lr()[0]))

        if early_stop >= Config.ealier_stop:
            print("Early Stop")
            print("Best Accuracy: %.3f" % best_acc)
            break
    print('Finished Training using %.3f seconds', time.time() - start_time)
train(model, dataloader, criterion, optimizer, scheduler, Config.epochs)  

模型构建以及训练如上,

现在看 500轮次,50个忍耐度的效果比较好

Epoch 145: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.81batch/s, acc=61.6, loss=1.62]
[146] loss: 1.580, accuracy: 62.374 , lr:0.001000
Epoch 146: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.80batch/s, acc=62.5, loss=1.57]
[147] loss: 1.581, accuracy: 62.360 , lr:0.001000
Epoch 147: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=61.9, loss=1.61]
[148] loss: 1.582, accuracy: 62.324 , lr:0.001000
Early Stop
Best Accuracy: 62.397
Finished Training using %.3f seconds 1205.000694513321

使用效果

不满意的地方,写死96=seq_len 是不对的。
应该是 配合 padding 使用,并mask padding来指导损失 @todo, 下一篇文章我会搞定!

import torch
from train03 import Config
from train03 import PoemModel

from train03 import PoemDataset
import os

poem_dataset = PoemDataset(Config.poem_path, 96)

vocab_size = len(poem_dataset.word2ix)
model = PoemModel(
    vocab_size, 
    Config.embedding_dim, 
    Config.hidden_dim, 
    Config.LSTM_layers, 
    Config.dropout).to(Config.device)
model.load_state_dict(torch.load(os.path.join(Config.model_save_path, "best_model.pth")))

def generate(model, start_words, ix2word, word2ix, device):
     results = list(start_words)
     start_words_len = len(start_words)
     # 第一个词语是<START>
     input = torch.Tensor([word2ix['<START>']]).view(1, 1).long()
     # 最开始的隐状态初始为0矩阵
     # torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim)
     hidden = torch.zeros((2,Config.LSTM_layers * 1, 1, Config.hidden_dim), dtype=torch.float32).to(Config.device)
     input = input.to(Config.device)
     hidden = hidden.to(Config.device)
     model.eval()
     with torch.no_grad():
        for i in range(48):
            output, hidden = model(input, hidden)
            # 如果在给定的句首中,input为句首中的下一个字
            if i < start_words_len:
                w = results[i]
                input = input.data.new([word2ix[w]]).view(1, 1)
            else:
                top_index = output.data[0].topk(1)[1][0].item()
                w = ix2word[top_index]
                results.append(w)
                input = input.data.new([top_index]).view(1, 1)
            if w == '<EOP>':
                del results[-1]
                break
        return results

雨 余 虚 馆 竹 阴 清 , 独 坐 寒 窗 昼 未 醒 。 云 布 远 村 红 叶 返 , 水 深 秋 竹 翠 梢 寒 。 泉 声 入 阁 慙 嘉 石 , 山 色 题 诗 好 赋 诗 。

但是我有一点不太理解。
他是输入一个字,输出一个字,这一点好像不妥。不应该是 输入一个 生成1个, 然后输入两个,生成1个,然后输入3个生成1个么。。。 大神请指教一下吧。


http://www.kler.cn/a/373665.html

相关文章:

  • G1原理—5.G1垃圾回收过程之Mixed GC
  • IntelliJ IDEA中Maven项目的配置、创建与导入全攻略
  • java项目之房屋租赁系统源码(springboot+mysql+vue)
  • 79 Openssl3.0 RSA公钥加密数据
  • 【机器学习:八、逻辑回归】
  • imageio 图片转mp4 保存mp4
  • 下一代「自动化测试框架」WebdriverIO
  • STM32--STM32 微控制器详解
  • unity3d————Mathf.Lerp() 函数详解
  • 从0开始深度学习(21)——读写数据和GPU
  • 【Nas】X-DOC:Mac mini 安装 ZeroTier 并替换 planet 实现内网穿透
  • 人工智能中的机器学习和模型评价
  • RNN在训练中存在的问题
  • 常见的机器学习模型汇总
  • C++ 复习记录(个人记录)
  • 基于Multisim的四位抢答器设计与仿真
  • 数据结构,问题 A: 翻转字符串
  • 野火鲁班猫4 (RK3588)系统配置
  • Mybatis 统计sql运行时间
  • 嵌入式linux跨平台基于mongoose的TCP C++类的源码
  • 如何在macOS开发中给 PKG 签名和公证(productsign+notarytool)
  • Vue中path和component属性
  • JAVA基础练习题
  • 攻防世界 MISC miao~详解
  • 无人机测绘遥感技术算法概述!
  • Q-learning原理及代码实现