GithubHelp home page GithubHelp logo

Comments (6)

vaxilicaihouxian avatar vaxilicaihouxian commented on May 22, 2024
@Slf4j
@Service
@Component
public class StoreService {

    @Autowired
    private StoreDao storeDao;

    @Autowired
    private StoreServiceImpl storeService;

    public Store create(Store store) {
        Assert.notNull(store, "store can't be null");
        store.setName(StrUtils.emptyToNull(store.getName()));
        store.setPrice(store.getPrice() == null ? 0.0 : store.getPrice());
        if (store.getPrice() < 100) {
            throw new BizException("商品价格必须大于100");
        }
        if ("001".equals(store.getType())) {
            if (store.getPrice() < 200) {
                throw new BizException("商品类型为 '001',价格必须大于200");
            }
        }
        if (storeDao.selectByName(store.getName()) != null) {
            throw new BizException("商品名称不能重复");
        }
        storeDao.insert(store);
        return storeService.selectById(store.getId());
    }

    public Store update(Store store) {
        Assert.notNull(store, "store can't be null");
        store.setName(StrUtils.emptyToNull(store.getName()));
        store.setPrice(store.getPrice() == null ? 0.0 : store.getPrice());
        if (store.getPrice() < 100) {
            throw new BizException("商品价格必须大于100");
        }
        if ("001".equals(store.getType())) {
            if (store.getPrice() < 200) {
                throw new BizException("商品类型为 '001',价格必须大于200");
            }
        }
        if (storeDao.selectByName(store.getName()) != null) {
            throw new BizException("商品名称不能重复");
        }
        storeDao.updateById(store);
        return storeService.selectById(store.getId());
    }

    public Store delete(String id) {
        Store store = storeDao.selectById(id);
        storeDao.deleteById(id);
        return store;
    }

    public Store selectById(String id) {
        Store store = storeDao.selectById(id);
        return store;
    }

    public Store selectById(Store store) {
        Store store1 = storeDao.selectById(store.getId());
        return store1;
    }

    public List<Store> selectAll() {
        return storeDao.selectAll();
    }
}

我用的int4量化的版本,用chatglm.cpp 跑的,在macbook pro m2乞丐版上

from codegeex2.

vaxilicaihouxian avatar vaxilicaihouxian commented on May 22, 2024

你要不试试在你的prompt最后再加一个\n ?我发现只有一个\n结尾的时候经常出现重复生成同样代码的问题

from codegeex2.

Stanislas0 avatar Stanislas0 commented on May 22, 2024

Environment

- OS: macos Ventura 13.2.1
- Python: 3.11
- Transformers: 4.30.2
- PyTorch: 2.0.1
- CUDA Support: False

Current Behavior

使用 Mac M2Max 进行推理异常:

  1. 内存最高吃到94G;
  2. 要求 Java 语言,推理结果 Python;
  3. 简单prompt推理时长超长(几十秒到3分钟);
  4. 复杂prompt经常不会出结果(10分钟);
  5. 错误的、重复的推理结果。

代码 demo

model_path = "/xxxxxxxxx"
model_id = 'ZhipuAI/codegeex2-6b'

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).half().to("mps")
model = model.eval()
# remember adding a language tag for better performance
prompt = "// language: java\n// 使用Mybatis-plus 的分页查询用户\n"
# prompt = "language: Python\n# write a bubble sort function\n"
# prompt = "language: Java\n# write a bubble sort function\n"
# prompt = "language: Java\n# 使用Mybatis-plus 的分页查询用户\n"
#prompt = "# language: Java\n# 使用Mybatis-plus 写一个关于【商城Service】的业务代码,商城的 Service 命名为 StoreService.工具:1. 字符串处理使用hutool的StrUtils; 2. 抛异常使用hutool的Assert; 3. 业务异常使用 BizException; 实体类有字段 :String id;String name;Double price;String type;业务-【新增商品】,业务规则:1. 必填名称;2. 价格必须大于100;3. 如果商品类型为 '001',价格必须大于200;4. 商品名称不能重复。\n"
# prompt = "# language:Java\n# 写一个冒泡排序函数"
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_length=888)
response = tokenizer.decode(outputs[0])
print(response)

具体案例如下:

Case 1:官方 demo 的 prompt = “# language: Python\n# write a bubble sort function\n”

可推理出结果,时间大约10秒。

Case 2:官方 demo 改为 java prompt = “# language: Java\n# write a bubble sort function\n”

写出了Python, 时间大约10秒。

language: Java
# write a bubble sort function
def bubble_sort(arr):
    for i in range(len(arr) - 1):
        for j in range(len(arr) - 1):
            if arr[j] > arr[j + 1]:
                arr[j], arr[j + 1] = arr[j + 1], arr[j]
    return arr
