Karol 3 anni fa
commit
da8add46f9
9 ha cambiato i file con 366 aggiunte e 0 eliminazioni
  1. 3 0
      .gitignore
  2. 21 0
      LICENSE
  3. 87 0
      README.md
  4. 92 0
      chatgpt.go
  5. 132 0
      context.go
  6. 12 0
      errors.go
  7. 12 0
      format.go
  8. 5 0
      go.mod
  9. 2 0
      go.sum

+ 3 - 0
.gitignore

@@ -0,0 +1,3 @@
+.idea
+chatgpt_test.go
+test

+ 21 - 0
LICENSE

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Shihao
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 87 - 0
README.md

@@ -0,0 +1,87 @@
+# chatgpt
+
+> chartgpt client for golang
+
+## Usege
+
+Download the package first:
+
+```shell
+go get github.com/solywsh/chatgpt
+```
+
+Conversation without context:
+
+```go
+package main
+
+import (
+	"fmt"
+	"github.com/solywsh/chatgpt"
+	"time"
+)
+
+func main() {
+	// The timeout is used to control the situation that the session is in a long and multi session situation.
+	// If it is set to 0, there will be no timeout. Note that a single request still has a timeout setting of 30s.
+	chat := chatgpt.New("openai_key", "user_id(not required)", 30*time.Second)
+	defer chat.Close()
+	//
+	//select {
+	//case <-chat.GetDoneChan():
+	//	fmt.Println("time out/finish")
+	//}
+	question := "你认为2022年世界杯的冠军是谁?"
+	fmt.Printf("Q: %s\n", question)
+	answer, err := chat.Chat(question)
+	if err != nil {
+		fmt.Println(err)
+	}
+	fmt.Printf("A: %s\n", answer)
+
+	//Q: 你认为2022年世界杯的冠军是谁?
+	//A: 这个问题很难回答,因为2022年世界杯还没有开始,所以没有人知道冠军是谁。
+}
+
+```
+
+Conversation with context:
+
+```golang
+package main
+
+import (
+	"fmt"
+	"github.com/solywsh/chatgpt"
+	"time"
+)
+
+func main() {
+	chat := chatgpt.New("openai_key", "user_id(not required)", 10*time.Second)
+	defer chat.Close()
+	//select {
+	//case <-chat.GetDoneChan():
+	//	fmt.Println("time out")
+	//}
+	question := "现在你是一只猫,接下来你只能用\"喵喵喵\"回答."
+	fmt.Printf("Q: %s\n", question)
+	answer, err := chat.ChatWithContext(question)
+	if err != nil {
+		fmt.Println(err)
+	}
+	fmt.Printf("A: %s\n", answer)
+	question = "你是一只猫吗?"
+	fmt.Printf("Q: %s\n", question)
+	answer, err = chat.ChatWithContext(question)
+	if err != nil {
+		fmt.Println(err)
+	}
+	fmt.Printf("A: %s\n", answer)
+
+	// Q: 现在你是一只猫,接下来你只能用"喵喵喵"回答.
+	// A: 喵喵喵!
+	// Q: 你是一只猫吗?
+	// A: 喵喵~!
+}
+```
+

+ 92 - 0
chatgpt.go

