Layer 0, Head 0 KEY: closest word mappings ranked by amplitude

Space-delineated word mappings:
(( demic (( (( ]( Tags (( (( (( (( ]( (( (( (( | (( (( (( (( (( (( (( (( demic (( ({
(( (( (( (( (( ({ )))) ({ (( (( demic (( ( : (( (( (( ( )): Although (( | : ({ [{ ((
{( )))) )): : )))) Although ...   Although [{ ( <0x00> demic ... {( ingly }) )))) )):
)") demic ingly occasionally ingly )))) aneous Although Although {[ : Although
Although ingly   Although ingly ingly ^{( ... File : ological Although ingly
Although 。 <0x00> ${ ingly }) ingly approximately ingly ingly {[ ( : }) ${ ( ingly :
aneous )}} aneous [{ )))) *( aneous ). : )): )^{ )))) db
Corresponding .abs() values:
tensor([0.4941, 0.4922, 0.4414, 0.4375, 0.4375, 0.4355, 0.4023, 0.3887, 0.3789,
        0.3789, 0.3750, 0.3730, 0.3730, 0.3730, 0.3730, 0.3672, 0.3652, 0.3633,
        0.3594, 0.3535, 0.3496, 0.3438, 0.3359, 0.3320, 0.3223, 0.3203, 0.3105,
        0.3105, 0.3066, 0.2910, 0.2910, 0.2891, 0.2891, 0.2891, 0.2891, 0.2871,
        0.2871, 0.2812, 0.2715, 0.2715, 0.2695, 0.2676, 0.2656, 0.2617, 0.2598,
        0.2539, 0.2520, 0.2490, 0.2461, 0.2461, 0.2451, 0.2412, 0.2334, 0.2324,
        0.2324, 0.2324, 0.2275, 0.2266, 0.2207, 0.2197, 0.2139, 0.2139, 0.2129,
        0.2080, 0.2070, 0.2061, 0.2031, 0.2031, 0.2031, 0.2012, 0.1934, 0.1904,
        0.1885, 0.1875, 0.1865, 0.1865, 0.1846, 0.1826, 0.1826, 0.1826, 0.1816,
        0.1816, 0.1787, 0.1777, 0.1777, 0.1777, 0.1768, 0.1748, 0.1748, 0.1748,
        0.1748, 0.1729, 0.1719, 0.1709, 0.1699, 0.1689, 0.1689, 0.1680, 0.1670,
        0.1650, 0.1641, 0.1641, 0.1621, 0.1611, 0.1611, 0.1602, 0.1592, 0.1572,
        0.1562, 0.1562, 0.1553, 0.1523, 0.1484, 0.1484, 0.1455, 0.1445, 0.1416,
        0.1348, 0.1328, 0.1309, 0.1299, 0.1270, 0.1177, 0.1177, 0.1143, 0.1118,
        0.1030, 0.0806])
Corresponding Wk_vector.norm() values (vector amplitude):
tensor([1.1953, 1.6328, 1.0781, 1.1484, 1.7109, 1.4922, 1.0781, 1.0469, 1.0938,
        1.0703, 1.1562, 1.0234, 1.0547, 1.0938, 1.6562, 1.0781, 1.0938, 1.0156,
        0.9844, 1.0234, 1.1094, 0.9727, 1.2344, 1.3516, 1.0625, 0.9961, 0.8438,
        1.1328, 1.0234, 0.8750, 0.9258, 0.9453, 0.9258, 0.8945, 0.9297, 0.7852,
        1.2656, 0.9883, 1.0859, 0.8711, 0.8867, 0.8203, 0.7773, 0.8086, 0.9336,
        1.4688, 0.7578, 1.2578, 0.8906, 0.9727, 1.0000, 1.0547, 0.7773, 1.0703,
        0.8086, 0.9688, 0.7500, 1.4375, 1.1172, 1.1016, 1.2812, 1.0312, 0.9727,
        1.1484, 1.0781, 0.7812, 0.7266, 1.3281, 1.0938, 1.1250, 0.8398, 1.0234,
        1.0078, 1.2578, 1.1484, 1.2422, 0.6758, 1.0859, 1.2969, 1.2812, 0.7695,
        0.7969, 1.1094, 1.2422, 1.1641, 1.0859, 1.2812, 1.1328, 1.1328, 0.7266,
        1.1016, 1.1328, 1.0078, 1.1406, 1.1094, 1.2188, 1.0625, 1.0938, 1.0312,
        0.9961, 1.0078, 1.0938, 0.9531, 1.0781, 1.0312, 1.1484, 1.0156, 0.7539,
        0.8633, 0.5312, 1.0625, 1.0859, 0.8555, 1.0859, 0.6680, 0.9961, 0.7344,
        0.9805, 0.8242, 0.5781, 0.7617, 0.9141, 0.7344, 0.4785, 0.5938, 0.8242,
        0.5312, 0.6445])

Layer 0, Head 0 QUERY: closest word mappings ranked by amplitude

Space-delineated word mappings:
<s> ]) <s> ]) ]) ]) ]) ]]) ]) }}) )))) ]) ]) ]) )). }}) }}) )))) }}) )))) }}) ]) ))))
]). )))) }] ) ]) ]) )}} <s> )). ]). }}) <s> }}) >) )))) *) )): ]) "] ]) )}} }] ]) }})
}}) )))) )") )))) "] ]) }}) : ) Ve }} <s> <s> ) <s> <s> ( ), <s> ") <s> <s> <s> : Sur
)), <s> <s> Mi )), .) ?” ) Anthony Sur :} Fu – Ara <s> <s> storage )))) '] } )))) ))))
)- ). Anthony ) Guy Bon La ( Su A ?) ( )), <s> Cent Sur Stuart )/ ery ICATION UES </s>
The ] Philadelphia oo ) :} : ... .) Broadcast antry Mc
Corresponding .abs() values:
tensor([0.8750, 0.7695, 0.6992, 0.6172, 0.5078, 0.5000, 0.4961, 0.4863, 0.4805,
        0.4746, 0.4688, 0.4668, 0.4570, 0.4512, 0.4453, 0.4434, 0.4434, 0.4316,
        0.4297, 0.4238, 0.4199, 0.4141, 0.4004, 0.3809, 0.3809, 0.3750, 0.3711,
        0.3691, 0.3652, 0.3535, 0.3516, 0.3418, 0.3418, 0.3398, 0.3379, 0.3359,
        0.3359, 0.3164, 0.3125, 0.3125, 0.3066, 0.3066, 0.3008, 0.3008, 0.2930,
        0.2891, 0.2852, 0.2812, 0.2773, 0.2715, 0.2695, 0.2695, 0.2617, 0.2617,
        0.2617, 0.2578, 0.2578, 0.2520, 0.2520, 0.2500, 0.2412, 0.2363, 0.2354,
        0.2275, 0.2236, 0.2236, 0.2227, 0.2207, 0.2207, 0.2178, 0.2178, 0.2178,
        0.2070, 0.2061, 0.2051, 0.2031, 0.2031, 0.1992, 0.1934, 0.1934, 0.1934,
        0.1904, 0.1904, 0.1885, 0.1855, 0.1816, 0.1738, 0.1719, 0.1699, 0.1699,
        0.1699, 0.1680, 0.1660, 0.1660, 0.1650, 0.1631, 0.1631, 0.1621, 0.1621,
        0.1621, 0.1611, 0.1592, 0.1582, 0.1553, 0.1533, 0.1514, 0.1494, 0.1494,
        0.1484, 0.1387, 0.1357, 0.1348, 0.1338, 0.1318, 0.1318, 0.1318, 0.1309,
        0.1299, 0.1299, 0.1279, 0.1235, 0.1235, 0.1216, 0.1196, 0.1172, 0.1104,
        0.1011, 0.0913])
