当前位置: 首页 > article >正文

python:music21 构建 LSTM+GAN 模型生成爵士风格音乐

keras_lstm_gan_midi.py 这是一个结合 LSTM 和 GAN 生成爵士风格音乐的完整Python脚本。这个实现包含音乐特征提取、对抗训练机制和MIDI生成功能:

import numpy as np
from music21 import converter, instrument, note, chord, stream
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import (LSTM, Dense, Dropout, 
         Input, Embedding, Reshape, Bidirectional, Conv1D, Flatten)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical

# 配置参数
MIDI_FILE = "jazz_swing.mid"  # 爵士训练数据
SEQ_LENGTH = 32               # 序列长度
NOISE_DIM = 100               # 噪声向量维度
BATCH_SIZE = 64
EPOCHS = 2000
STEPS_PER_EPOCH = 50
SAVE_INTERVAL = 100

class JazzGAN:
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
        self.seq_length = SEQ_LENGTH
        self.noise_dim = NOISE_DIM
        
        # 构建模型
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
        self.gan = self.build_gan()
        
        # 配置优化器
        self.d_optimizer = Adam(0.0002, 0.5)
        self.g_optimizer = Adam(0.0001, 0.5)
        
    def build_generator(self):
        """构建LSTM生成器"""
        model = Sequential([
            Input(shape=(self.noise_dim,)),
            Dense(256),
            Reshape((1, 256)),
            LSTM(512, return_sequences=True),
            Dropout(0.3),
            LSTM(256),
            Dense(self.vocab_size, activation='softmax')
        ])
        return model
    
    def build_discriminator(self):
        """构建CNN-LSTM判别器""" 
        model = Sequential([
            Input(shape=(self.seq_length,)),
            Embedding(self.vocab_size, 128),
            Conv1D(64, 3, strides=2, padding='same'),
            Bidirectional(LSTM(128)),
            Dense(64, activation='relu'),
            Dropout(0.2),
            Dense(1, activation='sigmoid')
        ])
        return model
    
    def build_gan(self):
        """组合GAN模型"""
        self.discriminator.trainable = False
        gan_input = Input(shape=(self.noise_dim,))
        generated_seq = self.generator(gan_input)
        validity = self.discriminator(generated_seq)
        return Model(gan_input, validity)
    
    def preprocess_midi(self, file_path):
        """处理MIDI数据"""
        notes = []
        midi = converter.parse(file_path)
        
        print("Extracting notes...")
        for element in midi.flat.notes:
            if isinstance(element, note.Note):
                notes.append(str(element.pitch))
            elif isinstance(element, chord.Chord):
                notes.append('.'.join(str(n) for n in element.normalOrder))
        
        # 创建字典映射
        unique_notes = sorted(set(notes))
        self.note_to_int = {n:i for i, n in enumerate(unique_notes)}
        self.int_to_note = {i:n for i, n in enumerate(unique_notes)}
        self.vocab_size = len(unique_notes)
        
        # 转换为整数序列
        int_sequence = [self.note_to_int[n] for n in notes]
        
        # 创建训练序列
        sequences = []
        for i in range(len(int_sequence) - self.seq_length):
            seq = int_sequence[i:i+self.seq_length]
            sequences.append(seq)
            
        return np.array(sequences)
    
    def train(self, X_train):
        # 标签平滑
        valid = np.ones((BATCH_SIZE, 1)) * 0.9
        fake = np.zeros((BATCH_SIZE, 1))
        
        for epoch in range(EPOCHS):
            # 训练判别器
            idx = np.random.randint(0, X_train.shape[0], BATCH_SIZE)
            real_seqs = X_train[idx]
            
            noise = np.random.normal(0, 1, (BATCH_SIZE, self.noise_dim))
            gen_seqs = self.generator.predict(noise)
            
            d_loss_real = self.discriminator.train_on_batch(real_seqs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_seqs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            # 训练生成器
            noise = np.random.normal(0, 1, (BATCH_SIZE, self.noise_dim))
            g_loss = self.gan.train_on_batch(noise, valid)
            
            # 输出训练进度
            if epoch % 100 == 0:
                print(f"Epoch {epoch} | D Loss: {d_loss[0]} | G Loss: {g_loss}")
                
            # 保存示例
            if epoch % SAVE_INTERVAL == 0:
                self.generate_and_save(epoch)
                
    def generate_and_save(self, epoch):
        """生成示例音乐"""
        noise = np.random.normal(0, 1, (1, self.noise_dim))
        generated = self.generator.predict(noise)
        generated_indices = np.argmax(generated, axis=-1)
        
        # 转换为音符
        output_notes = []
        for idx in generated_indices[0]:
            output_notes.append(self.int_to_note[idx])
            
        # 创建MIDI流
        midi_stream = stream.Stream()
        
        for pattern in output_notes:
            # 添加爵士和弦扩展
            if '.' in pattern:
                notes_in_chord = pattern.split('.')
                chord_notes = [note.Note(int(n)) for n in notes_in_chord]
                # 添加7th扩展
                if len(chord_notes) == 3:
                    root = chord_notes[0].pitch
                    chord_notes.append(root.transpose(10))
                new_chord = chord.Chord(chord_notes)
                midi_stream.append(new_chord)
            else:
                new_note = note.Note(int(pattern))
                new_note.storedInstrument = instrument.Saxophone()
                midi_stream.append(new_note)
                
        # 添加摇摆节奏
        self.add_swing_rhythm(midi_stream)
        
        midi_stream.write('midi', fp=f'jazz_gan_{epoch}.mid')
        print(f"Saved sample at epoch {epoch}")
    
    def add_swing_rhythm(self, stream_obj):
        """添加爵士摇摆节奏"""
        for i, n in enumerate(stream_obj.flat.notes):
            if i % 2 == 1:  # 每两个音符中的第二个
                n.offset += 0.08  # 轻微延迟
                if isinstance(n, note.Note):
                    n.duration.quarterLength *= 0.9

if __name__ == "__main__":
    # 初始化并预处理数据
    gan = JazzGAN(vocab_size=0)  # 初始占位
    sequences = gan.preprocess_midi(MIDI_FILE)
    
    # 训练GAN
    gan.train(sequences)

使用说明:

  1. 准备数据

    • 需要至少10-20个爵士MIDI文件(建议包含钢琴三重奏、大乐队等风格)

    • 推荐数据集:Jazzomat Research Project
      MeloSpySuite 捆绑了3个命令行工具:用于转换旋律文件格式的 melconv、用于特征提取的 melfeature 和用于模式挖掘的 melpat。

  2. 环境配置
    pip install tensorflow
    pip install music

实现亮点说明:

1.模型架构改进

  1. 混合架构设计

    • 生成器:使用双层LSTM结构,适合处理音乐时序特征

    • 判别器:结合CNN和BiLSTM,有效捕捉局部与全局模式

    • 加入Embedding层处理离散音符符号

  2. 爵士特征增强

    # 在和弦中自动添加7th音
    if len(chord_notes) == 3:
        root = chord_notes[0].pitch
        chord_notes.append(root.transpose(10))
    
    # 摇摆节奏处理
    n.offset += 0.08
    n.duration.quarterLength *= 0.9
    
  3. 训练优化技巧

    • 使用标签平滑(Label Smoothing)提升判别器鲁棒性

    • 分离生成器和判别器的学习率(0.0001 vs 0.0002)

    • 定期保存生成样本监控训练进度

  4.  训练建议

    • 使用GPU加速(至少需要8GB显存)

    • 初始阶段设置EPOCHS=500 以获得较好效果

    • 调整SEQ_LENGTH(32-64)匹配音乐片段长度

  5. 生成样本后处理

    • 使用DAW(如 Ableton Live)添加爵士乐器的真实音色

    • 人工调整和声进行确保功能性(II-V-I等典型进行)

性能优化方向

  1. 模型架构改进

    # 在生成器加入注意力机制
    from tensorflow.keras.layers import Attention
    
    def build_generator(self):
        inputs = Input(shape=(self.noise_dim,))
        x = Dense(256)(inputs)
        x = Reshape((1, 256))(x)
        x = LSTM(512, return_sequences=True)(x)
        x = Attention()([x, x])  # 自注意力
        x = LSTM(256)(x)
        outputs = Dense(self.vocab_size, activation='softmax')(x)
        return Model(inputs, outputs)
    

    2.数据增强: 

    # 实时数据增强
    def augment_sequence(seq):
        # 随机转调
        if np.random.rand() > 0.5:
            shift = np.random.randint(-3, 4)
            seq = (seq + shift) % self.vocab_size
        # 随机节奏缩放
        return seq
    

该脚本生成的爵士音乐将具备以下特征:

  • 复杂的和弦扩展(7th、9th、11th)

  • 摇摆节奏(Swing Feel)

  • 即兴化的旋律走向

  • 符合爵士和声进行规则(如替代和弦使用)

建议配合使用MIDI效果器(如iReal Pro的和声引擎)进行后期处理,可以获得更专业的爵士乐效果。


http://www.kler.cn/a/596938.html

相关文章:

  • SpringBoot+VUE(Ant Design Vue)实现图片下载预览功能
  • 仿函数 VS 函数指针实现回调
  • 存算分离是否真的有必要?从架构之争到 Doris 实战解析
  • 关于网络中的超参数小记
  • RTOS系列文章(17)-- 为什么RTOS选择PendSV实现任务切换?(从硬件机制到RTOS设计的终极答案)
  • NocoBase 本周更新汇总:优化表格区块的列和操作
  • Vue 中的日期格式化实践:从原生 Date 到可视化展示!!!
  • 青少年编程与数学 02-011 MySQL数据库应用 10课题、记录的操作
  • 【微服务架构】SpringCloud(二):Eureka原理、服务注册、Euraka单独使用
  • 蓝桥杯备考:二分答案之路标设置
  • 掌握新编程语言的秘诀:利用 AI 快速上手 Python、Go、Java 和 Rust
  • AI大白话(六):强化学习——AI如何通过“试错“成为大师?
  • 隋卞做 隋卞一探 视频下载
  • 配置DHCP(centos+OUS)
  • QHDBO基于量子计算和多策略融合的蜣螂优化算法
  • Fiddler抓包工具最快入门
  • 人工智能之数学基础:矩阵条件数在线性方程组求解中的应用
  • 律师解读《无人驾驶航空器飞行管理暂行条例》第二十二条
  • illustrate:一款蛋白/核酸结构快速渲染为“卡通风格”的小工具
  • Vue学习笔记集--路由