@@ -0,0 +1,92 @@
+package chatgpt
+
+import (
+	"context"
+	gogpt "github.com/sashabaranov/go-gpt3"
+	"time"
+)
+
+type ChatGPT struct {
+	client         *gogpt.Client
+	ctx            context.Context
+	userId         string
+	maxQuestionLen int
+	maxText        int
+	maxAnswerLen   int
+	timeOut        time.Duration // 超时时间, 0表示不超时
+	doneChan       chan struct{}
+	cancel         func()
+
+	ChatContext *ChatContext
+}
+
+func New(ApiKey, UserId string, timeOut time.Duration) *ChatGPT {
+	var ctx context.Context
+	var cancel func()
+	if timeOut == 0 {
+		ctx, cancel = context.WithCancel(context.Background())
+	} else {
+		ctx, cancel = context.WithTimeout(context.Background(), timeOut)
+	}
+	timeOutChan := make(chan struct{}, 1)
+	go func() {
+		<-ctx.Done()
+		timeOutChan <- struct{}{} // 发送超时信号,或是提示结束,用于聊天机器人场景,配合GetTimeOutChan() 使用
+	}()
+	return &ChatGPT{
+		client:         gogpt.NewClient(ApiKey),
+		ctx:            ctx,
+		userId:         UserId,
+		maxQuestionLen: 1024, // 最大问题长度
+		maxAnswerLen:   1024, // 最大答案长度
+		maxText:        4096, // 最大文本 = 问题 + 回答, 接口限制
+		timeOut:        timeOut,
+		doneChan:       timeOutChan,
+		cancel: func() {
+			cancel()
+		},
+		ChatContext: NewContext(),
+	}
+}
+func (c *ChatGPT) Close() {
+	c.cancel()
+}
+
+func (c *ChatGPT) GetDoneChan() chan struct{} {
+	return c.doneChan
+}
+
+func (c *ChatGPT) SetMaxQuestionLen(maxQuestionLen int) int {
+	if maxQuestionLen > c.maxText-c.maxAnswerLen {
+		maxQuestionLen = c.maxText - c.maxAnswerLen
+	}
+	c.maxQuestionLen = maxQuestionLen
+	return c.maxQuestionLen
+}
+
+func (c *ChatGPT) Chat(question string) (answer string, err error) {
+	question = question + "."
+	if len(question) > c.maxQuestionLen {
+		return "", OverMaxQuestionLength
+	}
+	if len(question)+c.maxAnswerLen > c.maxText {
+		question = question[:c.maxText-c.maxAnswerLen]
+	}
+	req := gogpt.CompletionRequest{
+		Model:            gogpt.GPT3TextDavinci003,
+		MaxTokens:        c.maxAnswerLen,
+		Prompt:           question,
+		Temperature:      0.9,
+		TopP:             1,
+		N:                1,
+		FrequencyPenalty: 0,
+		PresencePenalty:  0.5,
+		User:             c.userId,
+		Stop:             []string{},
+	}
+	resp, err := c.client.CreateCompletion(c.ctx, req)
+	if err != nil {
+		return "", err
+	}
+	return formatAnswer(resp.Choices[0].Text), err
+}

+ 132 - 0
context.go