print(bubble_sort([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))

Case 3:prompt = “# language:Java\n# 写一个冒泡排序函数”

推理出结果了,但是要求是 Java,写成了 Python,而且中文有大量冗余。

# languageJava
# 写个冒泡排序函数
# 冒泡排序:
# 1.比较相邻的元素。如果第个比第二个大,就交换他们两个。
# 2.对每对相邻元素作同样的工作,从开始第对到结尾的最后对。这步做完后,最后的元素会是最大的数。
# 3.针对所有的元素重复以上的步骤,除了最后个。
# 4.持续每次对越来越少的元素重复上面的步骤,直到没有任何对数字需要比较。
# 冒泡排序的原理:
# 1.比较相邻的元素。如果第个比第二个大,就交换他们两个。
# 2.对每对相邻元素作同样的工作,从开始第对到结尾的最后对。这步做完后,最后的元素会是最大的数。
# 3.针对所有的元素重复以上的步骤,除了最后个。
# 4.持续每次对越来越少的元素重复上面的步骤,直到没有任何对数字需要比较。
# 冒泡排序的代码实现:
def bubble_sort(alist):
    for i in range(len(alist) - 1, 0, -1):
        for j in range(i):
            if alist[j] > alist[j + 1]:
                alist[j], alist[j + 1] = alist[j + 1], alist[j]
    return alist
alist = [54, 26, 93, 17, 77, 31, 44, 55, 20]
print(bubble_sort(alist))

Case4:prompt = "# language:Java\n# 冒泡排序“

内存吃到90G,3分钟不出结果

Case5:prompt = "使用Mybatis-plus 的分页查询用户"

  • idea 插件:表现正常
  • 本地执行:内存吃到90G, 70%卡死,30%可出结果并且结果正常。

Case6:带有大量上下文,有业务场景的长 prompt

prompt = "# language: Java\n# 使用Mybatis-plus 写一个关于【商城Service】的业务代码,商城的 Service 命名为 StoreService.工具:1. 字符串处理使用hutool的StrUtils; 2. 抛异常使用hutool的Assert; 3. 业务异常使用 BizException; 实体类有字段 :String id;String name;Double price;String type;业务-【新增商品】,业务规则:1. 必填名称;2. 价格必须大于100;3. 如果商品类型为 '001',价格必须大于200;4. 商品名称不能重复。\n"

  • idea 的插件:表现良好,
  • 网友(cuda 3090):基本秒出结果,执行结果正常。
  • 本地执行时:内存吃到94G,最长执行了10分钟不出结果,只出过一次结果,显示如下:
Loading checkpoint shards: 100%|██████████| 7/7 [00:05<00:00,  1.35it/s]
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
/Users/nacol/Projects/llm/CodeGeeX2/venv/lib/python3.11/site-packages/transformers/generation/utils.py:2419: UserWarning: MPS: no support for int64 min/max ops, casting it to int32 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/ReduceOps.mm:1271.)
  if unfinished_sequences.max() == 0:
# language: Java
# 使用Mybatis-plus 写一个关于【商城Service】的业务代码,商城的 Service 命名为 StoreService.工具:1. 字符串处理使用hutool的StrUtils; 2. 抛异常使用hutool的Assert; 3. 业务异常使用 BizException; 实体类有字段 :String id;String name;Double price;String type;业务-【新增商品】,业务规则:1. 必填名称;2. 价格必须大于100;3. 如果商品类型为 '001',价格必须大于200;4. 商品名称不能重复。
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.log.Log;
import cn.hutool.log.LogFactory;
import cn.hutool.

Prompt写得有问题,开源的 CodeGeeX2-6B 是一个基座代码模型,它的使用方式是偏补全的,要按照某种语言一般的编程习惯使用就可以了。比如prompt需要使用相应语言的注释符号,Python用"# [prompt]",Java则应该用“// [prompt]”。也可以加一些关键字来引导模型生成函数或类,比如Java用“// [prompt]\npublic class”

from codegeex2.

NacolZero avatar NacolZero commented on May 22, 2024

Prompt写得有问题,开源的 CodeGeeX2-6B 是一个基座代码模型,它的使用方式是偏补全的,要按照某种语言一般的编程习惯使用就可以了。比如prompt需要使用相应语言的注释符号,Python用"# [prompt]",Java则应该用“// [prompt]”。也可以加一些关键字来引导模型生成函数或类,比如Java用“// [prompt]\npublic class”

感谢你的指出,
我修正了 prompt 如下:
prompt = "// language: Java\n// 使用Mybatis-plus 的分页查询用户\n\npublic class"
依然有以下问题:

  1. 推理速度极慢(1~3分钟);
  2. 内存耗用极高 40G;
  3. 输出结果有时候会异常,将我的prompt输出;
  4. 输出结果有时候正确。

from codegeex2.

NacolZero avatar NacolZero commented on May 22, 2024

你要不试试在你的prompt最后再加一个\n ?我发现只有一个\n结尾的时候经常出现重复生成同样代码的问题

加了,但是依然没解决生成速度、吃内存等问题。
应该是其它问题。

from codegeex2.

Stanislas0 avatar Stanislas0 commented on May 22, 2024

Prompt写得有问题,开源的 CodeGeeX2-6B 是一个基座代码模型,它的使用方式是偏补全的,要按照某种语言一般的编程习惯使用就可以了。比如prompt需要使用相应语言的注释符号,Python用"# [prompt]",Java则应该用“// [prompt]”。也可以加一些关键字来引导模型生成函数或类,比如Java用“// [prompt]\npublic class”

感谢你的指出, 我修正了 prompt 如下: prompt = "// language: Java\n// 使用Mybatis-plus 的分页查询用户\n\npublic class" 依然有以下问题:

  1. 推理速度极慢(1~3分钟);
  2. 内存耗用极高 40G;
  3. 输出结果有时候会异常,将我的prompt输出;
  4. 输出结果有时候正确。
  1. 推理速度慢是正常的,毕竟是CPU不是GPU,性能差距是很大的。
  2. 内存占用高是实现问题,没有针对mac做过优化,你可以试一下开启use_cache=True。
  3. 输出中本来就会带prompt部分,需要自己做截取。
  4. 输出结果和prompt有关,和超参数设置也有关系。

from codegeex2.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.