Emergent generative agents
リビジョン | 35fbd054cec56752630a2152d93e3ac3e10e3830 (tree) |
---|---|
日時 | 2023-04-15 01:38:46 |
作者 | Corbin <cds@corb...> |
コミッター | Corbin |
Give sentence indices an inigo behavior.
The exponentially-decaying behavior appears to work well with the
given parsimony metric. I spitballed the exponent, based on a couple
thought files from some bots.
Temperature really should be lower, so I tweaked that too.
@@ -1,6 +1,7 @@ | ||
1 | 1 | #!/usr/bin/env nix-shell |
2 | 2 | #! nix-shell -i python3 -p python3Packages.irc python3Packages.faiss python3Packages.transformers python3Packages.torch |
3 | 3 | |
4 | +from collections import defaultdict | |
4 | 5 | from concurrent.futures import ThreadPoolExecutor |
5 | 6 | from datetime import datetime |
6 | 7 | from heapq import nsmallest |
@@ -10,13 +11,10 @@ import random | ||
10 | 11 | import re |
11 | 12 | import sys |
12 | 13 | |
13 | -from faiss import IndexFlatL2 | |
14 | -import numpy as np | |
15 | - | |
16 | 14 | from irc.bot import SingleServerIRCBot |
17 | 15 | from irc.strings import lower |
18 | 16 | |
19 | -from common import Log, Timer, breakAt | |
17 | +from common import Log, SentenceIndex, breakAt | |
20 | 18 | from gens.trans import Flavor, HFGen |
21 | 19 | from gens.camelid import CamelidGen |
22 | 20 |
@@ -46,29 +44,6 @@ def load_character(path): | ||
46 | 44 | with open(os.path.join(path, "character.json"), "r") as handle: |
47 | 45 | return json.load(handle) |
48 | 46 | |
49 | -class SentenceIndex: | |
50 | - def __init__(self, path, dimensions): | |
51 | - self.path = path | |
52 | - self.index = IndexFlatL2(dimensions) | |
53 | - | |
54 | - def load(self): | |
55 | - with open(self.path, "r") as handle: | |
56 | - data = json.load(handle) | |
57 | - self.db = list(data.items()) | |
58 | - self.index.add(np.array([row[1] for row in self.db], dtype="float32")) | |
59 | - | |
60 | - def save(self): | |
61 | - with open(self.path, "w") as f: json.dump(dict(self.db), f) | |
62 | - | |
63 | - def search(self, embedding, k): | |
64 | - with Timer("%d nearest neighbors" % k): | |
65 | - D, I = self.index.search(np.array([embedding], dtype="float32"), k) | |
66 | - return [self.db[i][0] for i in I[0] if i >= 0] | |
67 | - | |
68 | - def add(self, s, embedding): | |
69 | - self.index.add(np.array([embedding], dtype="float32")) | |
70 | - self.db.append((s, embedding)) | |
71 | - | |
72 | 47 | logpath = sys.argv[2] |
73 | 48 | character = load_character(logpath) |
74 | 49 | startingChannels = character.pop("startingChannels") |
@@ -88,7 +63,7 @@ max_context_length = gen.contextLength() | ||
88 | 63 | thought_index = SentenceIndex(os.path.join(logpath, "thoughts.json"), |
89 | 64 | llama_gen.embedding_width) |
90 | 65 | thought_index.load() |
91 | -print("~ Thought index:", thought_index.index.ntotal, "thoughts") | |
66 | +print("~ Thought index:", thought_index.size(), "thoughts") | |
92 | 67 | |
93 | 68 | executor = ThreadPoolExecutor(max_workers=1) |
94 | 69 |
@@ -108,6 +83,7 @@ class Agent(SingleServerIRCBot): | ||
108 | 83 | self.startingChannels = startingChannels |
109 | 84 | self.logpath = logpath |
110 | 85 | self.logs = {} |
86 | + self.willReply = defaultdict(bool) | |
111 | 87 | |
112 | 88 | def on_join(self, c, e): |
113 | 89 | channel = e.target |
@@ -132,21 +108,19 @@ class Agent(SingleServerIRCBot): | ||
132 | 108 | # https://github.com/jaraco/irc/blob/main/scripts/testbot.py |
133 | 109 | nick = lower(self.connection.get_nickname()) |
134 | 110 | lowered = lower(line) |
135 | - if self.thinking: print("~ Already thinking") | |
136 | - else: | |
137 | - self.thinkAbout(channel) | |
138 | - if (nick in lowered and random.random() <= 0.875): | |
139 | - self.generateReply(c, channel) | |
140 | - elif random.random() <= 0.125: self.generateReply(c, channel) | |
111 | + if nick in lowered or random.random() <= 0.125: | |
112 | + self.willReply[channel] = True | |
113 | + if not self.thinking: self.thinkAbout(c, channel) | |
141 | 114 | |
142 | 115 | def thoughtPrompt(self): |
143 | 116 | key = NO_THOUGHTS_EMBED if self.recent_thought is None else self.recent_thought[1] |
144 | - # Fetch more thoughts than necessary, and always prefer shorter | |
117 | + # Fetch more thoughts than necessary, and then always prefer shorter | |
145 | 118 | # thoughts. This is an attempt to prevent exponential rumination. |
146 | - new_thoughts = thought_index.search(key, 10) | |
147 | - # .search() returns most relevant thoughts first; reversing the list | |
148 | - # creates more focused chains of thought. | |
149 | - new_thoughts = nsmallest(5, new_thoughts.reverse(), key=len) | |
119 | + new_thoughts = thought_index.search(key, 20) | |
120 | + # XXX .search() returns most relevant thoughts first; reversing the list | |
121 | + # would create more focused chains of thought. | |
122 | + # Smaller thoughts are better. | |
123 | + new_thoughts = nsmallest(10, new_thoughts, key=len) | |
150 | 124 | if self.recent_thought is not None: |
151 | 125 | new_thoughts.append(self.recent_thought[0]) |
152 | 126 | print("~ Thoughts:", *new_thoughts) |
@@ -180,8 +154,10 @@ Users: {users}""" | ||
180 | 154 | return "\n".join(lines) |
181 | 155 | |
182 | 156 | def generateReply(self, c, channel): |
157 | + print("~ Will reply to channel:", channel) | |
158 | + self.willReply.pop(channel, None) | |
183 | 159 | log = self.logs[channel] |
184 | - nick = self.connection.get_nickname() | |
160 | + nick = c.get_nickname() | |
185 | 161 | prefix = f"{datetime.now():%H:%M:%S} <{nick}>" |
186 | 162 | examples = self.examplesFromOtherChannels(channel) |
187 | 163 | # NB: "full" prompt needs log lines from current channel... |
@@ -190,18 +166,20 @@ Users: {users}""" | ||
190 | 166 | log.bumpCutoff(max_context_length, gen.countTokens, fullPrompt, prefix) |
191 | 167 | # ...and current channel's log lines are added here. |
192 | 168 | s = log.finishPrompt(fullPrompt, prefix) |
193 | - print("~ log length:", len(log.l) - log.cutoff, | |
194 | - "prompt length (tokens):", gen.countTokens(s)) | |
169 | + # print("~ log length:", len(log.l) - log.cutoff, | |
170 | + # "prompt length (tokens):", gen.countTokens(s)) | |
195 | 171 | # NB: At this point, execution is kicked out to a thread. |
196 | 172 | def cb(completion): |
197 | 173 | self.thinking = False |
198 | 174 | reply = breakIRCLine(completion.result()) |
199 | 175 | log.irc(datetime.now(), nick, reply) |
200 | 176 | c.privmsg(channel, reply) |
177 | + if self.willReply: | |
178 | + self.generateReply(c, next(iter(self.willReply))) | |
201 | 179 | self.thinking = True |
202 | 180 | executor.submit(lambda: gen.complete(s)).add_done_callback(cb) |
203 | 181 | |
204 | - def thinkAbout(self, channel): | |
182 | + def thinkAbout(self, c, channel): | |
205 | 183 | print("~ Will ponder channel:", channel) |
206 | 184 | s = prompt + self.newThoughtPrompt(channel) |
207 | 185 | def cb(completion): |
@@ -211,6 +189,8 @@ Users: {users}""" | ||
211 | 189 | embedding = llama_gen.embed(thought) |
212 | 190 | self.recent_thought = thought, embedding |
213 | 191 | thought_index.add(thought, embedding) |
192 | + thought_index.prune() | |
193 | + if self.willReply[channel]: self.generateReply(c, channel) | |
214 | 194 | self.thinking = True |
215 | 195 | executor.submit(lambda: gen.complete(s)).add_done_callback(cb) |
216 | 196 |
@@ -1,21 +1,24 @@ | ||
1 | 1 | #!/usr/bin/env nix-shell |
2 | -#! nix-shell -i python3 -p python3 | |
2 | +#! nix-shell -i python3 -p python3Packages.faiss | |
3 | 3 | |
4 | -import json, sys | |
4 | +import sys | |
5 | 5 | |
6 | +from common import SentenceIndex | |
6 | 7 | from gens.camelid import CamelidGen |
7 | 8 | |
8 | 9 | path = sys.argv[-1] |
9 | 10 | gen = CamelidGen() |
11 | +index = SentenceIndex(path, gen.embedding_width) | |
10 | 12 | |
11 | -with open(path, "r") as handle: db = json.load(handle) | |
12 | -print("Thought database:", len(db), "entries") | |
13 | +index.load() | |
14 | +print("Thought database:", index.size(), "entries") | |
13 | 15 | |
14 | 16 | while True: |
15 | 17 | try: thought = input("> ").strip() |
16 | 18 | except EOFError: break |
17 | 19 | if not thought: break |
18 | - db[thought] = gen.embed(thought) | |
20 | + index.add(thought, gen.embed(thought)) | |
21 | + index.prune() | |
19 | 22 | |
20 | -print("Saving thought database:", len(db), "entries") | |
21 | -with open(path, "w") as handle: json.dump(db, handle) | |
23 | +print("Saving thought database:", index.size(), "entries") | |
24 | +index.save() |
@@ -1,6 +1,11 @@ | ||
1 | 1 | from bisect import bisect |
2 | +import json | |
3 | +import random | |
2 | 4 | from time import perf_counter |
3 | 5 | |
6 | +from faiss import IndexFlatL2 | |
7 | +import numpy as np | |
8 | + | |
4 | 9 | class Timer: |
5 | 10 | "Basic context manager for timing an operation." |
6 | 11 | def __init__(self, label): self.l = label |
@@ -59,3 +64,48 @@ def parseLine(line, speakers): | ||
59 | 64 | for edge in speakers: |
60 | 65 | if not line.startswith(edge): line = breakAt(line, edge) |
61 | 66 | return line.strip() |
67 | + | |
68 | +def parsimony(s): return 1 - 2 ** -(len(s) * (1 / 50)) | |
69 | + | |
70 | +class SentenceIndex: | |
71 | + def __init__(self, path, dimensions): | |
72 | + self.path = path | |
73 | + self.dimensions = dimensions | |
74 | + self.index = None | |
75 | + | |
76 | + def size(self): return self.index.ntotal | |
77 | + | |
78 | + def rebuild(self): | |
79 | + with Timer("rebuilding sentence index"): | |
80 | + self.index = IndexFlatL2(self.dimensions) | |
81 | + self.index.add(np.array([row[1] for row in self.db], dtype="float32")) | |
82 | + | |
83 | + def load(self): | |
84 | + with open(self.path, "r") as handle: | |
85 | + data = json.load(handle) | |
86 | + self.db = list(data.items()) | |
87 | + self.rebuild() | |
88 | + | |
89 | + def save(self): | |
90 | + with open(self.path, "w") as f: json.dump(dict(self.db), f) | |
91 | + | |
92 | + def search(self, embedding, k): | |
93 | + with Timer("%d nearest neighbors" % k): | |
94 | + D, I = self.index.search(np.array([embedding], dtype="float32"), k) | |
95 | + return [self.db[i][0] for i in I[0] if i >= 0] | |
96 | + | |
97 | + def add(self, s, embedding): | |
98 | + self.index.add(np.array([embedding], dtype="float32")) | |
99 | + self.db.append((s, embedding)) | |
100 | + | |
101 | + def prune(self): | |
102 | + # NB: This is the same maths as an inigo. Older thoughts are less | |
103 | + # likely to be removed. | |
104 | + i = int(random.expovariate(1 / 5)) | |
105 | + if not (0 < i < len(self.db)): return | |
106 | + thought = self.db[-i][0] | |
107 | + print("~ Is is short enough?", thought) | |
108 | + if random.random() <= parsimony(thought): | |
109 | + print("~ Would prune thought:", thought) | |
110 | + # self.db.pop(-i) | |
111 | + # self.rebuild() |
@@ -57,6 +57,6 @@ class HFGen: | ||
57 | 57 | do_sample=True, |
58 | 58 | # Force responses. |
59 | 59 | min_length=5, |
60 | - # Slightly sharpen results. | |
61 | - temperature=0.875, repetition_penalty=1.0625, | |
60 | + # Sharpen results. | |
61 | + temperature=0.75, repetition_penalty=1.125, | |
62 | 62 | )[0]["generated_text"] |