@@ -0,0 +1,132 @@
+package chatgpt
+
+import (
+	"fmt"
+	gogpt "github.com/sashabaranov/go-gpt3"
+	"strings"
+)
+
+var (
+	DefaultAiRole    = "AI"
+	DefaultHumanRole = "Human"
+
+	DefaultCharacter  = []string{"helpful", "creative", "clever", "friendly", "lovely", "talkative"}
+	DefaultBackground = "The following is a conversation with AI assistant. The assistant is %s"
+	DefaultPreset     = "\n%s: Cześć!\n%s: Jestem asystentem AI, co mogę dla Ciebie zrobić?"
+)
+
+type ChatContext struct {
+	background  string // 对话背景
+	preset      string // 预设对话
+	maxSeqTimes int    // 最大对话次数
+	aiRole      *role  // AI角色
+	humanRole   *role  // 人类角色
+
+	old        []conversation // 旧对话
+	restartSeq string         // 重新开始对话的标识
+	startSeq   string         // 开始对话的标识
+
+	seqTimes int // 对话次数
+}
+
+type conversation struct {
+	role   *role
+	prompt string
+}
+
+type role struct {
+	name string
+}
+
+func NewContext() *ChatContext {
+	return &ChatContext{
+		aiRole:      &role{name: DefaultAiRole},
+		humanRole:   &role{name: DefaultHumanRole},
+		background:  fmt.Sprintf(DefaultBackground, strings.Join(DefaultCharacter, ", ")+"."),
+		maxSeqTimes: 10,
+		preset:      fmt.Sprintf(DefaultPreset, DefaultHumanRole, DefaultAiRole),
+		old:         []conversation{},
+		seqTimes:    0,
+		restartSeq:  "\n" + DefaultHumanRole + ": ",
+		startSeq:    "\n" + DefaultAiRole + ": ",
+	}
+}
+
+func (c *ChatContext) SetHumanRole(role string) {
+	c.humanRole.name = role
+	c.restartSeq = "\n" + c.humanRole.name + ": "
+}
+
+func (c *ChatContext) SetAiRole(role string) {
+	c.aiRole.name = role
+	c.startSeq = "\n" + c.aiRole.name + ": "
+}
+
+func (c *ChatContext) SetMaxSeqTimes(times int) {
+	c.maxSeqTimes = times
+}
+
+func (c *ChatContext) GetMaxSeqTimes() int {
+	return c.maxSeqTimes
+}
+
+func (c *ChatContext) SetBackground(background string) {
+	c.background = background
+}
+
+func (c *ChatContext) SetPreset(preset string) {
+	c.preset = preset
+}
+
+func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
+	question = question + "."
+	if len(question) > c.maxQuestionLen {
+		return "", OverMaxQuestionLength
+	}
+	if c.ChatContext.seqTimes >= c.ChatContext.maxSeqTimes {
+		return "", OverMaxSequenceTimes
+	}
+	var promptTable []string
+	promptTable = append(promptTable, c.ChatContext.background)
+	promptTable = append(promptTable, c.ChatContext.preset)
+	for _, v := range c.ChatContext.old {
+		if v.role == c.ChatContext.humanRole {
+			promptTable = append(promptTable, "\n"+v.role.name+": "+v.prompt)
+		} else {
+			promptTable = append(promptTable, v.role.name+": "+v.prompt)
+		}
+	}
+	promptTable = append(promptTable, "\n"+c.ChatContext.restartSeq+question)
+	prompt := strings.Join(promptTable, "\n")
+	prompt += c.ChatContext.startSeq
+	if len(prompt) > c.maxText-c.maxAnswerLen {
+		return "", OverMaxTextLength
+	}
+	req := gogpt.CompletionRequest{
+		Model:            gogpt.GPT3TextDavinci003,
+		MaxTokens:        c.maxAnswerLen,
+		Prompt:           prompt,
+		Temperature:      0.9,
+		TopP:             1,
+		N:                1,
+		FrequencyPenalty: 0,
+		PresencePenalty:  0.5,
+		User:             c.userId,
+		Stop:             []string{c.ChatContext.aiRole.name + ":", c.ChatContext.humanRole.name + ":"},
+	}
+	resp, err := c.client.CreateCompletion(c.ctx, req)
+	if err != nil {
+		return "", err
+	}
+	resp.Choices[0].Text = formatAnswer(resp.Choices[0].Text)
+	c.ChatContext.old = append(c.ChatContext.old, conversation{
+		role:   c.ChatContext.humanRole,
+		prompt: question,
+	})
+	c.ChatContext.old = append(c.ChatContext.old, conversation{
+		role:   c.ChatContext.aiRole,
+		prompt: resp.Choices[0].Text,
+	})
+	c.ChatContext.seqTimes++
+	return resp.Choices[0].Text, nil
+}

+ 12 - 0
errors.go

@@ -0,0 +1,12 @@
+package chatgpt
+
+import "errors"
+
+// OverMaxSequenceTimes 超过最大对话时间
+var OverMaxSequenceTimes = errors.New("maximum conversation times exceeded")
+
+// OverMaxTextLength 超过最大文本长度
+var OverMaxTextLength = errors.New("maximum text length exceeded")
+
+// OverMaxQuestionLength 超过最大问题长度
+var OverMaxQuestionLength = errors.New("maximum question length exceeded")

+ 12 - 0
format.go

@@ -0,0 +1,12 @@
+package chatgpt
+
+func formatAnswer(answer string) string {
+	for len(answer) > 0 {
+		if answer[:1] == "\n" || answer[0] == ' ' {
+			answer = answer[1:]
+		} else {
+			break
+		}
+	}
+	return answer
+}

+ 5 - 0
go.mod

@@ -0,0 +1,5 @@
+module github.com/solywsh/chatgpt
+
+go 1.19
+
+require github.com/sashabaranov/go-gpt3 v0.0.0-20221202105456-0f9f4aa343ad

+ 2 - 0
go.sum

@@ -0,0 +1,2 @@
+github.com/sashabaranov/go-gpt3 v0.0.0-20221202105456-0f9f4aa343ad h1:xuO52/X2TArHVbDISrxXasZrDhdsSGABAR5sJrePVGU=
+github.com/sashabaranov/go-gpt3 v0.0.0-20221202105456-0f9f4aa343ad/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=