January 12th, 2023

Derived from: https://github.com/karpathy/nn-zero-to-hero/blob/master/lectures/makemore/makemore_part1_bigrams.ipynb

Makemore

Purpose: to make more of examples you give it. Ie: names

training makemore on names will make unique sounding names

this dataset will be used to train a character level language model

modelling sequence of characters and able to predict next character in a sequence

makemore implements a services of language model neural nets

Bigram Bag of Words MLP RNN GRU Transformer like GPT2 (NanoGPT Yay!)

describe the problem in terms of probabilities. wrt name isabella, the i is likely to be first char in a word, s is likely to come after i and so on until a is the likely be after an l and the last letter of a name

understand the dataset, load into array and check min/max

In [4]:
words = open('names.txt', 'r').read().splitlines()
In [5]:
words[:10]
Out[5]:
['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']
In [6]:
len(words)
Out[6]:
32033
In [7]:
min(len(w) for w in words)
Out[7]:
2
In [8]:
max(len(w) for w in words)
Out[8]:
15

wrt name isabella, the i is likely to be first char in a word, s is likely to come after i and so on until a is the likely be after an l and the last letter of a name

a bigram model looks at two characters at a time and tries to predict the next character in the sequence

zip(chs, chs[1:]) returns tuples of characters in each set. For example when the word is emma chs = emma and chs[1:] = mma Thus the tuples of emma, mma are: em, mm, ma, a[NULL] which returns NULL and you are just left with the first 3 tuples.

using a start and end tokens you can track what names start and end with

the dictionary b stores the counts of bigrams

In [9]:
b = {}
for w in words:
  chs = ['<S>'] + list(w) + ['<E>']
  for ch1, ch2 in zip(chs, chs[1:]):
    bigram = (ch1, ch2)
    b[bigram] = b.get(bigram, 0) + 1

b.items() returns the tuples of the bigram and counts sorted sorts by counts and - makes it decending order

