百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

我的模型我做主02——训练自己的大模型:简易入门指南

itomcoil 2025-05-02 18:57 7 浏览

模型训练往往需要较高的配置,为了满足友友们的好奇心,这里我们不要内存,不要gpu,用最简单的方式,让大家感受一下什么是模型训练。基于你的硬件配置,我们可以设计一个完全在CPU上运行的简易模型训练方案。以下是具体步骤:

环境准备

这里以mac为例,其他系统原理类似,也可不使用miniconda,本文主要集中在训练代码和推理代码上。

安装Miniconda(推荐)

# 下载Miniconda
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh
# 安装
bash Miniconda3-latest-MacOSX-arm64.sh

创建虚拟环境

conda create -n tinyai python=3.9
conda activate tinyai

安装PyTorch

# 安装pytorch,也可通过官网选择合适的安装语句
pip install torch torchvision torchaudio

超简易模型训练方案

纯CPU训练微型文本模型

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# 超简单数据集
class TextDataset(Dataset):
    def __init__(self):
        self.data = ["hello world", "deep learning", "apple silicon", "metal acceleration"]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data[idx]
        # 简单字符级编码
        x = [ord(c) for c in text[:-1]]
        y = [ord(c) for c in text[1:]]
        return torch.tensor(x), torch.tensor(y)

# 超简单模型
class TinyLM(nn.Module):
    def __init__(self, vocab_size=128):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, 32)
        self.rnn = nn.RNN(32, 64, batch_first=True)
        self.fc = nn.Linear(64, vocab_size)
    
    def forward(self, x):
        x = self.embed(x)
        out, _ = self.rnn(x)
        return self.fc(out)
def custom_collate_fn(batch):
    # batch是包含多个(__getitem__返回结果)的列表
    x_batch, y_batch = zip(*batch)
    
    # 找到本批次中的最大长度
    max_len = max(len(x) for x in x_batch)
    
    # 填充每个样本
    x_padded = torch.stack([
        torch.cat([x, torch.zeros(max_len - len(x), dtype=torch.long)]) 
        for x in x_batch
    ])
    
    y_padded = torch.stack([
        torch.cat([y, torch.zeros(max_len - len(y), dtype=torch.long)]) 
        for y in y_batch
    ])
    
    return x_padded, y_padded



# 训练设置
dataset = TextDataset()
# loader = DataLoader(dataset, batch_size=2)
# 然后修改DataLoader
loader = DataLoader(dataset, batch_size=2, collate_fn=custom_collate_fn)
model = TinyLM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练循环
for epoch in range(10):
    for x, y in loader:
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output.view(-1, 128), y.view(-1))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# 保存模型和tokenizer(虽然我们用的是简单编码)
torch.save(model.state_dict(), 'tinylm.pth')

# 同时保存词汇表信息(这里只是示例,实际字符编码是固定的)
import pickle
with open('char_vocab.pkl', 'wb') as f:
    pickle.dump({'vocab_size': 128}, f)  # ASCII字符范围

模型推理

创建一个新的Python文件inference.py:

import torch
import torch.nn as nn

class TinyLM(nn.Module):
    def __init__(self, vocab_size=128):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, 32)
        self.rnn = nn.RNN(32, 64, batch_first=True)
        self.fc = nn.Linear(64, vocab_size)
    
    def forward(self, x):
        x = self.embed(x)
        out, _ = self.rnn(x)
        return self.fc(out)

# 加载模型
model = TinyLM()
model.load_state_dict(torch.load('tinylm.pth'))
model.eval()  # 设置为评估模式

# 简单的字符编码函数
def text_to_tensor(text):
    return torch.tensor([[ord(c) for c in text]])

# 推理函数
def generate_text(start_str, length=10):
    input_seq = text_to_tensor(start_str)
    hidden = None
    
    for _ in range(length):
        with torch.no_grad():  # 禁用梯度计算
            output = model(input_seq)
            # 获取最后一个字符的预测
            last_char_logits = output[0, -1, :]
            # 选择概率最高的字符
            predicted_char = torch.argmax(last_char_logits).item()
            # 添加到输入序列中
            input_seq = torch.cat([
                input_seq, 
                torch.tensor([[predicted_char]])
            ], dim=1)
    
    # 将数字转换回字符
    generated_text = ''.join([chr(c) for c in input_seq[0].tolist()])
    return generated_text