Corresponding Wq_vector.norm() values (vector amplitude):
tensor([2.1094, 1.5312, 1.7812, 1.2422, 1.1328, 1.1328, 1.1172, 1.1641, 1.1016,
        1.0859, 1.0625, 0.9727, 1.0625, 1.1250, 1.1719, 1.1094, 1.0000, 1.0938,
        0.9766, 1.0938, 1.0156, 1.0000, 0.9453, 0.9766, 1.0312, 0.9492, 0.8828,
        0.9492, 0.9805, 0.9492, 1.2109, 1.0078, 1.1172, 0.9219, 1.1094, 0.9258,
        0.8789, 0.8906, 0.9375, 0.8711, 0.8828, 0.9219, 0.8320, 0.8008, 0.8555,
        0.7617, 0.8555, 0.9570, 0.7227, 0.7578, 0.7773, 0.8789, 0.8945, 0.9570,
        0.9023, 0.8164, 1.2734, 0.7852, 1.0156, 0.8281, 0.7500, 0.9883, 0.9531,
        0.6914, 1.1562, 0.8906, 1.2500, 0.9062, 0.9570, 0.9805, 0.6328, 1.2344,
        0.8594, 0.8203, 0.8438, 1.0938, 0.8789, 1.0938, 0.7461, 0.8203, 1.1797,
        1.0547, 0.6211, 1.0703, 0.8164, 1.0781, 0.6875, 0.7617, 1.0156, 0.9844,
        0.7266, 0.8867, 0.8320, 0.7695, 0.9453, 0.6602, 0.9453, 0.6797, 0.9609,
        0.9570, 0.9062, 0.6641, 0.8438, 0.7656, 0.7656, 0.7461, 0.7539, 0.8555,
        0.9141, 0.8359, 0.8281, 0.7461, 0.7852, 0.7031, 0.7695, 0.7891, 0.6250,
        0.5703, 0.7734, 0.7617, 0.6133, 0.5430, 0.6055, 0.6680, 0.5195, 0.7305,
        0.6836, 0.5977])
#####code####
def find_words_by_head(head = 5, all_dat = loaded, mat = loaded['layers.0.attention.wk.weight']):
    headRange = (slice(head*dim_per_head,(head+1)*dim_per_head), slice(0,4096))
    headMat = torch.tensor(mat[headRange])  # [128, 4096]

    vects_ME = (all_dat['tok_embeddings.weight'] @ headMat.transpose(0,1)).abs() # [32000, 128]
    
    best_vect = torch.zeros(128,3)
    for pl in range(128):
        best_vect[pl,0] = vects_ME[:,pl].argmax()      # indices
        best_vect[pl,1] = vects_ME[best_vect[pl,0].int(),pl]  # values
        best_vect[pl,2] = headMat[pl,:].norm()
        
        
    sorted_best = best_vect[:,1].sort(descending = True)
    print("Closest word mappings for head #", head)
    print("Head vector amplitudes: ")
    print(best_vect[sorted_best.indices,2])
    print("Closest  match: ")
    print(sorted_best.values)
    
    print(' '.join([tokenizer.decode(int(best_vect[sorted_best.indices,0][pl])) for pl in range(128)]))
#####code####