In [10]:
sorted(b.items(), key = lambda kv: -kv[1])
Out[10]:
[(('n', '<E>'), 6763),
 (('a', '<E>'), 6640),
 (('a', 'n'), 5438),
 (('<S>', 'a'), 4410),
 (('e', '<E>'), 3983),
 (('a', 'r'), 3264),
 (('e', 'l'), 3248),
 (('r', 'i'), 3033),
 (('n', 'a'), 2977),
 (('<S>', 'k'), 2963),
 (('l', 'e'), 2921),
 (('e', 'n'), 2675),
 (('l', 'a'), 2623),
 (('m', 'a'), 2590),
 (('<S>', 'm'), 2538),
 (('a', 'l'), 2528),
 (('i', '<E>'), 2489),
 (('l', 'i'), 2480),
 (('i', 'a'), 2445),
 (('<S>', 'j'), 2422),
 (('o', 'n'), 2411),
 (('h', '<E>'), 2409),
 (('r', 'a'), 2356),
 (('a', 'h'), 2332),
 (('h', 'a'), 2244),
 (('y', 'a'), 2143),
 (('i', 'n'), 2126),
 (('<S>', 's'), 2055),
 (('a', 'y'), 2050),
 (('y', '<E>'), 2007),
 (('e', 'r'), 1958),
 (('n', 'n'), 1906),
 (('y', 'n'), 1826),
 (('k', 'a'), 1731),
 (('n', 'i'), 1725),
 (('r', 'e'), 1697),
 (('<S>', 'd'), 1690),
 (('i', 'e'), 1653),
 (('a', 'i'), 1650),
 (('<S>', 'r'), 1639),
 (('a', 'm'), 1634),
 (('l', 'y'), 1588),
 (('<S>', 'l'), 1572),
 (('<S>', 'c'), 1542),
 (('<S>', 'e'), 1531),
 (('j', 'a'), 1473),
 (('r', '<E>'), 1377),
 (('n', 'e'), 1359),
 (('l', 'l'), 1345),
 (('i', 'l'), 1345),
 (('i', 's'), 1316),
 (('l', '<E>'), 1314),
 (('<S>', 't'), 1308),
 (('<S>', 'b'), 1306),
 (('d', 'a'), 1303),
 (('s', 'h'), 1285),
 (('d', 'e'), 1283),
 (('e', 'e'), 1271),
 (('m', 'i'), 1256),
 (('s', 'a'), 1201),
 (('s', '<E>'), 1169),
 (('<S>', 'n'), 1146),
 (('a', 's'), 1118),
 (('y', 'l'), 1104),
 (('e', 'y'), 1070),
 (('o', 'r'), 1059),
 (('a', 'd'), 1042),
 (('t', 'a'), 1027),
 (('<S>', 'z'), 929),
 (('v', 'i'), 911),
 (('k', 'e'), 895),
 (('s', 'e'), 884),
 (('<S>', 'h'), 874),
 (('r', 'o'), 869),
 (('e', 's'), 861),
 (('z', 'a'), 860),
 (('o', '<E>'), 855),
 (('i', 'r'), 849),
 (('b', 'r'), 842),
 (('a', 'v'), 834),
 (('m', 'e'), 818),
 (('e', 'i'), 818),
 (('c', 'a'), 815),
 (('i', 'y'), 779),
 (('r', 'y'), 773),
 (('e', 'm'), 769),
 (('s', 't'), 765),
 (('h', 'i'), 729),
 (('t', 'e'), 716),
 (('n', 'd'), 704),
 (('l', 'o'), 692),
 (('a', 'e'), 692),
 (('a', 't'), 687),
 (('s', 'i'), 684),
 (('e', 'a'), 679),
 (('d', 'i'), 674),
 (('h', 'e'), 674),
 (('<S>', 'g'), 669),
 (('t', 'o'), 667),
 (('c', 'h'), 664),
 (('b', 'e'), 655),
 (('t', 'h'), 647),
 (('v', 'a'), 642),
 (('o', 'l'), 619),
 (('<S>', 'i'), 591),
 (('i', 'o'), 588),
 (('e', 't'), 580),
 (('v', 'e'), 568),
 (('a', 'k'), 568),
 (('a', 'a'), 556),
 (('c', 'e'), 551),
 (('a', 'b'), 541),
 (('i', 't'), 541),
 (('<S>', 'y'), 535),
 (('t', 'i'), 532),
 (('s', 'o'), 531),
 (('m', '<E>'), 516),
 (('d', '<E>'), 516),
 (('<S>', 'p'), 515),
 (('i', 'c'), 509),
 (('k', 'i'), 509),
 (('o', 's'), 504),
 (('n', 'o'), 496),
 (('t', '<E>'), 483),
 (('j', 'o'), 479),
 (('u', 's'), 474),
 (('a', 'c'), 470),
 (('n', 'y'), 465),
 (('e', 'v'), 463),
 (('s', 's'), 461),
 (('m', 'o'), 452),
 (('i', 'k'), 445),
 (('n', 't'), 443),
 (('i', 'd'), 440),
 (('j', 'e'), 440),
 (('a', 'z'), 435),
 (('i', 'g'), 428),
 (('i', 'm'), 427),
 (('r', 'r'), 425),
 (('d', 'r'), 424),
 (('<S>', 'f'), 417),
 (('u', 'r'), 414),
 (('r', 'l'), 413),
 (('y', 's'), 401),
 (('<S>', 'o'), 394),
 (('e', 'd'), 384),
 (('a', 'u'), 381),
 (('c', 'o'), 380),
 (('k', 'y'), 379),
 (('d', 'o'), 378),
 (('<S>', 'v'), 376),
 (('t', 't'), 374),
 (('z', 'e'), 373),
 (('z', 'i'), 364),
 (('k', '<E>'), 363),
 (('g', 'h'), 360),
 (('t', 'r'), 352),
 (('k', 'o'), 344),
 (('t', 'y'), 341),
 (('g', 'e'), 334),
 (('g', 'a'), 330),
 (('l', 'u'), 324),
 (('b', 'a'), 321),
 (('d', 'y'), 317),
 (('c', 'k'), 316),
 (('<S>', 'w'), 307),
 (('k', 'h'), 307),
 (('u', 'l'), 301),
 (('y', 'e'), 301),
 (('y', 'r'), 291),
 (('m', 'y'), 287),
 (('h', 'o'), 287),
 (('w', 'a'), 280),
 (('s', 'l'), 279),
 (('n', 's'), 278),
 (('i', 'z'), 277),
 (('u', 'n'), 275),
 (('o', 'u'), 275),
 (('n', 'g'), 273),
 (('y', 'd'), 272),
 (('c', 'i'), 271),
 (('y', 'o'), 271),
 (('i', 'v'), 269),
 (('e', 'o'), 269),
 (('o', 'm'), 261),
 (('r', 'u'), 252),
 (('f', 'a'), 242),
 (('b', 'i'), 217),
 (('s', 'y'), 215),
 (('n', 'c'), 213),
 (('h', 'y'), 213),
 (('p', 'a'), 209),
 (('r', 't'), 208),
 (('q', 'u'), 206),
 (('p', 'h'), 204),
 (('h', 'r'), 204),
 (('j', 'u'), 202),
 (('g', 'r'), 201),
 (('p', 'e'), 197),
 (('n', 'l'), 195),
 (('y', 'i'), 192),
 (('g', 'i'), 190),
 (('o', 'd'), 190),
 (('r', 's'), 190),
 (('r', 'd'), 187),
 (('h', 'l'), 185),
 (('s', 'u'), 185),
 (('a', 'x'), 182),
 (('e', 'z'), 181),
 (('e', 'k'), 178),
 (('o', 'v'), 176),
 (('a', 'j'), 175),
 (('o', 'h'), 171),
 (('u', 'e'), 169),
 (('m', 'm'), 168),
 (('a', 'g'), 168),
 (('h', 'u'), 166),
 (('x', '<E>'), 164),
 (('u', 'a'), 163),
 (('r', 'm'), 162),
 (('a', 'w'), 161),
 (('f', 'i'), 160),
 (('z', '<E>'), 160),
 (('u', '<E>'), 155),
 (('u', 'm'), 154),
 (('e', 'c'), 153),
 (('v', 'o'), 153),
 (('e', 'h'), 152),
 (('p', 'r'), 151),
 (('d', 'd'), 149),
 (('o', 'a'), 149),
 (('w', 'e'), 149),
 (('w', 'i'), 148),
 (('y', 'm'), 148),
 (('z', 'y'), 147),
 (('n', 'z'), 145),
 (('y', 'u'), 141),
 (('r', 'n'), 140),
 (('o', 'b'), 140),
 (('k', 'l'), 139),
 (('m', 'u'), 139),
 (('l', 'd'), 138),
 (('h', 'n'), 138),
 (('u', 'd'), 136),
 (('<S>', 'x'), 134),
 (('t', 'l'), 134),
 (('a', 'f'), 134),
 (('o', 'e'), 132),
 (('e', 'x'), 132),
 (('e', 'g'), 125),
 (('f', 'e'), 123),
 (('z', 'l'), 123),
 (('u', 'i'), 121),
 (('v', 'y'), 121),
 (('e', 'b'), 121),
 (('r', 'h'), 121),
 (('j', 'i'), 119),
 (('o', 't'), 118),
 (('d', 'h'), 118),
 (('h', 'm'), 117),
 (('c', 'l'), 116),
 (('o', 'o'), 115),
 (('y', 'c'), 115),
 (('o', 'w'), 114),
 (('o', 'c'), 114),
 (('f', 'r'), 114),
 (('b', '<E>'), 114),
 (('m', 'b'), 112),
 (('z', 'o'), 110),
 (('i', 'b'), 110),
 (('i', 'u'), 109),
 (('k', 'r'), 109),
 (('g', '<E>'), 108),
 (('y', 'v'), 106),
 (('t', 'z'), 105),
 (('b', 'o'), 105),
 (('c', 'y'), 104),
 (('y', 't'), 104),
 (('u', 'b'), 103),
 (('u', 'c'), 103),
 (('x', 'a'), 103),
 (('b', 'l'), 103),
 (('o', 'y'), 103),
 (('x', 'i'), 102),
 (('i', 'f'), 101),
 (('r', 'c'), 99),
 (('c', '<E>'), 97),
 (('m', 'r'), 97),
 (('n', 'u'), 96),
 (('o', 'p'), 95),
 (('i', 'h'), 95),
 (('k', 's'), 95),
 (('l', 's'), 94),
 (('u', 'k'), 93),
 (('<S>', 'q'), 92),
 (('d', 'u'), 92),
 (('s', 'm'), 90),
 (('r', 'k'), 90),
 (('i', 'x'), 89),
 (('v', '<E>'), 88),
 (('y', 'k'), 86),
 (('u', 'w'), 86),
 (('g', 'u'), 85),
 (('b', 'y'), 83),
 (('e', 'p'), 83),
 (('g', 'o'), 83),
 (('s', 'k'), 82),
 (('u', 't'), 82),
 (('a', 'p'), 82),
 (('e', 'f'), 82),
 (('i', 'i'), 82),
 (('r', 'v'), 80),
 (('f', '<E>'), 80),
 (('t', 'u'), 78),
 (('y', 'z'), 78),
 (('<S>', 'u'), 78),
 (('l', 't'), 77),
 (('r', 'g'), 76),
 (('c', 'r'), 76),
 (('i', 'j'), 76),
 (('w', 'y'), 73),
 (('z', 'u'), 73),
 (('l', 'v'), 72),
 (('h', 't'), 71),
 (('j', '<E>'), 71),
 (('x', 't'), 70),
 (('o', 'i'), 69),
 (('e', 'u'), 69),
 (('o', 'k'), 68),
 (('b', 'd'), 65),
 (('a', 'o'), 63),
 (('p', 'i'), 61),
 (('s', 'c'), 60),
 (('d', 'l'), 60),
 (('l', 'm'), 60),
 (('a', 'q'), 60),
 (('f', 'o'), 60),
 (('p', 'o'), 59),
 (('n', 'k'), 58),
 (('w', 'n'), 58),
 (('u', 'h'), 58),
 (('e', 'j'), 55),
 (('n', 'v'), 55),
 (('s', 'r'), 55),
 (('o', 'z'), 54),
 (('i', 'p'), 53),
 (('l', 'b'), 52),
 (('i', 'q'), 52),
 (('w', '<E>'), 51),
 (('m', 'c'), 51),
 (('s', 'p'), 51),
 (('e', 'w'), 50),
 (('k', 'u'), 50),
 (('v', 'r'), 48),
 (('u', 'g'), 47),
 (('o', 'x'), 45),
 (('u', 'z'), 45),
 (('z', 'z'), 45),
 (('j', 'h'), 45),
 (('b', 'u'), 45),
 (('o', 'g'), 44),
 (('n', 'r'), 44),
 (('f', 'f'), 44),
 (('n', 'j'), 44),
 (('z', 'h'), 43),
 (('c', 'c'), 42),
 (('r', 'b'), 41),
 (('x', 'o'), 41),
 (('b', 'h'), 41),
 (('p', 'p'), 39),
 (('x', 'l'), 39),
 (('h', 'v'), 39),
 (('b', 'b'), 38),
 (('m', 'p'), 38),
 (('x', 'x'), 38),
 (('u', 'v'), 37),
 (('x', 'e'), 36),
 (('w', 'o'), 36),
 (('c', 't'), 35),
 (('z', 'm'), 35),
 (('t', 's'), 35),
 (('m', 's'), 35),
 (('c', 'u'), 35),
 (('o', 'f'), 34),
 (('u', 'x'), 34),
 (('k', 'w'), 34),
 (('p', '<E>'), 33),
 (('g', 'l'), 32),
 (('z', 'r'), 32),
 (('d', 'n'), 31),
 (('g', 't'), 31),
 (('g', 'y'), 31),
 (('h', 's'), 31),
 (('x', 's'), 31),
 (('g', 's'), 30),
 (('x', 'y'), 30),
 (('y', 'g'), 30),
 (('d', 'm'), 30),
 (('d', 's'), 29),
 (('h', 'k'), 29),
 (('y', 'x'), 28),
 (('q', '<E>'), 28),
 (('g', 'n'), 27),
 (('y', 'b'), 27),
 (('g', 'w'), 26),
 (('n', 'h'), 26),
 (('k', 'n'), 26),
 (('g', 'g'), 25),
 (('d', 'g'), 25),
 (('l', 'c'), 25),
 (('r', 'j'), 25),
 (('w', 'u'), 25),
 (('l', 'k'), 24),
 (('m', 'd'), 24),
 (('s', 'w'), 24),
 (('s', 'n'), 24),
 (('h', 'd'), 24),
 (('w', 'h'), 23),
 (('y', 'j'), 23),
 (('y', 'y'), 23),
 (('r', 'z'), 23),
 (('d', 'w'), 23),
 (('w', 'r'), 22),
 (('t', 'n'), 22),
 (('l', 'f'), 22),
 (('y', 'h'), 22),
 (('r', 'w'), 21),
 (('s', 'b'), 21),
 (('m', 'n'), 20),
 (('f', 'l'), 20),
 (('w', 's'), 20),
 (('k', 'k'), 20),
 (('h', 'z'), 20),
 (('g', 'd'), 19),
 (('l', 'h'), 19),
 (('n', 'm'), 19),
 (('x', 'z'), 19),
 (('u', 'f'), 19),
 (('f', 't'), 18),
 (('l', 'r'), 18),
 (('p', 't'), 17),
 (('t', 'c'), 17),
 (('k', 't'), 17),
 (('d', 'v'), 17),
 (('u', 'p'), 16),
 (('p', 'l'), 16),
 (('l', 'w'), 16),
 (('p', 's'), 16),
 (('o', 'j'), 16),
 (('r', 'q'), 16),
 (('y', 'p'), 15),
 (('l', 'p'), 15),
 (('t', 'v'), 15),
 (('r', 'p'), 14),
 (('l', 'n'), 14),
 (('e', 'q'), 14),
 (('f', 'y'), 14),
 (('s', 'v'), 14),
 (('u', 'j'), 14),
 (('v', 'l'), 14),
 (('q', 'a'), 13),
 (('u', 'y'), 13),
 (('q', 'i'), 13),
 (('w', 'l'), 13),
 (('p', 'y'), 12),
 (('y', 'f'), 12),
 (('c', 'q'), 11),
 (('j', 'r'), 11),
 (('n', 'w'), 11),
 (('n', 'f'), 11),
 (('t', 'w'), 11),
 (('m', 'z'), 11),
 (('u', 'o'), 10),
 (('f', 'u'), 10),
 (('l', 'z'), 10),
 (('h', 'w'), 10),
 (('u', 'q'), 10),
 (('j', 'y'), 10),
 (('s', 'z'), 10),
 (('s', 'd'), 9),
 (('j', 'l'), 9),
 (('d', 'j'), 9),
 (('k', 'm'), 9),
 (('r', 'f'), 9),
 (('h', 'j'), 9),
 (('v', 'n'), 8),
 (('n', 'b'), 8),
 (('i', 'w'), 8),
 (('h', 'b'), 8),
 (('b', 's'), 8),
 (('w', 't'), 8),
 (('w', 'd'), 8),
 (('v', 'v'), 7),
 (('v', 'u'), 7),
 (('j', 's'), 7),
 (('m', 'j'), 7),
 (('f', 's'), 6),
 (('l', 'g'), 6),
 (('l', 'j'), 6),
 (('j', 'w'), 6),
 (('n', 'x'), 6),
 (('y', 'q'), 6),
 (('w', 'k'), 6),
 (('g', 'm'), 6),
 (('x', 'u'), 5),
 (('m', 'h'), 5),
 (('m', 'l'), 5),
 (('j', 'm'), 5),
 (('c', 's'), 5),
 (('j', 'v'), 5),
 (('n', 'p'), 5),
 (('d', 'f'), 5),
 (('x', 'd'), 5),
 (('z', 'b'), 4),
 (('f', 'n'), 4),
 (('x', 'c'), 4),
 (('m', 't'), 4),
 (('t', 'm'), 4),
 (('z', 'n'), 4),
 (('z', 't'), 4),
 (('p', 'u'), 4),
 (('c', 'z'), 4),
 (('b', 'n'), 4),
 (('z', 's'), 4),
 (('f', 'w'), 4),
 (('d', 't'), 4),
 (('j', 'd'), 4),
 (('j', 'c'), 4),
 (('y', 'w'), 4),
 (('v', 'k'), 3),
 (('x', 'w'), 3),
 (('t', 'j'), 3),
 (('c', 'j'), 3),
 (('q', 'w'), 3),
 (('g', 'b'), 3),
 (('o', 'q'), 3),
 (('r', 'x'), 3),
 (('d', 'c'), 3),
 (('g', 'j'), 3),
 (('x', 'f'), 3),
 (('z', 'w'), 3),
 (('d', 'k'), 3),
 (('u', 'u'), 3),
 (('m', 'v'), 3),
 (('c', 'x'), 3),
 (('l', 'q'), 3),
 (('p', 'b'), 2),
 (('t', 'g'), 2),
 (('q', 's'), 2),
 (('t', 'x'), 2),
 (('f', 'k'), 2),
 (('b', 't'), 2),
 (('j', 'n'), 2),
 (('k', 'c'), 2),
 (('z', 'k'), 2),
 (('s', 'j'), 2),
 (('s', 'f'), 2),
 (('z', 'j'), 2),
 (('n', 'q'), 2),
 (('f', 'z'), 2),
 (('h', 'g'), 2),
 (('w', 'w'), 2),
 (('k', 'j'), 2),
 (('j', 'k'), 2),
 (('w', 'm'), 2),
 (('z', 'c'), 2),
 (('z', 'v'), 2),
 (('w', 'f'), 2),
 (('q', 'm'), 2),
 (('k', 'z'), 2),
 (('j', 'j'), 2),
 (('z', 'p'), 2),
 (('j', 't'), 2),
 (('k', 'b'), 2),
 (('m', 'w'), 2),
 (('h', 'f'), 2),
 (('c', 'g'), 2),
 (('t', 'f'), 2),
 (('h', 'c'), 2),
 (('q', 'o'), 2),
 (('k', 'd'), 2),
 (('k', 'v'), 2),
 (('s', 'g'), 2),
 (('z', 'd'), 2),
 (('q', 'r'), 1),
 (('d', 'z'), 1),
 (('p', 'j'), 1),
 (('q', 'l'), 1),
 (('p', 'f'), 1),
 (('q', 'e'), 1),
 (('b', 'c'), 1),
 (('c', 'd'), 1),
 (('m', 'f'), 1),
 (('p', 'n'), 1),
 (('w', 'b'), 1),
 (('p', 'c'), 1),
 (('h', 'p'), 1),
 (('f', 'h'), 1),
 (('b', 'j'), 1),
 (('f', 'g'), 1),
 (('z', 'g'), 1),
 (('c', 'p'), 1),
 (('p', 'k'), 1),
 (('p', 'm'), 1),
 (('x', 'n'), 1),
 (('s', 'q'), 1),
 (('k', 'f'), 1),
 (('m', 'k'), 1),
 (('x', 'h'), 1),
 (('g', 'f'), 1),
 (('v', 'b'), 1),
 (('j', 'p'), 1),
 (('g', 'z'), 1),
 (('v', 'd'), 1),
 (('d', 'b'), 1),
 (('v', 'h'), 1),
 (('h', 'h'), 1),
 (('g', 'v'), 1),
 (('d', 'q'), 1),
 (('x', 'b'), 1),
 (('w', 'z'), 1),
 (('h', 'q'), 1),
 (('j', 'b'), 1),
 (('x', 'm'), 1),
 (('w', 'g'), 1),
 (('t', 'b'), 1),
 (('z', 'x'), 1)]

