c_polynomial
#include <stdio.h> // 包含标准输入输出库
#include <stdlib.h> // 包含标准库函数,如malloc等
#include <string.h> // 包含字符串处理函数,如memcpy
#include <intrin.h> // 注意GCC可能没有这个头文件,需要使用__builtin函数或者自己实现_bittest64
#include <stdint.h>
#include <windows.h>
#include <float.h> // For _control87
#include <fenv.h>
int coeffs[9]; // 定义一个整型数组,用于存储多项式的系数
uint64_t v5 = 0x400C0210000001LL;
unsigned int v12;
typedef struct output { // 定义一个结构体,用于存储输出的flag
int a;
int b;
int c;
int d;
short e; // short类型用于节省空间
short f;
short g;
short h;
short i;
} output;
long total = 0; // 定义一个长整型变量,用于累加计算结果
long power = 1; // 定义一个长整型变量,用于计算x的幂
int result = 0;
// 定义一个字符数组,用于存储XOR加密的密钥
char xorcode[60] = "\xd5\x2a\x00\xd8\xec\xf1\x77\x43\x4d\xc5\xbc\x9c\xab\x3a\x65\xd9\x6b\x1f\x5d\x3a\x61\x9b\x9b\xcc\x39\x2d\xcb\x2a\x1a\xda\xfc\xf6\x65\x1c\x03\x96\xef\x86\xe9\x6b\x3d\x90\x37\x51\x03\x63\x36\xc5\xc3\x8e\x66\x7d";
int main() {
feclearexcept(FE_ALL_EXCEPT);
// 检测是否有浮点异常发生
if (fetestexcept(FE_INVALID | FE_DENORMAL)) {
printf("Floating point exceptions are set, possibly being debugged.\n");
exit(EXIT_FAILURE);
}
if (IsDebuggerPresent()) {
printf("This program is being debugged. Exiting!\n");
exit(1);
}
printf("I'm practicing my neuro math skills. Give me nine integers: "); // 提示用户输入9个整数
scanf("%d %d %d %d %d %d %d %d %d", &coeffs[0], &coeffs[1], &coeffs[2], &coeffs[3], &coeffs[4], &coeffs[5], &coeffs[6], &coeffs[7], &coeffs[8]);
printf("Hmm, let me think");
fflush(stdout);
if (IsDebuggerPresent()) {
printf("This program is being debugged. Exiting!\n");
exit(1);
}
printf(".");
fflush(stdout);
printf(".");
fflush(stdout);
printf(".\n");
// 验证多项式在特定x值下的结果是否正确
for (int i = -60; i < 60; i++) {
total = 0;
power = 1;
for (int x = 0; x < 9; x++) { // 循环遍历系数数组
total += power * coeffs[x]; // 累加当前系数乘以x的幂 coeffs*数值 x的不同幂相加 由x的0次方加到x的8次方
power *= i; // 更新x的幂
result =total;
}
if( i == 114 || i == 514 || (v12 = (unsigned int)(i + 37), v12 <= 54) && _bittest64((const int64_t*)&v5, v12) ) {
if (IsDebuggerPresent()) {
printf("This program is being debugged. Exiting!\n");
exit(1);
} // (x-44)(x-58)(x-5)(x+37)(x-17)(x+9)(x-6)(x+4)多项式性质有一项为0就都为0
if (result != 0) { // 如果在根处结果不为0,则验证失败
printf("Those aren't the right numbers. Try again!\n");
return 1; // 退出程序
}
} else {
if (result == 0) { if (IsDebuggerPresent()) {
printf("This program is being debugged. Exiting!\n");
exit(1);
} // 如果在非根处结果为0,则验证失败
printf("Those aren't the right numbers. Try again!\n");
return 1; // 退出程序
}
}
}
// 如果最高次项系数不为1,则将所有系数除以最高次项系数,进行归一化
if (coeffs[8] != 1) {
for (int i = 0; i < 9; i++) {
coeffs[i] /= coeffs[8];
}
}
if (IsDebuggerPresent()) {
printf("This program is being debugged. Exiting!\n");
exit(1);
}
if (coeffs[7] != -606)
{
printf("WRONG") ;
return 1;}
if (coeffs[6] != 44114)
{
printf("WRONG") ;
return 1;}
// 验证通过,准备输出flag
printf("Correct! Here's the flag: ");
output o; // 创建一个output结构体实例
o.a = coeffs[0]; // 将系数赋值给output结构体的成员
o.b = coeffs[1];
o.c = coeffs[2];
o.d = coeffs[3];
o.e = coeffs[4];
o.f = coeffs[5];
o.g = coeffs[6];
o.h = coeffs[7];
o.i = coeffs[8];
// 使用XOR加密算法对flag进行加密
unsigned char xorbuf[52];
memcpy(xorbuf, (char*)&o, 26); // 将output结构体的前26个字节复制到xorbuf
memcpy(xorbuf+26, (char*)&o, 26); // 将output结构体的后26个字节复制到xorbuf的后26个字节位置
for (int i = 0; i < 52; i++) {
xorbuf[i] ^= xorcode[i]; // 对xorbuf中的每个字节进行XOR加密
printf("%c", xorbuf[i]); // 输出加密后的字符
}
printf("\n"); // 输出换行符
}
c_sm4
备注:只有 1 个魔改点(在密钥扩展里:K[0] = MK[0] ^ SM4_FK[0]+1; 以及后面 +2/+3/+4)。
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#define SM4_BLOCK_SIZE 16
#define SM4_ROUNDS 32
/* ====== 标准 SM4 Sbox ====== */
static const uint8_t SM4_SBOX[256] = {
0xd6,0x90,0xe9,0xfe,0xcc,0xe1,0x3d,0xb7,0x16,0xb6,0x14,0xc2,0x28,0xfb,0x2c,0x05,
0x2b,0x67,0x9a,0x76,0x2a,0xbe,0x04,0xc3,0xaa,0x44,0x13,0x26,0x49,0x86,0x06,0x99,
0x9c,0x42,0x50,0xf4,0x91,0xef,0x98,0x7a,0x33,0x54,0x0b,0x43,0xed,0xcf,0xac,0x62,
0xe4,0xb3,0x1c,0xa9,0xc9,0x08,0xe8,0x95,0x80,0xdf,0x94,0xfa,0x75,0x8f,0x3f,0xa6,
0x47,0x07,0xa7,0xfc,0xf3,0x73,0x17,0xba,0x83,0x59,0x3c,0x19,0xe6,0x85,0x4f,0xa8,
0x68,0x6b,0x81,0xb2,0x71,0x64,0xda,0x8b,0xf8,0xeb,0x0f,0x4b,0x70,0x56,0x9d,0x35,
0x1e,0x24,0x0e,0x5e,0x63,0x58,0xd1,0xa2,0x25,0x22,0x7c,0x3b,0x01,0x21,0x78,0x87,
0xd4,0x00,0x46,0x57,0x9f,0xd3,0x27,0x52,0x4c,0x36,0x02,0xe7,0xa0,0xc4,0xc8,0x9e,
0xea,0xbf,0x8a,0xd2,0x40,0xc7,0x38,0xb5,0xa3,0xf7,0xf2,0xce,0xf9,0x61,0x15,0xa1,
0xe0,0xae,0x5d,0xa4,0x9b,0x34,0x1a,0x55,0xad,0x93,0x32,0x30,0xf5,0x8c,0xb1,0xe3,
0x1d,0xf6,0xe2,0x2e,0x82,0x66,0xca,0x60,0xc0,0x29,0x23,0xab,0x0d,0x53,0x4e,0x6f,
0xd5,0xdb,0x37,0x45,0xde,0xfd,0x8e,0x2f,0x03,0xff,0x6a,0x72,0x6d,0x6c,0x5b,0x51,
0x8d,0x1b,0xaf,0x92,0xbb,0xdd,0xbc,0x7f,0x11,0xd9,0x5c,0x41,0x1f,0x10,0x5a,0xd8,
0x0a,0xc1,0x31,0x88,0xa5,0xcd,0x7b,0xbd,0x2d,0x74,0xd0,0x12,0xb8,0xe5,0xb4,0xb0,
0x89,0x69,0x97,0x4a,0x0c,0x96,0x77,0x7e,0x65,0xb9,0xf1,0x09,0xc5,0x6e,0xc6,0x84,
0x18,0xf0,0x7d,0xec,0x3a,0xdc,0x4d,0x20,0x79,0xee,0x5f,0x3e,0xd7,0xcb,0x39,0x48
};
/* ====== FK / CK ====== */
static const uint32_t SM4_FK[4] = {
0xa3b1bac6u, 0x56aa3350u, 0x677d9197u, 0xb27022dcu
};
static const uint32_t SM4_CK[32] = {
0x00070e15u,0x1c232a31u,0x383f464du,0x545b6269u,
0x70777e85u,0x8c939aa1u,0xa8afb6bdu,0xc4cbd2d9u,
0xe0e7eef5u,0xfc030a11u,0x181f262du,0x343b4249u,
0x50575e65u,0x6c737a81u,0x888f969du,0xa4abb2b9u,
0xc0c7ced5u,0xdce3eaf1u,0xf8ff060du,0x141b2229u,
0x30373e45u,0x4c535a61u,0x686f767du,0x848b9299u,
0xa0a7aeb5u,0xbcc3cad1u,0xd8dfe6edu,0xf4fb0209u,
0x10171e25u,0x2c333a41u,0x484f565du,0x646b7279u
};
/* ====== 工具函数(字节序安全)====== */
static uint32_t rotl32(uint32_t x, int n) {
return (x << n) | (x >> (32 - n));
}
static uint32_t load_be32(const uint8_t b[4]) {
return ((uint32_t)b[0] << 24) | ((uint32_t)b[1] << 16) | ((uint32_t)b[2] << 8) | (uint32_t)b[3];
}
static void store_be32(uint8_t b[4], uint32_t v) {
b[0] = (uint8_t)(v >> 24);
b[1] = (uint8_t)(v >> 16);
b[2] = (uint8_t)(v >> 8);
b[3] = (uint8_t)(v);
}
static uint32_t tau(uint32_t a) {
uint8_t x[4];
store_be32(x, a);
x[0] = SM4_SBOX[x[0]];
x[1] = SM4_SBOX[x[1]];
x[2] = SM4_SBOX[x[2]];
x[3] = SM4_SBOX[x[3]];
return load_be32(x);
}
/* 加解密轮函数里的线性变换 L */
static uint32_t L(uint32_t b) {
return b ^ rotl32(b, 2) ^ rotl32(b, 10) ^ rotl32(b, 18) ^ rotl32(b, 24);
}
/* 密钥扩展里的线性变换 L' */
static uint32_t Lp(uint32_t b) {
return b ^ rotl32(b, 13) ^ rotl32(b, 23);
}
static uint32_t T(uint32_t x) { /* 用于加解密 */
return L(tau(x));
}
static uint32_t Tp(uint32_t x) { /* 用于密钥扩展 */
return Lp(tau(x));
}
/* ====== SM4 密钥扩展 ====== */
static void sm4_setkey_enc(uint32_t rk[32], const uint8_t key[16]) {
uint32_t MK[4], K[36];
MK[0] = load_be32(key + 0);
MK[1] = load_be32(key + 4);
MK[2] = load_be32(key + 8);
MK[3] = load_be32(key + 12);
K[0] = MK[0] ^ SM4_FK[0]+1;//这里稍微魔改了一下
K[1] = MK[1] ^ SM4_FK[1]+2;
K[2] = MK[2] ^ SM4_FK[2]+3;
K[3] = MK[3] ^ SM4_FK[3]+4;
for (int i = 0; i < 32; i++) {
uint32_t tmp = K[i+1] ^ K[i+2] ^ K[i+3] ^ SM4_CK[i];
K[i+4] = K[i] ^ Tp(tmp);
rk[i] = K[i+4];
}
}
static void sm4_setkey_dec(uint32_t rk_dec[32], const uint8_t key[16]) {
uint32_t rk_enc[32];
sm4_setkey_enc(rk_enc, key);
for (int i = 0; i < 32; i++) rk_dec[i] = rk_enc[31 - i];
}
/* ====== SM4 单块加/解密(同一个函数,rk不同)====== */
static void sm4_crypt_block(const uint32_t rk[32], const uint8_t in[16], uint8_t out[16]) {
uint32_t X[36];
X[0] = load_be32(in + 0);
X[1] = load_be32(in + 4);
X[2] = load_be32(in + 8);
X[3] = load_be32(in + 12);
for (int i = 0; i < 32; i++) {
uint32_t tmp = X[i+1] ^ X[i+2] ^ X[i+3] ^ rk[i];
X[i+4] = X[i] ^ T(tmp);
}
/* 反序输出 */
store_be32(out + 0, X[35]);
store_be32(out + 4, X[34]);
store_be32(out + 8, X[33]);
store_be32(out + 12, X[32]);
}
/* ====== PKCS#7 padding ====== */
static uint8_t* pkcs7_pad(const uint8_t* in, size_t inlen, size_t* outlen) {
size_t pad = SM4_BLOCK_SIZE - (inlen % SM4_BLOCK_SIZE);
if (pad == 0) pad = SM4_BLOCK_SIZE;
*outlen = inlen + pad;
uint8_t* out = (uint8_t*)malloc(*outlen);
if (!out) return NULL;
memcpy(out, in, inlen);
memset(out + inlen, (int)pad, pad);
return out;
}
static int pkcs7_unpad(uint8_t* buf, size_t* len) {
if (*len == 0 || (*len % SM4_BLOCK_SIZE) != 0) return 0;
uint8_t pad = buf[*len - 1];
if (pad == 0 || pad > SM4_BLOCK_SIZE) return 0;
for (size_t i = 0; i < pad; i++) {
if (buf[*len - 1 - i] != pad) return 0;
}
*len -= pad;
return 1;
}
/* ====== ECB 加/解密 ====== */
static uint8_t* sm4_ecb_encrypt(const uint8_t* plaintext, size_t plen,
const uint8_t key[16], size_t* clen) {
uint32_t rk[32];
sm4_setkey_enc(rk, key);
size_t padded_len = 0;
uint8_t* padded = pkcs7_pad(plaintext, plen, &padded_len);
if (!padded) return NULL;
uint8_t* ct = (uint8_t*)malloc(padded_len);
if (!ct) { free(padded); return NULL; }
for (size_t off = 0; off < padded_len; off += SM4_BLOCK_SIZE) {
sm4_crypt_block(rk, padded + off, ct + off);
}
free(padded);
*clen = padded_len;
return ct;
}
/* ====== 打印 hex ====== */
static void print_hex(const uint8_t* data, size_t len) {
for (size_t i = 0; i < len; i++) printf("%02x", data[i]);
printf("\n");
}
int main(void) {
/* 你可以改成自己想要的 16 字节 key */
const uint8_t key[16] = {
0x01,0x23,0x45,0x67,0x89,0xab,0xcd,0xef,
0xfe,0xdc,0xba,0x98,0x76,0x54,0x32,0x10
};
char input[4096];
printf("Enter a string to encrypt and decrypt: ");
if (!fgets(input, sizeof(input), stdin)) return 1;
/* 去掉末尾换行 */
size_t inlen = strlen(input);
if (inlen > 0 && input[inlen - 1] == '\n') {
input[inlen - 1] = '\0';
inlen--;
}
/* 加密 */
size_t clen = 0;
uint8_t* ct = sm4_ecb_encrypt((const uint8_t*)input, inlen, key, &clen);
if (!ct) {
printf("Encrypt failed.\n");
return 1;
}
printf("Encrypted (Hex): ");
print_hex(ct, clen);
free(ct);
return 0;
}
r_zip
use std::env;
use std::fs;
use std::io;
const WINDOW_SIZE: usize = 256;
const MIN_MATCH: usize = 3;
const MAX_MATCH: usize = 15;
fn find_match(input: &[u8], pos: usize) -> (usize, usize) {
let mut best_len = 0usize;
let mut best_dist = 0usize;
let start = if pos > WINDOW_SIZE { pos - WINDOW_SIZE } else { 0 };
for i in start..pos {
let mut match_len = 0usize;
while match_len < MAX_MATCH
&& pos + match_len < input.len()
&& i + match_len < input.len() && input[i + match_len] == input[pos + match_len] { match_len += 1; } if match_len >= MIN_MATCH && match_len > best_len {
best_len = match_len;
best_dist = pos - i;
}
}
(best_len, best_dist)
}
fn compress_bytes(input: &[u8]) -> Vec {
let mut output = Vec::with_capacity(input.len() * 2);
let mut pos = 0usize;
while pos < input.len() { let (match_len, distance) = find_match(input, pos); if match_len >= MIN_MATCH {
let b0 = 0x80u8 | ((distance >> 4) as u8);
let b1 = (((distance & 0x0F) << 4) as u8) | (match_len as u8); output.push(b0); output.push(b1); pos += match_len; } else { output.push(input[pos]); pos += 1; } } output } fn usage() { eprintln!("用法: compress <input> <output>\n示例: compress a.txt a.z"); } fn main() -> io::Result<()> {
let args: Vec<String> = env::args().collect();
if args.len() != 3 {
usage();
std::process::exit(1);
}
let input_file = &args[1];
let output_file = &args[2];
let input_data = fs::read(input_file)?;
let compressed = compress_bytes(&input_data);
println!("原始大小: {} 字节", input_data.len());
println!("压缩后大小: {} 字节", compressed.len());
if !input_data.is_empty() {
let ratio = (1.0 - (compressed.len() as f64) / (input_data.len() as f64)) * 100.0;
println!("压缩率: {:.1}%", ratio);
}
fs::write(output_file, compressed)?;
println!("压缩完成: {}", output_file);
Ok(())
}
r_png
use std::env;
use std::fs;
use std::io;
fn is_4digit_key(s: &str) -> bool {
s.len() == 4 && s.bytes().all(|b| b.is_ascii_digit())
}
fn rc4_crypt(data: &mut [u8], key: &[u8]) {
let mut s = [0u8; 256];
for (i, v) in s.iter_mut().enumerate() {
*v = i as u8;
}
// KSA
let mut j: u8 = 0;
for i in 0..256usize {
let si = s[i];
let kb = key[i % key.len()];
j = j.wrapping_add(si).wrapping_add(kb);
s.swap(i, j as usize);
}
// PRGA
let mut i: u8 = 0;
j = 0;
for byte in data.iter_mut() {
i = i.wrapping_add(1);
j = j.wrapping_add(s[i as usize]);
s.swap(i as usize, j as usize);
let k = s[(s[i as usize].wrapping_add(s[j as usize])) as usize];
let k2 = k.wrapping_add(69); // 325 mod 256 = 69(与 C++ u8 截断一致)
*byte ^= k2;
}
}
fn main() -> io::Result<()> {
let args: Vec<String> = env::args().collect();
if args.len() != 3 {
eprintln!("用法: {} <input_file> <4digit_key>", args.get(0).unwrap_or(&"enc".to_string()));
eprintln!("示例: {} a.png 0123", args.get(0).unwrap_or(&"enc".to_string()));
std::process::exit(1);
}
let in_path = &args[1];
let key_str = &args[2];
if !is_4digit_key(key_str) {
eprintln!("[!] key 必须是 4 位数字,比如 0123");
std::process::exit(1);
}
let mut data = fs::read(in_path).map_err(|e| {
eprintln!("[!] 读取失败: {} ({})", in_path, e);
e
})?;
rc4_crypt(&mut data, key_str.as_bytes());
let out_path = format!("{}.rc4", in_path);
fs::write(&out_path, &data).map_err(|e| {
eprintln!("[!] 写入失败: {} ({})", out_path, e);
e
})?;
println!("[+] 加密完成: {}", out_path);
Ok(())
}
im_revenge&im
只需要学习源码后得出逆向模型的思路后解题即可。如果 im 认真做的话,这道题不是难事。
懒得开虚拟机截图了,exp 不是重点,重点在于逆向过程,不再过多赘述基础的东西:看其他选手的 WP 或者源码喂 AI 都可以。(看了一圈全是codex,有点扎心)
(我所有题目都把源码放出来了,随意使用,想怎么用怎么用)
源码如下(只修改了灯阵):
源码
import dataclasses
import enum
import functools
import logging
import operator
import pickle
import string
import zstandard as zstd
from .better_tracr_compiling import compile_rasp_to_model
from tracr.rasp import rasp
logger = logging.getLogger(__name__)
INITIAL_BOARD_EXAMPLE = """
#__#_____#
_______#__
_3____0___
__2__#___1
___10#____
____1##___
#___2__2__
___#____#_
__1_______
0_____1__0
""".strip()
# fmt: off
REFERENCE_ANSWER_EXAMPLE = [
0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 0, 0, 0, 0, 0, 1,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
0, 0, 1, 0, 0, 1, 0, 0, 1, 0,
1, 0, 0, 0, 1, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
]
# fmt on
# --- 新灯阵(11x11)---
INITIAL_BOARD = """
_#_________
_#_2_#_0_11
___________
_1__#____0_
_____#_#___
_#__2_2__#_
___0_3_____
_#____#__#_
___________
##_2_3_#_#_
_________#_
""".strip()
# --- 新灯阵的唯一解(row-major 展平 121 位)---
# fmt: off
REFERENCE_ANSWER = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
]
# fmt: on
_BOS = "BOS"
_EOS = "EOS"
INCORRECT_MESSAGES = [
"This flag is like a mismatched sock, it just doesn't belong. Pair it correctly and come back.",
"Oh, so close! And by close, I mean not even in the same timezone. Try again!",
"Nice try, but this flag wouldn't even pass a Turing test. Give it another shot!",
"Error 404: Your flag not found in our universe. Check your coordinates.",
"This flag is so wrong, it's not even wrong. Back to the drawing board!",
"If this were a game of hot and cold, you'd be a popsicle. Try again, warmer this time.",
"Beep boop, flag rejected. My silicon heart remains unmoved by your attempt.",
"You must've used a random string generator. Because that's not it, human.",
"Your flag has been flagged for being egregiously incorrect. Please revise.",
"I appreciate your creativity, but the flag needs to be correct, not imaginative.",
]
CORRECT_MESSAGE = (
"Congratulations! The flag is unictf{hashlib.sha256(your_input).hexdigest()}."
)
def _message_character_at(i, ok):
msg = CORRECT_MESSAGE if ok else INCORRECT_MESSAGES[0]
return msg[i] if i < len(msg) else _EOS
class Predicate(enum.Enum):
EQ = 0
LT = 1
GT = 2
_PREDICTATE_TO_OPERATOR = {
Predicate.EQ: operator.eq,
Predicate.LT: operator.lt,
Predicate.GT: operator.gt,
}
def _apply_pred(id2pv, pvid: rasp.Value, value: rasp.Value) -> bool:
assert isinstance(pvid, int) and isinstance(value, int)
pred, expected = id2pv[pvid]
return _PREDICTATE_TO_OPERATOR[pred](value, expected)
@dataclasses.dataclass
class Condition:
"""
Checks that |solution[coords].sum() <predicate> value| holds.
"""
coords: set[int]
value: int
predicate: Predicate
class Checker:
"""
Checks whether the solution is correct for the given initial board of Light Up
(Akari), a binary-determination logic puzzle published by Nikoli.
Light Up is played on a rectangular grid of white and black cells. The player places
light bulbs in white cells such that no two bulbs shine on each other, until the
entire grid is lit up. A bulb sends rays of light horizontally and vertically,
illuminating its entire row and column unless its light is blocked by a black cell.
A black cell may have a number on it from 0 to 4, indicating how many bulbs must be
placed adjacent to its four sides; for example, a cell with a 4 must have four bulbs
around it, one on each side, and a cell with a 0 cannot have a bulb next to any of
its sides. An unnumbered black cell may have any number of light bulbs adjacent to
it, or none. Bulbs placed diagonally adjacent to a numbered cell do not contribute
to the bulb count.
Checker first generates a list of conditions to be satisfied by the solution, and
then when the solution is provided, it checks whether all the conditions are met.
"""
_grid: list[str]
_n: int
_m: int
_conditions: list[Condition]
_DXDY = [(0, 1), (0, -1), (1, 0), (-1, 0)]
def __init__(self, board: str):
self._grid = board.strip().splitlines()
self._n = len(self._grid)
self._m = len(self._grid[0])
assert all(len(row) == self._m for row in self._grid)
self._conditions = self._build_conditions()
def _build_conditions(self) -> list[Condition]:
result = []
# Numbered cells or walls must not have bulbs.
must_be_zero_cells = {
self._coord(i, j)
for i in range(self._n)
for j in range(self._m)
if self._grid[i][j] != "_"
}
result.append(Condition(must_be_zero_cells, 0, Predicate.EQ))
# Numbered cells must have the correct number of bulbs around them.
for i in range(self._n):
for j in range(self._m):
if self._grid[i][j] in "01234":
result.append(
Condition(
set(self._adjacent_coord(i, j)),
int(self._grid[i][j]),
Predicate.EQ,
)
)
# Bulbs must not shine on each other.
# Scan horizontally.
for i in range(self._n):
cur = []
for j in range(self._m + 1):
if j >= self._m or self._grid[i][j] != "_":
if cur:
result.append(Condition(set(cur), 2, Predicate.LT))
cur = []
else:
cur.append(self._coord(i, j))
# Scan vertically.
for j in range(self._m):
cur = []
for i in range(self._n + 1):
if i >= self._n or self._grid[i][j] != "_":
if cur:
result.append(Condition(set(cur), 2, Predicate.LT))
cur = []
else:
cur.append(self._coord(i, j))
# All empty cells must be lit up.
for i in range(self._n):
for j in range(self._m):
if self._grid[i][j] != "_":
continue
visible_from_here = set()
for di, dj in self._DXDY:
ii, jj = i, j
while self._in_bound(ii, jj) and self._grid[ii][jj] == "_":
visible_from_here.add(self._coord(ii, jj))
ii += di
jj += dj
result.append(Condition(visible_from_here, 0, Predicate.GT))
return result
def _coord(self, i: int, j: int) -> int:
return i * self._m + j
def _in_bound(self, i: int, j: int) -> bool:
return 0 <= i < self._n and 0 <= j < self._m
def _xy(self, coord: int) -> tuple[int, int]:
return divmod(coord, self._m)
def _adjacent(self, i: int, j: int) -> list[tuple[int, int]]:
return [
(i + di, j + dj) for di, dj in self._DXDY if self._in_bound(i + di, j + dj)
]
def _adjacent_coord(self, i: int, j: int) -> list[int]:
return [self._coord(i, j) for i, j in self._adjacent(i, j)]
def _debug(self, solution, condition: Condition) -> str:
result = f"Sum: {sum(solution[coord] for coord in condition.coords)}\n"
for coord in condition.coords:
x, y = self._xy(coord)
result += f"{x}, {y} = {solution[coord]} ({self._grid[x][y]})\n"
return result
def check(self, solution: list[int]) -> bool:
if len(solution) != self._n * self._m:
return False
for condition in self._conditions:
total = sum(solution[coord] for coord in condition.coords)
opr = _PREDICTATE_TO_OPERATOR[condition.predicate]
if not opr(total, condition.value):
logger.debug(
"Condition not met: %s\n%s",
condition,
self._debug(solution, condition),
)
return False
return True
def to_tracr_program(self):
from . import miprim
from tracr.compiler.lib import length, make_count
decoded_input = rasp.Map(
lambda x: ord(x) - 0x30 if isinstance(x, str) and x in "01" else 0,
rasp.tokens,
)
sn = self._n * self._m
assert sn <= 200
length_ok = length == self._n * self._m
input_format_ok = rasp.Map(
lambda x: isinstance(x, str) and x in "01", rasp.tokens
)
all_ok = length_ok & input_format_ok
parts = (len(self._conditions) + sn - 1) // sn
chunk_size = min((len(self._conditions) + parts - 1) // parts, sn - 1)
pv2id = {}
for cond in self._conditions:
pv2id.setdefault((cond.predicate, cond.value), len(pv2id))
id2pv = {v: k for k, v in pv2id.items()}
for i in range(0, len(self._conditions), chunk_size):
chunk = self._conditions[i : i + chunk_size]
coords = [cond.coords for cond in chunk]
pvids = [pv2id[(cond.predicate, cond.value)] for cond in chunk]
pvids = miprim.make_constant_sequence(pvids, default=0)
values = miprim.sum_01_sequence(decoded_input, coords, max_seq_len=128)
cur = rasp.SequenceMap(functools.partial(_apply_pred, id2pv), pvids, values)
cur = miprim.set_out_of_range_value_to_true(cur, range(len(chunk)))
all_ok &= cur
all_ok = make_count(all_ok, True) == sn
return rasp.SequenceMap(_message_character_at, rasp.indices, all_ok)
def _prettify_output(out):
if out[0] == "BOS":
out = out[1:]
if "EOS" in out:
out = out[: out.index("EOS")]
if isinstance(out[0], str):
out = "".join(out)
return out
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--debug", action="store_true")
parser.add_argument(
"-o", "--output", help="Output file", default="challenge.pkl.zst"
)
args = parser.parse_args()
# logging.basicConfig(level=logging.DEBUG)
checker = Checker(INITIAL_BOARD)
print("Interpret:", checker.check(REFERENCE_ANSWER))
prog = checker.to_tracr_program()
result = compile_rasp_to_model(
prog,
vocab=set(string.printable),
max_seq_len=128,
compiler_bos=_BOS,
mlp_exactness=100000,
)
model = result.assembled_model
encoded_answer = list(map(str, REFERENCE_ANSWER))
wrong_answer = list(map(str, REFERENCE_ANSWER))
wrong_answer[-1] = "1"
print("RASP presented with correct answer:", _prettify_output(prog(encoded_answer)))
print("RASP presented with wrong answer:", _prettify_output(prog(wrong_answer)))
print("Correct:", _prettify_output(model.apply([_BOS] + encoded_answer).decoded)) # type: ignore
print("Wrong:", _prettify_output(model.apply([_BOS] + wrong_answer).decoded)) # type: ignore
# Dump
cctx = zstd.ZstdCompressor()
with open(args.output, "wb") as fp, cctx.stream_writer(fp) as cfp:
pickle.dump(
{
"config": {
"num_heads": model.model_config.num_heads,
"num_layers": model.model_config.num_layers,
"key_size": model.model_config.key_size,
"mlp_hidden_size": model.model_config.mlp_hidden_size,
"dropout_rate": model.model_config.dropout_rate,
"activation_function": "relu",
"layer_norm": model.model_config.layer_norm,
"causal": model.model_config.causal,
},
"params": model.params,
"input_encoder": model.input_encoder,
"output_encoder": model.output_encoder,
"residual_labels": model.residual_labels,
"embed_spaces": result.embed_spaces,
},
cfp,
)
发表回复