# 使用示例
if __name__ == "__main__":
    while True:
        seed = input("输入起始字符串(或输入q退出): ")
        if seed.lower() == 'q':
            break
        generated = generate_text(seed, length=20)
        print(f"生成结果: {generated}")

运行推理示例

输入起始字符串(或输入q退出): hello
生成结果: hello world deep lear

结果分析

生成的文本无意义

主要原因是模型太小或训练不足,后续的解决方案是增加训练epoch或扩大模型,当然本文的目的就是让大家熟悉一下基本的模型训练和推理流程。

相关推荐

python学习——029统计【字典的列表】里符合条件元素数量

方式一:使用普通for循环students=[{'name':'jack','points':100},{'...

玩转Python—列表使用教程

上一讲给大家介绍了Python的列表,今天继续给大家介绍Python中列表的使用。1.列表的元素的赋值#实例>>>num=[1,2,3,4,5,6,7,7,8,8,9]>...

python学习——030如何将列表中的元素按要求分类

方法一:原代码方法(使用while循环结合pop方法)创建了numbers列表的一个副本temp_numbers,在循环中对temp_numbers进行操作,保证原列表numbers的内...

Python 条件判断教程

Let'sdivein!1.基本的if语句(ifStatement)在Python中,if语句用来根据条件执行代码块。当条件为True(真)时,代码块将被执行;否则,...

list列表基本操作

【实验目的】1、掌握list列表的基本操作【实验原理】列表是Python中最基本的数据结构,列表是最常用的Python数据类型,列表的数据项不需要具有相同的类型。列表中的每个元素都分配一个数字-它...

Python变量类型判断方法详解

技术背景在Python编程中,变量类型的判断是一项基础且重要的操作。由于Python是动态类型语言,变量的类型在运行时才能确定,因此在开发过程中,我们常常需要明确变量的类型,以便进行相应的操作。同时,...

基础知识详解:Python any()函数的使用方法(含示例代码)

前言:今天为大家带来的内容是:Pythonany()函数的使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,要是喜欢本文章内容的朋友,一定不忘点赞转发关注收藏不迷...

新手学Python避坑,学习效率狂飙! 八、Python 布尔值判断

布尔值判断系统知识在Python里,布尔类型仅有两个值:True和False,它们常被用于条件判断。下面从几个方面展开介绍:1.布尔运算逻辑与(and):只有当两个操作数都为True时,...

Python新手必看|列表操作全攻略(增删改查+切片+推导式)

一、为什么列表是Python的"万能容器"?作为最灵活的序列类型,列表支持:存储任意类型数据(数字/字符串/对象)动态增减元素快速索引访问丰富的内置方法python#创建包含不同数据...

Python列表集合操作介绍?

列表和集合是在Python编程中比较常用,而且比较常见的两种数据结构,他们有着各自的特点以及使用场景,下面我们就来详细的介绍一下列表和集合在实际使用中的一些操作对比。List(列表)列表操作的特点就是...

对Python中序列的个人理解

序列指的是一块连续内存空间存放多个值,在python中,序列类型包括字符串、列表、字典、元组、集合。其中包括字符串、列表、元组为有序序列,字典、集合属于无序序列。例如:#定义变量name,并赋值为字...

python 数据结构之列表(list)简述及演示

(一)list列表定义使用中括号[],里面元素可以是任意类型,包括列表本身,也可以是字典、元组等。(二)在Python中,第一个列表元素的索引为0,而不是1。(三)要访问列表的任何元素,都可将...

Python 列表(List)详解

列表是Python中最基本、最常用的数据结构之一,它是一个有序的、可变的元素集合。一、列表的基本操作1.创建列表#空列表empty_list=[]empty_list=list()...

自学Python第九天——操作列表

自学Python第九天——操作列表一、遍历整个列表1、需要对列表中的每个元素进行相同的操作时,使用for循环例如前几张我们学过的一些内容,想将列表中的每个元素打印出来,需要不断地重复代码,而且因列表长...

Python 列表(List)完全指南:数据操作的利器

在Python中,列表(list)是一种可变序列(mutablesequence),它允许我们存储和操作一组有序数据(ordereddata)。本教程将从基础定义(basicdefiniti...