what we see in the above example is that "n" is the most common character at the end of a name, followed by "a", "n" often comes after an "a" ... and so no

it is significantly more convenient to store this in a 2D array rows are the first character of the two characters and columns are the second character each entry will tell us how often the first character follows the second character in the dataset

you can index into a tensor arrays using comma separated. ie: a[1, 3] would index to a[1][3]

In [11]:
import torch

27 X 27 because we use the 26 characters + 1 start/end token

In [12]:
N = torch.zeros((27, 27), dtype=torch.int32)

chars is indexed list of characters stoi is an index mapping of each character to an integer index itos is the inverse lookup of stoi

In [14]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

as we iterate over each tuple we increment a counter in the tensor

instead of having two special tokens we can just use 1 because 2 is not needed and in fact creates an entire row and column of zeros because there would never be a characters before the first or after the last.

So we replace that with just a . token which can be used to represent both tokens and will provide more optimal storage efficiency.

In [15]:
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    N[ix1, ix2] += 1
    

matplotlib allows you to create figures

the reason you have N[i, j].item() is because when you index into the array you get a tensor back. using .item() will return an int instead.

In [16]:
import numpy
In [17]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(16,16))
plt.imshow(N, cmap='Blues')
for i in range(27):
    for j in range(27):
        chstr = itos[i] + itos[j]
        plt.text(j, i, chstr, ha="center", va="bottom", color='gray')
        plt.text(j, i, N[i, j].item(), ha="center", va="top", color='gray')
