diff --git a/lnrpc.c b/lnrpc.c index d7e3cb1..6fd5dfa 100644 --- a/lnrpc.c +++ b/lnrpc.c @@ -34,7 +34,7 @@ int main(int argc, const char *argv[]) //int verbose = 1; timeout_str = getenv("LNRPC_TIMEOUT"); - int timeout_ms = timeout_str ? atoi(timeout_str) : 5000; + int timeout_ms = timeout_str ? atoi(timeout_str) : 50000000; timeout.tv_sec = timeout_ms / 1000; timeout.tv_usec = (timeout_ms % 1000) * 1000; @@ -101,6 +101,10 @@ int main(int argc, const char *argv[]) case COMMANDO_REPLY_CONTINUES: printf("%.*s", len - 8, buf + 8); continue; + case WIRE_PING: + if (!lnsocket_pong(ln, buf, len)) { + fprintf(stderr, "pong error...\n"); + } default: // ignore extra interleaved messages which can happen continue; diff --git a/lnsocket.c b/lnsocket.c index 38a9418..a3df26e 100644 --- a/lnsocket.c +++ b/lnsocket.c @@ -374,6 +374,63 @@ int lnsocket_make_network_tlv(unsigned char *buf, int buflen, return 1; } +int lnsocket_make_pong_msg(unsigned char *buf, int buflen, u16 num_pong_bytes) +{ + struct cursor msg; + + make_cursor(buf, buf + buflen, &msg); + + if (!cursor_push_u16(&msg, WIRE_PONG)) + return 0; + + // don't include itself + num_pong_bytes = num_pong_bytes <= 4 ? 0 : num_pong_bytes - 4; + + if (!cursor_push_u16(&msg, num_pong_bytes)) + return 0; + + if (msg.p + num_pong_bytes > msg.end) + return 0; + + memset(msg.p, 0, num_pong_bytes); + msg.p += num_pong_bytes; + + return msg.p - msg.start; +} + +static int lnsocket_decode_ping_payload(const unsigned char *payload, int payload_len, u16 *pong_bytes) +{ + struct cursor msg; + + make_cursor((u8*)payload, (u8*)payload + payload_len, &msg); + + if (!cursor_pull_u16(&msg, pong_bytes)) + return 0; + + return 1; +} + +int lnsocket_make_pong_from_ping(unsigned char *buf, int buflen, const unsigned char *ping, u16 ping_len) +{ + u16 pong_bytes; + + if (!lnsocket_decode_ping_payload(ping, ping_len, &pong_bytes)) + return 0; + + return lnsocket_make_pong_msg(buf, buflen, pong_bytes); +} + +int lnsocket_pong(struct lnsocket *ln, const unsigned char *ping, u16 ping_len) +{ + unsigned char pong[0xFFFF]; + u16 len; + + if (!(len = lnsocket_make_pong_from_ping(pong, sizeof(pong), ping, ping_len))) + return 0; + + return lnsocket_write(ln, pong, len); +} + int lnsocket_make_ping_msg(unsigned char *buf, int buflen, u16 num_pong_bytes, u16 ignored_bytes) { struct cursor msg; diff --git a/lnsocket.h b/lnsocket.h index 5099806..1509143 100644 --- a/lnsocket.h +++ b/lnsocket.h @@ -70,7 +70,6 @@ struct lnsocket EXPORT *lnsocket_create(); /* messages */ int lnsocket_make_network_tlv(unsigned char *buf, int buflen, const unsigned char **blockids, int num_blockids, struct tlv *tlv_out); -int EXPORT lnsocket_make_ping_msg(unsigned char *buf, int buflen, unsigned short num_pong_bytes, unsigned short ignored_bytes); int lnsocket_make_init_msg(unsigned char *buf, int buflen, const unsigned char *globalfeatures, unsigned short gflen, const unsigned char *features, unsigned short flen, const struct tlv **tlvs, unsigned short num_tlvs, unsigned short *outlen); int lnsocket_perform_init(struct lnsocket *ln); @@ -84,7 +83,12 @@ int lnsocket_read(struct lnsocket *, unsigned char **buf, unsigned short *len); int lnsocket_send(struct lnsocket *, unsigned short msg_type, const unsigned char *payload, unsigned short payload_len); int lnsocket_recv(struct lnsocket *, unsigned short *msg_type, unsigned char **payload, unsigned short *payload_len); +int lnsocket_pong(struct lnsocket *, const unsigned char *ping, unsigned short ping_len); + +int EXPORT lnsocket_make_pong_from_ping(unsigned char *buf, int buflen, const unsigned char *ping, unsigned short ping_len); +int EXPORT lnsocket_make_ping_msg(unsigned char *buf, int buflen, unsigned short num_pong_bytes, unsigned short ignored_bytes); +int EXPORT lnsocket_make_pong_msg(unsigned char *buf, int buflen, unsigned short num_pong_bytes); void* EXPORT lnsocket_secp(struct lnsocket *); void EXPORT lnsocket_genkey(struct lnsocket *); int EXPORT lnsocket_setkey(struct lnsocket *, const unsigned char key[32]); diff --git a/test.c b/test.c index 431e5c5..631bdae 100644 --- a/test.c +++ b/test.c @@ -20,6 +20,7 @@ int main(int argc, const char *argv[]) { static u8 msgbuf[4096]; u8 *buf; + u16 msgtype; struct lnsocket *ln; static unsigned char key[32] = {0}; @@ -58,11 +59,17 @@ int main(int argc, const char *argv[]) printf("sent ping "); print_data(msgbuf, len); - if (!(ok = lnsocket_read(ln, &buf, &len))) - goto done; + for (int packets = 0; packets < 3; packets++) { + if (!(ok = lnsocket_recv(ln, &msgtype, &buf, &len))) + goto done; + + if (msgtype == WIRE_PONG) { + printf("got pong! "); + print_data(buf, len); + break; + } + } - printf("got "); - print_data(buf, len); done: lnsocket_print_errors(ln); lnsocket_destroy(ln);