plt.axis('off');

this array has everything we need to sample from the bigram model

the counts tell us how often we start any word

In [18]:
N[0]
Out[18]:
tensor([   0, 4410, 1306, 1542, 1690, 1531,  417,  669,  874,  591, 2422, 2963,
        1572, 2538, 1146,  394,  515,   92, 1639, 2055, 1308,   78,  376,  307,
         134,  535,  929], dtype=torch.int32)

the reason we create floats is because we are going to normalize these counts to create a probability distribution

In [19]:
p = N[0].float()
p = p / p.sum()
p
Out[19]:
tensor([0.0000, 0.1377, 0.0408, 0.0481, 0.0528, 0.0478, 0.0130, 0.0209, 0.0273,
        0.0184, 0.0756, 0.0925, 0.0491, 0.0792, 0.0358, 0.0123, 0.0161, 0.0029,
        0.0512, 0.0642, 0.0408, 0.0024, 0.0117, 0.0096, 0.0042, 0.0167, 0.0290])

torch.multinomial returns a sample from the multinomial distribution. you give me probabilities and I will give you integers which are sampled according to the probability distribution.

using a specific seed, will give us a deterministic example so that examples can align

so we create a seeded generator and use this to generate 1 sample according to the probability distribution p this returns us a tensor with a single integer sampled from p according to the seeded randomness we use .item() to pop out the int of the tensor and then use itos to convert it to a character

In [20]:
g = torch.Generator().manual_seed(2147483647)
ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
itos[ix]
Out[20]:
'm'
In [24]:
g = torch.Generator().manual_seed(2147483647)
p = torch.rand(3, generator=g)
p = p / p.sum()
p
Out[24]:
tensor([0.6064, 0.3033, 0.0903])
In [25]:
torch.multinomial(p, num_samples=100, replacement=True, generator=g)
Out[25]:
tensor([1, 1, 2, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 2, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1,
        0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
        0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 1, 0,
        0, 1, 1, 1])
In [26]:
p.shape
Out[26]:
torch.Size([3])

we want to divide each row by their respective sums

if keepdim is True then the output tensor is the same size as the input except in the dimension(s) dim where it is size of 1. if keepdim is False the dimension is squeezed out

broadcasting semantics in torch two tensors are broadcastable if -each tensor has at least one dimension -when iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal or one of them is 1, or one of them does not exist

(27, 27) (27, 1)

it takes this dimension 1 and stretches it out, copies it 27 times to make both be 27, 27 internally

In [27]:
P = (N+1).float()
P /= P.sum(1, keepdims=True)

if we did not put in keepdims=True, then instead of having a 27, 1 array, the 1 dimension would be squeezed out leaving us with an array of 27. when doing the /= operation, the counts are aligned right to left so the dimension that gets added back in is on the left leaving us with a 1, 27. consequently we would be normalizing the columns instead of the rows and this causes a bug that gives us back garbage. The check below tests whether the row is normalized properly and should sum to 1

P /= P.sum... is an in place Tensor where as P = P / P.sum... would create another copy of P and be wasteful

In [28]:
P[0].sum()
Out[28]:
tensor(1.)

https://pytorch.org/docs/stable/notes/broadcasting.html In order for an operation to be broadcastable: -each tensor has at least one dimension -when iterating over the dimension sizes starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1 or one of them does not exist

we begin at index 0 while true we are going to grab the row of the array for the letter we are currently on we normalize this array into a probability distribution that sums to 1 we use the seeded generator object and we draw a single sample from the distribution and that will tell us what index for the next sample row will be we append the converted index to a character and append that to out if we select '.' as the indexed character (ix == 0) then we break

In [30]:
g = torch.Generator().manual_seed(2147483647)

for i in range(5):
  
  out = []
  ix = 0
  while True:
    p = P[ix]
    ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[ix])
    if ix == 0:
      break
  print(''.join(out))
mor.
axx.
minaymoryles.
kondlaisah.
anchshizarie.
In [329]:
# GOAL: maximize likelihood of the data w.r.t. model parameters (statistical modeling)
# equivalent to maximizing the log likelihood (because log is monotonic)
# equivalent to minimizing the negative log likelihood
# equivalent to minimizing the average negative log likelihood

# log(a*b*c) = log(a) + log(b) + log(c)

likelihood is product of probabilities by multiplying probabilities which have numbers between 0 and 1 you end up with a very small number to combat this, you use the log likelihood. log log_likelihood returns a number closer to 0 when probability is close to 1 and -inf when probability is close to 0 as commented above log(abc) = log(a) + log(b) + log(c) which are all negative. but for our loss function we want to minimize our loss and so lower numbers are better. therefore we need to take the negative log likelihood to invert the sign of the log likelihood so that our loss function makes sense if we then use a counter (n) to keep track of number of samples, we can normalize the negative log likelihood (nll) by dividing by n

in the example of andrejq, the probability of jq occuring is 0 which gives us infinite loss as such we modify the probability array from P = N.float() to P = (N+1).float() do give some unlikely but non-zero probability that will avoid give us infinity

In [32]:
log_likelihood = 0.0
n = 0

for w in words:
#for w in ["andrejq"]:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    prob = P[ix1, ix2]
    logprob = torch.log(prob)
    log_likelihood += logprob
    n += 1
    #print(f'{ch1}{ch2}: {prob:.4f} {logprob:.4f}')

print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/n}')
log_likelihood=tensor(-559951.5625)
nll=tensor(559951.5625)
2.4543561935424805

torch.tensor infers the dtype automatically while torch.Tensor returns a torch.FloatTensor stick with torch.tensor and explicity set dtype when applicable in the below example we want integers, so using .tensor is the best way to easily achieve that

In [35]:
# create the training set of bigrams (x,y)
xs, ys = [], []

for w in words[:1]:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    print(ch1, ch2)
    xs.append(ix1)
    ys.append(ix2)
    
xs = torch.tensor(xs)
ys = torch.tensor(ys)
. e
e m
m m
m a
a .
In [36]:
xs
Out[36]:
tensor([ 0,  5, 13, 13,  1])
In [37]:
ys
Out[37]:
tensor([ 5, 13, 13,  1,  0])

as we iterate through xs, we want to tune the NN to have high probabilities wrt ys. for example with .emma. we want to tune NN to have e high probability of start, m high probability to follow e, m also high probability of following m ...

one hot encoding we take an integer like 13 and we make it all zeros except for the 13th position where it becomes a 1 this encoding is much better suited to feed into a NN than the integer value 13, you also supply the num_classes variable which is the size of the set you want to encode (26 + 1 for .)

one_hot function does not take in a dtype, so to ensure this is a float32 instead of int64, we need to append .float() to the end of the one_hot function call

In [38]:
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=27).float()
xenc
Out[38]:
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.]])
In [39]:
xenc.shape
Out[39]:
torch.Size([5, 27])

what we see is .emma represented by 5 rows of integer encodings of the character where each bit is turned on in the stoi'th column . => 0 => [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.] e => 5 => [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.] m => 13 => [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.] m => 13 => [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], a => 1 => [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]

In [40]:
plt.imshow(xenc)
Out[40]:
<matplotlib.image.AxesImage at 0x7fe8cc0aafd0>
In [41]:
xenc.dtype
Out[41]:
torch.float32

torch.randn fills a tensor with random numbers drawn from a normal distribution (most numbers will be around 0 but some will be as high/low as +-3)

W is size 27, 1 xenc is 5, 27 the @ operation is a matrix multiplication operater in pytorch (dot product) when we multiple 5,27 with 27,1 we get a result of size 5, 1

this leaves us with 5 activations of this neuron on these 5 inputs evaluated in parallel

In [42]:
W = torch.randn((27, 27))
xenc @ W
Out[42]:
tensor([[ 4.1890e-01,  1.6821e+00, -3.8086e-01,  1.3436e-01, -5.2135e-01,
          1.5735e+00,  3.2776e-01,  1.9606e+00, -4.7497e-01,  5.5738e-01,
          9.1139e-02,  1.3858e+00, -9.8750e-02,  1.1440e-01, -2.2790e-01,
         -1.2869e+00,  5.7435e-04, -8.3552e-01, -1.3282e+00,  1.9281e-01,
          1.2909e+00,  1.5329e+00, -6.5998e-01, -1.5194e+00,  7.5818e-01,
          1.9944e-01,  2.2774e-01],
        [ 5.7299e-02,  1.0030e+00, -1.2168e-01,  1.6702e+00,  1.7145e-01,
         -6.2085e-01, -6.8763e-02,  7.7134e-02,  7.6783e-01,  1.4299e+00,
          1.3888e+00, -1.7189e-01,  1.9498e+00,  1.9764e+00,  1.6752e+00,
         -2.3225e-01,  1.5830e+00,  8.0632e-01, -2.8191e+00, -3.8667e-01,
          2.9416e-01, -1.1161e+00, -8.9780e-02, -7.1312e-01, -2.8513e-01,
          1.6922e+00, -2.1616e-03],
        [-2.7306e+00, -3.0366e-01, -5.3556e-01,  8.2630e-01, -3.8329e-01,
          3.1390e-01,  7.0886e-01, -5.6694e-01,  2.3444e-01,  1.2335e+00,
         -1.0487e+00,  2.2723e+00,  2.5501e-01, -1.1848e+00, -7.4882e-01,
         -1.5746e+00, -1.9665e-01, -3.8698e-01,  1.8836e+00,  1.0316e+00,
          3.2719e-01,  7.5979e-01,  1.3072e+00, -8.1478e-01,  2.5182e+00,
         -6.0154e-01,  8.3198e-02],
        [-2.7306e+00, -3.0366e-01, -5.3556e-01,  8.2630e-01, -3.8329e-01,
          3.1390e-01,  7.0886e-01, -5.6694e-01,  2.3444e-01,  1.2335e+00,
         -1.0487e+00,  2.2723e+00,  2.5501e-01, -1.1848e+00, -7.4882e-01,
         -1.5746e+00, -1.9665e-01, -3.8698e-01,  1.8836e+00,  1.0316e+00,
          3.2719e-01,  7.5979e-01,  1.3072e+00, -8.1478e-01,  2.5182e+00,
         -6.0154e-01,  8.3198e-02],
        [-1.0133e+00,  6.3266e-01, -5.0586e-01, -6.9402e-02,  1.3129e-01,
         -4.0221e-01, -1.9779e-01,  3.2179e-02, -9.3477e-01,  1.0209e+00,
         -1.2180e+00,  2.4028e+00, -1.2234e+00,  7.6524e-01,  1.2512e+00,
         -1.8544e+00, -1.2918e+00, -2.9848e-01, -8.2843e-01, -1.7549e+00,
          7.5878e-01, -1.6300e+00,  7.2076e-02, -4.5116e-01,  1.3405e+00,
          1.3706e+00,  4.9192e-01]])
In [44]:
(xenc @ W).shape
Out[44]:
torch.Size([5, 27])

the firing rate of the 13th neuron looking at the 3rd input

In [46]:
(xenc @ W)[3, 13]
Out[46]:
tensor(-1.1848)
In [47]:
(xenc[3]) 
Out[47]:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.])
In [48]:
W[:, 13]
Out[48]:
tensor([ 0.1144,  0.7652, -0.5911, -0.6675,  0.4726,  1.9764, -0.3261,  0.8006,
        -1.5484,  0.7669, -1.6879, -1.6150, -0.9893, -1.1848, -0.5114, -0.0257,
         1.5935, -0.9236,  0.0286,  1.9230,  0.2960, -0.1901, -0.9464, -2.8281,
        -0.5484,  0.7248,  1.3250])
In [50]:
(xenc[3] * W[:, 13]).sum()
Out[50]:
tensor(-1.1848)
In [45]:
logits = xenc @ W # log-counts
counts = logits.exp() # equivalent N
probs = counts / counts.sum(1, keepdims=True)
probs
Out[45]:
tensor([[0.0306, 0.1081, 0.0137, 0.0230, 0.0119, 0.0970, 0.0279, 0.1429, 0.0125,
         0.0351, 0.0220, 0.0804, 0.0182, 0.0226, 0.0160, 0.0056, 0.0201, 0.0087,
         0.0053, 0.0244, 0.0731, 0.0932, 0.0104, 0.0044, 0.0429, 0.0246, 0.0253],
        [0.0167, 0.0430, 0.0140, 0.0838, 0.0187, 0.0085, 0.0147, 0.0170, 0.0340,
         0.0659, 0.0633, 0.0133, 0.1109, 0.1139, 0.0842, 0.0125, 0.0768, 0.0353,
         0.0009, 0.0107, 0.0212, 0.0052, 0.0144, 0.0077, 0.0119, 0.0857, 0.0157],
        [0.0011, 0.0127, 0.0101, 0.0394, 0.0118, 0.0236, 0.0351, 0.0098, 0.0218,
         0.0593, 0.0060, 0.1675, 0.0223, 0.0053, 0.0082, 0.0036, 0.0142, 0.0117,
         0.1135, 0.0484, 0.0239, 0.0369, 0.0638, 0.0076, 0.2141, 0.0095, 0.0188],
        [0.0011, 0.0127, 0.0101, 0.0394, 0.0118, 0.0236, 0.0351, 0.0098, 0.0218,
         0.0593, 0.0060, 0.1675, 0.0223, 0.0053, 0.0082, 0.0036, 0.0142, 0.0117,
         0.1135, 0.0484, 0.0239, 0.0369, 0.0638, 0.0076, 0.2141, 0.0095, 0.0188],
        [0.0084, 0.0437, 0.0140, 0.0216, 0.0264, 0.0155, 0.0190, 0.0239, 0.0091,
         0.0644, 0.0069, 0.2564, 0.0068, 0.0498, 0.0810, 0.0036, 0.0064, 0.0172,
         0.0101, 0.0040, 0.0495, 0.0045, 0.0249, 0.0148, 0.0886, 0.0913, 0.0379]])
In [341]:
probs[0]
Out[341]:
tensor([0.0041, 0.0171, 0.0089, 0.0189, 0.0041, 0.0464, 0.0092, 0.0621, 0.0259,
        0.0087, 0.0147, 0.5287, 0.0373, 0.0151, 0.0480, 0.0040, 0.0096, 0.0117,
        0.0070, 0.0106, 0.0422, 0.0021, 0.0047, 0.0026, 0.0390, 0.0038, 0.0134])
In [342]:
probs[0].shape
Out[342]:
torch.Size([27])
In [343]:
probs[0].sum()
Out[343]:
tensor(1.0000)
In [344]:
# (5, 27) @ (27, 27) -> (5, 27)
In [345]:
# SUMMARY ------------------------------>>>>
In [51]:
xs
Out[51]:
tensor([ 0,  5, 13, 13,  1])
In [52]:
ys
Out[52]:
tensor([ 5, 13, 13,  1,  0])
In [53]:
# randomly initialize 27 neurons' weights. each neuron receives 27 inputs
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g)

by exponentiating the random sampled array, we are given positive floats such as one that could be used to predict integer counts for likelihood the logits are the "log counts" therefore counts is logits.exp()

probs is the normalized probabilities of counts

@ exp() and / + are all differentiable operations

we are: -taking inputs -we have differentiable operations we can backpropagate through -we are getting out probability distributions

softmax is a very often used layer in a neural net which takes logits exponentiates them and divides and normalizes them it takes a nerual net layer which could be positive or negative and outputs a probability distribution that always sums to 1 (normalizing function)

In [54]:
xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding
logits = xenc @ W # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
# btw: the last 2 lines here are together called a 'softmax'
In [55]:
probs.shape
Out[55]:
torch.Size([5, 27])
In [57]:
nlls = torch.zeros(5)
for i in range(5):
  # i-th bigram:
  x = xs[i].item() # input character index
  y = ys[i].item() # label character index
  print('--------')
  print(f'bigram example {i+1}: {itos[x]}{itos[y]} (indexes {x},{y})')
  print('input to the neural net:', x)
  print('output probabilities from the neural net:', probs[i])
  print('label (actual next character):', y)
  p = probs[i, y]
  print('probability assigned by the net to the the correct character:', p.item())
  logp = torch.log(p)
  print('log likelihood:', logp.item())
  nll = -logp
  print('negative log likelihood:', nll.item())
  nlls[i] = nll

print('=========')
print('average negative log likelihood, i.e. loss =', nlls.mean().item())
--------
bigram example 1: .e (indexes 0,5)
input to the neural net: 0
output probabilities from the neural net: tensor([0.0607, 0.0100, 0.0123, 0.0042, 0.0168, 0.0123, 0.0027, 0.0232, 0.0137,
        0.0313, 0.0079, 0.0278, 0.0091, 0.0082, 0.0500, 0.2378, 0.0603, 0.0025,
        0.0249, 0.0055, 0.0339, 0.0109, 0.0029, 0.0198, 0.0118, 0.1537, 0.1459])
label (actual next character): 5
probability assigned by the net to the the correct character: 0.01228625513613224
log likelihood: -4.399273872375488
negative log likelihood: 4.399273872375488
--------
bigram example 2: em (indexes 5,13)
input to the neural net: 5
output probabilities from the neural net: tensor([0.0290, 0.0796, 0.0248, 0.0521, 0.1989, 0.0289, 0.0094, 0.0335, 0.0097,
        0.0301, 0.0702, 0.0228, 0.0115, 0.0181, 0.0108, 0.0315, 0.0291, 0.0045,
        0.0916, 0.0215, 0.0486, 0.0300, 0.0501, 0.0027, 0.0118, 0.0022, 0.0472])
label (actual next character): 13
probability assigned by the net to the the correct character: 0.018050700426101685
log likelihood: -4.014570713043213
negative log likelihood: 4.014570713043213
--------
bigram example 3: mm (indexes 13,13)
input to the neural net: 13
output probabilities from the neural net: tensor([0.0312, 0.0737, 0.0484, 0.0333, 0.0674, 0.0200, 0.0263, 0.0249, 0.1226,
        0.0164, 0.0075, 0.0789, 0.0131, 0.0267, 0.0147, 0.0112, 0.0585, 0.0121,
        0.0650, 0.0058, 0.0208, 0.0078, 0.0133, 0.0203, 0.1204, 0.0469, 0.0126])
label (actual next character): 13
probability assigned by the net to the the correct character: 0.026691533625125885
log likelihood: -3.623408794403076
negative log likelihood: 3.623408794403076
--------
bigram example 4: ma (indexes 13,1)
input to the neural net: 13
output probabilities from the neural net: tensor([0.0312, 0.0737, 0.0484, 0.0333, 0.0674, 0.0200, 0.0263, 0.0249, 0.1226,
        0.0164, 0.0075, 0.0789, 0.0131, 0.0267, 0.0147, 0.0112, 0.0585, 0.0121,
        0.0650, 0.0058, 0.0208, 0.0078, 0.0133, 0.0203, 0.1204, 0.0469, 0.0126])
label (actual next character): 1
probability assigned by the net to the the correct character: 0.07367686182260513
log likelihood: -2.6080665588378906
negative log likelihood: 2.6080665588378906
--------
bigram example 5: a. (indexes 1,0)
input to the neural net: 1
output probabilities from the neural net: tensor([0.0150, 0.0086, 0.0396, 0.0100, 0.0606, 0.0308, 0.1084, 0.0131, 0.0125,
        0.0048, 0.1024, 0.0086, 0.0988, 0.0112, 0.0232, 0.0207, 0.0408, 0.0078,
        0.0899, 0.0531, 0.0463, 0.0309, 0.0051, 0.0329, 0.0654, 0.0503, 0.0091])
label (actual next character): 0
probability assigned by the net to the the correct character: 0.014977526850998402
log likelihood: -4.201204299926758
negative log likelihood: 4.201204299926758
=========
average negative log likelihood, i.e. loss = 3.7693049907684326
In [352]:
# --------- !!! OPTIMIZATION !!! yay --------------
In [353]:
xs
Out[353]:
tensor([ 0,  5, 13, 13,  1])
In [354]:
ys
Out[354]:
tensor([ 5, 13, 13,  1,  0])
In [58]:
# randomly initialize 27 neurons' weights. each neuron receives 27 inputs
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)

in makemore we are using negative log liklihood because we doing classification autograd used mean squared error because it was doing regression

loss is the negative log likelihood

to do gradient descent, you loop through forward pass, backward pass, adjustment for each labelled example tuning the NN

In [59]:
# forward pass
xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding
logits = xenc @ W # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
loss = -probs[torch.arange(5), ys].log().mean()
In [60]:
print(loss.item())
3.7693049907684326
In [61]:
# backward pass
W.grad = None # set to zero the gradient
loss.backward()
In [62]:
W.grad
Out[62]:
tensor([[ 0.0121,  0.0020,  0.0025,  0.0008,  0.0034, -0.1975,  0.0005,  0.0046,
          0.0027,  0.0063,  0.0016,  0.0056,  0.0018,  0.0016,  0.0100,  0.0476,
          0.0121,  0.0005,  0.0050,  0.0011,  0.0068,  0.0022,  0.0006,  0.0040,
          0.0024,  0.0307,  0.0292],
        [-0.1970,  0.0017,  0.0079,  0.0020,  0.0121,  0.0062,  0.0217,  0.0026,
          0.0025,  0.0010,  0.0205,  0.0017,  0.0198,  0.0022,  0.0046,  0.0041,
          0.0082,  0.0016,  0.0180,  0.0106,  0.0093,  0.0062,  0.0010,  0.0066,
          0.0131,  0.0101,  0.0018],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0058,  0.0159,  0.0050,  0.0104,  0.0398,  0.0058,  0.0019,  0.0067,
          0.0019,  0.0060,  0.0140,  0.0046,  0.0023, -0.1964,  0.0022,  0.0063,
          0.0058,  0.0009,  0.0183,  0.0043,  0.0097,  0.0060,  0.0100,  0.0005,
          0.0024,  0.0004,  0.0094],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0125, -0.1705,  0.0194,  0.0133,  0.0270,  0.0080,  0.0105,  0.0100,
          0.0490,  0.0066,  0.0030,  0.0316,  0.0052, -0.1893,  0.0059,  0.0045,
          0.0234,  0.0049,  0.0260,  0.0023,  0.0083,  0.0031,  0.0053,  0.0081,
          0.0482,  0.0187,  0.0051],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000]])
In [63]:
W.data += -0.1 * W.grad
In [292]:
# --------- !!! OPTIMIZATION !!! yay, but this time actually --------------
In [87]:
# create the dataset
xs, ys = [], []
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    xs.append(ix1)
    ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print('number of examples: ', num)

# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)
number of examples:  228146
In [86]:
num
Out[86]:
228146

if the regularization loss to too strong, you wont be able to overcome the loss and everything will be uniform predictions

In [90]:
# gradient descent
for k in range(300):
  
  # forward pass
  xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding
  logits = xenc @ W # predict log-counts
  counts = logits.exp() # counts, equivalent to N
  probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
  loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean() # + 0.01*(W**2).mean() regularization loss
  print(loss.item())
  
  # backward pass
  W.grad = None # set to zero the gradient
  loss.backward()
  
  # update
  W.data += -50 * W.grad
2.696505546569824
2.6773719787597656
2.6608052253723145
2.6463515758514404
2.633665084838867
2.622471570968628
2.6125476360321045
2.6037068367004395
2.595794916152954
2.5886809825897217
2.582256317138672
2.5764293670654297
2.5711238384246826
2.566272735595703
2.5618226528167725
2.5577263832092285
2.5539441108703613
2.550442695617676
2.5471925735473633
2.5441696643829346
2.5413522720336914
2.538722038269043
2.536262035369873
2.5339581966400146
2.531797409057617
2.5297679901123047
2.527860164642334
2.5260636806488037
2.5243709087371826
2.522773265838623
2.521263837814331
2.519836664199829
2.5184857845306396
2.517204999923706
2.515990734100342
2.5148372650146484
2.5137410163879395
2.512697696685791
2.511704921722412
2.5107579231262207
2.509855031967163
2.5089924335479736
2.5081682205200195
2.507380485534668
2.5066256523132324
2.5059032440185547
2.5052108764648438
2.5045459270477295
2.503908157348633
2.503295421600342
2.502706289291382
2.5021398067474365
2.5015945434570312
2.5010693073272705
2.500563383102417
2.500075578689575
2.4996049404144287
2.499150514602661
2.4987120628356934
2.498288154602051
2.4978790283203125
2.4974827766418457
2.4970998764038086
2.4967293739318848
2.496370315551758
2.496022939682007
2.4956860542297363
2.4953596591949463
2.4950432777404785
2.4947361946105957
2.494438648223877
2.494149684906006
2.4938690662384033
2.4935967922210693
2.4933321475982666
2.493075132369995
2.4928252696990967
2.492582321166992
2.4923462867736816
2.492116928100586
2.4918932914733887
2.491675853729248
2.491464376449585
2.491258144378662
2.491058111190796
2.4908626079559326
2.4906723499298096
2.4904870986938477
2.4903063774108887
2.4901304244995117
2.4899587631225586
2.4897916316986084
2.489628553390503
2.489469289779663
2.489314079284668
2.4891624450683594
2.4890148639678955
2.488870859146118
2.48872971534729
2.4885921478271484
2.4884581565856934
2.4883270263671875
2.488198757171631
2.4880735874176025
2.4879512786865234
2.4878318309783936
2.4877147674560547
2.487600326538086
2.4874887466430664
2.487379550933838
2.4872729778289795
2.487168312072754
2.4870667457580566
2.486966133117676
2.486868381500244
2.4867727756500244
2.4866788387298584
2.4865872859954834
2.486497163772583
2.4864091873168945
2.486323118209839
2.486238718032837
2.4861559867858887
2.4860751628875732
2.4859957695007324
2.4859180450439453
2.485841989517212
2.485767126083374
2.4856936931610107
2.485621929168701
2.485551357269287
2.485482692718506
2.485414981842041
2.4853484630584717
2.485283374786377
2.4852192401885986
2.485156536102295
2.4850947856903076
2.485034465789795
2.484975576400757
2.484916925430298
2.4848597049713135
2.4848034381866455
2.484748363494873
2.484694719314575
2.4846413135528564
2.484589099884033
2.4845376014709473
2.4844870567321777
2.484437942504883
2.484389066696167
2.4843411445617676
2.4842941761016846
2.484247922897339
2.4842026233673096
2.4841578006744385
2.484113931655884
2.4840710163116455
2.4840285778045654
2.4839866161346436
2.483945369720459
2.48390531539917
2.483865261077881
2.483825922012329
2.483787775039673
2.4837498664855957
2.483712911605835
2.483675956726074
2.4836394786834717
2.4836041927337646
2.483569383621216
2.483534574508667
2.4835009574890137
2.4834673404693604
2.4834344387054443
2.4834022521972656
2.483370304107666
2.4833385944366455
2.4833078384399414
2.4832773208618164
2.4832472801208496
2.483217716217041
2.4831883907318115
2.4831600189208984
2.4831314086914062
2.4831035137176514
2.4830760955810547
2.483048439025879
2.4830219745635986
2.4829957485198975
2.4829697608947754
2.4829440116882324
2.4829187393188477
2.482893466949463
2.4828693866729736
2.4828450679779053
2.482820987701416
2.482797622680664
2.482774019241333
2.4827513694763184
2.4827287197113037
2.4827065467834473
2.482684373855591
2.4826626777648926
2.4826409816741943
2.4826200008392334
2.4825992584228516
2.4825785160064697
2.482558250427246
2.4825379848480225
2.482517957687378
2.48249888420105
2.4824793338775635
2.4824604988098145
2.4824414253234863
2.4824228286743164
2.4824042320251465
2.4823861122131348
2.4823684692382812
2.4823508262634277
2.482333183288574
2.482316255569458
2.4822990894317627
2.4822821617126465
2.4822654724121094
2.4822492599487305
2.4822330474853516
2.4822168350219727
2.482201337814331
2.4821856021881104
2.4821698665618896
2.482154369354248
2.4821391105651855
2.4821245670318604
2.482110023498535
2.4820950031280518
2.4820809364318848
2.4820668697357178
2.4820523262023926
2.482038736343384
2.482024908065796
2.482011556625366
2.4819979667663574
2.4819846153259277
2.4819717407226562
2.4819588661193848
2.481945753097534
2.481933116912842
2.4819209575653076
2.481908082962036
2.481895923614502
2.4818837642669678
2.4818718433380127
2.4818601608276367
2.4818480014801025
2.4818363189697266
2.481825113296509
2.481813669204712
2.4818027019500732
2.4817914962768555
2.481780529022217
2.481769561767578
2.4817588329315186
2.481748342514038
2.4817378520965576
2.481727361679077
2.4817168712615967
2.4817068576812744
2.481696844100952
2.48168683052063
2.4816770553588867
2.4816672801971436
2.4816575050354004
2.4816479682922363
2.481638193130493
2.4816291332244873
2.4816195964813232
2.4816102981567383
2.4816012382507324
2.4815924167633057
2.481583833694458
2.481574773788452
2.4815661907196045
2.481557607650757
2.48154878616333
2.4815402030944824
2.481531858444214
2.4815237522125244
2.481515645980835
2.4815075397491455
2.481499195098877
2.4814915657043457
2.4814834594726562
2.481475591659546
2.4814682006835938
2.4814603328704834
2.481452465057373
2.4814453125
2.4814376831054688
2.4814302921295166
In [91]:
# finally, sample from the 'neural net' model
g = torch.Generator().manual_seed(2147483647)

for i in range(5):
  
  out = []
  ix = 0
  while True:
    
    # ----------
    # BEFORE:
    #p = P[ix]
    # ----------
    # NOW:
    xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
    logits = xenc @ W # predict log-counts
    counts = logits.exp() # counts, equivalent to N
    p = counts / counts.sum(1, keepdims=True) # probabilities for next character
    # ----------
    
    ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[ix])
    if ix == 0:
      break
  print(''.join(out))
mor.
axx.
minaymoryles.
kondlaisah.
anchthizarie.

both the bigram probability population (N) and the weights of the neural net (W) end up being the same though they arrive that these numbers much differently N is computed by summing up the probabilities and taking the log W is computed first by some random sampling, then multiplying the training set one_hot encoded with the training set to get the logits the logits exponentiated is equal to the counts the probs normalizes the counts into a probability distribution loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean() probs[torch.arange(num), ys] plucks out the probabilities that the neural net assigns to the correct next character .log() because we actually look at the log probability .mean() is the average then we multiply by -1 to take the negative log likelihood

calling .backward() will fill in all the gradients for our W (W.grad)

Homework:

Posted In:

ABOUT THE AUTHOR:
Software Developer always striving to be better. Learn from others' mistakes, learn by doing, fail fast, maximize productivity, and really think hard about good defaults. Computer developers have the power to add an entire infinite dimension with a single Int (or maybe BigInt). The least we can do with that power is be creative.